diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..841836f17 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "verl"] + path = verl + url = https://github.com/volcengine/verl.git diff --git a/docker/Apptainerfile.rocm b/docker/Apptainerfile.rocm deleted file mode 100644 index 025962187..000000000 --- a/docker/Apptainerfile.rocm +++ /dev/null @@ -1,57 +0,0 @@ -Bootstrap: docker - -# Support - Traing: fsdp; Inference: vllm -# FROM: rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 -# Support - Traing: fsdp; Inference: vllm, sglang -FROM lmsysorg/sglang:v0.4.5-rocm630 - -%environment - export PYTORCH_ROCM_ARCH="gfx90a;gfx942" - - export HIPCC_COMPILE_FLAGS_APPEND="--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__" - export CFLAGS="-D__HIP_PLATFORM_AMD__" - export CXXFLAGS="-D__HIP_PLATFORM_AMD__" - -%post - # Create source directory - mkdir -p /opt/src - - # Uninstall and reinstall vllm - pip uninstall -y vllm - cd /opt/src - git clone -b v0.6.3 https://github.com/vllm-project/vllm.git - cd vllm - MAX_JOBS=$(nproc) python3 setup.py install - cd /opt - rm -rf /opt/src/vllm - - # Install dependencies - pip install "tensordict<0.6" --no-deps - pip install accelerate \ - codetiming \ - datasets \ - dill \ - hydra-core \ - liger-kernel \ - numpy \ - pandas \ - peft \ - "pyarrow>=15.0.0" \ - pylatexenc \ - "ray[data,train,tune,serve]" \ - torchdata \ - transformers \ - wandb \ - orjson \ - pybind11 - - # Clone and install verl from GitHub - cd /opt - git clone https://github.com/volcengine/verl.git - cd verl - # Uncomment to use a specific version - # git checkout v0.3.0.post0 - pip install -e . --no-deps - - # Install torch_memory_saver - pip install git+https://github.com/ExtremeViscent/torch_memory_saver.git --no-deps \ No newline at end of file diff --git a/docker/Dockerfile.extention.awsefa b/docker/Dockerfile.extention.awsefa deleted file mode 100644 index 10be07697..000000000 --- a/docker/Dockerfile.extention.awsefa +++ /dev/null @@ -1,55 +0,0 @@ -# Base Image support aws EFA -# Build Image with frameworks based on this -FROM verlai/verl:app-verl0.5-sglang0.4.6.post5-mcore0.12.1 - -# For aws instances with EFA net interface (Sagemaker AI Pod) -# install EFA driver: -######## AWS EFA ############ -ENV NCCL_VERSION=2.25.1-1 -ENV DEBIAN_FRONTEND=noninteractive -ENV EFA_INSTALLER_VERSION=1.40.0 -ENV AWS_OFI_NCCL_VERSION=1.14.2 -ENV FI_EFA_SET_CUDA_SYNC_MEMOPS=0 -ENV FI_PROVIDER=efa - -RUN apt update && apt install -y linux-image-generic libhwloc-dev - -RUN cd /tmp && \ - curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz && \ - tar -xf aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz && \ - cd aws-efa-installer && \ - ./efa_installer.sh -y -g --skip-kmod --skip-limit-conf --no-verify && \ - ldconfig && \ - rm -rf /tmp/aws-efa-installer /var/lib/apt/lists/* - -# NCCL EFA Plugin -RUN cd /tmp && \ - curl -LO https://github.com/aws/aws-ofi-nccl/archive/refs/tags/v${AWS_OFI_NCCL_VERSION}.tar.gz && \ - tar -xzf /tmp/v${AWS_OFI_NCCL_VERSION}.tar.gz && \ - rm /tmp/v${AWS_OFI_NCCL_VERSION}.tar.gz && \ - mv aws-ofi-nccl-${AWS_OFI_NCCL_VERSION} aws-ofi-nccl && \ - cd /tmp/aws-ofi-nccl && \ - ./autogen.sh && \ - ./configure --prefix=/opt/amazon/efa \ - --with-libfabric=/opt/amazon/efa \ - --with-cuda=/usr/local/cuda \ - --enable-platform-aws \ - --with-mpi=/opt/amazon/openmpi && \ - make -j$(nproc) install && \ - rm -rf /tmp/aws-ofi/nccl - -# NCCL -RUN echo "/usr/local/lib" >> /etc/ld.so.conf.d/local.conf && \ - echo "/opt/amazon/openmpi/lib" >> /etc/ld.so.conf.d/efa.conf && \ - ldconfig - -ENV OMPI_MCA_pml=^cm,ucx \ - OMPI_MCA_btl=tcp,self \ - OMPI_MCA_btl_tcp_if_exclude=lo,docker0,veth_def_agent \ - OPAL_PREFIX=/opt/amazon/openmpi \ - NCCL_SOCKET_IFNAME=^docker,lo,veth_def_agent \ - FI_EFA_USE_HUGE_PAGE=0 - -# docker build -t verl:awsefa --label "commit=$(git rev-parse --short HEAD)" . -# on aws: -# docker run --ipc=host --privileged --name verldev --gpus all --network=host --shm-size=1800gb -itd verl:awsefa diff --git a/docker/Dockerfile.ngc.vllm b/docker/Dockerfile.ngc.vllm deleted file mode 100644 index 7f29f8a55..000000000 --- a/docker/Dockerfile.ngc.vllm +++ /dev/null @@ -1,47 +0,0 @@ -# docker buildx build --platform linux/x86_64 -t "verlai/verl:ngc-th2.4.0-cu124-vllm0.6.3-ray2.4-te1.7-v0.0.6" -f docker/Dockerfile.ngc.vllm . --builder cloud-verlai-verl-builder --progress=plain --push -FROM nvcr.io/nvidia/pytorch:24.05-py3 - -# uninstall nv-pytorch fork -RUN pip3 uninstall pytorch-quantization \ - pytorch-triton \ - torch \ - torch-tensorrt \ - torchvision \ - xgboost transformer_engine flash_attn \ - apex megatron-core -y - -RUN pip3 install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124 - -# =============== Megatron dependencies (optional) ================= -# install apex, set MAX_JOBS to avoid OOMs -RUN MAX_JOBS=4 pip3 install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \ - --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" \ - git+https://github.com/NVIDIA/apex -# =============== End of Megatron dependencies (optional) ================= - -RUN pip3 install --no-cache-dir \ - accelerate \ - codetiming \ - datasets \ - dill \ - hydra-core \ - numpy \ - 'pandas' \ - 'peft' \ - 'pyarrow>=15.0.0' \ - 'pybind11' \ - 'pylatexenc' \ - 'ray>=2.10' \ - 'tensordict<0.6' \ - 'transformers' \ - 'vllm==0.6.3.post1' \ - 'wandb' - -# full dependencies -RUN pip3 install pytest pre-commit py-spy pyext liger-kernel - -# =============== Megatron dependencies (optional) ================= -# install Transformer Engine, which requires FA 2.5.8. Do it in a separate step for docker cache -RUN MAX_JOBS=4 NINJA_FLAGS="-j4" pip3 install flash-attn==2.5.8 --no-cache-dir --no-build-isolation -RUN MAX_JOBS=1 NINJA_FLAGS="-j1" TE_BUILD_WITH_NINJA=0 pip3 install git+https://github.com/eric-haibin-lin/TransformerEngine.git@v1.7.0 -# =============== End of Megatron dependencies (optional) ================= diff --git a/docker/Dockerfile.ngc.vllm0.8 b/docker/Dockerfile.ngc.vllm0.8 deleted file mode 100644 index 127839fe7..000000000 --- a/docker/Dockerfile.ngc.vllm0.8 +++ /dev/null @@ -1,75 +0,0 @@ -# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10) -# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html -FROM nvcr.io/nvidia/pytorch:24.08-py3 - -# Define environments -ENV MAX_JOBS=32 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Define installation arguments -ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ -ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# Set apt source -RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ - { \ - echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ - } > /etc/apt/sources.list - -# Install systemctl -RUN apt-get update && \ - apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ - apt-get clean - -# Install tini -RUN apt-get update && \ - apt-get install -y tini && \ - apt-get clean - -# Change pip source -RUN pip config set global.index-url "${PIP_INDEX}" && \ - pip config set global.extra-index-url "${PIP_INDEX}" && \ - python -m pip install --upgrade pip - -# Uninstall nv-pytorch fork -RUN pip uninstall -y torch torchvision torchaudio \ - pytorch-quantization pytorch-triton torch-tensorrt \ - xgboost transformer_engine flash_attn apex megatron-core grpcio - -# Install torch-2.6.0+cu124 + vllm-0.8.3 -# torch-2.6.0+cu124: cxx11abi=False -# torch-2.6.0+cu126: cxx11abi=True -# see https://github.com/flashinfer-ai/flashinfer/issues/911 -RUN pip install --no-cache-dir "vllm==0.8.3" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata \ - "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=15.0.0" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \ - pytest py-spy pyext pre-commit ruff - -# Install flash-attn-2.7.4.post1 (cxx11abi=False) -RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ - pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - -# Install flashinfer-0.2.2.post1+cu124 (cxx11abi=False) -# vllm-0.8.3 does not support flashinfer>=0.2.3 -# see https://github.com/vllm-project/vllm/pull/15777 -RUN wget -nv https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ - pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl - -# Fix packages -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -# Install verl -RUN pip install --no-cache-dir verl[vllm] -U - -# Reset pip config -RUN pip config unset global.index-url && \ - pip config unset global.extra-index-url diff --git a/docker/Dockerfile.ngc.vllm0.8.sagemaker b/docker/Dockerfile.ngc.vllm0.8.sagemaker deleted file mode 100644 index 0a7806340..000000000 --- a/docker/Dockerfile.ngc.vllm0.8.sagemaker +++ /dev/null @@ -1,46 +0,0 @@ -# Using a pre-built image from AWS DLC which contains the current version of python (3.10) and supported cuda version (12.1) -FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.1.0-transformers4.36.0-gpu-py310-cu121-ubuntu20.04 - -# uninstall nv-pytorch fork -RUN pip3 uninstall -y pytorch-quantization \ - pytorch-triton torch torch-tensorrt torchvision \ - xgboost transformer_engine flash_attn apex megatron-core - -# Define environments -ENV MAX_JOBS=32 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Install systemctl -RUN apt-get update && \ - apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ - apt-get clean - -# Install tini -RUN apt-get update && \ - apt-get install -y tini && \ - apt-get clean - -# Install torch-2.6.0 + vllm-0.8.2 -RUN pip install --no-cache-dir vllm==0.8.2 torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata==0.11.0 \ - transformers>=4.49.0 accelerate datasets peft hf-transfer \ - ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \ - pytest pre-commit py-spy pyext ruff - -# Install flash_attn-2.7.4.post1 -RUN pip uninstall -y transformer-engine flash-attn && \ - pip install flash-attn==2.7.4.post1 --no-build-isolation - -# Fix cv2 -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6 && \ - pip install --no-cache-dir --upgrade optree>=0.13.0 - -# Install verl -RUN pip install --no-cache-dir verl[vllm] -U - -# Reset pip config -RUN pip config unset global.index-url && \ - pip config unset global.extra-index-url diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm deleted file mode 100644 index e8c209cd0..000000000 --- a/docker/Dockerfile.rocm +++ /dev/null @@ -1,321 +0,0 @@ -# FROM "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.4:94_ubuntu22.04_py3.10_pytorch_release-2.7_575e247" -FROM "rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04" - -SHELL ["/bin/bash", "-ceuxo", "pipefail"] - -ENV MAX_JOBS=512 - -ENV PATH="/usr/local/python3.12/bin:$PATH" -RUN ln -sf /usr/bin/python3.12 /usr/bin/python && \ - ln -sf /usr/bin/pip3.12 /usr/bin/pip - -############################################ -############################################ -RUN apt-get update -RUN apt-get install -y pkg-config liblzma-dev -############################################ -############################################ - - -########################################### -##########Install TransformerEngine######## -########################################### -WORKDIR /workspace/ -# transformer-engine install -# https://github.com/ROCm/TransformerEngine - -RUN rm -rf TransformerEngine -RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git -WORKDIR /workspace/TransformerEngine -RUN git checkout 236178e5 -# git checkout bb061ade -# git checkout 864405c - -ENV NVTE_FRAMEWORK=pytorch -ENV NVTE_ROCM_ARCH=gfx942 -ENV NVTE_USE_HIPBLASLT=1 -ENV NVTE_USE_ROCM=1 - -# export CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}" -ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" - - -# ENV NVTE_BUILD_MAX_JOBS=$(MAX_JOBS) - -RUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv - -WORKDIR /workspace/ -########################################### -########################################### -########################################### - - - - - -#################################################################################### -################Install vllm - sglang require vllm 0.6.7 dependency################# -#################################################################################### -#### Require vllm 0.6.7 - checkout 113274a0 -WORKDIR /workspace/ -RUN rm -rf vllm -RUN pip uninstall -y vllm -# Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html -RUN git clone https://github.com/ROCm/vllm.git -# git clone https://github.com/vllm-project/vllm.git -WORKDIR /workspace/vllm -RUN git checkout 113274a0 -ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" -#ENV MAX_JOBS=512 -ENV MAX_JOBS=${MAX_JOBS} -RUN pip install "boto3>=1.26.0" -RUN pip install setuptools_scm -# will add src into py. You can delete the repo -RUN python3 setup.py install -WORKDIR /workspace/ -#################################################################################### -#################################################################################### -#################################################################################### - - - -########################################### -############For hack docker################ -########################################### -RUN pip install setuptools==75.8.0 -########################################### -########################################### -########################################### - - - -########################################### -############build sgalng################### -########################################### -# Set environment variables -ENV BASE_DIR=/sgl-workspace -ENV BUILD_TYPE=all -ENV SGL_REPO=https://github.com/sgl-project/sglang -ENV SGL_BRANCH=v0.4.6.post5 -ENV TRITON_REPO=https://github.com/ROCm/triton.git -ENV TRITON_COMMIT=improve_fa_decode_3.0.0 -ENV AITER_REPO=https://github.com/ROCm/aiter.git -ENV AITER_COMMIT=v0.1.2 -# v0.1.2 version - commit id: 9d11f47 -# ENV AITER_COMMIT=9d11f47 - -ENV HIP_FORCE_DEV_KERNARG=1 -ENV HSA_NO_SCRATCH_RECLAIM=1 -ENV SGLANG_SET_CPU_AFFINITY=1 -ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 -ENV NCCL_MIN_NCHANNELS=112 -ENV MOE_PADDING=1 -ENV VLLM_FP8_PADDING=1 -ENV VLLM_FP8_ACT_PADDING=1 -ENV VLLM_FP8_WEIGHT_PADDING=1 -ENV VLLM_FP8_REDUCE_CONV=1 -ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 -ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 -ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" -ENV AMDGPU_TARGETS=gfx942 -ENV ROCM_ARCH=gfx942 -ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" - -# Switch to working directory -WORKDIR /sgl-workspace - -# Clean and create directory -RUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace - -# Clone and build sglang -RUN git clone ${SGL_REPO} \ - && cd sglang \ - && git checkout ${SGL_BRANCH} || echo "Using default branch" \ - && cd sgl-kernel \ - && rm -f pyproject.toml \ - && mv pyproject_rocm.toml pyproject.toml \ - && python setup_rocm.py install \ - && cd .. \ - && if [ "$BUILD_TYPE" = "srt" ]; then \ - python -m pip --no-cache-dir install -e "python[srt_hip]"; \ - else \ - python -m pip --no-cache-dir install -e "python[all_hip]"; \ - fi \ - && cd /sgl-workspace \ - && cp -r /sgl-workspace/sglang /sglang \ - && python -m pip cache purge - -# Install common Python packages -RUN pip install IPython orjson python-multipart torchao pybind11 - -# Rebuild Triton -RUN pip uninstall -y triton || true \ - && git clone ${TRITON_REPO} \ - && cd triton \ - && git checkout ${TRITON_COMMIT} \ - && cd python \ - && python3 setup.py install \ - && cd /sgl-workspace - - -# ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1" -# ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" - -# Build aiter -#version: Commit 9d11f47 - # && git checkout ${AITER_COMMIT} \ -RUN pip uninstall -y aiter || true -RUN git clone ${AITER_REPO} \ - && cd aiter \ - && git checkout ${AITER_COMMIT} \ - && git submodule sync \ - && git submodule update --init --recursive \ - && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \ - && cd /sgl-workspace - # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \ - # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \ - -# Copy MI300X config -RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \ - /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ - -type f -name '*MI300X*' | \ - xargs -I {} sh -c 'vf_config=$(echo "$1" | sed "s/MI300X/MI300X_VF/"); cp "$1" "$vf_config"' -- {} - -# Environment setup complete. -RUN echo "Environment setup complete." - -WORKDIR /workspace/ -########################################### -########################################### -########################################### - - - - - - -########################################### -###############vllm v0.8.5################# -########################################### -# ENV GITHUB_USERNAME=yushengsu-thu -# ENV GITHUB_MAIL=yushengsu@gmail.com - -# RUN git config --global user.name "${GITHUB_USERNAME}" \ -# && git config --global user.email "${GITHUB_MAIL}" - -WORKDIR /workspace/ - -ENV VLLM_TARGET_DEVICE=rocm -ENV ROCM_PATH=/opt/rocm -ENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev - -# Find the repo path in: DockerFile/Dockerfile.rocm_yang -# RUN git clone https://github.com/RLFoundation/vllm-patch.git -RUN pip uninstall -y vllm || true -RUN rm -rf vllm-patch -RUN git clone https://github.com/RLFoundation/vllm-patch.git \ - && cd vllm-patch \ - && git checkout v0.8.5-sleep-numa \ - && rm -rf build/ dist/ *.egg-info \ - && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \ - && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py install - # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py develop - -WORKDIR /workspace/ -########################################### -########################################### -########################################### - - - - -######################################### -#### Install megatron-core############### -######################################### -RUN pip uninstall -y megatron-core && \ - git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \ - cd Megatron-LM-amd_version && \ - pip install -vvv -e . && \ - cd /workspace/ -######################################### -######################################### -######################################### - - - - -####################################### -################apex################### -####################################### -WORKDIR /workspace/ -RUN pip uninstall -y apex && \ - git clone https://github.com/ROCm/apex.git && \ - cd apex && \ - python setup.py install && \ - cd /workspace/ -####################################### -####################################### -####################################### - - - - -################################################################################ -###########################Add torch_memory_saver############################### -################################################################################ -# Set environment variables -ENV HIPCC_COMPILE_FLAGS_APPEND="--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__" -ENV CFLAGS="-D__HIP_PLATFORM_AMD__" -ENV CXXFLAGS="-D__HIP_PLATFORM_AMD__" -RUN pip install "git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa" -################################################################################ -################################################################################ -################################################################################ - - - -######################################## -######Install ray####################### -######################################## -# need to add this patch: https://github.com/ray-project/ray/pull/53531/files -RUN pip uninstall ray -y -RUN pip install "ray[data,train,tune,serve]>=2.47.0" -######################################## -######################################## -######################################## - - - -########################################## -#######Install other dependencies######### -########################################## -RUN pip install "tensordict==0.6.2" --no-deps && \ - pip install accelerate \ - codetiming \ - datasets \ - dill \ - hydra-core \ - liger-kernel \ - numpy \ - pandas \ - peft \ - "pyarrow>=15.0.0" \ - pylatexenc \ - torchdata \ - wandb \ - orjson \ - pybind11 - -WORKDIR /workspace/ -RUN git clone https://github.com/volcengine/verl.git && \ - cd verl && \ - pip install -e . -########################################## -########################################## -########################################## - - - -WORKDIR /workspace/ - -CMD ["/usr/bin/bash"] diff --git a/docker/Dockerfile.rocm_verl-0.3.0.post1 b/docker/Dockerfile.rocm_verl-0.3.0.post1 deleted file mode 100644 index 185096d9d..000000000 --- a/docker/Dockerfile.rocm_verl-0.3.0.post1 +++ /dev/null @@ -1,58 +0,0 @@ -# Build the docker in the repo dir: -# docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 . -# docker images # you can find your built docker - - -# Support - Traing: fsdp; Inference: vllm -# FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 -# Support - Traing: fsdp; Inference: vllm, sglang -FROM lmsysorg/sglang:v0.4.6.post5-rocm630 - -# Set working directory -# WORKDIR $PWD/app - -# Set environment variables -ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" - -ENV HIPCC_COMPILE_FLAGS_APPEND="--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__" -ENV CFLAGS="-D__HIP_PLATFORM_AMD__" -ENV CXXFLAGS="-D__HIP_PLATFORM_AMD__" - -# Install vllm -RUN pip uninstall -y vllm && \ - rm -rf vllm && \ - git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \ - cd vllm && \ - MAX_JOBS=$(nproc) python3 setup.py install && \ - cd .. && \ - rm -rf vllm - -# Copy the entire project directory -COPY . . - -# Install dependencies -RUN pip install "tensordict==0.6.2" --no-deps && \ - pip install accelerate \ - codetiming \ - datasets \ - dill \ - hydra-core \ - liger-kernel \ - numpy \ - pandas \ - peft \ - "pyarrow>=15.0.0" \ - pylatexenc \ - "ray[data,train,tune,serve]<2.45.0" \ - torchdata \ - transformers \ - wandb \ - orjson \ - pybind11 - -RUN git clone https://github.com/volcengine/verl.git && \ - cd verl && \ - pip install -e . - -# Install torch_memory_saver -RUN pip install git+https://github.com/ExtremeViscent/torch_memory_saver.git --no-deps diff --git a/docker/Dockerfile.rocm_verl-0.4.1 b/docker/Dockerfile.rocm_verl-0.4.1 deleted file mode 100644 index b6d30521b..000000000 --- a/docker/Dockerfile.rocm_verl-0.4.1 +++ /dev/null @@ -1,322 +0,0 @@ -# FROM "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-rel-6.4:94_ubuntu22.04_py3.10_pytorch_release-2.7_575e247" -FROM "rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04" - -SHELL ["/bin/bash", "-ceuxo", "pipefail"] - -ENV MAX_JOBS=512 - -ENV PATH="/usr/local/python3.12/bin:$PATH" -RUN ln -sf /usr/bin/python3.12 /usr/bin/python && \ - ln -sf /usr/bin/pip3.12 /usr/bin/pip - -############################################ -############################################ -RUN apt-get update -RUN apt-get install -y pkg-config liblzma-dev -############################################ -############################################ - - -########################################### -##########Install TransformerEngine######## -########################################### -WORKDIR /workspace/ -# transformer-engine install -# https://github.com/ROCm/TransformerEngine - -RUN rm -rf TransformerEngine -RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git -WORKDIR /workspace/TransformerEngine -RUN git checkout 236178e5 -# git checkout bb061ade -# git checkout 864405c - -ENV NVTE_FRAMEWORK=pytorch -ENV NVTE_ROCM_ARCH=gfx942 -ENV NVTE_USE_HIPBLASLT=1 -ENV NVTE_USE_ROCM=1 - -# export CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}" -ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" - - -# ENV NVTE_BUILD_MAX_JOBS=$(MAX_JOBS) - -RUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv - -WORKDIR /workspace/ -########################################### -########################################### -########################################### - - - - - -#################################################################################### -################Install vllm - sglang require vllm 0.6.7 dependency################# -#################################################################################### -#### Require vllm 0.6.7 - checkout 113274a0 -WORKDIR /workspace/ -RUN rm -rf vllm -RUN pip uninstall -y vllm -# Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html -RUN git clone https://github.com/ROCm/vllm.git -# git clone https://github.com/vllm-project/vllm.git -WORKDIR /workspace/vllm -RUN git checkout 113274a0 -ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" -#ENV MAX_JOBS=512 -ENV MAX_JOBS=${MAX_JOBS} -RUN pip install "boto3>=1.26.0" -RUN pip install setuptools_scm -# will add src into py. You can delete the repo -RUN python3 setup.py install -WORKDIR /workspace/ -#################################################################################### -#################################################################################### -#################################################################################### - - - -########################################### -############For hack docker################ -########################################### -RUN pip install setuptools==75.8.0 -########################################### -########################################### -########################################### - - - -########################################### -############build sgalng################### -########################################### -# Set environment variables -ENV BASE_DIR=/sgl-workspace -ENV BUILD_TYPE=all -ENV SGL_REPO=https://github.com/sgl-project/sglang -ENV SGL_BRANCH=v0.4.6.post5 -ENV TRITON_REPO=https://github.com/ROCm/triton.git -ENV TRITON_COMMIT=improve_fa_decode_3.0.0 -ENV AITER_REPO=https://github.com/ROCm/aiter.git -ENV AITER_COMMIT=v0.1.2 -# v0.1.2 version - commit id: 9d11f47 -# ENV AITER_COMMIT=9d11f47 - -ENV HIP_FORCE_DEV_KERNARG=1 -ENV HSA_NO_SCRATCH_RECLAIM=1 -ENV SGLANG_SET_CPU_AFFINITY=1 -ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 -ENV NCCL_MIN_NCHANNELS=112 -ENV MOE_PADDING=1 -ENV VLLM_FP8_PADDING=1 -ENV VLLM_FP8_ACT_PADDING=1 -ENV VLLM_FP8_WEIGHT_PADDING=1 -ENV VLLM_FP8_REDUCE_CONV=1 -ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 -ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 -ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" -ENV AMDGPU_TARGETS=gfx942 -ENV ROCM_ARCH=gfx942 -ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" - -# Switch to working directory -WORKDIR /sgl-workspace - -# Clean and create directory -RUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace - -# Clone and build sglang -RUN git clone ${SGL_REPO} \ - && cd sglang \ - && git checkout ${SGL_BRANCH} || echo "Using default branch" \ - && cd sgl-kernel \ - && rm -f pyproject.toml \ - && mv pyproject_rocm.toml pyproject.toml \ - && python setup_rocm.py install \ - && cd .. \ - && if [ "$BUILD_TYPE" = "srt" ]; then \ - python -m pip --no-cache-dir install -e "python[srt_hip]"; \ - else \ - python -m pip --no-cache-dir install -e "python[all_hip]"; \ - fi \ - && cd /sgl-workspace \ - && cp -r /sgl-workspace/sglang /sglang \ - && python -m pip cache purge - -# Install common Python packages -RUN pip install IPython orjson python-multipart torchao pybind11 - -# Rebuild Triton -RUN pip uninstall -y triton || true \ - && git clone ${TRITON_REPO} \ - && cd triton \ - && git checkout ${TRITON_COMMIT} \ - && cd python \ - && python3 setup.py install \ - && cd /sgl-workspace - - -# ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1" -# ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" - -# Build aiter -#version: Commit 9d11f47 - # && git checkout ${AITER_COMMIT} \ -RUN pip uninstall -y aiter || true -RUN git clone ${AITER_REPO} \ - && cd aiter \ - && git checkout ${AITER_COMMIT} \ - && git submodule sync \ - && git submodule update --init --recursive \ - && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \ - && cd /sgl-workspace - # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \ - # && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop \ - -# Copy MI300X config -RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \ - /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ - -type f -name '*MI300X*' | \ - xargs -I {} sh -c 'vf_config=$(echo "$1" | sed "s/MI300X/MI300X_VF/"); cp "$1" "$vf_config"' -- {} - -# Environment setup complete. -RUN echo "Environment setup complete." - -WORKDIR /workspace/ -########################################### -########################################### -########################################### - - - - - - -########################################### -###############vllm v0.8.5################# -########################################### -# ENV GITHUB_USERNAME=yushengsu-thu -# ENV GITHUB_MAIL=yushengsu@gmail.com - -# RUN git config --global user.name "${GITHUB_USERNAME}" \ -# && git config --global user.email "${GITHUB_MAIL}" - -WORKDIR /workspace/ - -ENV VLLM_TARGET_DEVICE=rocm -ENV ROCM_PATH=/opt/rocm -ENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev - -# Find the repo path in: DockerFile/Dockerfile.rocm_yang -# RUN git clone https://github.com/RLFoundation/vllm-patch.git -RUN pip uninstall -y vllm || true -RUN rm -rf vllm-patch -RUN git clone https://github.com/RLFoundation/vllm-patch.git \ - && cd vllm-patch \ - && git checkout v0.8.5-sleep-numa \ - && rm -rf build/ dist/ *.egg-info \ - && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \ - && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py install - # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py develop - -WORKDIR /workspace/ -########################################### -########################################### -########################################### - - - - -######################################### -#### Install megatron-core############### -######################################### -RUN pip uninstall -y megatron-core && \ - git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \ - cd Megatron-LM-amd_version && \ - pip install -vvv -e . && \ - cd /workspace/ -######################################### -######################################### -######################################### - - - - -####################################### -################apex################### -####################################### -WORKDIR /workspace/ -RUN pip uninstall -y apex && \ - git clone https://github.com/ROCm/apex.git && \ - cd apex && \ - python setup.py install && \ - cd /workspace/ -####################################### -####################################### -####################################### - - - - -################################################################################ -###########################Add torch_memory_saver############################### -################################################################################ -# Set environment variables -ENV HIPCC_COMPILE_FLAGS_APPEND="--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__" -ENV CFLAGS="-D__HIP_PLATFORM_AMD__" -ENV CXXFLAGS="-D__HIP_PLATFORM_AMD__" -RUN pip install "git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa" -################################################################################ -################################################################################ -################################################################################ - - - -######################################## -######Install ray####################### -######################################## -# need to add this patch: https://github.com/ray-project/ray/pull/53531/files -RUN pip uninstall ray -y -RUN pip install "ray[data,train,tune,serve]>=2.47.0" -######################################## -######################################## -######################################## - - - -########################################## -#######Install other dependencies######### -########################################## -RUN pip install "tensordict==0.6.2" --no-deps && \ - pip install accelerate \ - codetiming \ - datasets \ - dill \ - hydra-core \ - liger-kernel \ - numpy \ - pandas \ - peft \ - "pyarrow>=15.0.0" \ - pylatexenc \ - torchdata \ - wandb \ - orjson \ - pybind11 - -WORKDIR /workspace/ -RUN git clone https://github.com/volcengine/verl.git && \ - cd verl && \ - pip install -e . -########################################## -########################################## -########################################## - - - -WORKDIR /workspace/ - -CMD ["/usr/bin/bash"] -CMD ["/usr/bin/bash"] diff --git a/docker/Dockerfile.sglang b/docker/Dockerfile.sglang deleted file mode 100644 index 11ad4a77d..000000000 --- a/docker/Dockerfile.sglang +++ /dev/null @@ -1,55 +0,0 @@ -# Start from the NVIDIA official image (ubuntu-22.04 + python-3.10) -# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html -FROM nvcr.io/nvidia/pytorch:24.08-py3 - -# Define environments -ENV MAX_JOBS=32 -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" - -# Define installation arguments -ARG APT_SOURCE=https://mirrors.ustc.edu.cn/ubuntu/ - -# Set apt source -RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ - { \ - echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ - } > /etc/apt/sources.list - -# Install systemctl -RUN apt-get update && \ - apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ - apt-get clean - -# Install tini -RUN apt-get update && \ - apt-get install -y tini && \ - apt-get clean - -# Change pip source -ARG PIP_INDEX=https://mirrors.aliyun.com/pypi/simple/ - -RUN pip config set global.index-url "${PIP_INDEX}" && \ - pip config set global.extra-index-url "${PIP_INDEX}" && \ - python -m pip install --upgrade pip - -# Install sglang-0.4.6.post5 and torch-memory-saver -RUN pip uninstall -y cuda-python && pip install "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir - -# Install torch-2.6.0 -RUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \ - transformers>=4.49.0 accelerate datasets peft hf_transfer \ - ray[default] codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb liger-kernel \ - pytest pre-commit py-spy pyext - -# Install flash_attn-2.7.4.post1 -RUN pip uninstall -y transformer-engine flash-attn && \ - wget -v https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ - pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - -# Fix cv2 -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6 diff --git a/docker/Dockerfile.vemlp.vllm.te b/docker/Dockerfile.vemlp.vllm.te deleted file mode 100644 index 361fb2084..000000000 --- a/docker/Dockerfile.vemlp.vllm.te +++ /dev/null @@ -1,41 +0,0 @@ -# docker buildx build --platform linux/x86_64 -t "verlai/verl:$TAG" -f docker/$FILE . - -# the one in docker.io is an alias for the one veturbo -# FROM vemlp-cn-beijing.cr.volces.com/veturbo/pytorch:2.4-cu124 -FROM docker.io/haibinlin/verl:v0.0.5-th2.4.0-cu124-base - -# only config pip index with https://pypi.tuna.tsinghua.edu.cn/simple if needed -# unset for now -RUN pip3 config unset global.index-url - -# transformers 4.47.0 contains the following bug: -# AttributeError: 'Gemma2Attention' object has no attribute '_flash_attn_uses_top_left_mask' -RUN pip3 install --no-cache-dir \ - torch==2.4.0 \ - accelerate \ - codetiming \ - dill \ - hydra-core \ - numpy \ - pybind11 \ - tensordict \ - "transformers <= 4.46.0" - -RUN pip3 install --no-cache-dir flash-attn==2.7.0.post2 --no-build-isolation - -# vllm depends on ray -RUN pip3 install --no-cache-dir vllm==0.6.3 ray==2.10 - -# install apex -RUN MAX_JOBS=4 pip3 install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \ - --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" \ - git+https://github.com/NVIDIA/apex - -# install Transformer Engine -# - flash-attn pinned to 2.5.3 by TransformerEngine, switch to eric-haibin-lin/TransformerEngine.git@v1.7.0 to relax version req -# - install with: MAX_JOBS=1 NINJA_FLAGS="-j1" TE_BUILD_WITH_NINJA=0 to avoid OOM -# - cudnn is required by TransformerEngine -# RUN CUDNN_PATH=/opt/conda/lib/python3.11/site-packages/nvidia/cudnn \ -# pip3 install git+https://github.com/eric-haibin-lin/TransformerEngine.git@v1.7.0 -RUN MAX_JOBS=1 NINJA_FLAGS="-j1" pip3 install flash-attn==2.5.3 --no-cache-dir --no-build-isolation -RUN MAX_JOBS=1 NINJA_FLAGS="-j1" pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@v1.7 diff --git a/docker/Dockerfile.vllm.sglang.megatron.deepseek b/docker/Dockerfile.vllm.sglang.megatron.deepseek deleted file mode 100644 index 784537180..000000000 --- a/docker/Dockerfile.vllm.sglang.megatron.deepseek +++ /dev/null @@ -1,115 +0,0 @@ -# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10) -# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html -FROM nvcr.io/nvidia/pytorch:24.08-py3 - -# Define environments -ENV MAX_JOBS=32 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Define installation arguments -ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ -ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# Set apt source -RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ - { \ - echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ - } > /etc/apt/sources.list - -# Install systemctl -RUN apt-get update && \ - apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ - apt-get clean - -# Install tini -RUN apt-get update && \ - apt-get install -y tini aria2 && \ - apt-get clean - -# Change pip source -RUN pip config set global.index-url "${PIP_INDEX}" && \ - pip config set global.extra-index-url "${PIP_INDEX}" && \ - python -m pip install --upgrade pip - -# Uninstall nv-pytorch fork -RUN pip uninstall -y torch torchvision torchaudio \ - pytorch-quantization pytorch-triton torch-tensorrt \ - xgboost transformer_engine flash_attn apex megatron-core grpcio - -# Reinstall CUDA 12.4 -RUN aria2c https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \ - mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 - -RUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \ - dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \ - cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/ && \ - apt-get update && \ - apt-get -y install cuda-toolkit-12-4 && \ - rm cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \ - update-alternatives --set cuda /usr/local/cuda-12.4 && \ - rm -rf /usr/local/cuda-12.6 - -# Install torch-2.6.0+cu124 + vllm-0.8.5.post1 + sglang-0.4.6.post5 -# torch-2.6.0+cu124: cxx11abi=False -# torch-2.6.0+cu126: cxx11abi=True -# see https://github.com/flashinfer-ai/flashinfer/issues/911 -# Install sglang-0.4.6.post1 and torch-memory-saver -RUN pip install --resume-retries 999 "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install --resume-retries 999 torch-memory-saver --no-cache-dir - -RUN pip install --resume-retries 999 --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata - -RUN pip install --resume-retries 999 --no-cache-dir "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=15.0.0" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile \ - pytest py-spy pyext pre-commit ruff - -# Install flash-attn-2.7.4.post1 (cxx11abi=False) -RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ - pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - -# Fix packages -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -# Install cudnn -RUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ - dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ - cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \ - apt-get update && \ - apt-get -y install cudnn-cuda-12 && \ - rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb - -RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 - -# Install Apex -RUN git clone https://github.com/NVIDIA/apex.git && \ - cd apex && \ - pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ - -# Install TransformerEngine -RUN export NVTE_FRAMEWORK=pytorch && pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 - -# Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 - -# Fix opencv -RUN pip install opencv-python - -RUN pip install opencv-fixer && \ - python -c "from opencv_fixer import AutoFix; AutoFix()" - -# Install verl - -# Reset pip config -RUN pip config unset global.index-url && \ - pip config unset global.extra-index-url - - RUN apt-get update && \ - apt-get install -y aria2 libfreeimage3 libfreeimage-dev zlib1g \ No newline at end of file diff --git a/docker/Dockfile.ngc.vllm0.8 b/docker/Dockfile.ngc.vllm0.8 deleted file mode 100644 index ae585524d..000000000 --- a/docker/Dockfile.ngc.vllm0.8 +++ /dev/null @@ -1,59 +0,0 @@ -# Start from the NVIDIA official image (ubuntu-22.04 + python-3.10) -# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html -FROM nvcr.io/nvidia/pytorch:24.08-py3 - -# uninstall nv-pytorch fork -RUN pip3 uninstall -y pytorch-quantization \ - pytorch-triton torch torch-tensorrt torchvision \ - xgboost transformer_engine flash_attn apex megatron-core - -# Define environments -ENV MAX_JOBS=32 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Define installation arguments -ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ -ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# Set apt source -RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ - { \ - echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ - } > /etc/apt/sources.list - -# Install systemctl -RUN apt-get update && \ - apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ - apt-get clean - -# Install tini -RUN apt-get update && \ - apt-get install -y tini && \ - apt-get clean - -# Change pip source -RUN pip config set global.index-url "${PIP_INDEX}" && \ - pip config set global.extra-index-url "${PIP_INDEX}" && \ - python -m pip install --upgrade pip - -# Install torch-2.6.0 + vllm-0.8.1 -RUN pip install --no-cache-dir vllm==0.8.1 torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \ - transformers>=4.49.0 accelerate datasets peft hf-transfer \ - ray codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \ - pytest yapf py-spy pyext pre-commit ruff - -# Install flash_attn-2.7.4.post1 -RUN pip uninstall -y transformer-engine flash-attn && \ - wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ - pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - -# Fix cv2 -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --no-cache-dir nvidia-ml-py>=12.560.30 opencv-python-headless==4.8.0.74 fastapi==0.115.6 && \ - pip install -U optree>=0.13.0 diff --git a/docker/README.md b/docker/README.md deleted file mode 100644 index 1d19e8341..000000000 --- a/docker/README.md +++ /dev/null @@ -1,79 +0,0 @@ -# Dockerfiles of verl - -We provide pre-built Docker images for quick setup. And from this version, we utilize a new image release hierarchy for productivity and stability. - -The image types are divided into three large categories: - -- **Base Image**: Without inference and training frameworks, only basic dependencies are installed. Can directly install vllm or SGLang on top of it, without need of reinstall torch or CUDA. -- **Application Image**: Stable version with inference and training frameworks installed. -- **Preview Image**: Unstable version with the latest frameworks and features. - -The first two types of images are hosted on dockerhub [verlai/verl](https://hub.docker.com/r/verlai/verl) repository, while the preview images are hosted on community repository. - -> The image versions are mapped with verl releases, for example, image with tag ``verl0.4`` is built for verl release ``v0.4.x``. - -## Base Image - -The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4``. The installed package versions can be found from tags, and the Dockerfile can be found in ``verl[version]-[packages]/Dockerfile.base``. - -The base images for preview are ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0`` and ``verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`` with different CUDA versions. - -The update of base image is not frequent, and the app image can be built on top of it without reinstalling base packages. - -## Application Image - -From this version, we divide images built for vLLM and SGLang as the divergence of dependent packages like FlashInfer. - -There are four types of application images available: - -- **vLLM with FSDP and Megatron**: ``verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1`` -- **SGLang with FSDP and Megatron**: ``verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1`` -- **Preview version of SGLang with FSDP and Megatron, CUDA 12.6**: ``verlai/verl:app-verl0.5-sglang0.4.8-mcore0.12.1`` -- **Preview version of SGLang with FSDP and Megatron, CUDA 12.8**: ``verlai/verl:app-preview-verl0.5-sglang0.4.8-mcore0.12.1`` - -For Megatron 0.13.0, we offer preview images, to use latest codes, just replace ``mcore0.12.1`` with ``mcore0.13.0-preview`` in the above image tag. - -The latest vLLM support is coming soon. - -Docker images with Megatron backends are runnable with large language model like ``Qwen/Qwen3-235B-A22B``, ``deepseek-ai/DeepSeek-V3-0324`` post-training. Refer to the :doc:`Large Language Model Post-Training documentation<../perf/dpsk>` for more details. - -Application images can be updated frequently, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.app.[frameworks]``. Based on the base image, it is easy to build your own application image with the desired inference and training frameworks. - -## Community Image - -For vLLM with FSDP, please refer to [hiyouga/verl](https://hub.docker.com/r/hiyouga/verl) repository and the latest version is ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``. - -For SGLang with FSDP, please refer to [ocss884/verl-sglang](https://hub.docker.com/r/ocss884/verl-sglang) repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group. - -See files under ``docker/`` for NGC-based image or if you want to build your own. - -Note that For aws instances with EFA net interface (Sagemaker AI Pod), you need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa`` - -## Installation from Docker - -After pulling the desired Docker image and installing desired inference and training frameworks, you can run it with the following steps: - -1. Launch the desired Docker image and attach into it: - -```sh -docker create --runtime=nvidia --gpus all --net=host --shm-size="10g" --cap-add=SYS_ADMIN -v .:/workspace/verl --name verl sleep infinity -docker start verl -docker exec -it verl bash -``` - -2. If you use the images provided, you only need to install verl itself without dependencies: - -```sh -# install the nightly version (recommended) -git clone https://github.com/volcengine/verl && cd verl -pip3 install --no-deps -e . -``` - -[Optional] If you hope to switch between different frameworks, you can install verl with the following command: - -```sh -# install the nightly version (recommended) -git clone https://github.com/volcengine/verl && cd verl -pip3 install -e .[vllm] -pip3 install -e .[sglang] -``` diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12 b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12 deleted file mode 100644 index eaa12611e..000000000 --- a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12 +++ /dev/null @@ -1,41 +0,0 @@ -# Start from the verl base image -# Dockerfile.base -FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 - -# Define environments -ENV MAX_JOBS=32 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Install sglang-0.4.6.post5 and torch-memory-saver -RUN pip install --resume-retries 999 "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir - -# Some sglang operations in 0.4.6.post5 require vllm -# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer -RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 - -# Fix packages -RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pyext pre-commit ruff - -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 - -# Install TransformerEngine -RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 - -# Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 - -# Fix for transformers 4.53.0 -RUN pip3 install --no-cache-dir "transformers[hf_xet]<4.52.0" - -# Install mbridge -RUN pip3 install --no-cache-dir mbridge \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12.deepep b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12.deepep deleted file mode 100644 index dc6907610..000000000 --- a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.12.deepep +++ /dev/null @@ -1,82 +0,0 @@ -# Start from the verl base image -# Dockerfile.base -FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 - -# Define environments -ENV MAX_JOBS=32 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Install sglang-0.4.6.post5 and torch-memory-saver -RUN pip install --resume-retries 999 "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir - -# Some sglang operations in 0.4.6.post5 require vllm -# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer -RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 - -# Fix packages -RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pyext pre-commit ruff - -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 - -# Install TransformerEngine -RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 - -# Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 - -# Fix for transformers 4.53.0 -RUN pip3 install --no-cache-dir "transformers[hf_xet]<4.52.0" - -# Install mbridge -RUN pip3 install --no-cache-dir mbridge - -# Install DeepEP -## the dependency of IBGDA -RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so - -## Clone and build deepep and deepep-nvshmem -RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \ - git clone https://github.com/deepseek-ai/DeepEP.git && \ - cd DeepEP && git checkout a84a248 - -# Prepare nvshmem -RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \ - tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \ - cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch - -ENV CUDA_HOME=/usr/local/cuda -### Set MPI environment variables. Having errors when not set. -ENV CPATH=/usr/local/mpi/include:$CPATH -ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH -ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH -ENV GDRCOPY_HOME=/workspace/gdrcopy - -## Build deepep-nvshmem -RUN cd deepep-nvshmem && \ - NVSHMEM_SHMEM_SUPPORT=0 \ - NVSHMEM_UCX_SUPPORT=0 \ - NVSHMEM_USE_NCCL=0 \ - NVSHMEM_MPI_SUPPORT=0 \ - NVSHMEM_IBGDA_SUPPORT=1 \ - NVSHMEM_PMIX_SUPPORT=0 \ - NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ - NVSHMEM_USE_GDRCOPY=1 \ - cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install - -ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install -ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH -ENV PATH=$NVSHMEM_DIR/bin:$PATH - -## Build deepep -RUN cd DeepEP && \ - python setup.py install \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.13.preview b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.13.preview deleted file mode 100644 index 0e0bdd43f..000000000 --- a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.sglang.vllm.mcore0.13.preview +++ /dev/null @@ -1,82 +0,0 @@ -# Start from the verl base image -# Dockerfile.base -FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 - -# Define environments -ENV MAX_JOBS=32 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Install sglang-0.4.6.post5 and torch-memory-saver -RUN pip install --resume-retries 999 "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir - -# Some sglang operations in 0.4.6.post5 require vllm -# [Warning] vllm can have some packages not compatible with sglang, for example, flashinfer -RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 - -# Fix packages -RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pyext pre-commit ruff - -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 - -# Install TransformerEngine -RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5 - -# Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.13.0 - -# Fix for transformers 4.53.0 -RUN pip3 install --no-cache-dir "transformers[hf_xet]<4.52.0" - -# Install mbridge -RUN pip3 install --no-cache-dir mbridge - -# Install DeepEP -## the dependency of IBGDA -RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so - -## Clone and build deepep and deepep-nvshmem -RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \ - git clone https://github.com/deepseek-ai/DeepEP.git && \ - cd DeepEP && git checkout a84a248 - -# Prepare nvshmem -RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \ - tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \ - cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch - -ENV CUDA_HOME=/usr/local/cuda -### Set MPI environment variables. Having errors when not set. -ENV CPATH=/usr/local/mpi/include:$CPATH -ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH -ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH -ENV GDRCOPY_HOME=/workspace/gdrcopy - -## Build deepep-nvshmem -RUN cd deepep-nvshmem && \ - NVSHMEM_SHMEM_SUPPORT=0 \ - NVSHMEM_UCX_SUPPORT=0 \ - NVSHMEM_USE_NCCL=0 \ - NVSHMEM_MPI_SUPPORT=0 \ - NVSHMEM_IBGDA_SUPPORT=1 \ - NVSHMEM_PMIX_SUPPORT=0 \ - NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ - NVSHMEM_USE_GDRCOPY=1 \ - cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install - -ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install -ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH -ENV PATH=$NVSHMEM_DIR/bin:$PATH - -## Build deepep -RUN cd DeepEP && \ - python setup.py install \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12 b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12 deleted file mode 100644 index fcf066eda..000000000 --- a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12 +++ /dev/null @@ -1,47 +0,0 @@ -# Start from the verl base image -# Dockerfile.base -FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 - -# Define environments -ENV MAX_JOBS=32 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Install torch-2.6.0+cu126 + vllm-0.8.5.post1 -# torch-2.6.0+cu124: cxx11abi=False -# torch-2.6.0+cu126: cxx11abi=True -# see https://github.com/flashinfer-ai/flashinfer/issues/911 -RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 - -# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True) -# vllm-0.8.3 does not support flashinfer>=0.2.3 -# see https://github.com/vllm-project/vllm/pull/15777 -RUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ - pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ - rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl - -# Fix packages -RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pyext pre-commit ruff - -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 - -# Install TransformerEngine -RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 - -# Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 - -# Fix for transformers 4.53.0 -RUN pip3 install --no-cache-dir "transformers[hf_xet]<4.52.0" - -# Install mbridge -RUN pip3 install --no-cache-dir mbridge \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12.deepep b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12.deepep deleted file mode 100644 index 61b4fdc6a..000000000 --- a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.12.deepep +++ /dev/null @@ -1,88 +0,0 @@ -# Start from the verl base image -# Dockerfile.base -FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 - -# Define environments -ENV MAX_JOBS=32 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Install torch-2.6.0+cu126 + vllm-0.8.5.post1 -# torch-2.6.0+cu124: cxx11abi=False -# torch-2.6.0+cu126: cxx11abi=True -# see https://github.com/flashinfer-ai/flashinfer/issues/911 -RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 - -# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True) -# vllm-0.8.3 does not support flashinfer>=0.2.3 -# see https://github.com/vllm-project/vllm/pull/15777 -RUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ - pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ - rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl - -# Fix packages -RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pyext pre-commit ruff - -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 - -# Install TransformerEngine -RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 - -# Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 - -# Fix for transformers 4.53.0 -RUN pip3 install --no-cache-dir "transformers[hf_xet]<4.52.0" - -# Install mbridge -RUN pip3 install --no-cache-dir mbridge - -# Install DeepEP -## the dependency of IBGDA -RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so - -## Clone and build deepep and deepep-nvshmem -RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \ - git clone https://github.com/deepseek-ai/DeepEP.git && \ - cd DeepEP && git checkout a84a248 - -# Prepare nvshmem -RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \ - tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \ - cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch - -ENV CUDA_HOME=/usr/local/cuda -### Set MPI environment variables. Having errors when not set. -ENV CPATH=/usr/local/mpi/include:$CPATH -ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH -ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH -ENV GDRCOPY_HOME=/workspace/gdrcopy - -## Build deepep-nvshmem -RUN cd deepep-nvshmem && \ - NVSHMEM_SHMEM_SUPPORT=0 \ - NVSHMEM_UCX_SUPPORT=0 \ - NVSHMEM_USE_NCCL=0 \ - NVSHMEM_MPI_SUPPORT=0 \ - NVSHMEM_IBGDA_SUPPORT=1 \ - NVSHMEM_PMIX_SUPPORT=0 \ - NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ - NVSHMEM_USE_GDRCOPY=1 \ - cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install - -ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install -ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH -ENV PATH=$NVSHMEM_DIR/bin:$PATH - -## Build deepep -RUN cd DeepEP && \ - python setup.py install \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview deleted file mode 100644 index 1fba3fa86..000000000 --- a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.app.vllm.mcore0.13.preview +++ /dev/null @@ -1,85 +0,0 @@ -# Start from the verl base image -# Dockerfile.base -FROM verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4 - -# Define environments -ENV MAX_JOBS=32 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Install torch-2.6.0+cu126 + vllm-0.8.5.post1 -# torch-2.6.0+cu124: cxx11abi=False -# torch-2.6.0+cu126: cxx11abi=True -# see https://github.com/flashinfer-ai/flashinfer/issues/911 -RUN pip install --resume-retries 999 --no-cache-dir vllm==0.8.5.post1 - -# Install flashinfer-0.2.2.post1+cu126 (cxx11abi=True) -# vllm-0.8.3 does not support flashinfer>=0.2.3 -# see https://github.com/vllm-project/vllm/pull/15777 -RUN aria2c --max-tries=9999 https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ - pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ - rm flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl - -# Fix packages -RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pyext pre-commit ruff - -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 - -# Install TransformerEngine -RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5 - -# Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 - -# Install mbridge -RUN pip3 install --no-cache-dir mbridge - -# Install DeepEP -## the dependency of IBGDA -RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so - -## Clone and build deepep and deepep-nvshmem -RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \ - git clone https://github.com/deepseek-ai/DeepEP.git && \ - cd DeepEP && git checkout a84a248 - -# Prepare nvshmem -RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \ - tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \ - cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch - -ENV CUDA_HOME=/usr/local/cuda -### Set MPI environment variables. Having errors when not set. -ENV CPATH=/usr/local/mpi/include:$CPATH -ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH -ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH -ENV GDRCOPY_HOME=/workspace/gdrcopy - -## Build deepep-nvshmem -RUN cd deepep-nvshmem && \ - NVSHMEM_SHMEM_SUPPORT=0 \ - NVSHMEM_UCX_SUPPORT=0 \ - NVSHMEM_USE_NCCL=0 \ - NVSHMEM_MPI_SUPPORT=0 \ - NVSHMEM_IBGDA_SUPPORT=1 \ - NVSHMEM_PMIX_SUPPORT=0 \ - NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ - NVSHMEM_USE_GDRCOPY=1 \ - cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install - -ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install -ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH -ENV PATH=$NVSHMEM_DIR/bin:$PATH - -## Build deepep -RUN cd DeepEP && \ - python setup.py install \ No newline at end of file diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.base b/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.base deleted file mode 100644 index 25b1d9431..000000000 --- a/docker/verl0.4-cu124-torch2.6-fa2.7.4/Dockerfile.base +++ /dev/null @@ -1,113 +0,0 @@ -# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks -# Target: verlai/verl:base-v2-cu124-cudnn9.8-torch2.6-fa2.8.0-te2.3 -# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10) -# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html -FROM nvcr.io/nvidia/pytorch:24.08-py3 - -# Define environments -ENV MAX_JOBS=16 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Define installation arguments -ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ -ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# Set apt source -RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ - { \ - echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ - } > /etc/apt/sources.list - -# Install systemctl -RUN apt-get update && \ - apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ - apt-get clean - -# Install tini -RUN apt-get update && \ - apt-get install -y tini aria2 && \ - apt-get clean - -# Change pip source -RUN pip config set global.index-url "${PIP_INDEX}" && \ - pip config set global.extra-index-url "${PIP_INDEX}" && \ - python -m pip install --upgrade pip - -# Uninstall nv-pytorch fork -RUN pip uninstall -y torch torchvision torchaudio \ - pytorch-quantization pytorch-triton torch-tensorrt \ - xgboost transformer_engine flash_attn apex megatron-core grpcio - -# Reinstall CUDA 12.4 -RUN aria2c https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin && \ - mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 - -RUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \ - dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \ - cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/ && \ - apt-get update && \ - apt-get -y install cuda-toolkit-12-4 && \ - rm cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb && \ - update-alternatives --set cuda /usr/local/cuda-12.4 && \ - rm -rf /usr/local/cuda-12.6 - -RUN pip install --resume-retries 999 --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 - -RUN pip install --resume-retries 999 --no-cache-dir "tensordict==0.6.2" torchdata "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pyext pre-commit ruff - -# Install flash-attn-2.7.4.post1 (cxx11abi=False) -RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ - pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - -# Fix packages -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -# Install cudnn -RUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ - dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ - cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \ - apt-get update && \ - apt-get -y install cudnn-cuda-12 && \ - rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb - -# Install Apex -RUN git clone https://github.com/NVIDIA/apex.git && \ - cd apex && \ - pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ - -# Profiling tools -RUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ - apt-get update && apt-get install -y libxcb-cursor0 && \ - dpkg -i ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ - rm -rf /usr/local/cuda/bin/nsys && \ - ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys /usr/local/cuda/bin/nsys && \ - rm -rf /usr/local/cuda/bin/nsys-ui && \ - ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \ - rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb - -# Fix opencv -RUN pip install --resume-retries 999 --no-cache-dir opencv-python - -RUN pip install --resume-retries 999 --no-cache-dir opencv-fixer && \ - python -c "from opencv_fixer import AutoFix; AutoFix()" - -RUN pip install --resume-retries 999 --no-cache-dir cuda-bindings - -# Reset pip config -RUN pip config unset global.index-url && \ - pip config unset global.extra-index-url - -RUN apt-get update && \ - apt-get install -y libfreeimage3 libfreeimage-dev zlib1g htop - diff --git a/docker/verl0.4-cu124-torch2.6-fa2.7.4/README.md b/docker/verl0.4-cu124-torch2.6-fa2.7.4/README.md deleted file mode 100644 index 6f77fee6a..000000000 --- a/docker/verl0.4-cu124-torch2.6-fa2.7.4/README.md +++ /dev/null @@ -1,31 +0,0 @@ -# verl image with verl v0.4.x - -## Important packages version - -```txt -cuda==12.4 -cudnn==9.8.0 -torch==2.6.0 -flash_attn=2.7.4 -sglang==0.4.6.post5 -vllm==0.8.5.post1 -vidia-cudnn-cu12==9.8.0.87 -transformer_engine==2.3 -megatron.core==core_v0.12.1 -# Preview -transformer_engine==2.5 -megatron.core==core_r0.13.0 -``` - -## Target - -- Base image: - - `verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4` -- App image: - - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1`: SGLang requires vLLM in 0.4.6.post5 version, vLLM can have some package conflicts with SGLang - - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1-deepep`: Built with deepep - - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1` - - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1-deepep`: Built with deepep -- Preview image: - - `verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.13.0-preview` - - `verlai/verl:app-verl0.4-vllm0.8.5-mcore0.13.0-preview` \ No newline at end of file diff --git a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.12 b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.12 deleted file mode 100644 index 07435b31c..000000000 --- a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.12 +++ /dev/null @@ -1,37 +0,0 @@ -# Start from the verl base image -# Dockerfile.base -FROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0 - -# Define environments -ENV MAX_JOBS=8 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Install sglang-0.4.8 and torch-memory-saver -# Install FlashInfer Python package -RUN pip install --upgrade pip setuptools packaging -RUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1 -RUN pip install --resume-retries 999 --no-cache-dir "sglang[all]==0.4.8" && pip install torch-memory-saver --no-cache-dir - -# Fix packages -RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pyext pre-commit ruff - -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 - -# Install TransformerEngine -RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.3 - -# Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 - -# Install mbridge -RUN pip3 install --no-cache-dir mbridge \ No newline at end of file diff --git a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.13.preview b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.13.preview deleted file mode 100644 index 24b831508..000000000 --- a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.mcore0.13.preview +++ /dev/null @@ -1,37 +0,0 @@ -# Start from the verl base image -# Dockerfile.base -FROM verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0 - -# Define environments -ENV MAX_JOBS=8 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Install sglang-0.4.8 and torch-memory-saver -# Install FlashInfer Python package -RUN pip install --upgrade pip setuptools packaging -RUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1 -RUN pip install --resume-retries 999 --no-cache-dir "sglang[all]==0.4.8" && pip install torch-memory-saver --no-cache-dir - -# Fix packages -RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pyext pre-commit ruff - -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 - -# Install TransformerEngine -RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5 - -# Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.1 - -# Install mbridge -RUN pip3 install --no-cache-dir mbridge \ No newline at end of file diff --git a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.base b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.base deleted file mode 100644 index 915834a0d..000000000 --- a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/Dockerfile.base +++ /dev/null @@ -1,132 +0,0 @@ -# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks -# Target: verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6 -# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10) -# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html -FROM nvcr.io/nvidia/pytorch:24.08-py3 - -# Define environments -ENV MAX_JOBS=16 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Define installation arguments -ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ -ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# Set apt source -RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ - { \ - echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ - } > /etc/apt/sources.list - -# Install systemctl -RUN apt-get update && \ - apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ - apt-get clean - -# Install tini -RUN apt-get update && \ - apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \ - apt-get clean - -# Change pip source -RUN pip config set global.index-url "${PIP_INDEX}" && \ - pip config set global.extra-index-url "${PIP_INDEX}" && \ - python -m pip install --upgrade pip - -# Uninstall nv-pytorch fork -RUN pip uninstall -y torch torchvision torchaudio \ - pytorch-quantization pytorch-triton torch-tensorrt \ - xgboost transformer_engine flash_attn apex megatron-core grpcio - -RUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 - -# Install flash-attn-2.8.0.post2 (cxx11abi=True) -RUN ABI_FLAG=$(python -c "import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')") && \ - URL="https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl" && \ - FILE="flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp310-cp310-linux_x86_64.whl" && \ - wget -nv "${URL}" && \ - pip install --no-cache-dir "${FILE}" - -# Fix packages -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -# Install cudnn -RUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ - dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ - cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \ - apt-get update && \ - apt-get -y install cudnn-cuda-12 && \ - rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb - -# Install Apex -RUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" --resume-retries 999 git+https://github.com/NVIDIA/apex.git - -# Profiling tools -RUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ - apt-get update && apt-get install -y libxcb-cursor0 - -RUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ - rm -rf /usr/local/cuda/bin/nsys && \ - ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys /usr/local/cuda/bin/nsys && \ - rm -rf /usr/local/cuda/bin/nsys-ui && \ - ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \ - rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb - -RUN pip install --resume-retries 999 --no-cache-dir "tensordict==0.6.2" torchdata "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas cuda-bindings \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pyext pre-commit ruff - -# Install DeepEP -## the dependency of IBGDA -RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so - -## Clone and build deepep and deepep-nvshmem -RUN git clone -b v2.3.1 https://github.com/NVIDIA/gdrcopy.git && \ - git clone https://github.com/deepseek-ai/DeepEP.git && \ - cd DeepEP && git checkout a84a248 - -# Prepare nvshmem -RUN wget https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_3.2.5-1.txz && \ - tar -xvf nvshmem_src_3.2.5-1.txz && mv nvshmem_src deepep-nvshmem && \ - cd deepep-nvshmem && git apply ../DeepEP/third-party/nvshmem.patch - -ENV CUDA_HOME=/usr/local/cuda -### Set MPI environment variables. Having errors when not set. -ENV CPATH=/usr/local/mpi/include:$CPATH -ENV LD_LIBRARY_PATH=/usr/local/mpi/lib:$LD_LIBRARY_PATH -ENV LD_LIBRARY_PATH=/usr/local/x86_64-linux-gnu:$LD_LIBRARY_PATH -ENV GDRCOPY_HOME=/workspace/gdrcopy - -## Build deepep-nvshmem -RUN cd deepep-nvshmem && \ - NVSHMEM_SHMEM_SUPPORT=0 \ - NVSHMEM_UCX_SUPPORT=0 \ - NVSHMEM_USE_NCCL=0 \ - NVSHMEM_MPI_SUPPORT=0 \ - NVSHMEM_IBGDA_SUPPORT=1 \ - NVSHMEM_PMIX_SUPPORT=0 \ - NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ - NVSHMEM_USE_GDRCOPY=1 \ - cmake -G Ninja -S . -B build/ -DCMAKE_INSTALL_PREFIX=/workspace/deepep-nvshmem/install && cmake --build build/ --target install - -ENV NVSHMEM_DIR=/workspace/deepep-nvshmem/install -ENV LD_LIBRARY_PATH=$NVSHMEM_DIR/lib:$LD_LIBRARY_PATH -ENV PATH=$NVSHMEM_DIR/bin:$PATH - -## Build deepep -RUN cd DeepEP && \ - python setup.py install - -# Reset pip config -RUN pip config unset global.index-url && \ - pip config unset global.extra-index-url - diff --git a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/README.md b/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/README.md deleted file mode 100644 index c29a7f1f7..000000000 --- a/docker/verl0.5-cu126-torch2.7.1-fa2.8.0/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# verl image with verl v0.5 - -## Important packages version - -```txt -cuda==12.6 -cudnn==9.8.0 -torch==2.7.1 -flash_attn=2.8.0 ## -sglang==0.4.8 -vllm==0.8.5.post1 -vidia-cudnn-cu12==9.8.0.87 -transformer_engine==2.3 -megatron.core==core_v0.12.1 -# Preview -transformer_engine==2.5 -megatron.core==core_r0.13.0 -``` - -## Target - -- Base image: - - `verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0`: We offer a base image with deep ep built in -- App image: - - `verlai/verl:app-verl0.5-sglang0.4.9-mcore0.12.1` - - `verlai/verl:app-verl0.5-sglang0.4.9-mcore0.13.0-preview` -- vllm temporarily not support latest version \ No newline at end of file diff --git a/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.megatron b/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.megatron deleted file mode 100644 index d41ea19d6..000000000 --- a/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.app.sglang.megatron +++ /dev/null @@ -1,36 +0,0 @@ -# Start from the verl base image -# Dockerfile.base -FROM verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6 - -# Define environments -ENV MAX_JOBS=8 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Install sglang-0.4.8 and torch-memory-saver -# Install FlashInfer Python package -RUN pip install --resume-retries 999 --no-cache-dir --no-build-isolation flashinfer-python==0.2.6.post1 -RUN pip install --resume-retries 999 --no-cache-dir "sglang[all]==0.4.8" && pip install torch-memory-saver --no-cache-dir - -# Fix packages -RUN pip install --no-cache-dir "tensordict==0.6.2" "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pre-commit ruff - -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --resume-retries 999 --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -RUN pip install --resume-retries 999 --no-cache-dir nvidia-cudnn-cu12==9.8.0.87 - -# Install TransformerEngine -RUN export NVTE_FRAMEWORK=pytorch && pip3 install --resume-retries 999 --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5 - -# Install Megatron-LM -RUN pip3 install --no-deps --no-cache-dir --no-build-isolation git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.13.0 - -# Install mbridge -RUN pip3 install --no-cache-dir mbridge \ No newline at end of file diff --git a/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.base b/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.base deleted file mode 100644 index 29c49faa8..000000000 --- a/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/Dockerfile.base +++ /dev/null @@ -1,91 +0,0 @@ -# Base Docker Image of verl, with CUDA/Torch/FlashAttn/Apex/TransformerEngine, without other frameworks -# Target: verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0-fi0.2.6 -# Start from the NVIDIA official image (ubuntu-22.04 + cuda-12.6 + python-3.10) -# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html -FROM nvcr.io/nvidia/pytorch:25.02-py3 - -# Define environments -ENV MAX_JOBS=16 -ENV VLLM_WORKER_MULTIPROC_METHOD=spawn -ENV DEBIAN_FRONTEND=noninteractive -ENV NODE_OPTIONS="" -ENV PIP_ROOT_USER_ACTION=ignore -ENV HF_HUB_ENABLE_HF_TRANSFER="1" - -# Define installation arguments -ARG APT_SOURCE=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ -ARG PIP_INDEX=https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple - -# Set apt source -RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ - { \ - echo "deb ${APT_SOURCE} jammy main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-updates main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-backports main restricted universe multiverse"; \ - echo "deb ${APT_SOURCE} jammy-security main restricted universe multiverse"; \ - } > /etc/apt/sources.list - -# Install systemctl -RUN apt-get update && \ - apt-get install -y -o Dpkg::Options::="--force-confdef" systemd && \ - apt-get clean - -# Install tini -RUN apt-get update && \ - apt-get install -y tini aria2 libfreeimage3 libfreeimage-dev zlib1g htop && \ - apt-get clean - -# Change pip source -RUN pip config set global.index-url "${PIP_INDEX}" && \ - pip config set global.extra-index-url "${PIP_INDEX}" && \ - python -m pip install --upgrade pip - -# Uninstall nv-pytorch fork -RUN pip uninstall -y torch torchvision torchaudio \ - pytorch-quantization pytorch-triton torch-tensorrt \ - xgboost transformer_engine flash_attn apex megatron-core grpcio - -RUN pip install --resume-retries 999 --no-cache-dir torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 --index-url https://download.pytorch.org/whl/cu128 - -# Install flash-attn-2.8.0.post2 (cxx11abi=True) -RUN ABI_FLAG=$(python -c "import torch; print('TRUE' if torch._C._GLIBCXX_USE_CXX11_ABI else 'FALSE')") && \ - URL="https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp312-cp312-linux_x86_64.whl" && \ - FILE="flash_attn-2.8.0.post2+cu12torch2.7cxx11abi${ABI_FLAG}-cp312-cp312-linux_x86_64.whl" && \ - wget -nv "${URL}" && \ - pip install --no-cache-dir "${FILE}" - -# Fix packages -RUN pip uninstall -y pynvml nvidia-ml-py && \ - pip install --no-cache-dir --upgrade "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" - -# Install cudnn -RUN aria2c --max-tries=9999 https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ - dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb && \ - cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ && \ - apt-get update && \ - apt-get -y install cudnn-cuda-12 && \ - rm cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb - -# Install Apex -RUN pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" --resume-retries 999 git+https://github.com/NVIDIA/apex.git - -# Profiling tools -RUN aria2c --always-resume=true --max-tries=99999 https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_3/nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ - apt-get update && apt-get install -y libxcb-cursor0 - -RUN apt-get install -y ./nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb && \ - rm -rf /usr/local/cuda/bin/nsys && \ - ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys /usr/local/cuda/bin/nsys && \ - rm -rf /usr/local/cuda/bin/nsys-ui && \ - ln -s /opt/nvidia/nsight-systems/2025.3.1/target-linux-x64/nsys-ui /usr/local/cuda/bin/nsys-ui && \ - rm nsight-systems-2025.3.1_2025.3.1.90-1_amd64.deb - -RUN pip install --resume-retries 999 --no-cache-dir "tensordict==0.6.2" torchdata "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=19.0.1" pandas cuda-bindings \ - ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler blobfile xgrammar \ - pytest py-spy pre-commit ruff - -# Reset pip config -RUN pip config unset global.index-url && \ - pip config unset global.extra-index-url - diff --git a/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/README.md b/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/README.md deleted file mode 100644 index 07d68977f..000000000 --- a/docker/verl0.5-preview-cu128-torch2.7.1-fa2.8.0/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# verl image with verl v0.5 - -## Important packages version - -```txt -cuda==12.8 -cudnn==9.8.0 -torch==2.7.1 -flash_attn=2.8.0 ## -sglang==0.4.8 -transformer_engine==2.5 -megatron.core==core_r0.13.0 -vidia-cudnn-cu12==9.8.0.87 -``` - -## Target - -- Base image: - - `verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`: We offer a base image with flash infer 0.2.6.post1 built in -- App image: - - `verlai/verl:app-verl0.5-preview-sglang0.4.8-mcore0.13.0-preview` -- vllm temporarily not support latest version - -## !!!Notice!!! - -- pyext is lack of maintainace and cannot work with python 3.12, consider using replacement and deprecating this package. \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index 8bda904a9..000000000 --- a/docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -SPHINXPROJ = verl -SOURCEDIR = . -BUILDDIR = _build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/README.md b/docs/README.md deleted file mode 100644 index 8c5db0487..000000000 --- a/docs/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# verl documentations - -## Build the docs - -```bash -# If you want to view auto-generated API docstring, please make sure verl is available in python path. For instance, install verl via: -# pip install .. -e[test] - -# Install dependencies needed for building docs. -pip install -r requirements-docs.txt - -# Build the docs. -make clean -make html -``` - -## Open the docs with your browser - -```bash -python -m http.server -d _build/html/ -``` -Launch your browser and navigate to http://localhost:8000 to view the documentation. Alternatively you could drag the file `_build/html/index.html` to your local browser and view directly. diff --git a/docs/README_vllm0.7.md b/docs/README_vllm0.7.md deleted file mode 100644 index e84feddd7..000000000 --- a/docs/README_vllm0.7.md +++ /dev/null @@ -1,73 +0,0 @@ -# Upgrading to vllm >= 0.7 - -Note: verl+vllm 0.8.3 is now stable. Please see ``docs/README_vllm0.8.md`` for upgrade guide. - -## Installation - -Note: At time of writing, verl+vllm 0.7.x supports **FSDP** for training and **vLLM** for rollout. - -``` -# Create the conda environment -conda create -n verl python==3.10 -conda activate verl - -# Install verl -git clone https://github.com/volcengine/verl.git -cd verl -pip3 install -e . - -# Install the latest stable version of vLLM -pip3 install vllm==0.7.3 - -# Install flash-attn -pip3 install flash-attn --no-build-isolation - -``` - -Note that if you are installing lower versions of vLLM (0.7.0, 0.7.1, 0.7.2), you need to make some tiny patches manually on vllm (/path/to/site-packages/vllm after installation) after the above steps: - -- vllm/distributed/parallel_state.py: Remove the assertion below: - -``` -if (world_size - != tensor_model_parallel_size * pipeline_model_parallel_size): - raise RuntimeError( - f"world_size ({world_size}) is not equal to " - f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " - f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") - -``` - -- vllm/executor/uniproc_executor.py: change `local_rank = rank` to `local_rank = int(os.environ["LOCAL_RANK"])` -- vllm/model_executor/model_loader/weight_utils.py: remove the `torch.cuda.empty_cache()` in `pt_weights_iterator` - -## Features - -### Use cuda graph - -After installation, examples using FSDP as training backends can be used. By default, the `enforce_eager` is set to True, which disables the cuda graph. To enjoy cuda graphs and the sleep mode of vLLM>=0.7, add the following lines to the bash script: - -``` -actor_rollout_ref.rollout.enforce_eager=False \ -actor_rollout_ref.rollout.free_cache_engine=True \ - -``` - -For a typical job like examples/ppo_trainer/run_qwen2-7b_seq_balance.sh, the rollout generation time is 85 seconds with vLLM0.7.0. By enabling the cudagraph, the generation duration is further reduced to 62 seconds. - -**Note:** Currently, if the `n` is greater than 1 in `SamplingParams` in vLLM>=0.7, there is a potential performance issue on the stability of rollout generation time (Some iterations would see generation time bursts) using vLLM's V0 Engine. - -### Use vLLM V1 Engine - -Using the vLLM V1 engine can avoid instability issues and achieve additional performance improvements. To use the V1 engine, you can first uninstall the previously installed vLLM and then follow the steps below to install the newer version. - -``` -git clone https://github.com/vllm-project/vllm.git -cd vllm -git checkout 2275784 -sed -i "903a\ data_parallel_size = world_size // pipeline_model_parallel_size // tensor_model_parallel_size" ./vllm/distributed/parallel_state.py -VLLM_USE_PRECOMPILED=1 pip install --editable . -``` - -Then you can enable the V1 engine by setting `export VLLM_USE_V1=1`. In some benchmark tests, the V1 engine demonstrates a 1.5x speed improvement over the vLLM V0 engine. -The stable support of the vLLM V1 engine is available on verl main. diff --git a/docs/README_vllm0.8.md b/docs/README_vllm0.8.md deleted file mode 100644 index d4f509f19..000000000 --- a/docs/README_vllm0.8.md +++ /dev/null @@ -1,52 +0,0 @@ -# Upgrading to vLLM >= 0.8 - -Last updated: 05/04/2025. - -## Installation - -Note: This version of verl+vLLM 0.8+ supports **FSDP** for training and **vLLM** for rollout. - -```bash -# Create the conda environment -conda create -n verl python==3.10 -conda activate verl - -# Install verl -git clone https://github.com/volcengine/verl.git -cd verl -pip3 install -e . - -# Install the latest stable version of vLLM -pip3 install vllm==0.8.3 - -# Install flash-attn -pip3 install flash-attn --no-build-isolation - -``` - -We have a pre-built docker image for verl+vLLM 0.8.3. You can direct import it with the following command: - -```bash -docker pull hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0 -``` - -## Features - -vLLM 0.8+ supports cuda graph and V1 engine by default in verl. To enable these features, remember to add the following lines to the bash script: - -```bash -actor_rollout_ref.rollout.enforce_eager=False \ -actor_rollout_ref.rollout.free_cache_engine=True \ -``` - -and also **remove** the environment variable if it exists: - -## Notes - -When you just directly upgrade vllm>=0.8, some dependency packages may undergo version changes. If you encounter the following problems: - -```bash -in from torch.multiprocessing.reductions import ForkingPickler ImportError: cannot import name 'ForkingPickler' from 'torch.multiprocessing.reductions' (/opt/conda/lib/python3.11/site-packages/torch/multiprocessing/reductions.py) -``` - -You need to upgrade `tensordict` to version 0.6.2 using the command `pip install tensordict==0.6.2`. diff --git a/docs/_static/js/runllm-widget.js b/docs/_static/js/runllm-widget.js deleted file mode 100644 index bec345cac..000000000 --- a/docs/_static/js/runllm-widget.js +++ /dev/null @@ -1,14 +0,0 @@ -document.addEventListener("DOMContentLoaded", function () { - var script = document.createElement("script"); - script.type = "module"; - script.id = "runllm-widget-script"; - script.src = "https://widget.runllm.com"; - script.setAttribute("version", "stable"); - script.setAttribute("crossorigin", "true"); - script.setAttribute("runllm-keyboard-shortcut", "Mod+j"); - script.setAttribute("runllm-name", "verl Chatbot"); - script.setAttribute("runllm-position", "TOP_RIGHT"); - script.setAttribute("runllm-assistant-id", "679"); - script.async = true; - document.head.appendChild(script); - }); \ No newline at end of file diff --git a/docs/_static/logo.png b/docs/_static/logo.png deleted file mode 100644 index 6a3b61308..000000000 Binary files a/docs/_static/logo.png and /dev/null differ diff --git a/docs/advance/checkpoint.rst b/docs/advance/checkpoint.rst deleted file mode 100644 index 56bec4a75..000000000 --- a/docs/advance/checkpoint.rst +++ /dev/null @@ -1,183 +0,0 @@ -.. _checkpoint-page: - -Using Checkpoints to Support Fault Tolerance Training -===================================================== - -Last updated: 06/25/2025. - -There could be training errors or machine failure during the whole RLHF training process, -so it is recommended to enable checkpoints to minimize your loss. - -The API Interface has already been listed in :ref:`config-explain-page`, -and we will not repeat them. But there are still some technique details -we hope to clarify. - -.. note:: - - Notice that the ``checkpoint.contents`` field has no effect to FSDP checkpoint except ``hf_model``, - the other 3 fields are binded together to save and load. We recommend to include ``model``, ``optimizer`` and ``extra`` all. - -Checkpoint Saving Directory Structure -------------------------------------- - -Commonly, we use the ``default_local_dir`` declared in ``ppo_trainer.yaml`` or ``ppo_megatron_trainer.yml`` -to work as preffix when saving checkpoints, which is ``checkpoints/${trainer.project_name}/${trainer.experiment_name}``. - -So the inner checkpoint structure of **FSDP** is like: - -.. code:: - - checkpoints/${trainer.project_name}/${trainer.experiment_name} - ├── global_steps_${i} - │ ├── actor - │ │ ├── huggingface # default save config and tokenizer, save huggingface model if include ``hf_model`` in checkpoint.contents - │ │ └── fsdp_config.json # FSDP config file, including world_size and fsdp version - │ │ ├── model_world_size_{self.world_size}_rank_{self.rank}.pt - │ │ ├── optim_world_size_{self.world_size}_rank_{self.rank}.pt - │ │ └── extra_state_world_size_{self.world_size}_rank_{self.rank}.pt - │ ├── critic - │ │ ├── huggingface - │ │ └── fsdp_config.json - │ │ ├── model_world_size_{self.world_size}_rank_{self.rank}.pt - │ │ ├── optim_world_size_{self.world_size}_rank_{self.rank}.pt - │ │ └── extra_state_world_size_{self.world_size}_rank_{self.rank}.pt - └── latest_checkpointed_iteration.txt - -All model shards, optimizers and extra states are stored together, in a sharded and distributed way. - -While **Megatron** current checkpoint structure is: - -.. code:: - - checkpoints/${trainer.project_name}/${trainer.experiment_name} - ├── global_steps_${i} - │ ├── actor - │ │ ├── huggingface # default save config and tokenizer, save huggingface model if include ``hf_mode`` in checkpoint.contents - │ │ └── dist_ckpt # save sharded model/optimizer/rng_states, naming the same as Megatron - │ └── critic - │ │ ├── huggingface - │ │ └── dist_ckpt - └── latest_checkpointed_iteration.txt - -Convert FSDP and Megatron Checkpoints to HuggingFace Format Model ------------------------------------------------------------------ - -We provide a tool to convert the FSDP and Megatron checkpoints to HuggingFace format model. -The tool is located in ``verl/model_merger``. For older versions of verl that don't include fsdp_config.json in checkpoints, you can use the legacy model merger located at ``verl/scripts/legacy_model_merger.py``. - -The script supports two main sub-commands: `merge` (to convert and save checkpoints) and `test` (to validate merged checkpoints against a reference model). -The arguments for the `merge` sub-command are as follows: - -.. code:: bash - - usage: python -m verl.model_merger merge [-h] --backend {fsdp,megatron} [--local_dir LOCAL_DIR] [--tie-word-embedding] [--is-value-model] [--use_cpu_initialization] [--target_dir TARGET_DIR] - [--hf_upload_path HF_UPLOAD_PATH] [--private] - - options: - -h, --help show this help message and exit - --backend {fsdp,megatron} - The backend of the model - --local_dir LOCAL_DIR - Path to the saved model checkpoints - --tie-word-embedding Whether to tie word embedding weights (currently only Megatron supported) - --is-value-model Whether the model is a value model (currently only Megatron supported) - --use_cpu_initialization - Whether to use CPU initialization for the model. This is useful for large models that cannot fit into GPU memory during initialization. - --target_dir TARGET_DIR - Directory to save the merged huggingface model - --hf_upload_path HF_UPLOAD_PATH - Hugging Face repository ID to upload the model - --private Whether to upload the model to a private Hugging Face repository - -Example usage for merging Megatron checkpoints: - -.. code:: bash - - python -m verl.model_merger merge \ - --backend megatron \ - --tie-word-embedding \ - --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model - -Example usage for distributed merging Megatron checkpoints: - -.. code:: bash - - torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge \ - --backend megatron \ - --tie-word-embedding \ - --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model - -Example usage for merging FSDP checkpoints: - -.. code:: bash - - python -m verl.model_merger merge \ - --backend fsdp \ - --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model - - -Megatron Merger details ------------------------ - -Current implement of decoder layers uses ``nn.ModuleList`` to store the layers, -and thus the model layers on every PP rank and VPP rank starts their index from 0. - -There are 3 ways to correct this behavior: - -1. Modify the decoder layer's state_dict, add ``offset`` to each layer's index, thus rewrite ``nn.ModuleList`` implementation. -2. Modify the layer index when saving checkpoint and recover them when loading checkpoint. -3. The Checkpoint merger do this work, calculate the actual ``offset`` from ``state_dict`` only, a little complex. - -Current implementation use solution 2. - - -HuggingFace to Megatron DistCheckpoint details ----------------------------------------------- - -If your model is quite huge, we recommend you to use Megatron dist-checkpoint to load the model. -Megatron dist-checkpoint supports loading with different kinds of model parallelism, -and it is much faster than the original checkpoint loading. - -To convert original HuggingFace model to Megatron dist-checkpoint, -you can use the ``scripts/converter_hf_to_mcore.py`` script. Large MoE models are temporarily supported with CPU initialization, -which is a little slower. While we are working on a better solution to support large models. - -Example command to convert the model is as follows: - -.. code:: bash - - python scripts/converter_hf_to_mcore.py \ - --hf_model_path Qwen/Qwen1.5-MoE-A2.7B-Chat \ - --output_path /mnt/disk/Qwen/Qwen1.5-MoE-A2.7B-Chat \ - --use_cpu_initialization # Only work for MoE models - - -Example command to distributed convert the huge model like deepseekv3 671B is as follows: - -.. code:: bash - - torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} scripts/converter_hf_to_mcore.py \ - --hf_model_path deepseek-ai/DeepSeek-V3 \ - --output_path /mnt/disk/deepseek-ai/DeepSeek-V3 \ - --use_cpu_initialization # Only work for MoE models - -Original Checkpoint Utils -------------------------- - -Original Checkpoint Utils refer to original checkpoint implementation in ``verl/models/[model]/megatron/checkpoint_utils``. - -We only need ``[model]_loader.py`` in original checkpoint utils now, since we get rid of storing ``hf_model`` every time (which is not recommended for large model training, try only saving sharded models if you can). - -.. note:: - - Note that ``[model]_loader`` only support environments where **storage clusters are able to connect with every calculation nodes**. - Because it utilizes **sharded load way to minimize the loading checkpoint overhead**. - Every rank loads its own data from ``state_dict`` which can be accessed by all of them. - While there is also no need to broadcast among DP ranks, since the saved state_dict is only produced by DP rank 0. - - For users who can **only place the huggingface model on one device**, we keep the original costly implementation in ``[model]_loader_deprecated``. In this implementation, rank 0 broadcast all weights to each tp and pp rank, and then dp rank 0 broadcast to all dp ranks. There may be at risks of OOM. - - To use deprecated loader, change the import package of ``load_state_dict_to_megatron_llama``. diff --git a/docs/advance/dpo_extension.rst b/docs/advance/dpo_extension.rst deleted file mode 100644 index ee9ac619d..000000000 --- a/docs/advance/dpo_extension.rst +++ /dev/null @@ -1,273 +0,0 @@ -Extend to other RL(HF) algorithms -================================= - -Last updated: 02/25/2025. - -We already implemented the complete training pipeline of the PPO -algorithms. To extend to other algorithms, we analyze the high-level -principle to use verl and provide a tutorial to implement the DPO -algorithm. Users can follow the similar paradigm to extend to other RL algorithms. - -.. note:: **Key ideas**: Single process drives multi-process computation and data communication. - -Overall Approach ----------------- - -Step 1: Consider what multi-machine multi-GPU computations are needed -for each model, such as ``generate_sequence`` , ``compute_log_prob`` and -``update_policy`` in the actor_rollout model. Implement distributed -single-process-multiple-data (SPMD) computation and encapsulate them -into APIs - -Step 2: Based on different distributed scenarios, including FSDP and 3D -parallelism in Megatron-LM, implement single-process control of data -interaction among multi-process computations. - -Step 3: Utilize the encapsulated APIs to implement the control flow - -Example: Online DPO -------------------- - -We use verl to implement a simple online DPO algorithm. The algorithm -flow of Online DPO is as follows: - -1. There is a prompt (rollout) generator which has the same weight as - the actor model. After a batch of prompts are fed into the generator, - it generates N responses for each prompt. -2. Send all the prompts + responses to a verifier for scoring, which can - be reward model or a rule-based function. Then sort them in pairs to - form a training batch. -3. Use this training batch to train the actor model using DPO. During - the process, a reference policy is needed. - -Step 1: What are the multi-machine multi-GPU computations -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -**Sample Generator** - -Implementation details: - -.. code:: python - - from verl.single_controller.base import Worker - from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool - import ray - - @ray.remote - class SampleGenerator(Worker): - def __init__(self, config): - super().__init__() - self.config = config - - def generate_sequences(self, data): - pass - -Here, ``SampleGenerator`` can be viewed as a multi-process pulled up by -``torchrun``, with each process running the same code (SPMD). -``SampleGenerator`` needs to implement a ``generate_sequences`` API for -the control flow to call. The implementation details inside can use any -inference engine including vllm, sglang and huggingface. Users can -largely reuse the code in -verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py and we won't -go into details here. - -**ReferencePolicy inference** - -API: compute reference log probability - -.. code:: python - - from verl.single_controller.base import Worker - import ray - - @ray.remote - class ReferencePolicy(Worker): - def __init__(self): - super().__init__() - self.model = Model() - - def infer(self, data): - return self.model(data) - -**Actor update** - -API: Update actor model parameters - -.. code:: python - - from verl.single_controller.base import Worker - import ray - - @ray.remote - class DPOActor(Worker): - def __init__(self): - super().__init__() - self.model = Model() - self.model = FSDP(self.model) # or other distributed strategy - self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) - self.loss_fn = xxx - - def update(self, data): - self.optimizer.zero_grad() - logits = self.model(data) - loss = self.loss_fn(logits) - loss.backward() - self.optimizer.step() - -**Notes: How to distinguish between control processes and distributed computation processes** -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -- Control processes are generally functions directly decorated with - ``@ray.remote`` -- Computation processes are all wrapped into a ``RayWorkerGroup``. - -Users can reuse most of the distribtued computation logics implemented -in PPO algorithm, including FSDP and Megatron-LM backend in -verl/verl/trainer/ppo. - -Step 2: Based on different distributed scenarios, implement single-process control of multi-process data interaction -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -**The core problem to solve here is how a single process sends data to -multiple processes, drives multi-process computation, and how the -control process obtains the results of multi-process computation.** -First, we initialize the multi-process ``WorkerGroup`` in the control -process. - -.. code:: python - - @ray.remote(num_cpus=1) - def main_task(config): - # construct SampleGenerator - resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs - ray_cls = RayClassWithInitArgs(SampleGenerator, config=config) - # put SampleGenerator onto resource pool - worker_group = RayWorkerGroup(resource_pool, ray_cls) - - # construct reference policy - -As we can see, in the control process, multiple processes are wrapped -into a ``RayWorkerGroup``. Inside this ``WorkerGroup``, there is a -``self._workers`` member, where each worker is a RayActor -(https://docs.ray.io/en/latest/ray-core/actors.html) of SampleGenerator. -ray_trainer.md also provide an implementation of -``MegatronRayWorkerGroup``. - -Assuming the model is distributed using FSDP, and there is a batch of -data on the control process, for data parallelism, the underlying -calling process is: - -.. code:: python - - data = xxx - data_list = data.chunk(dp_size) - - output = [] - for d in data_list: - # worker_group._workers[i] is a SampleGenerator - output.append(worker_group._workers[i].generate_sequences.remote(d)) - - output = ray.get(output) - output = torch.cat(output) - -Single process calling multiple processes involves the following 3 -steps: - -1. Split the data into DP parts on the control process. -2. Send the data to remote, call the remote computation through RPC, and - utilize multi-process computation. -3. Obtain the computation results of each worker on the control process - and merge them. - -Frequently calling these 3 steps on the controller process greatly hurts -code readability. **In verl, we have abstracted and encapsulated these 3 -steps, so that the worker's method + dispatch + collect can be -registered into the worker_group** - -.. code:: python - - from verl.single_controller.base.decorator import register - - def dispatch_data(worker_group, data): - return data.chunk(worker_group.world_size) - - def collect_data(worker_group, data): - return torch.cat(data) - - dispatch_mode = { - 'dispatch_fn': dispatch_data, - 'collect_fn': collect_data - } - - @register(dispatch_mode=dispatch_mode) - def generate_sequences(self, data): - pass - -In this way, we can directly call the method inside the worker through -the ``worker_group`` on the control (driver) process (which is a single -process): - -.. code:: python - - output = worker_group.generate_sequences(data) - -This single line includes data splitting, data distribution and -computation, and data collection. - -Furthermore, the model parallelism size of each model is usually fixed, -including dp, tp, pp. So for these common distributed scenarios, we have -pre-implemented specific dispatch and collect methods,in `decorator.py `_, which can be directly used to wrap the computations. - -.. code:: python - - from verl.single_controller.base.decorator import register, Dispatch - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def generate_sequences(self, data: DataProto) -> DataProto: - pass - -Here it requires the data interface to be ``DataProto``. Definition of -``DataProto`` is in `protocol.py `_. - -Step 3: Main training loop -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -With the above training flows, we can implement the algorithm's control -flow. It is recommended that ``main_task`` is also a ray remote process. - -.. code:: python - - @ray.remote(num_cpus=1) - def main_task(config): - # construct SampleGenerator - resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs - ray_cls = RayClassWithInitArgs(SampleGenerator, config=config) - # put SampleGenerator onto resource pool - sample_gen = RayWorkerGroup(resource_pool, ray_cls) - - # construct reference policy - ray_cls = RayClassWithInitArgs(ReferencePolicy) - ref_policy = RayWorkerGroup(resource_pool, ray_cls) - - # construct actor - ray_cls = RayClassWithInitArgs(DPOActor) - dpo_policy = RayWorkerGroup(resource_pool, ray_cls) - - dataloader = DataLoader() - - for data in dataloader: - # generate data - data = sample_gen.generate_sequences(data) - # generate scores for each data - data = generate_scores(data) - # generate pairwise data using scores - data = generate_pairwise_data(data) - # generate ref_log_prob - data.batch['ref_log_prob'] = ref_policy.infer(data) - # update using dpo - dpo_policy.update(data) - # logging - -Here, different ``WorkerGroups`` can be placed in the same resource pool or -in different resource pools using ``create_colocated_worker_cls`` -similar as in `ray_trainer.py `_. diff --git a/docs/advance/fsdp_extension.rst b/docs/advance/fsdp_extension.rst deleted file mode 100644 index 181e10908..000000000 --- a/docs/advance/fsdp_extension.rst +++ /dev/null @@ -1,97 +0,0 @@ - -Add models with the FSDP backend -================================== - -Last updated: 02/09/2025. - -Model --------------------------- - -In principle, our FSDP backend can support any HF model and we can -sychronoize the actor model weight with vLLM using `hf_weight_loader.py` under `third_party/vllm`. -However, ``hf_weight_loader`` is will gather the full state_dict of a -model during synchronization, which may cause OOM. We suggest using -``dtensor_weight_loader`` which gather the full model parameter layer by -layer to reduce the peak memory usage. We already support dtensor weight -loader for the models below in `dtensor_weight_loader.py` under `third_party/vllm`: - -- ``GPT2LMHeadModel`` -- ``LlamaForCausalLM`` -- ``LLaMAForCausalLM`` -- ``MistralForCausalLM`` -- ``InternLMForCausalLM`` -- ``AquilaModel`` -- ``AquilaForCausalLM`` -- ``Phi3ForCausalLM`` -- ``GemmaForCausalLM`` -- ``Gemma2ForCausalLM`` -- ``GPTBigCodeForCausalLM`` -- ``Starcoder2ForCausalLM`` -- ``Qwen2ForCausalLM`` -- ``DeepseekV2ForCausalLM`` - -To implement ``dtensor_weight_loader`` of a model that's supported in -vLLM, follow the guide of gemma model below: - -1. Copy the - ``load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]])`` from the vllm model class - to ``dtensor_weight_loaders.py`` -2. Modify the arguments to - ``(actor_weights: Dict, vllm_model: nn.Module)`` -3. Replace the ``self`` to ``vllm_model`` -4. Add the - ``local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight)`` - before each ``param = params_dict[name]`` and modify the following - weight loading using ``local_loaded_weight``. -5. Register the implemented dtensor weight loader to ``__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__``. - -.. code-block:: diff - - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - + def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - params_dict = dict(self.named_parameters()) - + params_dict = dict(vllm_model.named_parameters()) - loaded_params = set() - - for name, loaded_weight in weights: - + for name, loaded_weight in actor_weights.items(): - for (param_name, shard_name, shard_id) in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = param.weight_loader - - weight_loader(param, loaded_weight, shard_id) - + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) - break - else: - # lm_head is not used in vllm as it is tied with embed_token. - # To prevent errors, skip loading lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - - weight_loader(param, loaded_weight) - + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) - loaded_params.add(name) - unloaded_params = params_dict.keys() - loaded_params - if unloaded_params: - raise RuntimeError( - "Some weights are not initialized from checkpoints: " - f"{unloaded_params}") \ No newline at end of file diff --git a/docs/advance/megatron_extension.rst b/docs/advance/megatron_extension.rst deleted file mode 100644 index 9a52e6017..000000000 --- a/docs/advance/megatron_extension.rst +++ /dev/null @@ -1,20 +0,0 @@ -Add models with the Megatron-LM backend -========================================= - -Last updated: 04/25/2025. - -Model ------------ - - -If use latest verl, we have direct support of ``GPTModel`` for Megatron backend. -You can use the similar way of using Megatron to pretrain custom models. -We list the steps here: - -1. Find `model_initializer.py `_ -2. If your model is configurable by ``TransformerLayerSpec`` , you can - directly use ``GPTModel``. Otherwise, Please implement a new - ``ModelLayerSpec`` and ``ModelLayer`` here. -3. Use the right ``LayerSpec`` , ``TransformerConfig`` and ``HuggingfaceConfig`` - as arguments to initialize the GPTModel. -4. Return the model at last. diff --git a/docs/advance/placement.rst b/docs/advance/placement.rst deleted file mode 100644 index 43ba761f7..000000000 --- a/docs/advance/placement.rst +++ /dev/null @@ -1,13 +0,0 @@ -Ray API Design Tutorial -======================================= - -Last updated: 10/30/2024. - -We provide a tutorial for our Ray API design, including: - -- Ray basic concepts -- Resource Pool and RayWorkerGroup -- Data Dispatch, Execution and Collection -- Initialize the RayWorkerGroup and execute the distributed computation in the given Resource Pool - -See details in `tutorial.ipynb `_. \ No newline at end of file diff --git a/docs/advance/ppo_lora.rst b/docs/advance/ppo_lora.rst deleted file mode 100644 index baf3ab90a..000000000 --- a/docs/advance/ppo_lora.rst +++ /dev/null @@ -1,87 +0,0 @@ -RL(HF) algorithms with LoRA Support -=========================================== - -Last updated: 06/05/2025. - -We support LoRA (Low-Rank Adaptation) for reinforcement learning algorithms such as PPO, GRPO, and others. - -LoRA is a parameter-efficient fine-tuning technique that injects trainable low-rank matrices into pre-trained weights (typically linear layers). This reduces memory footprint and compute cost, making it possible to fine-tune large models with limited hardware. - -The benefits this brings include: - -- reinforcement learning with very large models (e.g. 70B+) with modest hardware (e.g. 8x80G GPUs), -- enable larger batch sizes due to reduced memory usage, -- simplify model transfer and deployment, as only LoRA adapters need to be saved, -- Combine with techniques like `SLoRA `_ or `CCoE `_ to serve multiple LoRA adapters efficiently - -This guide explains how to enable LoRA in RL training and configure related parameters. - -Usage Guide ------------------------- -1. Lora is available in the `verl.trainer.ppo.ray_trainer.RayPPOTrainer`. Examples are provided via the `verl.trainer.main_ppo` entry point. - -2. Currently, LoRA is supported via huggingface peft, only with fsdp/fsdp2 and vllm backend (sglang support coming soon). - -- `strategy=fsdp` or `strategy=fsdp2` -- `rollout.name=vllm` - -3. Required configurations for LoRA: - -- `actor_rollout_ref.model.lora_rank`: int, set to a reasonable value greater than 0 (e.g., 8, 16, 32, 64) -- `actor_rollout_ref.model.lora_alpha`: float, the alpha term in LoRA -- `actor_rollout_ref.rollout.load_format="safetensors"`: required. This enables vLLM to load the base model. -- `actor_rollout_ref.model.target_modules`: the target modules for LoRA. Typically set to "all-linear". - -4. Recommend options: - -- `actor_rollout_ref.model.use_shm=True`: preload the model into `/dev/shm` to improve model loading speed. -- `actor_rollout_ref.rollout.layered_summon=True`: this enables the actor-model to gather the FSDP shards per layers when synchronizing the LoRA Adapter to vLLM, thereby reducing GPU peak memory. Recommended if the model is very large (70B+) or the GPU memory is limited (< 48GB) - - -Best Practices and Notes -------------------------- - -1. **Learning rate**: it is recommended to increase the value of learning rate by an order of magnitude. - -2. **LoRA Rank**: - -- Too small a rank can hurt convergence. -- LoRA rank recommendation from @thelongestusernameofall: - - - A very small lora_rank can lead to slower convergence or worse training performance. It is recommended to set lora_rank to be>=32. Tests have shown that for a 0.5B model, with lora_rank=32,the training convergence speed and final performance are almost identical to non-LoRA training - - For a 32B model,with lora_rank=128,the training convergence speed and final performance are also almost identical to non-LoRA training. - - More comprehensive reference results are coming soon. - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/f2b80b8b26829124dd393b7a795a0640eff11644/docs/lora.jpg?raw=true - -3. Reference configuration for RL training with the Qwen2.5-72B model using 8 x 80GB GPUs (increase lora_rank if needed): - -.. code-block:: - - data.train_batch_size=64 \ - actor_rollout_ref.model.use_shm=True \ - actor_rollout_ref.model.lora_rank=32 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.model.target_modules=all-linear \ - actor_rollout_ref.actor.optim.lr=3e-5 \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=8 \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.max_num_seqs=64 \ - actor_rollout_ref.rollout.max_model_len=1536 \ - actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ - actor_rollout_ref.rollout.load_format=safetensors \ - actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - -Example Script -------------------- - -For an end-to-end example, refer to the script below: - -examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh diff --git a/docs/advance/rollout_trace.rst b/docs/advance/rollout_trace.rst deleted file mode 100644 index ea203bbc0..000000000 --- a/docs/advance/rollout_trace.rst +++ /dev/null @@ -1,125 +0,0 @@ -Trace Function Usage Instructions -======================================== - -Last updated: 07/10/2025. - -Applicable Scenarios --------------------- - -Agentic RL involves multiple turns of conversations, tool invocations, and user interactions during the rollout process. During the Model Training process, it is necessary to track function calls, inputs, and outputs to understand the flow path of data within the application. The Trace feature helps, in complex multi-round conversations, to view the transformation of data during each interaction and the entire process leading to the final output by recording the inputs, outputs, and corresponding timestamps of functions, which is conducive to understanding the details of how the model processes data and optimizing the training results. - -The Trace feature integrates commonly used Agent trace tools, including wandb weave and mlflow, which are already supported. Users can choose the appropriate trace tool according to their own needs and preferences. Here, we introduce the usage of each tool. - - -Trace Parameter Configuration ------------------------------ - -- ``actor_rollout_ref.rollout.trace.backend=mlflow|weave`` # the trace backend type -- ``actor_rollout_ref.rollout.trace.token2text=True`` # To show decoded text in trace view - - -Glossary --------- - -+----------------+------------------------------------------------------------------------------------------------------+ -| Object | Explaination | -+================+======================================================================================================+ -| trajectory | A complete multi-turn conversation includes: | -| | 1. LLM output at least once | -| | 2. Tool Call | -+----------------+------------------------------------------------------------------------------------------------------+ -| step | The training step corresponds to the global_steps variable in the trainer | -+----------------+------------------------------------------------------------------------------------------------------+ -| sample_index | The identifier of the sample, defined in the extra_info.index of the dataset. It is usually a number,| -| | but may also be a uuid in some cases. | -+----------------+------------------------------------------------------------------------------------------------------+ -| rollout_n | In the GROP algorithm, each sample is rolled out n times. rollout_n represents the serial number of | -| | the rollout. | -+----------------+------------------------------------------------------------------------------------------------------+ -| validate | Whether the test dataset is used for evaluation? | -+----------------+------------------------------------------------------------------------------------------------------+ - -Rollout trace functions ------------------------ - -There are 2 functions used for tracing: - -1. ``rollout_trace_op``: This is a decorator function used to mark the functions to trace. In default, only few method has it, you can add it to more functions to trace more infor. -2. ``rollout_trace_attr``: This function is used to mark the entry of a trajectory and input some info to trace. If you add new type of agent, you may need to add it to enable trace. - - -Usage of wandb weave --------------------- - -1.1 Basic Configuration -~~~~~~~~~~~~~~~~~~~~~~~ - -1. Set the ``WANDB_API_KEY`` environment variable -2. Configuration Parameters - - 1. ``actor_rollout_ref.rollout.trace.backend=weave`` - 2. ``trainer.logger=['console', 'wandb']``: This item is optional. Trace and logger are independent functions. When using Weave, it is recommended to also enable the wandb logger to implement both functions in one system. - 3. ``trainer.project_name=$project_name`` - 4. ``trainer.experiment_name=$experiment_name`` - 5. ``actor_rollout_ref.rollout.mode=async``: Since trace is mainly used for agentic RL, need to enable agent toop using async mode for either vllm or sglang. - -Note: -The Weave Free Plan comes with a default monthly network traffic allowance of 1GB. During the training process, the amount of trace data generated is substantial, reaching dozens of gigabytes per day, so it is necessary to select an appropriate wandb plan. - - -1.2 View Trace Logs -~~~~~~~~~~~~~~~~~~~ - -After executing the training, on the project page, you can see the WEAVE sidebar. Click Traces to view it. - -Each Trace project corresponds to a trajectory. You can filter and select the trajectories you need to view by step, sample_index, rollout_n, and experiment_name. - -After enabling token2text, prompt_text and response_text will be automatically added to the output of ToolAgentLoop.run, making it convenient to view the input and output content. - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/weave_trace_list.png?raw=true - -1.3 Compare Trace Logs -~~~~~~~~~~~~~~~~~~~~~~ - -Weave can select multiple trace items and then compare the differences among them. - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/weave_trace_compare.png?raw=true - -Usage of mlflow ---------------- - -1. Basic Configuration -~~~~~~~~~~~~~~~~~~~~~~ - -1. Set the ``MLFLOW_TRACKING_URI`` environment variable, which can be: - - 1. Http and https URLs corresponding to online services - 2. Local files or directories, such as ``sqlite:////tmp/mlruns.db``, indicate that data is stored in ``/tmp/mlruns.db``. When using local files, it is necessary to initialize the file first (e.g., start the UI: ``mlflow ui --backend-store-uri sqlite:////tmp/mlruns.db``) to avoid conflicts when multiple workers create files simultaneously. - -2. Configuration Parameters - - 1. ``actor_rollout_ref.rollout.trace.backend=mlflow`` - 2. ``trainer.logger=['console', 'mlflow']``. This item is optional. Trace and logger are independent functions. When using mlflow, it is recommended to also enable the mlflow logger to implement both functions in one system. - 3. ``trainer.project_name=$project_name`` - 4. ``trainer.experiment_name=$experiment_name`` - - -2. View Log -~~~~~~~~~~~ - -Since ``trainer.project_name`` corresponds to Experiments in mlflow, in the mlflow view, you need to select the corresponding project name, then click the "Traces" tab to view traces. Among them, ``trainer.experiment_name`` corresponds to the experiment_name of tags, and tags corresponding to step, sample_index, rollout_n, etc., are used for filtering and viewing. - -For example, searching for ``"tags.step = '1'"`` can display all trajectories of step 1. - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/mlflow_trace_list.png?raw=true - -Opening one of the trajectories allows you to view each function call process within it. - -After enabling token2text, prompt_text and response_text will be automatically added to the output of ToolAgentLoop.run, making it convenient to view the content. - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/mlflow_trace_view.png?raw=true - -Note: - -1. mlflow does not support comparing multiple traces -2. rollout_trace can not associate the mlflow trace with the run, so the trace content cannot be seen in the mlflow run logs. diff --git a/docs/advance/rope.rst b/docs/advance/rope.rst deleted file mode 100644 index 9463549e4..000000000 --- a/docs/advance/rope.rst +++ /dev/null @@ -1,39 +0,0 @@ -RoPE Scaling override -======================================= - -Last updated: 05/14/2025. - -Some models such as `Qwen/Qwen2.5-7B-Instruct `_ support RoPE Scaling but don't have it defined in their config.json file. -For example, this model supports this configuration: - -.. code:: python - - { - ..., - "rope_scaling": { - "factor": 4.0, - "original_max_position_embeddings": 32768, - "type": "yarn" - } - } - - - -In order to support a longer context for such models, you must override the model configs when starting the trainer. - -PPO example: - -.. code:: bash - - +actor_rollout_ref.model.override_config.rope_scaling.type=yarn \ - +actor_rollout_ref.model.override_config.rope_scaling.factor=4.0 \ - +actor_rollout_ref.model.override_config.rope_scaling.original_max_position_embeddings=32768 \ - - -And for the critic model - -.. code:: bash - - +critic.model.override_config.rope_scaling.type=yarn \ - +critic.model.override_config.rope_scaling.factor=4.0 \ - +critic.model.override_config.rope_scaling.original_max_position_embeddings=32768 \ diff --git a/docs/algo/baseline.md b/docs/algo/baseline.md deleted file mode 100644 index ce74c367d..000000000 --- a/docs/algo/baseline.md +++ /dev/null @@ -1,75 +0,0 @@ -# Algorithm Baselines - -Last updated: 06/18/2025. - -## Math related datasets - -### GSM8k - -Assuming GSM8k/math dataset is preprocessed via: - -```bash -python3 examples/data_preprocess/*.py -``` - -Refer to the table below to reproduce RL training from different pre-trained checkpoints. Below is the performance on the GSM8k dataset if not specified otherwise. More comprehensive benchmark results areavailable in the recipe folder. - - -| Hardware | Model | Method | Test score | Details | -|-------------|----------------------------------|-------------------|--------------|---------| -| NVIDIA GPU | google/gemma-2-2b-it | hf checkpoint | 23.9 | [Huggingface](https://huggingface.co/google/gemma-2-2b-it#benchmark-results) | -| NVIDIA GPU | google/gemma-2-2b-it | SFT | 52.06 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/gemma-2-2b-it-sft-0.411.log) | -| NVIDIA GPU | google/gemma-2-2b-it | SFT + PPO | 64.02 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/gemma-2-2b-it-ppo-bsz512_4-prompt1024-resp-512-0.640.log), [wandb](https://api.wandb.ai/links/verl-team/h7ux8602) | -| NVIDIA GPU | Qwen/Qwen2.5-0.5B-Instruct | hf checkpoint | 36.4 | [Qwen blog](https://qwenlm.github.io/blog/qwen2.5-llm/) | -| NVIDIA GPU | Qwen/Qwen2.5-0.5B-Instruct | PPO | 56.7 | [command and log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) | -| NVIDIA GPU | Qwen/Qwen2.5-0.5B-Instruct | PRIME | 58.7 | [script](https://github.com/volcengine/verl/blob/main/recipe/prime/run_prime_qwen.sh), [wandb](https://api.wandb.ai/links/zefan-wang-thu-tsinghua-university/rxd1btvb) | -| NVIDIA GPU | Qwen/Qwen2.5-0.5B-Instruct | GRPO-LoRA | 54.3 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz64_2-prompt512-resp1024-lorarank32-score0.543.log)| -| NVIDIA GPU | Qwen/Qwen2.5-1.5B-Instruct | GRPO-LoRA | 77.9 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-1.5B-bsz64_2-prompt512-resp1024-lorarank32-score0.779.log)| -| NVIDIA GPU | Qwen/Qwen2.5-3B-Instruct | GRPO-LoRA | 86.1 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-3B-bsz64_2-prompt512-resp1024-lorarank32-score0.861.log)| -| NVIDIA GPU | deepseek-ai/deepseek-llm-7b-chat | PPO (Megatron) | 69.5 [1] | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/deepseek-llm-7b-chat-megatron-bsz256_4-prompt512-resp512-0.695.log), [wandb](https://wandb.ai/verl-team/verl_megatron_gsm8k_examples/runs/10fetyr3) | -| NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GRPO | 89 | [script](https://github.com/volcengine/verl/blob/a65c9157bc0b85b64cd753de19f94e80a11bd871/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh) | -| NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GRPO (FSDP2) | 89.8 | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log) | -| NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GRPO (Megatron) | 89.6 | [log](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b_math_megatron.log) | -| NVIDIA GPU | Qwen/Qwen2.5-7B-Instruct | ReMax | 97 | [script](https://github.com/eric-haibin-lin/verl/blob/main/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh), [wandb](https://wandb.ai/liziniu1997/verl_remax_example_gsm8k/runs/vxl10pln) | -| NVIDIA GPU | Qwen/Qwen2.5-7B-Instruct | SPPO | 65.6 (MATH) | [SPPO script](https://github.com/volcengine/verl/tree/main/recipe/sppo/README.md) | -| NVIDIA GPU | Qwen/Qwen2.5-7B-Instruct | GRPO-LoRA | 93.4 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-7B-bsz64_8-prompt512-resp1024-lorarank32-score0.934.log)| -| NVIDIA GPU | Mixtral-8x22B-Instruct-v0.1 | Instruct model | 83.7 | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/) | -| NVIDIA GPU | Mixtral-8x22B-Instruct-v0.1 | RLOO (Megatron) | 92.3 | [wandb](https://api.wandb.ai/links/ppo_dev/sbuiuf2d) | -| NVIDIA GPU | Qwen/Qwen2.5-7B-Instruct | SPIN | 92 | [script](https://github.com/volcengine/verl/tree/main/recipe/spin/README.md) | -| NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GPG | 88 | [log](https://github.com/diqiuzhuanzhuan/verldata/blob/main/run_logs/qwen2-7b_math.log), [wandb](https://wandb.ai/diqiuzhuanzhuan/verl_gpg_example_gsm8k_math/runs/ab86c4va) | -| NVIDIA GPU | Qwen/Qwen2-7B-Instruct | GPG (Megatron) | 88 | [log](https://github.com/diqiuzhuanzhuan/verldata/blob/main/run_logs/qwen2-7b_math_megatron.log), [wandb](https://wandb.ai/diqiuzhuanzhuan/verl_gpg_example_gsm8k_math/runs/yy8bheu8) | -| NVIDIA GPU | Qwen/Qwen2.5-VL-7B-Instruct | GRPO (Megatron) | 65.4 (GEO3k) | [script](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh), [wandb](https://api.wandb.ai/links/megatron-core-moe-dev/1yngvkek) | -| AMD MI300 | deepseek-ai/deepseek-llm-7b-chat | PPO | 70.5 [1] | [log](https://github.com/yushengsu-thu/verl_training_log/blob/main/gsm8k/ppo_run_deepseek7b_llm.log) | -| AMD MI300 | deepseek-ai/deepseek-llm-7b-chat | GRPO | 71.4 [1] | [log](https://github.com/yushengsu-thu/verl_training_log/blob/main/gsm8k/grpo_run_deepseek7b_llm.log) | -| NVIDIA GPU | Qwen/Qwen2.5-14B-Instruct | GRPO-LoRA | 94.6 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-14B-bsz64_8-prompt512-resp1024-lorarank32-score0.946.log)| -| NVIDIA GPU | Qwen/Qwen2.5-32B-Instruct | GRPO-LoRA | 95.8 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-32B-bsz64_8-prompt512-resp1024-lorarank32-score0.958.log)| -| NVIDIA GPU | Qwen/Qwen2.5-72B-Instruct | GRPO-LoRA | 96.0 | [command and logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-72B-bs64_8-prompt512-resp1024-lorarank32-score0.960.log)| - -### DAPO math-17k - -- Training DAPO math-17k dataset: https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k -- Testing: AIME'24: https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024 - -Note: -- For Qwen/Qwen2.5-Math-7B, we directly modify the max_position_embeddings to 32768 without observing performance degradation in order to train longer response length. - -| Hardware | Model | Method | Test score | Details | -|-------------|----------------------------------|-------------------|--------------|---------| -| NVIDIA GPU | Qwen/Qwen2.5-Math-7B (32k) | DAPO | 36.3 | [command](https://github.com/volcengine/verl/blob/main/recipe/dapo/test_dapo_7b_math.sh), [logs](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361)| - - - -## Coding related datasets - -Below is the result on leetcode if not specified otherwise. - -| Hardware | Model | Method | Test score | Details | -|-------------|----------------------------------|-------------------|--------------|---------| -| NVIDIA GPU | PRIME-RL/Eurus-2-7B-SFT | RPIME | 36.1 | [script](https://github.com/volcengine/verl/blob/main/recipe/prime/run_prime_qwen_code.sh), [swanlab](https://swanlab.cn/@wangzefan/prime_example/runs/7f541qhspgmy8nmhdlx35/chart) | - - -### Notes - -[1] During evaluation, we have only extracted answers following the format `"####"`. A more flexible answer extraction, longer response length, and better prompt engineering may lead to a higher score. - -[2] The default value of `actor_rollout_ref.actor.entropy_coeff` is set to `0.0` since verl 0.3.x on 2025-05-30, which is different from previous versions. diff --git a/docs/algo/dapo.md b/docs/algo/dapo.md deleted file mode 100644 index 96f242eaa..000000000 --- a/docs/algo/dapo.md +++ /dev/null @@ -1,187 +0,0 @@ -# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) - -Last updated: 06/19/2025. - -> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211) - -🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper@arXiv](https://arxiv.org/abs/2503.14476) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO) - -> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps. -> -> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png) - -## Quickstart - -1. Prepare the datasets **on the Ray cluster**: - -```bash -bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default -``` - -2. Submit the job to the Ray cluster **from any machine**: - -```bash -cd verl # Repo root -export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to -export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster -# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml -export RUNTIME_ENV="./recipe/dapo/runtime_env.yaml" # This sets environment variables for the Ray cluster -bash recipe/dapo/run_dapo_qwen2.5_32b.sh # or other scripts -``` - -## Reproduction Runs - -| Setup | AIME 2024 Acc. | Hardware | Image | Commit | Environment Variables | Training Script | Training Record | -| -------------------------------------------- | -------------- | --------- | -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | -| DAPO | 52% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | -| DAPO w/o Dynamic Sampling | 50% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_wo_ds_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | -| DAPO w/o Token-level Loss & Dynamic Sampling | 44% | 16x8xH20 | `hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_early_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | - -> [!IMPORTANT] -> -> **📢 Call for Contribution!** -> -> Welcome to submit your reproduction runs and setups! - -## Configuration - -### Separated Clip Epsilons (-> Clip-Higher) - -An example configuration: - -```yaml -actor_rollout_ref: - actor: - clip_ratio_low: 0.2 - clip_ratio_high: 0.28 -``` - -`clip_ratio_low` and `clip_ratio_high` specify the $\varepsilon_{\text {low }}$ and $\varepsilon_{\text {high }}$ in the DAPO objective. - -Core relevant code: - -```python -pg_losses1 = -advantages * ratio -pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) -pg_losses = torch.maximum(pg_losses1, pg_losses2) -``` - -### Dynamic Sampling (with Group Filtering) - -An example configuration: - -```yaml -data: - gen_batch_size: 1536 - train_batch_size: 512 -algorithm: - filter_groups: - enable: True - metric: acc # score / seq_reward / seq_final_reward / ... - max_num_gen_batches: 10 # Non-positive values mean no upper limit -``` - -Setting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0. - -The trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`. - -Core relevant code: - -```python -prompt_bsz = self.config.data.train_batch_size -if num_prompt_in_batch < prompt_bsz: - print(f'{num_prompt_in_batch=} < {prompt_bsz=}') - num_gen_batches += 1 - max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches - if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: - print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...') - continue - else: - raise ValueError( - f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.' - ) -else: - # Align the batch - traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n - batch = batch[:traj_bsz] -``` - -### Flexible Loss Aggregation Mode (-> Token-level Loss) - -An example configuration: - -```yaml -actor_rollout_ref: - actor: - loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" - # NOTE: "token-mean" is the default behavior -``` - -Setting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch. - -Core relevant code: - -```python -if loss_agg_mode == "token-mean": - loss = verl_F.masked_mean(loss_mat, loss_mask) -elif loss_agg_mode == "seq-mean-token-sum": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum - loss = torch.mean(seq_losses) # seq-mean -elif loss_agg_mode == "seq-mean-token-mean": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean - loss = torch.mean(seq_losses) # seq-mean -else: - raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") -``` - -### Overlong Reward Shaping - -An example configuration: - -```yaml -data: - max_response_length: 20480 # 16384 + 4096 -reward_model: - overlong_buffer: - enable: True - len: 4096 - penalty_factor: 1.0 -``` - -Setting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit. - -Specifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length` by `0` to `overlong_buffer.len` tokens. - -Core relevant code: - -```python -if self.overlong_buffer_cfg.enable: - overlong_buffer_len = self.overlong_buffer_cfg.len - expected_len = self.max_resp_len - overlong_buffer_len - exceed_len = valid_response_length - expected_len - overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor - overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) - reward += overlong_reward -``` - -## FAQ - -### Where is the "Overlong Filtering" in the paper? - -Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here. - -### What's the difference between [the `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) and the [`recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)? - -[The `recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) is for **as-is reproduction** and thus won't be updated with new features. - -[The `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) works as an example of how to extend the latest `verl` to implement an algorithm recipe, which will be maintained with new features. - -### Why can't I produce similar results after modifications? - -RL infrastructures nowadays still have inherent unrobustness, on which we are still working hard to improve. - -We strongly recommend to only modify one thing at a time. - -We also list some known problems here: - -1. Enabling CUDA graph (`enforce_eager=False`) might cause model performance degradation, whose cause is still under investigation. diff --git a/docs/algo/entropy.md b/docs/algo/entropy.md deleted file mode 100644 index 46153b7e8..000000000 --- a/docs/algo/entropy.md +++ /dev/null @@ -1,115 +0,0 @@ -# Recipe: Entropy Mechanism - -Last updated: 06/27/2025. - - -
- - The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning. - -[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617) [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue -)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861) - - - - -
- - -## 🎉News - -- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29). -- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse. - - - -## ✨Getting started - -After preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run: - -``` -cd verl -conda activate your_env -bash recipe/dapo/7b_kl_cov.sh -``` - -While for training Qwen2.5-32B on multi nodes, you can run the following commands: - -``` -cd verl -conda activate your_env -bash recipe/dapo/32b_kl_cov.sh -``` - -## 📖Introduction - -
- issue -
- -This paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion. - -
- issue -
- -Theoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose ​​Clip-Cov​​ and ​​KL-Cov​​, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance. - -## 📃Evaluation - -
- issue -
- - -Our method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL. -| **Method** | **AIME24** | **AIME25** | **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** | -| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: | -| *Qwen2.5-7B* | | | | | | | | | -| GRPO | 21.2 | 9.6 | 58.7 | 78.8 | 27.9 | 40.7 | 36.7 | 38.6 | -| w. Clip-higher | 18.1 | 11.5 | 56.6 | 79.2 | 29.8 | 43.3 | 40.4 | 38.8 | -| w. **`CLIP-Cov`** | 22.1 | **15.8** | 58.2 | 80.4 | **30.5** | **44.1** | **41.1** | 40.4 | -| w. **`KL-Cov`** | **22.6** | 12.9 | **61.4** | **80.8** | 29.1 | 42.6 | 38.2 | **40.6** | -| *Qwen2.5-32B* | | | | | | | | | -| GRPO | 21.8 | 16.2 | 69.7 | 84.2 | 35.2 | 43.6 | 45.5 | 45.8 | -| w. Clip-higher | 35.6 | 22.3 | 69.5 | 77.2 | 35.1 | 42.5 | 43.0 | 47.2 | -| w. **`CLIP-Cov`** | 32.3 | 22.7 | 67.2 | **87.0** | **42.0** | **57.2** | 46.0 | 50.3 | -| w. **`KL-Cov`** | **36.8** | **30.8** | **74.5** | 84.6 | 39.1 | 49.0 | **46.3** | **52.2** | - -Our two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively. - - -## 🎈Citation -If you find this paper or repo helpful, please cite us. - -```bibtex -@article{cui2025entropy, - title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models}, - author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others}, - journal={arXiv preprint arXiv:2505.22617}, - year={2025} -} -``` -## 🌻Acknowledgement -We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions! - -## 📬 Contact - -For questions, discussion, or collaboration opportunities, feel free to contact: -- Ganqu Cui: cuiganqu@pjlab.org.cn -- Yuchen Zhang: yuchen.zhang2003@gmail.com -- Jiacheng Chen: jackchan9345@gmail.com -- Ning Ding: ningding.cs@gmail.com - diff --git a/docs/algo/gpg.md b/docs/algo/gpg.md deleted file mode 100644 index 36bede8c3..000000000 --- a/docs/algo/gpg.md +++ /dev/null @@ -1,36 +0,0 @@ -# GPG: Group Policy Gradient - -Last updated: 07/03/2025. - -Group Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning -](https://arxiv.org/abs/2504.02546). - -## Key Components -- Use a corrected advantage function to improve policy gradient accuracy and training efficiency. -- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO) - -## Configuration -To configure GPG within the framework, use the following YAML settings. - -```yaml -algorithm: - adv_estimator: gpg -actor_rollout_ref: - actor: - policy_loss: - loss_mode: "gpg" -``` - -## Advanced Extensions -GPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance. - -```yaml -algorithm: - adv_estimator: gpg -actor_rollout_ref: - actor: - use_kl_loss: True # enable kl regularization - kl_loss_coef: 0.01 - policy_loss: - loss_mode: "gpg" -``` \ No newline at end of file diff --git a/docs/algo/grpo.md b/docs/algo/grpo.md deleted file mode 100644 index ba6d8ddab..000000000 --- a/docs/algo/grpo.md +++ /dev/null @@ -1,71 +0,0 @@ -# Group Relative Policy Optimization (GRPO) - -Last updated: 05/31/2025. - -In reinforcement learning, classic algorithms like PPO rely on a "critic" model to estimate the value of actions, guiding the learning process. However, training this critic model can be resource-intensive. - -GRPO simplifies this process by eliminating the need for a separate critic model. Instead, it operates as follows: -- Group Sampling: For a given problem, the model generates multiple possible solutions, forming a "group" of outputs. -- Reward Assignment: Each solution is evaluated and assigned a reward based on its correctness or quality. -- Baseline Calculation: The average reward of the group serves as a baseline. -- Policy Update: The model updates its parameters by comparing each solution's reward to the group baseline, reinforcing better-than-average solutions and discouraging worse-than-average ones. - -This approach reduces computational overhead by avoiding the training of a separate value estimation model, making the learning process more efficient. For more details, refer to the original paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/pdf/2402.03300) - -## Key Components - -- No Value Function (Critic-less): unlike PPO, GRPO does not train a separate value network (critic) -- Group Sampling (Grouped Rollouts): instead of evaluating one rollout per input, GRPO generates multiple completions (responses) from the current policy for each prompt. This set of completions is referred to as a group. -- Relative Rewards: within each group, completions are scored (e.g., based on correctness), and rewards are normalized relative to the group. - -## Configuration - -Note that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior. - -Despite that many configurations start with the `ppo_` prefix, they work across different RL algorithms in verl, as the GRPO training loop is similar to that of PPO (without critic). - -![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d) - -- `actor_rollout.ref.rollout.n`: For each prompt, sample n times. Default to 1. For GRPO, please set it to a value larger than 1 for group sampling. - -- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n` - -- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers. - -- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for GRPO updates on one set of sampled trajectories for actor - -- `actor_rollout_ref.actor.clip_ratio`: The GRPO clip range. Default to 0.2 - -- `algorithm.adv_estimator`: Default is gae. Please set it to grpo instead - -- `actor_rollout_ref.actor.loss_agg_mode`: Default is "token-mean". Options include "token-mean", "seq-mean-token-sum", "seq-mean-token-mean". The original GRPO paper takes the sample-level loss (seq-mean-token-mean), which may be unstable in long-CoT scenarios. All GRPO example scripts provided in verl uses the default configuration "token-mean" for loss aggregation instead. - -Instead of adding KL penalty in the reward, GRPO regularizes by directly adding the KL divergence between the trained policy and the reference policy to the loss: - -- `actor_rollout_ref.actor.use_kl_loss`: To use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False. Please set it to True for GRPO. - -- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001. - -- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html - -## Advanced Extensions - -### DrGRPO - -[Understanding R1-Zero-Like Training: A Critical Perspective](https://arxiv.org/pdf/2503.20783) claims there's optimization bias in GRPO, which leads to artificially longer responses, especially for incorrect outputs. This inefficiency stems from the way GRPO calculates advantages using group-based reward normalization. Instead, DrGRPO aggregates token-level losses by normalizing with a global constant to eliminate length bias. - -Configure the following to enable DrGRPO, with all other parameters the same as GRPO's: - -- `actor_rollout_ref.actor.loss_agg_mode`: "seq-mean-token-sum-norm", which turns off seq-dim averaging -- `actor_rollout_ref.actor.use_kl_loss`: Please set it to False for DrGRPO -- `algorithm.norm_adv_by_std_in_grpo`: False, which turns off standard deviation norm - -## Reference Example - -Qwen2.5 GRPO training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log) - -```bash -bash examples/grpo_trainer/run_qwen3-8b.sh -``` - -For more reference performance, please see https://verl.readthedocs.io/en/latest/algo/baseline.html diff --git a/docs/algo/opo.md b/docs/algo/opo.md deleted file mode 100644 index 338f3a762..000000000 --- a/docs/algo/opo.md +++ /dev/null @@ -1,33 +0,0 @@ -# On-Policy RL with Optimal Reward Baseline (OPO) - -Last updated: 06/02/2025. - -Loose on-policy constraints and suboptimal baselines in reinforcement learning often lead to training instability such as large policy shifts and entropy collapse. OPO addresses these challenges by using exact on-policy training with the theretically optimal reward baseline for advantage estimation. It achieves lower policy shifts and higher output entropy, encouraging more diverse and less repetitive responses. - -OPO uses group sampling to generate multiple outputs for each input like GRPO. Unlike group-based algorithms which typically use the mean reward of a group as its baseline, OPO employs a theoretically optimal baseline: the length-weighted reward of the group. It also omits the standard deviation normalization. By adopting these two key components, OPO enables the training of a single policy model with the objective of maximizing only the expected reward. For more detailes, refer to the original paper [On-Policy RL with Optimal Reward Baseline](https://arxiv.org/pdf/2505.23585). - -## Key Components - -- Exact On-Policy Training: always generates responses from the current policy, without using any pre-generated data or off-policy data. -- Optimal Reward Baseline: uses a length-weighted reward of the group as the baseline for normalizing the rewards. - -## Configuration - -To configure OPO within the framework, use the following YAML settings. These parameters are crucial for enabling exact on-policy training and activating the optimal reward baseline. - -```yaml -algorithm: - adv_estimator: opo # Use OPO for optimal reward baseline -data: - train_batch_size: 1024 -actor_rollout_ref: - actor: - ppo_mini_batch_size: 1024 # ppo_mini_batch_size should equal to train_batch_size to enable exact on-policy training - entropy_coeff: 0 # disable entropy regularization - use_kl_loss: False # disable kl regularization - kl_loss_coef: 0 -``` - -## Advanced Extensions - -OPO can also be extended to other algorithms like RLOO and Reinforce++. It just needs to adjust their configurations to enable exact on-policy training and incorporate the optimal length-weighted reward baseline with minimal modifications to their advantage estimation functions. diff --git a/docs/algo/ppo.md b/docs/algo/ppo.md deleted file mode 100644 index d1f3046e5..000000000 --- a/docs/algo/ppo.md +++ /dev/null @@ -1,105 +0,0 @@ -# Proximal Policy Optimization (PPO) - -Last updated: 06/19/2025. - -Proximal Policy Optimization (PPO) is a family of policy gradient methods for reinforcement learning, proposed by OpenAI in 2017. PPO strikes a balance between simplicity, stability, and performance, making it one of the most widely used algorithms in modern RL applications, including large-scale language model fine-tuning. - -Traditional policy gradient methods like REINFORCE or Vanilla Policy Gradient suffer from: - -- High variance and sample inefficiency. -- Instability due to large policy updates. - -PPO addresses this problem using a clipped surrogate objective that avoids overly large updates without requiring second-order derivatives. - -For more technical details regarding PPO, we suggest reading the introduction in the [OpenAI spinning up tutorial](https://spinningup.openai.com/en/latest/algorithms/ppo.html), and the paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347). - -## Key Components - -- Actor-Critic Architecture: PPO requires both an actor model (policy) and a critic model (value function). This differs from other algorithms like GRPO and RLOO that don't require a critic model. - -- Generalized Advantage Estimation (GAE): PPO uses GAE for computing advantage values, which helps reduce variance in policy gradient estimates while maintaining low bias. - -- Clipped Surrogate Objective: The core of PPO is implemented through the clipped surrogate objective function that limits policy updates. - -## Configuration - -Note that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior. - -Most critic configs are similar to those of actors. Note that the critic model is omitted from the figure below. - -![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d) - -- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n` - -- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers - -- `actor_rollout_ref.critic.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO critic updates. The ppo_mini_batch_size is a global size across all workers - -- `actor_rollout_ref.actor.clip_ratio`: The PPO clip range. Default to 0.2 - -- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for actor - -- `critic.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for critic. Defaults to `actor_rollout_ref.actor.ppo_epochs` - -- `algorithm.gemma`: discount factor - -- `algorithm.lam`: The lambda term that trades off between bias and variance in the GAE estimator - -- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo - -## Advanced Extensions - -### KL Divergence Control - -Options to prevent the policy from diverging too far from a reference policy. Two mechanisms are available: KL reward penalty and KL loss. For more technical details, see [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) - -Options to use KL loss for KL divergence control: - -- `actor_rollout_ref.actor.use_kl_loss`: to use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False - -- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001. - -- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html - -Options to use KL penalty in the reward: - -- `algorithm.use_kl_in_reward`: Whether to enable in-reward kl penalty. Default is False. - -- `algorithm.kl_penalty`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. This defines the way to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty` in core_algos.py. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html - -- `algorithm.kl_ctrl.kl_coef`: The (initial) coefficient of in-reward kl_penalty. Default is 0.001. -- `algorithm.kl_ctrl.type`: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController. -- `algorithm.kl_ctrl.horizon`: See source code of AdaptiveKLController for details. -- `algorithm.kl_ctrl.target_kl`: See source code of AdaptiveKLController for details. - -### Dual-clip PPO - -The Dual-Clip PPO introduces a approach by applying a lower bound to the policy ratio when the advantage is less than zero, when multiplied by a large raito, does not exceed a specified lower bound. - -![image](https://github.com/user-attachments/assets/fc232181-d8b0-4307-8dd2-4dc0a4c1c139) - -- `actor_rollout_ref.actor.clip_ratio_c`: lower bound of the value for Dual-clip PPO, defaults to 3.0 - -## Reference Example - -Qwen2.5 training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) - -```bash -bash run_gemma.sh - trainer.n_gpus_per_node=1 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - trainer.logger=console \ - critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - data.train_batch_size=256 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size=2 \ - critic.ppo_micro_batch_size=2 -``` - -Reference performance with verl v0.2: - -| Model | Method | Score | Link | -|-------------------------------|------------------|-------|------------------------------------------------------------------------------------------------| -| Qwen/Qwen2.5-0.5B-Instruct | pretrained model | 36.4 | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/) | -| Qwen/Qwen2.5-0.5B-Instruct | PPO | 56.7 | [PPO Command and Logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) | diff --git a/docs/algo/spin.md b/docs/algo/spin.md deleted file mode 100644 index c2a834262..000000000 --- a/docs/algo/spin.md +++ /dev/null @@ -1,179 +0,0 @@ -# Recipe: Self-Play Fine-Tuning (SPIN) - -Last updated: 05/31/2025. - -`verl` provides a recipe inspired by the paper **"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory. - -**Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models: - -1. **Synthetic Data Generation:** The current model generates responses, creating its own training data from previous iterations. -2. **Two-Player Game Setup:** A game involving two players acted by a single LLM. -3. **Iterative Training:** The model progressively improves by refining its policy, with each iteration's model becoming the opponent for the next iteration. - -Paper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\*, [Yihe Deng](https://github.com/uclaml/SPIN)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) - -[[Webpage](https://uclaml.github.io/SPIN/)] [[Huggingface](https://huggingface.co/papers/2401.01335)] [[Paper](https://arxiv.org/abs/2401.01335)] [[Original Implementation](https://github.com/uclaml/SPIN)] - -verl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20) - ---- - -## Key Function (compute_online_dpo_loss) and Related works -SPIN (Chen et al., 2024) proposes an iterative self-play mechanism to fine-tune language models. In each iteration, SPIN's training objective, when using a logistic loss function, is equivalent to Direct Preference Optimization (DPO) loss (Rafailov et al., 2023). - -This `verl` recipe realizes SPIN's core concept by using DPO loss iteratively (Xu et al., 2023; Xiong et al., 2023; Snorkel AI, 2024). This means that in each iteration, we fine-tune the LLM using DPO loss for preference optimization. Notably, Xu et al. (2023) explored iterative preference optimization with pairwise cringe loss, while Xiong et al. (2023) discussed how to bridge theory and practice for RLHF under KL constraints using iterative training. The concept of iterative preference learning was also explored in online DPO (Guo et al., 2024), which focuses on direct alignment from online AI feedback. In online DPO, preference data is dynamically updated during training, allowing the model to learn from its own generated data. - -Specifically, we developed the **`compute_online_dpo_loss`** function and built this SPIN recipe on top of it. By incorporating online preference generation, this approach enables continuously refining language models without relying on fixed external preference datasets. - -**Reference Papers:** -* [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024) -* [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023) -* [Somethings are more cringe than others: Preference optimization with the pairwise cringe loss](https://arxiv.org/abs/2312.16682) (Xu et al., 2023) -* [Iterative preference learning from human feedback: Bridging theory and practice for rlhf under kl-constraint](https://arxiv.org/abs/2312.11456) (Xiong et al., 2023) -* [Snorkel-Mistral-PairRM-DPO](https://huggingface.co/snorkelai/Snorkel-Mistral-PairRM-DPO) (Snorkel AI, 2024) -* [Direct language model alignment from online ai feedback](https://arxiv.org/abs/2402.04792) (Guo et al., 2024) - - -## Our Online DPO Implementation - -Our `compute_online_dpo_loss` function adapts `verl`'s existing PPO infrastructure (based on `verl` v0.3.0.post1) for this iterative online DPO. Key aspects of our implementation include: - -* **No Critic:** Unlike PPO, we omit the value function critic. -* **Dynamic Reference Model:** An explicit reference policy (`ref_policy_wg`) is used for DPO loss. This reference model's weights can be periodically updated from the actor (`ref_update_freq`), providing a dynamic baseline. -* **Online Preference Generation:** The `compute_onlineDPO_pref` function (in `core_algos.py`) dynamically creates chosen/rejected pairs based on a reward source (e.g., rule-based ranking for math problems). -* **DPO Loss Integration:** We replace PPO's policy loss with our `compute_online_dpo_loss` (in `core_algos.py`) within the actor update (`dp_actor.py`), directly optimizing the policy using the generated preferences. -* **Iterative Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the entire self-play loop: generation, preference labeling, optional reference model updates, and policy updates, enabling continuous self-improvement aligned with SPIN's principles. - ---- -## Algorithm - -This recipe implements an Online algorithm adapted to the `verl` Reinforcement Learning framework, which provides an alternative to PPO for fine-tuning language models. - -**Online Loop:** Instead of maximizing a scalar reward signal in PPO, this approach directly optimizes the policy model to align with preference data generated *online* during training: - -1. **Generation:** The current model generates multiple responses for each prompt in a batch. -2. **Preference Labeling:** A function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). This can be done using a reward function or implicit ranking based on specific rules. (In this recipe, we use rule-based ranking on the math problem). -3. **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using `compute_online_dpo_loss`, comparing against a reference model. - -**Connection with SPIN:** -Instead of only using a fixed target data distribution, the online generation loop in step 2 will dynamically change the target data distribution by using a certain Preference Labeling method (rule-based ranking on the math problem by selecting the better one in this recipe). This explores the direction mentioned in SPIN's paper Section 7 about "dynamically changing target data distribution" to potentially elevate LLM performance beyond the fixed human-annotated data ceiling. - ---- - -## Reproduce the Experiment (Example Setup) - -The following steps outline how to set up the environment and run the SPIN recipe, based on the provided test log using GSM8K and Qwen2.5-3B-Instruct. - -1. **Setup Environment (Example using Docker):** - ```bash - # Start a container with GPU access and shared memory - docker run -it --name spin_test --gpus all \ - --shm-size=32g \ - --ipc=host \ - -v /path/to/host/.cache:/root/.cache \ - -e HF_TOKEN= \ - lmsysorg/sglang:latest \ - /bin/bash - - # Inside the container or on your host machine: - # Ensure /tmp is writable - mkdir -p /tmp - chmod 1777 /tmp - - # Install Python 3.10 (if not present) and venv - sudo apt update - sudo apt install -y python3.10 python3.10-venv tmux - python3 -m ensurepip --upgrade - - # Create and activate a virtual environment - python3 -m venv ~/.python/spin_env - source ~/.python/spin_env/bin/activate - - # Install uv (fast package installer) - python3 -m pip install uv - ``` - -2. **Install verl and Dependencies:** - ```bash - # Clone the verl repository and checkout the spin branch - cd ~ - git clone git@github.com:volcengine/verl.git && cd verl - - # Install flash-attn (handle potential build issues) - python3 -m uv pip install wheel packaging - python3 -m uv pip install flash-attn --no-build-isolation --no-deps - - # Install verl with sglang extras - python3 -m uv pip install -e ".[sglang]" - ``` - *Note: If `flash-attn` installation fails, try the manual steps again or consult its documentation.* - -3. **Login & Download Data/Model:** - ```bash - # Login to Weights & Biases (optional, for logging) - export WANDB_API_KEY= - # wandb login - - # Download the GSM8K dataset - python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k # Adjusted path - - # Download the base model (Example: Qwen2.5-3B-Instruct) - huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen2.5-3B-Instruct - ``` - -4. **Configure:** - * Modify the configuration file (e.g., `config/spin_trainer.yaml` or the one specified in the run script) with correct paths to your downloaded model, data, desired hyperparameters (`dpo_beta`, learning rate, etc.), and distributed training settings (nodes, GPUs per node). - * Pay attention to `actor_rollout_ref.model_path`, `data` paths, `reward_model` config (if using one), and `trainer.ref_update_freq`. - -5. **Run Training:** - ```bash - # Set CUDA visible devices (adjust based on your hardware and config) - export CUDA_VISIBLE_DEVICES=0,1,2,3 - - # Launch the training script (e.g., test.sh or a custom script) - # Ensure test.sh points to the correct config and main script - bash recipe/spin/run_spin.sh - ``` - ---- - -## Configuration - -* The primary configuration is typically managed through a YAML file specified in the launch script (e.g., `config/spin_trainer.yaml`). -* Key configuration sections: - * `data`: Paths to training/validation prompt files, batch sizes, sequence lengths. - * `actor_rollout_ref`: Paths to the base model (used for actor and initial reference), FSDP settings, optimization parameters (learning rate, scheduler). - * `reward_model`: Configuration for the reward model used for online preference labeling (path, batch size, etc.). Can be omitted if using a simpler reward function. - * `algorithm`: DPO-specific hyperparameters like `dpo_beta`, `dpo_loss_type`. - * `trainer`: Distributed training settings (nodes, GPUs per node), logging (WandB), checkpointing frequency, and `ref_update_freq` (set > 0 to enable periodic reference model updates from the actor). - ---- - -## Key Files - -* `main_spin.py`: Main entry point using Hydra to load the config and launch the `SpinTrainer`. -* `spin_trainer.py`: Defines the `SpinTrainer` class, orchestrating the Online DPO training loop. -* `fsdp_workers.py`: Implements Ray workers (Actor, Reference) potentially using FSDP. -* `dp_actor.py`: Contains the actor class, including the DPO policy update logic. -* `core_algos.py`: Includes helper functions for `compute_online_dpo_loss` and `compute_onlineDPO_pref`. -* `config/spin_trainer.yaml` (or similar): Main Hydra configuration file for the recipe. -* `run_spin.sh` (or similar): Example bash script for launching a training run. -* `README.md`: This file. - ---- - -## Acknowledgement - -We sincerely thank the contribution and guidance from the `verl` community and advisors, including (adapted from SPPO): - -* [Zixiang Chen](https://sites.google.com/view/zxchen) -* [Yuhao Yang](https://github.com/yhyang201) -* [Yifan Zhang](https://github.com/yifanzhang-pro) -* [Yongan Xiang](https://github.com/BearBiscuit05) -* [Junrong Lin](https://github.com/ocss884) -* [Yuxuan Tong](https://github.com/tongyx361) -* [Guangming Shen](https://github.com/PeterSH6) -* [Biao He](https://www.linkedin.com/in/biao-he/) -* [Qingquan Song](https://qingquansong.github.io/) -* [Chenyang Zhao](https://zhaochenyang20.github.io/Chayenne/) -* [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) diff --git a/docs/algo/sppo.md b/docs/algo/sppo.md deleted file mode 100644 index bf7c4e9e6..000000000 --- a/docs/algo/sppo.md +++ /dev/null @@ -1,52 +0,0 @@ -# Recipe: Self-Play Preference Optimization (SPPO) - -Last updated: 05/28/2025. - -verl provides a community recipe implementation for the paper [Self-Play Preference Optimization for Language Model Alignment](https://arxiv.org/abs/2405.00675). SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets. - -Paper Authors: [Yue Wu](https://yuewu.us/)\*, [Zhiqing Sun](https://www.cs.cmu.edu/~zhiqings/)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Yiming Yang](https://www.cs.cmu.edu/~yiming/), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) - -verl Implementation Authors: [Yuhao Yang](https://github.com/yhyang201), [Chenyang Zhao](https://github.com/zhaochenyang20) - -[[Webpage](https://uclaml.github.io/SPPO/)] [[Huggingface](https://huggingface.co/papers/2405.00675)] [[Paper](https://arxiv.org/abs/2405.00675)][[Original Implementation](https://github.com/uclaml/SPPO)] - -## Reproduce the Experiment - -We evaluate the performance of SPPO on the MATH dataset. Starting from an initial score of 46.6 with Qwen2.5-7B-Instruct, we achieve a score of 65.6 after 20 epochs of training, placing our model approximately in the top 20 on the [MATH leaderboard](https://paperswithcode.com/sota/math-word-problem-solving-on-math). It's important to note that verl's internal evaluation metrics may not perfectly align with the official evaluation methodology for Qwen2.5-7B-Instruct. Therefore, for consistency and fair comparison, we report only the results based on verl's evaluation framework. - -``` -git clone git@github.com:volcengine/verl.git -cd verl -python3 -m uv pip install -e ".[sglang]" - -export WANDB_API_KEY= - -python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math -huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct - -export CUDA_VISIBLE_DEVICES=0,1,2,3 -bash recipe/sppo/run_qwen2.5-7b_rm.sh -``` - -Note that the installation would occasionally fail to install flash-attn. If this happens, you can install it manually by running: - -```bash -python3 -m uv pip install wheel -python3 -m uv pip install packaging -python3 -m uv pip install flash-attn --no-build-isolation --no-deps -``` - -## Acknowledgement - -We sincerely thank the contribution and guidance from: - -- [Yue Wu](https://yuewu.us/) -- [Chendong Wang](https://cdwang96.github.io/) -- [Yifan Zhang](https://github.com/yifanzhang-pro) -- [Yongan Xiang](https://github.com/BearBiscuit05) -- [Junrong Lin](https://github.com/ocss884) -- [Yuxuan Tong](https://github.com/tongyx361) -- [Guangming Shen](https://github.com/PeterSH6) -- [Biao He](https://www.linkedin.com/in/biao-he/) -- [Qingquan Song](https://qingquansong.github.io/) -- [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) diff --git a/docs/amd_tutorial/amd_build_dockerfile_page.rst b/docs/amd_tutorial/amd_build_dockerfile_page.rst deleted file mode 100644 index 51efa247c..000000000 --- a/docs/amd_tutorial/amd_build_dockerfile_page.rst +++ /dev/null @@ -1,796 +0,0 @@ -Getting started with AMD (ROCM Kernel) -===================================================== - -Last updated: 07/06/2025. - -Author: `Yusheng Su `_ - -Setup ------ - -If you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and set ``RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES`` or ``RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES`` when starting ray in verl's RLHF training. - - -docker/Dockerfile.rocm -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: bash - - FROM "rlfoundation.azurecr.io/rocm6.3.4:vllm-0.8.5-numa-patch-ubuntu-22.04" - - SHELL ["/bin/bash", "-ceuxo", "pipefail"] - - ENV MAX_JOBS=512 - - ENV PATH="/usr/local/python3.12/bin:$PATH" - RUN ln -sf /usr/bin/python3.12 /usr/bin/python && \ - ln -sf /usr/bin/pip3.12 /usr/bin/pip - - ############################################ - RUN apt-get update - RUN apt-get install -y pkg-config liblzma-dev - ############################################ - - ########################################### - ##########Install TransformerEngine######## - ########################################### - WORKDIR /workspace/ - # transformer-engine install - # https://github.com/ROCm/TransformerEngine - RUN rm -rf TransformerEngine - RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git - WORKDIR /workspace/TransformerEngine - git checkout 236178e5 - # git checkout bb061ade - # git checkout 864405c - ENV NVTE_FRAMEWORK=pytorch - ENV NVTE_ROCM_ARCH=gfx942 - ENV NVTE_USE_HIPBLASLT=1 - ENV NVTE_USE_ROCM=1 - # export CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr:${CMAKE_PREFIX_PATH:-}" - ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" - RUN MAX_JOBS=$(MAX_JOBS) pip install . -vvv - WORKDIR /workspace/ - ########################################### - ########################################### - ########################################### - - - - - - #################################################################################### - ################Install vllm - sglang require vllm 0.6.7 dependency################# - #################################################################################### - #### Require vllm 0.6.7 - checkout 113274a0 - WORKDIR /workspace/ - RUN rm -rf vllm - RUN pip uninstall -y vllm - # Refer to here (down-grade vllm to 0.6.3): https://docs.vllm.ai/en/v0.6.3/getting_started/amd-installation.html - RUN git clone https://github.com/ROCm/vllm.git - # git clone https://github.com/vllm-project/vllm.git - WORKDIR /workspace/vllm - RUN git checkout 113274a0 - ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" - #ENV MAX_JOBS=512 - ENV MAX_JOBS=${MAX_JOBS} - RUN pip install "boto3>=1.26.0" - RUN pip install setuptools_scm - # will add src into py. You can delete the repo - RUN python3 setup.py install - WORKDIR /workspace/ - #################################################################################### - #################################################################################### - #################################################################################### - - - - ########################################### - ############For hack docker################ - ########################################### - RUN pip install setuptools==75.8.0 - ########################################### - ########################################### - ########################################### - - - - ########################################### - ############build sgalng################### - ########################################### - # Set environment variables - ENV BASE_DIR=/sgl-workspace - ENV BUILD_TYPE=all - ENV SGL_REPO=https://github.com/sgl-project/sglang - ENV SGL_BRANCH=v0.4.6.post5 - ENV TRITON_REPO=https://github.com/ROCm/triton.git - ENV TRITON_COMMIT=improve_fa_decode_3.0.0 - ENV AITER_REPO=https://github.com/ROCm/aiter.git - ENV AITER_COMMIT=v0.1.2 - # v0.1.2 version - commit id: 9d11f47 - # ENV AITER_COMMIT=9d11f47 - ENV HIP_FORCE_DEV_KERNARG=1 - ENV HSA_NO_SCRATCH_RECLAIM=1 - ENV SGLANG_SET_CPU_AFFINITY=1 - ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 - ENV NCCL_MIN_NCHANNELS=112 - ENV MOE_PADDING=1 - ENV VLLM_FP8_PADDING=1 - ENV VLLM_FP8_ACT_PADDING=1 - ENV VLLM_FP8_WEIGHT_PADDING=1 - ENV VLLM_FP8_REDUCE_CONV=1 - ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 - ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 - ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" - ENV AMDGPU_TARGETS=gfx942 - ENV ROCM_ARCH=gfx942 - ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" - # Switch to working directory - WORKDIR /sgl-workspace - # Clean and create directory - RUN rm -rf /sgl-workspace && mkdir -p /sgl-workspace - - # Clone and build sglang - RUN git clone ${SGL_REPO} \ - && cd sglang \ - && git checkout ${SGL_BRANCH} || echo "Using default branch" \ - && cd sgl-kernel \ - && rm -f pyproject.toml \ - && mv pyproject_rocm.toml pyproject.toml \ - && python setup_rocm.py install \ - && cd .. \ - && if [ "$BUILD_TYPE" = "srt" ]; then \ - python -m pip --no-cache-dir install -e "python[srt_hip]"; \ - else \ - python -m pip --no-cache-dir install -e "python[all_hip]"; \ - fi \ - && cd /sgl-workspace \ - && cp -r /sgl-workspace/sglang /sglang \ - && python -m pip cache purge - - # Install common Python packages - RUN pip install IPython orjson python-multipart torchao pybind11 - # Rebuild Triton - RUN pip uninstall -y triton || true \ - && git clone ${TRITON_REPO} \ - && cd triton \ - && git checkout ${TRITON_COMMIT} \ - && cd python \ - && python3 setup.py install \ - && cd /sgl-workspace - # ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942 --amdgpu-lower-module-lds-strategy=1" - # ENV HIPCC_COMPILE_FLAGS_APPEND="--offload-arch=gfx942" - - # Build aiter - #version: Commit 9d11f47 - # && git checkout ${AITER_COMMIT} \ - RUN pip uninstall -y aiter || true - RUN git clone ${AITER_REPO} \ - && cd aiter \ - && git checkout ${AITER_COMMIT} \ - && git submodule sync \ - && git submodule update --init --recursive \ - && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py install \ - && cd /sgl-workspace - - # Copy MI300X config - RUN find /sgl-workspace/sglang/python/sglang/srt/layers/quantization/configs/ \ - /sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ - -type f -name '*MI300X*' | \ - xargs -I {} sh -c 'vf_config=$(echo "$1" | sed "s/MI300X/MI300X_VF/"); cp "$1" "$vf_config"' -- {} - - # Environment setup complete. - RUN echo "Environment setup complete." - - WORKDIR /workspace/ - ########################################### - ########################################### - ########################################### - - - - - - - ########################################### - ###############vllm v0.8.5################# - ########################################### - WORKDIR /workspace/ - - ENV VLLM_TARGET_DEVICE=rocm - ENV ROCM_PATH=/opt/rocm - ENV SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev - # Find the repo path in: DockerFile/Dockerfile.rocm_yang - # RUN git clone https://github.com/RLFoundation/vllm-patch.git - RUN pip uninstall -y vllm || true - RUN rm -rf vllm-patch - RUN git clone https://github.com/RLFoundation/vllm-patch.git \ - && cd vllm-patch \ - && git checkout v0.8.5-sleep-numa \ - && rm -rf build/ dist/ *.egg-info \ - && ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so \ - && SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py install - # RUN SETUPTOOLS_SCM_PRETEND_VERSION=0.8.5.dev PYTORCH_ROCM_ARCH="gfx90a;gfx942" MAX_JOBS=${MAX_JOBS} python3 setup.py develop - WORKDIR /workspace/ - ########################################### - ########################################### - ########################################### - - - - - ######################################### - #### Install megatron-core############### - ######################################### - RUN pip uninstall -y megatron-core && \ - git clone https://github.com/yushengsu-thu/Megatron-LM-amd_version.git && \ - cd Megatron-LM-amd_version && \ - pip install -vvv -e . && \ - cd /workspace/ - ######################################### - ######################################### - ######################################### - - - - - ####################################### - ################apex################### - ####################################### - WORKDIR /workspace/ - RUN pip uninstall -y apex && \ - git clone git@github.com:ROCm/apex.git && \ - cd apex && \ - python setup.py install && \ - cd /workspace/ - ####################################### - ####################################### - ####################################### - - - ################################################################################ - ###########################Add torch_memory_saver############################### - ################################################################################ - # Set environment variables - ENV HIPCC_COMPILE_FLAGS_APPEND="--amdgpu-target=gfx90a;gfx942 -D__HIP_PLATFORM_AMD__" - ENV CFLAGS="-D__HIP_PLATFORM_AMD__" - ENV CXXFLAGS="-D__HIP_PLATFORM_AMD__" - RUN pip install "git+https://github.com/YangWang92/torch_memory_saver_numa.git@numa" - ################################################################################ - ################################################################################ - ################################################################################ - - - - ######################################## - ######Install ray####################### - ######################################## - # need to add this patch: https://github.com/ray-project/ray/pull/53531/files - RUN pip uninstall ray -y - RUN pip install "ray[data,train,tune,serve]>=2.47.0" - ######################################## - ######################################## - ######################################## - - - ########################################## - #######Install other dependencies######### - ########################################## - RUN pip install "tensordict==0.6.2" --no-deps && \ - pip install accelerate \ - codetiming \ - datasets \ - dill \ - hydra-core \ - liger-kernel \ - numpy \ - pandas \ - peft \ - "pyarrow>=15.0.0" \ - pylatexenc \ - torchdata \ - wandb \ - orjson \ - pybind11 - - WORKDIR /workspace/ - RUN git clone https://github.com/volcengine/verl.git && \ - cd verl && \ - pip install -e . - ########################################## - ########################################## - ########################################## - - WORKDIR /workspace/ - CMD ["/usr/bin/bash"] - - -Build the image: -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: bash - - docker docker/build -t verl-rocm . - -Run the container -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Note: You can pull the docker from this DockerHub: [RLSys Foundation](https://hub.docker.com/u/yushengsuthu) -Pull the image: -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: bash - - docker pull yushengsuthu/verl:verl-0.4.1_ubuntu-22.04_rocm6.3.4-numa-patch_vllm0.8.5_sglang0.4.6.post4 - - docker tag yushengsuthu/verl:verl-0.4.1_ubuntu-22.04_rocm6.3.4-numa-patch_vllm0.8.5_sglang0.4.6.post4 verl-rocm:latest - -Run the container -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - -Optional: Running without root and with user permissions -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: bash - - docker run --rm -it \ - --device /dev/dri \ - --device /dev/kfd \ - -p 8265:8265 \ - --group-add video \ - --cap-add SYS_PTRACE \ - --security-opt seccomp=unconfined \ - --privileged \ - -v $HOME/.ssh:/root/.ssh \ - -v $HOME:$HOME \ - --shm-size 128G \ - -w $PWD \ - verl-rocm \ - /bin/bash - -(Optional): If you do not want to root mode and require assign yourself as the user -Please add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script. - -Example -------- - -Due to to special setting in AMD (ROCM) torch, -1. If your ``ray>=2.45.0`` (default), you need to set ``RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES`` when starting ray in verl's RLHF training and add this [patch](https://github.com/ray-project/ray/pull/53531/files). -2. If your ``ray<2.45.0``, you need to set ``RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES`` when starting ray in verl's RLHF training. -Inference ``$ENGINE`` can be ``vllm`` or ``sglang``. We choose ``vllm`` as default in the following examples. - - - -PPO -~~~ - -.. code-block:: bash - - YOUR_PROJECT_NAME=r1-verl-ppo-upstream - YOUR_RUN_NAME=r1-training_ppo-upstream - # export HYDRA_FULL_ERROR=1 - - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - - # [ray] < 2.45.0 - #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 - - # [ray] >= 2.45.0 - export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 # Patch with https://github.com/ray-project/ray/pull/52794 - - GPUS_PER_NODE=8 - MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct - python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k - python3 -c "import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')" - ENGINE=vllm #sglang - - PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ - data.train_files=data/gsm8k/train.parquet \ - data.val_files=data/gsm8k/test.parquet \ - data.train_batch_size=256 \ - data.val_batch_size=1312 \ - data.max_prompt_length=512 \ - data.max_response_length=256 \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - critic.optim.lr=1e-5 \ - critic.model.path=$MODEL_PATH \ - critic.ppo_micro_batch_size_per_gpu=4 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=console \ - trainer.project_name=$YOUR_PROJECT_NAME \ - trainer.experiment_name=$YOUR_RUN_NAME \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=$GPUS_PER_NODE \ - trainer.nnodes=1 \ - trainer.save_freq=10 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 #2>&1 | tee verl_demo.log - -GRPO -~~~~ - -.. code-block:: bash - - YOUR_PROJECT_NAME=r1-verl-grpo-upstream - YOUR_RUN_NAME=r1-training_grpo-upstream - # export HYDRA_FULL_ERROR=1 - # export FSDP_VERBOSE=1 - - #export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - - # [ray] < 2.45.0 - #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 - - # [ray] >= 2.45.0 - export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 # Patch with https://github.com/ray-project/ray/pull/52794 - - GPUS_PER_NODE=8 - MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct - # MODEL_PATH=Qwen/Qwen2-7B-Instruct - python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k - python3 -c "import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')" - ENGINE=vllm #sglang - - python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=data/gsm8k/train.parquet \ - data.val_files=data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.val_batch_size=1312 \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.model.enable_gradient_checkpointing=Flase \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.fsdp_config.param_offload=False \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name=$YOUR_PROJECT_NAME \ - trainer.experiment_name=$YOUR_RUN_NAME \ - trainer.n_gpus_per_node=$GPUS_PER_NODE \ - trainer.val_before_train=False \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 - - - -Multi-node training: slurm with Docker/Podman container ---------------------------------------------------------------------------------------- - -If you want to run multi-node training with slurm, you can use the following script. - -.. note:: - 1. You need to use ``podman`` or ``docker`` in the following script. We will release the apptainer script later. - 2. If you want to use ``podman``, you just replace ``docker`` with ``podman`` in the following script. - -The script includes the following steps: - -1. SLURM Configuration -2. Environment Setup -3. Docker/Podman Container Setup -4. Ray Cluster Initialization -5. Data Preprocessing -6. Model Setup -7. Training Launch - - -slurm_script.sh -~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: bash - - #!/bin/bash - - #SBATCH --job-name=verl-ray-on-slurm - #SBATCH --nodes=2 - #SBATCH --ntasks-per-node=2 - #SBATCH --mem=200G - #SBATCH --time=30-00:00:00 - #SBATCH --gpus-per-node=8 - #SBATCH --cpus-per-task=28 - #SBATCH --output=../verl_log/slurm-%j.out - #SBATCH --error=../verl_log/slurm-%j.err - #SBATCH --nodelist=gpu-[0,1] - - - # load necessary modules - ### Run this setup - # [Cluster]: Use docker - # docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 - - - ########################################################################## - ###The following setting should be set in different project and cluster### - ########################################################################## - - ### Project - CONTAINER_NAME="multinode_verl_training" - IMG="verl.rocm" - DOCKERFILE="docker/Dockerfile.rocm" - # echo $PWD - verl_workdir="${HOME}/projects/verl_upstream" - export TRANSFORMERS_CACHE="${HOME}/.cache/huggingface" - export HF_HOME=$TRANSFORMERS_CACHE - - ### Cluster Network Setting - export NCCL_DEBUG=TRACE - export GPU_MAX_HW_QUEUES=2 - export TORCH_NCCL_HIGH_PRIORITY=1 - export NCCL_CHECKS_DISABLE=1 - # export NCCL_IB_HCA=rdma0,rdma1,rdma2,rdma3,rdma4,rdma5,rdma6,rdma7 - export NCCL_IB_HCA=mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_8,mlx5_9 - export NCCL_IB_GID_INDEX=3 - export NCCL_CROSS_NIC=0 - export CUDA_DEVICE_MAX_CONNECTIONS=1 - export NCCL_PROTO=Simple - export RCCL_MSCCL_ENABLE=0 - export TOKENIZERS_PARALLELISM=false - export HSA_NO_SCRATCH_RECLAIM=1 - ########################################################################## - - ## Assign using GPUs - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - - ### For rocm and training script - # [ray] < 2.45.0 - #export RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 - - # [ray] >= 2.45.0 - export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 # Patch with https://github.com/ray-project/ray/pull/52794 - - - # Build and launch the Docker container - srun bash -c " - # Exit on any error - set -e - - # Clean up dangling images (images with tag) - docker image prune -f - - # Need to pull the docker first - docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 - - if ! docker images --format "{{.Repository}}:{{.Tag}}" | grep -q "${IMG}"; then - echo \"Building ${IMG} image...\" - docker build -f \"${DOCKERFILE}\" -t \"${IMG}\" . - else - echo \"${IMG} image already exists, skipping build\" - fi - - # Removing old container if exists - docker rm \"${CONTAINER_NAME}\" 2>/dev/null || true - - # Checking network devices - ibdev2netdev - - # Launch the docker - docker run --rm -d \ - -e HYDRA_FULL_ERROR=1 \ - -e RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 \ - -e RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1 \ - -e NCCL_DEBUG=${NCCL_DEBUG} \ - -e GPU_MAX_HW_QUEUES=${GPU_MAX_HW_QUEUES} \ - -e TORCH_NCCL_HIGH_PRIORITY=${TORCH_NCCL_HIGH_PRIORITY} \ - -e NCCL_CHECKS_DISABLE=${NCCL_CHECKS_DISABLE} \ - -e NCCL_IB_HCA=${NCCL_IB_HCA} \ - -e NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX} \ - -e NCCL_CROSS_NIC=${NCCL_CROSS_NIC} \ - -e CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS} \ - -e NCCL_PROTO=${NCCL_PROTO} \ - -e RCCL_MSCCL_ENABLE=${RCCL_MSCCL_ENABLE} \ - -e TOKENIZERS_PARALLELISM=${TOKENIZERS_PARALLELISM} \ - -e HSA_NO_SCRATCH_RECLAIM=${HSA_NO_SCRATCH_RECLAIM} \ - -e TRANSFORMERS_CACHE=${TRANSFORMERS_CACHE} \ - -e HF_HOME=${HF_HOME} \ - --network host \ - --device /dev/dri \ - --device /dev/kfd \ - --device /dev/infiniband \ - --group-add video \ - --cap-add SYS_PTRACE \ - --security-opt seccomp=unconfined \ - --privileged \ - -v \${HOME}:\${HOME} \ - -v \${HOME}/.ssh:/root/.ssh \ - -w "${verl_workdir}" \ - --shm-size 128G \ - --name \"${CONTAINER_NAME}\" \ - \"${IMG}\" \ - tail -f /dev/null - - echo \"Container setup completed\" - " - # (Optional): If you do not want to root mode and require assign yuorself as the user - # Please add `-e HOST_UID=$(id -u)` and `-e HOST_GID=$(id -g)` into the above docker launch script. - - - - - - ### Ray launch the nodes before training - - # Getting the node names - nodes_array=($(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ')) - - head_node=${nodes_array[0]} - head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) - - # if we detect a space character in the head node IP, we'll - # convert it to an ipv4 address. This step is optional. - if [[ "$head_node_ip" == *" "* ]]; then - IFS=' ' read -ra ADDR <<<"$head_node_ip" - if [[ ${#ADDR[0]} -gt 16 ]]; then - head_node_ip=${ADDR[1]} - else - head_node_ip=${ADDR[0]} - fi - echo "IPV6 address detected. We split the IPV4 address as $head_node_ip" - fi - - port=6379 - ip_head=$head_node_ip:$port - export ip_head - echo "IP Head: $ip_head" - - # make sure we set environment variables before Ray initialization - - # Print out all env variables - printenv - - echo "Starting HEAD at $head_node" - srun --nodes=1 --ntasks=1 -w "$head_node" \ - docker exec "${CONTAINER_NAME}" \ - ray start --head --node-ip-address="$head_node_ip" --port=$port \ - --dashboard-port=8266 \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & - # optional, though may be useful in certain versions of Ray < 1.0. - sleep 10 - - # number of nodes other than the head node - worker_num=$((SLURM_JOB_NUM_NODES - 1)) - - for ((i = 1; i <= worker_num; i++)); do - node_i=${nodes_array[$i]} - echo "Debug: Starting worker on node_i = ${node_i}" - if [ -z "$node_i" ]; then - echo "Error: Empty node name for worker $i" - continue - fi - echo "Starting WORKER $i at $node_i" - srun --nodes=1 --ntasks=1 -w "$node_i" \ - docker exec "${CONTAINER_NAME}" \ - ray start --address "$ip_head" --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & - sleep 5 - done - - - - - # Ray initlization test (See whether any error in the above execution) - echo "Testing Ray initialization in the slurm nodes..." - docker exec "${CONTAINER_NAME}" python3 -c ' - import ray - try: - ray.init(address="auto") - print("\n=== Ray Cluster Status ===") - print(f"Number of nodes: {len(ray.nodes())}") - for node in ray.nodes(): - print("Node: {}, Status: {}".format(node["NodeManagerHostname"], node["Alive"])) - # print(f"Node: {node}") - ray.shutdown() - print("Ray initialization successful!") - except Exception as e: - print(f"Ray initialization failed: {str(e)}") - ' - echo "=== Ray test completed ===" - ###### - - - - # Run data preprocessing - - echo "Starting data preprocessing..." - docker exec "${CONTAINER_NAME}" \ - python3 "examples/data_preprocess/gsm8k.py" "--local_dir" "../data/gsm8k" - - echo "Starting data preprocessing..." - docker exec "${CONTAINER_NAME}" \ - python3 "examples/data_preprocess/math_dataset.py" "--local_dir" "../data/math" - - train_files="../data/gsm8k/train.parquet" - val_files="../data/gsm8k/test.parquet" - - # Download and test model - echo "Loading model..." - docker exec "${CONTAINER_NAME}" \ - python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')" - MODEL_PATH="Qwen/Qwen2-7B-Instruct" - - # Set model path after pipeline test - MODEL_PATH="Qwen/Qwen2.5-0.5B-Instruct" - - echo "== Data and model loading Done ==" - - echo "Start to train..." - - docker exec "${CONTAINER_NAME}" \ - python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')" - MODEL_PATH="Qwen/Qwen2-7B-Instruct" - - - PYTHONUNBUFFERED=1 srun --overlap --nodes=${SLURM_NNODES} --ntasks=1 -w "$head_node" \ - docker exec "${CONTAINER_NAME}" \ - python3 -m verl.trainer.main_ppo \ - data.train_files=$train_files \ - data.val_files=$val_files \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.model.enable_gradient_checkpointing=False \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=$MODEL_PATH \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=8 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.kl_ctrl.kl_coef=0.0001 \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example' \ - trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ - trainer.n_gpus_per_node=${SLURM_GPUS_PER_NODE} \ - trainer.val_before_train=False \ - trainer.nnodes=${SLURM_NNODES} \ - trainer.save_freq=-1 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 - - -Run slurm_script.sh -~~~~~~~~~~~~~~~~~~~~ -Just sbatch your slurm_script.sh - -.. code-block:: bash - - sbatch slurm_script.sh - diff --git a/docs/amd_tutorial/amd_vllm_page.rst b/docs/amd_tutorial/amd_vllm_page.rst deleted file mode 100644 index 9c64755cb..000000000 --- a/docs/amd_tutorial/amd_vllm_page.rst +++ /dev/null @@ -1,105 +0,0 @@ -verl performance tuning for AMD (ROCm Kernel) -===================================================== - -Last updated: 04/25/2025. - -Author: `Yang Wang `_ - -Patch vLLM to Enable Sleep Mode for AMD GPUs --------------------------------------------------------------- - -By default, verl requires vLLM to enable sleep mode, which allows vLLM to offload GPU memory to CPU memory after rollout. However, this feature is still under review by the vLLM community. - -To enable vLLM's sleep mode, you can first use community patched code (from `this pull request `_) to build vLLM from the source code in the corresponding pull request. After the patch merged in vLLM main branch, you can directly install vLLM from the latest version. - -1. Clone the vLLM repository and build it with the following commands: - -.. code-block:: bash - - git clone -b sleep_amd https://github.com/HollowMan6/vllm.git - cd vllm - sudo ln -sf /opt/rocm/lib/libamdhip64.so /usr/lib/libamdhip64.so - VLLM_TARGET_DEVICE=rocm ROCM_PATH=/opt/rocm/ VLLM_GPU_LANG=HIP SETUPTOOLS_SCM_PRETEND_VERSION=0.8.4.dev python3 setup.py develop - -2. Additionally, make sure to use the ROCm version in your Docker image lager than or equal to ROCm 6.3.4, and we recommend to use ROCm 6.4.0 for better performance (see `this comment `_). - -After the upgrade, you can verify whether sleep mode is enabled by running the following test code (from `this comment `_). - -.. code-block:: python - - import torch - from vllm import LLM - - llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", enable_sleep_mode=True) - - def run_inference(prompt): - outputs = llm.generate(prompt) - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - - print("CUDA Memory Usage (after inference):") - torch.cuda.empty_cache() - print(f"{torch.cuda.memory_allocated()=}") - - run_inference("San Francisco is") - llm.sleep() - - print("CUDA Memory Usage (after sleep):") - torch.cuda.empty_cache() - print(f"{torch.cuda.memory_allocated()=}") - - llm.wake_up() - - print("CUDA Memory Usage (after wakeup):") - torch.cuda.empty_cache() - print(f"{torch.cuda.memory_allocated()=}") - - run_inference("Paris is") - -If sleep mode is enabled, you should see the memory usage reduce after sleep. - -After applying the vLLM patch and completing the installation, you can enable sleep mode in verl to reduce memory overhead. This allows verl to offload unused GPU memory during rollout, significantly lowering the memory footprint during long-context training or multi-node reinforcement learning. - - -Enable CUDA Graph and Bypass ROCm-related issues --------------------------------------------------------------- - -Due to potential issues with CUDA graph capture in ROCm, we’ve found that vLLM’s CUDA graph feature cannot be enabled on multiple nodes in verl on AMD platforms with vLLM V1 mode. This leads to significantly slower rollout performance. - -Our investigation shows that ROCm may trigger an unexpected crash when attempting to capture large batches with CUDA graph. One workaround is to patch the LLM configuration (from `this commit `_). - -.. code-block:: python - - self.inference_engine = LLM( - model=model_path, - enable_sleep_mode=True, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend="external_launcher", - dtype=config.dtype, - enforce_eager=config.enforce_eager, - gpu_memory_utilization=config.gpu_memory_utilization, - disable_custom_all_reduce=True, - disable_mm_preprocessor_cache=True, - limit_mm_per_prompt=limit_mm_per_prompt, - skip_tokenizer_init=False, - max_model_len=max_model_len, - load_format=load_format, - disable_log_stats=config.disable_log_stats, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=config.enable_chunked_prefill, - enable_prefix_caching=True, - trust_remote_code=trust_remote_code, - # enable compilation config to bypass oom on rocm - # change depends on your GPU memory size - compilation_config={"cudagraph_capture_sizes": [1, 2, 4, 8, 16, 32, 64]}, - seed=config.get('seed', 0), - ) - -Then, you can choose to enable CUDA graph by setting the following environment variables (see `this page `_): - -.. code-block:: bash - - actor_rollout_ref.rollout.enforce_eager=False \ diff --git a/docs/api/data.rst b/docs/api/data.rst deleted file mode 100644 index 1f6018bc9..000000000 --- a/docs/api/data.rst +++ /dev/null @@ -1,61 +0,0 @@ -Data interface -========================= - -Last updated: 05/19/2025 (API docstrings are auto-generated). - -DataProto is the interface for data exchange. - -The :class:`verl.DataProto` class contains two key members: - -- batch: a :class:`tensordict.TensorDict` object for the actual data -- meta_info: a :class:`Dict` with additional meta information - -TensorDict -~~~~~~~~~~~~ - -:attr:`DataProto.batch` is built on top of :class:`tensordict`, a project in the PyTorch ecosystem. -A TensorDict is a dict-like container for tensors. To instantiate a TensorDict, you must specify key-value pairs as well as the batch size. - -.. code-block:: python - - >>> import torch - >>> from tensordict import TensorDict - >>> tensordict = TensorDict({"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 5)}, batch_size=[2,]) - >>> tensordict["twos"] = 2 * torch.ones(2, 5, 6) - >>> zeros = tensordict["zeros"] - >>> tensordict - TensorDict( - fields={ - ones: Tensor(shape=torch.Size([2, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False), - twos: Tensor(shape=torch.Size([2, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), - zeros: Tensor(shape=torch.Size([2, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([2]), - device=None, - is_shared=False) - -One can also index a tensordict along its batch_size. The contents of the TensorDict can be manipulated collectively as well. - -.. code-block:: python - - >>> tensordict[..., :1] - TensorDict( - fields={ - ones: Tensor(shape=torch.Size([1, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False), - twos: Tensor(shape=torch.Size([1, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), - zeros: Tensor(shape=torch.Size([1, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([1]), - device=None, - is_shared=False) - >>> tensordict = tensordict.to("cuda:0") - >>> tensordict = tensordict.reshape(6) - -For more about :class:`tensordict.TensorDict` usage, see the official tensordict_ documentation. - -.. _tensordict: https://pytorch.org/tensordict/overview.html - - -Core APIs -~~~~~~~~~~~~~~~~~ - -.. autoclass:: verl.DataProto - :members: to, select, union, make_iterator, concat diff --git a/docs/api/single_controller.rst b/docs/api/single_controller.rst deleted file mode 100644 index 44ea366ff..000000000 --- a/docs/api/single_controller.rst +++ /dev/null @@ -1,30 +0,0 @@ -Single Controller interface -============================ - -Last updated: 05/27/2025 (API docstrings are auto-generated). - -The Single Controller provides a unified interface for managing distributed workers -using Ray or other backends and executing functions across them. -It simplifies the process of dispatching tasks and collecting results, particularly -when dealing with data parallelism or model parallelism. - - -Core APIs -~~~~~~~~~~~~~~~~~ - -.. autoclass:: verl.single_controller.Worker - :members: __init__, __new__, get_master_addr_port, get_cuda_visible_devices, world_size, rank - -.. autoclass:: verl.single_controller.WorkerGroup - :members: __init__, world_size - -.. autoclass:: verl.single_controller.ClassWithInitArgs - :members: __init__, __call__ - -.. autoclass:: verl.single_controller.ResourcePool - :members: __init__, world_size, local_world_size_list, local_rank_list - -.. autoclass:: verl.single_controller.ray.RayWorkerGroup - :members: __init__ - -.. autofunction:: verl.single_controller.ray.create_colocated_worker_cls \ No newline at end of file diff --git a/docs/api/trainer.rst b/docs/api/trainer.rst deleted file mode 100644 index abfa51f01..000000000 --- a/docs/api/trainer.rst +++ /dev/null @@ -1,31 +0,0 @@ -Trainer Interface -================================ - -Last updated: 06/08/2025 (API docstrings are auto-generated). - -Trainers drive the training loop. Introducing new trainer classes in case of new training paradiam is encouraged. - -.. autosummary:: - :nosignatures: - - verl.trainer.ppo.ray_trainer.RayPPOTrainer - - -Core APIs -~~~~~~~~~~~~~~~~~ - -.. autoclass:: verl.trainer.ppo.ray_trainer.RayPPOTrainer - :members: __init__, init_workers, fit - -.. automodule:: verl.utils.tokenizer - :members: hf_tokenizer - -.. automodule:: verl.trainer.ppo.core_algos - :members: agg_loss, kl_penalty, compute_policy_loss, kl_penalty - -.. automodule:: verl.trainer.ppo.reward - :members: load_reward_manager, compute_reward, compute_reward_async - -.. autoclass:: verl.workers.reward_manager.NaiveRewardManager - -.. autoclass:: verl.workers.reward_manager.DAPORewardManager diff --git a/docs/api/utils.rst b/docs/api/utils.rst deleted file mode 100644 index e15e3a5a3..000000000 --- a/docs/api/utils.rst +++ /dev/null @@ -1,76 +0,0 @@ -Utilities -============ - -Last updated: 05/19/2025 (API docstrings are auto-generated). - -This section documents the utility functions and classes in the VERL library. - -Python Functional Utilities ------------------------------- - -.. automodule:: verl.utils.py_functional - :members: append_to_dict - -File System Utilities ------------------------- - -.. automodule:: verl.utils.fs - :members: copy_to_local - -Tracking Utilities ---------------------- - -.. automodule:: verl.utils.tracking - :members: Tracking - -Metrics Utilities ---------------------- - -.. automodule:: verl.utils.metric - :members: reduce_metrics - -Checkpoint Management ------------------------- - -.. automodule:: verl.utils.checkpoint.checkpoint_manager - :members: find_latest_ckpt_path - -.. automodule:: verl.utils.checkpoint.fsdp_checkpoint_manager - :members: FSDPCheckpointManager - -Dataset Utilities ---------------------- - -.. automodule:: verl.utils.dataset.rl_dataset - :members: RLHFDataset, collate_fn - -Torch Functional Utilities ------------------------------ - -.. automodule:: verl.utils.torch_functional - :members: get_constant_schedule_with_warmup, masked_whiten, masked_mean, logprobs_from_logits - -Sequence Length Balancing ----------------------------- - -.. automodule:: verl.utils.seqlen_balancing - :members: get_reverse_idx, rearrange_micro_batches - -Ulysses Utilities --------------------- - -.. automodule:: verl.utils.ulysses - :members: gather_outputs_and_unpad, ulysses_pad_and_slice_inputs - -FSDP Utilities ------------------- - -.. automodule:: verl.utils.fsdp_utils - :members: get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn, load_fsdp_model_to_gpu, load_fsdp_optimizer, offload_fsdp_model_to_cpu, offload_fsdp_optimizer, - -Debug Utilities -------------------- - -.. automodule:: verl.utils.profiler - :members: log_gpu_memory_usage, GPUMemoryLogger - diff --git a/docs/ascend_tutorial/ascend_profiling.rst b/docs/ascend_tutorial/ascend_profiling.rst deleted file mode 100644 index db2972d78..000000000 --- a/docs/ascend_tutorial/ascend_profiling.rst +++ /dev/null @@ -1,100 +0,0 @@ -在昇腾设备上基于FSDP后端进行数据采集 -==================================== - -Last updated: 07/14/2025. - -这是一份在昇腾设备上基于FSDP后端使用GRPO或DAPO算法进行数据采集的教程。 - -配置 ----- - -复用verl/trainer/config/ppo_trainer.yaml中的配置项控制采集的模式和步数, -通过verl/trainer/config/npu_profile/npu_profile.yaml中的配置项控制例如采集等级等参数。 - -全局采集控制 -~~~~~~~~~~~~ - -通过 ppo_trainer.yaml 中的参数控制采集步数和模式: - -- trainer.profile_steps: - 该参数可以设置为一个包含采集步数的列表,例如[2, - 4], 意味着将会采集第二步和第四步。如果该参数为null,则代表不进行采集 -- actor_rollout_ref.profiler: - 控制采集的ranks和模式 - - - all_ranks:设为True代表对所有rank进行采集 - - ranks:当all_ranks不为True时, - 通过ranks参数控制需要采集的rank,该参数设置为一个包含采集rank的列表, 例如[0, - 1] - - discrete: - 控制采集的模式。当该参数设置为False,代表采集端到端的数据;当该参数设置为True,代表采用离散模式分训练阶段采集数据 - -通过 npu_profile.yaml 中的参数控制具体采集行为: - -- save_path:采集数据的存放路径 -- level:采集等级,可选项为level_none、level0、level1和level2 - - - level_none:不采集所有Level层级控制的数据,即关闭profiler_level - - level0:采集上层应用数据、底层NPU数据以及NPU上执行的算子信息 - - level1:在level0的基础上多采集CANN层AscendCL数据和NPU上执行的AI - Core性能指标信息 - - level2:在level1的基础上多采集CANN层Runtime数据以及AI CPU - -- record_shapes:是否记录张量形状 -- with_memory:是否启用内存分析 -- with_npu:是否采集device侧性能数据 -- with_cpu:是否采集host侧性能数据 -- with_module:是否记录框架层python调用栈信息 -- with_stack:是否记录算子调用栈信息 -- analysis:是否自动解析数据 - -示例 ----- - -禁用采集 -~~~~~~~~ - -.. code:: yaml - - trainer: - profile_steps: null # disable profile - -端到端采集 -~~~~~~~~~~ - -.. code:: yaml - - trainer: - profile_steps: [1, 2, 5] - actor_rollout_ref: - profiler: - discrete: False - all_ranks: True - - -离散模式采集 -~~~~~~~~~~~~ - -.. code:: yaml - - trainer: - profile_steps: [1, 2, 5] - actor_rollout_ref: - profiler: - discrete: True - all_ranks: False - ranks: [0, 1] - - -可视化 ------- - -采集后的数据存放在用户设置的save_path下,可通过 `MindStudio Insight `_ 工具进行可视化。 - -如果analysis参数设置为False,采集之后需要进行离线解析: - -.. code:: python - - import torch_npu - # profiler_path请设置为"localhost.localdomain___ascend_pt"目录的上一级目录 - torch_npu.profiler.profiler.analyse(profiler_path=profiler_path) \ No newline at end of file diff --git a/docs/ascend_tutorial/ascend_profiling_en.rst b/docs/ascend_tutorial/ascend_profiling_en.rst deleted file mode 100644 index 3ab067ae2..000000000 --- a/docs/ascend_tutorial/ascend_profiling_en.rst +++ /dev/null @@ -1,109 +0,0 @@ -Data collection based on FSDP (Fully Sharded Data Parallel) backend on Ascend devices(NPU) -========================================================================================== - -Last updated: 07/14/2025. - -This is a tutorial for data collection using the GRPO or DAPO algorithm -based on FSDP on Ascend devices. - -Configuration -------------- - -Reuse the configuration items in -verl/trainer/config/ppo_trainer.yaml to control the collection mode -and steps, you can also manage the collection behaviors such as -collection level via verl/trainer/config/npu_profile/npu_profile.yaml. - -Global collection control -~~~~~~~~~~~~~~~~~~~~~~~~~ - -Use parameters in ppo_trainer.yaml to control the collection mode -and steps. - -- trainer.profile_steps: This parameter can be set as a list that has - collection steps, such as [2, 4], which means it will collect steps 2 - and 4. If set to null, no collection occurs. -- actor_rollout_ref.profiler: Control the ranks and mode of profiling - - - all_ranks: Collects data from all ranks when set to true. - - ranks: This parameter specifies which ranks to collect (e.g., [0, - 1]) when all_ranks is False. - - discrete: Controls the collection mode. If False, end-to-end data - is collected; if True, data is collected in discrete phases during - training. - -Use parameters in npu_profile.yaml to control collection behavior: - -- save_path: Storage path for collected data. -- level: Collection level—options are level_none, level0, level1, and - level2 - - - level_none: Disables all level-based data collection (turns off - profiler_level). - - level0: Collect high-level application data, underlying NPU data, - and operator execution details on NPU. - - level1: Extends level0 by adding CANN-layer AscendCL data and AI - Core performance metrics on NPU. - - level2: Extends level1 by adding CANN-layer Runtime data and AI - CPU metrics. - -- record_shapes: Whether to record tensor shapes. -- with_memory: Whether to enable memory analysis. -- with_npu: Whether to collect device-side performance data. -- with_cpu: Whether to collect host-side performance data. -- with_module: Whether to record framework-layer Python call stack - information. -- with_stack: Whether to record operator call stack information. -- analysis: Enables automatic data parsing. - -Examples --------- - -Disabling collection -~~~~~~~~~~~~~~~~~~~~ - -.. code:: yaml - - trainer: - profile_steps: null # disable profile - -End-to-End collection -~~~~~~~~~~~~~~~~~~~~~ - -.. code:: yaml - - trainer: - profile_steps: [1, 2, 5] - actor_rollout_ref: - profiler: - discrete: False - all_ranks: True - - -Discrete Mode Collection -~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: yaml - - trainer: - profile_steps: [1, 2, 5] - actor_rollout_ref: - profiler: - discrete: True - all_ranks: False - ranks: [0, 1] - - -Visualization -------------- - -Collected data is stored in the user-defined save_path and can be -visualized by using the `MindStudio Insight `_ tool. - -If the analysis parameter is set to False, offline parsing is required after data collection: - -.. code:: python - - import torch_npu - # Set profiler_path to the parent directory of the "localhost.localdomain___ascend_pt" folder - torch_npu.profiler.profiler.analyse(profiler_path=profiler_path) \ No newline at end of file diff --git a/docs/ascend_tutorial/ascend_quick_start.rst b/docs/ascend_tutorial/ascend_quick_start.rst deleted file mode 100644 index 589964328..000000000 --- a/docs/ascend_tutorial/ascend_quick_start.rst +++ /dev/null @@ -1,204 +0,0 @@ -verl x Ascend -=================================== - -Last updated: 06/17/2025. - -我们在 verl 上增加对华为昇腾设备的支持。 - -硬件支持 ------------------------------------ - -Atlas 200T A2 Box16 - -Atlas 900 A2 PODc - - -安装 ------------------------------------ - -基础环境准备 -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -+-----------+-------------+ -| software | version | -+-----------+-------------+ -| Python | == 3.10 | -+-----------+-------------+ -| CANN | == 8.1.RC1 | -+-----------+-------------+ -| torch | == 2.5.1 | -+-----------+-------------+ -| torch_npu | == 2.5.1.RC1| -+-----------+-------------+ - - -vllm & vllm-ascend -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -为了能够在 verl 中正常使用 vllm,需使用以下命令编译安装 vllm 和 vllm-ascend。请注意根据机器类型区分安装方式。 - -.. code-block:: bash - - # vllm - git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm.git - cd vllm - pip install -r requirements-build.txt - - # for Atlas 200T A2 Box16 - VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/ - - # for Atlas 900 A2 PODc - VLLM_TARGET_DEVICE=empty pip install -e . - -.. code-block:: bash - - # vllm-ascend - git clone -b v0.7.3.post1 --depth 1 https://github.com/vllm-project/vllm-ascend.git - cd vllm-ascend - export COMPILE_CUSTOM_KERNELS=1 - python setup.py install - -安装verl -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: bash - - git clone https://github.com/volcengine/verl.git - cd verl - pip install -r requirements-npu.txt - pip install -e . - -其他三方库说明 -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -+--------------+---------------+ -| software | description | -+--------------+---------------+ -| transformers | v4.52.4 | -+--------------+---------------+ -| flash_attn | not supported | -+--------------+---------------+ -| liger-kernel | not supported | -+--------------+---------------+ -| tensordict | 0.8.3 (ARM) | -+--------------+---------------+ - -1. 支持通过 transformers 使能 --flash_attention_2, transformers 需大于等于 4.52.0版本。 -2. 不支持通过 flash_attn 使能 flash attention 加速。 -3. 不支持 liger-kernel 使能。 -4. 针对 ARM 服务器,tensordict 要求 0.8.3,可在依赖安装完成后再手动安装 tensordict。 -5. 针对 x86 服务器,需要安装 cpu 版本的 torchvision。 - -.. code-block:: bash - - pip install torchvision==0.20.1+cpu --index-url https://download.pytorch.org/whl/cpu - - -快速开始 ------------------------------------ -正式使用前,建议您通过对Qwen2.5-0.5B GRPO的训练尝试以检验环境准备和安装的正确性。 - -1.下载数据集并将数据集预处理为parquet格式,以便包含计算RL奖励所需的必要字段 - -.. code-block:: bash - - python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k - -2.执行训练 - -.. code-block:: bash - - set -x - - export VLLM_ATTENTION_BACKEND=XFORMERS - - python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=128 \ - data.max_prompt_length=512 \ - data.max_response_length=128 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.actor.optim.lr=5e-7 \ - actor_rollout_ref.model.use_remove_padding=False \ - actor_rollout_ref.actor.entropy_coeff=0.001 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 \ - trainer.device=npu $@ - - -支持现状 ------------------------------------ - -+-----------+-------------------------+-------------+-------------------+----------------------+ -| algorithm | model | rewards mae | throughput ratio | hardware | -+-----------+-------------------------+-------------+-------------------+----------------------+ -| GRPO | Qwen2.5-7B-instruct | 0.38% | 0.588 | Atlas 200T A2 Box16 | -+-----------+-------------------------+-------------+-------------------+----------------------+ -| GRPO | Qwen2.5-32B-instruct | 0.30% | 0.685 | Atlas 200T A2 Box16 | -+-----------+-------------------------+-------------+-------------------+----------------------+ -| GRPO | Qwen2.5-VL-3B-instruct | 3.14% | 0.470 | Atlas 200T A2 Box16 | -+-----------+-------------------------+-------------+-------------------+----------------------+ -| GRPO | Qwen2.5-VL-7B-instruct | 3.30% | 0.380 | Atlas 200T A2 Box16 | -+-----------+-------------------------+-------------+-------------------+----------------------+ -| GRPO | Qwen2.5-VL-32B-instruct | 0.79% | 0.568 | Atlas 200T A2 Box16 | -+-----------+-------------------------+-------------+-------------------+----------------------+ -| DAPO | Qwen2.5-7B-instruct | 3.83% | pending | Atlas 200T A2 Box16 | -+-----------+-------------------------+-------------+-------------------+----------------------+ -| SFT-PEFT | Qwen2.5-0.5B-instruct | 0.06% | 0.305 | Atlas 900 A2 PODc | -+-----------+-------------------------+-------------+-------------------+----------------------+ - -精度对比说明 -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -对于 SFT 类算法,我们期望在相同配置下华为昇腾设备与 A100 的 loss 平均绝对误差<= 2%。计算方式如下图。更多信息请参考 `精度计算说明 `_。 - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/loss_comparison.png?raw=true - :alt: loss_comparison - -根据经验,对于 GRPO 等 RL 类算法,我们期望在相同配置下华为昇腾设备与 A100 的 rewards 平均绝对误差<= 4%,计算方式参考上图。 - - -吞吐对比说明 -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Ascend npu 和 A100 分别取日志中前4个 step 的 "perf/throughput" 做平均, throughput ratio = npu 平均值 / A100 平均值。 - - - -计划 ------------------------------------ - -查看 `roadmap `_ 获取更多特性的支持进度。 - - - -声明 ------------------------------------ -verl中提供的ascend支持代码皆为参考样例,商业使用请通过官方正式途径沟通,谢谢。 diff --git a/docs/conf.py b/docs/conf.py deleted file mode 100644 index d405288ff..000000000 --- a/docs/conf.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) - - -# -- Project information ----------------------------------------------------- - -project = "verl" -copyright = "2024 ByteDance Seed Foundation MLSys Team" -author = "Guangming Sheng, Chi Zhang, Yanghua Peng, Haibin Lin" - - -# -- General configuration --------------------------------------------------- -# The master toctree document. -master_doc = "index" - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - "myst_parser", - "sphinx.ext.autodoc", - "sphinx.ext.autosummary", - "sphinx.ext.autosectionlabel", - "sphinx.ext.napoleon", - "sphinx.ext.viewcode", -] -# Use Google style docstrings instead of NumPy docstrings. -napoleon_google_docstring = True -napoleon_numpy_docstring = False - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -source_suffix = { - ".rst": "restructuredtext", - ".md": "markdown", -} - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = "en" - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] - - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "sphinx_rtd_theme" - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] - -# Add the JavaScript file -html_js_files = [ - "js/runllm-widget.js", -] - -exclude_patterns += ["README.md", "README_vllm0.7.md"] - -suppress_warnings = ["ref.duplicate", "ref.myst"] diff --git a/docs/examples/config.rst b/docs/examples/config.rst deleted file mode 100644 index 0f05c181b..000000000 --- a/docs/examples/config.rst +++ /dev/null @@ -1,683 +0,0 @@ -.. _config-explain-page: - -Config Explanation -=================== - -Last updated: 06/18/2025. - -ppo_trainer.yaml for RL FSDP Backend -------------------------------------- - -Data -~~~~ - -.. code:: yaml - - data: - tokenizer: null - train_files: ~/data/rlhf/gsm8k/train.parquet - val_files: ~/data/rlhf/gsm8k/test.parquet - prompt_key: prompt - max_prompt_length: 512 - max_response_length: 512 - train_batch_size: 1024 - return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs - return_raw_chat: False - return_full_prompt: False - shuffle: True - filter_overlong_prompts: False - filter_overlong_prompts_workers: 1 - truncation: error - image_key: images - trust_remote_code: True - custom_cls: - path: null - name: null - -- ``data.train_files``: Training set parquet. Can be a list or a single - file. The program will read all files into memory, so it can't be too - large (< 100GB). The path can be either local path or HDFS path. For - HDFS path, we provide utils to download it to DRAM and convert the - HDFS path to local path. -- ``data.val_files``: Validation parquet. Can be a list or a single - file. -- ``data.prompt_key``: The field in the dataset where the prompt is - located. Default is 'prompt'. -- ``data.max_prompt_length``: Maximum prompt length. All prompts will be - left-padded to this length. An error will be reported if the length is - too long -- ``data.max_response_length``: Maximum response length. Rollout in RL - algorithms (e.g. PPO) generates up to this length -- ``data.train_batch_size``: Batch size sampled for one training - iteration of different RL algorithms. -- ``data.return_raw_input_ids``: Whether to return the original - input_ids without adding chat template. This is mainly used to - accommodate situations where the reward model's chat template differs - from the policy. It needs to be decoded first, then apply the RM's - chat template. If using a model-based RM, and the policy and RM - chat_templates are different, this flag needs to be set -- ``data.return_raw_chat``: Whether to return the original chat (prompt) - without applying chat template. -- ``data.return_full_prompt``: Whether to return the full prompt with chat template -- ``data.shuffle``: Whether to shuffle the data in the dataloader. -- ``data.filter_overlong_prompts``: Default don't filter. -- ``data.filter_overlong_prompts_workers``: For large-scale dataset, filtering - overlong prompts could be timeconsuming. You cat set the ``filter_overlong_prompts_workers`` - to use multiprocessing for speed up. Default to 1. -- ``data.truncation``: Truncate the input_ids or prompt length if they - exceed max_prompt_length. Default is 'error', not allow exceed the - max_prompt_length. The users should increase the max_prompt_length if - throwing the error. You can also set ``left``, ``right`` and ``middle``. - When ``middle`` is selected, the logic splits the allowed max length roughly in half - and keeps the head and tail of the sequence, effectively discarding the middle section. -- ``data.image_key``: The field in the multi-modal dataset where the image is - located. Default is 'images'. -- ``data.trust_remote_code``: If the remote tokenizer has python file, we can use this field to allow - using remote tokenizer. For example: moonshotai/Moonlight-16B-A3B-Instruct - -Customized Dataset -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Customized dataset extension is implemented for the SFT trainer and can be extended to other trainers with similar changes. - -.. code:: yaml - - custom_cls: - path: null - name: null - -- ``data.custom_cls.path``: The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used. -- ``data.custom_cls.name``: The name of the dataset class within the specified file. - -Actor/Rollout/Reference Policy -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: yaml - - actor_rollout_ref: - hybrid_engine: True - model: - path: ~/models/deepseek-llm-7b-chat - external_lib: null - override_config: - model_config: {} - moe_config: # Megatron only, can adjust moe configuration - freeze_moe_router: False # Megatron only, can freeze moe router (no grad) - enable_gradient_checkpointing: False - enable_activation_offload: False - trust_remote_code: False - use_remove_padding: False - actor: - strategy: fsdp # This is for backward-compatibility - ppo_mini_batch_size: 256 - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: 8 - use_dynamic_bsz: False - ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} - grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.0 - use_kl_loss: False # True for GRPO - use_torch_compile: True # False to disable torch compile - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo - ppo_epochs: 1 - data_loader_seed: null - shuffle: False - ulysses_sequence_parallel_size: 1 # sp size - optim: - lr: 1e-6 - lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: 0.0 # only used with cosine lr scheduler, default to 0.0 - num_cycles: 0.5 # only used with cosine lr scheduler, default to 0.5 - warmup_style: constant # select from constant/cosine - total_training_steps: -1 # must be override by program - fsdp_config: - wrap_policy: - # transformer_layer_cls_to_wrap: None - min_num_params: 0 - param_offload: False - optimizer_offload: False - fsdp_size: -1 - checkpoint: - # What to include in saved checkpoints - # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - save_contents: ['model', 'optimizer', 'extra'] - # For more flexibility, you can specify the contents to load from the checkpoint. - load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} - ref: - fsdp_config: - param_offload: False - wrap_policy: - # transformer_layer_cls_to_wrap: None - min_num_params: 0 - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: 16 - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size - rollout: - name: vllm - temperature: 1.0 - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1 - prompt_length: ${data.max_prompt_length} # not use for opensource - response_length: ${data.max_response_length} - # for vllm rollout - dtype: bfloat16 # should align with FSDP - gpu_memory_utilization: 0.5 - ignore_eos: False - enforce_eager: True - free_cache_engine: True - load_format: dummy_dtensor - tensor_model_parallel_size: 2 - max_num_batched_tokens: 8192 - max_num_seqs: 1024 - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: 16 - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - # for hf rollout - do_sample: True - engine_kwargs: # inference engine parameters - vllm: - swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB - disable_mm_preprocessor_cache: False # disable preprocessor cache for multimodel models - sglang: - attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla - - n: 1 # for each prompt, sample n responses (i.e. num sample times). set it to values > 1 for grpo, rloo - val_kwargs: - # sampling parameters for validation - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1.0 - temperature: 0 - n: 1 - do_sample: False # default eager for validation - - agent: - custom_async_server: # Use custom async server implementation for rollout - path: null - name: null - -**Common config for actor, rollout and reference model** - -- ``actor_rollout_ref.hybrid_engine``: Whether it's a hybrid engine, - currently only supports hybrid engine -- ``actor_rollout_ref.model.path``: Huggingface model path. This can be - either local path or HDFS path. For HDFS path, we provide utils to - download it to DRAM and convert the HDFS path to local path. -- ``actor_rollout_ref.model.external_libs``: Additional Python packages - that need to be imported. Used to register models or tokenizers into - the Huggingface system. -- ``actor_rollout_ref.model.override_config``: Used to override some of - the model's original configurations, mainly dropout -- ``actor_rollout_ref.model.enable_gradient_checkpointing``: Whether to - enable gradient checkpointing for the actor -- ``actor_rollout_ref.model.enable_activation_offload``: Whether to enable - activation offloading for the actor -- ``actor_rollout_ref.model.trust_remote_code``: Whether to enable loading - a remote code model -- ``actor_rollout_ref.model.use_fused_kernels``: Whether to use fused - kernels in the model. If set to True, the following parameters will be - used. - - ``actor_rollout_ref.model.fused_kernel_options.impl_backend``: The - implementation backend for fused kernels. Options: "triton" or - "torch". Default is "torch". - While in megatron, we only support "triton" as the - implementation backend, so there is no need for this option. -- ``actor_rollout_ref.model.use_remove_padding``: Whether to use remove - padding in the model. If set to True, the model will remove padding - tokens in the input_ids and response_ids. This helps a lot in improving model running efficiency. - -**Actor model** - -- ``actor_rollout_ref.actor.strategy``: fsdp or megatron. In this - example, we use fsdp backend. - -- ``actor_rollout_ref.actor.ppo_mini_batch_size``: One sample is split - into multiple sub-batches with batch_size=ppo_mini_batch_size for PPO - updates. The ppo_mini_batch_size is a global num across all workers/gpus - -- ``actor_rollout_ref.actor.ppo_micro_batch_size``: [Will be deprecated, use ppo_micro_batch_size_per_gpu] - Similar to gradient accumulation, the micro_batch_size_per_gpu for one forward pass, - trading speed for GPU memory. The value represent the global view. - -- ``actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``: Similar to gradient - accumulation, the micro_batch_size_per_gpu for one forward pass, trading speed - for GPU memory. The value represent the local num per gpu. - -- ``actor_rollout_ref.actor.grad_clip``: Gradient clipping for actor - updates -- ``actor_rollout_ref.actor.use_kl_loss``: to use kl loss in actor. When used, we are not applying KL in the reward function. - -- ``actor_rollout_ref.actor.clip_ratio``: PPO clip ratio - -- ``actor_rollout_ref.actor.use_torch_compile``: Whether to use torch compile in actor - -- ``actor_rollout_ref.actor.entropy_coeff``: The weight of entropy when - calculating PPO loss. The default value is changed to 0.0 since v0.3.x - -- ``actor_rollout_ref.actor.ppo_epochs``: Number of epochs for PPO - updates on one set of sampled data - -- ``actor_rollout_ref.actor.data_loader_seed``: From torch 2.6.0 Megatron backend can get wrong seed generated by pytorch - between cp ranks and cause misalignment between data on these ranks, so we shall manually set the seed to avoid hanging - issue. if ``actor_rollout_ref.actor.shuffle`` is not null, this must be set. - -- ``actor_rollout_ref.actor.shuffle``: Whether to shuffle data when - there are multiple epochs - -- ``actor_rollout_ref.actor.optim``: Actor's optimizer parameters - -- ``actor_rollout_ref.actor.fsdp_config``: FSDP config for actor - training - - - ``wrap_policy``: FSDP wrap policy. By default, it uses Huggingface's - wrap policy, i.e., wrapping by DecoderLayer - - - No need to set transformer_layer_cls_to_wrap, so we comment it. - - - ``*_offload``: Whether to enable parameter, gradient and optimizer - offload - - - Trading speed for GPU memory. - -- ``actor_rollout_ref.actor.use_kl_loss``: Whether to enable kl loss. Default is False. - -- ``actor_rollout_ref.actor.kl_loss_coef``: The coefficient of kl loss. Default is 0.001. - -- ``actor_rollout_ref.actor.kl_loss_type``: Support ``kl`` (``k1``), ``abs``, ``mse`` (``k2``), ``low_var_kl`` (``k3``) and ``full``. How to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty()` in `core_algos.py `_ . See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html - -- ``actor_rollout_ref.actor.checkpoint``: The configurations of checkpoint function in actor - - - ``save_contents``: The contents to save in the checkpoint. By default, we save model, optimizer and extra information in the checkpoint. - The extra information includes Rng states currently, FSDP supported lr_scheduler, and Megatron opt_param_scheduler will coming soon. - We do not store hf_model in checkpoint by default, but we provide a tool in ``scripts/model_merge.py`` to convert checkpoint format to hf format. - - - ``load_contents``: The contents to load in the checkpoint, you can specify different checkpoint loading contents. By default, it is the same with ``save_checkpoint``. - -**Reference Model** - -Reference model will be enabled when ``actor.use_kl_loss`` or/and ``algorithm.use_kl_in_reward`` is/are True. - -- ``actor_rollout_ref.ref``: FSDP config same as actor. **For models - larger than 7B, it's recommended to turn on offload for ref by - default** - -- ``actor_rollout_ref.ref.log_prob_micro_batch_size``: [Will be deprecate, use log_prob_micro_batch_size_per_gpu] - The batch size for one forward pass in the computation of ``ref_log_prob``. The value represent the global num. - -- ``actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu``: The batch size - for one forward pass in the computation of ``ref_log_prob``. The value represent the local num per gpu. - -**Rollout Model** - -- ``actor_rollout_ref.rollout.name``: hf/vllm/sglang. - -- Rollout (Auto-regressive) parameters. The key should be equal to the - property name in vLLM's ``SamplingParams``. - - - ``temperature``, ``top_k``, ``top_p`` and others: Sampling - parameters in ``SamplingParams``. - -- ``actor_rollout_ref.rollout.dtype``: Rollout model parameters type. This should be align with - the actor model parameter type in FSDP/Megatron backend. - -- ``actor_rollout_ref.rollout.gpu_memory_utilization``: - - - For vLLM v0.7.0 and later: The fraction of **total** GPU memory to be used for the vLLM instance. - - For SGLang: Corresponding to ``mem_fraction_static``, the fraction of the free GPU memory used for **static** memory like model weights and KV cache. - -- ``actor_rollout_ref.rollout.tensor_model_parallel_size``: TP size for rollout. Only effective - for vllm. - -- ``actor_rollout_ref.rollout.log_prob_micro_batch_size``: [Will be deprecate, use log_prob_micro_batch_size_per_gpu] - The batch size for one forward pass in the computation of ``log_prob``. The value represent the global num. - -- ``actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu``: Micro batch size per gpu (The batch size for - one forward pass) for recalculating ``log_prob``. The value represent the local num per gpu. - -- ``actor_rollout_ref.rollout.do_sample``: Whether to sample during training rollout. If set to False, the rollout model - will perform greedy sampling. - -- ``actor_rollout_ref.rollout.val_kwargs```: Sampling parameters used specifically during validation. - - - ``top_k``: Top-k sampling parameter. Default to -1 for vLLM rollout or 0 for HF rollout. - - ``top_p``: Top-p sampling parameter. Default is 1.0 (disabled). - - ``temperature``: Sampling temperature. Default is 0 (deterministic greedy). - - ``n``: Number of responses to generate during validation. Default is 1. - - ``do_sample``: Whether to use sampling during validation. Default is False for - deterministic outputs. When set to True, the rollout will use the ``actor_rollout_ref.rollout.val_kwargs`` parameters - (top_k, top_p, temperature) to control the sampling behavior. - -- ``actor_rollout_ref.rollout.engine_kwargs.vllm``: extra vllm engine args - - - ``swap_space``: swap space in GB used by the inference engine. Positive integer, e.g., ``32`` means 32 GB. ``null``: means not setting and using the engine default value (usually, e.g., 4 GB for vLLM) - - ``disable_mm_preprocessor_cache``: Whether to disable preprocessor cache for multimodel models. - -- ``actor_rollout_ref.rollout.engine_kwargs.sglang``: extra sglang engine args - - - ``attention_backend``: The attention backend to use for the inference engine. - - - ``null``: means not setting and using the engine default value (usually, e.g., ``fa3`` for SGLang) - - ``flashinfer``: Use flashinfer attention backend. - - ``triton``: Use triton attention backend. - - ``flashmla``: Use flashmla attention backend. - -- ``actor_rollout_ref.rollout.ignore_eos``: Whether to ignore the EOS - token and continue generating tokens after the EOS token is generated. - -- ``actor_rollout_ref.rollout.free_cache_engine``: Offload the KVCache - after rollout generation stage. Default is True. When set to True, - for vllm v0.5.4 and v0.6.3, we need to disable the usage of CUDAGraph - (set ``enforce_eager`` to True.) - -- ``actor_rollout_ref.rollout.enforce_eager``: Whether to use CUDAGraph - in vLLM generation. Default set to True to disable CUDAGraph. - -- ``actor_rollout_ref.rollout.load_format``: Which weight loader to use - to load the actor model weights to the rollout model. - - - ``auto``: Use Megatron weight loader. - - ``megatron``: Use Megatron weight loader. Deployed with Megatron - backend. The input model ``state_dict()`` is already partitioned - along TP dimension and already gathered along PP dimension. This - weight loader requires that the Rollout model and Actor model's - parameters shape and name should be identical. - - ``dtensor``: Default solution when using Huggingface weight loader. - Deployed with FSDP backend and the state_dict_type is - ``StateDictType.SHARDED_STATE_DICT``. Recommend to use this weight - loader - - ``hf``: Use Huggingface weight loader. Deployed with FSDP backend - and the state_dict_type is ``StateDictType.FULL_STATE_DICT``. This - solution doesn't need to rewrite the weight loader for each model - implemented in vLLM but it results in larger peak memory usage. - - ``dummy_hf``, ``dummy_megatron``, ``dummy_dtensor``: Random - initialization. - -.. note:: **NOTED**: In this config field, users only need to select from ``dummy_megatron``, ``dummy_dtensor``, ``dummy_hf`` for rollout initialization and our hybrid engine will select the corresponding weight loader (i.e., ``megatron``, ``dtensor``, ``hf``) during actor/rollout weight synchronization. - - -Megatron Optimizer and Optimizer Parameter Scheduler -____________________________________________________ - -.. code:: yaml - - optim: - optimizer: adam - lr: 1e-6 - clip_grad: 1.0 - total_training_steps: -1 # must be override by program - lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0 - lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - lr_decay_steps: null - lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root - min_lr: 0.0 # minimum learning rate, default to 0.0 - weight_decay: 0.01 - weight_decay_incr_style: constant # select from constant/linear/cosine - lr_wsd_decay_style: exponential # select from constant/exponential/cosine - lr_wsd_decay_steps: null - use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler - - -Notice that there are some differences in APIs between Megatron optimizer and FSDP optimizer. - -- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``warmup_style`` actually means the style of lr decay after warmup. -- Megatron optimizer also support weight decay decay mechanism -- ``use_checkpoint_opt_param_scheduler`` determines whether to use the checkpoint optimizer parameter scheduler. If set to True, the optimizer parameter scheduler will be saved in the checkpoint and loaded from the checkpoint during resuming training. - -For learning rate decay, original Megatron pretrain default option of ``lr_decay_style`` is ``linear``, -meaning that the learning rate will be linearly decayed from the initial learning rate to ``min_lr`` within the -``lr_decay_steps``. However, in verl, to align with FSDP's default behavior, we set the default -``lr_decay_style`` to ``constant``, meaning that the learning rate will be kept constant after the warmup stage. - - -Critic Model -~~~~~~~~~~~~ - -Most parameters for Critic are similar to Actor Model. - -Reward Model -~~~~~~~~~~~~ - -.. code:: yaml - - reward_model: - enable: False - model: - input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical - path: ~/models/Anomy-RM-v0.1 - external_lib: ${actor_rollout_ref.model.external_lib} - trust_remote_code: False - fsdp_config: - min_num_params: 0 - param_offload: False - micro_batch_size_per_gpu: 16 - max_length: null - reward_manager: naive - -- ``reward_model.enable``: Whether to enable reward model. If False, we - compute the reward only with the user-defined reward functions. In - GSM8K and Math examples, we disable reward model. For RLHF alignment - example using full_hh_rlhf, we utilize reward model to assess the - responses. If False, the following parameters are not effective. -- ``reward_model.model`` - - - ``input_tokenizer``: Input tokenizer. If the reward model's chat - template is inconsistent with the policy, we need to first decode to - plaintext, then apply the rm's chat_template. Then score with RM. If - chat_templates are consistent, it can be set to null. - - ``path``: RM's HDFS path or local path. Note that RM only supports - AutoModelForSequenceClassification. Other model types need to define - their own RewardModelWorker and pass it from the code. - - ``trust_remote_code``: Whether to enable loading a remote code model, - default to False. -- ``reward_model.reward_manager``: Reward Manager. This defines the mechanism - of computing rule-based reward and handling different reward sources. Default - is ``naive``. If all verification functions are multiprocessing-safe, the reward - manager can be set to ``prime`` for parallel verification. - -Customized Reward Function -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: yaml - - custom_reward_function: - path: null - name: compute_score - -- ``custom_reward_function.path``: The path to the file containing your customized reward function. If not specified, pre-implemented reward functions will be used. -- ``custom_reward_function.name`` (Optional) : The name of the reward function within the specified file. Default is 'compute_score'. - -Algorithm -~~~~~~~~~ - -.. code:: yaml - - algorithm: - gamma: 1.0 - lam: 1.0 - adv_estimator: gae - use_kl_in_reward: False - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.005 - horizon: 10000 - target_kl: 0.1 - -- ``gamma``: discount factor -- ``lam``: Trade-off between bias and variance in the GAE estimator -- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo`` -- ``use_kl_in_reward``: Whether to enable in-reward kl penalty. Default is False. -- ``kl_penalty``: Support ``kl``, ``abs``, ``mse``, ``low_var_kl`` and ``full``. How to - calculate the kl divergence between actor and reference policy. For - specific options, refer to `kl_penalty()` in `core_algos.py `_ . -- ``kl_ctrl``: Config for in-reward kl_penalty controller - - ``kl_coef``: The (initial) coefficient of in-reward kl_penalty. Default is 0.001. - - ``type``: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController. - - ``horizon`` and ``target_kl``: See source code of AdaptiveKLController for details. - -Trainer -~~~~~~~ - -.. code:: yaml - - trainer: - total_epochs: 30 - project_name: verl_examples - experiment_name: gsm8k - logger: ['console', 'wandb'] - log_val_generations: 0 - nnodes: 1 - n_gpus_per_node: 8 - save_freq: -1 - val_before_train: True - test_freq: 2 - critic_warmup: 0 - default_hdfs_dir: null # hdfs checkpoint path - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} # local checkpoint path - resume_mode: auto # or disable or resume_path if resume_from_path is set - resume_from_path: null - remove_previous_ckpt_in_save: False - del_local_ckpt_after_load: False - ray_wait_register_center_timeout: 300 - -- ``trainer.total_epochs``: Number of epochs in training. -- ``trainer.project_name``: For wandb, swanlab, mlflow -- ``trainer.experiment_name``: For wandb, swanlab, mlflow -- ``trainer.logger``: Support console and wandb, swanlab, mlflow, tensorboard -- ``trainer.log_val_generations``: The number of logged generation during validation (default ``0``) -- ``trainer.nnodes``: Number of nodes used in the training. -- ``trainer.n_gpus_per_node``: Number of GPUs per node. -- ``trainer.save_freq``: The frequency (by iteration) to save checkpoint - of the actor and critic model. -- ``trainer.val_before_train``: Whether to run validation before training. -- ``trainer.test_freq``: The validation frequency (by iteration). -- ``trainer.critic_warmup``: The number of iteration to train the critic - model before actual policy learning. -- ``trainer.resume_mode``: The mode of resuming training. Support - ``disable``, ``auto`` and ``resume_path``. If set to ``auto`` as default, the - program will automatically resume from the latest checkpoint in the - ``default_local_dir``. If set to ``resume_path``, the program will resume - from the path specified in ``resume_from_path``. -- ``trainer.resume_from_path``: The path to resume training from. Only - effective when ``resume_mode`` is set to ``resume_path``. -- ``trainer.remove_previous_ckpt_in_save``: Whether to remove previous - checkpoints in the save directory. Default is False. -- ``trainer.del_local_ckpt_after_load``: Whether to delete local - checkpoints after loading them. Default is False. -- ``trainer.ray_wait_register_center_timeout``: The timeout for waiting - for the ray register center to be ready. Default is 300 seconds. - - -This figure illustrates how the configurations affect the training. - -https://excalidraw.com/#json=pfhkRmiLm1jnnRli9VFhb,Ut4E8peALlgAUpr7E5pPCA - -.. image:: https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d - - -evaluation.yaml ---------------- - -Data -~~~~ - -.. code:: yaml - - data: - path: /tmp/math_Qwen2-7B-Instruct.parquet - prompt_key: prompt - response_key: responses - data_source_key: data_source - reward_model_key: reward_model - -- ``data.path``: Path to the dataset file (Parquet format). -- ``data.prompt_key``: The field in the dataset where the prompt is located. Default is 'prompt'. -- ``data.response_key``: The key holds the generated responses. This should be a list of strings representing the responses. Default is 'responses'. -- ``data.data_source_key``: This is used to separate metric calculations for different data sources, ensuring that metrics are calculated independently for each source. -- ``data.reward_model_key``: The key holds the reference answers. These reference answers typically serve as the ground truth or test cases for the task. - -Customized Reward Function -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: yaml - - custom_reward_function: - path: null - name: compute_score - -- ``custom_reward_function.path``: The path to the file containing your customized reward function. If not specified, pre-implemented reward functions will be used. -- ``custom_reward_function.name`` (Optional) : The name of the reward function within the specified file. Default is 'compute_score'. - -sft_trainer.yaml for SFT FSDP Backend --------------------------------------- - - -Optim -~~~~~~~ - -.. code:: yaml - - optim: - lr: 1e-5 - weight_decay: 0.01 - warmup_steps_ratio: 0.1 - clip_grad: 1.0 - lr_scheduler: cosine - -- ``optim.lr``: Learning rate for the optimizer. -- ``optim.weight_decay``: Weight decay for the optimizer. -- ``optim.warmup_steps_ratio``: Ratio of warmup steps to total training steps. -- ``optim.clip_grad``: Gradient clipping value. -- ``optim.lr_scheduler``: Learning rate scheduler type. Options: - - - ``cosine``: Cosine learning rate scheduler with warmup (default). - - ``wsd``: Warmup-Stable-Decay scheduler that provides a stable learning rate phase between warmup and decay phases. - -Model -~~~~~~~~~~~~ - -Most parameters for Model are similar to Reward Model. - -.. code:: yaml - - model: - partial_pretrain: ~/models/gemma-1.1-7b-it - fsdp_config: - model_dtype: fp32 - wrap_policy: - min_num_params: 0 - cpu_offload: False - offload_params: False - external_lib: null - enable_gradient_checkpointing: False - trust_remote_code: False - lora_rank: 0 - lora_alpha: 16 - target_modules: all-linear - use_liger: False - -- ``partial_pretrain``: HDFS path or local path for the pretrained model. -- ``fsdp_config`` - - - ``model_dtype``: Model parameters type, default to ``fp32``. - Support: ``bf16``, ``fp16``, ``fp32``. - - ``cpu_offload``: Whether to enable CPU offloading for FSDP. If True, - the offload_params will be used as argument. - - ``offload_params``: Whether to offload parameters to CPU - when not involved in computation. If True, then this offloads gradients - to CPU as well, meaning that the optimizer step runs on CPU. - -- ``lora_rank``: The rank of the LoRA model, default to 0. If ``lora_rank``>0, - we will train LoRA modules instead of tuning the full model. -- ``lora_alpha``: The alpha parameter for LoRA scaling, default to 16. -- ``target_modules``: The names of the modules to apply the adapter to, - default to ``all-linear``. See `peft docs `_ for detail. - -- ``use_liger``: Whether to enable Liger kernel, default to False. If True, - we apply Liger kernel to the model (depends on `liger-kernel`). diff --git a/docs/examples/gsm8k_example.rst b/docs/examples/gsm8k_example.rst deleted file mode 100644 index 02d1a526c..000000000 --- a/docs/examples/gsm8k_example.rst +++ /dev/null @@ -1,190 +0,0 @@ -GSM8K Example -============= - -Last updated: 03/25/2025. - -Introduction ------------- - -In this example, we train an LLM to tackle the GSM8k task. - -Paper: https://arxiv.org/pdf/2110.14168 - -Dataset: https://huggingface.co/datasets/gsm8k - -Note that the original paper mainly focuses on training a verifier (a -reward model) to solve math problems via Best-of-N sampling. In this -example, we train an RLHF agent using a rule-based reward model. - -Dataset Introduction --------------------- - -GSM8k is a math problem dataset. The prompt is an elementary school -problem. The LLM model is required to answer the math problem. - -The training set contains 7473 samples and the test set contains 1319 -samples. - -**An example** - -Prompt - - Katy makes coffee using teaspoons of sugar and cups of water in the - ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups - of water, calculate the number of teaspoonfuls of sugar she used. - -Solution - - The total ratio representing the ingredients she used to make the - coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the - number of teaspoons she used is 7/20, she used 7/20\ *120 = - <<7/20*\ 120=42>>42 #### 42 - -Step 1: Prepare dataset ------------------------ - -.. code:: bash - - cd examples/data_preprocess - python3 gsm8k.py --local_dir ~/data/gsm8k - -Step 2: Download Model ----------------------- - -There're three ways to prepare the model checkpoints for post-training: - -- Download the required models from huggingface or modelscope - -.. code:: bash - - huggingface-cli download deepseek-ai/deepseek-math-7b-instruct --local-dir ~/models/deepseek-math-7b-instruct --local-dir-use-symlinks False - # or - modelscope download --model deepseek-ai/deepseek-math-7b-instruct --local_dir ~/models/deepseek-math-7b-instruct - -- Already store your store model in the local directory or HDFS path. -- Also, you can directly use the model name in huggingface (e.g., - deepseek-ai/deepseek-math-7b-instruct) in - ``actor_rollout_ref.model.path`` and ``critic.model.path`` field in - the run script. You can also download models from modelscope by setting environmental variable ``VERL_USE_MODELSCOPE=True``. - See examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh for example. - -Noted that users should prepare checkpoints for actor, critic and reward -model. - -[Optional] Step 3: SFT your Model ---------------------------------- - -We provide a SFT Trainer using PyTorch FSDP in -`fsdp_sft_trainer.py `_. -Users can customize their own SFT -script using our FSDP SFT Trainer. - -We also provide various training scripts for SFT on GSM8K dataset in `gsm8k sft directory `_. - -.. code:: shell - - set -x - - torchrun -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=question \ - data.response_key=answer \ - data.micro_batch_size_per_gpu=8 \ - model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \ - trainer.project_name=gsm8k-sft \ - trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ - trainer.total_epochs=4 \ - trainer.logger='["console","wandb"]' - - -If you use AMD GPUs (ROCm kernel), you need to add the following environment variables into the run script: - - .. code-block:: bash - - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES - export CUDA_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES - - -Step 4: Perform PPO training with your model on GSM8K Dataset -------------------------------------------------------------- - -- Prepare your own run.sh script. Here's an example for GSM8k dataset - and deepseek-llm-7b-chat model. -- Users could replace the ``data.train_files`` ,\ ``data.val_files``, - ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on - their environment. -- See :doc:`config` for detailed explanation of each config field. - -**Reward Model/Function** - -We use a rule-based reward model. We force the model to produce a final -answer following 4 “#” as shown in the solution. We extract the final -answer from both the solution and model's output using regular -expression matching. We compare them and assign a reward of 1 to correct -answer, 0.1 to incorrect answer and 0 to no answer. - -**Training Script** - -The training script example for FSDP and Megatron-LM backend are stored in examples/ppo_trainer directory. - -.. code:: bash - - cd ../ppo_trainer - bash run_deepseek7b_llm.sh - -The script of run_deepseek7b_llm.sh - -.. code:: bash - - set -x - - python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_micro_batch_size_per_gpu=32 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='deepseek_llm_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 $@ - - -If you use AMD GPUs (ROCm kernel), you need to add the following environment variables into the run script: - - .. code-block:: bash - - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES - export CUDA_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES - -If you encounter any issues in using AMD GPUs running VeRL, feel free to contact me - `Yusheng Su `_. \ No newline at end of file diff --git a/docs/examples/multi_modal_example.rst b/docs/examples/multi_modal_example.rst deleted file mode 100644 index 844005b66..000000000 --- a/docs/examples/multi_modal_example.rst +++ /dev/null @@ -1,45 +0,0 @@ -Multi-Modal Example Architecture -================================= - -Last updated: 04/28/2025. - -Introduction ------------- - -Now, verl has supported multi-modal training. You can use fsdp and -vllm/sglang to start a multi-modal RL task. Megatron supports is also -on the way. - -Follow the steps below to quickly start a multi-modal RL task. - -Step 1: Prepare dataset ------------------------ - -.. code:: python - - # it will be saved in the $HOME/data/geo3k folder - python examples/data_preprocess/geo3k.py - -Step 2: Download Model ----------------------- - -.. code:: bash - - # download the model from huggingface - python3 -c "import transformers; transformers.pipeline(model='Qwen/Qwen2.5-VL-7B-Instruct')" - -Step 3: Perform GRPO training with multi-modal model on Geo3K Dataset ---------------------------------------------------------------------- - -.. code:: bash - - # run the task - bash examples/grpo_trainer/run_qwen2_5_vl-7b.sh - - - - - - - - diff --git a/docs/examples/ppo_code_architecture.rst b/docs/examples/ppo_code_architecture.rst deleted file mode 100644 index 94d62413a..000000000 --- a/docs/examples/ppo_code_architecture.rst +++ /dev/null @@ -1,209 +0,0 @@ -PPO Example Architecture -======================== - -Last updated: 02/17/2025. - -Let's start with the Proximal Policy Optimization algorithm, which is -most widely used algorithm in LLM post-training. - -The main entry point of the PPO algorithm example is: -`main_ppo.py `_. -In this tutorial, we will go through the code architecture in `main_ppo.py `_. - -Define the data ---------------- - -Users need to preprocess and store the dataset in parquet files. -And we implement `RLHFDataset` to load and tokenize the parquet files. - -For ``RLHFDataset`` (Default), at least 1 fields are required: - -- ``prompt``: Contains the string prompt - -We already provide some examples of processing the datasets to parquet -files in `data_preprocess directory `_. Currently, we support -preprocess of GSM8k, MATH, Hellasage, Full_hh_rlhf datasets. See :doc:`../preparation/prepare_data` for -more information. - -Define the reward functions for different datasets --------------------------------------------------- - -In this main entry point, the users only need to define their own reward -function based on the datasets (or applications) utilized in PPO -training. - -For example, we already provide reward functions for `GSM8k `_ -and `MATH `_ -datasets in the ``_select_rm_score_fn``. In the ``RewardManager``, we -will compute the reward score based on the data_source to select -corresponding reward functions. For some RLHF datasets (e.g., -full_hh_rlhf), the reward model is utilized to assess the responses -without any reward functions. In this case, the ``RewardManager`` will -return the ``rm_score`` computed by the reward model directly. - -See `reward functions `_ for detailed implementation. - -Define worker classes ---------------------- - -.. code:: python - - if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: # for FSDP backend - assert config.critic.strategy in {"fsdp", "fsdp2"} - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker - from verl.single_controller.ray import RayWorkerGroup - ray_worker_group_cls = RayWorkerGroup - - elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker - from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM - - else: - raise NotImplementedError - - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role - - role_worker_mapping = { - Role.ActorRollout: ActorRolloutRefWorker, - Role.Critic: CriticWorker, - Role.RefPolicy: ActorRolloutRefWorker - } - - global_pool_id = 'global_pool' - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - Role.Critic: global_pool_id, - Role.RefPolicy: global_pool_id, - } - -Step 1: Construct the mapping between roles and workers -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -A role represents a group of workers in the same process. We have -pre-defined several roles in `ray_trainer.py `_. - -.. code:: python - - class Role(Enum): - """ - To create more roles dynamically, you can subclass Role and add new members - """ - Actor = 0 # This worker only has Actor - Rollout = 1 # This worker only has Rollout - ActorRollout = 2 # This worker has both actor and rollout, it's a HybridEngine - Critic = 3 # This worker only has critic - RefPolicy = 4 # This worker only has reference policy - RewardModel = 5 # This worker only has reward model - ActorRolloutRef = 6 # This worker contains actor, rollout and reference policy simultaneously - -Step 2: Define the worker class corresponding to this role -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -- We have pre-implemented the ``ActorRolloutRefWorker``. Through - different configs, it can be a standalone actor, a standalone rollout, - an ActorRollout HybridEngine, or an ActorRolloutRef HybridEngine -- We also pre-implemented workers for ``Actor``, ``Rollout``, - ``Critic``, ``Reward Model`` and ``Reference model`` on two different - backend: PyTorch FSDP - and Megatron-LM. - See `FSDP Workers `_ - and `Megatron-LM Workers `_ - for more information. - -Step 3: Define resource pool id and resource pool spec -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -- Resource pool is a division of global GPU resources, - ``resource_pool_spec`` is a dict, mapping from id to # of GPUs - - - In the above example, we defined a global resource pool: - global_pool_id, and then put all roles on this one resource pool - with all the GPUs in this post-training task. This refers to - *co-locate* placement where all the models share the same set of - GPUs. - -- See resource pool and placement for advance usage. - -Defining reward model/function ------------------------------- - -.. code:: python - - # we should adopt a multi-source reward function here - # - for rule-based rm, we directly call a reward score - # - for model-based rm, we call a model - # - for code related prompt, we send to a sandbox if there are test cases - # - finally, we combine all the rewards together - # - The reward type depends on the tag of the data - if config.reward_model.enable: - from verl.workers.fsdp_workers import RewardModelWorker - role_worker_mapping[Role.RewardModel] = RewardModelWorker - mapping[Role.RewardModel] = global_pool_id - - reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0) - - # Note that we always use function-based RM for validation - val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1) - - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - -Since not all tasks use model-based RM, users need to define here -whether it's a model-based RM or a function-based RM - -- If it's a model-based RM, directly add the ``RewardModel`` role in the - resource mapping and add it to the resource pool mapping. - - - Note that the pre-defined ``RewardModelWorker`` only supports models - with the structure of huggingface - ``AutoModelForSequenceClassification``. If it's not this model, you - need to define your own RewardModelWorker in `FSDP Workers `_ - and `Megatron-LM Workers `_. - -- If it's a function-based RM, the users are required to classified the - reward function for each datasets. - -.. code:: python - - def _select_rm_score_fn(data_source): - if data_source == 'openai/gsm8k': - return gsm8k.compute_score - elif data_source == 'lighteval/MATH': - return math.compute_score - else: - raise NotImplementedError - -See reward functions implemented in `directory `_ -for more information. - -Define, init and run the PPO Trainer ------------------------------------- - -.. code:: python - - trainer = RayPPOTrainer(config=config, - tokenizer=tokenizer, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn) - trainer.init_workers() - trainer.fit() - -- We first initialize the ``RayPPOTrainer`` with user config, tokenizer - and all the above worker mapping, resource pool, worker group and - reward functions -- We first call the ``trainer.init_workers()`` to initialize the models - on the allocated GPUs (in the resource pool) -- The actual PPO training will be executed in ``trainer.fit()`` - -verl can be easily extended to other RL algorithms by reusing the Ray -model workers, resource pool and reward functions. See :doc:`extension<../advance/dpo_extension>` for -more information. - -Details of the ``RayPPOTrainer`` is discussed in :doc:`Ray Trainer<../workers/ray_trainer>`. diff --git a/docs/examples/sandbox_fusion_example.rst b/docs/examples/sandbox_fusion_example.rst deleted file mode 100644 index f3359efda..000000000 --- a/docs/examples/sandbox_fusion_example.rst +++ /dev/null @@ -1,54 +0,0 @@ -Sandbox Fusion Example -============================ - -Last updated: 06/27/2025. - -Introduction ------------- - -Sandbox Fusion is a remote code sandbox service that provides a secure environment for running and evaluating code generated by Large Language Models (LLMs). This example demonstrates how to train an LLM and use Sandbox Fusion to verify generated code, enhancing both security and performance. - -By leveraging a remote code sandbox service with greater CPU resources for concurrent code verification, you can reduce the reward stage time by 10-30%, depending on the quality of the generated code. - -Step 1: Prepare the Dataset ---------------------------- - -We use the Eurus-2-RL-Data dataset for training. This dataset combines math and code questions, making it suitable for LLM training tasks. You can download it from HuggingFace: `Eurus-2-RL-Data Dataset `_. - -Step 2: Set Up the Sandbox Fusion Service ------------------------------------------ - -Sandbox Fusion is a remote code sandbox service designed to securely run and evaluate LLM-generated code. To use it: - -1. **Access Full Documentation**: For detailed setup instructions, refer to the `Sandbox Fusion Documentation `_. -2. **Deploy the Service**: Choose one of the following deployment methods: - - - **Local Deployment**: Follow the guide `here `_. - - **FaaS Instance (Volcengine)**: Create an instance using the `Volcengine Documentation `_. - -After deployment, you will receive an API endpoint in the format: ``https:///run_code``. - -Step 3: Configure the Training Script -------------------------------------- - -To integrate Sandbox Fusion into your training script, configure the following parameters: - -**Key Settings for Sandbox Fusion** - -- ``reward_model.sandbox_fusion.url=''``: Enable Sandbox Fusion by specifying the API endpoint (must end with ``/run_code``). -- ``reward_model.sandbox_fusion.max_concurrent=256``: Set the maximum number of concurrent API requests to the Sandbox Fusion service. -- ``reward_model.sandbox_fusion.memory_limit_mb=1024``: Set the memory limit (in MB) for each sandbox instance. Defaults to 1024MB if not specified. - -**Additional Optimization** - -To further reduce code verification time, enable parallel processing with: - -- ``reward_model.reward_manager=prime``: The Prime reward manager verifies code across multiple subprocesses concurrently. - -**Example Script** - -For a practical implementation, refer to the example script: - -``examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh`` - -Once you’ve set your API endpoint in the script, you can start the training job. \ No newline at end of file diff --git a/docs/faq/faq.rst b/docs/faq/faq.rst deleted file mode 100644 index 328ad6eb7..000000000 --- a/docs/faq/faq.rst +++ /dev/null @@ -1,178 +0,0 @@ -Frequently Asked Questions -==================================== - -Last updated: 06/25/2025. - -Ray related ------------- - -How to add breakpoint for debugging with distributed Ray? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Please checkout the official debugging guide from Ray: https://docs.ray.io/en/latest/ray-observability/ray-distributed-debugger.html - - -"Unable to register worker with raylet" -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The cause of this issue is due to some system setting, e.g., SLURM added some constraints on how the CPUs are shared on a node. -While `ray.init()` tries to launch as many worker processes as the number of CPU cores of the machine, -some constraints of SLURM restricts the `core-workers` seeing the `raylet` process, leading to the problem. - -To fix this issue, you can set the config term ``ray_init.num_cpus`` to a number allowed by your system. - -Distributed training ------------------------- - -How to run multi-node post-training with Ray? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You can start a ray cluster and submit a ray job, following the official guide from Ray: https://docs.ray.io/en/latest/ray-core/starting-ray.html - -Then in the configuration, set the ``trainer.nnode`` config to the number of machines for your job. - -How to use verl on a Slurm-managed cluster? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Ray provides users with `this `_ official -tutorial to start a Ray cluster on top of Slurm. We have verified the :doc:`GSM8K example<../examples/gsm8k_example>` -on a Slurm cluster under a multi-node setting with the following steps. - -1. [Optional] If your cluster support `Apptainer or Singularity `_ and you wish -to use it, convert verl's Docker image to an Apptainer image. Alternatively, set up the environment with the package -manager available on your cluster or use other container runtimes (e.g. through `Slurm's OCI support `_) available to you. - -.. code:: bash - - apptainer pull /your/dest/dir/vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3.sif docker://verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 - -2. Follow :doc:`GSM8K example<../examples/gsm8k_example>` to prepare the dataset and model checkpoints. - -3. Modify `examples/slurm/ray_on_slurm.slurm `_ with your cluster's own information. - -4. Submit the job script to the Slurm cluster with `sbatch`. - -Please note that Slurm cluster setup may vary. If you encounter any issues, please refer to Ray's -`Slurm user guide `_ for common caveats. - -If you changed Slurm resource specifications, please make sure to update the environment variables in the job script if necessary. - - -Install related ------------------------- - -NotImplementedError: TensorDict does not support membership checks with the `in` keyword. -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Detail error information: - -.. code:: bash - - NotImplementedError: TensorDict does not support membership checks with the `in` keyword. If you want to check if a particular key is in your TensorDict, please use `key in tensordict.keys()` instead. - -Cause of the problem: There is no suitable version of tensordict package for the linux-arm64 platform. The confirmation method is as follows: - -.. code:: bash - - pip install tensordict==0.6.2 - -Output example: - -.. code:: bash - - ERROR: Could not find a version that satisfies the requirement tensordict==0.6.2 (from versions: 0.0.1a0, 0.0.1b0, 0.0.1rc0, 0.0.2a0, 0.0.2b0, 0.0.3, 0.1.0, 0.1.1, 0.1.2, 0.8.0, 0.8.1, 0.8.2, 0.8.3) - ERROR: No matching distribution found for tensordict==0.6.2 - -Solution 1st: - Install tensordict from source code: - -.. code:: bash - - pip uninstall tensordict - git clone https://github.com/pytorch/tensordict.git - cd tensordict/ - git checkout v0.6.2 - python setup.py develop - pip install -v -e . - -Solution 2nd: - Temperally modify the error takeplace codes: tensordict_var -> tensordict_var.keys() - - -Illegal memory access ---------------------------------- - -If you encounter the error message like ``CUDA error: an illegal memory access was encountered`` during rollout, please check the vLLM documentation for troubleshooting steps specific to your vLLM version. - -Checkpoints ------------------------- - -If you want to convert the model checkpoint into huggingface safetensor format, please refer to ``verl/model_merger``. - - -Triton ``compile_module_from_src`` error ------------------------------------------------- - -If you encounter triton compilation error similar to the stacktrace below, please set the ``use_torch_compile`` flag according to -https://verl.readthedocs.io/en/latest/examples/config.html to disable just-in-time compilation for fused kernels. - -.. code:: bash - - File "/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in - return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) - File "/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 338, in run - return self.fn.run(*args, **kwargs) - File "/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/jit.py", line 607, in run - device = driver.active.get_current_device() - File "/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/driver.py", line 23, in __getattr__ - self._initialize_obj() - File "/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/driver.py", line 20, in _initialize_obj - self._obj = self._init_fn() - File "/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/driver.py", line 9, in _create_driver - return actives[0]() - File "/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 371, in __init__ - self.utils = CudaUtils() # TODO: make static - File "/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 80, in __init__ - mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils") - File "/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 57, in compile_module_from_src - so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) - File "/data/lbh/conda_envs/verl/lib/python3.10/site-packages/triton/runtime/build.py", line 48, in _build - ret = subprocess.check_call(cc_cmd) - File "/data/lbh/conda_envs/verl/lib/python3.10/subprocess.py", line 369, in check_call - raise CalledProcessError(retcode, cmd) - -What is the meaning of train batch size, mini batch size, and micro batch size? ------------------------------------------------------------------------------------------- - -This figure illustrates the relationship between different batch size configurations. - -https://excalidraw.com/#json=pfhkRmiLm1jnnRli9VFhb,Ut4E8peALlgAUpr7E5pPCA - -.. image:: https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d - -How to generate ray timeline to analyse performance of a training job? ------------------------------------------------------------------------------------------- - -To generate the ray timeline file, you can set the config term ``ray_init.timeline_file`` to a json file path. -For example: - -.. code:: bash - - ray_init.timeline_file=/tmp/ray_timeline.json - -The file will be generated in the specified path at the end of a training job. -You can use tools like chrome://tracing or the Perfetto UI and view the ray timeline file. - -This figure shows the ray timeline file generated by from a training job on 1 node with 4 GPUs - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray_timeline.png?raw=true - -How to set proxy only for wandb? ------------------------------------------------------------------------------------------- - -If you need a proxy to access wandb, you can add below config in your training job script. -Comparing to using global https_proxy env variable, this approach won't mess up other http requests, such as ChatCompletionScheduler. - -.. code:: bash - - +trainer.wandb_proxy=http:// \ No newline at end of file diff --git a/docs/hybrid_flow.rst b/docs/hybrid_flow.rst deleted file mode 100644 index 3aa5a4a97..000000000 --- a/docs/hybrid_flow.rst +++ /dev/null @@ -1,266 +0,0 @@ -========================================================= -HybridFlow Programming Guide -========================================================= - -Last updated: 06/02/2025. - -.. _vermouth: https://github.com/vermouth1992 - -Author: `Chi Zhang `_ - -verl is an open source implementation of the paper `HybridFlow `_ [1]_. In this section, we will introduce the basic concepts of HybridFlow, the motivation and how to program with verl APIs. - -Motivation and Design ------------------------- -We use dataflow to represent RL systems. [4]_. - -DataFlow -~~~~~~~~~~~~~~~~~~~~ - -Dataflow is an abstraction of computations. Neural Network training is a typical dataflow. It can be represented by computational graph. - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/dataflow.jpeg?raw=true - :alt: The dataflow graph from CS231n 2024 lecture 4 - -This figure [2]_ represents the computation graph of a polynomial function followed by a sigmoid function. In the data flow of neural network computation, each node represents an operator, and each edge represents the direction of forward/backward propagation. The computation graph determines the architecture of the neural network. - -RL as a dataflow problem -++++++++++++++++++++++++++++++++++++++++++++++ - -Reinforcement learning (RL) training can also be represented as a dataflow. Below is the dataflow graph that represents the PPO algorithm used in RLHF [3]_: - -.. image:: https://picx.zhimg.com/70/v2-cb8ab5ee946a105aab6a563e92682ffa_1440w.avis?source=172ae18b&biz_tag=Post - :alt: PPO dataflow graph, credit to Zhihu 低级炼丹师 - -However, the dataflow of RL has fundamental differences compared with dataflow of neural network training as follows: - -+--------------------------+--------------------------------------------------+---------------------+ -| Workload | Node | Edge | -+--------------------------+--------------------------------------------------+---------------------+ -| Neural Network Training | Operator (+/-/matmul/softmax) | Tensor movement | -+--------------------------+--------------------------------------------------+---------------------+ -| Reinforcement Learning | High-level operators (rollout/model forward) | Data Movement | -+--------------------------+--------------------------------------------------+---------------------+ - -In the case of tabular reinforcement learning, each operator is a simple scalar math operation (e.g., bellman update). In deep reinforcement learning(DRL), each operator is a high-level neural network computation such as model inference/update. This makes RL a two-level dataflow problem: - -- Control flow: defines how the high-level operators are executed (e.g., In PPO, we first perform rollout. Then, we perform advantage computation. Finally, we perform training). It expresses the **core logics of RL algorithms**. -- Computation flow: defines the dataflow of **neural network computation** (e.g., model forward/backward/optimizer). - - -Design Choices -~~~~~~~~~~~~~~~~~~~~ -The model size used in DRL before the LLM era is typically small. Thus, the high-level neural network computation can be done in a single process. This enables embedding the computation flow inside the control flow as a single process. - -However, in the LLM era, the computation flow (e.g., training neural network) becomes a multi-process program. This naturally leads to two design choices: - -1. Convert the control flow into a multi-process program as well. Then colocate with computation flow (unified multi-controller) - -- Advantages: - - - Achieves the **optimal performance** under fixed computation flow and control flow as the communication overhead in both training and data transfer is minimized. - -- Disadvantages: - - - The computation and/or control flow is **hard to reuse** from software perspective as computation code is coupled with specific controller code. For example, the training loop of PPO is generic. Say we have an PPO training flow implemented with a specific computation flow such as FSDP. Neither the control flow or computation flow can be reused if we want to switch the computation flow from FSDP to Megatron, due to the coupling of control and computation flows. - - Requires more efforts from the user under flexible and dynamic control flows, due to the multi-process nature of the program. - -2. Separate the flows: single process for the control flow and multi-process for computation flow - -- Advantages: - - - The computation flow defined elsewhere can be **easily reused** after the decoupling. - - The controller runs on a single process. Implementing a new RL algorithm with a **different control flow is simple and easy**. - -- Disadvantages: - - - Additional **data communication overhead** each time the controller process and computatation processes interact. The data has to be sent back and forth. - -In verl, the latter strategy with separate control flow and computation flow is adopted. verl is designed to decouple the control flow of RL algorithms, and the implementation of computation engines. - -Overall Execution Diagram -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Below is a simplified diagram denoting the execution of a reinforcement learning job. In the diagram, the controller runs on a single process, while the generator/actor workers, critic workers run on multiple processes, placed with specific resource groups. For rollout, the controller passes the data to the generator to perform sample generation. When the rollout is done, the data is passed back to controller for the next step of the algorithm. Similar execution is done for other workers. With the hybrid controller design, the data flow and computation is decoupled to provide both efficiency in computation and flexibility in defining algorithm training loops. - -.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/driver_worker.png?raw=true - :alt: The execution diagram - -Codebase walkthrough (PPO) ------------------------------------------------- - -Entry function -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Code: https://github.com/volcengine/verl/blob/main/verl/trainer/main_ppo.py - -In this file, we define a remote function `main_task` that serves as the controller (driver) process as shown in the above figure. We also define a ``RewardManager``, where users can customize their reward function based on the data source in the dataset. Note that `RewardManager` should return the final token-level reward that is optimized by RL algorithms. Note that users can combine model-based rewards and rule-based rewards. -The ``main_task`` constructs a RayPPOTrainer instance and launch the fit. Note that ``main_task`` **runs as a single process**. - -We highly recommend that the ``main_task`` is NOT scheduled on the head of the ray cluster because ``main_task`` will consume a lot of memory but the head usually contains very few resources. - -Ray trainer -~~~~~~~~~~~~~~~~~~~~ -Code: https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/ray_trainer.py - -The RayPPOTrainer manages - -- Worker and WorkerGroup construction -- Runs the main loop of PPO algorithm - -Note that, the fit function of RayPPOTrainer **runs as a single process**. - -Worker and WorkerGroup construction -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Each workerGroup manages a list of workers that runs remotely. Note that the worker group runs in the process of its constructor. -Each worker inside the WorkerGroup runs on a GPU. The worker group serves as a proxy for the controller process to interact with a list of workers, in order to perform certain computations. **In order to do so, we have to bind the methods of the worker into the method of the WorkerGroup and define the data dispatch and data collection**. This is done via simple decoration that will be introduced in the Worker definition section. - -For example, in PPO, we define 3 worker groups: - -- ActorRolloutRef: manages actor, rollout and reference policy. ActorRolloutRefWorker can be instantiated as a single actor, a single rollout, a single reference policy, a combined actor/rollout or a combined actor/rollout/ref. This design is aimed for the maximum code reuse in various scenarios. The reason for colocating actor and rollout is for fast weight transfer using nccl. The reason for coloating actor and reference is to implement an efficient lora PPO as the reference policy is simply the base model of PPO in lora. The colocation is done via ``verl.single_controller.ray.base.create_colocated_worker_cls``, where it creates a single ray remote class exposing all class methods from these roles. -- Critic: manages the critic model -- Reward: manages the reward model - -The worker group will be constructed on the resource pool it designates. The resource pool is a set of GPUs in the ray cluster. - -Worker definition -~~~~~~~~~~~~~~~~~~~~ - -.. _ActorRolloutRefWorker: https://github.com/volcengine/verl/blob/main/verl/workers/fsdp_workers.py - -We take `ActorRolloutRefWorker `_ for an example. -The APIs it should expose to the controller process are: - -- init_model: build the underlying model -- generate_sequences: given prompts, generate responses -- compute_log_prob: compute the log-probability of a generated sequence using actor -- compute_ref_log_prob: compute the log-probability of a generated sequence using reference policy -- save_checkpoint: save the checkpoint - -Note that these methods are defined in the worker that can only be invoked via remote calls. For example, if the controller process wants to initialize the model, it has to call - -.. code-block:: python - - for worker in actor_rollout_ref_wg: - worker.init_model.remote() - -If the controller process wants to generate sequences, it has to call - -.. code-block:: python - - data = xxx - # split the data into dp chunks - data_dp_lst = data.split(dp_size) - output_dp_lst = [] - for i, worker in enumerate(actor_rollout_ref_wg): - output_future = worker.generate_sequences.remote(data_dp_lst[i]) - output_dp_lst.append(output_future) - output = torch.cat(ray.get(output_dp_lst), dim=0) - -We observe that controller process calling worker group methods in general can be divided into 3 parts: - -- Split the data into data parallel sizes -- Dispatch the corresponding data into each worker -- Collect and concatenate the data when the computation finishes - -In verl, we design a syntax sugar to encapsulate the 3 processes into a single call from the controller process. - -.. code-block:: python - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def generate_sequences(data): - ... - - # on the driver - output = actor_rollout_ref_wg.generate_sequences(data) - -We decorate the method of the worker with a ``register`` that explicitly defines how the input data should be split and dispatched to each worker, and how the output data should be collected and concatenated by the controller. For example, ``Dispatch.DP_COMPUTE_PROTO`` splits the input data into dp chunks, dispatch each data to each worker, collect the output and concatenate the results. Note that this function requires the input and output to be a DataProto defined here (https://github.com/volcengine/verl/blob/main/verl/protocol.py). - - -PPO main loop -~~~~~~~~~~~~~~~~~~~~ -With the aforementioned APIs, we can implement the main loop of PPO as if it is a single process program - -.. code-block:: python - - for prompt in dataloader: - output = actor_rollout_ref_wg.generate_sequences(prompt) - old_log_prob = actor_rollout_ref_wg.compute_log_prob(output) - ref_log_prob = actor_rollout_ref_wg.compute_ref_log_prob(output) - values = critic_wg.compute_values(output) - rewards = reward_wg.compute_scores(output) - # compute_advantages is running directly on the control process - advantages = compute_advantages(values, rewards) - output = output.union(old_log_prob) - output = output.union(ref_log_prob) - output = output.union(values) - output = output.union(rewards) - output = output.union(advantages) - # update actor - actor_rollout_ref_wg.update_actor(output) - critic.update_critic(output) - -Takeaways -~~~~~~~~~~~~~~~~~~~~ -- This programming paradigm enables users to use different computation backend without modification of the control process. -- This programming paradigm enables flexible placement (by changing the mapping of WorkerGroup and ResourcePool) without modification of the control process. - -Repository organization ------------------------------------------------- - -Important code files in the repository are organized as below: - -.. code-block:: bash - - verl # the verl package - trainer - main_ppo.py # the entrypoint for RL training - ppo - ray_trainer.py # the training loop for RL algorithms such as PPO - fsdp_sft_trainer.py # the SFT trainer with FSDP backend - config - generation.yaml # configuration template for rollout - ppo_trainer.yaml # configuration template for the RL trainer - workers - protocol.py # the interface of DataProto - fsdp_workers.py # the FSDP worker interfaces: ActorRolloutRefWorker, CriticWorker, RewardModelWorker - megatron_workers.py # the Megatron worker interfaces: ActorRolloutRefWorker, CriticWorker, RewardModelWorker - actor - dp_actor.py # data parallel actor with FSDP backend - megatron_actor.py # nD parallel actor with Megatron backend - critic - dp_critic.py # data parallel critic with FSDP backend - megatron_critic.py # nD parallel critic with FSDP backend - reward_model - megatron - reward_model.py # reward model with Megatron backend - rollout - vllm - vllm_rollout.py # rollout with vllm backend - hf_rollout.py # rollout with huggingface TGI backend - sharding_manager - fsdp_ulysses.py # data and model resharding when using FSDP + ulysses - fsdp_vllm.py # data and model resharding when using FSDP + ulysses + vllm - megatron_vllm.py # data and model resharding when using Megatron + vllm - utils - dataset # datasets for SFT/RM/RL - reward_score # function based reward - gsm8k.py # reward function for gsm8k dataset - math.py # reward function for math dataset - seqlen_balancing.py # the sequence balance optimization - models - llama # Megatron implementation for llama, deepseek, mistral, etc - transformers # ulysses integration with transformer models such as llama, qwen, etc - weight_loader_registery.py # registry of weight loaders for loading hf ckpt into Megatron - third_party - vllm # adaptor for vllm's usage in RL - vllm_spmd # vllm >= v0.7 adaptor - examples # example scripts - tests # integration and unit tests - .github # the configuration of continuous integration tests - - -.. [1] HybridFlow: A Flexible and Efficient RLHF Framework: https://arxiv.org/abs/2409.19256v2 -.. [2] Data flow graph credit to CS231n 2024 lecture 4: https://cs231n.stanford.edu/slides/2024/lecture_4.pdf -.. [3] PPO dataflow graph credit to 低级炼丹师 from Zhihu​: https://zhuanlan.zhihu.com/p/635757674 -.. [4] RLFlow diff --git a/docs/index.rst b/docs/index.rst deleted file mode 100644 index 980066a7f..000000000 --- a/docs/index.rst +++ /dev/null @@ -1,186 +0,0 @@ -Welcome to verl's documentation! -================================================ - -verl is a flexible, efficient and production-ready RL training framework designed for large language models (LLMs) post-training. It is an open source implementation of the `HybridFlow `_ paper. - -verl is flexible and easy to use with: - -- **Easy extension of diverse RL algorithms**: The hybrid programming model combines the strengths of single-controller and multi-controller paradigms to enable flexible representation and efficient execution of complex Post-Training dataflows. Allowing users to build RL dataflows in a few lines of code. - -- **Seamless integration of existing LLM infra with modular APIs**: Decouples computation and data dependencies, enabling seamless integration with existing LLM frameworks, such as PyTorch FSDP, Megatron-LM, vLLM and SGLang. Moreover, users can easily extend to other LLM training and inference frameworks. - -- **Flexible device mapping and parallelism**: Supports various placement of models onto different sets of GPUs for efficient resource utilization and scalability across different cluster sizes. - -- Ready integration with popular HuggingFace models - - -verl is fast with: - -- **State-of-the-art throughput**: By seamlessly integrating existing SOTA LLM training and inference frameworks, verl achieves high generation and training throughput. - -- **Efficient actor model resharding with 3D-HybridEngine**: Eliminates memory redundancy and significantly reduces communication overhead during transitions between training and generation phases. - --------------------------------------------- - -.. _Contents: - -.. toctree:: - :maxdepth: 2 - :caption: Quickstart - - start/install - start/quickstart - start/multinode - start/ray_debug_tutorial - start/more_resources - start/agentic_rl - -.. toctree:: - :maxdepth: 2 - :caption: Programming guide - - hybrid_flow - single_controller - -.. toctree:: - :maxdepth: 1 - :caption: Data Preparation - - preparation/prepare_data - preparation/reward_function - -.. toctree:: - :maxdepth: 2 - :caption: Configurations - - examples/config - -.. toctree:: - :maxdepth: 1 - :caption: PPO Example - - examples/ppo_code_architecture - examples/gsm8k_example - examples/multi_modal_example - -.. toctree:: - :maxdepth: 1 - :caption: Algorithms - - algo/ppo.md - algo/grpo.md - algo/dapo.md - algo/spin.md - algo/sppo.md - algo/entropy.md - algo/opo.md - algo/baseline.md - algo/gpg.md - -.. toctree:: - :maxdepth: 1 - :caption: PPO Trainer and Workers - - workers/ray_trainer - workers/fsdp_workers - workers/megatron_workers - workers/sglang_worker - -.. toctree:: - :maxdepth: 1 - :caption: Performance Tuning Guide - - perf/dpsk.md - perf/perf_tuning - README_vllm0.8.md - perf/device_tuning - perf/nsight_profiling.md - -.. toctree:: - :maxdepth: 1 - :caption: Adding new models - - advance/fsdp_extension - advance/megatron_extension - -.. toctree:: - :maxdepth: 1 - :caption: Advanced Features - - advance/checkpoint - advance/rope - advance/ppo_lora.rst - sglang_multiturn/multiturn.rst - sglang_multiturn/interaction_system.rst - advance/placement - advance/dpo_extension - examples/sandbox_fusion_example - advance/rollout_trace.rst - -.. toctree:: - :maxdepth: 1 - :caption: Hardware Support - - amd_tutorial/amd_build_dockerfile_page.rst - amd_tutorial/amd_vllm_page.rst - ascend_tutorial/ascend_quick_start.rst - ascend_tutorial/ascend_profiling.rst - ascend_tutorial/ascend_profiling_en.rst - -.. toctree:: - :maxdepth: 1 - :caption: API References - - api/data - api/single_controller.rst - api/trainer.rst - api/utils.rst - - -.. toctree:: - :maxdepth: 2 - :caption: FAQ - - faq/faq - -.. toctree:: - :maxdepth: 1 - :caption: Development Notes - - sglang_multiturn/sandbox_fusion.rst - -Contribution -------------- - -verl is free software; you can redistribute it and/or modify it under the terms -of the Apache License 2.0. We welcome contributions. -Join us on `GitHub `_, `Slack `_ and `Wechat `_ for discussions. - -Contributions from the community are welcome! Please check out our `project roadmap `_ and `good first issues `_ to see where you can contribute. - -Code Linting and Formatting -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We use pre-commit to help improve code quality. To initialize pre-commit, run: - -.. code-block:: bash - - pip install pre-commit - pre-commit install - -To resolve CI errors locally, you can also manually run pre-commit by: - -.. code-block:: bash - - pre-commit run - -Adding CI tests -^^^^^^^^^^^^^^^^^^^^^^^^ - -If possible, please add CI test(s) for your new feature: - -1. Find the most relevant workflow yml file, which usually corresponds to a ``hydra`` default config (e.g. ``ppo_trainer``, ``ppo_megatron_trainer``, ``sft_trainer``, etc). -2. Add related path patterns to the ``paths`` section if not already included. -3. Minimize the workload of the test script(s) (see existing scripts for examples). - -We are HIRING! Send us an `email `_ if you are interested in internship/FTE opportunities in MLSys/LLM reasoning/multimodal alignment. diff --git a/docs/perf/device_tuning.rst b/docs/perf/device_tuning.rst deleted file mode 100644 index 567683b3b..000000000 --- a/docs/perf/device_tuning.rst +++ /dev/null @@ -1,281 +0,0 @@ -Hardware Resource Needed for RL -=============================== - -Last updated: 06/25/2025. - -Since RL requires more resources compared to regular training, -determining how much resources are needed to successfully run it before training -is a relatively difficult task. To provide more people with reference points for -resource selection when dealing with different models and tasks, this section is -mainly dedicated to introducing the environmental requirements based on experiments -we have conducted. - -However, due to limited staff and equipment resources, we also hope for more -contributions from the open-source community. When submitting a PR, it is necessary -to provide a script to be added to the example/tuning scripts. - -We need two types of scripts: one is the configuration that can run with the **minimum -resources(min)**, and the other is the configuration that runs with **recommended resources(recommended)**. For the former, -it can be understood as a script that can run after applying all memory optimization techniques -(e.g., offload, gradient checkpointing). For the latter, it can be understood as a script that -can run while avoiding operations that incur additional time overhead as much as possible (targetting best throughput). - -When defining script names, please follow this format: -``[model]_[task]_[gpunums]_[device]_[train]_[infer].sh``. This will effectively improve -the script's recognizability. You can place the script under the ``examples/tuning/`` directory. - -If you happen to have a configuration that has already been tested, we welcome you to submit -a PR and include a screenshot from Wandb or other verifiable evidence. - ----------------------------------------- - -0.5B -~~~ - -.. list-table:: - :widths: auto - :header-rows: 1 - - * - Tag - - Model - - Task - - Resource - - MaxBatch - - Train - - Infer - - Link - - Contributor - * - MIN - - Qwen2.5-0.5B - - GRPO-LoRA - - 1*H100 - - 116 - - fsdp - - vllm0.8.3 - - `qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh `_ - - `SimonHuang `_ - -1.5B -~~~ - -.. list-table:: - :widths: auto - :header-rows: 1 - - * - Tag - - Model - - Task - - Resource - - MaxBatch - - Train - - Infer - - Link - - Contributor - * - MIN - - Qwen2.5-1.5B - - GRPO-LoRA - - 1*H100 - - 128 - - fsdp - - vllm0.8.3 - - `qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh `_ - - `SimonHuang `_ - -3B -~~~ - -.. list-table:: - :widths: auto - :header-rows: 1 - - * - Tag - - Model - - Task - - Resource - - MaxBatch - - Train - - Infer - - Link - - Contributor - * - MIN - - Qwen2.5-3B - - GRPO-LoRA - - 1*H100 - - 62 - - fsdp - - vllm0.8.3 - - `qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh `_ - - `SimonHuang `_ - -7B -~~~ - -.. list-table:: - :widths: auto - :header-rows: 1 - - * - Tag - - Model - - Task - - Resource - - MaxBatch - - Train - - Infer - - Link - - Contributor - * - MIN - - Qwen2-7B - - GRPO - - 2*H800 - - \ - - fsdp - - vllm0.8.2 - - `qwen2-7b_grpo_2_h800_fsdp_vllm `_ - - `Xiangyongan `_ - * - MIN - - Qwen2.5-7B - - GRPO-LoRA - - 1*H100 - - 16 - - fsdp - - vllm0.8.3 - - `qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh `_ - - `SimonHuang `_ - -14B -~~~ - -.. list-table:: - :widths: auto - :header-rows: 1 - - * - Tag - - Model - - Task - - Resource - - MaxBatch - - Train - - Infer - - Link - - Contributor - * - MIN - - Qwen2-14B - - GRPO - - 4*H800 - - \ - - fsdp - - vllm0.8.2 - - `qwen2-14b_grpo_4_h800_fsdp_vllm `_ - - `Xiangyongan `_ - * - MIN - - Qwen2.5-14B - - GRPO-LoRA - - 2*H100 - - 116 - - fsdp - - vllm0.8.3 - - `qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh `_ - - `SimonHuang `_ - -32B -~~~ - -.. list-table:: - :widths: auto - :header-rows: 1 - - * - Tag - - Model - - Task - - Resource - - MaxBatch - - Train - - Infer - - Link - - Contributor - * - MIN - - Qwen2-32B - - GRPO - - 8*H20 - - \ - - megatron - - vllm0.8.2 - - `qwen2-32b_grpo_8_h20_megatron_vllm `_ - - `Xiangyongan `_ - * - MIN - - Qwen2.5-32B - - GRPO-LoRA - - 4*H100 - - 180 - - fsdp - - vllm0.8.3 - - `qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh `_ - - `SimonHuang `_ - -70B -~~~ - -.. list-table:: - :widths: auto - :header-rows: 1 - - * - Tag - - Model - - Task - - Resource - - MaxBatch - - Train - - Infer - - Link - - Contributor - * - MIN - - Qwen2-70B - - GRPO - - 32*H20 - - \ - - fsdp - - vllm0.8.2 - - `qwen2-70b_grpo_32_h20_fsdp_vllm `_ - - `Xiangyongan `_ - * - MIN - - Qwen2-70B - - GRPO - - 32*H800 - - \ - - fsdp - - vllm0.8.3 - - `qwen2-70b_grpo_32_h800_fsdp_vllm `_ - - `Xiangyongan `_ - * - MIN - - Qwen2.5-72B - - GRPO-LoRA - - 8*H100 - - 176 - - fsdp - - vllm0.8.3 - - `qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh `_ - - `SimonHuang `_ - -405B -~~~~ - -.. table:: - :widths: auto - - ====== ====== ====== ======== ======== ====== ====== ====== - tag model task resource MaxBatch train infer link - ====== ====== ====== ======== ======== ====== ====== ====== - \ \ \ \ \ \ \ - ====== ====== ====== ======== ======== ====== ====== ====== - -671B -~~~~ - -.. table:: - :widths: auto - - ====== ====== ====== ======== ======== ====== ====== ====== - tag model task resource MaxBatch train infer link - ====== ====== ====== ======== ======== ====== ====== ====== - \ \ \ \ \ \ \ - ====== ====== ====== ======== ======== ====== ====== ====== diff --git a/docs/perf/dpsk.md b/docs/perf/dpsk.md deleted file mode 100644 index 0a3b42a11..000000000 --- a/docs/perf/dpsk.md +++ /dev/null @@ -1,51 +0,0 @@ -# Training DeepSeek 671b - -Last updated: 06/13/2025. - -verl integrates Megatron to support large MoE models such as `Qwen3-235B-A22B` and `deepseek-ai/DeepSeek-V3`. This is an ongoing community effort. - -In the journey the community added the following features and optimizations that enable verl with larger models: -- per tensor weight resharding between rollout and training -- context parallelism and expert parallelism enabled via megatron -- dynamic batch size (sequence balance) for megatron -- reduced ray-related serialization overhead -- optimizer offloading, recomputation, and efficient kernels -- various debugging metrics and utils - -and the megatron backend now has a wider list of models supported: -- DeepSeek-V3 -- Moonlight -- Qwen3 -- Qwen2.5-VL (to be merged soon) -- Qwen2 -- Mixtral - -## Getting Started - -### DeepSeek 671b - -The recommended image with pre-built megatron dependency is `whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.1-te2.3-deepseekv3`, built with the Dockerfile in [docker/Dockerfile.vllm.sglang.megatron.deepseek](https://github.com/volcengine/verl/blob/main/docker/Dockerfile.vllm.sglang.megatron.deepseek). - -For checkpoint loading, we rely on megatron dist-ckpt for resharding. A converted dist-ckpt for DeepSeek-V3 is available from [huggingface BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt](https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main). - -To run end-to-end training on the DAPO dataset, run [recipe/dapo/test_dapo_dspk_671b_megatron.sh](https://github.com/volcengine/verl/blob/main/recipe/dapo/test_dapo_dspk_671b_megatron.sh). It runs on 512 H20(96GB) GPUs with the following setup: -- vllm rollout with TP=32, bfloat16 -- megatron training with attention DP, MoE EP=32, PP=16, bfloat16 - -MTP is disabled during RL training. - -### Qwen3 236b - -For Qwen3-236b, please refer to [examples/grpo_trainer/run_qwen3-236b_megatron.sh](https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen3-236b_megatron.sh), which runs on 128 H20(96GB) GPUs. - -## Upcoming Optimizations - -The community continue to optimize large MoE models further, ongoing efforts include: -- further optimizing memory consumption, and provide recommended/tuned configurations with various machine types -- optimizing long context RL training performance -- performance improvement with SGLang x Megatron - -We invite the community to try and improve verl together. Get connected with us on [slack](https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA)/[wechat](https://raw.githubusercontent.com/eric-haibin-lin/verl-community/refs/heads/main/WeChat.JPG)/[Github issues](https://github.com/volcengine/verl/issues/708)! - -## Acknowledgement -@vermouth1992 @ISEEKYAN @ETOgaosion @yzlnew @ShareLer @BearBiscuit05 @ccclyu @ann-qin-lu @SwordFaith @zzong2006 @zhaochenyang20 @ocss884 @eric-haibin-lin diff --git a/docs/perf/nsight_profiling.md b/docs/perf/nsight_profiling.md deleted file mode 100644 index ed083c38e..000000000 --- a/docs/perf/nsight_profiling.md +++ /dev/null @@ -1,107 +0,0 @@ -# NVIDIA Nsight Systems profiling in verl - -Last updated: 06/20/2025. - -This guide explains how to use NVIDIA Nsight Systems for profiling verl training runs. - -## Configuration - -Profiling in verl can be configured through several parameters in the trainer configuration file (ppo_trainer.yaml or other files like dapo_trainer.yaml): - -### Prerequisites - -Nsight Systems version is important, please reference `docker/Dockerfile.vllm.sglang.megatron` for the version we used. - -### Global profiling control - -verl has one single controller process and multiple worker processes. Both controller and worker processes can be profiled. Since the controller process can be executed in any nodes in the cluster, there is a message printed in the logging to indicate the controller process node hostname and process id. - -In `trainer`, three new config entries control the profiler behaviors: - -* **`trainer.profile_steps`**. List of step numbers at which profiling should be performed. For example: [1, 2, 5] will profile steps 1, 2, and 5. And ``null`` means no profiling. - - -* **`controller_nsight_options`**. This config group is for the single controller. All fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. `ppo_trainer.yaml` provides a workable example. Users can reference [Nsight Systems manual](https://docs.nvidia.com/nsight-systems/UserGuide/index.html) and [Ray user guide](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html) for more details. - -* **`worker_nsight_options`**. This config group is for the worker processes. Similarly all fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. Capture range is used to control the profiler when to start and stop. So `capture-range: "cudaProfilerApi"` is fixed and does not change it. Users can change `capture-range-end` with some accurate calculation or just leave it `null`. - -### Worker process profiling - -Verl manages mulitiple RL roles, _Actor_, _Ref_, _Rollout_, _Critic_, _Reward_, which are implemented in different Worker classes. And these workers can be combined into one Ray Actor, running in a process group. Each RL role has its own profiling config group, `profiler`, which consists of three fields: - -* **`all_ranks` and `ranks`**. When `all_ranks` is set `True` then all ranks will be profiled; when set `False`, `ranks` will be profiled. By default, verl profiles the whole training process in a series ` worker_process_..nsys-rep` files for each process rank. PID is the process ID; RID is the capture range ID. - -* **`discrete`**. When set `False`, all the roles actions in one training step will be dumped in one database. When set `True`, the actions annotated by `DistProfiler.annotate` will be dumped into a discrete database. In this case, each role's action occupies one ``. - -* **`actor_rollout_ref`**. This Worker can be configured to contain at most 3 roles and executes together. So `actor_rollout_ref` has a `profiler` config and all the inside roles inherit it. - -* **Verl collocate mode**. Verl can combine two Worker sub classes to one Worker Actor. In this case, the user should take care that the combined Workers have consistent `discrete`. The Nsight Systems profiler uses a `torch.cuda.profiler.start()` and `stop()` pair to dump a `` database anyway. - -### where to find the profiling data - -By default the `*.nsys-rep` files are saved in the directory `/tmp/ray/session_latest/logs/nsight/` at each node. According to the Ray manual, this default directory is not changeable. ["however, Ray preserves the `--output` option of the default config"](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html). - -Some users may think it is not convenient, but it is understandable that Ray may start hundreds of processes and it would be a big network file system pressure if we save the files in one central place. - -## Usage Example - -To enable profiling for specific components and steps, modify your ppo_trainer.yaml like this: - -### Disable profiler -```yaml - trainer: - profile_steps: null # disable profile -``` - -### Enable profiler and one database for one training step -```yaml - trainer: - profile_steps: [1, 2, 5] - actor_rollout_ref: - profiler: - discrete: False - all_ranks: False - ranks: [0, 1] - critic: - profiler: - discrete: False - all_ranks: False - ranks: [0, 1] - reward_model: - profiler: - discrete: False - all_ranks: False - ranks: [0, 1] -``` - -### Enable profiler and multiple databases for one training step -```yaml - trainer: - profile_steps: [1, 2, 5] - actor_rollout_ref: - profiler: - discrete: True - all_ranks: False - ranks: [0, 1] - critic: - profiler: - discrete: True - all_ranks: False - ranks: [0, 1] - reward_model: - profiler: - discrete: True - all_ranks: False - ranks: [0, 1] -``` - -## Profiling Output - -When profiling is enabled, verl will generate Nsight Systems profiles for the specified components and steps. The profiles will include: - -- CUDA kernel execution -- Memory operations -- CPU-GPU synchronization -- NVTX markers for key operations - -Nsight Systems supports multi-report view, to open multiple databases together. In this mode, different processes and steps can be aligned in one time line for better analysis. diff --git a/docs/perf/perf_tuning.rst b/docs/perf/perf_tuning.rst deleted file mode 100644 index 58df6ce13..000000000 --- a/docs/perf/perf_tuning.rst +++ /dev/null @@ -1,199 +0,0 @@ -Performance Tuning Guide -============================== - -Last updated: 06/23/2025. - -Author: `Guangming Sheng `_, `Jiali Zheng `_ - -In this section, we will discuss how to tune the performance of all the stages in verl, including: - -1. Rollout generation throughput. - -2. Enable ``use_remove_padding=True`` for sequence packing (i.e., data packing and remove padding). - -3. Batch size tuning for forward and backward computation - -4. Enable ``use_dynamic_bsz=True`` for higher throughput. - -5. Utilize Ulysses Sequence Parallel for Long Context Training - -6. LigerKernel for SFT performance optimization - -7. Forward prefetch in FSDP training backend - -8. Memory optimization for entropy calculation from logits - -Rollout Generation Tuning --------------------------- - -verl currently supports two rollout backends: vLLM and TGI (with SGLang support coming soon). - -Below are key factors for tuning vLLM-based rollout. Before tuning, we recommend setting ``actor_rollout_ref.rollout.disable_log_stats=False`` so that rollout statistics are logged. - -- Increase ``gpu_memory_utilization``. - - - For vLLM v0.7.0 and later, the vLLM instance will only use gpu_memory_utilization of the **total** memory. - - For SGLang, it's the fraction of the free GPU memory used for **static** memory like model weights and KV cache. However, the remaining (1-gpu_memory_utilization) will also be used during inference. - - However, if model parameters and optimizer states are not offloaded, using too high a fraction can lead to OOM. - A value between 0.5 and 0.7 often strikes a good balance between high throughput and avoiding OOM. - - Note: since the definition of ``gpu_memory_utilization`` varies across inference engines, a value that works well for one engine may cause OOM for another. - -- Adjust ``max_num_seqs`` or ``max_num_batched_tokens``. - If the GPU cache utilization is relatively low in the log, increase ``max_num_seqs`` or ``max_num_batched_tokens`` - can enlarge the effective batch size in the decoding stage, allowing more concurrent requests per batch. - We recommend setting ``max_num_batched_tokens > 2048`` for higher throughput. - -- Use a smaller ``tensor_parallel_size``. - When GPU resources allow, a smaller tensor parallel size spawns more vLLM replicas. - Data parallelism (DP) can yield higher throughput than tensor parallelism (TP), but also increases KVCache consumption. - Carefully balance the trade-off between more replicas and higher memory usage. - Our experient in Sec. 8.4 of `HybridFlow paper `_ evaluate this trade-off. - -More tuning details such as dealing with Preemption and Chunked-prefill -can be found in `vLLM official tuning guide `_ - -For optimal performance, we recommend using vLLM v0.8.3 or later. See https://github.com/volcengine/verl/blob/main/docs/README_vllm0.8.md for details. - -Enable remove padding (sequence packing) ------------------------------------------ - -Currently, for llama, mistral, gemma1 and qwen based models, users can enable `use_remove_padding=True` to utilize the -sequence packing implementation provided by transformers library. - -For other models, transformers library may also support it but we haven't tested it yet. -Users can add the desired model config to the `test_transformer.py `_ file. -And test its functionaility by running the following command: - -.. code-block:: bash - - pytest -s tests/models/test_transformer.py - -If the test passes, you can add your desired model into the model `registry.py `_ file. -Then, you can enjoy the performance boost of sequence packing -and welcome to PR your tested model to verl! - - -Batch Size Tuning ------------------ - -To achieve higher throughput in experience preparation (i.e., model fwd) and model update (i.e., actor/critic fwd/bwd), -users may need to tune the ``*micro_batch_size_per_gpu`` for different computation. - -In verl, the core principle for setting batch sizes is: - -- **Algorithmic metrics** (train batch size, PPO mini-batch size) are *global* (from a single-controller perspective), - normalized in each worker. See the `normalization code `_. - -- **Performance-related parameters** (micro batch size, max token length for dynamic batch size) are *local* parameters that define the per-GPU data allocations. - See the `normalization code `_. - -.. note:: In your training script, please use ``*micro_batch_size_per_gpu`` instead of ``*micro_batch_size``. - So that you don't need to consider the normalization of the ``micro_batch_size`` and ``micro_batch_size`` will be deprecated. - -Batch Size Tuning tips -"""""""""""""""""""""" - -Therefore, users may need to tune the ``*micro_batch_size_per_gpu`` to accelerate training. Here're some tips: - -1. **Enable gradient checkpointing**: - Set ``actor_rollout_ref.model.enable_gradient_checkpointing=True`` and ``critic.model.enable_gradient_checkpointing=True``. - This often allows for larger micro-batch sizes and will be beneficial for large mini-batch training. - -2. Increase the ``*micro_batch_size_per_gpu`` as much as possible till equals to normalized ``mini_batch_size``. - -3. **Use larger forward-only parameters**: - Forward only parameter, such as ``actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu``, - ``actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu``, ``critic.forward_micro_batch_size_per_gpu`` could be larger (e.g., 2x) than training related micro batch sizes, - such as ``actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``, ``critic.ppo_micro_batch_size_per_gpu``. - -4. **Allow larger micro-batch sizes for Critic and Reward models**: - micro batch size of Critic and Reward model could be larger than Actor model. This is because the actor model has much larger vocab size in the final layer. - -5. **Enable activation offloading**: - Set ``actor_rollout_ref.model.enable_activation_offload=True`` and ``critic.model.enable_activation_offload=True``. - This often works together with gradient checkpointing to get larger micro-batch sizes and it's only available in FSDP backend now. - -Tuning for Dynamic Batch Size ------------------------------ - -Dynamic batch size is a technique that allows the model to process similar number of tokens in a single forward pass (with different actual batch sizes). -This can significantly improve the training efficiency and reduce the memory usage. - -To utilize this technique, users can set ``use_dynamic_bsz=True`` in actor, ref, critic and reward models. -With ``use_dynamic_bsz=True``, users don't need to tune ``*micro_batch_size_per_gpu``. -Instead, users should tune the following parameters: - -- ``actor_rollout_ref.actor.ppo_max_token_len_per_gpu``, ``critic.ppo_max_token_len_per_gpu``: - The maximum number of tokens to be processed in fwd and bwd of ``update_policy`` and ``update_critic``. - -- ``actor_rollout_ref.ref.log_prob_max_token_len_per_gpu`` and ``actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu``: - The maximum number of tokens to be processed in a the fwd computation of ``compute_log_prob`` and ``comptue_ref_log_prob``. - -- ``critic.forward_micro_batch_size_per_gpu``, ``reward_model.forward_micro_batch_size_per_gpu``: - The maximum number of tokens to be processed in a the fwd computation of ``compute_values``, ``compute_rm_score``. - -Dynamic Batch Size Tuning tips -"""""""""""""""""""""""""""""" - -Here're some tips to tune the above parameters: - -1. **Increase** ``actor_rollout_ref.actor.ppo_max_token_len_per_gpu`` - Make it at least 2 x (max_prompt_length + max_response_length). We set it to 3x in `run_qwen2-7b_rm_seq_balance.sh `_. - Try to increase it to get higher throughput. - -2. **Forward-only parameters can be larger**: - Similar to the non-dynamic-batch scenario, forward-only token limits can exceed those used in forward/backward operations. - -3. **Use larger limits for Critic and Reward models**: - Critic and Reward parameters can be set at least 2× the Actor’s limits. For instance, we set them to 4× here: - `run_qwen2-7b_rm_seq_balance.sh `_ - -.. :math:`\text{critic.ppo_max_token_len_per_gpu} = 2 \times \text{actor.ppo_max_token_len_per_gpu})`. - -Ulysses Sequence Parallel for Long Context Training ----------------------------------------------------- - -To utilize this technique, users can set ``ulysses_sequence_parallel_size>1`` in actor, ref, critic and reward models. - -We support different model utilize different ulysses_sequence_parallel_size sizes. - -To train log sequence (>32k), users may need to decrease the ``*micro_batch_size_per_gpu`` and ``*max_token_len_per_gpu`` to avoid OOM. - -LigerKernel for SFT ----------------------- - -LigerKernel is a high-performance kernel for Supervised Fine-Tuning (SFT) that can improve training efficiency. To enable LigerKernel in your SFT training: - -1. Install liger-kernel via ``pip3 install liger-kernel``. In your SFT configuration file (e.g., ``verl/trainer/config/sft_trainer.yaml``), set the ``use_liger`` parameter: - - .. code-block:: yaml - - model: - use_liger: True # Enable LigerKernel for SFT - -2. The default value is ``False``. Enable it only when you want to use LigerKernel's optimizations. - -3. LigerKernel is particularly useful for improving training performance in SFT scenarios. - -Forward prefetch in FSDP training backend ----------------------- - -During the training phase, users can enable forward prefetching in FSDP by setting ``fsdp_config.forward_prefetch=True``. For example, ``actor_rollout_ref.actor.fsdp_config.forward_prefetch=True``. This configuration prefetches the next forward-pass all-gather operation before completing the current forward computation, overlapping communication with computation and improving efficiency. For further details, refer to the `FSDP forward_pefetch `_ documentation. - -.. note:: - Backward prefetch is unsupported because the ``BACKWARD_POST`` policy may prefetch incorrectly in nested-module cases. For details, see the `FSDP documentation `_ - -Memory optimization for entropy calculation from logits ----------------------- - -The ``logits`` tensor (typically of shape ``[bsz*seq_len, voc]``) can consume significant memory. When using ``compute_entropy_from_logits``, memory usage reaches approximately ``[bsz*seq_len, voc] × (4 bytes (float32) + 2 bytes (autocast for softmax+logsumexp) + 1 byte (softmax output))``. - -To reduce this memory peak, enable chunked computation by setting: -``actor_rollout_ref.ref.entropy_from_logits_with_chunking = True`` -This processes the tensor in chunks of shape ``[chunk_size, voc]`` (e.g., 2048) rather than the full sequence length, exclusively during the model's forward pass. - -Additionally, during training, standard gradient checkpointing (``enable_gradient_checkpointing=True``) does not apply to entropy calculations. To reduce memory peaks in this context, set: -``actor_rollout_ref.actor.entropy_checkpointing = True`` -This enables entropy recomputation specifically for the entropy calculation, lowering memory usage during training. diff --git a/docs/preparation/prepare_data.rst b/docs/preparation/prepare_data.rst deleted file mode 100644 index 312352826..000000000 --- a/docs/preparation/prepare_data.rst +++ /dev/null @@ -1,128 +0,0 @@ -Prepare Data for Post-Training -======================================== - -Last updated: 02/09/2025. - -Before starting the post-training job, we need to prepare the data for -the policy training. The data should be stored in the parquet format. - -We provide several data preprocess scripts for different datasets, -including GSM8K, MATH, HelloSwag, Full_hh_rlhf. To prepare other datasets, we need -to follow the following steps: The data preprocess script can be divided -into two parts: - -1. The first part is the common part, which loads the dataset from - huggingface's ``datasets`` package. Then preprocess the datasets with - the ``make_map_fn`` and then store in the parquet format. - -.. code:: python - - import re - import os - import datasets - - from verl.utils.hdfs_io import copy, makedirs - import argparse - - # To extract the solution for each prompts in the dataset - # def extract_solution(solution_str): - # ... - - - if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--local_dir', default='/opt/tiger/gsm8k') - parser.add_argument('--hdfs_dir', default=None) - - args = parser.parse_args() - - num_few_shot = 5 - data_source = 'openai/gsm8k' - - dataset = datasets.load_dataset(data_source, 'main') - - train_dataset = dataset['train'] - test_dataset = dataset['test'] - - # Construct a `def make_map_fn(split)` for the corresponding datasets. - # ... - - train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True) - test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True) - - local_dir = args.local_dir - hdfs_dir = args.hdfs_dir - - train_dataset.to_parquet(os.path.join(local_dir, 'train.parquet')) - test_dataset.to_parquet(os.path.join(local_dir, 'test.parquet')) - - makedirs(hdfs_dir) - - copy(src=local_dir, dst=hdfs_dir) - -2. The users are required to implement the ``make_map_fn()`` function - (as well as the ``extract_solution``) on their own to support - different datasets or tasks. - -We already implemented the data preprocess of GSM8k, MATH, Hellaswag and Full_hh_rlhf -datasets. And we take the GSM8k dataset as an example: - -**GSM8K** - -In the ``make_map_fn``, each data field should consist of the following -5 fields: - -1. ``data_source``: The name of the dataset. To index the corresponding - reward function in the ``RewardModule`` -2. ``prompt``: This field should be constructed in the format of - huggingface chat_template. The tokenizer in ``RLHFDataset`` will - apply chat template and tokenize the prompt. -3. ``ability``: Define the task category. -4. ``reward_model``: Currently, we only utilize the ``ground_truth`` - field during evaluation. The ``ground_truth`` is computed by the - ``extract_solution`` function. **NOTED** that the implementation of - the corresponding reward function should align with this extracted - ``ground_truth``. -5. ``extra_info``: Record some information of the current prompt. Not - use for now. - -.. code:: python - - def extract_solution(solution_str): - solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) # extract the solution after #### - assert solution is not None - final_solution = solution.group(0) - final_solution = final_solution.split('#### ')[1].replace(',', '') - return final_solution - - instruction_following = "Let's think step by step and output the final answer after \"####\"." - - # add a row to each data item that represents a unique id - def make_map_fn(split): - - def process_fn(example, idx): - question = example.pop('question') - - question = question + ' ' + instruction_following - - answer = example.pop('answer') - solution = extract_solution(answer) - data = { - "data_source": data_source, - "prompt": [{ - "role": "user", - "content": question - }], - "ability": "math", - "reward_model": { - "style": "rule", - "ground_truth": solution - }, - "extra_info": { - 'split': split, - 'index': idx - } - } - return data - - return process_fn diff --git a/docs/preparation/reward_function.rst b/docs/preparation/reward_function.rst deleted file mode 100644 index 286e2aff4..000000000 --- a/docs/preparation/reward_function.rst +++ /dev/null @@ -1,71 +0,0 @@ -Implement Reward Function for Dataset -====================================== - -Last updated: 06/02/2025. - -For each dataset, we need to implement a reward function or utilize a reward model to compute the rewards for the generated responses. -We already pre-implemented some reward functions in `reward_score directory `_. -You can also use customized reward functions. - -Currently, we support reward functions for GSM8k and MATH datasets. For RLHF datasets (e.g., -full_hh_rlhf) and Code Generation (e.g., APPS), we utilize reward model -and SandBox (will opensource soon) for evaluation respectively. - -RewardManager -------------- - -In the entrypoint of the PPO Post-Training script `main_ppo.py `_, -we implement a ``RewardManager`` that utilize pre-implemented reward functions to compute the scores for each response. - -In the ``RewardManager``, we implemented a ``__call__`` function to -compute the score for each response. -All the reward functions are executed by ``compute_score_fn``. -The input is a ``DataProto``, which includes: - -- ``input_ids``, ``attention_mask``: ``input_ids`` and ``attention_mask`` after applying - chat_template, including prompt and response -- ``responses``: response tokens -- ``ground_truth``: The ground truth string of the current prompt. - Stored in ``non_tensor_batch`` in the ``DataProto``, which should be - preprocessed in the parquet files. -- ``data_source``: The dataset name of the current prompt. Stored in - ``non_tensor_batch`` in the ``DataProto``, which should be - preprocessed in the parquet files. - -After detokenize the responses, the responses string and the ground -truth string will be input to the ``compute_score_fn`` to compute the -score for each response. - -Reward Functions ----------------- - -Pre-implemented -~~~~~~~~~~~~~~~ - -We already pre-implemented some reward functions in `reward_score directory `_. - -- In the `GSM8k example `_, we - force the response to output the final answer after four ####, then - use string matching to compare with the ground truth. If completely - correct, score 1 point; if the format is correct, score 0.1 points; if - the format is incorrect, score 0 points. -- In the `MATH example `_, we follow - the implementation in `lm-evaluation-harness repository `_. - -Customized -~~~~~~~~~~ - -You can implement customized reward functions in a separate file and specify them using ``custom_reward_function.path`` and ``custom_reward_function.name``. For the set of them, please refer to :ref:`config-explain-page`. - -The parameters of your reward function should be ``data_source``, ``solution_str``, ``ground_truth``, and ``extra_info``. -For example: - -.. code:: python - - def my_reward_fn(data_source, solution_str, ground_truth, extra_info=None): - return len(solution_str)/100 - -If you are testing only a single customized reward function, you can simply name it 'compute_score' and leave ``custom_reward_function.name`` unset. - -To run multiple tests with different customized reward functions, you can modify both ``custom_reward_function.path`` and ``custom_reward_function.name`` for each trial. -For instance, you might create a single `my_reward.py` file and implement multiple reward functions within it. This way, for different trials, you only need to adjust ``custom_reward_function.name``, making it more convenient to conduct multiple tests within scripts. diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt deleted file mode 100644 index 55ccdb8f7..000000000 --- a/docs/requirements-docs.txt +++ /dev/null @@ -1,13 +0,0 @@ -# markdown support -recommonmark -myst_parser -# markdown table support -sphinx-markdown-tables - -# theme default rtd - -# crate-docs-theme -sphinx-rtd-theme - -# pin tokenizers version to avoid env_logger version req -tokenizers==0.21 diff --git a/docs/sglang_multiturn/interaction_system.rst b/docs/sglang_multiturn/interaction_system.rst deleted file mode 100644 index 26b3db91e..000000000 --- a/docs/sglang_multiturn/interaction_system.rst +++ /dev/null @@ -1,419 +0,0 @@ -Interaction System for Multi-turn RL Training -============================================= - -Last updated: 06/25/2025. - -Overview --------- - -The verl interaction system enables dynamic, multi-turn conversational feedback during reinforcement learning training. This system allows models to engage in iterative problem-solving scenarios where interaction agents can provide corrective feedback, guidance, or evaluation based on the model's responses. - -**New in Multi-Interaction Support**: The system now supports multiple named interactions within a single training session, enabling sophisticated training scenarios where different samples can use different interaction strategies. This allows for curriculum learning, domain-specific feedback, and flexible agent switching at the sample level. - -Key features: - -- **Async-based Architecture**: Non-blocking interaction processing for distributed training -- **Instance Management**: Stateful session handling with unique instance IDs for concurrent interactions -- **SGLang Integration**: Seamless integration with SGLang rollout system for multi-turn conversations -- **Configuration-driven**: Dynamic agent loading via YAML configuration files -- **Multi-Interaction Support**: Registry system enabling multiple named interactions per rollout -- **Sample-Level Selection**: Each sample can specify which interaction to use via configuration -- **Reward Integration**: Turn-level scoring mechanism integrated with verl's reward system - -Architecture ------------- - -The interaction system follows a plugin-based architecture with clear separation of concerns: - -.. code-block:: - - Interaction Registry System - ↓ - BaseInteraction (Abstract Interface) - ↓ - Multiple Named Interactions (e.g., Gsm8kInteraction, CustomInteraction) - ↓ - SGLang Rollout Integration (interaction_map) - ↓ - Sample-Level Interaction Selection - ↓ - Async Request Lifecycle Management - -Core Components -~~~~~~~~~~~~~~~ - -**Interaction Registry System** - -The interaction registry system allows loading and managing multiple named interactions: - -.. code-block:: python - - from verl.interactions.utils.interaction_registry import initialize_interactions_from_config - - # Load multiple interactions from config - interaction_map = initialize_interactions_from_config("config.yaml") - - # Access specific interaction by name - gsm8k_interaction = interaction_map["gsm8k"] - custom_interaction = interaction_map["custom_solver"] - -**BaseInteraction Interface** - -All interaction agents must implement the ``BaseInteraction`` abstract class: - -.. code-block:: python - - from verl.interactions.base import BaseInteraction - from typing import Dict, Any, List, Tuple, Optional - - class BaseInteraction: - def __init__(self, config: Dict[str, Any]): - self.config = config - self.name: str = config.get("name", "interaction_agent") - - async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str: - """Initialize interaction session, return instance_id""" - - async def generate_response(self, instance_id: str, messages: List[Dict[str, Any]], **kwargs) -> Tuple[bool, str, float, Dict[str, Any]]: - """Generate response, return (should_terminate, response, score, metadata)""" - - async def calculate_score(self, instance_id: str, **kwargs) -> float: - """Calculate turn-level score for RL training""" - - async def finalize_interaction(self, instance_id: str, **kwargs) -> None: - """Clean up resources""" - -**Request Lifecycle** - -The interaction system integrates with SGLang's async rollout via state management: - -1. ``PENDING`` → Initialize interaction via ``start_interaction()`` -2. ``GENERATING`` → Model generates response -3. ``INTERACTING`` → Process response via ``generate_response()`` -4. ``GENERATING`` → Continue if not terminated, otherwise ``COMPLETED`` - -Configuration -------------- - -**Basic Setup** - -Enable interaction in your rollout configuration: - -.. code-block:: yaml - - actor_rollout_ref: - rollout: - multi_turn: - enable: true - interaction_config_path: "path/to/interaction_config.yaml" - max_user_turns: 10 - max_assistant_turns: 10 - -**Interaction Configuration File** - -Create an interaction configuration file (e.g., ``interaction_config.yaml``): - -**Single Interaction (Legacy Format)** - -.. code-block:: yaml - - interaction: - - name: "gsm8k" - class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" - config: {} - -**Multiple Interactions (New Format)** - -.. code-block:: yaml - - interaction: - - name: "gsm8k" - class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" - config: {} - - name: "custom_solver" - class_name: "custom.interactions.CustomInteraction" - config: - solver_type: "advanced" - timeout: 30 - - name: "code_verifier" - class_name: "verl.interactions.base.BaseInteraction" - config: - verification_mode: "strict" - -**Automatic Name Generation** - -If no ``name`` field is provided, the system will automatically generate one from the class name: - -.. code-block:: yaml - - interaction: - - class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" - config: {} - # Automatically generates name: "gsm8k" - -The system will dynamically load all specified interaction classes and make them available by name. - -Implementation Example: GSM8K ------------------------------ - -The GSM8K interaction demonstrates a complete implementation for math problem-solving scenarios: - -.. code-block:: python - - from verl.interactions.base import BaseInteraction - from verl.utils.reward_score import gsm8k - from uuid import uuid4 - - class Gsm8kInteraction(BaseInteraction): - def __init__(self, config: dict): - super().__init__(config) - self._instance_dict = {} - - async def start_interaction(self, instance_id=None, ground_truth=None, **kwargs): - if instance_id is None: - instance_id = str(uuid4()) - self._instance_dict[instance_id] = { - "response": "", - "ground_truth": ground_truth, - "reward": 0.0, - } - return instance_id - - async def generate_response(self, instance_id, messages, **kwargs): - # Extract last user message content - content = "" - for item in reversed(messages): - if item.get("role") == "user": - content = item.get("content", "") - break - - # Ensure GSM8K format (#### prefix) - if content.startswith("#### "): - self._instance_dict[instance_id]["response"] = content - else: - self._instance_dict[instance_id]["response"] = "#### " + content - - reward = await self.calculate_score(instance_id) - if reward == 1.0: - return True, "Your response is correct!", 1.0, {} - else: - return False, "Your response is incorrect! You need to reflect on your answer and try again.", 0.0, {} - - async def calculate_score(self, instance_id, **kwargs): - return gsm8k.compute_score( - self._instance_dict[instance_id]["response"], - self._instance_dict[instance_id]["ground_truth"], - method="flexible", format_score=0.0, score=1.0, - ) - - async def finalize_interaction(self, instance_id, **kwargs): - del self._instance_dict[instance_id] - -Training Integration --------------------- - -**Training Script Configuration** - -Include interaction configuration in your training command: - -.. code-block:: bash - - python3 -m verl.trainer.main_ppo \\ - --config-path="$CONFIG_PATH" \\ - --config-name='gsm8k_multiturn_grpo_w_interaction' \\ - algorithm.adv_estimator=grpo \\ - data.train_batch_size=512 \\ - data.return_raw_chat=True \\ - actor_rollout_ref.rollout.name=sglang \\ - actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml" \\ - trainer.total_epochs=15 - -**Data Requirements** - -Ensure your dataset includes interaction parameters with the ``name`` field for interaction selection: - -.. code-block:: python - - # Dataset should include interaction_kwargs in non_tensor_batch - interaction_kwargs = [ - {"name": "gsm8k", "query": "What is 2+2?", "ground_truth": "4"}, - {"name": "custom_solver", "query": "Solve: x^2 + 5x + 6 = 0", "ground_truth": "x = -2, -3"}, - {"name": "gsm8k", "query": "What is 3+3?", "ground_truth": "6"}, - ] - -**Sample-Level Interaction Selection** - -Each sample can specify which interaction to use via the ``name`` field. This enables flexible training scenarios where different samples use different interaction strategies: - -.. code-block:: python - - # Example: Math problems use GSM8K interaction, code problems use code verifier - data_samples = [ - { - "prompt": "What is 15% of 200?", - "interaction_kwargs": { - "name": "gsm8k", - "query": "What is 15% of 200?", - "ground_truth": "30" - } - }, - { - "prompt": "Write a function to check if a number is prime", - "interaction_kwargs": { - "name": "code_verifier", - "code_type": "python", - "expected_behavior": "return True for prime numbers" - } - } - ] - -**Backward Compatibility** - -If no ``name`` field is provided in ``interaction_kwargs``, the system defaults to ``"gsm8k"`` for backward compatibility. - -Best Practices --------------- - -**Resource Management** - -- Always implement proper cleanup in ``finalize_interaction()`` -- Use unique instance IDs to avoid conflicts in concurrent training -- Handle edge cases like empty messages or malformed content - -**Performance Optimization** - -- Keep interaction logic lightweight to avoid blocking training -- Use async/await properly to maintain non-blocking behavior -- Consider caching expensive computations within interaction instances - -**Testing** - -Comprehensive testing is essential for interaction systems: - -.. code-block:: python - - import pytest - from unittest.mock import patch - - @pytest.mark.asyncio - async def test_interaction_workflow(): - interaction = YourInteraction({}) - - # Test complete workflow - instance_id = await interaction.start_interaction(ground_truth="expected_answer") - - messages = [{"role": "user", "content": "user_response"}] - should_terminate, response, reward, metadata = await interaction.generate_response(instance_id, messages) - - assert should_terminate in [True, False] - assert isinstance(reward, float) - - await interaction.finalize_interaction(instance_id) - -Advanced Usage --------------- - -**Multi-Interaction Training Strategies** - -You can design sophisticated training scenarios using multiple interactions: - -.. code-block:: python - - # Example: Progressive difficulty with different interaction agents - class MathTrainingPipeline: - def create_interaction_config(self): - return { - "interaction": [ - { - "name": "basic_math", - "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", - "config": {"difficulty": "easy"} - }, - { - "name": "advanced_math", - "class_name": "custom.interactions.AdvancedMathInteraction", - "config": {"difficulty": "hard", "allow_hints": True} - }, - { - "name": "competition_math", - "class_name": "custom.interactions.CompetitionMathInteraction", - "config": {"time_limit": 300, "show_steps": False} - } - ] - } - - def create_curriculum_data(self, epoch): - if epoch < 5: - return [{"name": "basic_math", ...} for _ in samples] - elif epoch < 10: - return [{"name": "advanced_math", ...} for _ in samples] - else: - return [{"name": "competition_math", ...} for _ in samples] - -**Custom Scoring Functions** - -You can integrate custom reward functions: - -.. code-block:: python - - async def calculate_score(self, instance_id, **kwargs): - response = self._instance_dict[instance_id]["response"] - ground_truth = self._instance_dict[instance_id]["ground_truth"] - - # Custom evaluation logic - if custom_evaluation_function(response, ground_truth): - return 1.0 - else: - return 0.0 - -**Multi-step Interactions** - -For complex scenarios requiring multiple feedback rounds: - -.. code-block:: python - - async def generate_response(self, instance_id, messages, **kwargs): - instance = self._instance_dict[instance_id] - instance["attempts"] += 1 - - # Evaluate current response - reward = await self.calculate_score(instance_id) - - if reward > 0.8: - return True, "Excellent work!", reward, {} - elif instance["attempts"] < 3: - return False, "Good attempt, but try to improve...", reward, {} - else: - return True, "Maximum attempts reached.", reward, {} - -Troubleshooting ---------------- - -**Common Issues** - -1. **Instance ID Conflicts**: Ensure unique instance IDs across concurrent sessions -2. **Memory Leaks**: Always call ``finalize_interaction()`` to clean up resources -3. **Blocking Operations**: Keep interaction logic async and non-blocking -4. **Configuration Errors**: Verify interaction config path and class name are correct -5. **Interaction Name Conflicts**: Ensure all interactions have unique names in the configuration -6. **Missing Interaction**: Verify the ``name`` field in ``interaction_kwargs`` matches available interactions -7. **Backward Compatibility**: When migrating from single to multi-interaction, add ``name`` fields to existing data - -**Debugging** - -Enable debug logging to trace interaction flow: - -.. code-block:: bash - - export VERL_LOGGING_LEVEL=DEBUG - -**Performance Monitoring** - -Monitor interaction performance impact on training throughput and adjust accordingly. - -Related Documentation --------------------- - -- :doc:`multiturn`: Basic multi-turn rollout configuration -- :doc:`sandbox_fusion`: Tool integration with SGLang -- :doc:`search_tool_example`: Search tool implementation example \ No newline at end of file diff --git a/docs/sglang_multiturn/multiturn.rst b/docs/sglang_multiturn/multiturn.rst deleted file mode 100644 index 5a4c444cb..000000000 --- a/docs/sglang_multiturn/multiturn.rst +++ /dev/null @@ -1,343 +0,0 @@ -Multi-turn Rollout Support -========================== - -Last updated: 06/27/2025. - -Basic Configuration -~~~~~~~~~~~~~~~~~~~ - -To enable multi-turn rollout, make sure to configure the following fields in your rollout configuration: - -.. code-block:: yaml - - actor_rollout_ref: - rollout: - multi_turn: True - name: "sglang" - -These configuration activates the sglang engine for multi-turn interaction during rollout. - -Custom Tool Configuration -~~~~~~~~~~~~~~~~~~~~~~~~~ - -For custom environment interaction tools, you can implement your own tools based on ``verl.tools.base_tool.BaseTool``. Then, specify your tool configurations in a YAML file: - -.. code-block:: yaml - - tools: - - class_name: "" - config: - type: native - tool_schema: - -You may refer to GSM8KTool_example_configuration_, which is one example of the tool configurations. Its implementation can be found in gsm8k_tool.py_. - -Finally, set the ``tools_config_file`` in your rollout config: - -.. code-block:: yaml - - actor_rollout_ref: - rollout: - tool_kwargs: - tools_config_file: - -This allows integration of customized tool behaviors during actor rollout steps. - -If you want rollout with simulated interaction, you can set the ``interaction_config_file`` in your rollout config: - -.. code-block:: yaml - - interaction: - - class_name: "" - config: {} - -.. code-block:: yaml - - actor_rollout_ref: - rollout: - interaction_config_file: - -If your tool creates multi-modal inputs, you should return a list of multi-modal inputs in your tool.execute() implementation. - -Image and video should be processed before returning. For example, if you are using Qwen2.5-VL, you can use the following code to get the representations: - -.. code-block:: python - - async def execute(self, ...) -> Tuple[str | Dict[str, Any], float, dict]: - ... - from verl.utils.dataset.vision_utils import process_image, process_video - - img1 = process_image(img1) - video1 = process_video(video1) - - # due to the (image | video) key is ("image" | "video") instead of ("images" | "videos") in vllm, we need to use ("image" | "video") to specify list of images/videos - # link: https://github.com/vllm-project/vllm/blob/3c545c0c3b98ee642373a308197d750d0e449403/vllm/multimodal/parse.py#L205 - return {"image": [img1, ...], "video": [video1, ...], "text": "..."}, 0, {} - -remeber to set ``return_multi_modal_inputs: False`` in your dataset config in order to process the multi-modal inputs in the rollout correctly. -Refer to the `Handling Multi-Modal Inputs in Datasets`_ section for more details. - -MCP Tool Configuration -~~~~~~~~~~~~~~~~~~~~~~ - -For MCP interaction tools, you can flexibly configure them using a YAML file. The typical setup is as follows: - -.. code-block:: yaml - - tools: - - class_name: "" - config: - type: mcp - mcp: - mcp_servers_config_path: ./mcp_server.json - tool_selected_list: {} - -The ``tool_selected_list`` field is optional and specifies which tools to use from the servers. If you want to enable all available tools, simply omit this attribute. Besides, ``mcp_servers_config_path`` points to a JSON file containing the MCP server configurations. For example: - -.. code-block:: json - - { - "mcpServers": { - "SSE Server": { - "url": "your_server_url", - "auth_token": "your_server_api_token" - }, - "STDIO Server": { - "command": "npx", - "args": ["-y", "server-mcp@0.2.1"], - "env": { - "SERVER_API_KEY": "your_server_api_token" - } - } - } - } - -Since the content formats returned by the MCP server may vary, users can inherit from ``MCPBaseTool`` and override the ``_parse_tool_result`` method to implement custom parsing logic. - -.. code-block:: python - - class MCPYourTool(MCPBaseTool): - def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): - super().__init__(config, tool_schema) - - def _parse_tool_result(self, content: list) -> Tuple[str, dict]: - ... - -Overall, you may refer to mcp_search_tool.py_ and mcp_tool_config.yaml_ for custom implementation and configuration. - -Multi-turn Tokenization -~~~~~~~~~~~~~~~~~~~~~~~ - -Tokenizing multi-turn rollouts poses a challenge: after applying the chat template and tokenizing the full message list, it's hard to identify which tokens belong to assistant messages. Since the token list is flat, it lacks direct alignment with the message roles. - -To address this, we adopt a **delta-based tokenization** strategy. Each time the LLM generates a new message, we: - -1. Apply the chat template to all prior messages (`messages[:i]`). -2. Apply the chat template again including the latest message (`messages[:i+1]`). -3. Tokenize only the *delta* between these two serialized message strings. - -This ensures that only tokens generated by the assistant are included in the loss mask. - -.. code-block:: python - - # When using tokenizer - # Exclude the assistant prompt (e.g., "<|im_start|>assistant") from the loss by setting add_generation_prompt=True - prev = tokenizer.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False) - curr = tokenizer.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False) - token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False) - loss_mask += [1] * len(token_ids) # Mask only the new assistant tokens - -.. code-block:: python - - # When using processor - # Exclude the assistant prompt (e.g., "<|im_start|>assistant") from the loss by setting add_generation_prompt=True - prev = processor.apply_chat_template(messages[:i], add_generation_prompt=True, tokenize=False) - prev_model_inputs = processor(text=prev, images=images, videos=videos, return_tensors="pt")[0].tolist() - curr = processor.apply_chat_template(messages[:i+1], add_generation_prompt=False, tokenize=False) - curr_model_inputs = processor(text=curr, images=images, videos=videos, return_tensors="pt")[0].tolist() - token_ids += curr_model_inputs["input_ids"][len(prev_model_inputs["input_ids"]):] - loss_mask += [1] * len(token_ids) # Mask only the new assistant tokens - -While we've validated this produces consistent results with full message tokenization, future models' chat template could break compatibility. To guard against silent inconsistencies, we compare the delta-based tokenization with full-tokenization results by default at the end of each rollout. - -If you see the following warning, you can check the mismatched substring in the log: - -.. code-block:: - - Inconsistent training and inference tokenization detected. This may lead to unexpected behavior during training. Please review your chat template to determine if this is intentional. For more information, refer to the multiturn README.md. - -The tokenization sanity check mode can be configured using the ``actor_rollout_ref.rollout.multi_turn.tokenization_sanity_check_mode`` parameter, which accepts the following values: - -- ``strict`` (default): Performs strict comparison between delta-based and full tokenization results, raising warnings for any differences. - -- ``ignore_strippable``: Ignores differences in whitespace characters (``\n``, ``\t``, ``\r``, spaces) while still checking for meaningful text mismatches. This is useful when debugging chat template issues where whitespace variations are expected and acceptable. - -- ``disable``: Completely disables the tokenization sanity check. Only use this if you have thoroughly validated that tokenization discrepancies are expected and won't impact training. - -Example configuration: - -.. code-block:: yaml - - actor_rollout_ref: - rollout: - multi_turn: - tokenization_sanity_check_mode: "ignore_strippable" # Choose from: "disable", "ignore_strippable", "strict" - -Handling Multi-Modal Inputs in Datasets -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If your dataset includes multi-modal inputs (such as images or videos), you can control whether these are pre-processed and included in each sample by setting the return_multi_modal_inputs flag in your dataset config (used by RLHFDataset). - -- ``return_multi_modal_inputs: True`` (default): The dataset will pre-process and include a multi_modal_inputs dictionary for each sample. This dict contains the model-ready representations (e.g., image tensors, video tensors, etc.) as produced by your processor. This is useful for single-turn or SFT-style training, where the model expects all modalities to be present in the batch. - -- ``return_multi_modal_inputs: False``: The dataset will not include the multi_modal_inputs field. This is recommended for multi-turn RL or tool-augmented rollouts, where the model may generate new multi-modal inputs dynamically during rollout, and you want to avoid conflicts or redundant data in the batch. - - -Special Cases -^^^^^^^^^^^^^ - -Some models (e.g., Qwen/QwQ-32B and Qwen3 series) remove internal reasoning content during chat template rendering. As a result, the message content can vary across turns, making the delta-based tokenization inaccurate. - -For example, for the following conversation: - -.. code-block:: python - - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is 2 + 2?"}, - {"role": "assistant", "content": "user asked about a simple math question. 2 + 2 = 4."}, - {"role": "user", "content": "Explain why."}, - {"role": "assistant", "content": "user wants to know the reasoning behind the answer. Search for a good explanation", - "tool_calls": [{"id": "tool1", "type": "search", "arguments": {"query": "Why is 2 + 2 = 4?"}}]}, - {"role": "tool", "content": "The sum of two and two is four because it is a basic arithmetic operation."}, - {"role": "assistant", "content": "The tool provided a good explanation.The sum of two and two is four because it is a basic arithmetic operation."} - ] - -1. Qwen/QwQ-32B will remove all reasoning content except the last assistant message after applying the chat template. - -.. code-block:: text - - <|im_start|>system - You are a helpful assistant.<|im_end|> - <|im_start|>user - What is 2 + 2?<|im_end|> - <|im_start|>assistant - 2 + 2 = 4.<|im_end|> - <|im_start|>user - Explain why.<|im_end|> - <|im_start|>assistant - - {"name": "", "arguments": {"query": "Why is 2 + 2 = 4?"}} - <|im_end|> - <|im_start|>user - - The sum of two and two is four because it is a basic arithmetic operation. - <|im_end|> - <|im_start|>assistant - The tool provided a good explanation. The sum of two and two is four because it is a basic arithmetic operation.<|im_end|> - -2. Qwen3 series will remove all reasoning content before the last user message. - -.. code-block:: text - - <|im_start|>system - You are a helpful assistant.<|im_end|> - <|im_start|>user - What is 2 + 2?<|im_end|> - <|im_start|>assistant - 2 + 2 = 4.<|im_end|> - <|im_start|>user - Explain why.<|im_end|> - <|im_start|>assistant - - user wants to know the reasoning behind the answer. Search for a good explanation - - - - {"name": "", "arguments": {"query": "Why is 2 + 2 = 4?"}} - <|im_end|> - <|im_start|>user - - The sum of two and two is four because it is a basic arithmetic operation. - <|im_end|> - <|im_start|>assistant - - The tool provided a good explanation. - - - The sum of two and two is four because it is a basic arithmetic operation.<|im_end|> - -To handle this, we fall back to a **fixed base conversation** containing only a single system and user message. Since this base doesn't include assistant messages or reasoning content, it remains consistent across turns. - -.. code-block:: python - - BASE_CHAT_HISTORY = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "I am a user."} - ] - prev = tokenizer.apply_chat_template(BASE_CHAT_HISTORY, add_generation_prompt=True, tokenize=False) - curr = tokenizer.apply_chat_template([*BASE_CHAT_HISTORY, messages[i]], add_generation_prompt=False, tokenize=False) - token_ids += tokenizer.encode(curr[len(prev):], add_special_tokens=False) - loss_mask += [1] * len(token_ids) - -This method works well for Qwen3 series. However, Qwen/QwQ-32B currently has a bug in its chat template. A fix_ has been proposed but not yet adopted. Until then, use the following command to download the fixed model revision: - -.. code-block:: bash - - pip install huggingface_hub - huggingface-cli download Qwen/QwQ-32B --revision refs/pr/81 - -.. _fix: https://huggingface.co/Qwen/QwQ-32B/discussions/81 - -Discrepancy Between Training and Inference Templates -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Although the above approach fixes the delta mismatch issue, the removal of reasoning content in the inference-time chat template introduces a new discrepancy: training uses the full reasoning content, while inference does not. - -This mismatch can affect model performance in unpredictable ways. To avoid it, we default to using the full response (including reasoning) for both training and rollout. - -However, this approach comes with trade-offs: - -1. Long reasoning contents can easily exceed the model's context window, especially in multi-turn rollout. -2. There's a mismatch between rollout and production environment now—models will not have reasoning content from past turns if you use the default chat template in production. - -We are still evaluating the impact of these issues. If you experience context length problems or prefer rollouts that match production (i.e., exclude reasoning), you can enable: - -``actor_rollout_ref.rollout.multi_turn.use_inference_chat_template = True`` - -GSM8K Multi-turn Training Performance -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -See the training performance of multi-turn rollout on the GSM8K task HERE_. - -.. _HERE: https://wandb.ai/zhaochenyang20/gsm8k_async_rl/runs/1ro1r7om?nw=nwuserzhaochenyang20 - -.. _GSM8KTool_example_configuration: https://github.com/volcengine/verl/blob/main/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml - -.. _gsm8k_tool.py: https://github.com/volcengine/verl/blob/main/verl/tools/gsm8k_tool.py - -.. _mcp_search_tool.py: https://github.com/volcengine/verl/blob/main/verl/tools/mcp_search_tool.py - -.. _mcp_tool_config.yaml: https://github.com/volcengine/verl/blob/main/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml - -Interaction System -~~~~~~~~~~~~~~~~~~ - -For dynamic conversational feedback during RL training, see: - -.. toctree:: - :maxdepth: 1 - - interaction_system - -Search Tool Integration -~~~~~~~~~~~~~~~~~~~~~~~ - -.. toctree:: - :maxdepth: 1 - - search_tool_example - -Code Walkthrough -~~~~~~~~~~~~~~~~~~~~~~~ -If you want to learn more in depth about the code execution flow, please read https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/rlhf/verl/multi-turn/code-walk-through diff --git a/docs/sglang_multiturn/sandbox_fusion.rst b/docs/sglang_multiturn/sandbox_fusion.rst deleted file mode 100644 index 207af5289..000000000 --- a/docs/sglang_multiturn/sandbox_fusion.rst +++ /dev/null @@ -1,292 +0,0 @@ -=============================== -Sandbox Fusion Tool Integration -=============================== - -Last updated: 06/10/2025. - -Motivations -=========== - -- As users of verl, we want to allow the model to call certain tools during Actor rollout, incorporating the results into the training process. -- A colleague from ByteDance proposed a paper aimed at enhancing model capability through code execution tools. -- We aim to support tool-calling capabilities of inference engines using `sandbox-fusion` as the code execution system, providing the community with a reimplementation of `retools`. - -Reward Compute with Sandbox Fusion + FaaS Integration -===================================================== - -- In current datasets and tasks, similar work already exists (e.g., Prime), which uses local processes as runners to execute model-generated code for reward computation. -- On this basis, #1429 has advanced the design by integrating FaaS as the runner for reward computation. - -Goals -===== - -- Adapt to the `sglang` tool-calling protocol and define tools for sandbox fusion. -- Integrate with the `async-rollout` process, ensuring sandbox fusion tools follow asyncIO conventions. -- Design and implement a basic rate limiter to prevent issues such as 429 errors. - -Non-Goals -========= - -- Training effectiveness is out of scope. -- Observability metrics are not considered. -- Distributed failover and component fault tolerance are not addressed. - -Design Details -============== - -Tool Schema Definition ----------------------- - -- Currently, only code execution is considered, requiring a `code` field in the JSON from the model. -- Only Python code is supported for now, so no `language` parameter is defined. - -.. code-block:: python - - OpenAIFunctionToolSchema( - type="function", - function=OpenAIFunctionSchema( - name="code_interpreter", - description="A tool for executing code.", - parameters=OpenAIFunctionParametersSchema( - type="object", - properties={ - "code": OpenAIFunctionPropertySchema( - type="string", - description="The code to execute.", - enum=None, - ) - }, - required=["code"], - ), - strict=False, - ) - ) - -Configuration Parameters --------------------------- - -+----------------------------+--------------------------------------------------------------+ -| Parameter Name | Description | -+============================+==============================================================+ -| `num_workers` | Number of worker threads/processes per DP to request runner. | -+----------------------------+--------------------------------------------------------------+ -| `rate_limit` | Global limit of concurrent code executions. Default: 10 | -+----------------------------+--------------------------------------------------------------+ -| `default_timeout` | Timeout (in seconds) for each code execution. Default: 30 | -+----------------------------+--------------------------------------------------------------+ -| `default_language` | Default programming language. Default: "python" | -+----------------------------+--------------------------------------------------------------+ -| `enable_global_rate_limit` | Whether to enable global rate limiting. Default: True | -+----------------------------+--------------------------------------------------------------+ -| `sandbox_fusion_url` | URL for the veFaas sandbox execution service | -+----------------------------+--------------------------------------------------------------+ - -Rate Limiting Design ------------------------ - -Objective: - -- Limit the number of inflight requests using a token bucket model. - -- Ensure ordered submission to code runners to avoid starvation due to backoff. - -Design Highlights: - -- Use Ray Global Actor as a singleton distributed counter at cluster level. - -- Semaphore used for counting, with `acquire` and `release` in separate thread pools to preserve order. - -- Use Ray’s cloud-pickle to serialize functions for decoupled `ExecutionWorker`. - -.. code-block:: python - - @ray.remote(concurrency_groups={"acquire": 1,"release": 10}) - class TokenBucketWorker: - def __init__(self, rate_limit: int): - self.rate_limit = rate_limit - self.current_count = 0 - self._semaphore = threading.Semaphore(rate_limit) - - @ray.method(concurrency_group="acquire") - def acquire(self): - self._semaphore.acquire() - self.current_count += 1 - - @ray.method(concurrency_group="release") - def release(self): - self._semaphore.release() - self.current_count -= 1 - - def get_current_count(self): - return self.current_count - - class ExecutionWorker: - def __init__(self, enable_global_rate_limit=True, rate_limit=10): - self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None - - def _init_rate_limit(self, rate_limit): - return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) - - def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: - with ExitStack() as stack: - stack.callback(self.rate_limit_worker.release.remote) - ray.get(self.rate_limit_worker.acquire.remote()) - try: - return fn(*fn_args, **fn_kwargs) - except Exception as e: - logger.warning(f"Error when executing code: {e}") - - def init_execution_pool(num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode=PoolMode.ThreadMode): - if mode == PoolMode.ThreadMode: - return ray.remote(ExecutionWorker).options(max_concurrency=num_workers).remote( - enable_global_rate_limit=enable_global_rate_limit, - rate_limit=rate_limit - ) - else: - raise NotImplementedError("Process mode is not implemented yet") - -Tool Implementation -------------------- - -- Use `instance_id` to identify requests across multiple dialogue rounds. - -- Use `execution_pool` to implement async invocation. - -- Cleanup state after rollout completion. - -.. code-block:: python - - class SandboxFusionTool(BaseTool): - def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): - ... - self.execution_pool = init_execution_pool(...) - ... - - async def create(self, instance_id: Optional[str] = None, ...): - ... - - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]: - code = parameters.get("code", "") - timeout = parameters.get("timeout", self.default_timeout) - language = parameters.get("language", self.default_language) - if not isinstance(code, str): - code = str(code) - - result = await self.execution_pool.execute.remote(self.execute_code,instance_id,code,timeout,language) - self._instance_dict[instance_id]["reward"].append(result.strip()) - - return result, result, {} - - def execute_code(self,instance_id,code,timeout=30,language="python"): - result_status, metadata = _process_single_case(0, None, None,self.sandbox_fusion_url, code, timeout, language) - # we should always expect this since we don't have correct answer - if metadata["run_status"] == "Finished": - actual_output = metadata["stdout"] if metadata["stdout"] is not None else "" - return actual_output - else: - return "no stdout here" - - async def calc_reward(self, instance_id: str, ...): - ... - - async def release(self, instance_id: str, ...): - ... - -Test Plan -========= - -Unit Tests ----------- - -- **test_tools_registration**: Test tool registration and initialization. -- **test_rollout_req_creation**: Validate that `AsyncRolloutReq` is built correctly. -- **test_over_size_case**: Ensure rollout terminates early when exceeding `max_seq_len`. -- **test_tool_call_basic_case**: Mock `sglang` output, validate tool call and result. -- **test_tool_call_batch_case**: Test batch processing of tool calls. -- **test_basic_multi_process_init**: Validate Ray global actor behaves as singleton. -- **TestSingleNodeRateLimiterCase**: Verify rate limiter works in single-node mode. -- **test_rotten_execution**: Ensure rate limiter recovers from function errors. -- **TestMultiNodeRateLimiterCase**: Verify behavior in multi-node environments. - -e2e Tests ----------- -we provide e2e test scripts in `tests/special_e2e` folder, named `tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh` - -by setting 'trainer.rollout_data_dir' you can dump the rollout data to local disk. here is an sample taken from the rollout data: - -.. code-block:: python - - { - "input": " - - system\nYou are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### `.\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{\"type\": \"function\", \"function\": {\"name\": \"code_interpreter\", \"description\": \"A tool for executing code.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"code\": {\"type\": \"string\", \"description\": \"The code to execute.\", \"enum\": null}}, \"required\": [\"code\"]}, \"strict\": false}}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n\n - - user\nJohn gets a bonus that's the same percentage every year. Last year he made $100,000 and got a $10,000 bonus. This year he makes $200,000. How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`.\n - - assistant\n", - "output": "assistant\n\nOkay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together.\n\nFirst, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year.\n\nLast year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that.\n\n10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year.\n\nThis year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy.\n - - \n{\"name\": \"code_interpreter\", \"arguments\": {\"code\": \"total_pay_this_year = 200000\\nbonus_percentage = 10 / 100\\nbonus_this_year = total_pay_this_year * bonus_percentage\\ntotal_income_this_year = total_pay_this_year + bonus_this_year\\nprint(total_income_this_year)\", \"executes\": \"True\"}}\n\n - - tool\n220000.0\n\n - - assistant\n\nSo the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000.\n\nI need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters.\n\n\n#### 220000.0", - "score": 0, - "step": 1 - } - -here is the readable format version: - -.. code-block:: python - - [system] - - You are a math expert. You are given a question and you need to solve it step by step. Reasoning step by step before any tool call. You should use the `calc_gsm8k_reward` tool after step by step solving the question, before generate final answer at least once and refine your answer if necessary. Put your final answer in the format of `#### `. - - # Tools - - You may call one or more functions to assist with the user query. - - You are provided with function signatures within XML tags: - - {"type": "function", "function": {"name": "code_interpreter", "description": "A tool for executing code.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to execute.", "enum": null}}, "required": ["code"]}, "strict": false}} - - - For each function call, return a json object with function name and arguments within XML tags: - - {"name": , "arguments": } - - - [user] - - John gets a bonus that's the same percentage every year. Last year he made $100,000 and got a $10,000 bonus. This year he makes $200,000. How much will John make this year when adding both his total pay and bonus together? Let's think step by step and output the final answer after `####`. - - [assistant] - - - Okay, let's see. John gets a bonus that's the same percentage every year. Last year, he made $100,000 and got a $10,000 bonus. This year, he makes $200,000. The question is asking how much he'll make this year when adding both his total pay and bonus together. - - First, I need to figure out the percentage bonus he gets. Since the bonus is the same percentage each year, if I can find that percentage based on last year's details, I can apply it to this year's earnings to find this year's bonus. Then, I'll add that bonus to his total pay this year. - - Last year, he earned $100,000 and got a $10,000 bonus. To find the percentage, I can divide the bonus by his total earnings. So, 10,000 divided by 100,000. Let me calculate that. - - 10,000 / 100,000 equals 0.1, which is 10% as a percentage. So, his bonus is 10% of his pay each year. - - This year, he makes $200,000. So, the bonus this year would be 10% of 200,000. Let me compute that using the code interpreter to ensure accuracy. - - {"name": "code_interpreter", "arguments": {"code": "total_pay_this_year = 200000\nbonus_percentage = 10 / 100\nbonus_this_year = total_pay_this_year * bonus_percentage\ntotal_income_this_year = total_pay_this_year + bonus_this_year\nprint(total_income_this_year)", "executes": "True"}} - - - [tool] - - 220000.0 - - [assistant] - - - So the code shows that John's total income this year would be $220,000. This is consistent with the 10% bonus on last year's pay of $100,000, which was $10,000. Therefore, this year's bonus of $20,000 added to his total pay of $200,000 results in a total income of $220,000. - - I need to verify that there are no other factors at play. The problem states it's a constant percentage bonus each year, so there's no mention of changing percentages or other variables. Therefore, applying a straightforward percentage of 10% on this year's partial salary amount would result in $20,000. Adding that to his total pay gives $220,000. The calculation is correct and aligns with the question's parameters. - - - #### 220000.0 diff --git a/docs/sglang_multiturn/search_tool_example.rst b/docs/sglang_multiturn/search_tool_example.rst deleted file mode 100644 index cbbdeb0d0..000000000 --- a/docs/sglang_multiturn/search_tool_example.rst +++ /dev/null @@ -1,264 +0,0 @@ -======================= -Search Tool Integration -======================= - -Last updated: 05/30/2025. - -Introduction ------------- -- We have added a search tool calling function to Multi-Turn RL, enabling the model to initiate retrieval requests during Actor rollout and directly use retrieval results for training. **We support using a local dense retriever as the retrieval tool, as well as integrating with your own local retrieval engine.** - - - -Quick Reproduction ------------------- - -Create a New Docker Container -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: bash - - docker run \ - -it \ - --shm-size 32g \ - --gpus all \ - -v {Huggingface-Cache-Path}:/root/.cache \ - --ipc=host \ - --network=host \ - --privileged \ - --name sglang_{your-name} \ - lmsysorg/sglang:dev \ - /bin/zsh - -If you need to restart after exiting the container: - -.. code:: bash - - docker start -i sglang_{your-name} - -Update Python and Configure the Virtual Environment using uv -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code:: bash - - apt update - apt install -y python3.10 python3.10-venv - - # Create a virtual environment - python3 -m venv ~/.python/verl-multiturn-rollout - - # Activate the virtual environment - source ~/.python/verl-multiturn-rollout/bin/activate - - # Install uv - python3 -m pip install uv - -Install verl Upstream -~~~~~~~~~~~~~~~~~~~~~ - -.. code:: bash - - cd ~ - git clone https://github.com/volcengine/verl.git - cd verl - - # Install verl - python3 -m uv pip install . - python3 -m uv pip install -r ./requirements_sglang.txt - - # Manually install flash-attn - python3 -m uv pip install wheel - python3 -m uv pip install packaging - python3 -m uv pip install flash-attn --no-build-isolation --no-deps - -Set Up a Local Retrieval Engine -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If you are using your own local retrieval service, you can skip this -step. We chose the local dense retriever provided in the search-R1 -example; detailed instructions are in the `searchR1 -docs `__. -In brief: - -- The GPU version offers higher accuracy and speed; each GPU uses about - 5–7 GB of memory. -- The CPU version can be used for simple testing but has lower - retrieval precision, which will degrade training performance. See the - `retriever - documentation `__ - in search-R1 for details. -- Recommend using Conda to install faiss-gpu=1.8.0; venv may cause errors. - -**Note**: To start both the training process and the local retrieval -service, we launch two separate Python environments. The training uses -uv in the verl-multiturn-rollout environment, while the retriever uses -conda to install ``faiss-gpu``. - -.. code:: bash - - # Download the Miniconda installer script - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh - - # Install to $HOME/miniconda3 in batch mode - bash ~/miniconda.sh -b -p $HOME/miniconda3 - - # Activate conda (only in the current shell) - eval "$($HOME/miniconda3/bin/conda shell.bash hook)" - - # (Optional) Add conda to your default shell startup - conda init - - # Reload shell config - source ~/.bashrc - - # Create and activate the retriever environment with Python 3.10 - conda create -n retriever python=3.10 -y - conda activate retriever - - # Install PyTorch (with GPU support) and related libraries - conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia -y - - # Install other Python packages - pip install transformers datasets pyserini huggingface_hub - - # Install the GPU version of faiss - conda install faiss-gpu=1.8.0 -c pytorch -c nvidia -y - - # Install the API service framework - pip install uvicorn fastapi - -Download the Indexing and Corpus -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The local retrieval files are large—prepare sufficient disk space. -Downloading is about 60–70 GB, and uncompressed takes about 132 GB: - -.. code:: bash - - conda activate retriever - - save_path=/the/path/to/save - python examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py --save_path $save_path - cat $save_path/part_* > $save_path/e5_Flat.index - gzip -d $save_path/wiki-18.jsonl.gz - -Start the Local flat e5 Retrieval Server -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -1. The first startup will download models and load the index. -2. Apart from the download, startup takes about 1–2 minutes. -3. After startup, each GPU uses about 5–7 GB of memory, leaving the rest - for multi-turn RL training. - -.. code:: bash - - conda activate retriever - - index_file=$save_path/e5_Flat.index - corpus_file=$save_path/wiki-18.jsonl - retriever_name=e5 - retriever_path=intfloat/e5-base-v2 - - python examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py \ - --index_path $index_file \ - --corpus_path $corpus_file \ - --topk 3 \ - --retriever_name $retriever_name \ - --retriever_model $retriever_path \ - --faiss_gpu - -Set Up WANDB_API_KEY -~~~~~~~~~~~~~~~~~~~~ - -.. code:: bash - - export WANDB_API_KEY={YOUR_WANDB_API_KEY} - - # Define a timestamp function - function now() { - date '+%Y-%m-%d-%H-%M' - } - -**Preprocess the Dataset** -~~~~~~~~~~~~~~~~~~~~~~~~~~ - - **Note:** The following data processing and training commands must be - run in the verl-multiturn-rollout environment. - -.. code:: bash - - python3 examples/data_preprocess/preprocess_search_r1_dataset.py - -Testing on 8 x H20 -~~~~~~~~~~~~~~~~~~ - -.. code:: bash - - # Ensure the now() function is defined - # Create a logs directory - mkdir -p logs - - # Set GPUs and run with a suitable log path - export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - - nohup bash examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh \ - trainer.experiment_name=qwen2.5-3b-it_rm-searchR1-like-sgl-multiturn-$(now) \ - > logs/searchR1-like$(now).log 2>&1 & - -Custom Search Configuration ---------------------------- - -To enable multi-turn reasoning, set the following fields in your config: - -.. code:: yaml - - actor_rollout_ref: - rollout: - name: "sglang" - multi_turn: - enable: True - -You must specify ``retrieval_service_url`` in ``examples/sglang_multiturn/config/tool_config/search_tool_config.yaml``, and properly configure concurrency. For more details on concurrency, refer to the Sandbox Fusion example: - -.. code:: yaml - - tools: - - class_name: verl.tools.search_tool.SearchTool - config: - retrieval_service_url: http://127.0.0.1:8000/retrieve - num_workers: 120 - rate_limit: 120 - timeout: 30 - -The retriever input/output formats are as follows. If your service -parameters match, only modify ``retrieval_service_url``. You can also -customize in ``search_r1_like_utils.py``. - -.. code:: python - - Input format: - { - "queries": ["What is Python?", "Tell me about neural networks."], - "topk": 3, - "return_scores": true - } - - Output format (when return_scores=True, similarity scores are returned): - { - "result": [ - [ # Results for each query - { - "document": doc, "score": score - }, - # ... more documents - ], - # ... results for other queries - ] - } - -Notes ------ - -1. The total training time is about 27 hours; meanwhile, the validation - dataset is very large (51 k), and each validation takes about 6000 s. - (Therefore, ``val_before_train=False`` by default) diff --git a/docs/single_controller.rst b/docs/single_controller.rst deleted file mode 100644 index d12177854..000000000 --- a/docs/single_controller.rst +++ /dev/null @@ -1,336 +0,0 @@ -The Design of ``verl.single_controller`` -============================================== - -Last updated: 05/21/2025. - -**Author:**\ `Wang Zhang `__ - -Preface -------- - -We prepared this document for developers of ``verl``, particularly those -interested in understanding or contributing to the -``verl.single_controller`` module. It is not intended for end users, but -for contributors seeking to understand the architectural rationale and -internal mechanics. - --------------- - -Origin ------- - -The ``single_controller`` module originated from a request I received — -to adapt a toy single-process RLHF script into a distributed system with -minimal changes, while maintaining ease of debugging. - -Common practice — such as using PyTorch’s Distributed Data Parallel -(DDP) — typically involves wrapping ``nn.Module`` and launching multiple -processes that execute the same function under different ranks. However, -this approach presents two main limitations in the context of -distributed RLHF: - Difficulty representing multiple DAGs as required by -PPO; - Difficulty inspecting intermediate tensors during training. - -To maintain debuggability, we opted for a different approach — breaking -the training loop into well-defined stages like ``generate_sequences``, -``compute_advantages``, and so on. - -We selected `Ray `__ as the initial backend for -``verl`` due to its ability to expose Python class methods as RPC -endpoints. However, Ray’s default model only supports **one method call, -one RPC**, while training LLMs typically requires coordination across -multiple processes. - -To hide this multi-Ray actors invocation for a single method from users, -we introduced the following components: - -- ``WorkerGroup`` – manages a group of remote workers and provides - a unified interface for multi-process distributed computation; -- ``ResourcePool`` – binds computational resources to worker - processes; -- ``ClassWithArgs`` – enables delayed remote instantiation with - specified initialization arguments. - --------------- - -A Running Example: ``generate_sequences`` ------------------------------------------ - -To illustrate the design, we walk through how the ``generate_sequences`` -method in the ``ActorRolloutRefWorker`` class is registered and invoked -across distributed workers. - --------------- - -Step 1: Register with a Decorator -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The first step is to define the ``generate_sequences`` and decorate it -with ``@register`` as it will be called in driver script. - -**Source:** -`fsdp_workers.py `__ - -.. code:: python - - class ActorRolloutRefWorker(Worker): - ... - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def generate_sequences(self, prompts: DataProto): - prompts = prompts.to(torch.cuda.current_device()) - ... - -The ``@register`` decorator adds metadata to the ``generate_sequences`` -method. Currently, it doesn’t alter functionality, but attaches -attributes via a magic key (``MAGIC_ATTR``): - -**Source:** -`decorator.py `__ - -.. code:: python - - def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): - ... - def decorator(func): - @wraps(func) - def inner(*args, **kwargs): - if materialize_futures: - args, kwargs = _materialize_futures(*args, **kwargs) - return func(*args, **kwargs) - - attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} - setattr(inner, MAGIC_ATTR, attrs) - return inner - - return decorator - -As the code shows, values of ``dispatch_mode``, ``execute_mode`` and -``blocking`` is attached the ``generate_sequences`` method. - --------------- - -Step 2: Binding During Initialization -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -These attached attributes are extracted and utilized when -``ActorRolloutRefWorker``, wrapped in a ``RayClassWithArgs``, is passed -into a ``RayWorkerGroup``. - -**Source:** -`main_generation.py `__ - -.. code:: python - - ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") - resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) - wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) - -During the -`initialization `__ -of ``RayWorkerGroup``, two key steps occur: - -1. Worker instances (Ray actors) are created: - `RayWorkerGroup._init_with_resource_pool `__ -2. Methods decorated with ``@register`` are bound to ``RayWorkerGroup``: - `RayWorkerGroup._bind_worker_method `__ - -.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/worker_group_init.png?raw=true - :alt: initialization_and_binding_of_worker_group - - initialization_and_binding_of_worker_group - -The binding procedure is the heart of ``verl.single_controller``. - -**Key function:** -`WorkerGroup._bind_worker_method `__ - -.. code:: python - - def _bind_worker_method(self, user_defined_cls, func_generator): - ... - for method_name in dir(user_defined_cls): - try: - method = getattr(user_defined_cls, method_name) - assert callable(method) - except Exception: - continue # Skip properties - <<>> - -When a method has the ``MAGIC_ATTR``, the attributes set by -``@register`` are extracted: - -.. code:: python - - <<>> - if hasattr(method, MAGIC_ATTR): - attribute = getattr(method, MAGIC_ATTR) - dispatch_mode = attribute["dispatch_mode"] - execute_mode = attribute["execute_mode"] - blocking = attribute["blocking"] - - <<>> - -As show in the flow chart above, these attributes are fed into -``func_generator``. However, ``func_generator`` takes ``method_name``, -``dispatch_fn``, ``collect_fn``, ``execute_fn``, ``blocking``. We need -to find the corresponding ``dispatch_fn`` and ``collect_fn`` associated -with the ``dispatch_mode`` (``DP_COMPUTE_PROTO``) from -`DISPATCH_MODE_FN_REGISTRY `__: - -.. code:: python3 - - DISPATCH_MODE_FN_REGISTRY = { - Dispatch.ONE_TO_ALL: { - "dispatch_fn": dispatch_one_to_all, - "collect_fn": collect_all_to_all, - }, - ... - Dispatch.DP_COMPUTE_PROTO: { - "dispatch_fn": dispatch_dp_compute_data_proto, - "collect_fn": collect_dp_compute_data_proto, - }, - ... - } - -Similarly, the ``execute_fn`` is selected by ``execute_mode`` and -extracted by: - -.. code:: python - - <<>> - # get execute_fn_name - execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) - wg_execute_fn_name = execute_mode["execute_fn_name"] - - # get execute_fn from string - try: - execute_fn = getattr(self, wg_execute_fn_name) - assert callable(execute_fn), "execute_fn must be callable" - except Exception: - print(f"execute_fn {wg_execute_fn_name} is invalid") - raise - <<>> - -In this ``generate_sequences`` cases: - -``dispatch_mode = Dispatch.DP_COMPUTE_PROTO`` - -``dispatch_fn = dispatch_dp_compute_data_proto`` - -``collect_fn = collect_dp_compute_data_proto`` - -``execute_fn = RayWorkerGroup.execute_all`` - -ONE_TO_ALL v.s. DP_COMPUTE_PROTO -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -``dispatch_mode`` is associated with a ``dispatch_fn`` and a -``collect_fn``. As the name implies, ``dispatch_fn`` processes the input -arguments in ``WorkerGroup`` and generate a batch (list) of input -arguments, each of which will be fed into a worker attached to the -``WorkerGroup``. - -``dispatch_fn`` of ``ONE_TO_ALL`` is -`dispatch_one_to_all `__, -which just duplicates all the input arguments into N replicas, where N -equals the number of Workers attached to the ``worker_group``: - -.. code:: python - - def dispatch_one_to_all(worker_group, *args, **kwargs): - args = tuple([arg] * worker_group.world_size for arg in args) - kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} - return args, kwargs - -``dispatch_fn`` of ``DP_COMPUTE_PROTO`` is -`dispatch_dp_compute_data_proto `__, -which uses ``DataProto.chunk`` to split a large ``DataProto`` into N -smaller ``DataProto``, where N equals the world_size (number of the -workers) of the ``worker_group``: - -.. code:: python - - def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): - from verl.single_controller.base.worker_group import WorkerGroup - - assert isinstance(worker_group, WorkerGroup) - # Note: enable auto padding for dp compute DatapProto - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding( - worker_group.world_size, - *args, - **kwargs, - ) - return splitted_args, splitted_kwargs - -The ``collect_fn`` follows the same pattern and process a batch (list) -of returned value from all workers of a ``WorkerGroup`` and merge it -into a list as ``collect_all_to_all`` does or a large ``DataProto`` as -``collect_dp_compute_data_proto`` does. - -Finally, a new method is dynamically generated using ``func_generator`` -and added to the ``WorkerGroup`` instance: - -.. code:: python - - <<>> - # bind a new method to the RayWorkerGroup - func = func_generator( - self, - method_name, - dispatch_fn=dispatch_fn, - collect_fn=collect_fn, - execute_fn=execute_fn, - blocking=blocking, - ) - - try: - setattr(self, method_name, func) - method_names.append(method_name) - except Exception as e: - raise ValueError(f"Fail to set method_name {method_name}") from e - -This makes the method invocable via the ``WorkerGroup`` interface. - --------------- - -Step 3: Call Chain -~~~~~~~~~~~~~~~~~~ - -All the machinery above ensures that distributed calls feel identical to -single-process ones. In the original single-process script, the code -looks like: - -.. code:: python - - rollout = Rollout() - rollout.generate_sequences(batch) - -With ``verl``, the multiprocess program becomes: - -.. code:: python - - rollout = RayWorkerGroup(resource_pool=[4], RayClassWithArgs(Rollout)) - rollout.generate_sequences(batch) - -.. figure:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/call_generate_sequences.png?raw=true - :alt: call_chain_of_generate_sequences - - call_chain_of_generate_sequences - -Behind this simple call: - ``dispatch_fn`` splits input across workers - -``execute_fn`` performs the actual remote invocation - ``collect_fn`` -gathers the results - -All of this is abstracted away, enabling developers to write distributed -code with minimal changes to their existing logic. - --------------- - -Beyond RL Post-Training: Generalizing ``verl.single_controller`` ----------------------------------------------------------------- - -The ``verl.single_controller`` module generalizes well beyond -reinforcement learning. It provides a clean abstraction to batch-process -remote method calls, with automatic input/output handling. - -By minimizing the gap between single-process and multi-process scripts, -``verl.single_controller`` opens the door to distributed computing in -broader domains — not limited to RL post-training. - -We hope this design inspires more examples and extensions from the -community. diff --git a/docs/start/agentic_rl.rst b/docs/start/agentic_rl.rst deleted file mode 100644 index 60af79f5f..000000000 --- a/docs/start/agentic_rl.rst +++ /dev/null @@ -1,125 +0,0 @@ -Agentic RL Training -=================== - -Last updated: 07/15/2025. - -Overview ----------- -The goal of Agentic RL is to improve the performance of backend models from reinforcement learning to the Agent. During the training process, a series of features are developed: - -1. Server-based asynchronous rollout -2. Multi-turn conversations and tool calls -3. LangGraph-based Agent - - -This document explains the system principles and usage involved to help users implement Agentic RL. - - -Server-based Asynchronous Rollout ---------------------------------- - -Since Agents need to interact with the environment through various tool calls, in order to avoid GPU idling while waiting for tool call return results, an asyncio based co-routing mechanism is utilized to execute each rollout requests asynchronously, thereby improving training performance. To support asynchronous rollout, the inference engine (server) and the agent (client) are architecturally separated, implementing a server-based system with the following objectives: - -1. Enabling load balancing mechanisms to balance loads across multiple GPUs and reduce the impact of long-tail requests on performance. For this purpose, scheduling capabilities in stream mode (recipe\stream_mode) are implemented as a recipe. -2. Preventing agent specific features such as tracing from affecting the inference engine. - -System Architecture -~~~~~~~~~~~~~~~~~~~ - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/agent_loop.png?raw=true - -System Components -~~~~~~~~~~~~~~~~~ - -+--------------------------+----------------------------------------------------------------------------+ -| Component | Role | -+==========================+============================================================================+ -| AgentLoop | Client, implements Agent functions | -+--------------------------+----------------------------------------------------------------------------+ -| AsyncLLMServerManager | Inference gateway, provides generate interface for AgentLoop | -+--------------------------+----------------------------------------------------------------------------+ -| AsyncServer | Server, each instance is connected to one DP group of the inference engine | -+--------------------------+----------------------------------------------------------------------------+ - -**"generate" Interface** - -The "generate" function based on ray actor is used between the Client and Server instead of the standard chat completion API. This is because the conversion between tokens and text can be irreversible. For example, the token converted from "" will be different from that generated by the LLM. During the training phase, it is necessary to strictly use the tokens generated by LLM inference to avoid inaccurate in computing advantage, which may affect model performance. Having the Server provide a token-based API helps the Client maintain the relationship between the text generated by tool calls and the tokens returned by the LLM, so as to output correct tokens for training. - - -**Inference Engine Adaptation** -AsyncServer uniformly provides a generate function to the upper layer, with separate implementations for SGLang and vLLM to hide underlying differences: - -1. The SGLang AsyncServer uses the async_generate interface of the SGLang engine, which is located on the first GPU of each TP group. Therefore, AsyncServer needs to remotely call async_generate through ray actor. -2. The vLLM AsyncServer uses the generate interface of the vLLM engine, which can communicate with the GPUs in the TP group through ZMQ and can be directly called in AsyncServer. - - -Usage Example -~~~~~~~~~~~~~ - -Follow :doc:`GSM8K example<../examples/gsm8k_example>` to prepare the dataset and model checkpoints. -This example uses the sglang inference engine by default, and you can also modify rollout_name to use vllm. - -.. code-block:: bash - - bash examples/grpo_trainer/run_qwen2-7b_seq_balance.sh - - -Multi-turn Conversations and Tool Calls ---------------------------------------- - -Follow :doc:`Multi-turn Rollout Support<../sglang_multiturn/multiturn>` to prepare tool and configuration files. - -The Tool Agent Loop has an additional requirement: adding an "agent_name" field to the dataset. During rollout, it will choose to use tool_agent_loop or single_turn_agent (default) based on this field. - -Usage Example -~~~~~~~~~~~~~ - -.. code-block:: bash - - # install mlflow to view toolcall and llm trace - pip install mlflow - - # This will download and preprocess the GSM8K dataset into ~/data/gsm8k/ and add the "agent_name" field. - bash examples/data_preprocess/gsm8k_tool_agent_loop.py - - # Start training with tool calls and enabled mlflow based trace helping to debug the rollout details - bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh - - # When training is done, start a mlflow server to view trace - mlflow ui -h 0.0.0.0 -p 5000 --backend-store-uri sqlite:////tmp/mlruns.db - - # then you can open http://:5000 from browser to view trace - - -Note: During training, because the model may sometimes fail to generate correct toolcall tags, an error message "Failed to decode tool call" will be output to the console, which does not indicate an abnormality in training. - - -Follow :doc:`Rollout trace<../advance/rollout_trace>` to known more about trace feature. - - - -Agent Framework ---------------- - -System Architecture -~~~~~~~~~~~~~~~~~~~ - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/langgraph_agent.png?raw=true - -System Components -~~~~~~~~~~~~~~~~~ - -+--------------------------+-----------------------------------------------------------------------------------------------+ -| Component | Role | -+==========================+===============================================================================================+ -| ChatModel | LLM object of LangChain, used to adapt to the “generate” api provided by AsyncLLMServerManager| -+--------------------------+-----------------------------------------------------------------------------------------------+ -| RectAgentLoop | Agent adaptation layer, which by default supports a naive LangGraph Agentic. | -| | New classes can be derived to support user-defined Agents, and the run function needs to be | -| | implemented to complete Agent calls. | -+--------------------------+-----------------------------------------------------------------------------------------------+ -| AsyncServer | Server, each instance is connected to one DP group of the inference engine. | -+--------------------------+-----------------------------------------------------------------------------------------------+ - - -Follow doc "recipe/langgraph_agent/example/README.md" for more details. \ No newline at end of file diff --git a/docs/start/install.rst b/docs/start/install.rst deleted file mode 100644 index 12c9c3531..000000000 --- a/docs/start/install.rst +++ /dev/null @@ -1,341 +0,0 @@ -Installation -============ - -Requirements ------------- - -- **Python**: Version >= 3.9 -- **CUDA**: Version >= 12.1 - -verl supports various backends. Currently, the following configurations are available: - -- **FSDP** and **Megatron-LM** (optional) for training. -- **SGLang**, **vLLM** and **TGI** for rollout generation. - -Choices of Backend Engines ----------------------------- - -1. Training: - -We recommend using **FSDP** backend to investigate, research and prototype different models, datasets and RL algorithms. The guide for using FSDP backend can be found in :doc:`FSDP Workers<../workers/fsdp_workers>`. - -For users who pursue better scalability, we recommend using **Megatron-LM** backend. Currently, we support `Megatron-LM v0.12.1 `_. The guide for using Megatron-LM backend can be found in :doc:`Megatron-LM Workers<../workers/megatron_workers>`. - - -2. Inference: - -For inference, vllm 0.8.3 and later versions have been tested for stability. We recommend turning on env var `VLLM_USE_V1=1` for optimal performance. - -For SGLang, refer to the :doc:`SGLang Backend<../workers/sglang_worker>` for detailed installation and usage instructions. SGLang rollout is under extensive development and offers many advanced features and optimizations. We encourage users to report any issues or provide feedback via the `SGLang Issue Tracker `_. - -For huggingface TGI integration, it is usually used for debugging and single GPU exploration. - -Install from docker image -------------------------- - -We provide pre-built Docker images for quick setup. And from this version, -we utilize a new image release hierarchy for productivity and stability. - -The image types are divided into three large categories: - -- **Base Image**: Without inference and training frameworks, only basic dependencies are installed. - Can directly install vllm or SGLang on top of it, without need of reinstall torch or CUDA. -- **Application Image**: Stable version with inference and training frameworks installed. -- **Community Image**: Unstable version with the latest frameworks and features. - -The first two types of images are hosted on dockerhub `verlai/verl `_ repository, while the preview images are hosted on community repository. - -.. note:: - - The image versions are mapped with verl releases, for example, image with tag ``verl0.4`` is built for verl release ``v0.4.x``. - -Base Image -:::::::::: - -The stable base image is ``verlai/verl:base-verl0.4-cu124-cudnn9.8-torch2.6-fa2.7.4``. The installed package versions can be found from tags, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.base``. - -The base images for preview are ``verlai/verl:base-verl0.5-cu126-cudnn9.8-torch2.7.1-fa2.8.0` and ``verlai/verl:base-verl0.5-preview-cu128-cudnn9.8-torch2.7.1-fa2.8.0`` with different CUDA versions. From verl0.5, images are built with `Deep-EP `_ for efficient EP communication. - -The update of base image is not frequent, and the app image can be built on top of it without reinstalling base packages. - -Application Image -::::::::::::::::: - -From this version, we divide images built for vLLM and SGLang as the divergence of dependent packages like FlashInfer. - -There are four types of application images available: - -- **vLLM with FSDP and Megatron**: ``verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1``, with Deep-EP support: ``verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.1-deepep``. -- **SGLang with FSDP and Megatron**: ``verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1`` (need vLLM support, but can have some package conflicts), with Deep-EP support: ``verlai/verl:app-verl0.4-sglang0.4.6.post5-vllm0.8.5-mcore0.12.1-deepep``. -- **Preview version of SGLang with FSDP and Megatron, CUDA 12.6**: ``verlai/verl:app-verl0.5-sglang0.4.8-mcore0.12.1`` -- **Preview version of SGLang with FSDP and Megatron, CUDA 12.8**: ``verlai/verl:app-preview-verl0.5-sglang0.4.8-mcore0.12.1`` - -The latest vLLM support is coming soon. - -Docker images with Megatron backends are runnable with large language model like ``Qwen/Qwen3-235B-A22B``, ``deepseek-ai/DeepSeek-V3-0324`` post-training. Refer to the :doc:`Large Language Model Post-Training documentation<../perf/dpsk>` for more details. - -Application images can be updated frequently, and the Dockerfile can be found in ``docker/verl[version]-[packages]/Dockerfile.app.[frameworks]``. Based on the base image, it is easy to build your own application image with the desired inference and training frameworks. - -Community Image -::::::::::::::: - -Community images are provided by the community, including the latest versions of vLLM and SGLang, and may include experimental features or configurations. And also works for other hardwares or platforms like AMD GPUs with ROCM or AWS EFA and Sagemaker. - -For latest vLLM with FSDP, please refer to `hiyouga/verl `_ repository and the latest version is ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``. - -For latest SGLang with FSDP, please refer to `ocss884/verl-sglang `_ repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group. - -See files under ``docker/`` for NGC-based image or if you want to build your own. - -Note that For aws instances with EFA net interface (Sagemaker AI Pod), -you need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa`` - -Installation from Docker -:::::::::::::::::::::::: - -After pulling the desired Docker image and installing desired inference and training frameworks, you can run it with the following steps: - -1. Launch the desired Docker image and attach into it: - -.. code:: bash - - docker create --runtime=nvidia --gpus all --net=host --shm-size="10g" --cap-add=SYS_ADMIN -v .:/workspace/verl --name verl sleep infinity - docker start verl - docker exec -it verl bash - - -2. If you use the images provided, you only need to install verl itself without dependencies: - -.. note:: - - # install the nightly version (recommended) - git clone https://github.com/volcengine/verl && cd verl - pip3 install --no-deps -e . - -[Optional] If you hope to switch between different frameworks, you can install verl with the following command: - -.. note:: - - # install the nightly version (recommended) - git clone https://github.com/volcengine/verl && cd verl - pip3 install -e .[vllm] - pip3 install -e .[sglang] - - -Install from custom environment ---------------------------------------------- - -We recommend to use docker images for convenience. However, if your environment is not compatible with the docker image, you can also install verl in a python environment. - - -Pre-requisites -:::::::::::::: - -For training and inference engines to utilize better and faster hardware support, CUDA/cuDNN and other dependencies are required, -and some of the dependencies are easy to be overridden when installing other packages, -so we put them in the :ref:`Post-installation` step. - -.. note:: - - The installation steps below are recommended configurations for the latest version of verl. - If you are trying to customize your own environment, please ignore the strict constraints. - -We need to install the following pre-requisites: - -- **CUDA**: Version >= 12.4 -- **cuDNN**: Version >= 9.8.0 -- **Apex** - -CUDA above 12.4 is recommended to use as the docker image, -please refer to `NVIDIA's official website `_ for other version of CUDA. - -.. code:: bash - - # change directory to anywher you like, in verl source code directory is not recommended - wget https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb - dpkg -i cuda-repo-ubuntu2204-12-4-local_12.4.1-550.54.15-1_amd64.deb - cp /var/cuda-repo-ubuntu2204-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/ - apt-get update - apt-get -y install cuda-toolkit-12-4 - update-alternatives --set cuda /usr/local/cuda-12.4 - - -cuDNN can be installed via the following command, -please refer to `NVIDIA's official website `_ for other version of cuDNN. - -.. code:: bash - - # change directory to anywher you like, in verl source code directory is not recommended - wget https://developer.download.nvidia.com/compute/cudnn/9.8.0/local_installers/cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb - dpkg -i cudnn-local-repo-ubuntu2204-9.8.0_1.0-1_amd64.deb - cp /var/cudnn-local-repo-ubuntu2204-9.8.0/cudnn-*-keyring.gpg /usr/share/keyrings/ - apt-get update - apt-get -y install cudnn-cuda-12 - -NVIDIA Apex is required for Megatron-LM and FSDP training. -You can install it via the following command, but notice that this steps can take a very long time. -It is recommended to set the ``MAX_JOBS`` environment variable to accelerate the installation process, -but do not set it too large, otherwise the memory will be overloaded and your machines may hang. - -.. code:: bash - - # change directory to anywher you like, in verl source code directory is not recommended - git clone https://github.com/NVIDIA/apex.git && \ - cd apex && \ - MAX_JOB=32 pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ - - -Install dependencies -:::::::::::::::::::: - -.. note:: - - We recommend to use a fresh new conda environment to install verl and its dependencies. - - **Notice that the inference frameworks often strictly limit your pytorch version and will directly override your installed pytorch if not paying enough attention.** - - As a countermeasure, it is recommended to install inference frameworks first with the pytorch they needed. For vLLM, if you hope to use your existing pytorch, - please follow their official instructions - `Use an existing PyTorch installation `_ . - - -1. First of all, to manage environment, we recommend using conda: - -.. code:: bash - - conda create -n verl python==3.10 - conda activate verl - - -2. Then, execute the ``install.sh`` script that we provided in verl: - -.. code:: bash - - # Make sure you have activated verl conda env - # If you need to run with megatron - bash scripts/install_vllm_sglang_mcore.sh - # Or if you simply need to run with FSDP - USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh - - -If you encounter errors in this step, please check the script and manually follow the steps in the script. - - -Install verl -:::::::::::: - -For installing the latest version of verl, the best way is to clone and -install it from source. Then you can modify our code to customize your -own post-training jobs. - -.. code:: bash - - git clone https://github.com/volcengine/verl.git - cd verl - pip install --no-deps -e . - - -Post-installation -::::::::::::::::: - -Please make sure that the installed packages are not overridden during the installation of other packages. - -The packages worth checking are: - -- **torch** and torch series -- **vLLM** -- **SGLang** -- **pyarrow** -- **tensordict** -- **nvidia-cudnn-cu12**: For Magetron backend - -If you encounter issues about package versions during running verl, please update the outdated ones. - - -Install with AMD GPUs - ROCM kernel support ------------------------------------------------------------------- - -When you run on AMD GPUs (MI300) with ROCM platform, you cannot use the previous quickstart to run verl. You should follow the following steps to build a docker and run it. -If you encounter any issues in using AMD GPUs running verl, feel free to contact me - `Yusheng Su `_. - -Find the docker for AMD ROCm: `docker/Dockerfile.rocm `_ -:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: - -.. code-block:: bash - - # Build the docker in the repo dir: - # docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 . - # docker images # you can find your built docker - FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 - - # Set working directory - # WORKDIR $PWD/app - - # Set environment variables - ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" - - # Install vllm - RUN pip uninstall -y vllm && \ - rm -rf vllm && \ - git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \ - cd vllm && \ - MAX_JOBS=$(nproc) python3 setup.py install && \ - cd .. && \ - rm -rf vllm - - # Copy the entire project directory - COPY . . - - # Install dependencies - RUN pip install "tensordict<0.6" --no-deps && \ - pip install accelerate \ - codetiming \ - datasets \ - dill \ - hydra-core \ - liger-kernel \ - numpy \ - pandas \ - datasets \ - peft \ - "pyarrow>=15.0.0" \ - pylatexenc \ - "ray[data,train,tune,serve]" \ - torchdata \ - transformers \ - wandb \ - orjson \ - pybind11 && \ - pip install -e . --no-deps - -Build the image -:::::::::::::::::::::::: - -.. code-block:: bash - - docker build -t verl-rocm . - -Launch the container -:::::::::::::::::::::::::::: - -.. code-block:: bash - - docker run --rm -it \ - --device /dev/dri \ - --device /dev/kfd \ - -p 8265:8265 \ - --group-add video \ - --cap-add SYS_PTRACE \ - --security-opt seccomp=unconfined \ - --privileged \ - -v $HOME/.ssh:/root/.ssh \ - -v $HOME:$HOME \ - --shm-size 128G \ - -w $PWD \ - verl-rocm \ - /bin/bash - -If you do not want to root mode and require assign yourself as the user, -Please add ``-e HOST_UID=$(id -u)`` and ``-e HOST_GID=$(id -g)`` into the above docker launch script. - -verl with AMD GPUs currently supports FSDP as the training engine, vLLM and SGLang as the inference engine. We will support Megatron in the future. diff --git a/docs/start/more_resources.rst b/docs/start/more_resources.rst deleted file mode 100644 index aa8cb2a62..000000000 --- a/docs/start/more_resources.rst +++ /dev/null @@ -1,7 +0,0 @@ -More Resources -============== - -Last updated: 06/30/2025. - -- Introduction to verl (`Slides `_) -- verl Code Walkthrough (`Slides `_, `Talk in Chinese `_) diff --git a/docs/start/multinode.rst b/docs/start/multinode.rst deleted file mode 100644 index 9e058055d..000000000 --- a/docs/start/multinode.rst +++ /dev/null @@ -1,591 +0,0 @@ -Multinode Training -================== - -Last updated: 06/10/2025. - -.. _wuxibin89: https://github.com/wuxibin89 - -Author: `Xibin Wu `_, `Yusheng Su `_. - -Manual ------- - -Set up multinode ray cluster -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -1. Start head node with ``ray start --head --dashboard-host=0.0.0.0``, there're 2 address you should care about: - -- GCS address: ``ray start --address=
``, where worker node should connect to. -- Dashboard address: ``
:8265``, where you should submit job to the cluster. - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/head.png?raw=true - -2. Start worker node with ``ray start --address=
`` you get above. - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/worker.png?raw=true - -3. Now you should see the cluster have 2 nodes with ``ray status``. - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/status.png?raw=true - -4. Additionally, you can access dashboard in the browser with the address you get above. - -*Firewall rules maybe need configure to access the dashboard, if there's any trouble, please contact your network administrator.* - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/overview.png?raw=true - -Submit job to ray cluster -~~~~~~~~~~~~~~~~~~~~~~~~~ -1. Submit ray job to cluster with the dashboard address you get above. - -.. code-block:: bash - - ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env=verl/trainer/runtime_env.yaml \ - --no-wait \ - -- \ - python3 -m verl.trainer.main_ppo \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=2 \ - ... - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/submit.png?raw=true - -2. Then you can check the job status with the following commands: - -- ray job list: list all jobs submitted to the cluster. -- ray job logs : query the logs of the job. -- ray job status : query the status of the job. -- ray job stop : request the job to be stopped. - -3. You can also access driver/task/actor logs in ``/tmp/ray/session_latest/logs/``, driver log is ``job-driver-raysubmit_.log``. - -4. We strongly recommend you to view job detail from dashboard in multinode training, because it provide more structure way to view the job information. - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/job.png?raw=true -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/job_detail.png?raw=true - - -Slurm ------ -TBD - -dstack ------- -`dstackai/dstack `_ is an open-source container orchestrator that simplifies distributed training across cloud providers and on-premises environments -without the need to use K8S or Slurm. - -Prerequisite -~~~~~~~~~~~~ -Once dstack is `installed `_, initialize the directory as a repo with ``dstack init``. - -.. code-block:: bash - - mkdir myproject && cd myproject - dstack init - -**Create a fleet** - -Before submitting distributed training jobs, create a `dstack` `fleet `_. - -Run a Ray cluster task -~~~~~~~~~~~~~~~~~~~~~~ - -Once the fleet is created, define a Ray cluster task, e.g. in ``ray-cluster.dstack.yml``: - -.. code-block:: yaml - - type: task - name: ray-verl-cluster - - nodes: 2 - - env: - - WANDB_API_KEY - - PYTHONUNBUFFERED=1 - - CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.2 - commands: - - git clone https://github.com/volcengine/verl - - cd verl - - pip install --no-deps -e . - - pip install hf_transfer hf_xet - - | - if [ $DSTACK_NODE_RANK = 0 ]; then - python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k - python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-7B-Instruct')" - ray start --head --port=6379; - else - ray start --address=$DSTACK_MASTER_NODE_IP:6379 - fi - - # Expose Ray dashboard port - ports: - - 8265 - - resources: - gpu: 80GB:8 - shm_size: 128GB - - # Save checkpoints on the instance - volumes: - - /checkpoints:/checkpoints - -Now, if you run this task via `dstack apply`, it will automatically forward the Ray's dashboard port to `localhost:8265`. - -.. code-block:: bash - - dstack apply -f ray-cluster.dstack.yml - -As long as the `dstack apply` is attached, you can use `localhost:8265` to submit Ray jobs for execution - -Submit Ray jobs -~~~~~~~~~~~~~~~ - -Before you can submit Ray jobs, ensure to install `ray` locally: - -.. code-block:: shell - - pip install ray - -Now you can submit the training job to the Ray cluster which is available at ``localhost:8265``: - -.. code-block:: shell - - $ RAY_ADDRESS=http://localhost:8265 - $ ray job submit \ - -- python3 -m verl.trainer.main_ppo \ - data.train_files=/root/data/gsm8k/train.parquet \ - data.val_files=/root/data/gsm8k/test.parquet \ - data.train_batch_size=256 \ - data.max_prompt_length=512 \ - data.max_response_length=256 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - critic.optim.lr=1e-5 \ - critic.model.path=Qwen/Qwen2.5-7B-Instruct \ - critic.ppo_micro_batch_size_per_gpu=4 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.project_name=ppo_training \ - trainer.experiment_name=qwen-2.5-7B \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=2 \ - trainer.default_local_dir=/checkpoints \ - trainer.save_freq=10 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 2>&1 | tee verl_demo.log \ - trainer.resume_mode=disable - - -For more details on how `dstack` works, check out its `documentation `_. - -How to debug? ---------------------- - - -Ray Distributed Debugger VSCode Extension (Recommended) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -1. Starting with Ray 2.39, Anyscale has introduced the `Ray Distributed Debugger `_ VSCode extension. Follow the extension’s installation instructions, then add your cluster using the dashboard URL you obtained earlier. - - .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/debugger.png?raw=true - :alt: Ray Distributed Debugger VSCode extension screenshot - -2. Prerequisites. - - Ensure the following are installed (see the extension README for more detail): - - - Visual Studio Code - - `ray[default]` >= 2.9.1 - - `debugpy` >= 1.8.0 - - .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/c7098b755ff689859837773a916c857.png?raw=true - :alt: VSCode with Ray prerequisites - -3. Environment Variables. - - To enable post‑mortem debugging, set: - - .. code-block:: bash - - export RAY_DEBUG_POST_MORTEM=1 - - .. admonition:: Note - :class: important - - Be sure to remove any legacy flags before starting Ray: - - - `RAY_DEBUG=legacy` - - `--ray-debugger-external` - -4. Configuring BreakpointsSet up breakpoint() in your code, and submit job to cluster. Then the extension will show the breakpoint information. - - - 1. Insert `breakpoint()` calls into your remote functions. - 2. Submit your job to the cluster. - - The extension will detect active breakpoints and display them in VSCode. - - .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/4ddad74395c79a1402331c0ce73316f.png?raw=true - :alt: Detected breakpoint in VSCode - - **Note:** Breakpoints are only supported inside functions decorated with `@ray.remote`. - -5. Launching the Debugger. - - Run your job directly from the command line (do not use a `launch.json`): - - .. code-block:: bash - - python job.py - -6. Attaching to a Breakpoint. - - Once the process hits the first `breakpoint()`, click the Ray Distributed Debugger icon in the VSCode sidebar to attach the debugger. - - .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/4ddad74395c79a1402331c0ce73316f.png?raw=true - :alt: Attaching VSCode debugger to Ray process - -7. Debugging With Multiple breakpoint(). - - For each subsequent task, first disconnect the current debugger session, then click the extension icon again to attach to the next breakpoint. - - .. image:: https://github.com/aoshen524/verl/blob/main/docs/start/6e83c910a62c82fecb89c6619e001cd.png?raw=true - :alt: Disconnecting and reconnecting the debugger - -Legacy Ray Debugger -~~~~~~~~~~~~~~~~~~~ -1. Ray has a builtin legacy `debugger `_ that allows you to debug your distributed applications. To enable debugger, start ray cluster with ``RAY_DEBUG=legacy`` and ``--ray-debugger-external``. - -.. code-block:: bash - - # start head node - RAY_DEBUG=legacy ray start --head --dashboard-host=0.0.0.0 --ray-debugger-external - # start worker node - RAY_DEBUG=legacy ray start --address='10.124.46.192:6379' --ray-debugger-external - -2. Set up breakpoint in your code, and submit job to cluster. Then run ``ray debug`` to wait breakpoint: - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/legacy.png?raw=true - - -Multi-node training on AMD clusters ---------------------------------------------------------------------------------------- - -If you want to run multi-node training with slurm with Docker/Podman container on AMD Cluster, you can use the following script. - -If you encounter any issues in using AMD GPUs running verl, please contact `Yusheng Su `_. - -.. note:: - 1. You need to use ``podman`` or ``docker`` in the following script. We will release the apptainer script later. - 2. If you want to use ``podman``, you just replace ``docker`` with ``podman`` in the following script. - -The script includes the following steps: - -1. SLURM Configuration -2. Environment Setup -3. Docker/Podman Container Setup -4. Ray Cluster Initialization -5. Data Preprocessing -6. Model Setup -7. Training Launch - - -slurm_script.sh -~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: bash - - #!/bin/bash - - #SBATCH --job-name=verl-ray-on-slurm - #SBATCH --nodes=2 - #SBATCH --ntasks-per-node=2 - #SBATCH --mem=200G - #SBATCH --time=30-00:00:00 - #SBATCH --gpus-per-node=8 - #SBATCH --cpus-per-task=28 - #SBATCH --output=../verl_log/slurm-%j.out - #SBATCH --error=../verl_log/slurm-%j.err - #SBATCH --nodelist=gpu-[0,1] - - - # load necessary modules - ### Run this setup - # [Cluster]: Use docker - # docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 - - - ########################################################################## - ###The following setting should be set in different project and cluster### - ########################################################################## - - ### Project - CONTAINER_NAME="multinode_verl_training" - IMG="verl.rocm" - DOCKERFILE="docker/Dockerfile.rocm" - # echo $PWD - verl_workdir="${HOME}/projects/verl_upstream" - export TRANSFORMERS_CACHE="${HOME}/.cache/huggingface" - export HF_HOME=$TRANSFORMERS_CACHE - - ### Cluster Network Setting - export NCCL_DEBUG=TRACE - export GPU_MAX_HW_QUEUES=2 - export TORCH_NCCL_HIGH_PRIORITY=1 - export NCCL_CHECKS_DISABLE=1 - # export NCCL_IB_HCA=rdma0,rdma1,rdma2,rdma3,rdma4,rdma5,rdma6,rdma7 - export NCCL_IB_HCA=mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_8,mlx5_9 - export NCCL_IB_GID_INDEX=3 - export NCCL_CROSS_NIC=0 - export CUDA_DEVICE_MAX_CONNECTIONS=1 - export NCCL_PROTO=Simple - export RCCL_MSCCL_ENABLE=0 - export TOKENIZERS_PARALLELISM=false - export HSA_NO_SCRATCH_RECLAIM=1 - ########################################################################## - - ### For rocm and training script - export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 - export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES - export CUDA_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES - - - # Build and launch the Docker container - srun bash -c " - # Exit on any error - set -e - - # Clean up dangling images (images with tag) - docker image prune -f - - # Need to pull the docker first - docker pull docker.io/rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 - - if ! docker images --format "{{.Repository}}:{{.Tag}}" | grep -q "${IMG}"; then - echo \"Building ${IMG} image...\" - docker build -f \"${DOCKERFILE}\" -t \"${IMG}\" . - else - echo \"${IMG} image already exists, skipping build\" - fi - - # Removing old container if exists - docker rm \"${CONTAINER_NAME}\" 2>/dev/null || true - - # Checking network devices - ibdev2netdev - - # Launch the docker - docker run --rm -d \ - -e HYDRA_FULL_ERROR=1 \ - -e HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES} \ - -e ROCR_VISIBLE_DEVICES=${ROCR_VISIBLE_DEVICES} \ - -e CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES} \ - -e NCCL_DEBUG=${NCCL_DEBUG} \ - -e GPU_MAX_HW_QUEUES=${GPU_MAX_HW_QUEUES} \ - -e TORCH_NCCL_HIGH_PRIORITY=${TORCH_NCCL_HIGH_PRIORITY} \ - -e NCCL_CHECKS_DISABLE=${NCCL_CHECKS_DISABLE} \ - -e NCCL_IB_HCA=${NCCL_IB_HCA} \ - -e NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX} \ - -e NCCL_CROSS_NIC=${NCCL_CROSS_NIC} \ - -e CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS} \ - -e NCCL_PROTO=${NCCL_PROTO} \ - -e RCCL_MSCCL_ENABLE=${RCCL_MSCCL_ENABLE} \ - -e TOKENIZERS_PARALLELISM=${TOKENIZERS_PARALLELISM} \ - -e HSA_NO_SCRATCH_RECLAIM=${HSA_NO_SCRATCH_RECLAIM} \ - -e TRANSFORMERS_CACHE=${TRANSFORMERS_CACHE} \ - -e HF_HOME=${HF_HOME} \ - --network host \ - --device /dev/dri \ - --device /dev/kfd \ - --device /dev/infiniband \ - --group-add video \ - --cap-add SYS_PTRACE \ - --security-opt seccomp=unconfined \ - --privileged \ - -v \${HOME}:\${HOME} \ - -v \${HOME}/.ssh:/root/.ssh \ - -w "${verl_workdir}" \ - --shm-size 128G \ - --name \"${CONTAINER_NAME}\" \ - \"${IMG}\" \ - tail -f /dev/null - - echo \"Container setup completed\" - " - # (Optional): If you do not want to root mode and require assign yuorself as the user - # Please add `-e HOST_UID=$(id -u)` and `-e HOST_GID=$(id -g)` into the above docker launch script. - - - - - - ### Ray launch the nodes before training - - # Getting the node names - nodes_array=($(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ')) - - head_node=${nodes_array[0]} - head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) - - # if we detect a space character in the head node IP, we'll - # convert it to an ipv4 address. This step is optional. - if [[ "$head_node_ip" == *" "* ]]; then - IFS=' ' read -ra ADDR <<<"$head_node_ip" - if [[ ${#ADDR[0]} -gt 16 ]]; then - head_node_ip=${ADDR[1]} - else - head_node_ip=${ADDR[0]} - fi - echo "IPV6 address detected. We split the IPV4 address as $head_node_ip" - fi - - port=6379 - ip_head=$head_node_ip:$port - export ip_head - echo "IP Head: $ip_head" - - # make sure we set environment variables before Ray initialization - - # Print out all env variables - printenv - - echo "Starting HEAD at $head_node" - srun --nodes=1 --ntasks=1 -w "$head_node" \ - docker exec "${CONTAINER_NAME}" \ - ray start --head --node-ip-address="$head_node_ip" --port=$port \ - --dashboard-port=8266 \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & - # optional, though may be useful in certain versions of Ray < 1.0. - sleep 10 - - # number of nodes other than the head node - worker_num=$((SLURM_JOB_NUM_NODES - 1)) - - for ((i = 1; i <= worker_num; i++)); do - node_i=${nodes_array[$i]} - echo "Debug: Starting worker on node_i = ${node_i}" - if [ -z "$node_i" ]; then - echo "Error: Empty node name for worker $i" - continue - fi - echo "Starting WORKER $i at $node_i" - srun --nodes=1 --ntasks=1 -w "$node_i" \ - docker exec "${CONTAINER_NAME}" \ - ray start --address "$ip_head" --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & - sleep 5 - done - - - - - # Ray initlization test (See whether any error in the above execution) - echo "Testing Ray initialization in the slurm nodes..." - docker exec "${CONTAINER_NAME}" python3 -c ' - import ray - try: - ray.init(address="auto") - print("\n=== Ray Cluster Status ===") - print(f"Number of nodes: {len(ray.nodes())}") - for node in ray.nodes(): - print("Node: {}, Status: {}".format(node["NodeManagerHostname"], node["Alive"])) - # print(f"Node: {node}") - ray.shutdown() - print("Ray initialization successful!") - except Exception as e: - print(f"Ray initialization failed: {str(e)}") - ' - echo "=== Ray test completed ===" - ###### - - - - # Run data preprocessing - - echo "Starting data preprocessing..." - docker exec "${CONTAINER_NAME}" \ - python3 "examples/data_preprocess/gsm8k.py" "--local_dir" "../data/gsm8k" - - echo "Starting data preprocessing..." - docker exec "${CONTAINER_NAME}" \ - python3 "examples/data_preprocess/math_dataset.py" "--local_dir" "../data/math" - - train_files="../data/gsm8k/train.parquet" - val_files="../data/gsm8k/test.parquet" - - # Download and test model - echo "Loading model..." - docker exec "${CONTAINER_NAME}" \ - python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')" - MODEL_PATH="Qwen/Qwen2-7B-Instruct" - - # Set model path after pipeline test - MODEL_PATH="Qwen/Qwen2.5-0.5B-Instruct" - - echo "== Data and model loading Done ==" - - echo "Start to train..." - - docker exec "${CONTAINER_NAME}" \ - python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2-7B-Instruct')" - MODEL_PATH="Qwen/Qwen2-7B-Instruct" - - - PYTHONUNBUFFERED=1 srun --overlap --nodes=${SLURM_NNODES} --ntasks=1 -w "$head_node" \ - docker exec "${CONTAINER_NAME}" \ - python3 -m verl.trainer.main_ppo \ - data.train_files=$train_files \ - data.val_files=$val_files \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.model.enable_gradient_checkpointing=False \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=$MODEL_PATH \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=8 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.kl_ctrl.kl_coef=0.0001 \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example' \ - trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ - trainer.n_gpus_per_node=${SLURM_GPUS_PER_NODE} \ - trainer.val_before_train=False \ - trainer.nnodes=${SLURM_NNODES} \ - trainer.save_freq=-1 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 - - -Run multi-node training with above slurm_script.sh -~~~~~~~~~~~~~~~~~~~~ -Just sbatch your slurm_script.sh - -.. code-block:: bash - - sbatch slurm_script.sh - diff --git a/docs/start/quickstart.rst b/docs/start/quickstart.rst deleted file mode 100644 index 22b8388a2..000000000 --- a/docs/start/quickstart.rst +++ /dev/null @@ -1,150 +0,0 @@ -.. _quickstart: - -========================================================= -Quickstart: PPO training on GSM8K dataset -========================================================= - -Post-train a LLM using GSM8K dataset. - -Introduction ------------- - -.. _hf_dataset_gsm8k: https://huggingface.co/datasets/gsm8k - -In this example, we train an LLM to tackle the `GSM8k `_ task with function-based rewards. [1]_ - -Prerequisite: - -- the latest version of ``verl`` and its dependencies installed following the installation guide. Using the docker image is recommended. - -- a GPU with at least 24 GB HBM - - -Dataset Introduction --------------------- - -GSM8k is a math problem dataset. The prompt is an elementary school -problem. The LLM model is asked to solve the math problem. Below is an example: - -Prompt - - Katy makes coffee using teaspoons of sugar and cups of water in the - ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups - of water, calculate the number of teaspoonfuls of sugar she used. - -Solution - - The total ratio representing the ingredients she used to make the - coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the - number of teaspoons she used is 7/20, she used 7/20\ *120 = - <<7/20*\ 120=42>>42 #### 42 - -Step 1: Prepare the dataset ----------------------------- - -We preprocess the dataset in parquet format so that (1) it contains necessary fields for computing RL rewards and (2) is faster to read. - -.. code-block:: bash - - python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k - -Step 2: Download a model for post-training -------------------------------------------- - -In this example, we start with the ``Qwen2.5-0.5B-Instruct`` model. - -If you want to perform SFT before RL, refer to the :doc:`Complete GSM8K Example<../examples/gsm8k_example>`, the `sft directory `_ and `SFT Trainer `_ for further details. - -.. code-block:: bash - - python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-0.5B-Instruct')" - -Step 3: Perform PPO training with the instruct model ----------------------------------------------------------------------- - -**Reward Model/Function** - -We use a pre-defined rule-based reward model. We force the model to produce a final -answer following 4 “#” as shown in the solution. We extract the final -answer from both the solution and model's output using regular -expression matching. We assign a reward of 1 to correct -answer, 0.0 to incorrect answer and 0 to no answer. - -For more details, please refer to `verl/utils/reward_score/gsm8k.py `_. - -**Training Script** - -Now let's run PPO training with the dataset and model above. [2]_ - - -Set the ``data.train_files`` ,\ ``data.val_files``, ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on your dataset and model names or paths. -You may set ``VERL_USE_MODELSCOPE=True`` to download models from `modelscope `_ instead of `huggingface `_. - -.. code-block:: bash - - PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=256 \ - data.max_prompt_length=512 \ - data.max_response_length=256 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - critic.optim.lr=1e-5 \ - critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - critic.ppo_micro_batch_size_per_gpu=4 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=console \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=1 \ - trainer.nnodes=1 \ - trainer.save_freq=10 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 2>&1 | tee verl_demo.log - -You are expected to see the following logs, indicating training in progress. The key metric ``val/test_score/openai/gsm8k`` is computed every ``trainer.test_freq`` steps: - -.. code-block:: bash - - step:0 - timing/gen:21.470 - timing/ref:4.360 - timing/values:5.800 - actor/reward_kl_penalty:0.000 - actor/reward_kl_penalty_coeff:0.001 - timing/adv:0.109 - timing/update_critic:15.664 - critic/vf_loss:14.947 - critic/vf_clipfrac:0.000 - critic/vpred_mean:-2.056 - critic/grad_norm:1023.278 - critic/lr(1e-4):0.100 - timing/update_actor:20.314 - actor/entropy_loss:0.433 - actor/pg_loss:-0.005 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/grad_norm:1.992 - actor/lr(1e-4):0.010 - critic/score/mean:0.004 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.004 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:2.360 - critic/advantages/min:-2.280 - critic/returns/mean:0.003 - critic/returns/max:0.000 - critic/returns/min:0.000 - critic/values/mean:-2.045 - critic/values/max:9.500 - critic/values/min:-14.000 - response_length/mean:239.133 - response_length/max:256.000 - response_length/min:77.000 - prompt_length/mean:104.883 - prompt_length/max:175.000 - prompt_length/min:68.000 - step:1 - timing/gen:23.020 - timing/ref:4.322 - timing/values:5.953 - actor/reward_kl_penalty:0.000 - actor/reward_kl_penalty:0.001 - timing/adv:0.118 - timing/update_critic:15.646 - critic/vf_loss:18.472 - critic/vf_clipfrac:0.384 - critic/vpred_mean:1.038 - critic/grad_norm:942.924 - critic/lr(1e-4):0.100 - timing/update_actor:20.526 - actor/entropy_loss:0.440 - actor/pg_loss:0.000 - actor/pg_clipfrac:0.002 - actor/ppo_kl:0.000 - actor/grad_norm:2.060 - actor/lr(1e-4):0.010 - critic/score/mean:0.000 - critic/score/max:0.000 - critic/score/min:0.000 - critic/rewards/mean:0.000 - critic/rewards/max:0.000 - critic/rewards/min:0.000 - critic/advantages/mean:0.000 - critic/advantages/max:2.702 - critic/advantages/min:-2.616 - critic/returns/mean:0.000 - critic/returns/max:0.000 - critic/returns/min:0.000 - critic/values/mean:-2.280 - critic/values/max:11.000 - critic/values/min:-16.000 - response_length/mean:232.242 - response_length/max:256.000 - response_length/min:91.000 - prompt_length/mean:102.398 - prompt_length/max:185.000 - prompt_length/min:70.000 - -Checkout ``Algorithm Baselines`` page for full training and validation logs for reference. - -The checkpoint is saved at the following dir by default: ``checkpoints/${trainer.project_name}/${trainer.experiment_name}``. You can merge the saved checkpoints to huggingface model using ``verl.model_merger`` module, for example: - -.. code-block:: bash - - python3 -m verl.model_merger merge \ - --backend fsdp \ - --local_dir checkpoints/${trainer.project_name}/${trainer.experiment_name}/global_step_1/actor \ - --target_dir checkpoints/${trainer.project_name}/${trainer.experiment_name}/global_step_1/actor/huggingface - -For more details about checkpoint and model merging, please refer to :ref:`checkpoint-page`. - -To enable ``wandb`` for experiment tracking, set the following configs: - -.. code-block:: bash - - trainer.logger='["console","wandb"]' \ - trainer.project_name=$YOUR_PROJECT_NAME \ - trainer.experiment_name=$YOUR_RUN_NAME \ - -If you encounter out of memory issues with HBM less than 32GB, enable the following configs would help: - -.. code-block:: bash - - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - critic.ppo_micro_batch_size_per_gpu=1 \ - -For the full set of configs, please refer to :ref:`config-explain-page` for detailed explanation and performance tuning. - - -.. [1] The original paper (https://arxiv.org/pdf/2110.14168) mainly focuses on training a verifier (a reward model) to solve math problems via Best-of-N sampling. In this example, we train an RL agent using a rule-based reward model. -.. [2] More training script examples for FSDP and Megatron-LM backend are stored in `examples/ppo_trainer `_ directory. diff --git a/docs/start/ray_debug_tutorial.rst b/docs/start/ray_debug_tutorial.rst deleted file mode 100644 index 9e7c87dfa..000000000 --- a/docs/start/ray_debug_tutorial.rst +++ /dev/null @@ -1,96 +0,0 @@ -Ray Debug Tutorial -================== - -Last updated: 04/23/2025 - - -.. _wuxibin89: https://github.com/wuxibin89 - -Author: `Ao Shen `_. - -How to debug? ---------------------- - - -Ray Distributed Debugger VSCode Extension (Recommended) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -1. Starting with Ray 2.39, Anyscale has introduced the `Ray Distributed Debugger `_ VSCode extension. Follow the extension’s installation instructions, then add your cluster using the dashboard URL you obtained earlier. - - .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/debugger.png?raw=true - :alt: Ray Distributed Debugger VSCode extension screenshot - -2. Prerequisites. - - Ensure the following are installed (see the extension README for more detail): - - - Visual Studio Code - - `ray[default]` >= 2.9.1 - - `debugpy` >= 1.8.0 - - .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/readme.png?raw=true - :alt: VSCode with Ray prerequisites - -3. Environment Variables. - - To enable post‑mortem debugging, set: - - .. code-block:: bash - - export RAY_DEBUG_POST_MORTEM=1 - - .. admonition:: Note - :class: important - - Be sure to remove any legacy flags before starting Ray: - - - `RAY_DEBUG=legacy` - - `--ray-debugger-external` - -4. Configuring BreakpointsSet up breakpoint() in your code, and submit job to cluster. Then the extension will show the breakpoint information. - - - 1. Insert `breakpoint()` calls into your remote functions. - 2. Submit your job to the cluster. - - The extension will detect active breakpoints and display them in VSCode. - - **Note:** Breakpoints are only supported inside functions decorated with `@ray.remote`. - -5. Launching the Debugger. - - Run your job directly from the command line (do not use a `launch.json`): - - .. code-block:: bash - - python job.py - -6. Attaching to a Breakpoint. - - Once the process hits the first `breakpoint()`, click the Ray Distributed Debugger icon in the VSCode sidebar to attach the debugger. - - .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/launch.png?raw=true - :alt: Attaching VSCode debugger to Ray process - -7. Debugging With Multiple breakpoint(). - - For each subsequent task, first disconnect the current debugger session, then click the extension icon again to attach to the next breakpoint. - - .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/disconnect.png?raw=true - :alt: Disconnecting and reconnecting the debugger - -Legacy Ray Debugger -~~~~~~~~~~~~~~~~~~~ -1. Ray has a builtin legacy `debugger `_ that allows you to debug your distributed applications. To enable debugger, start ray cluster with ``RAY_DEBUG=legacy`` and ``--ray-debugger-external``. - -.. code-block:: bash - - # start head node - RAY_DEBUG=legacy ray start --head --dashboard-host=0.0.0.0 --ray-debugger-external - # start worker node - RAY_DEBUG=legacy ray start --address='10.124.46.192:6379' --ray-debugger-external - -2. Set up breakpoint in your code, and submit job to cluster. Then run ``ray debug`` to wait breakpoint: - -.. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/ray/legacy.png?raw=true - diff --git a/docs/workers/fsdp_workers.rst b/docs/workers/fsdp_workers.rst deleted file mode 100644 index b158fb265..000000000 --- a/docs/workers/fsdp_workers.rst +++ /dev/null @@ -1,144 +0,0 @@ -PyTorch FSDP Backend -====================== - -Last updated: 02/12/2025. - -We support PyTorch FSDP Backend by implementing various workers for -actor, critic, reference, rollout and reward models. We also implement -the ``FSDPVLLMShardingManager`` that reshard weight between FSDP and -vLLM in `fsdp_vllm.py `_. - -**Pros** - -- Readily support various models. - - - Users only need to implement the corresponding - ``dtensor_weight_loader`` for weight synchronization between FSDP - and vLLM. While for ``hf_weight_loader``, users can directly apply - any models supported both in HF and vLLM without any code change. - -- Easy to organize the forward and backward computation for each model. - -**Cons** - -- Poor scalability when it comes to large-scale models (e.g. Llama 70B - and 405B) -- The resharding overhead between actor and rollout could be larger than - Megatron-LM backend. - -Due to the simplicity, we recommend using FSDP backend for algorithm -research and prototyping. - -FSDP Workers --------------- - -ActorRolloutRefWorker -^^^^^^^^^^^^^^^^^^^^^ - -Actor/Rollout HybridEngine -'''''''''''''''''''''''''' - -1. HybridEngine, Actor and Rollout initialization API. - -.. code:: python - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - -``ONE_TO_ALL``: when calling the ``init_model`` function from the driver -process, each worker (on a GPU) will execute the following model -initialization process. - -The initialization details of HybridEngine, Actor and Rollout are -highlighted below: - -1. ``DataParallelPPOActor`` implements the simple PPO computation logics - when the model is built with FSDP, including compute log prob, model - update. -2. ``vLLMRollout`` support generation with vLLM. We modify the vLLM - Engine and make it executed under SPMD to fit into our - ``WorkerGroup`` design. -3. ``FSDPVLLMShardingManager`` a context manager to perform actual - resharding between actor and rollout. - -See `source code `_. for more information. - -1. Generate sequence and recompute log prob - -.. code:: python - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def generate_sequences(self, prompts: DataProto): - -- ``Dispatch.DP_COMPUTE_PROTO``: The data will be dispatched and - collected along the DP dimension - -- In this function, the rollout model will perform auto-regressive - generation and the actor model will recompute the old log prob for the - generated response. - -3. Update actor model - -.. code:: python - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def update_actor(self, data: DataProto): - -- Update the actor model weight using PPO & entropy loss. - -ReferenceModel -'''''''''''''' - -1. Reference model initialization - -The reference model is initialized using the same function as the actor -model without initializing the HybridEngine and Optimizer. Then the -actor model is also wrapped by the ``DataParallelPPOActor``. - -2. Compute reference log prob - -.. code:: python - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_ref_log_prob(self, data: DataProto): - -- In this function, the reference model will call the compute log prob - function in ``DataParallelPPOActor`` to compute the reference log - prob. - -CriticWorker and RewardWorker -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -1. Model initialization - -Quite similar to reference model. The CriticWorker will perform -additional initialization for the Optimizer. - -2. Compute Values for CriticWorker - -.. code:: python - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_values(self, data: DataProto): - -3. Update Critic - -.. code:: python - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def update_critic(self, data: DataProto): - -4. Compute Reward - -.. code:: python - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_rm_score(self, data: DataProto): - - -HybridShard ------------- - -We didn't support FSDP `HybridShard`. To support this, we may need to -construct a 2D device mesh and test the corresponding -``dtensor_weight_loader`` and ``hf_weight_loader`` for each model. diff --git a/docs/workers/megatron_workers.rst b/docs/workers/megatron_workers.rst deleted file mode 100644 index b93bd033c..000000000 --- a/docs/workers/megatron_workers.rst +++ /dev/null @@ -1,304 +0,0 @@ -Megatron-LM Backend -=================== - -Last updated: 06/24/2025. - -We support Megatron Backend by implementing various workers for actor, -critic, reference, rollout and reward models. We also implement the -``3DHybridEngine`` using Megatron-LM and vLLM/SGLang in -`megatron_vllm.py `_ -and `megatron_sglang.py `_. - -**Pros** - -- Support 5D parallelism (TP, EP, CP, DP, PP) and sequence parallelism - for best scalablility and throughput. -- 3D HybridEngine can significantly reduce peak memory usage and reduce - weight synchronize overhead between actor and rollout. - -**Cons** - -- Huggingface Models and Megatron checkpoints need tools for conversion. - - -Development Progress --------------------- - - -Note that [Deprecated] means that the feature is not supported in the latest -version of verl. -[To-Optimize] means that the feature is implemented but not optimized yet. -[WIP] means that the feature is working in progress. -[In-Release] means that the feature is ready and in review process, -coming at any time. - - -+---------------+-----------------------------------------------------------+ -| [Deprecated] | Megatron 3D Parallelism with custom models | -+---------------+-----------------------------------------------------------+ -| [Done] | Megatron 0.11.0 ``GPTModel`` support | -+---------------+-----------------------------------------------------------+ -| [Done] | Megatron GRPO support | -+---------------+-----------------------------------------------------------+ -| [Done] | Megatron with vLLM 0.8.2, with per-tensor weights loading | -+---------------+-----------------------------------------------------------+ -| [Done] | Megatron with Context Parallel | -+---------------+-----------------------------------------------------------+ -| [Done] | Qwen2MoE model support | -+---------------+-----------------------------------------------------------+ -| [To-Optimize] | Megatron dist Checkpoint | -+---------------+-----------------------------------------------------------+ -| [To-Optimize] | Huggingface and Megatron Checkpoint Converter | -+---------------+-----------------------------------------------------------+ -| [To-Optimize] | Efficient fused linear, entropy and cross entropy | -+---------------+-----------------------------------------------------------+ -| [Done] | Megatron offload(param, grad, optimizer) | -+---------------+-----------------------------------------------------------+ -| [Done] | Megatron Profiler | -+---------------+-----------------------------------------------------------+ -| [In-Release] | Megatron 0.12.0, TE 2.2 with vLLM 0.8.3 and Fused Attn | -+---------------+-----------------------------------------------------------+ -| [WIP] | Moonlight/DeepSeek-V3 model support | -+---------------+-----------------------------------------------------------+ -| [WIP] | Expert Parallel support | -+---------------+-----------------------------------------------------------+ -| [WIP] | Megatron support dynamic batch size | -+---------------+-----------------------------------------------------------+ -| [To-Do] | Performance tuning | -+---------------+-----------------------------------------------------------+ -| [MileStone] | Runnable with DeepSeek-V3 671B post-training | -+---------------+-----------------------------------------------------------+ - - - -Utils of Megatron Workers -------------------------- - -MegatronWorker -^^^^^^^^^^^^^^ - -``MegatronWorker`` is the base class of different megatron worker -classes. In this class, ``get_megatron_global_info`` and -``get_megatron_rank_info`` function to retrieve the 3D parallel world -size and rank of each ``Worker`` running on specific GPU. These information -will be used in transfer protocol for Megatron Backend. - -The following ``Worker`` class for different models will be utilized to -construct the ``WorkerGroup`` . - -We implement various of APIs for each ``Worker`` class decorated by the -``@register(dispatch_mode=)`` . These APIs can be called by the ray -driver process. The data can be correctly collect and dispatch following -the ``dispatch_mode`` on each function. The supported dispatch_model -(i.e., transfer protocols) can be found in `decorator.py `_. - -ActorRolloutRefWorker -^^^^^^^^^^^^^^^^^^^^^ - -This class is implemented for Actor/Rollout HybridEngine or for the -reference model to initialize their model and perform computation. - -Actor/Rollout HybridEngine -'''''''''''''''''''''''''' - -1. HybridEngine, Actor and Rollout initialization API. - -.. code:: python - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - -``ONE_TO_ALL``: when calling the ``init_model`` function from the driver -process, each worker (on a GPU) will execute the following model -initialization process. - -The initialization details of HybridEngine, Actor and Rollout are -highlighted below: - -1. ``MegatronPPOActor`` implements the simple PPO computation logics - when the model is built with Megatron, including compute log prob, - model update. -2. ``vLLMRollout`` support generation with vLLM. We modify the vLLM - Engine and make it executed under SPMD to fit into our - ``WorkerGroup`` design. -3. ``MegatronVLLMShardingManager`` a context manager to perform actual - resharding between actor and rollout. - -See `source code `_ for more information. - -.. code:: python - - # build actor model - self.actor = MegatronPPOActor(config=self.config.actor, - model_config=self.actor_model_config, - megatron_config=megatron_config, - actor_module=self.actor_module, - actor_optimizer=self.actor_optimizer, - actor_optimizer_config=self.actor_optim_config) - - # build rollout - # rollout initialization - rollout = vLLMRollout(actor_module=params, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - train_tp=mpu.get_tensor_model_parallel_world_size()) - # perform weight resharding between actor and rollout - sharding_manager = MegatronVLLMShardingManager(module=self.hybrid_engine, - inference_engine=rollout.inference_engine, - model_config=self.actor_model_config, - layer_name_mapping=layer_name_mapping) - ... - -1. Generate sequence and recompute log prob - -.. code:: python - - @register(dispatch_mode=Dispatch.MEGATRON_PP_AS_DP_PROTO) - def generate_sequences(self, prompts: DataProto): - -- ``Dispatch.MEGATRON_PP_AS_DP_PROTO``: The PP dimension of the actor - model will be regarded as DP dimension. Then the driver process will - dispatch and collect the data according to this reorganization. This - is because, in HybridEngine, the actor weight, which usually applied - larger 3D parallel sizes, will be gathered along the PP dimension and - TP dimension. Therefore, the corresponding data should be dispatched - and collected through the 3D parallel group of the rollout model, - rather than the actor model. However, the world_size and rank - information can only be retrieved from ``get_megatron_global_info`` and - ``get_megatron_rank_info``, which records the 3D information for the - actor model. Moreover, the data resharding inside TP dimension will be - processed within the HybridEngine. - -- In this function, the rollout model will perform auto-regressive - generation and the actor model will recompute the old log prob for the - generated response. - -3. Update actor model - -.. code:: python - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - def update_actor(self, data: DataProto): - -- ``Dispatch.MEGATRON_COMPUTE_PROTO``: User passes the data partitioned - by DP dimension. The data is dispatched to all tp/pp ranks within the - same dp group, and ultimately only collects output data from tp=0 and - the last pp. -- Update the actor model weight using PPO & entropy loss. - - -..note:: - - Currently, training Tensor Parallel Size can be different from inference - Tensor Parallel Size. - - -ReferenceModel -'''''''''''''' - -1. Reference model initialization - -The reference model is initialized using the same function as the actor -model without initializing the HybridEngine and Optimizer. Then the -actor model is also wrapped by the ``MegatronPPOActor``. - -2. Compute reference log prob - -.. code:: python - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - def compute_ref_log_prob(self, data: DataProto): - -- In this function, the reference model will call the compute log prob - function in ``MegatronPPOActor`` to compute the reference log prob. - -CriticWorker and RewardWorker -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -1. Model initialization - -Quite similar to reference model. The CriticWorker will perform -additional initialization for the Optimizer. - -2. Compute Values for CriticWorker - -.. code:: python - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - def compute_values(self, data: DataProto): - -3. Update Critic - -.. code:: python - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - def update_critic(self, data: DataProto): - -4. Compute Reward - -.. code:: python - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - def compute_rm_score(self, data: DataProto): - - -Utils of Train Optimization ---------------------------- - -Offload -^^^^^^^ -When resources are tight, the offload method can lower GPU memory -usage, helping training and inference frameworks work well under verl. -It moves parameters, gradients, and optimizers to CPU memory and only -loads them back to the GPU when needed. - -If you want to use the offload, you can add the following parameters -for the actor and ref separately. - -.. code:: python - - # For the actor - actor_rollout_ref.actor.megatron.param_offload=True \ - actor_rollout_ref.actor.megatron.grad_offload=True \ - actor_rollout_ref.actor.megatron.optimizer_offload=True \ - # For the ref w/o grad and optimizer - actor_rollout_ref.ref.megatron.param_offload=True \ - - -For the critic, you can include these parameters. - -.. code:: python - - # For the critic - critic.megatron.param_offload=True \ - critic.megatron.grad_offload=True \ - critic.megatron.optimizer_offload=True \ - -Profiler -^^^^^^^^ - -The profiler is a tool that helps you understand the performance of your -model. It can be used to profile the time spent on different operations -and identify the bottlenecks. You can get more information from -`torch.profiler `_. - -In verl, now the profiler is only support for the actor role In Megatron. You can set -the begin step and end step to profile. Notice, one step means one gradient update. And -the profile result will be saved in the save_path. If you just want to profile in the -specific rank, you can set the profile_ranks, by default, it will be [0]. - -.. code:: python - - actor_rollout_ref.actor.profile.use_profile=True \ - actor_rollout_ref.actor.profile.profile_ranks=[0] \ - actor_rollout_ref.actor.profile.step_start=0 \ - actor_rollout_ref.actor.profile.step_end=1 \ - actor_rollout_ref.actor.profile.save_path="./profile" - - -Related MCore Document ----------------------- - -There is also a detailed document of using MCore to train different -kinds of models, please refer to `MCore Document `_. diff --git a/docs/workers/ray_trainer.rst b/docs/workers/ray_trainer.rst deleted file mode 100644 index 9c482d39a..000000000 --- a/docs/workers/ray_trainer.rst +++ /dev/null @@ -1,241 +0,0 @@ -PPO Ray Trainer -=============== - -Last updated: 02/12/2025. - -We implement the RayPPOTrainer, which is a trainer runs on the driver -process on a single CPU/GPU node (default is CPU). - -The PPORayTrainer include 3 core functions for data preparation, -WorkerGroup initialization and PPO training loop. - -Data Preparation ----------------- - -The ``PPORayTrainer``, as a single process, is responsible for loading a -complete batch of samples (prompts) from the dataset and then dispatch -to different worker_groups running on different GPUs. - -To generalize the data loading, we implement the ``RLHFDataset`` class -to load the preprocessed parquet files, apply chat templates to the -prompts, add padding, truncate prompts that exceed max prompt length and -then tokenize. - -.. code:: python - - self.train_dataset = RLHFDataset(data_files=self.config.data.train_files, - tokenizer=self.tokenizer, - config=self.config.data) - -Then, the dataloader will iterate the dataset under PPO mini batch size. - -WorkerGroup Initialization --------------------------- - -We first introduce a basic implementation of initializing the -``WorkerGroup`` of the actor model on a given set of GPUs. - -.. code:: python - - # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool - # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. - # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different WorkerGroup for differnt models - resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes, - use_gpu=True, - max_colocate_count=1) - # define actor rollout cls to be init on remote - actor_rollout_cls = RayClassWithInitArgs(cls=ActorRolloutWorker) - # define actor_rollout worker group - actor_rollout_worker_group = MegatronRayWorkerGroup(resource_pool=resource_pool, - ray_cls_with_init=actor_rollout_cls, - default_megatron_kwargs=config.actor_rollout.megatron) - -Different WorkerGroups, like ``actor_rollout_worker_group`` , -``critic_worker_group`` and ``ref_worker_group`` lies on a separate -process in the above implementation. - -The driver process can then call the distributed compute function within -the ``actor_rollout_worker_group`` and other roles to construct the RL -training loop. - -For models colocated in the same set of GPUs, we further provide a -fine-grain optimization, which merge the ``worker_group`` of different roles -in the same process. This optimization can save the redundant -CUDA/distributed context in different processes. - -.. code:: python - - # initialize WorkerGroup - # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, - # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to different worker groups. - # See TODO(url) for more information. - all_wg = {} - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - - if self.use_critic: - self.critic_wg = all_wg['critic'] - self.critic_wg.init_model() - - if self.use_reference_policy: - self.ref_policy_wg = all_wg['ref'] - self.ref_policy_wg.init_model() - - if self.use_rm: - self.rm_wg = all_wg['rm'] - self.rm_wg.init_model() - - # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg['actor_rollout'] - self.actor_rollout_wg.init_model() - -.. note:: For megatron backend, if we merge the ``worker_groups`` into the same processes, all the roles will utilize the same 3D parallel size. To optimize this, we may need to maintain several 3D process groups for each role in the same distributed context. If you want to use different 3D parallel size for different roles, please follow the similar architecture of the first code block to initialize each role's ``worker_group`` - - -PPO Training Loop ------------------ - -We implement the PPO training loop by calling the functions in -worker_group of each role. The input and output data of each function is -a ``DataProto`` object implemented in `protocol.py `_. In the training -loop, trainer will dispatch/collect the data to/from different GPUs -following the transfer protocols wrapped in the workers' functions. The -computation of PPO micro batches is processed in ``update_actor`` and -``update_critic`` functions. - -To extend to other RLHF algorithms, such as DPO, GRPO, please refer to -:doc:`../advance/dpo_extension`. - -.. code:: python - - def fit(self): - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from verl.utils.tracking import Tracking - from omegaconf import OmegaConf - - logger = Tracking(project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True)) - - global_steps = 0 - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None: - val_metrics = self._validate() - pprint(f'Initial validation metrics: {val_metrics}') - - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - - batch: DataProto = DataProto.from_single_dict(batch_dict) - # batch = batch.to('cuda') - - # pop those keys for generation - gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) - - # generate a batch - with Timer(name='gen', logger=None) as timer: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - metrics['timing/gen'] = timer.last - - batch = batch.union(gen_batch_output) - - if self.use_reference_policy: - # compute reference log_prob - with Timer(name='ref', logger=None) as timer: - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - metrics['timing/ref'] = timer.last - - # compute values - with Timer(name='values', logger=None) as timer: - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - metrics['timing/values'] = timer.last - - with Timer(name='adv', logger=None) as timer: - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - # we combine with rule-based rm - reward_tensor = self.reward_fn(batch) - batch.batch['token_level_scores'] = reward_tensor - - # compute rewards. apply_kl_penalty if available - batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl_in_reward, - kl_penalty=self.config.algorithm.kl_penalty) - metrics.update(kl_metrics) - - # compute advantages, executed on the driver process - batch = compute_advantage(batch, - self.config.algorithm.gamma, - self.config.algorithm.lam, - adv_estimator=self.config.algorithm.adv_estimator) - metrics['timing/adv'] = timer.last - - # update critic - if self.use_critic: - with Timer(name='update_critic', logger=None) as timer: - critic_output = self.critic_wg.update_critic(batch) - metrics['timing/update_critic'] = timer.last - critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= global_steps: - # update actor - with Timer(name='update_actor', logger=None) as timer: - actor_output = self.actor_rollout_wg.update_actor(batch) - metrics['timing/update_actor'] = timer.last - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) - metrics.update(actor_output_metrics) - - # validate - if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0: - with Timer(name='testing', logger=None) as timer: - val_metrics: dict = self._validate() - val_metrics = {f'val/{key}': val for key, val in val_metrics.items()} - metrics['timing/testing'] = timer.last - metrics.update(val_metrics) - - # collect metrics - data_metrics = compute_data_metrics(batch=batch) - metrics.update(data_metrics) - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=global_steps) - - if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0: - actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor', - f'global_step_{global_steps}') - actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor') - self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path) - - if self.use_critic: - critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic', - f'global_step_{global_steps}') - critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic') - self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path) - - global_steps += 1 - - # perform validation after training - if self.val_reward_fn is not None: - val_metrics = self._validate() - pprint(f'Final validation metrics: {val_metrics}') diff --git a/docs/workers/sglang_worker.rst b/docs/workers/sglang_worker.rst deleted file mode 100644 index 1ef93823c..000000000 --- a/docs/workers/sglang_worker.rst +++ /dev/null @@ -1,237 +0,0 @@ -SGLang Backend -============== - -Last updated: 05/31/2025. - -**Authored By SGLang RL Team and listed alphabetically by last name** - -`Jingyi Chen `_, `Yitong Guan `_, `Zhuobin Huang `_, `Jiajun Li `_, `Ji Li `_, `Shenggui Li `_, `Junrong Lin `_, `Xiang Long `_, `Rui Lu `_, `Jin Pan `_, `Shuai Shi `_, `Yushen Su `_, `Xinyuan Tong `_, `Chendong Wang `_, `Hanchen Zhang `_, `Haoran Wang `_, `Yongan Xiang `_, `Chengxing Xie `_, `Yuhao Yang `_, `Jinwei Yao `_, `Qiaolin Yu `_, `Yuzhen Zhou `_, `Chenyang Zhao `_ - - - -Introduction ------------- -`SGLang `_ is an open-source state-of-the-art inference service engine, fully adopted by xAI to support all inference needs of Grok during research and serving processes. - -Currently, verl fully supports using SGLang as the inference engine during the rollout phase. As a rollout engine, SGLang provides the same feature coverage as vLLM., including memory saving and multi-node rollout features. After installing verl and SGLang, simply add ``actor_rollout_ref.rollout.name=sglang`` at startup script to seamlessly switch between the two inference frameworks. - -In addition, the SGLang team is actively working on supporting features such as Multi-Turn Agentic RL, VLM RLHF, Server-Based RLHF, and Partial Rollout. You can track the related development progress in the `Tracking Roadmap `_. - -Installation ------------- -Please always follow the following command to install SGLang with verl. - -.. code-block:: bash - - pip install --upgrade pip - # Currently 0.4.6.post5, subject to updates at any time, please refer to the latest version specified in `setup.py` - pip install -e ".[sglang]" - -You can check the following dependencies are in your environment: - -.. note:: - - - **PyTorch**: 2.6.0+cu124 - - **CUDA**: 12.4 - - **flashinfer-python**: 0.2.5+cu124torch2.6 - - **sgLang**: 0.4.6.post5 - - **sgl-kernel**: 0.1.4 - -Using SGLang as the Inference Backend for PPO Training on a Single Machine -------------------------------------------------------------------------- -We use Qwen/Qwen2-7B-Instruct on the gsm8k dataset for a simple test. - -1. Run the following command to prepare the gsm8k dataset: - -.. code-block:: bash - - python3 examples/data_preprocess/gsm8k.py - -2. Run the following script to conduct a PPO experiment on a single machine with 4 GPUs: - -.. code-block:: bash - - export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK=True - PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=4096 \ - data.max_prompt_length=4096 \ - data.max_response_length=4096 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - critic.optim.lr=1e-5 \ - critic.model.path=Qwen/Qwen2-7B-Instruct \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.model.fsdp_config.param_offload=True \ - critic.model.fsdp_config.optimizer_offload=True \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=console \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=4 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 2>&1 | tee verl_demo.log - -Why export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -1. ``verl`` initializes a ``SGLangRollout`` module during rollout, which is used to evaluate/generate samples. - -2. ``SGLangRollout`` will initialize ``Engine``, and further initialize a ``torch.distributed.DeviceMesh``, used to support Tensor Parallel (TP). - -3. ``DeviceMesh.init()`` internally checks the free GPU memory of all participating devices. If the difference is too large (more than ~10%), it directly reports an error to avoid initialization failures or deadlocks. - -Why might there be inconsistent GPU memory? -""""""""""""""""""""""""""""""""""""""""""" - -**1. Ray Distributed Actor loads the model at different times** - -``verl`` uses Ray-based multi-process, multi-GPU concurrent training. Each ``WorkerDict`` may be called at different times: - -.. code-block:: python - - self.rollout = SGLangRollout(...) - -Different workers initialize the model at different times → different memory usage. - -**2. Delayed initialization causes memory bias** - -Some workers start model loading/inference (e.g., ``generate_sequences()``, ``compute_log_prob()``) earlier than others. -Early workers already use up GPU memory → late workers still have empty memory → memory difference appears. - -**3. SGLang's TP init uses "all-device broadcast", but there's no uniform release timing** - -Although ``SGLangRollout`` may only involve subset of GPUs, its ``Engine`` initialization calls ``torch.distributed.init_process_group()`` and broadcasts weights, so: - -- Non-rollout GPUs also join the communication. -- Later on, ``DeviceMesh`` init will fail due to "inconsistent memory". - -**4. Different FSDP/TP loading behaviors also lead to mismatch** - -If using: - -.. code-block:: bash - - actor.fsdp_config.param_offload=True - ref.fsdp_config.param_offload=True - -Then some workers keep params on CPU while others already sharded to GPU → leads to asymmetric memory layout. - -Using SGLang as the Inference Backend for PPO Training Across Multiple Machines ------------------------------------------------------------------------------- -SGLang also supports running verl's RAY-based cross-machine inference in IPv4 and IPv6 scenarios. In the script below, we use TP=16 for cross-machine inference. Suppose we have two interconnected machines: node0 with IP 10.94.16.4 and node1 with IP 10.94.16.5. - -1. Start Ray on node0: - -.. code-block:: bash - - ray start --head --dashboard-host=0.0.0.0 - -You will see the following prompt: - -.. code-block:: bash - - Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details. - - Local node IP: 10.94.16.4 - - -------------------- - Ray runtime started. - -------------------- - - Next steps - To add another node to this Ray cluster, run - ray start --address='10.94.16.4:6379' - -2. Have node1 join the Ray cluster: - -Run the following command on node1: - -.. code-block:: bash - - ray start --address='10.94.16.4:6379' - -Run the following command to confirm that the Ray cluster now has two nodes: - -.. code-block:: bash - - ray status - -You can see that the cluster has two nodes with 16 GPUs: - -.. code-block:: bash - - ======== Autoscaler status: 2025-04-09 09:25:37.694016 ======== - Node status - --------------------------------------------------------------- - Active: - 1 node_ef382ffd687d8f6b060c1b68e63ada7341b936fe5b1901dd04de1027 - 1 node_1eb4d7d07e793114c23a89d1a41f1f76acf6ef5b35af844a4ee8e4ba - Pending: - (no pending nodes) - Recent failures: - (no failures) - - Resources - --------------------------------------------------------------- - Usage: - 0.0/360.0 CPU - 0.0/16.0 GPU - 0B/3.39TiB memory - 0B/372.53GiB object_store_memory - -3. Run the following script to train meta-llama/Llama-3.1-8B-Instruct with TP=16 across 2 machines using 16 GPUs: - -.. code-block:: bash - - DATA_DIR=$HOME/data/gsm8k - - python3 -m verl.trainer.main_ppo \ - actor_rollout_ref.rollout.name=sglang \ - data.train_files=$DATA_DIR/train.parquet \ - data.val_files=$DATA_DIR/test.parquet \ - data.train_batch_size=4096 \ - data.max_prompt_length=4096 \ - data.max_response_length=4096 \ - actor_rollout_ref.model.path=meta-llama/Llama-3.1-8B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=16 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.rollout.free_cache_engine=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=meta-llama/Llama-3.1-8B-Instruct \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_micro_batch_size=16 \ - critic.model.fsdp_config.param_offload=True \ - critic.model.fsdp_config.optimizer_offload=True \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.val_before_train=True \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=2 \ - trainer.save_freq=-1 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 2>&1 | tee verl_demo.log diff --git a/examples/data_preprocess/hellaswag.py b/examples/data_preprocess/hellaswag.py deleted file mode 100644 index 1b3f20080..000000000 --- a/examples/data_preprocess/hellaswag.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Preprocess Hellaswag dataset. - -""" - -import argparse -import os -import re - -import datasets - -from verl.utils.hdfs_io import copy, makedirs - - -def preprocess(text): - text = text.strip() - # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. - text = text.replace(" [title]", ". ") - text = re.sub("\\[.*?\\]", "", text) - text = text.replace(" ", " ") - return text - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--local_dir", default="/opt/tiger/hellaswag") - parser.add_argument("--hdfs_dir", default=None) - - args = parser.parse_args() - - data_source = "Rowan/hellaswag" - - dataset = datasets.load_dataset(data_source, trust_remote_code=True) - - train_dataset = dataset["train"] - val_dataset = dataset["validation"] - test_dataset = dataset["test"] - - instruction = "Please complete the following sentence.\n" - - def make_map_fn(split): - def process_fn(doc, idx): - ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() - query = preprocess(doc["activity_label"] + ": " + ctx) - choices = [preprocess(ending) for ending in doc["endings"]] - gold = int(doc["label"]) - - data = { - "data_source": data_source, - "prompt": [{"role": "user", "content": query}], - "ability": "nlp", - "reward_model": { - "style": "model", - "eval": "multiple_choice", # using loglikelihood - "ground_truth": gold, - "choices": choices, - }, - "extra_info": {"split": split, "index": idx}, - } - return data - - return process_fn - - # filter data that doesn't have a label - train_dataset = train_dataset.filter(lambda x: len(x["label"]) > 0) - val_dataset = val_dataset.filter(lambda x: len(x["label"]) > 0) - test_dataset = test_dataset.filter(lambda x: len(x["label"]) > 0) - - train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) - val_dataset = val_dataset.map(function=make_map_fn("validation"), with_indices=True) - test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) - - local_dir = args.local_dir - hdfs_dir = args.hdfs_dir - - train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) - val_dataset.to_parquet(os.path.join(local_dir, "validation.parquet")) - test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) - - if hdfs_dir is not None: - makedirs(hdfs_dir) - - copy(src=local_dir, dst=hdfs_dir) diff --git a/examples/data_preprocess/math_dataset.py b/examples/data_preprocess/math_dataset.py deleted file mode 100644 index e2e5d3524..000000000 --- a/examples/data_preprocess/math_dataset.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Preprocess the MATH-lighteval dataset to parquet format -""" - -import argparse -import os - -import datasets - -from verl.utils.hdfs_io import copy, makedirs -from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed - - -def extract_solution(solution_str): - return remove_boxed(last_boxed_only_string(solution_str)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--local_dir", default="~/data/math") - parser.add_argument("--hdfs_dir", default=None) - - args = parser.parse_args() - - # 'lighteval/MATH' is no longer available on huggingface. - # Use mirror repo: DigitalLearningGmbH/MATH-lighteval - data_source = "DigitalLearningGmbH/MATH-lighteval" - print(f"Loading the {data_source} dataset from huggingface...", flush=True) - dataset = datasets.load_dataset(data_source, trust_remote_code=True) - - train_dataset = dataset["train"] - test_dataset = dataset["test"] - - instruction_following = "Let's think step by step and output the final answer within \\boxed{}." - - # add a row to each data item that represents a unique id - def make_map_fn(split): - def process_fn(example, idx): - question = example.pop("problem") - - question = question + " " + instruction_following - - answer = example.pop("solution") - solution = extract_solution(answer) - data = { - "data_source": data_source, - "prompt": [{"role": "user", "content": question}], - "ability": "math", - "reward_model": {"style": "rule", "ground_truth": solution}, - "extra_info": {"split": split, "index": idx}, - } - return data - - return process_fn - - train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) - test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) - - local_dir = args.local_dir - hdfs_dir = args.hdfs_dir - - train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) - test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) - - if hdfs_dir is not None: - makedirs(hdfs_dir) - - copy(src=local_dir, dst=hdfs_dir) diff --git a/examples/generation/run_deepseek7b_mutli_node.sh b/examples/generation/run_deepseek7b_mutli_node.sh deleted file mode 100644 index e939268ff..000000000 --- a/examples/generation/run_deepseek7b_mutli_node.sh +++ /dev/null @@ -1,22 +0,0 @@ -set -x - -data_path=$HOME/data/rlhf/gsm8k/test.parquet -save_path=$HOME/data/rlhf/math/deepseek_v2_lite_gen_test.parquet -model_path=deepseek-ai/deepseek-llm-7b-chat - -python3 -m verl.trainer.main_generation \ - trainer.nnodes=2 \ - trainer.n_gpus_per_node=8 \ - data.path=$data_path \ - data.prompt_key=prompt \ - data.n_samples=1 \ - data.output_path=$save_path \ - model.path=$model_path\ - +model.trust_remote_code=True \ - rollout.temperature=1.0 \ - rollout.top_k=50 \ - rollout.top_p=0.7 \ - rollout.prompt_length=2048 \ - rollout.response_length=1024 \ - rollout.tensor_model_parallel_size=16 \ - rollout.gpu_memory_utilization=0.8 diff --git a/examples/generation/run_deepseek_v2_lite_math.sh b/examples/generation/run_deepseek_v2_lite_math.sh deleted file mode 100644 index 0c5a74b1f..000000000 --- a/examples/generation/run_deepseek_v2_lite_math.sh +++ /dev/null @@ -1,22 +0,0 @@ -set -x - -data_path=$HOME/data/gsm8k/test.parquet -save_path=$HOME/data/gsm8k/deepseek_v2_lite_gen_test.parquet -model_path=deepseek-ai/deepseek-llm-7b-chat - -python3 -m verl.trainer.main_generation \ - trainer.nnodes=1 \ - trainer.n_gpus_per_node=8 \ - data.path=$data_path \ - data.prompt_key=prompt \ - data.n_samples=1 \ - data.output_path=$save_path \ - model.path=$model_path \ - +model.trust_remote_code=True \ - rollout.temperature=1.0 \ - rollout.top_k=50 \ - rollout.top_p=0.7 \ - rollout.prompt_length=2048 \ - rollout.response_length=1024 \ - rollout.tensor_model_parallel_size=2 \ - rollout.gpu_memory_utilization=0.8 diff --git a/examples/gpg_trainer/gpg.md b/examples/gpg_trainer/gpg.md deleted file mode 100644 index b40cc83bc..000000000 --- a/examples/gpg_trainer/gpg.md +++ /dev/null @@ -1,34 +0,0 @@ -# GPG: Group Policy Gradient - -Group Policy Gradient (GPG) is a minimalist reinforcement learning (RL) method that enhances the reasoning ability of large language models without relying on supervised fine-tuning or complex tricks. GPG revisits traditional policy gradients and directly optimizes the RL objective—no surrogate losses, no KL penalties, no critic, and no reference model. Compared to GRPO, GPG is simpler, more efficient, and achieves better results on many tasks. For more details, please refer to the original paper [GPG: A Simple and Strong Reinforcement Learning Baseline for Model Reasoning -](https://arxiv.org/abs/2504.02546). - -## Key Components -- Use a corrected advantage function to improve policy gradient accuracy and training efficiency. -- By eliminating the critic and reference models, avoiding KL divergence constraints, significantly simplifies the training process compared to Group Relative Policy Optimization (GRPO) - -## Configuration -To configure GPG within the framework, use the following YAML settings. - -```yaml -algorithm: - adv_estimator: gpg -actor_rollout_ref: - actor: - policy_loss: - loss_mode: "gpg" -``` - -## Advanced Extensions -GPG is a simple and strong baseline for model reasoning. Although it avoids using KL loss in its original form, you can still use KL loss to further improve the performance. - -```yaml -algorithm: - adv_estimator: gpg -actor_rollout_ref: - actor: - use_kl_loss: True # enable kl regularization - kl_loss_coef: 0.01 - policy_loss: - loss_mode: "gpg" -``` \ No newline at end of file diff --git a/examples/gpg_trainer/run_qwen2-7b_math.sh b/examples/gpg_trainer/run_qwen2-7b_math.sh deleted file mode 100755 index 1454bf294..000000000 --- a/examples/gpg_trainer/run_qwen2-7b_math.sh +++ /dev/null @@ -1,52 +0,0 @@ -set -x - -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gpg \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.actor.policy_loss.loss_mode=gpg \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_gpg_example_gsm8k_math' \ - trainer.experiment_name='qwen2_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/gpg_trainer/run_qwen2-7b_math_megatron.sh b/examples/gpg_trainer/run_qwen2-7b_math_megatron.sh deleted file mode 100755 index 2317fa07d..000000000 --- a/examples/gpg_trainer/run_qwen2-7b_math_megatron.sh +++ /dev/null @@ -1,54 +0,0 @@ -set -x - -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=gpg \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.policy_loss.loss_mode=gpg \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_gpg_example_gsm8k_math' \ - trainer.experiment_name='qwen2_7b_megatron' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/README.md b/examples/grpo_trainer/README.md deleted file mode 100644 index c1df5ccfc..000000000 --- a/examples/grpo_trainer/README.md +++ /dev/null @@ -1,69 +0,0 @@ -# Group Relative Policy Optimization (GRPO) - -In reinforcement learning, classic algorithms like PPO rely on a "critic" model to estimate the value of actions, guiding the learning process. However, training this critic model can be resource-intensive. - -GRPO simplifies this process by eliminating the need for a separate critic model. Instead, it operates as follows: -- Group Sampling: For a given problem, the model generates multiple possible solutions, forming a "group" of outputs. -- Reward Assignment: Each solution is evaluated and assigned a reward based on its correctness or quality. -- Baseline Calculation: The average reward of the group serves as a baseline. -- Policy Update: The model updates its parameters by comparing each solution's reward to the group baseline, reinforcing better-than-average solutions and discouraging worse-than-average ones. - -This approach reduces computational overhead by avoiding the training of a separate value estimation model, making the learning process more efficient. For more details, refer to the original paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/pdf/2402.03300) - -## Key Components - -- No Value Function (Critic-less): unlike PPO, GRPO does not train a separate value network (critic) -- Group Sampling (Grouped Rollouts): instead of evaluating one rollout per input, GRPO generates multiple completions (responses) from the current policy for each prompt. This set of completions is referred to as a group. -- Relative Rewards: within each group, completions are scored (e.g., based on correctness), and rewards are normalized relative to the group. - -## Configuration - -Note that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior. - -Despite that many configurations start with the `ppo_` prefix, they work across different RL algorithms in verl, as the GRPO training loop is similar to that of PPO (without critic). - -![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d) - -- `actor_rollout.ref.rollout.n`: For each prompt, sample n times. Default to 1. For GRPO, please set it to a value larger than 1 for group sampling. - -- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n` - -- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers. - -- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for GRPO updates on one set of sampled trajectories for actor - -- `actor_rollout_ref.actor.clip_ratio`: The GRPO clip range. Default to 0.2 - -- `algorithm.adv_estimator`: Default is gae. Please set it to grpo instead - -- `actor_rollout_ref.actor.loss_agg_mode`: Default is "token-mean". Options include "token-mean", "seq-mean-token-sum", "seq-mean-token-mean". The original GRPO paper takes the sample-level loss (seq-mean-token-mean), which may be unstable in long-CoT scenarios. All GRPO example scripts provided in verl uses the default configuration "token-mean" for loss aggregation instead. - -Instead of adding KL penalty in the reward, GRPO regularizes by directly adding the KL divergence between the trained policy and the reference policy to the loss: - -- `actor_rollout_ref.actor.use_kl_loss`: To use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False. Please set it to True for GRPO. - -- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001. - -- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html - -## Advanced Extensions - -### DrGRPO - -The work [Understanding R1-Zero-Like Training: A Critical Perspective](https://arxiv.org/pdf/2503.20783) claims there's optimization bias in GRPO, that leads to artificially longer responses, especially for incorrect outputs. This inefficiency stems from the way GRPO calculates advantages using group-based reward normalization, which can inadvertently favor longer, less accurate responses. Instead, DrGRPO aggregates token-level losses by normalizing with a global constant to eliminate length bias. - -Configure the following to enable DrGRPO, with all other parameters the same as GRPO's: - -- `actor_rollout_ref.actor.loss_agg_mode`: "seq-mean-token-sum-norm", which turns off seq-dim averaging -- `actor_rollout_ref.actor.use_kl_loss`: Please set it to False for DrGRPO -- `algorithm.norm_adv_by_std_in_grpo`: False, which turns off standard deviation norm - -## Reference Example - -Qwen2.5 GRPO training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/qwen2-7b-fsdp2.log) - -```bash -bash examples/grpo_trainer/run_qwen3-8b.sh -``` - -For more reference performance, please see https://verl.readthedocs.io/en/latest/algo/baseline.html diff --git a/examples/grpo_trainer/run_deepseek671b_math_megatron.sh b/examples/grpo_trainer/run_deepseek671b_math_megatron.sh deleted file mode 100644 index 2087570ed..000000000 --- a/examples/grpo_trainer/run_deepseek671b_math_megatron.sh +++ /dev/null @@ -1,104 +0,0 @@ -set -x - -# 0. download the config -# only need to download the `configuration_deepseek.py`, `config.json`, `tokenizer_config.json`, `tokenizer.json` and `generation_config.json` -# remove the `quantization_config` in the `config.json` -# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported - -huggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json - -# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main -# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path -DIST_CKPT_PATH="" -LLM="" - - -# 2. run the script -gsm8k_train_path=/data/gsm8k/train.parquet -gsm8k_test_path=/data/gsm8k/test.parquet -train_files=$gsm8k_train_path -test_files=$gsm8k_test_path - -ALL_OFFLOAD=${ALL_OFFLOAD:-True} -COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} -COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} -COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} - -ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} -ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} -REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} -CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} -RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} - -# 512 H20(96GB) -NODES=64 -PP=16 -TP=1 -EP=32 -ETP=1 -INFER_TP=32 -# consider TP/ETP, and enable recompute if short of memory - -# full recompute -# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ -# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ -# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ - -n_resp_per_prompt=4 - -# RAY_ADDRESS='auto' ray job submit --working-dir . -- -python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=512 \ - data.max_prompt_length=2048 \ - data.max_response_length=4096 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=$LLM \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.actor.use_torch_compile=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.temperature=1.0 \ - actor_rollout_ref.rollout.top_p=1.0 \ - actor_rollout_ref.rollout.top_k=-1 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=$INFER_TP \ - algorithm.use_kl_in_reward=False \ - trainer.logger='["console","tensorboard"]' \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='dsv3-32nodes' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=$NODES \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=3 \ - +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=2 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \ - actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ - actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \ - actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ - actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \ - actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ - actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ - actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ - actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - trainer.default_local_dir=$CKPT_DIR \ - trainer.val_before_train=False \ - trainer.total_epochs=100 $@ diff --git a/examples/grpo_trainer/run_deepseek7b_llm.sh b/examples/grpo_trainer/run_deepseek7b_llm.sh deleted file mode 100644 index af9204ab1..000000000 --- a/examples/grpo_trainer/run_deepseek7b_llm.sh +++ /dev/null @@ -1,40 +0,0 @@ -set -x - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='deepseek_llm_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_deepseek7b_llm_math.sh b/examples/grpo_trainer/run_deepseek7b_llm_math.sh deleted file mode 100644 index 198e6f4ae..000000000 --- a/examples/grpo_trainer/run_deepseek7b_llm_math.sh +++ /dev/null @@ -1,49 +0,0 @@ -set -x - - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k_math' \ - trainer.experiment_name='deepseek_llm_7b_function_rm_math' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh b/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh deleted file mode 100644 index 84d59e2ee..000000000 --- a/examples/grpo_trainer/run_deepseek7b_llm_math_megatron.sh +++ /dev/null @@ -1,51 +0,0 @@ -set -x - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k_math' \ - trainer.experiment_name='deepseek_llm_7b_math_megatron' \ - trainer.n_gpus_per_node=16 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh b/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh deleted file mode 100644 index 72cd4445a..000000000 --- a/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh +++ /dev/null @@ -1,39 +0,0 @@ -set -x - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='deepseek_llm_7b_function_rm_seq_packing' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_minicpmo2_6.sh b/examples/grpo_trainer/run_minicpmo2_6.sh deleted file mode 100644 index d95622e1a..000000000 --- a/examples/grpo_trainer/run_minicpmo2_6.sh +++ /dev/null @@ -1,49 +0,0 @@ -set -x - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/geo3k/train.parquet \ - data.val_files=$HOME/data/geo3k/test.parquet \ - data.train_batch_size=128 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=False \ - data.truncation='error' \ - data.image_key=images \ - data.trust_remote_code=True \ - data.custom_cls.path=recipe/minicpmo/rl_dataset.py \ - data.custom_cls.name=RLHFDataset \ - actor_rollout_ref.model.path=openbmb/MiniCPM-o-2_6 \ - actor_rollout_ref.model.trust_remote_code=True \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.use_dynamic_bsz=False \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - +actor_rollout_ref.actor.fsdp_config.use_orig_params=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.enforce_eager=False \ - actor_rollout_ref.rollout.free_cache_engine=False \ - actor_rollout_ref.rollout.n=8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_geo3k' \ - trainer.experiment_name='minicpmo2_6_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_moonlight16b_math_megatron.sh b/examples/grpo_trainer/run_moonlight16b_math_megatron.sh deleted file mode 100644 index aebac5c18..000000000 --- a/examples/grpo_trainer/run_moonlight16b_math_megatron.sh +++ /dev/null @@ -1,53 +0,0 @@ -set -x - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=grpo \ - data.train_files=/mnt/hdfs/gaoziyuan//data/gsm8k/train.parquet \ - data.val_files=/mnt/hdfs/gaoziyuan/data/gsm8k/test.parquet \ - data.train_batch_size=192 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.trust_remote_code=True \ - actor_rollout_ref.model.path=/mnt/hdfs/gaoziyuan/models/moonshotai/Moonlight-16B-A3B \ - actor_rollout_ref.model.trust_remote_code=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=3 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.actor.megatron.expert_model_parallel_size=4 \ - actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=1 \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=/mnt/hdfs/gaoziyuan/dist_ckpt/moonshotai/Moonlight-16B-A3B \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=3 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.ref.megatron.expert_model_parallel_size=4 \ - actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=1 \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=/mnt/hdfs/gaoziyuan/dist_ckpt/moonshotai/Moonlight-16B-A3B \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k_math' \ - trainer.experiment_name='moonlight_megatron_ep' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=3 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2-7b.sh b/examples/grpo_trainer/run_qwen2-7b.sh deleted file mode 100644 index c32087e8c..000000000 --- a/examples/grpo_trainer/run_qwen2-7b.sh +++ /dev/null @@ -1,41 +0,0 @@ -set -x - - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2-7b_math.sh b/examples/grpo_trainer/run_qwen2-7b_math.sh deleted file mode 100644 index f4e6ec408..000000000 --- a/examples/grpo_trainer/run_qwen2-7b_math.sh +++ /dev/null @@ -1,49 +0,0 @@ -set -x - - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k_math' \ - trainer.experiment_name='qwen2_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh deleted file mode 100644 index 0a23bab8f..000000000 --- a/examples/grpo_trainer/run_qwen2-7b_math_megatron.sh +++ /dev/null @@ -1,62 +0,0 @@ -set -x - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -rollout_mode="sync" -if [ "$rollout_mode" = "async" ]; then - export VLLM_USE_V1=1 - return_raw_chat="True" -fi - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -USE_FUSED_KERNELS=True - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.return_raw_chat=$return_raw_chat \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.model.use_fused_kernels=$USE_FUSED_KERNELS \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.mode=$rollout_mode \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k_math' \ - trainer.experiment_name='qwen2_7b_megatron' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh deleted file mode 100644 index 79881a1e0..000000000 --- a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh +++ /dev/null @@ -1,52 +0,0 @@ -set -x - - -# For async rollout mode, dataset should return raw chat. -rollout_mode="async" -rollout_name="sglang" # sglang or vllm -if [ "$rollout_mode" = "async" ]; then - export VLLM_USE_V1=1 - return_raw_chat="True" -fi - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.return_raw_chat=$return_raw_chat \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=$rollout_name \ - actor_rollout_ref.rollout.mode=$rollout_mode \ - actor_rollout_ref.rollout.multi_turn.format=hermes \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_7b_function_rm_kl1e-3' \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh deleted file mode 100644 index 54572c02d..000000000 --- a/examples/grpo_trainer/run_qwen2-7b_seq_balance_math_megatron.sh +++ /dev/null @@ -1,52 +0,0 @@ -set -x - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k_math' \ - trainer.experiment_name='qwen2_7b_megatron' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh b/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh deleted file mode 100644 index eeac388b4..000000000 --- a/examples/grpo_trainer/run_qwen2-7b_sgl_megatron.sh +++ /dev/null @@ -1,48 +0,0 @@ -set -x - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_7b_function_rm_megatron' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh b/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh deleted file mode 100644 index 81236f621..000000000 --- a/examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh +++ /dev/null @@ -1,47 +0,0 @@ -set -x - - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.shuffle=False \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ - actor_rollout_ref.model.use_shm=True \ - actor_rollout_ref.model.lora_rank=64 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.actor.optim.lr=3e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.load_format=safetensors \ - actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2.5_3b_grpo_lora' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh b/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh deleted file mode 100644 index d4a1a3fcd..000000000 --- a/examples/grpo_trainer/run_qwen2_5-7b_math_megatron_diff_tp.sh +++ /dev/null @@ -1,51 +0,0 @@ -set -x - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k_math' \ - trainer.experiment_name='qwen2_7b_megatron' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh b/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh deleted file mode 100644 index 6d0d4fe4e..000000000 --- a/examples/grpo_trainer/run_qwen2_5_32b_grpo_npu.sh +++ /dev/null @@ -1,41 +0,0 @@ -set -x - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6\ - actor_rollout_ref.model.use_remove_padding=False \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_5_32b_function_rm' \ - trainer.n_gpus_per_node=16 \ - trainer.nnodes=2 \ - trainer.save_freq=-1 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 \ - trainer.device=npu $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh deleted file mode 100644 index 44e94cd07..000000000 --- a/examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh +++ /dev/null @@ -1,72 +0,0 @@ -set -x - -# profiling configuration -PROFILE_STEPS="[2,4]" -PROFILE_RANKS_ALL=False -DISCRETE=True -PROFILE_RANKS="[1,2]" - -# profiling NPU options -SAVE_PATH="$HOME/profile_data" -LEVEL="level1" -WITH_MEMORY=False -RECORD_SHAPES=False -WITH_NPU=True -WITH_CPU=True -WITH_MODULE=False -WITH_STACK=False -ANALYSIS=True - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=5e-8 \ - actor_rollout_ref.model.use_remove_padding=False \ - actor_rollout_ref.profiler.ranks=$PROFILE_RANKS \ - actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \ - actor_rollout_ref.profiler.discrete=$DISCRETE \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.npu_profile.options.save_path=$SAVE_PATH \ - trainer.npu_profile.options.level=$LEVEL \ - trainer.npu_profile.options.with_memory=$WITH_MEMORY \ - trainer.npu_profile.options.record_shapes=$RECORD_SHAPES \ - trainer.npu_profile.options.with_npu=$WITH_NPU \ - trainer.npu_profile.options.with_cpu=$WITH_CPU \ - trainer.npu_profile.options.with_module=$WITH_MODULE \ - trainer.npu_profile.options.with_stack=$WITH_STACK \ - trainer.npu_profile.options.analysis=$ANALYSIS \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_5_7b_function_rm' \ - trainer.n_gpus_per_node=16 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=5 \ - trainer.profile_steps=$PROFILE_STEPS \ - trainer.device=npu $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh deleted file mode 100644 index 70491c235..000000000 --- a/examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh +++ /dev/null @@ -1,70 +0,0 @@ -set -x - -# profiling configuration -PROFILE_STEPS="[2,4]" -PROFILE_RANKS_ALL=True -DISCRETE=False - -# profiling NPU options -SAVE_PATH="$HOME/profile_data" -LEVEL="level1" -WITH_MEMORY=False -RECORD_SHAPES=False -WITH_NPU=True -WITH_CPU=True -WITH_MODULE=False -WITH_STACK=False -ANALYSIS=True - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=5e-8 \ - actor_rollout_ref.model.use_remove_padding=False \ - actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \ - actor_rollout_ref.profiler.discrete=$DISCRETE \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.npu_profile.options.save_path=$SAVE_PATH \ - trainer.npu_profile.options.level=$LEVEL \ - trainer.npu_profile.options.with_memory=$WITH_MEMORY \ - trainer.npu_profile.options.record_shapes=$RECORD_SHAPES \ - trainer.npu_profile.options.with_npu=$WITH_NPU \ - trainer.npu_profile.options.with_cpu=$WITH_CPU \ - trainer.npu_profile.options.with_module=$WITH_MODULE \ - trainer.npu_profile.options.with_stack=$WITH_STACK \ - trainer.npu_profile.options.analysis=$ANALYSIS \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_5_7b_function_rm' \ - trainer.n_gpus_per_node=16 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=5 \ - trainer.profile_steps=$PROFILE_STEPS \ - trainer.device=npu $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh b/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh deleted file mode 100644 index 07dda340c..000000000 --- a/examples/grpo_trainer/run_qwen2_5_7b_grpo_npu.sh +++ /dev/null @@ -1,42 +0,0 @@ -set -x - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=5e-8 \ - actor_rollout_ref.model.use_remove_padding=False \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_5_7b_function_rm' \ - trainer.n_gpus_per_node=16 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=5 \ - trainer.device=npu $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh deleted file mode 100644 index d0de1aac5..000000000 --- a/examples/grpo_trainer/run_qwen2_5_vl-7b-megatron.sh +++ /dev/null @@ -1,88 +0,0 @@ -set -x -ENGINE=${1:-vllm} -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -HF_MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct -DIST_CKPT_PATH=${DIST_CKPT_PATH} - -# convert HF model to meagatron format offlinely -# python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH - - -# megatron tuning guide: -# 1. recommend to offload all states by setting ALL_OFFLOAD=True -# 2. enable dynamic batch size by setting actor_rollout_ref.actor.use_dynamic_bsz=True ref.log_prob_use_dynamic_bsz=True rollout.log_prob_use_dynamic_bsz=True -# 3. set ppo_max_token_len_per_gpu and log_prob_max_token_len_per_gpu as large as possible for better MFU (limited by GPU memory). assure ppo_max_token_len_per_gpu > max_prompt_length+max_response_length, if sequence length is too long, you can increase the TP/PP size -# 4. if memory is very limited, enable full recompute, but the mfu will be 30% lower -# full recompute settings: -# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \ -# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \ -# +actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \ - -ALL_OFFLOAD=${ALL_OFFLOAD:-True} -COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} -COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} -COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} - -ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} -ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} -REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} - - -train_path=$HOME/data/geo3k/train.parquet -test_path=$HOME/data/geo3k/test.parquet - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=grpo \ - data.train_files="$train_path" \ - data.val_files="$test_path" \ - data.train_batch_size=512 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=$HF_MODEL_PATH \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=1 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.01 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=20480 \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=20480 \ - actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=1 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ - actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ - actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ - actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_geo3k' \ - trainer.experiment_name='qwen2_5_vl_7b_megatron' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b.sh deleted file mode 100644 index 450390e25..000000000 --- a/examples/grpo_trainer/run_qwen2_5_vl-7b.sh +++ /dev/null @@ -1,46 +0,0 @@ -set -x -ENGINE=${1:-vllm} - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/geo3k/train.parquet \ - data.val_files=$HOME/data/geo3k/test.parquet \ - data.train_batch_size=512 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.image_key=images \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.01 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.enforce_eager=False \ - actor_rollout_ref.rollout.free_cache_engine=True \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_geo3k' \ - trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh deleted file mode 100644 index b00ad8087..000000000 --- a/examples/grpo_trainer/run_qwen2_5_vl-7b_lora.sh +++ /dev/null @@ -1,52 +0,0 @@ -set -x -ENGINE=${1:-vllm} -# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: -# export VLLM_ATTENTION_BACKEND=XFORMERS - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/geo3k/train.parquet \ - data.val_files=$HOME/data/geo3k/test.parquet \ - data.train_batch_size=512 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.image_key=images \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=3e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \ - actor_rollout_ref.model.lora_rank=64 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.model.target_modules=all-linear \ - actor_rollout_ref.model.exclude_modules='.*visual.*' \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.01 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.enforce_eager=False \ - actor_rollout_ref.rollout.free_cache_engine=False \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_geo3k' \ - trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh deleted file mode 100644 index e9933b106..000000000 --- a/examples/grpo_trainer/run_qwen2_5_vl-7b_seq_balance.sh +++ /dev/null @@ -1,45 +0,0 @@ -set -x -ENGINE=${1:-vllm} - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/geo3k/train.parquet \ - data.val_files=$HOME/data/geo3k/test.parquet \ - data.train_batch_size=512 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.image_key=images \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=6144 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.01 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.enforce_eager=False \ - actor_rollout_ref.rollout.free_cache_engine=False \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=6144 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_geo3k' \ - trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh deleted file mode 100644 index ef1301126..000000000 --- a/examples/grpo_trainer/run_qwen2_5_vl_32b_npu.sh +++ /dev/null @@ -1,52 +0,0 @@ -set -x -ENGINE=${1:-vllm} - -# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, -# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. -export USE_OPTIMIZED_MODEL=0 - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/geo3k/train.parquet \ - data.val_files=$HOME/data/geo3k/test.parquet \ - data.train_batch_size=512 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.image_key=images \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-32B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.01 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.use_torch_compile=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ - actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.enforce_eager=True \ - actor_rollout_ref.rollout.free_cache_engine=True \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_geo3k' \ - trainer.experiment_name='qwen2_5_vl_32b_function_rm' \ - trainer.n_gpus_per_node=16 \ - trainer.nnodes=2 \ - trainer.save_freq=-1 \ - trainer.test_freq=-1 \ - trainer.total_epochs=15 \ - trainer.device=npu $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh deleted file mode 100644 index b319dee99..000000000 --- a/examples/grpo_trainer/run_qwen2_5_vl_3b_npu.sh +++ /dev/null @@ -1,52 +0,0 @@ -set -x -ENGINE=${1:-vllm} - -# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, -# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. -export USE_OPTIMIZED_MODEL=0 - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/geo3k/train.parquet \ - data.val_files=$HOME/data/geo3k/test.parquet \ - data.train_batch_size=512 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.image_key=images \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=16 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.01 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.use_torch_compile=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.enforce_eager=True \ - actor_rollout_ref.rollout.free_cache_engine=True \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_geo3k' \ - trainer.experiment_name='qwen2_5_vl_3b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=-1 \ - trainer.total_epochs=15 \ - trainer.device=npu $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh b/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh deleted file mode 100644 index 913da5424..000000000 --- a/examples/grpo_trainer/run_qwen2_5_vl_7b_npu.sh +++ /dev/null @@ -1,52 +0,0 @@ -set -x -ENGINE=${1:-vllm} - -# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, -# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. -export USE_OPTIMIZED_MODEL=0 - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/geo3k/train.parquet \ - data.val_files=$HOME/data/geo3k/test.parquet \ - data.train_batch_size=512 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.image_key=images \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=32 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.01 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.use_torch_compile=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.enforce_eager=True \ - actor_rollout_ref.rollout.free_cache_engine=True \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_geo3k' \ - trainer.experiment_name='qwen2_5_vl_7b_function_rm' \ - trainer.n_gpus_per_node=16 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=-1 \ - trainer.total_epochs=15 \ - trainer.device=npu $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen3-236b_megatron.sh b/examples/grpo_trainer/run_qwen3-236b_megatron.sh deleted file mode 100644 index 7c3f741db..000000000 --- a/examples/grpo_trainer/run_qwen3-236b_megatron.sh +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -# Note that we set the response length to 4k. This results in many truncations at the beginning. -# So the training dynamic acts as using RL to compress the math capabilities of QWen3 236b into 4k response instead of verbose thinking. -# We can achieve 0.5 on AIME'24 after 30 steps. - -project_name='DAPO' -exp_name='DAPO-Qwen3-236b-megatron-0531a1' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 4)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=0.1 - -loss_agg_mode="token-mean" - -train_prompt_bsz=256 -n_resp_per_prompt=4 -train_prompt_mini_bsz=16 - -# H20 GPUs -NNODES=${NNODES:-32} - -# Paths - -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} - -MODEL_PATH=$RAY_DATA_HOME/models/Qwen3-235B-A22B -MCORE_MODEL_PATH=$RAY_DATA_HOME/models/Qwen3-235B-A22B_dist_ckpt_mcore/ - -# convert QWen3-235b-A22b to dist ckpt of mcore. Conversion process will take about 4 hours -# python scripts/converter_hf_to_mcore.py --hf_model_path $MODEL_PATH --output_path $MCORE_MODEL_PATH --use_cpu_initialization -CKPTS_DIR=$RAY_DATA_HOME/ckpt/${project_name}/${exp_name} -TRAIN_FILE=$RAY_DATA_HOME/dataset/dapo-math-17k.parquet -TEST_FILE=$RAY_DATA_HOME/dataset/aime-2024.parquet - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -val_top_p=0.7 - -# Performance Related Parameter -offload=True -gen_tp=8 -train_tp=4 -train_ep=4 -train_pp=8 - -python3 -m verl.trainer.main_ppo \ - --config-path=config \ - --config-name='ppo_megatron_trainer.yaml' \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.megatron.param_offload=${offload} \ - actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ - actor_rollout_ref.actor.megatron.grad_offload=${offload} \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ - +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=5 \ - +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=5 \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.optim.clip_grad=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \ - actor_rollout_ref.ref.megatron.param_offload=${offload} \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ - reward_model.reward_manager=dapo \ - +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ - +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=False \ - trainer.test_freq=10 \ - trainer.save_freq=20 \ - trainer.total_epochs=10 \ - trainer.total_training_steps=100 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto \ - trainer.log_val_generations=10 diff --git a/examples/grpo_trainer/run_qwen3-8b.sh b/examples/grpo_trainer/run_qwen3-8b.sh deleted file mode 100644 index a99b432d6..000000000 --- a/examples/grpo_trainer/run_qwen3-8b.sh +++ /dev/null @@ -1,43 +0,0 @@ -# Tested successfully on the hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0 image. -# It outperforms the Qwen2 7B base model by two percentage points on the test set of GSM8K. - -set -x - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen3-8B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen3_8b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen3moe-30b_megatron.sh b/examples/grpo_trainer/run_qwen3moe-30b_megatron.sh deleted file mode 100644 index 49d5eb999..000000000 --- a/examples/grpo_trainer/run_qwen3moe-30b_megatron.sh +++ /dev/null @@ -1,54 +0,0 @@ -set -x - -HF_MODEL_PATH=Qwen/Qwen3-30B-A3B -DIST_CKPT_PATH=${DIST_CKPT_PATH} - -python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=64 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=$HF_MODEL_PATH \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.actor.megatron.expert_model_parallel_size=4 \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=4 \ - actor_rollout_ref.ref.megatron.expert_model_parallel_size=4 \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k_math' \ - trainer.experiment_name='qwen3_30b_moe_megatron' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=4 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/README.md b/examples/ppo_trainer/README.md deleted file mode 100644 index f4df70f9a..000000000 --- a/examples/ppo_trainer/README.md +++ /dev/null @@ -1,103 +0,0 @@ -# Proximal Policy Optimization (PPO) - -Proximal Policy Optimization (PPO) is a family of policy gradient methods for reinforcement learning, proposed by OpenAI in 2017. PPO strikes a balance between simplicity, stability, and performance, making it one of the most widely used algorithms in modern RL applications, including large-scale language model fine-tuning. - -Traditional policy gradient methods like REINFORCE or Vanilla Policy Gradient suffer from: - -- High variance and sample inefficiency. -- Instability due to large policy updates. - -PPO addresses this problem using a clipped surrogate objective that avoids overly large updates without requiring second-order derivatives. - -For more technical details regarding PPO, we suggest reading the introduction in the [OpenAI spinning up tutorial](https://spinningup.openai.com/en/latest/algorithms/ppo.html), and the paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347). - -## Key Components - -- Actor-Critic Architecture: PPO requires both an actor model (policy) and a critic model (value function). This differs from other algorithms like GRPO and RLOO that don't require a critic model. - -- Generalized Advantage Estimation (GAE): PPO uses GAE for computing advantage values, which helps reduce variance in policy gradient estimates while maintaining low bias. - -- Clipped Surrogate Objective: The core of PPO is implemented through the clipped surrogate objective function that limits policy updates. - -## Configuration - -Note that all configs containing `micro_batch_size` are used to configure the maximum sample or token count per forward or backward pass to avoid GPU OOMs, whose value should not change algorithmic/convergence behavior. - -Most critic configs are similar to those of actors. Note that the critic model is omitted from the figure below. - -![image](https://github.com/user-attachments/assets/16aebad1-0da6-4eb3-806d-54a74e712c2d) - -- `data.train_batch_size`: The global batch size of prompts used to generate a set of sampled trajectories/rollouts. The number of responses/trajectories is `data.train_batch_size * actor_rollout.ref.rollout.n` - -- `actor_rollout_ref.actor.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO actor updates. The ppo_mini_batch_size is a global size across all workers - -- `actor_rollout_ref.critic.ppo_mini_batch_size`: The set of sampled trajectories is split into multiple mini-batches with batch_size=ppo_mini_batch_size for PPO critic updates. The ppo_mini_batch_size is a global size across all workers - -- `actor_rollout_ref.actor.clip_ratio`: The PPO clip range. Default to 0.2 - -- `actor_rollout_ref.actor.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for actor - -- `critic.ppo_epochs`: Number of epochs for PPO updates on one set of sampled trajectories for critic. Defaults to `actor_rollout_ref.actor.ppo_epochs` - -- `algorithm.gamma`: discount factor - -- `algorithm.lam`: The lambda term that trades off between bias and variance in the GAE estimator - -- `algorithm.adv_estimator`: Support gae, grpo, reinforce_plus_plus, reinforce_plus_plus_baseline, rloo - -## Advanced Extensions - -### KL Divergence Control - -Options to prevent the policy from diverging too far from a reference policy. Two mechanisms are available: KL reward penalty and KL loss. For more technical details, see [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) - -Options to use KL loss for KL divergence control: - -- `actor_rollout_ref.actor.use_kl_loss`: to use kl loss in the actor. When used, we are not applying KL in the reward function. Default is False - -- `actor_rollout_ref.actor.kl_loss_coef`: The coefficient of kl loss. Default is 0.001. - -- `actor_rollout_ref.actor.kl_loss_type`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. How to calculate the kl divergence between actor and reference policy. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html - -Options to use KL penalty in the reward: - -- `algorithm.use_kl_in_reward`: Whether to enable in-reward kl penalty. Default is False. - -- `algorithm.kl_penalty`: Support kl(k1), abs, mse(k2), low_var_kl(k3) and full. This defines the way to calculate the kl divergence between actor and reference policy. For specific options, refer to `kl_penalty` in core_algos.py. See this blog post for detailed analysis: http://joschu.net/blog/kl-approx.html - -- `algorithm.kl_ctrl.kl_coef`: The (initial) coefficient of in-reward kl_penalty. Default is 0.001. -- `algorithm.kl_ctrl.type`: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController. -- `algorithm.kl_ctrl.horizon`: See source code of AdaptiveKLController for details. -- `algorithm.kl_ctrl.target_kl`: See source code of AdaptiveKLController for details. - -### Dual-clip PPO - -The Dual-Clip PPO introduces a approach by applying a lower bound to the policy ratio when the advantage is less than zero, when multiplied by a large raito, does not exceed a specified lower bound. - -![image](https://github.com/user-attachments/assets/fc232181-d8b0-4307-8dd2-4dc0a4c1c139) - -- `actor_rollout_ref.actor.clip_ratio_c`: lower bound of the value for Dual-clip PPO, defaults to 3.0 - -## Reference Example - -Qwen2.5 training log and commands: [link](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) - -```bash -bash run_gemma.sh - trainer.n_gpus_per_node=1 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - trainer.logger=console \ - critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - data.train_batch_size=256 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size=2 \ - critic.ppo_micro_batch_size=2 -``` - -Reference performance with verl v0.2: - -| Model | Method | Score | Link | -|-------------------------------|------------------|-------|------------------------------------------------------------------------------------------------| -| Qwen/Qwen2.5-0.5B-Instruct | pretrained model | 36.4 | [Qwen Blog](https://qwenlm.github.io/blog/qwen2.5-llm/) | -| Qwen/Qwen2.5-0.5B-Instruct | PPO | 56.7 | [PPO Command and Logs](https://github.com/eric-haibin-lin/verl-data/blob/experiments/gsm8k/Qwen2.5-0.5B-bsz256_2-prompt1024-resp512-0.567.log) | diff --git a/examples/ppo_trainer/run_deepseek7b_llm.sh b/examples/ppo_trainer/run_deepseek7b_llm.sh deleted file mode 100644 index 01e4a24a1..000000000 --- a/examples/ppo_trainer/run_deepseek7b_llm.sh +++ /dev/null @@ -1,41 +0,0 @@ -set -x - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_micro_batch_size_per_gpu=32 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='deepseek_llm_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh b/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh deleted file mode 100644 index eb6dc7923..000000000 --- a/examples/ppo_trainer/run_deepseek7b_llm_modelscope.sh +++ /dev/null @@ -1,42 +0,0 @@ -set -x - -VERL_USE_MODELSCOPE=True \ -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_micro_batch_size_per_gpu=32 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='deepseek_llm_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh b/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh deleted file mode 100644 index 312c6b50b..000000000 --- a/examples/ppo_trainer/run_deepseek7b_llm_pfppo.sh +++ /dev/null @@ -1,45 +0,0 @@ -set -x - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - algorithm.use_pf_ppo=True \ - algorithm.pf_ppo.reweight_method=pow \ # ["pow", "max_min", "max_random"] - algorithm.pf_ppo.weight_pow=2.0 \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.rollout.n=5 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_micro_batch_size_per_gpu=32 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='deepseek_llm_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh b/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh deleted file mode 100644 index 69ee7b8bd..000000000 --- a/examples/ppo_trainer/run_deepseek7b_llm_sandbox_fusion.sh +++ /dev/null @@ -1,44 +0,0 @@ -set -x - -python3 -m verl.trainer.main_ppo \ - reward_model.sandbox_fusion.url='https://xxxxxxxxx.apigateway-cn-beijing.volceapi.com/run_code' \ - reward_model.sandbox_fusion.max_concurrent=128 \ - reward_model.reward_manager=prime \ - algorithm.adv_estimator=gae \ - data.train_files=$HOME/data/Eurus-2-RL-Data/train.parquet \ - data.val_files=$HOME/data/Eurus-2-RL-Data/validation.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_micro_batch_size_per_gpu=32 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_sandbox_fusion' \ - trainer.experiment_name='deepseek_llm_7b_function_sandbox_fusion' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=1 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh b/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh deleted file mode 100644 index 3cb8a852b..000000000 --- a/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh +++ /dev/null @@ -1,43 +0,0 @@ -set -x - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - critic.optim.lr=1e-5 \ - critic.ulysses_sequence_parallel_size=2 \ - critic.model.use_remove_padding=True \ - critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_micro_batch_size_per_gpu=64 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='deepseek_llm_7b_function_rm_sp2' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh deleted file mode 100644 index 976641f13..000000000 --- a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh +++ /dev/null @@ -1,42 +0,0 @@ -set -x - -train_files=$HOME/data/full_hh_rlhf/rl/train.parquet -test_files=$HOME/data/full_hh_rlhf/rl/train.parquet # no use - -python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=512 \ - data.max_prompt_length=128 \ - data.max_response_length=128 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - critic.optim.lr=1e-5 \ - critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - reward_model.enable=True \ - reward_model.megatron.tensor_model_parallel_size=4 \ - reward_model.model.path=deepseek-ai/deepseek-llm-7b-chat \ - reward_model.micro_batch_size_per_gpu=4 \ - reward_model.param_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_megatron_full_hh_rlhf_examples' \ - trainer.experiment_name='deepseek_llm_7b_model_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=100 $@ diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh deleted file mode 100644 index c747b573f..000000000 --- a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh +++ /dev/null @@ -1,50 +0,0 @@ -set -x - -# Example runnable on H20 * 8 - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - critic.optim.lr=1e-5 \ - critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_ppo_gsm8k_math_examples' \ - trainer.experiment_name='deepseek_llm_7b_megatron' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=100 $@ diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh deleted file mode 100644 index 9cbbade33..000000000 --- a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh +++ /dev/null @@ -1,64 +0,0 @@ -set -x - -# Example runnable on H20 * 8 - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files=${train_files:-"$gsm8k_train_path"} -test_files=${test_files:-"$gsm8k_test_path"} - -# Nsight profiling configuration -PROFILE_STEPS="[1,2,5]" # or [] or null -PROFILE_RANKS_ALL=False # or True -PROFILE_RANKS=[0,4,8,12] -DISCRETE=True # or True - -python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=256 \ - data.max_prompt_length=1024 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.profiler.ranks=$PROFILE_RANKS \ - actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \ - actor_rollout_ref.profiler.discrete=$DISCRETE \ - critic.optim.lr=1e-5 \ - critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.profiler.ranks=$PROFILE_RANKS \ - critic.profiler.all_ranks=$PROFILE_RANKS_ALL \ - critic.profiler.discrete=$DISCRETE \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_ppo_gsm8k_math_examples' \ - trainer.experiment_name='deepseek_llm_7b_megatron' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=2 \ - trainer.save_freq=-1 \ - trainer.test_freq=-1 \ - trainer.total_epochs=100 \ - trainer.total_training_steps=6 \ - trainer.profile_steps=$PROFILE_STEPS $@ diff --git a/examples/ppo_trainer/run_gemma.sh b/examples/ppo_trainer/run_gemma.sh deleted file mode 100644 index b015275c1..000000000 --- a/examples/ppo_trainer/run_gemma.sh +++ /dev/null @@ -1,40 +0,0 @@ -set -x - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=512 \ - data.max_prompt_length=1024 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=google/gemma-2-2b-it \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=False \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=False \ - critic.model.path=google/gemma-2-2b-it \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example' \ - trainer.experiment_name='gemma2b_function_rm' \ - trainer.n_gpus_per_node=2 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh b/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh deleted file mode 100644 index 64bdbb727..000000000 --- a/examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh +++ /dev/null @@ -1,107 +0,0 @@ -set -x - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - - -# 0. download the model -huggingface-cli download moonshotai/Moonlight-16B-A3B-Instruct - -# 1. convert the model to mcore format -# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path -HF_MODEL_PATH=/data/models/moonshotai/Moonlight-16B-A3B-Instruct -DIST_CKPT_PATH=/data/mcore_ckpt/Moonlight-16B-A3B-Instruct -python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH - - -# 2. run the script -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -train_files=$gsm8k_train_path -test_files=$gsm8k_test_path - -ALL_OFFLOAD=${ALL_OFFLOAD:-False} -COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} -COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} -COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} - -ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} -ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} -REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} -CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} -RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} - - -NODES=4 -PP=2 -TP=8 -EP=8 -ETP=1 -VLLM_TP=4 - -# RAY_ADDRESS='auto' ray job submit --working-dir . -- -python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.trust_remote_code=True \ - actor_rollout_ref.model.path=$LLM \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ - critic.optim.lr=1e-5 \ - critic.model.path=$LLM \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='moonlight_16b_a3b_instruct_1node' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=$NODES \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - actor_rollout_ref.model.trust_remote_code=True \ - critic.model.trust_remote_code=True \ - +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=13 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \ - critic.megatron.pipeline_model_parallel_size=$PP \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \ - critic.megatron.tensor_model_parallel_size=$TP \ - actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ - actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \ - critic.megatron.expert_model_parallel_size=$EP \ - actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ - actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \ - critic.megatron.expert_tensor_parallel_size=$ETP \ - actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ - actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ - actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ - actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ - critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \ - critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \ - critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ - critic.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - trainer.val_before_train=False \ - trainer.total_epochs=100 $@ - \ No newline at end of file diff --git a/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh b/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh deleted file mode 100644 index accdd7f65..000000000 --- a/examples/ppo_trainer/run_qwen1.5_moe_a2.7b-gsm8k_megatron.sh +++ /dev/null @@ -1,74 +0,0 @@ -set -x - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -# 0. download the model -huggingface-cli download Qwen/Qwen1.5-MoE-A2.7B-Chat - -# 1. convert the model to mcore format -# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path -HF_MODEL_PATH=/data/models/Qwen/Qwen1.5-MoE-A2.7B-Chat -DIST_CKPT_PATH=/data/mcore_ckpt/Qwen1.5-MoE-A2.7B-Chat -python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH - -# 2. run the script -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -train_files=$gsm8k_train_path -test_files=$gsm8k_test_path - -NODES=4 -PP=2 -TP=4 -CP=1 -VLLM_TP=4 - -# RAY_ADDRESS='auto' ray job submit --working-dir . -- -python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=$HF_MODEL_PATH \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ - actor_rollout_ref.actor.megatron.context_parallel_size=$CP \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \ - actor_rollout_ref.ref.megatron.context_parallel_size=$CP \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \ - critic.optim.lr=1e-5 \ - critic.model.path=$HF_MODEL_PATH \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - critic.megatron.tensor_model_parallel_size=$TP \ - critic.megatron.pipeline_model_parallel_size=$PP \ - critic.megatron.context_parallel_size=$CP \ - critic.megatron.use_dist_checkpointing=True \ - critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_megatron_gsm8k_examples' \ - trainer.experiment_name='qwen1.5_moe_nochat' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=$NODES \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=100 $@ - \ No newline at end of file diff --git a/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh deleted file mode 100644 index 22558c62b..000000000 --- a/examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh +++ /dev/null @@ -1,48 +0,0 @@ -set -x - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - critic.optim.lr=1e-5 \ - critic.model.path=Qwen/Qwen2-7B-Instruct \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_ppo_gsm8k_math_examples' \ - trainer.experiment_name='qwen2_7b_megatron' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=100 $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm.sh b/examples/ppo_trainer/run_qwen2-7b_rm.sh deleted file mode 100644 index 98b305844..000000000 --- a/examples/ppo_trainer/run_qwen2-7b_rm.sh +++ /dev/null @@ -1,71 +0,0 @@ -# Discliamer: the model used in the script is only for academic purpose. -set -x - -# Data preparation scripts are available in ``examples/data_preprocess``. -# Example usage: -# -# python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math -# python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - - -# prepare model ckpt -huggingface-cli download Qwen/Qwen2-7B-Instruct --local-dir $HOME/models/Qwen2-7B-Instruct & -huggingface-cli download sfairXC/FsfairX-LLaMA3-RM-v0.1 --local-dir $HOME/models/FsfairX-LLaMA3-RM-v0.1 & -wait - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path="$HOME/models/Qwen2-7B-Instruct" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.optim.lr_warmup_steps_ratio=0.05 \ - critic.model.path="$HOME/models/Qwen2-7B-Instruct" \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_micro_batch_size_per_gpu=32 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - reward_model.enable=True \ - reward_model.model.path="$HOME/models/FsfairX-LLaMA3-RM-v0.1" \ - reward_model.model.use_remove_padding=True \ - reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size_per_gpu=32 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example' \ - trainer.val_before_train=False \ - trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh deleted file mode 100644 index e0ddc01e7..000000000 --- a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh +++ /dev/null @@ -1,60 +0,0 @@ -set -x - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=4096 \ - data.max_prompt_length=4096 \ - data.max_response_length=4096 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=512 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=Qwen/Qwen2-7B-Instruct \ - critic.model.enable_gradient_checkpointing=True \ - critic.use_dynamic_bsz=True \ - critic.ppo_max_token_len_per_gpu=98304 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - reward_model.enable=True \ - reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\ - reward_model.model.use_remove_padding=True \ - reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size_per_gpu=32 \ - reward_model.use_dynamic_bsz=True \ - reward_model.forward_max_token_len_per_gpu=98304 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing' \ - trainer.n_gpus_per_node=8 \ - trainer.val_before_train=False \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh deleted file mode 100644 index 7e0a335ef..000000000 --- a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh +++ /dev/null @@ -1,64 +0,0 @@ -set -x - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -FUSED_KERNEL_BACKEND=triton # or 'torch' for torch backend - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=4096 \ - data.max_prompt_length=4096 \ - data.max_response_length=4096 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.use_fused_kernels=True \ - actor_rollout_ref.model.fused_kernel_options.impl_backend=$FUSED_KERNEL_BACKEND \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=512 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=Qwen/Qwen2-7B-Instruct \ - critic.model.enable_gradient_checkpointing=True \ - critic.use_dynamic_bsz=True \ - critic.ppo_max_token_len_per_gpu=98304 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - reward_model.enable=True \ - reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\ - reward_model.model.use_remove_padding=True \ - reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size_per_gpu=32 \ - reward_model.use_dynamic_bsz=True \ - reward_model.forward_max_token_len_per_gpu=98304 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing_fused_kernel' \ - trainer.n_gpus_per_node=8 \ - trainer.val_before_train=False \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh deleted file mode 100644 index 4173d02ea..000000000 --- a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh +++ /dev/null @@ -1,78 +0,0 @@ -set -x - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files=${train_files:-"$gsm8k_train_path"} -test_files=${test_files:-"$gsm8k_test_path"} - -PROFILE_STEPS="[1,2,5]" # or [] or null -PROFILE_RANKS_ALL=False # or True -PROFILE_RANKS=[0,4,8,12] -DISCRETE=True # or True - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=4096 \ - data.max_prompt_length=4096 \ - data.max_response_length=4096 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=512 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ - actor_rollout_ref.profiler.ranks=$PROFILE_RANKS \ - actor_rollout_ref.profiler.all_ranks=$PROFILE_RANKS_ALL \ - actor_rollout_ref.profiler.discrete=$DISCRETE \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=Qwen/Qwen2-7B-Instruct \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_micro_batch_size_per_gpu=2 \ - critic.use_dynamic_bsz=True \ - critic.ppo_max_token_len_per_gpu=98304 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - critic.profiler.ranks=$PROFILE_RANKS \ - critic.profiler.all_ranks=$PROFILE_RANKS_ALL \ - critic.profiler.discrete=$DISCRETE \ - reward_model.enable=True \ - reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\ - reward_model.model.use_remove_padding=True \ - reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size_per_gpu=32 \ - reward_model.use_dynamic_bsz=True \ - reward_model.forward_max_token_len_per_gpu=98304 \ - reward_model.profiler.ranks=$PROFILE_RANKS \ - reward_model.profiler.all_ranks=$PROFILE_RANKS_ALL \ - reward_model.profiler.discrete=$DISCRETE \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing' \ - trainer.n_gpus_per_node=8 \ - trainer.val_before_train=False \ - trainer.nnodes=2 \ - trainer.save_freq=-1 \ - trainer.test_freq=-1 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=6 \ - trainer.profile_steps=$PROFILE_STEPS $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh deleted file mode 100644 index 9717e5f94..000000000 --- a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh +++ /dev/null @@ -1,60 +0,0 @@ -set -x - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -# For async rollout mode, dataset should return raw chat. -rollout_mode="sync" -if [ "$rollout_mode" = "async" ]; then - return_raw_chat="True" -fi - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.return_raw_chat=$return_raw_chat \ - data.train_batch_size=4096 \ - data.max_prompt_length=4096 \ - data.max_response_length=4096 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=512 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.mode=$rollout_mode \ - actor_rollout_ref.rollout.multi_turn.format=hermes \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=Qwen/Qwen2-7B-Instruct \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_max_token_len_per_gpu=98304 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='qwen2-7b_function_rm_bsz8k_p4k_r4k_seq_packing' \ - trainer.n_gpus_per_node=8 \ - trainer.val_before_train=False \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh deleted file mode 100644 index 5108e8b5d..000000000 --- a/examples/ppo_trainer/run_qwen2-7b_sglang_seq_balance.sh +++ /dev/null @@ -1,51 +0,0 @@ -set -x - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=4096 \ - data.max_prompt_length=4096 \ - data.max_response_length=4096 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=512 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=Qwen/Qwen2-7B-Instruct \ - critic.model.enable_gradient_checkpointing=True \ - critic.ppo_max_token_len_per_gpu=98304 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='qwen2-7b_function_rm_bsz8k_p4k_r4k_seq_packing' \ - trainer.n_gpus_per_node=8 \ - trainer.val_before_train=False \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2.5-32b.sh b/examples/ppo_trainer/run_qwen2.5-32b.sh deleted file mode 100644 index 580376585..000000000 --- a/examples/ppo_trainer/run_qwen2.5-32b.sh +++ /dev/null @@ -1,50 +0,0 @@ -set -x - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ - actor_rollout_ref.model.enable_gradient_checkpointing=False \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding=True \ - critic.model.path=Qwen/Qwen2.5-32B-Instruct \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=8 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example' \ - trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=4 \ - trainer.save_freq=20 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 $@ diff --git a/examples/ray/tutorial.ipynb b/examples/ray/tutorial.ipynb deleted file mode 100644 index ca176af0f..000000000 --- a/examples/ray/tutorial.ipynb +++ /dev/null @@ -1,963 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "0ddc582b", - "metadata": {}, - "source": [ - "# VeRL Ray API Tutorial" - ] - }, - { - "cell_type": "markdown", - "id": "71fe3b94", - "metadata": {}, - "source": [ - "## Chapter 1: Ray Basics" - ] - }, - { - "cell_type": "code", - "execution_count": 144, - "id": "1347d381", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import os" - ] - }, - { - "cell_type": "code", - "execution_count": 145, - "id": "e75b9d44", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import warnings\n", - "\n", - "import ray\n", - "import torch\n", - "\n", - "warnings.filterwarnings(\"ignore\")" - ] - }, - { - "cell_type": "code", - "execution_count": 146, - "id": "2e90ae00", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-11-01 17:27:19,132\tINFO worker.py:1752 -- Started a local Ray instance.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9cc9d2ccbdfb48918c8fd6cd13a0807a", - "version_major": 2, - "version_minor": 0 - }, - "text/html": [ - "
\n", - "
\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
Python version:3.9.2
Ray version:2.10.0
\n", - "\n", - "
\n", - "
\n" - ], - "text/plain": [ - "RayContext(dashboard_url='', python_version='3.9.2', ray_version='2.10.0', ray_commit='09abba26b5bf2707639bb637c208d062a47b46f6')" - ] - }, - "execution_count": 146, - "metadata": {}, - "output_type": "execute_result" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[36m(GPUAccumulator pid=224400)\u001b[0m rank 0, value: tensor([1.], device='cuda:0')\n", - "\u001b[36m(GPUAccumulator pid=225234)\u001b[0m rank 2, value: tensor([3.], device='cuda:0')\n", - "\u001b[36m(GPUAccumulator pid=225607)\u001b[0m rank 0, value: tensor([2.], device='cuda:0')\n", - "\u001b[36m(GPUAccumulator pid=226423)\u001b[0m rank 1, value: tensor([3.], device='cuda:0')\n", - "\u001b[36m(GPUAccumulator pid=226857)\u001b[0m rank 3, value: tensor([6.], device='cuda:0')\n", - "\u001b[36m(GPUAccumulatorDecorator pid=227475)\u001b[0m 10\n", - "\u001b[36m(GPUAccumulatorDecorator pid=227475)\u001b[0m rank 0, value: tensor([10.], device='cuda:0')\n", - "\u001b[36m(GPUAccumulatorDecorator pid=227655)\u001b[0m rank 1, value: tensor([11.], device='cuda:0')\n" - ] - } - ], - "source": [ - "# Build a local ray cluster. The head node and worker node are on this machine\n", - "ray.init()" - ] - }, - { - "cell_type": "markdown", - "id": "a127e4e4", - "metadata": {}, - "source": [ - "Implement an Accumulator class." - ] - }, - { - "cell_type": "code", - "execution_count": 147, - "id": "20e7b9a3", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "@ray.remote\n", - "class Accumulator:\n", - " def __init__(self):\n", - " self.value = 0\n", - "\n", - " def add(self, x):\n", - " self.value += x\n", - "\n", - " def get_value(self):\n", - " return self.value" - ] - }, - { - "cell_type": "code", - "execution_count": 148, - "id": "3b80098c", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Instantiate an accumulator. Accumulator can be viewed as a process, acting as an RPC service.\n", - "accumulator = Accumulator.remote()" - ] - }, - { - "cell_type": "code", - "execution_count": 149, - "id": "b14b1009", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0\n" - ] - } - ], - "source": [ - "value_ref = accumulator.get_value.remote() # Check the current value. Note that this function returns immediately and does not actually wait for the remote execution to complete.\n", - "# Get the value\n", - "value = ray.get(value_ref)\n", - "print(value)" - ] - }, - { - "cell_type": "code", - "execution_count": 150, - "id": "513a84b3", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "10\n" - ] - } - ], - "source": [ - "# Accumulate, then check the result.\n", - "accumulator.add.remote(10) # Similarly, the 'add' here will return immediately.\n", - "new_value = ray.get(accumulator.get_value.remote())\n", - "print(new_value)" - ] - }, - { - "cell_type": "markdown", - "id": "3c332fe0", - "metadata": {}, - "source": [ - "## Chapter 2: Resource Pool and RayWorkerGroup\n", - "In the previous example, it was a simple single-process worker. \n", - "In this example, we implement a worker with a GPU and form a RayWorkerGroup. Within this RayWorkerGroup, we implement a simple operation of an accumulator." - ] - }, - { - "cell_type": "code", - "execution_count": 151, - "id": "04229afb", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from verl.single_controller.base import Worker\n", - "from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool" - ] - }, - { - "cell_type": "code", - "execution_count": 152, - "id": "0d0dbd58", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "resource_pool = RayResourcePool([4], use_gpu=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 153, - "id": "68f6838a", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "@ray.remote\n", - "class GPUAccumulator(Worker):\n", - " def __init__(self) -> None:\n", - " super().__init__()\n", - " # The initial value of each rank is the same as the rank\n", - " self.value = torch.zeros(size=(1,), device=\"cuda\") + self.rank\n", - "\n", - " def add(self, x):\n", - " self.value += x\n", - " print(f\"rank {self.rank}, value: {self.value}\")\n", - " return self.value.cpu()" - ] - }, - { - "cell_type": "code", - "execution_count": 154, - "id": "23aad8fe", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[tensor([1.]), tensor([2.]), tensor([3.]), tensor([4.])]\n" - ] - } - ], - "source": [ - "# Each worker's initial value is its rank, and then each rank's value is incremented by 1, so the values obtained on each rank are [1, 2, 3, 4]\n", - "class_with_args = RayClassWithInitArgs(cls=GPUAccumulator)\n", - "worker_group = RayWorkerGroup(resource_pool, class_with_args)\n", - "print(worker_group.execute_all_sync(\"add\", x=[1, 1, 1, 1]))" - ] - }, - { - "cell_type": "markdown", - "id": "e6705284", - "metadata": {}, - "source": [ - "The principle of parameter passing: The input parameter is a list of length world_size, where each element in the list is dispatched respectively to each worker in the RayWorkerGroup. \n", - "The return parameter is also a list, corresponding to the return value of each worker." - ] - }, - { - "cell_type": "markdown", - "id": "d25c2412", - "metadata": {}, - "source": [ - "### GPU Resource Sharing" - ] - }, - { - "cell_type": "markdown", - "id": "f74f6d24", - "metadata": {}, - "source": [ - "RayWorkerGroups mapped to the same resource pool share the GPU. In this example, we implement three resource pools: the first occupies 4 GPUs, the second also occupies 4 GPUs, and the last occupies all 8 GPUs. Among them, the first resource pool reuses the resource pool mentioned above." - ] - }, - { - "cell_type": "code", - "execution_count": 155, - "id": "49f9c06f", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Create a new resource pool and then merge the newly created resource pool with the previous one.\n", - "resource_pool_1 = RayResourcePool([4], use_gpu=True, name_prefix=\"a\")\n", - "resource_pool_merge = merge_resource_pool(resource_pool, resource_pool_1)" - ] - }, - { - "cell_type": "code", - "execution_count": 156, - "id": "05c2e305", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Establish a RayWorkerGroup on the newly created resource pool.\n", - "worker_group_1 = RayWorkerGroup(resource_pool_1, class_with_args)\n", - "worker_group_merge = RayWorkerGroup(resource_pool_merge, class_with_args)" - ] - }, - { - "cell_type": "code", - "execution_count": 157, - "id": "6b9b13f4", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[tensor([2.]), tensor([3.]), tensor([4.]), tensor([5.])]\n" - ] - } - ], - "source": [ - "# Run 'add' on the second set of 4 GPUs; the result should be [2, 3, 4, 5].\n", - "output_1 = worker_group_1.execute_all_sync(\"add\", x=[2, 2, 2, 2])\n", - "print(output_1)" - ] - }, - { - "cell_type": "code", - "execution_count": 158, - "id": "d856d030", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[tensor([3.]), tensor([4.]), tensor([5.]), tensor([6.]), tensor([7.]), tensor([8.]), tensor([9.]), tensor([10.])]\n" - ] - } - ], - "source": [ - "# Run 'add' on the merged set of 8 GPUs; the result should be [3, 4, 5, 6, 7, 8, 9, 10].\n", - "output_merge = worker_group_merge.execute_all_sync(\"add\", x=[3, 3, 3, 3, 3, 3, 3, 3])\n", - "print(output_merge)" - ] - }, - { - "cell_type": "code", - "execution_count": 159, - "id": "33a4628c", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "4 4 8\n" - ] - } - ], - "source": [ - "print(worker_group.world_size, worker_group_1.world_size, worker_group_merge.world_size)" - ] - }, - { - "cell_type": "markdown", - "id": "3df19d13", - "metadata": {}, - "source": [ - "## Chapter 3: Data Dispatch, Execution and Collection" - ] - }, - { - "cell_type": "markdown", - "id": "acb22d9d", - "metadata": {}, - "source": [ - "In the above example, we used the `execute_all_sync` function in the RayWorkerGroup to dispatch data from the driver to each worker. This is very inconvenient for coding. \n", - "In this chapter, we use the form of function decorators to allow RayWorkerGroup to directly call functions written in the Worker, and to greatly simplify parameter passing." - ] - }, - { - "cell_type": "code", - "execution_count": 160, - "id": "35237432", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from verl.single_controller.base.decorator import Dispatch, Execute, register" - ] - }, - { - "cell_type": "code", - "execution_count": 161, - "id": "88b8ba3b", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "@ray.remote\n", - "class GPUAccumulatorDecorator(Worker):\n", - " def __init__(self) -> None:\n", - " super().__init__()\n", - " # The initial value of each rank is the same as the rank\n", - " self.value = torch.zeros(size=(1,), device=\"cuda\") + self.rank\n", - "\n", - " # map from a single input to all the worker\n", - " @register(Dispatch.ONE_TO_ALL)\n", - " def add(self, x):\n", - " print(x)\n", - " self.value = self.value + x\n", - " print(f\"rank {self.rank}, value: {self.value}\")\n", - " return self.value.cpu()" - ] - }, - { - "cell_type": "code", - "execution_count": 162, - "id": "eddaa043", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class_with_args = RayClassWithInitArgs(cls=GPUAccumulatorDecorator)\n", - "gpu_accumulator_decorator = RayWorkerGroup(resource_pool_merge, class_with_args)" - ] - }, - { - "cell_type": "code", - "execution_count": 163, - "id": "10087c91", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[tensor([10.]), tensor([11.]), tensor([12.]), tensor([13.]), tensor([14.]), tensor([15.]), tensor([16.]), tensor([17.])]\n" - ] - } - ], - "source": [ - "# As we can see, 10 is automatically dispatched to each Worker in this RayWorkerGroup.\n", - "print(gpu_accumulator_decorator.add(x=10))" - ] - }, - { - "cell_type": "markdown", - "id": "540ee6ad", - "metadata": {}, - "source": [ - "### Custom Dispatch, Collection\n", - "Users can customize `dispatch` and `collection` function. You only need to write the `dispatch_fn` and `collect_fn` functions yourself. We also support executing RPC only on rank_zero, with specific examples provided below." - ] - }, - { - "cell_type": "code", - "execution_count": 164, - "id": "8e041270", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from verl.single_controller.base.decorator import Dispatch, collect_all_to_all, register" - ] - }, - { - "cell_type": "code", - "execution_count": 165, - "id": "43b5be31", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def two_to_all_dispatch_fn(worker_group, *args, **kwargs):\n", - " \"\"\"\n", - " Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker.\n", - " \"\"\"\n", - " for arg in args:\n", - " assert len(arg) == 2\n", - " for i in range(worker_group.world_size - 2):\n", - " arg.append(arg[i % 2])\n", - " for k, v in kwargs.items():\n", - " assert len(v) == 2\n", - " for i in range(worker_group.world_size - 2):\n", - " v.append(v[i % 2])\n", - " return args, kwargs\n", - "\n", - "\n", - "@ray.remote\n", - "class TestActor(Worker):\n", - " # TODO: pass *args and **kwargs is bug prone and not very convincing\n", - " def __init__(self, x) -> None:\n", - " super().__init__()\n", - " self._x = x\n", - "\n", - " def foo(self, y):\n", - " return self._x + y\n", - "\n", - " @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO)\n", - " def foo_rank_zero(self, x, y):\n", - " return self._x + y + x\n", - "\n", - " @register(dispatch_mode={\"dispatch_fn\": two_to_all_dispatch_fn, \"collect_fn\": collect_all_to_all})\n", - " def foo_custom(self, x, y):\n", - " return self._x + y + x" - ] - }, - { - "cell_type": "code", - "execution_count": 166, - "id": "83ec6609", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)\n", - "worker_group = RayWorkerGroup(resource_pool, class_with_args)" - ] - }, - { - "cell_type": "code", - "execution_count": 167, - "id": "62c58d8a", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])\n", - "assert output_ref == [8, 10, 8, 10]\n", - "\n", - "output_ref = worker_group.foo_rank_zero(x=1, y=2)\n", - "assert output_ref == 5" - ] - }, - { - "cell_type": "code", - "execution_count": 168, - "id": "14689353", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "8\n" - ] - } - ], - "source": [ - "print(gpu_accumulator_decorator.world_size)" - ] - }, - { - "cell_type": "code", - "execution_count": 169, - "id": "2c80bbf4", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Shutdown ray cluster\n", - "ray.shutdown()" - ] - }, - { - "cell_type": "markdown", - "id": "a5c8151c", - "metadata": {}, - "source": [ - "## Chapter 4: NVMegatronRayWorkerGroup" - ] - }, - { - "cell_type": "markdown", - "id": "cd5680e9", - "metadata": {}, - "source": [ - "Due to the Ray issue, we can only support max_colocate_count=1 in RayResourcePool for now. \n", - "This means that each GPU can only have one process.\n", - "We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385" - ] - }, - { - "cell_type": "markdown", - "id": "92724419", - "metadata": {}, - "source": [ - "Therefore, we need to restart the ray and initialize a new resource_pool to demonstrate the **NVMegatronRayWorkerGroup**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9b038538", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Build a local ray cluster. The head node and worker node are on this machine\n", - "ray.init()" - ] - }, - { - "cell_type": "markdown", - "id": "ebfd8798", - "metadata": {}, - "source": [ - "Finally, we implement a `NVMegatronRayWorkerGroup`, within which we create a Megatron and then run a tensor parallel (tp) split Llama mlp layer. Here, we use a complex dispatch mode, `Megatron_COMPUTE`. This dispatch mode assumes that user passes the data partitioned by DP dimension. The data is dispatched to all tp/pp ranks within the same dp group, and ultimately only collects output data from tp=0 and the last pp. In this way, for users that only write code on the driver, the Megatron behind the RPC becomes transparent." - ] - }, - { - "cell_type": "code", - "execution_count": 171, - "id": "5a032154", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/opt/tiger/Megatron-LM\n", - "/opt/tiger/Megatron-LM/megatron/__init__.py\n" - ] - } - ], - "source": [ - "import sys\n", - "\n", - "current_pythonpath = os.environ.get(\"PYTHONPATH\", \"\")\n", - "\n", - "new_path = \"/opt/tiger/Megatron-LM\"\n", - "\n", - "new_pythonpath = f\"{new_path}:{current_pythonpath}\" if current_pythonpath else new_path\n", - "\n", - "os.environ[\"PYTHONPATH\"] = new_pythonpath\n", - "\n", - "print(new_path)\n", - "sys.path.append(new_path)\n", - "\n", - "import megatron\n", - "\n", - "print(megatron.__file__)" - ] - }, - { - "cell_type": "code", - "execution_count": 172, - "id": "8c84cd5a", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from megatron.core import parallel_state as mpu\n", - "from omegaconf import OmegaConf\n", - "\n", - "from verl.single_controller.base.decorator import Dispatch, Execute, register\n", - "from verl.single_controller.base.megatron.worker import MegatronWorker\n", - "from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup\n", - "from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup" - ] - }, - { - "cell_type": "code", - "execution_count": 173, - "id": "1b1debcc", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "resource_pool = RayResourcePool([4], use_gpu=True, max_colocate_count=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 174, - "id": "bccbe081", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "@ray.remote\n", - "class MLPLayerWorker(MegatronWorker):\n", - " def __init__(self):\n", - " super().__init__()\n", - " rank = int(os.environ[\"LOCAL_RANK\"])\n", - " torch.distributed.init_process_group(backend=\"nccl\")\n", - " torch.cuda.set_device(rank)\n", - "\n", - " mpu.initialize_model_parallel(\n", - " tensor_model_parallel_size=4,\n", - " pipeline_model_parallel_size=1,\n", - " virtual_pipeline_model_parallel_size=None,\n", - " pipeline_model_parallel_split_rank=None,\n", - " use_sharp=False,\n", - " context_parallel_size=1,\n", - " expert_model_parallel_size=1,\n", - " nccl_communicator_config_path=None,\n", - " )\n", - " from megatron.core import tensor_parallel\n", - "\n", - " tensor_parallel.model_parallel_cuda_manual_seed(10)\n", - "\n", - " @register(Dispatch.ONE_TO_ALL)\n", - " def init_model(self, config):\n", - " from omegaconf import OmegaConf\n", - "\n", - " from verl.models.llama.megatron.layers import ParallelLlamaMLP\n", - " from verl.utils.megatron_utils import init_model_parallel_config\n", - "\n", - " megatron_config = OmegaConf.create(\n", - " {\n", - " \"sequence_parallel\": False,\n", - " \"param_dtype\": \"fp32\",\n", - " \"tensor_model_parallel_size\": mpu.get_tensor_model_parallel_world_size(),\n", - " \"pipeline_model_parallel_rank\": mpu.get_pipeline_model_parallel_rank(),\n", - " \"pipeline_model_parallel_size\": mpu.get_pipeline_model_parallel_world_size(),\n", - " \"virtual_pipeline_model_parallel_rank\": mpu.get_virtual_pipeline_model_parallel_rank(),\n", - " \"virtual_pipeline_model_parallel_size\": mpu.get_virtual_pipeline_model_parallel_world_size(),\n", - " }\n", - " )\n", - "\n", - " megatron_config = init_model_parallel_config(megatron_config)\n", - " self.parallel_layer = ParallelLlamaMLP(config=config, megatron_config=megatron_config)\n", - "\n", - " @register(Dispatch.ONE_TO_ALL)\n", - " def get_weights(self):\n", - " output = {}\n", - " for key, val in self.parallel_layer.named_parameters():\n", - " output[key] = val\n", - " return output\n", - "\n", - " @register(Dispatch.MEGATRON_COMPUTE)\n", - " def run_layer(self, x):\n", - " x = x.to(\"cuda\")\n", - " y = self.parallel_layer(x)\n", - " return y" - ] - }, - { - "cell_type": "code", - "execution_count": 175, - "id": "a655271d", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "layer_cls = RayClassWithInitArgs(cls=MLPLayerWorker)\n", - "layer_worker_group = NVMegatronRayWorkerGroup(\n", - " resource_pool=resource_pool,\n", - " ray_cls_with_init=layer_cls,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 176, - "id": "f105ebee", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "4 4 1 1\n" - ] - } - ], - "source": [ - "print(layer_worker_group.world_size, layer_worker_group.tp_size, layer_worker_group.pp_size, layer_worker_group.dp_size)" - ] - }, - { - "cell_type": "code", - "execution_count": 177, - "id": "38655091", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "ffn_hidden_size = 11008\n", - "batch_size = 16\n", - "seq_len = 2048\n", - "hidden_size = 4096\n", - "\n", - "config = OmegaConf.create(\n", - " {\n", - " \"hidden_size\": hidden_size,\n", - " \"intermediate_size\": ffn_hidden_size,\n", - " \"hidden_act\": \"silu\",\n", - " \"pretraining_tp\": 1,\n", - " \"tp\": layer_worker_group.tp_size,\n", - " }\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 178, - "id": "a026efca", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "x = torch.rand(size=(seq_len, batch_size, hidden_size), dtype=torch.float32)" - ] - }, - { - "cell_type": "code", - "execution_count": 179, - "id": "f5fcaf13", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[None, None, None, None]" - ] - }, - "execution_count": 179, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "layer_worker_group.init_model(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 180, - "id": "3f5cc9b4", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([2048, 16, 4096])\n" - ] - } - ], - "source": [ - "output = layer_worker_group.run_layer(\n", - " [x]\n", - ") # This must be a list of size 1, ensuring that the input equals the data parallel (dp).\n", - "print(output[0].shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 181, - "id": "49792210", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Shutdown ray cluster\n", - "ray.shutdown()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh b/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh deleted file mode 100644 index c2bf6d05b..000000000 --- a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf.sh +++ /dev/null @@ -1,49 +0,0 @@ -set -x - - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=reinforce_plus_plus \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=3e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=1024 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=mse \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_7b_function_rm' \ - trainer.n_gpus_per_node=16 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh b/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh deleted file mode 100644 index b134ee5d1..000000000 --- a/examples/reinforce_plus_plus_trainer/run_qwen2-7b_math_rf_baseline.sh +++ /dev/null @@ -1,49 +0,0 @@ -set -x - - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=reinforce_plus_plus_baseline \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=3e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=1024 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=mse \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_7b_function_rm' \ - trainer.n_gpus_per_node=16 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh b/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh deleted file mode 100644 index feebe8a84..000000000 --- a/examples/remax_trainer/run_qwen2.5-3b_seq_balance.sh +++ /dev/null @@ -1,43 +0,0 @@ -set -x - -export HF_DATASETS_OFFLINE=1 -export TRANSFORMERS_OFFLINE=1 - - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=remax \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=512 \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.rollout.n=4 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=True \ - algorithm.kl_penalty=kl \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_remax_example_gsm8k' \ - trainer.experiment_name='qwen2.5_3b_function_rm_kl1e-3' \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=5 $@ diff --git a/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh b/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh deleted file mode 100644 index 8734eb351..000000000 --- a/examples/remax_trainer/run_qwen2.5-7b_seq_balance.sh +++ /dev/null @@ -1,43 +0,0 @@ -set -x - -export HF_DATASETS_OFFLINE=1 -export TRANSFORMERS_OFFLINE=1 - - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=remax \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.rollout.n=4 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=True \ - algorithm.kl_penalty=kl \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_remax_example_gsm8k' \ - trainer.experiment_name='qwen2.5_7b_function_rm_kl1e-3' \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=10 $@ diff --git a/examples/rloo_trainer/run_qwen2-7b.sh b/examples/rloo_trainer/run_qwen2-7b.sh deleted file mode 100644 index fc9b6e29f..000000000 --- a/examples/rloo_trainer/run_qwen2-7b.sh +++ /dev/null @@ -1,40 +0,0 @@ -set -x - - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=rloo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=True \ - algorithm.kl_penalty=kl \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_rloo_example_gsm8k' \ - trainer.experiment_name='qwen2_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/examples/sft/gsm8k/run_deepseek_6b7.sh b/examples/sft/gsm8k/run_deepseek_6b7.sh deleted file mode 100644 index 8a067f05d..000000000 --- a/examples/sft/gsm8k/run_deepseek_6b7.sh +++ /dev/null @@ -1,28 +0,0 @@ -set -x - -if [ "$#" -lt 2 ]; then - echo "Usage: run_deepseek_6b7.sh [other_configs...]" - exit 1 -fi - -nproc_per_node=$1 -save_path=$2 - -# Shift the arguments so $@ refers to the rest -shift 2 - -torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ - data.micro_batch_size_per_gpu=4 \ - model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \ - trainer.default_local_dir=$save_path \ - trainer.project_name=gsm8k-sft \ - trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ - trainer.total_epochs=4 \ - trainer.logger='["console","wandb"]' $@ \ No newline at end of file diff --git a/examples/sft/gsm8k/run_gemma_2b.sh b/examples/sft/gsm8k/run_gemma_2b.sh deleted file mode 100644 index 5b59893d2..000000000 --- a/examples/sft/gsm8k/run_gemma_2b.sh +++ /dev/null @@ -1,30 +0,0 @@ -# Tested with 2 & 4 GPUs - -set -x - -if [ "$#" -lt 2 ]; then - echo "Usage: run_gemma_2b.sh [other_configs...]" - exit 1 -fi - -nproc_per_node=$1 -save_path=$2 - -# Shift the arguments so $@ refers to the rest -shift 2 - -torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ - data.micro_batch_size_per_gpu=4 \ - model.partial_pretrain=google/gemma-2b-it \ - trainer.default_local_dir=$save_path \ - trainer.project_name=gsm8k-sft \ - trainer.experiment_name=gsm8k-sft-gemma-2b-it \ - trainer.total_epochs=2 \ - trainer.logger='["console","wandb"]' $@ \ No newline at end of file diff --git a/examples/sft/gsm8k/run_gemma_7b.sh b/examples/sft/gsm8k/run_gemma_7b.sh deleted file mode 100644 index aee7603d7..000000000 --- a/examples/sft/gsm8k/run_gemma_7b.sh +++ /dev/null @@ -1,26 +0,0 @@ -set -x - -if [ "$#" -lt 2 ]; then - echo "Usage: run_gemma_7b.sh [other_configs...]" - exit 1 -fi - -nproc_per_node=$1 -save_path=$2 - -# Shift the arguments so $@ refers to the rest -shift 2 - -torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=prompt \ - data.response_key=answer \ - data.micro_batch_size_per_gpu=4 \ - model.partial_pretrain=google/gemma-1.1-7b-it \ - trainer.default_local_dir=$save_path \ - trainer.project_name=gsm8k-sft \ - trainer.experiment_name=gsm8k-sft-gemma-1.1-7b-it \ - trainer.total_epochs=4 \ - trainer.logger='["console","wandb"]' $@ \ No newline at end of file diff --git a/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh b/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh deleted file mode 100644 index 45e427f39..000000000 --- a/examples/sft/gsm8k/run_qwen2_5_05b_sft_peft_sp2_npu.sh +++ /dev/null @@ -1,35 +0,0 @@ -set -x - -if [ "$#" -lt 2 ]; then - echo "Usage: run_qwen2_5_05b_sft_peft_sp2_npu.sh [other_configs...]" - exit 1 -fi - -nproc_per_node=$1 -save_path=$2 - -# Shift the arguments so $@ refers to the rest -shift 2 - -torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - optim.lr=1e-4 \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ - data.micro_batch_size_per_gpu=64 \ - model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ - trainer.default_local_dir=$save_path \ - trainer.project_name=gsm8k-sft \ - trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \ - trainer.logger=console \ - trainer.total_epochs=2 $@ \ - model.lora_rank=32 \ - model.lora_alpha=16 \ - model.target_modules=all-linear \ - model.strategy=fsdp \ - ulysses_sequence_parallel_size=2 \ - use_remove_padding=true diff --git a/examples/sft/gsm8k/run_qwen_05_peft.sh b/examples/sft/gsm8k/run_qwen_05_peft.sh deleted file mode 100644 index 3a7d44558..000000000 --- a/examples/sft/gsm8k/run_qwen_05_peft.sh +++ /dev/null @@ -1,37 +0,0 @@ -# Tested with 2 & 4 GPUs - -set -x - -if [ "$#" -lt 2 ]; then - echo "Usage: run_qwen_05_peft.sh [other_configs...]" - exit 1 -fi - -nproc_per_node=$1 -save_path=$2 - -# Shift the arguments so $@ refers to the rest -shift 2 - -torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - optim.lr=1e-4 \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ - data.micro_batch_size_per_gpu=4 \ - model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ - trainer.default_local_dir=$save_path \ - trainer.project_name=gsm8k-sft \ - trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \ - trainer.logger=console \ - trainer.total_epochs=1 $@ \ - model.lora_rank=32\ - model.lora_alpha=16 \ - model.target_modules=all-linear - - # Or you can do this: - # model.target_modules=[q_proj,v_proj] \ diff --git a/examples/sft/gsm8k/run_qwen_05_sp2.sh b/examples/sft/gsm8k/run_qwen_05_sp2.sh deleted file mode 100644 index 7210a5a40..000000000 --- a/examples/sft/gsm8k/run_qwen_05_sp2.sh +++ /dev/null @@ -1,31 +0,0 @@ -set -x - -if [ "$#" -lt 2 ]; then - echo "Usage: run_qwen_05_sp2.sh [other_configs...]" - exit 1 -fi - -nproc_per_node=$1 -save_path=$2 - -# Shift the arguments so $@ refers to the rest -shift 2 - -torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - optim.lr=1e-4 \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ - data.micro_batch_size=4 \ - model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ - trainer.default_local_dir=$save_path \ - trainer.project_name=gsm8k-sft \ - trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2 \ - trainer.logger=console \ - trainer.total_training_steps=1 $@ \ - ulysses_sequence_parallel_size=2 \ - use_remove_padding=true diff --git a/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh b/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh deleted file mode 100644 index 1c5cd591f..000000000 --- a/examples/sft/gsm8k/run_qwen_05_sp2_liger.sh +++ /dev/null @@ -1,31 +0,0 @@ -set -x - -if [ "$#" -lt 2 ]; then - echo "Usage: run_qwen_05_sp2.sh [other_configs...]" - exit 1 -fi - -nproc_per_node=$1 -save_path=$2 - -# Shift the arguments so $@ refers to the rest -shift 2 - -torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - optim.lr=1e-4 \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ - data.micro_batch_size=4 \ - model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ - model.use_liger=True \ - trainer.default_local_dir=$save_path \ - trainer.project_name=gsm8k-sft \ - trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2-liger \ - trainer.logger=console $@ \ - ulysses_sequence_parallel_size=2 \ - use_remove_padding=true diff --git a/examples/sft/multiturn/run_qwen_05_sp2.sh b/examples/sft/multiturn/run_qwen_05_sp2.sh deleted file mode 100644 index 5e1fc47e9..000000000 --- a/examples/sft/multiturn/run_qwen_05_sp2.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash -set -x - -if [ "$#" -lt 2 ]; then - echo "Usage: run_qwen_05_sp2.sh [other_configs...]" - exit 1 -fi - -nproc_per_node=$1 -save_path=$2 - -# Shift the arguments so $@ refers to the rest -shift 2 - -torchrun --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/multiturn/train.parquet \ - data.val_files=$HOME/data/multiturn/test.parquet \ - data.multiturn.enable=true \ - data.multiturn.messages_key=messages \ - data.micro_batch_size=4 \ - model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ - trainer.default_local_dir=$save_path \ - trainer.project_name=multiturn-sft \ - trainer.experiment_name=multiturn-sft-qwen-2.5-0.5b-instruct-sp2 \ - trainer.logger=console \ - trainer.total_training_steps=1 $@ \ - ulysses_sequence_parallel_size=2 \ - use_remove_padding=true \ No newline at end of file diff --git a/examples/sglang_multiturn/README.md b/examples/sglang_multiturn/README.md deleted file mode 100644 index 0c97c7e75..000000000 --- a/examples/sglang_multiturn/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# Multi-Turn Rollout Example (GSM8K) - -This example demonstrates how to perform **multi-turn rollout** using SGLang with a tool-calling capable model (e.g., Qwen2.5-3B) on the GSM8K dataset. - -## Usage - -### Step 1: Download GSM8K Dataset - -```bash -cd examples/data_preprocess -python3 gsm8k_multiturn_w_tool.py -``` - -This will download and preprocess the GSM8K dataset into ~/data/gsm8k/. - -### Step 2: Run Multi-Turn Rollout - -If you have 8 GPUs -Use the standard 8-GPU script: - -```bash -cd your_verl_root_dir -bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh -``` - -If you have only 4 GPUs -Use the fallback 4-GPU script: - -```bash -cd your_verl_root_dir -bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh -``` - -## Notes - -- The rollout supports multi-turn conversations with tool-calling capabilities. -- Current tools are used for GSM8K answer evaluation. -- Future versions may extend to search and code interpreter tools. diff --git a/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml b/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml deleted file mode 100644 index a9523f196..000000000 --- a/examples/sglang_multiturn/config/geo3k_multiturn_grpo.yaml +++ /dev/null @@ -1,25 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -data: - max_prompt_length: 2048 - max_response_length: 2048 - train_batch_size: 256 - return_raw_chat: True - return_multi_modal_inputs: False - -actor_rollout_ref: - hybrid_engine: True - model: - custom_chat_template: "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}{%- for tool in tools %}{{- \"\\n\" }}{{- tool | tojson }}{%- endfor %}{{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" - rollout: - name: sglang - multi_turn: - enable: True - max_assistant_turns: 5 - # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" diff --git a/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml b/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml deleted file mode 100644 index 5e208f333..000000000 --- a/examples/sglang_multiturn/config/geo3k_multiturn_megatron_grpo.yaml +++ /dev/null @@ -1,25 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_megatron_trainer - - _self_ - -data: - max_prompt_length: 2048 - max_response_length: 2048 - train_batch_size: 256 - return_raw_chat: True - return_multi_modal_inputs: False - -actor_rollout_ref: - hybrid_engine: True - model: - custom_chat_template: "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\n' }}{%- if messages[0]['role'] == 'system' %}{{- messages[0]['content'] }}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}{%- for tool in tools %}{{- \"\\n\" }}{{- tool | tojson }}{%- endfor %}{{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" - rollout: - name: sglang - multi_turn: - enable: True - max_assistant_turns: 5 - # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml deleted file mode 100644 index e9109232a..000000000 --- a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml +++ /dev/null @@ -1,21 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -data: - max_prompt_length: 1024 - max_response_length: 1024 - train_batch_size: 256 - return_raw_chat: True - -actor_rollout_ref: - hybrid_engine: True - rollout: - name: sglang - multi_turn: - enable: True - max_assistant_turns: 5 diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml deleted file mode 100644 index 122f7e50f..000000000 --- a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo_w_interaction.yaml +++ /dev/null @@ -1,21 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -data: - max_prompt_length: 1024 - max_response_length: 1024 - train_batch_size: 256 - return_raw_chat: True - -actor_rollout_ref: - hybrid_engine: True - rollout: - name: sglang - multi_turn: - enable: True - max_user_turns: 5 diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml deleted file mode 100644 index 8aff859cc..000000000 --- a/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml +++ /dev/null @@ -1,22 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_megatron_trainer - - _self_ - -data: - max_prompt_length: 1024 - max_response_length: 1024 - train_batch_size: 256 - return_raw_chat: True - -actor_rollout_ref: - hybrid_engine: True - rollout: - name: sglang - multi_turn: - enable: True - max_assistant_turns: 5 - diff --git a/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml b/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml deleted file mode 100644 index 78faf386e..000000000 --- a/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml +++ /dev/null @@ -1,4 +0,0 @@ -interaction: - - name: "gsm8k" - class_name: "verl.interactions.gsm8k_interaction.Gsm8kInteraction" - config: {} \ No newline at end of file diff --git a/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml b/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml deleted file mode 100644 index d1cfaccce..000000000 --- a/examples/sglang_multiturn/config/retool_multiturn_grpo.yaml +++ /dev/null @@ -1,22 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -data: - max_prompt_length: 1024 - max_response_length: 1024 - train_batch_size: 256 - return_raw_chat: True - -actor_rollout_ref: - hybrid_engine: True - rollout: - name: sglang - multi_turn: - enable: True - max_assistant_turns: 5 - tool_config_path: "./config/tool_config/sandbox_fusion_tool_config.yaml" diff --git a/examples/sglang_multiturn/config/search_multiturn_grpo.yaml b/examples/sglang_multiturn/config/search_multiturn_grpo.yaml deleted file mode 100644 index 0e24f62b7..000000000 --- a/examples/sglang_multiturn/config/search_multiturn_grpo.yaml +++ /dev/null @@ -1,23 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -data: - max_prompt_length: 1024 - max_response_length: 1024 - train_batch_size: 256 - return_raw_chat: True - shuffle: False - -actor_rollout_ref: - hybrid_engine: True - rollout: - name: sglang - multi_turn: - enable: True - max_assistant_turns: 2 - format: qwen diff --git a/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml deleted file mode 100644 index 675a342e6..000000000 --- a/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml +++ /dev/null @@ -1,16 +0,0 @@ -tools: - - class_name: "verl.tools.geo3k_tool.Geo3kTool" - config: - type: native - tool_schema: - type: "function" - function: - name: "calc_geo3k_reward" - description: "A tool for calculating the reward of geo3k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)" - parameters: - type: "object" - properties: - answer: - type: "string" - description: "The model's answer to the geo3k problem, must be a digits" - required: ["answer"] \ No newline at end of file diff --git a/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml deleted file mode 100644 index a4197baab..000000000 --- a/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml +++ /dev/null @@ -1,16 +0,0 @@ -tools: - - class_name: "verl.tools.gsm8k_tool.Gsm8kTool" - config: - type: native - tool_schema: - type: "function" - function: - name: "calc_gsm8k_reward" - description: "A tool for calculating the reward of gsm8k. (1.0 if parsed answer is correct, 0.0 if parsed answer is incorrect or not correctly parsed)" - parameters: - type: "object" - properties: - answer: - type: "string" - description: "The model's answer to the GSM8K math problem, must be a digits" - required: ["answer"] diff --git a/examples/sglang_multiturn/config/tool_config/mcp_server.json b/examples/sglang_multiturn/config/tool_config/mcp_server.json deleted file mode 100644 index 29424f71e..000000000 --- a/examples/sglang_multiturn/config/tool_config/mcp_server.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "mcpServers": { - "Tavily Expert": { - "url": "your_tavily_expert_url", - "auth_token": "your_tavily_api_token" - } - } -} \ No newline at end of file diff --git a/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml deleted file mode 100644 index 40abf7c67..000000000 --- a/examples/sglang_multiturn/config/tool_config/mcp_tool_config.yaml +++ /dev/null @@ -1,11 +0,0 @@ -tools: - - class_name: verl.tools.mcp_search_tool.MCPSearchTool - config: - rate_limit: 120 - timeout: 120 - type: mcp - mcp: - mcp_servers_config_path: ./mcp_server.json - # optional - tool_selected_list: - - tavily_search_tool \ No newline at end of file diff --git a/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml deleted file mode 100644 index 516acf569..000000000 --- a/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml +++ /dev/null @@ -1,24 +0,0 @@ -tools: - - class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool" - config: - sandbox_fusion_url: "https://xxx.apigateway-cn-beijing.volceapi.com/run_code" - num_workers: 10 - enable_global_rate_limit: true - rate_limit: 10 - default_timeout: 30 - default_language: "python" - memory_limit_mb: 1024 - type: native - - tool_schema: - type: "function" - function: - name: "code_interpreter" - description: "A tool for executing code." - parameters: - type: "object" - properties: - code: - type: "string" - description: "The code to execute." - required: ["code"] \ No newline at end of file diff --git a/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml deleted file mode 100644 index 926b6b832..000000000 --- a/examples/sglang_multiturn/config/tool_config/search_tool_config.yaml +++ /dev/null @@ -1,23 +0,0 @@ -tools: - - class_name: verl.tools.search_tool.SearchTool - config: - retrieval_service_url: http://127.0.0.1:8000/retrieve - num_workers: 120 - rate_limit: 120 - timeout: 30 - type: native - tool_schema: - type: function - function: - name: search - description: Searches the web for relevant information based on the given query. - parameters: - type: object - properties: - query_list: - type: array - item: - type: string - description: A list of fully-formed semantic queries. The tool will return search results for each query. - required: - - query_list \ No newline at end of file diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh deleted file mode 100644 index d9306e9df..000000000 --- a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn.sh +++ /dev/null @@ -1,54 +0,0 @@ -# run on 8xH100 -# make sure your current working directory is the root of the project - -set -x - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='geo3k_multiturn_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=256 \ - data.max_prompt_length=2048 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=16 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='geo3k_async_rl' \ - trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-sgl-multi-w-tool-verify-n16' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=20 \ - data.train_files=$HOME/data/geo3k_multiturn_w_tool/train.parquet \ - data.val_files=$HOME/data/geo3k_multiturn_w_tool/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ - trainer.total_epochs=15 $@ - diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh deleted file mode 100644 index 66f12a5e5..000000000 --- a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_geo3k_multiturn_4xgpu.sh +++ /dev/null @@ -1,58 +0,0 @@ -# run on 4xH100 -# make sure your current working directory is the root of the project - -set -x -export HYDRA_FULL_ERROR=1 -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='geo3k_multiturn_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=256 \ - data.max_prompt_length=2048 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=16 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='geo3k_async_rl' \ - trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-async-sgl-multi-w-tool-verify-n16-4cards' \ - trainer.n_gpus_per_node=4 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=20 \ - trainer.total_epochs=15 \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \ - critic.ppo_max_token_len_per_gpu=8192 \ - critic.forward_max_token_len_per_gpu=8192 \ - data.train_files=$HOME/data/geo3k/train.parquet \ - data.val_files=$HOME/data/geo3k/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ - $@ \ No newline at end of file diff --git a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh b/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh deleted file mode 100644 index 547b34d43..000000000 --- a/examples/sglang_multiturn/geo3k/run_qwen2.5-3b_megatron_geo3k_multiturn.sh +++ /dev/null @@ -1,65 +0,0 @@ -# run on 8xH100 -# make sure your current working directory is the root of the project -# this is a verification training script, the parallel setting should be tuned to your model - -set -x - -export PYTHONUNBUFFERED=1 -export RAY_DEDUP_LOGS=0 -export RUST_BACKTRACE=1 -export HYDRA_FULL_ERROR=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='geo3k_multiturn_megatron_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=256 \ - data.max_prompt_length=2048 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.context_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.megatron.seed=42 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.context_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='geo3k_async_rl' \ - trainer.experiment_name='qwen2.5-3b_function_rm-geo3k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=20 \ - data.train_files=$HOME/data/geo3k_multiturn_w_tool/train.parquet \ - data.val_files=$HOME/data/geo3k_multiturn_w_tool/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ - trainer.total_epochs=15 $@ - diff --git a/examples/sglang_multiturn/run_qwen0.5b_gsm8k_multiturn_curriculum.sh b/examples/sglang_multiturn/run_qwen0.5b_gsm8k_multiturn_curriculum.sh deleted file mode 100755 index d67a76e48..000000000 --- a/examples/sglang_multiturn/run_qwen0.5b_gsm8k_multiturn_curriculum.sh +++ /dev/null @@ -1,56 +0,0 @@ -# run on 8xH100 -# make sure your current working directory is the root of the project - -set -x - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='gsm8k_multiturn_grpo' \ - algorithm.adv_estimator=grpo \ - data.sampler.class_name="RandomCurriculumSampler" \ - data.sampler.class_path="pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu" \ - data.dataloader_num_workers=0 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.train_batch_size=256 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=16 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name='qwen3-4b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=20 \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ - trainer.total_epochs=15 $@ - diff --git a/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh b/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh deleted file mode 100644 index 2667664c9..000000000 --- a/examples/sglang_multiturn/run_qwen2.5-0.5b_gsm8k_multiturn_w_interaction.sh +++ /dev/null @@ -1,58 +0,0 @@ -# run on 8xH100 -# make sure your current working directory is the root of the project - -set -x - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" -TRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-512} -MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-8} -OFFLOAD=${OFFLOAD:-False} - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='gsm8k_multiturn_grpo_w_interaction' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=$TRAIN_BATCH_SIZE \ - data.max_prompt_length=1024 \ - data.max_response_length=$((1024 * 3)) \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - +actor_rollout_ref.model.enable_activation_offloading=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_BATCH_SIZE \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.fsdp_config.param_offload=$OFFLOAD \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=$OFFLOAD \ - +actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ - actor_rollout_ref.rollout.n=8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \ - actor_rollout_ref.ref.fsdp_config.param_offload=$OFFLOAD \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name='qwen2.5-0.5b_function_rm-gsm8k-sgl-multi-w-interaction-n8' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=20 \ - data.train_files=$HOME/data/gsm8k_verl_sgl_multi_turn_w_interaction/train.parquet \ - data.val_files=$HOME/data/gsm8k_verl_sgl_multi_turn_w_interaction/test.parquet \ - actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml" \ - trainer.total_epochs=15 $@ - diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh deleted file mode 100644 index 662723df4..000000000 --- a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh +++ /dev/null @@ -1,54 +0,0 @@ -# run on 8xH100 -# make sure your current working directory is the root of the project - -set -x - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='gsm8k_multiturn_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=256 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=16 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=20 \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ - trainer.total_epochs=15 \ - actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@ - diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh deleted file mode 100644 index 9e61893b0..000000000 --- a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh +++ /dev/null @@ -1,60 +0,0 @@ -# run on 4xH100 -# make sure your current working directory is the root of the project - -set -x -export HYDRA_FULL_ERROR=1 -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='gsm8k_multiturn_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=256 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=16 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16-4cards' \ - trainer.n_gpus_per_node=4 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=20 \ - trainer.total_epochs=15 \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \ - critic.ppo_max_token_len_per_gpu=8192 \ - critic.forward_max_token_len_per_gpu=8192 \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ - actor_rollout_ref.rollout.multi_turn.interaction_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/interaction_config/gsm8k_interaction_config.yaml" \ - actor_rollout_ref.rollout.multi_turn.max_user_turns=1 \ - $@ \ No newline at end of file diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh deleted file mode 100644 index 11c104fa9..000000000 --- a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_tool_agent_mlflow.sh +++ /dev/null @@ -1,57 +0,0 @@ -# run on 8xH100 -# make sure your current working directory is the root of the project - -set -x - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='gsm8k_multiturn_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=256 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.mode=async \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=16 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.rollout.trace.backend=mlflow \ - actor_rollout_ref.rollout.trace.token2text=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","mlflow"]' \ - trainer.project_name='gsm8k_tool-agent' \ - trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-tool-agent-verify-n16' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=20 \ - trainer.total_training_steps=2 \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ - trainer.total_epochs=15 $@ - diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh deleted file mode 100644 index a13d4f422..000000000 --- a/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh +++ /dev/null @@ -1,65 +0,0 @@ -# run on 8xH100 -# make sure your current working directory is the root of the project -# this is a verification training script, the parallel setting should be tuned to your model - -set -x - -export PYTHONUNBUFFERED=1 -export RAY_DEDUP_LOGS=0 -export RUST_BACKTRACE=1 -export HYDRA_FULL_ERROR=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='gsm8k_multiturn_megatron_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=/user/longxiang1/models/Qwen/Qwen2.5-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.actor.megatron.context_parallel_size=2 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.megatron.seed=42 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=2 \ - actor_rollout_ref.ref.megatron.context_parallel_size=2 \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=20 \ - data.train_files=/user/longxiang1/data/gsm8k_verl_sgl_multi_turn_preprocessed_v2/train.parquet \ - data.val_files=/user/longxiang1/data/gsm8k_verl_sgl_multi_turn_preprocessed_v2/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ - trainer.total_epochs=15 $@ - diff --git a/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh deleted file mode 100755 index 56228f4b5..000000000 --- a/examples/sglang_multiturn/run_qwen3-4b_gsm8k_multiturn.sh +++ /dev/null @@ -1,53 +0,0 @@ -# run on 8xH100 -# make sure your current working directory is the root of the project - -set -x - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='gsm8k_multiturn_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=256 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen3-4B \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=16 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name='qwen3-4b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=20 \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ - trainer.total_epochs=15 $@ - diff --git a/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py b/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py deleted file mode 100644 index 6fe554936..000000000 --- a/examples/sglang_multiturn/search_r1_like/local_dense_retriever/download.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 Search-R1 Contributors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/scripts/download.py - - -import argparse - -from huggingface_hub import hf_hub_download - -parser = argparse.ArgumentParser(description="Download files from a Hugging Face dataset repository.") -parser.add_argument("--repo_id", type=str, default="PeterJinGo/wiki-18-e5-index", help="Hugging Face repository ID") -parser.add_argument("--save_path", type=str, required=True, help="Local directory to save files") - -args = parser.parse_args() - -repo_id = "PeterJinGo/wiki-18-e5-index" -for file in ["part_aa", "part_ab"]: - hf_hub_download( - repo_id=repo_id, - filename=file, # e.g., "e5_Flat.index" - repo_type="dataset", - local_dir=args.save_path, - ) - -repo_id = "PeterJinGo/wiki-18-corpus" -hf_hub_download( - repo_id=repo_id, - filename="wiki-18.jsonl.gz", - repo_type="dataset", - local_dir=args.save_path, -) diff --git a/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py b/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py deleted file mode 100644 index 2f67c1439..000000000 --- a/examples/sglang_multiturn/search_r1_like/local_dense_retriever/retrieval_server.py +++ /dev/null @@ -1,415 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 Search-R1 Contributors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/search_r1/search/retrieval_server.py - -import argparse -import json -import warnings -from typing import Optional - -import datasets -import faiss -import numpy as np -import torch -import uvicorn -from fastapi import FastAPI -from pydantic import BaseModel -from tqdm import tqdm -from transformers import AutoModel, AutoTokenizer - - -def load_corpus(corpus_path: str): - corpus = datasets.load_dataset("json", data_files=corpus_path, split="train", num_proc=4) - return corpus - - -def load_docs(corpus, doc_idxs): - results = [corpus[int(idx)] for idx in doc_idxs] - return results - - -def load_model(model_path: str, use_fp16: bool = False): - model = AutoModel.from_pretrained(model_path, trust_remote_code=True) - model.eval() - model.cuda() - if use_fp16: - model = model.half() - tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True) - return model, tokenizer - - -def pooling(pooler_output, last_hidden_state, attention_mask=None, pooling_method="mean"): - if pooling_method == "mean": - last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) - return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] - elif pooling_method == "cls": - return last_hidden_state[:, 0] - elif pooling_method == "pooler": - return pooler_output - else: - raise NotImplementedError("Pooling method not implemented!") - - -class Encoder: - def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16): - self.model_name = model_name - self.model_path = model_path - self.pooling_method = pooling_method - self.max_length = max_length - self.use_fp16 = use_fp16 - - self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16) - self.model.eval() - - @torch.no_grad() - def encode(self, query_list: list[str], is_query=True) -> np.ndarray: - # processing query for different encoders - if isinstance(query_list, str): - query_list = [query_list] - - if "e5" in self.model_name.lower(): - if is_query: - query_list = [f"query: {query}" for query in query_list] - else: - query_list = [f"passage: {query}" for query in query_list] - - if "bge" in self.model_name.lower(): - if is_query: - query_list = [ - f"Represent this sentence for searching relevant passages: {query}" for query in query_list - ] - - inputs = self.tokenizer( - query_list, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt" - ) - inputs = {k: v.cuda() for k, v in inputs.items()} - - if "T5" in type(self.model).__name__: - # T5-based retrieval model - decoder_input_ids = torch.zeros((inputs["input_ids"].shape[0], 1), dtype=torch.long).to( - inputs["input_ids"].device - ) - output = self.model(**inputs, decoder_input_ids=decoder_input_ids, return_dict=True) - query_emb = output.last_hidden_state[:, 0, :] - else: - output = self.model(**inputs, return_dict=True) - query_emb = pooling( - output.pooler_output, output.last_hidden_state, inputs["attention_mask"], self.pooling_method - ) - if "dpr" not in self.model_name.lower(): - query_emb = torch.nn.functional.normalize(query_emb, dim=-1) - - query_emb = query_emb.detach().cpu().numpy() - query_emb = query_emb.astype(np.float32, order="C") - - del inputs, output - torch.cuda.empty_cache() - - return query_emb - - -class BaseRetriever: - def __init__(self, config): - self.config = config - self.retrieval_method = config.retrieval_method - self.topk = config.retrieval_topk - - self.index_path = config.index_path - self.corpus_path = config.corpus_path - - def _search(self, query: str, num: int, return_score: bool): - raise NotImplementedError - - def _batch_search(self, query_list: list[str], num: int, return_score: bool): - raise NotImplementedError - - def search(self, query: str, num: int = None, return_score: bool = False): - return self._search(query, num, return_score) - - def batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): - return self._batch_search(query_list, num, return_score) - - -class BM25Retriever(BaseRetriever): - def __init__(self, config): - super().__init__(config) - from pyserini.search.lucene import LuceneSearcher - - self.searcher = LuceneSearcher(self.index_path) - self.contain_doc = self._check_contain_doc() - if not self.contain_doc: - self.corpus = load_corpus(self.corpus_path) - self.max_process_num = 8 - - def _check_contain_doc(self): - return self.searcher.doc(0).raw() is not None - - def _search(self, query: str, num: int = None, return_score: bool = False): - if num is None: - num = self.topk - hits = self.searcher.search(query, num) - if len(hits) < 1: - if return_score: - return [], [] - else: - return [] - scores = [hit.score for hit in hits] - if len(hits) < num: - warnings.warn("Not enough documents retrieved!", stacklevel=2) - else: - hits = hits[:num] - - if self.contain_doc: - all_contents = [json.loads(self.searcher.doc(hit.docid).raw())["contents"] for hit in hits] - results = [ - { - "title": content.split("\n")[0].strip('"'), - "text": "\n".join(content.split("\n")[1:]), - "contents": content, - } - for content in all_contents - ] - else: - results = load_docs(self.corpus, [hit.docid for hit in hits]) - - if return_score: - return results, scores - else: - return results - - def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): - results = [] - scores = [] - for query in query_list: - item_result, item_score = self._search(query, num, True) - results.append(item_result) - scores.append(item_score) - if return_score: - return results, scores - else: - return results - - -class DenseRetriever(BaseRetriever): - def __init__(self, config): - super().__init__(config) - self.index = faiss.read_index(self.index_path) - if config.faiss_gpu: - co = faiss.GpuMultipleClonerOptions() - co.useFloat16 = True - co.shard = True - self.index = faiss.index_cpu_to_all_gpus(self.index, co=co) - - self.corpus = load_corpus(self.corpus_path) - self.encoder = Encoder( - model_name=self.retrieval_method, - model_path=config.retrieval_model_path, - pooling_method=config.retrieval_pooling_method, - max_length=config.retrieval_query_max_length, - use_fp16=config.retrieval_use_fp16, - ) - self.topk = config.retrieval_topk - self.batch_size = config.retrieval_batch_size - - def _search(self, query: str, num: int = None, return_score: bool = False): - if num is None: - num = self.topk - query_emb = self.encoder.encode(query) - scores, idxs = self.index.search(query_emb, k=num) - idxs = idxs[0] - scores = scores[0] - results = load_docs(self.corpus, idxs) - if return_score: - return results, scores.tolist() - else: - return results - - def _batch_search(self, query_list: list[str], num: int = None, return_score: bool = False): - if isinstance(query_list, str): - query_list = [query_list] - if num is None: - num = self.topk - - results = [] - scores = [] - for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc="Retrieval process: "): - query_batch = query_list[start_idx : start_idx + self.batch_size] - batch_emb = self.encoder.encode(query_batch) - batch_scores, batch_idxs = self.index.search(batch_emb, k=num) - batch_scores = batch_scores.tolist() - batch_idxs = batch_idxs.tolist() - - # load_docs is not vectorized, but is a python list approach - flat_idxs = sum(batch_idxs, []) - batch_results = load_docs(self.corpus, flat_idxs) - # chunk them back - batch_results = [batch_results[i * num : (i + 1) * num] for i in range(len(batch_idxs))] - - results.extend(batch_results) - scores.extend(batch_scores) - - del batch_emb, batch_scores, batch_idxs, query_batch, flat_idxs, batch_results - torch.cuda.empty_cache() - - if return_score: - return results, scores - else: - return results - - -def get_retriever(config): - if config.retrieval_method == "bm25": - return BM25Retriever(config) - else: - return DenseRetriever(config) - - -##################################### -# FastAPI server below -##################################### - - -class Config: - """ - Minimal config class (simulating your argparse) - Replace this with your real arguments or load them dynamically. - """ - - def __init__( - self, - retrieval_method: str = "bm25", - retrieval_topk: int = 10, - index_path: str = "./index/bm25", - corpus_path: str = "./data/corpus.jsonl", - dataset_path: str = "./data", - data_split: str = "train", - faiss_gpu: bool = True, - retrieval_model_path: str = "./model", - retrieval_pooling_method: str = "mean", - retrieval_query_max_length: int = 256, - retrieval_use_fp16: bool = False, - retrieval_batch_size: int = 128, - ): - self.retrieval_method = retrieval_method - self.retrieval_topk = retrieval_topk - self.index_path = index_path - self.corpus_path = corpus_path - self.dataset_path = dataset_path - self.data_split = data_split - self.faiss_gpu = faiss_gpu - self.retrieval_model_path = retrieval_model_path - self.retrieval_pooling_method = retrieval_pooling_method - self.retrieval_query_max_length = retrieval_query_max_length - self.retrieval_use_fp16 = retrieval_use_fp16 - self.retrieval_batch_size = retrieval_batch_size - - -class QueryRequest(BaseModel): - queries: list[str] - topk: Optional[int] = None - return_scores: bool = False - - -app = FastAPI() - - -@app.post("/retrieve") -def retrieve_endpoint(request: QueryRequest): - """ - Endpoint that accepts queries and performs retrieval. - - Input format: - { - "queries": ["What is Python?", "Tell me about neural networks."], - "topk": 3, - "return_scores": true - } - - Output format (when return_scores=True,similarity scores are returned): - { - "result": [ - [ # Results for each query - { - {"document": doc, "score": score} - }, - # ... more documents - ], - # ... results for other queries - ] - } - """ - if not request.topk: - request.topk = config.retrieval_topk # fallback to default - - # Perform batch retrieval - results, scores = retriever.batch_search( - query_list=request.queries, num=request.topk, return_score=request.return_scores - ) - - # Format response - resp = [] - for i, single_result in enumerate(results): - if request.return_scores: - # If scores are returned, combine them with results - combined = [] - for doc, score in zip(single_result, scores[i], strict=True): - combined.append({"document": doc, "score": score}) - resp.append(combined) - else: - resp.append(single_result) - return {"result": resp} - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Launch the local faiss retriever.") - parser.add_argument( - "--index_path", type=str, default="/home/peterjin/mnt/index/wiki-18/e5_Flat.index", help="Corpus indexing file." - ) - parser.add_argument( - "--corpus_path", - type=str, - default="/home/peterjin/mnt/data/retrieval-corpus/wiki-18.jsonl", - help="Local corpus file.", - ) - parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.") - parser.add_argument("--retriever_name", type=str, default="e5", help="Name of the retriever model.") - parser.add_argument( - "--retriever_model", type=str, default="intfloat/e5-base-v2", help="Path of the retriever model." - ) - parser.add_argument("--faiss_gpu", action="store_true", help="Use GPU for computation") - - args = parser.parse_args() - - # 1) Build a config (could also parse from arguments). - # In real usage, you'd parse your CLI arguments or environment variables. - config = Config( - retrieval_method=args.retriever_name, # or "dense" - index_path=args.index_path, - corpus_path=args.corpus_path, - retrieval_topk=args.topk, - faiss_gpu=args.faiss_gpu, - retrieval_model_path=args.retriever_model, - retrieval_pooling_method="mean", - retrieval_query_max_length=256, - retrieval_use_fp16=True, - retrieval_batch_size=512, - ) - - # 2) Instantiate a global retriever so it is loaded once and reused. - retriever = get_retriever(config) - - # 3) Launch the server. By default, it listens on http://127.0.0.1:8000 - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh b/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh deleted file mode 100644 index 4415e47a9..000000000 --- a/examples/sglang_multiturn/search_r1_like/run_qwen2.5-3b_instruct_search_multiturn.sh +++ /dev/null @@ -1,66 +0,0 @@ -# run on 8xH20 -# make sure your current working directory is the root of the project - -set -x - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - - -TRAIN_DATA="$HOME/data/searchR1_processed_direct/train.parquet" -VAL_DATA="$HOME/data/searchR1_processed_direct/test.parquet" - -TOOL_CONFIG="$CONFIG_PATH/tool_config/search_tool_config.yaml" - - - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='search_multiturn_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=512 \ - data.val_batch_size=256 \ - data.max_prompt_length=4096 \ - data.max_response_length=3000 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.285 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.max_model_len=15000 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.multi_turn.max_assistant_turns=2 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.val_before_train=False \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='search_r1_like_async_rl' \ - trainer.experiment_name='qwen2.5-3b-instruct_function_rm-search-async-sgl-multi-w-searchtool-verify-n16' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=100 \ - trainer.test_freq=50 \ - data.train_files="$TRAIN_DATA" \ - data.val_files="$VAL_DATA" \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$TOOL_CONFIG" \ - trainer.total_epochs=1 $@ - diff --git a/examples/slurm/ray_on_slurm.slurm b/examples/slurm/ray_on_slurm.slurm deleted file mode 100644 index 86567d811..000000000 --- a/examples/slurm/ray_on_slurm.slurm +++ /dev/null @@ -1,98 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=verl-ray-on-slurm -#SBATCH --nodes=2 -#SBATCH --ntasks-per-node=1 -#SBATCH --mem=200G -#SBATCH --partition=your-partition -#SBATCH --time=01:00:00 -#SBATCH --account=your-account -#SBATCH --gpus-per-node=4 -#SBATCH --cpus-per-task=64 -#SBATCH --output=slurm-%j.out -#SBATCH --error=slurm-%j.err - -# load necessary modules - -# replace these information with your own -verl_workdir=/path/to/verl -train_files=/path/to/gsm8k/train.parquet -val_files=/path/to/gsm8k/test.parquet -apptainer_image_path=/path/to/verl-ngc.sif -# replace these information with your own - -# Getting the node names -nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") -nodes_array=("$nodes") - -head_node=${nodes_array[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) - -# if we detect a space character in the head node IP, we'll -# convert it to an ipv4 address. This step is optional. -if [[ "$head_node_ip" == *" "* ]]; then -IFS=' ' read -ra ADDR <<<"$head_node_ip" -if [[ ${#ADDR[0]} -gt 16 ]]; then - head_node_ip=${ADDR[1]} -else - head_node_ip=${ADDR[0]} -fi -echo "IPV6 address detected. We split the IPV4 address as $head_node_ip" -fi - -port=6379 -ip_head=$head_node_ip:$port -export ip_head -echo "IP Head: $ip_head" - -# make sure we set environment variables before Ray initialization - -printenv - -echo "Starting HEAD at $head_node" -srun --nodes=1 --ntasks=1 -w "$head_node" \ - apptainer run --nv --bind $verl_workdir $apptainer_image_path \ - ray start --head --node-ip-address="$head_node_ip" --port=$port \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & -# optional, though may be useful in certain versions of Ray < 1.0. -sleep 10 - -# number of nodes other than the head node -worker_num=$((SLURM_JOB_NUM_NODES - 1)) - -for ((i = 1; i <= worker_num; i++)); do - node_i=${nodes_array[$i]} - echo "Starting WORKER $i at $node_i" - srun --nodes=1 --ntasks=1 -w "$node_i" \ - apptainer run --nv --bind $verl_workdir $apptainer_image_path \ - ray start --address "$ip_head" --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus "${SLURM_GPUS_PER_NODE}" --block & - sleep 5 -done - -PYTHONUNBUFFERED=1 srun --overlap --nodes=1 --ntasks=1 -w "$head_node" \ - apptainer run --nv --bind $verl_workdir $apptainer_image_path \ - python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files=$train_files \ - data.val_files=$val_files \ - data.train_batch_size=256 \ - data.max_prompt_length=512 \ - data.max_response_length=256 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - critic.optim.lr=1e-5 \ - critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - critic.ppo_micro_batch_size_per_gpu=4 \ - algorithm.use_kl_in_reward=False \ - trainer.logger=console \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node="${SLURM_GPUS_PER_NODE}" \ - trainer.nnodes="${SLURM_NNODES}" \ - trainer.save_freq=10 \ - trainer.test_freq=10 \ - trainer.total_epochs=15 2>&1 | tee verl_demo_slurm.log diff --git a/examples/split_placement/README.md b/examples/split_placement/README.md deleted file mode 100644 index 226b5436d..000000000 --- a/examples/split_placement/README.md +++ /dev/null @@ -1,60 +0,0 @@ -# Split Placement Example -Here we introduce how to run the naive implementation of the split placement of PPO algorithm. -We will release the complete version of flexible placement in the near future. - - For quickstart, you can only follow Step 2 to modify the code and then follow Step 4 to execute the split placement example. - -### Step 1: Placing the models to different GPUs -Specify the placement and resource allocation. In the example, we place the actor and reference in the first half of the GPUs while map the critic and reward model (if any) to the second half of the GPUs. -```python -actor_rollout_ref_pool_id = 'actor_rollout_ref_pool' -critic_pool_id = 'critic_pool' -if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: - resource_pool_spec = { - actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, - critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, - } -else: - resource_pool_spec = { - actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), - critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), - } -print(f'resource_pool_spec: {resource_pool_spec}') -mapping = { - Role.ActorRollout: actor_rollout_ref_pool_id, - Role.Critic: critic_pool_id, - Role.RefPolicy: actor_rollout_ref_pool_id, -} -mapping[Role.RewardModel] = critic_pool_id -``` - -### Step 2: Make the models executed asynchronously -Based on the model placement, we need to make the models executed asynchronously. - -To do so, you need to turn off the `blocking` flag (i.e., `blocking=False`) in our decorator of some model operations. -For example, we hope the actor update and critic update can be executed in parallel, then we need to make the following modification in `fsdp_workers.py` - -``` -@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) -def update_actor(self, data: DataProto): - ... -@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) -def update_critic(self, data: DataProto): - ... -``` - -We can also parallelize the computation of `ref_log_prob` and `values` and `rewards` in the split placement. For simplicity of the tutorial, we don't do this in this example. - -### Step 3: Execute these operation in parallel in the single controller process -To implement the parallel execution of the actor and critic update, the only thing we need to modify in the `ray_trainer.py` is to `get` the concurrent `futures` on the single controller process. - -```python -critic_output = critic_output.get() -actor_output = actor_output.get() -``` - -### Step 4: Run the split placement example - -``` -bash run_deepseek7b_llm.sh -``` \ No newline at end of file diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py deleted file mode 100644 index c438e7a13..000000000 --- a/examples/split_placement/main_ppo_split.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. -""" - -import hydra -import ray -import torch -from split_monkey_patch import fit - -from verl import DataProto -from verl.trainer.ppo.ray_trainer import RayPPOTrainer -from verl.utils.reward_score import gsm8k, math - - -def _select_rm_score_fn(data_source): - if data_source == "openai/gsm8k": - return gsm8k.compute_score - elif data_source == "lighteval/MATH": - return math.compute_score - else: - raise NotImplementedError - - -class RewardManager: - def __init__(self, tokenizer, num_examine) -> None: - self.tokenizer = tokenizer - self.num_examine = num_examine # the number of batches of decoded responses to print to the console - - def __call__(self, data: DataProto, return_dict: bool = False): - """We will expand this function gradually based on the available datasets""" - - # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if "rm_scores" in data.batch.keys(): - return data.batch["rm_scores"] - - reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) - - already_print_data_sources = {} - - for i in range(len(data)): - data_item = data[i] # DataProtoItem - - prompt_ids = data_item.batch["prompts"] - - prompt_length = prompt_ids.shape[-1] - - valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() - valid_prompt_ids = prompt_ids[-valid_prompt_length:] - - response_ids = data_item.batch["responses"] - valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() - valid_response_ids = response_ids[:valid_response_length] - - # decode - sequences = torch.cat((valid_prompt_ids, valid_response_ids)) - sequences_str = self.tokenizer.decode(sequences) - - ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] - - # select rm_score - data_source = data_item.non_tensor_batch["data_source"] - compute_score_fn = _select_rm_score_fn(data_source) - - score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth) - reward_tensor[i, valid_response_length - 1] = score - - if data_source not in already_print_data_sources: - already_print_data_sources[data_source] = 0 - - if already_print_data_sources[data_source] < self.num_examine: - already_print_data_sources[data_source] += 1 - print(sequences_str) - - if return_dict: - return {"reward_tensor": reward_tensor} - else: - return reward_tensor - - -@hydra.main(config_path="config", config_name="ppo_trainer_split", version_base=None) -def main(config): - if not ray.is_initialized(): - # this is for local ray cluster - ray.init( - runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}, - num_cpus=config.ray_init.num_cpus, - ) - - ray.get(main_task.remote(config)) - - -@ray.remote -def main_task(config): - # print initial config - from pprint import pprint - - from omegaconf import OmegaConf - - from verl.utils.fs import copy_to_local - - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - OmegaConf.resolve(config) - - # download the checkpoint from hdfs - local_path = copy_to_local(config.actor_rollout_ref.model.path) - - # instantiate tokenizer - from verl.utils import hf_tokenizer - - tokenizer = hf_tokenizer(local_path) - - # define worker classes - if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: - assert config.critic.strategy in {"fsdp", "fsdp2"} - from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker - - ray_worker_group_cls = RayWorkerGroup - - elif config.actor_rollout_ref.actor.strategy == "megatron": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker - - ray_worker_group_cls = NVMegatronRayWorkerGroup - - else: - raise NotImplementedError - - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role - - role_worker_mapping = { - Role.ActorRollout: ray.remote(ActorRolloutRefWorker), - Role.Critic: ray.remote(CriticWorker), - } - - # NOTE: initialze two resource pool - actor_rollout_ref_pool_id = "actor_rollout_ref_pool" - critic_pool_id = "critic_pool" - if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: - resource_pool_spec = { - actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, - critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, - } - else: - resource_pool_spec = { - actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), - critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), - } - print(f"resource_pool_spec: {resource_pool_spec}") - mapping = { - Role.ActorRollout: actor_rollout_ref_pool_id, - Role.Critic: critic_pool_id, - } - - # use reference model - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: - role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) - mapping[Role.RefPolicy] = actor_rollout_ref_pool_id - - # we should adopt a multi-source reward function here - # - for rule-based rm, we directly call a reward score - # - for model-based rm, we call a model - # - for code related prompt, we send to a sandbox if there are test cases - # - finally, we combine all the rewards together - # - The reward type depends on the tag of the data - if config.reward_model.enable: - if config.reward_model.strategy in {"fsdp", "fsdp2"}: - from verl.workers.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == "megatron": - from verl.workers.megatron_workers import RewardModelWorker - else: - raise NotImplementedError - role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) - mapping[Role.RewardModel] = critic_pool_id - - reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0) - - # Note that we always use function-based RM for validation - val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1) - - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - RayPPOTrainer.fit = fit - trainer = RayPPOTrainer( - config=config, - tokenizer=tokenizer, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn, - device_name=config.trainer.device, - ) - trainer.init_workers() - trainer.fit() - - -if __name__ == "__main__": - main() diff --git a/examples/split_placement/run_deepseek7b_llm.sh b/examples/split_placement/run_deepseek7b_llm.sh deleted file mode 100644 index 473dcccdd..000000000 --- a/examples/split_placement/run_deepseek7b_llm.sh +++ /dev/null @@ -1,37 +0,0 @@ -set -x - -python3 main_ppo_split.py \ - algorithm.adv_estimator=gae \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - critic.optim.lr=1e-5 \ - critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=8 \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_example_gsm8k' \ - trainer.experiment_name='deepseek_llm_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_epochs=15 $@ diff --git a/examples/split_placement/split_monkey_patch.py b/examples/split_placement/split_monkey_patch.py deleted file mode 100644 index ef58509b9..000000000 --- a/examples/split_placement/split_monkey_patch.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -An naive implementation of split placment example -""" - -import uuid -from copy import deepcopy -from pprint import pprint - -import numpy as np -import torch - -from verl import DataProto -from verl.trainer.ppo.ray_trainer import ( - AdvantageEstimator, - apply_kl_penalty, - compute_advantage, - compute_data_metrics, - compute_timing_metrics, - marked_timer, -) -from verl.utils.metric import reduce_metrics - - -def fit(self): - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC - to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - - from verl.utils.tracking import Tracking - - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - - self.global_steps = 0 - - # load checkpoint before doing anything - self._load_checkpoint() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - timing_raw = {} - - batch: DataProto = DataProto.from_single_dict(batch_dict) - - # pop those keys for generation - gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) - is_last_step = self.global_steps >= self.total_training_steps - - with marked_timer("step", timing_raw): - # generate a batch - with marked_timer("gen", timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - timing_raw.update(gen_batch_output.meta_info["timing"]) - gen_batch_output.meta_info.pop("timing", None) - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with marked_timer("gen_max", timing_raw): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - - batch = batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - batch.batch["reward_baselines"] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object - ) - # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - batch = batch.union(gen_batch_output) - - # Balance the number of valid tokens across DP ranks. - # NOTE: This usually changes the order of data in the `batch`, - # which won't affect the advantage calculation (since it's based on uid), - # but might affect the loss calculation (due to the change of mini-batching). - # TODO: Decouple the DP balancing and mini-batching. - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - - # recompute old_log_probs - with marked_timer("old_log_prob", timing_raw): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - batch = batch.union(old_log_prob) - - if self.use_reference_policy: - # compute reference log_prob - with marked_timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # compute values - if self.use_critic: - with marked_timer("values", timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with marked_timer("adv", timing_raw): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - # we combine with rule-based rm - reward_tensor = self.reward_fn(batch) - batch.batch["token_level_scores"] = reward_tensor - - # compute rewards. apply_kl_penalty if available - if self.config.algorithm.use_kl_in_reward: - batch, kl_metrics = apply_kl_penalty( - batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty - ) - metrics.update(kl_metrics) - else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - - # compute advantages, executed on the driver process - norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) - batch = compute_advantage( - batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - ) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with marked_timer("update_actor_call", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - else: - actor_output = None - - # update critic - if self.use_critic: - with marked_timer("update_critic_call", timing_raw): - critic_output = self.critic_wg.update_critic(batch) - - # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class - with marked_timer("update_actor_critic", timing_raw): - critic_output = critic_output.get() - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - metrics.update(critic_output_metrics) - - if actor_output is not None: - actor_output = actor_output.get() - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # validate - if ( - self.val_reward_fn is not None - and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) - ): - with marked_timer("testing", timing_raw): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 - ): - with marked_timer("save_checkpoint", timing_raw): - self._save_checkpoint() - - # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - if self.global_steps >= self.total_training_steps: - pprint(f"Final validation metrics: {last_val_metrics}") - return - - self.global_steps += 1 diff --git a/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh deleted file mode 100644 index a40ae6f60..000000000 --- a/examples/tuning/0.5b/qwen2-0.5b_grpo-lora_1_h100_fsdp_vllm.sh +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -NOW=$(date +%Y%m%d) -export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-0.5b-${NOW} -export WANDB_PROJECT=${WANDB_DIR} -export WANDB_EXP=0.5b-${NOW} -MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct - -set -x -nproc_per_gpu=116 -nnodes=1 -ngpu_per_node=1 -total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) -mini_batch_size=$(( total_procs )) - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=data/gsm8k/train.parquet \ - data.val_files=data/gsm8k/test.parquet \ - data.train_batch_size=${total_procs} \ - data.val_batch_size=${total_procs} \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.shuffle=False \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.model.use_shm=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.model.lora_rank=32 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.model.target_modules=all-linear \ - actor_rollout_ref.actor.optim.lr=3e-5 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.max_num_seqs=512 \ - actor_rollout_ref.rollout.max_model_len=1536 \ - actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.load_format=safetensors \ - actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.entropy_coeff=0.001 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXP} \ - trainer.n_gpus_per_node=1 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh deleted file mode 100644 index 6b6ede29b..000000000 --- a/examples/tuning/1.5b/qwen2-1.5b_grpo-lora_1_h100_fsdp_vllm.sh +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -NOW=$(date +%Y%m%d) -export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-1.5b-${NOW} -export WANDB_PROJECT=${WANDB_DIR} -export WANDB_EXP=1.5b-${NOW} -MODEL_PATH=Qwen/Qwen2.5-1.5B-Instruct - -set -x -nproc_per_gpu=128 -nnodes=1 -ngpu_per_node=1 -total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) -mini_batch_size=$(( total_procs )) - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=data/gsm8k/train.parquet \ - data.val_files=data/gsm8k/test.parquet \ - data.train_batch_size=${total_procs} \ - data.val_batch_size=${total_procs} \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.shuffle=False \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.model.use_shm=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.model.lora_rank=32 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.model.target_modules=all-linear \ - actor_rollout_ref.actor.optim.lr=3e-5 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.max_num_seqs=512 \ - actor_rollout_ref.rollout.max_model_len=1536 \ - actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.load_format=safetensors \ - actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.entropy_coeff=0.001 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXP} \ - trainer.n_gpus_per_node=1 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh b/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh deleted file mode 100644 index 247945ffc..000000000 --- a/examples/tuning/14b/qwen2-14b_grpo-lora_2_h100_fsdp_vllm.sh +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -NOW=$(date +%Y%m%d) -export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-14b-${NOW} -export WANDB_PROJECT=${WANDB_DIR} -export WANDB_EXP=14b-${NOW} -MODEL_PATH=Qwen/Qwen2.5-14B-Instruct - -set -x -nproc_per_gpu=58 # 32√ → 64× → 48√ → 56√ → 60× → 58√ → 59× -nnodes=1 -ngpu_per_node=2 -total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) -mini_batch_size=$(( total_procs )) - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=data/gsm8k/train.parquet \ - data.val_files=data/gsm8k/test.parquet \ - data.train_batch_size=${total_procs} \ - data.val_batch_size=${total_procs} \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.shuffle=False \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.model.use_shm=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.model.lora_rank=32 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.model.target_modules=all-linear \ - actor_rollout_ref.actor.optim.lr=3e-5 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.25 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.max_num_seqs=512 \ - actor_rollout_ref.rollout.max_model_len=1536 \ - actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.load_format=safetensors \ - actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ - actor_rollout_ref.actor.entropy_coeff=0.001 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXP} \ - trainer.n_gpus_per_node=2 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh b/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh deleted file mode 100644 index 2df21533c..000000000 --- a/examples/tuning/14b/qwen2_14b_grpo_4_h800_fsdp_vllm.sh +++ /dev/null @@ -1,47 +0,0 @@ -set -x - -gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/rlhf/math/test.parquet -model_path=Qwen/Qwen2.5-Coder-14B-Instruct - -train_files="['$gsm8k_train_path']" -test_files="['$gsm8k_test_path']" - -PYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=$model_path \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_14b_function_rm' \ - trainer.n_gpus_per_node=4 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ diff --git a/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh b/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh deleted file mode 100644 index d707a4adc..000000000 --- a/examples/tuning/32b/qwen2-32b_grpo-lora_4_h100_fsdp_vllm.sh +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -NOW=$(date +%Y%m%d) -export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-32b-${NOW} -export WANDB_PROJECT=${WANDB_DIR} -export WANDB_EXP=32b-${NOW} -MODEL_PATH=Qwen/Qwen2.5-32B-Instruct - -set -x -nproc_per_gpu=45 # 32√ → 64× → 48× → 40√ → 44√ → 46× → 45× -nnodes=1 -ngpu_per_node=4 -total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) -mini_batch_size=$(( total_procs )) - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=data/gsm8k/train.parquet \ - data.val_files=data/gsm8k/test.parquet \ - data.train_batch_size=${total_procs} \ - data.val_batch_size=${total_procs} \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.shuffle=False \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.model.use_shm=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.model.lora_rank=32 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.model.target_modules=all-linear \ - actor_rollout_ref.actor.optim.lr=3e-5 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.max_num_seqs=512 \ - actor_rollout_ref.rollout.max_model_len=1536 \ - actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.load_format=safetensors \ - actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ - actor_rollout_ref.actor.entropy_coeff=0.001 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXP} \ - trainer.n_gpus_per_node=4 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh b/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh deleted file mode 100644 index 3a96fe504..000000000 --- a/examples/tuning/32b/qwen2_32B_grpo_8_h20_megatron_vllm.sh +++ /dev/null @@ -1,52 +0,0 @@ -set -x - -# we need this to avoid fragmentation of GPU memory -export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:256 - -gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/rlhf/math/test.parquet -train_files="['$gsm8k_train_path']" -test_files="['$gsm8k_test_path']" - -model_path=Qwen/Qwen2.5-32B - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=512 \ - data.max_prompt_length=2048 \ - data.max_response_length=6144 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=$model_path \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=8 \ - actor_rollout_ref.actor.megatron.param_offload=True \ - actor_rollout_ref.actor.megatron.grad_offload=True \ - actor_rollout_ref.actor.megatron.optimizer_offload=True \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.ref.megatron.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='megatron_vllm_qwen2_32b' \ - trainer.experiment_name='qwen2_32b_grpo_8_h20' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh deleted file mode 100644 index fac34a5d5..000000000 --- a/examples/tuning/3b/qwen2-3b_grpo-lora_1_h100_fsdp_vllm.sh +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -NOW=$(date +%Y%m%d) -export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-3b-${NOW} -export WANDB_PROJECT=${WANDB_DIR} -export WANDB_EXP=3b-${NOW} -MODEL_PATH=Qwen/Qwen2.5-3B-Instruct - -set -x -nproc_per_gpu=62 -nnodes=1 -ngpu_per_node=1 -total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) -mini_batch_size=$(( total_procs )) - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=data/gsm8k/train.parquet \ - data.val_files=data/gsm8k/test.parquet \ - data.train_batch_size=${total_procs} \ - data.val_batch_size=${total_procs} \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.shuffle=False \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.model.use_shm=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.model.lora_rank=32 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.model.target_modules=all-linear \ - actor_rollout_ref.actor.optim.lr=3e-5 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.1 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.max_num_seqs=512 \ - actor_rollout_ref.rollout.max_model_len=1536 \ - actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.load_format=safetensors \ - actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.entropy_coeff=0.001 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXP} \ - trainer.n_gpus_per_node=1 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh b/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh deleted file mode 100644 index 9a1d50ad1..000000000 --- a/examples/tuning/70b/qwen2-70b_grpo_32_h20_fsdp_vllm.sh +++ /dev/null @@ -1,43 +0,0 @@ -set -x - -gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet -gsm8k_val_path=$HOME/data/rlhf/math/test.parquet -model_path=Qwen/Qwen2-72B-Instruct - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$data_path \ - data.val_files=$gsm8k_val_path \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=model_path \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.tensor_model_parallel_size=16 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='Qwen2_72B_Instruct' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=4 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ \ No newline at end of file diff --git a/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh b/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh deleted file mode 100644 index b15f406b1..000000000 --- a/examples/tuning/70b/qwen2-70b_grpo_32_h800_fsdp_vllm.sh +++ /dev/null @@ -1,45 +0,0 @@ -set -x - -#### important: vllm version must be >= 0.8.3 - -gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet -gsm8k_val_path=$HOME/data/rlhf/math/test.parquet -model_path=Qwen/Qwen2-72B-Instruct - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$gsm8k_train_path \ - data.val_files=$gsm8k_val_path \ - data.train_batch_size=1024 \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=$model_path \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.tensor_model_parallel_size=16 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='Qwen2_72B_Instruct' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=4 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ \ No newline at end of file diff --git a/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh b/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh deleted file mode 100644 index 7f93ed32f..000000000 --- a/examples/tuning/70b/qwen2-72b_grpo-lora_8_h100_fsdp_vllm.sh +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -NOW=$(date +%Y%m%d) -export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-72b-${NOW} -export WANDB_PROJECT=${WANDB_DIR} -export WANDB_EXP=72b-${NOW} -MODEL_PATH=Qwen/Qwen2.5-72B-Instruct - -set -x -nproc_per_gpu=22 # 16√ → 32× → 24× → 20√ → 22√ → 23× -nnodes=1 -ngpu_per_node=8 -total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) -mini_batch_size=$(( total_procs )) - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=data/gsm8k/train.parquet \ - data.val_files=data/gsm8k/test.parquet \ - data.train_batch_size=${total_procs} \ - data.val_batch_size=${total_procs} \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.shuffle=False \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.model.use_shm=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.model.lora_rank=32 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.model.target_modules=all-linear \ - actor_rollout_ref.actor.optim.lr=3e-5 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=8 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.max_num_seqs=512 \ - actor_rollout_ref.rollout.max_model_len=1536 \ - actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.load_format=safetensors \ - actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ - actor_rollout_ref.actor.entropy_coeff=0.001 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXP} \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh b/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh deleted file mode 100644 index a663a90d6..000000000 --- a/examples/tuning/7b/qwen2-7b_grpo-lora_1_h100_fsdp_vllm.sh +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 -NOW=$(date +%Y%m%d) -export WANDB_DIR=gsm8k-grpo-lora-qwen2.5-7b-${NOW} -export WANDB_PROJECT=${WANDB_DIR} -export WANDB_EXP=7b-${NOW} -MODEL_PATH=Qwen/Qwen2.5-7B-Instruct - -set -x -nproc_per_gpu=16 # 64√ → 128× → 96√ → 112× → 104× → 100√ → 102× → 101× -nnodes=1 -ngpu_per_node=1 -total_procs=$(( nproc_per_gpu * nnodes * ngpu_per_node )) -mini_batch_size=$(( total_procs )) - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=data/gsm8k/train.parquet \ - data.val_files=data/gsm8k/test.parquet \ - data.train_batch_size=${total_procs} \ - data.val_batch_size=${total_procs} \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.shuffle=False \ - actor_rollout_ref.model.path=$MODEL_PATH \ - actor_rollout_ref.model.use_shm=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.model.lora_rank=32 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.model.target_modules=all-linear \ - actor_rollout_ref.actor.optim.lr=3e-5 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.2 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.max_num_seqs=512 \ - actor_rollout_ref.rollout.max_model_len=1536 \ - actor_rollout_ref.rollout.max_num_batched_tokens=1536 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.load_format=safetensors \ - actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${mini_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.entropy_coeff=0.001 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXP} \ - trainer.n_gpus_per_node=1 \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ 2>&1 | tee ${WANDB_PROJECT}.log diff --git a/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh b/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh deleted file mode 100644 index 598e82b41..000000000 --- a/examples/tuning/7b/qwen2-7b_grpo_2_h800_fsdp_vllm.sh +++ /dev/null @@ -1,48 +0,0 @@ -set -x - - -gsm8k_train_path=$HOME/data/rlhf/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/rlhf/math/test.parquet -model_path=Qwen/Qwen2-7B-Instruct - -train_files="['$gsm8k_train_path']" -test_files="['$gsm8k_test_path']" - -PYTHONPATH=/opt/tiger/open_verl python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=$model_path \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_7b_function_rm' \ - trainer.n_gpus_per_node=2 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=15 $@ diff --git a/pyproject.toml b/pyproject.toml index e10da9e4d..7a182f777 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,93 +1,26 @@ -# ------------------------------- -# build-system -# ------------------------------- [build-system] -requires = [ - "setuptools>=61.0", - "wheel" -] +requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" -# ------------------------------- -# project (PEP 621 metadata) -# ------------------------------- [project] -name = "verl" -# We'll mark the version as "dynamic" because it's read from the file "verl/version/version" -# (PEP 621 calls this "dynamic version"). -# The actual version is specified in the [tool.setuptools.dynamic] section below. -dynamic = ["version", "dependencies", "optional-dependencies", "authors", "urls"] - -description = "verl: Volcano Engine Reinforcement Learning for LLM" -license = {text = "Apache-2.0"} # Changed from file to text format -readme = {file = "README.md", content-type = "text/markdown"} -requires-python = ">=3.10" - -# ------------------------------- -# tool.ruff - Linting configuration -# ------------------------------- -[tool.ruff] -# Note: While the formatter will attempt to format lines such that they remain within the line-length, -# it isn't a hard upper bound, and formatted lines may exceed the line-length. -line-length = 120 -exclude = ["tests/workers/rollout/test_sglang_async_rollout_sf_tools.py", "scripts/legacy_model_merger.py"] - -[tool.ruff.lint] -isort = {known-first-party = ["verl"]} -# c.f. https://github.com/vllm-project/vllm/blob/ce8d6b75fc0586045df75ee1568a5b5f9957251b/pyproject.toml -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # isort - "I", - "G", +name = "reasoning360" +version = "0.1.0" +description = "Reasoning360 extension for Verl" +readme = "README.md" +requires-python = ">=3.9" +classifiers = [ + "Programming Language :: Python :: 3", ] -ignore = [ - # star imports - "F405", "F403", - # lambda expression assignment - "E731", - # Loop control variable not used within loop body - "B007", - # f-string format - "UP032", - # `.log()` statement uses f-string - "G004", - # X | None for type annotations - "UP045", - # deprecated import - "UP035", - # line length - "E501" +dependencies = [ + "verl", # Assumes verl is installed in the environment + "hydra-core", + # Dependencies migrated from modified verl setup.py + "langdetect", + "immutabledict", + "nltk", + "polars", + "nvitop", # Moved from extras['gpu'] for convenience, or could be optional ] -# ------------------------------- -# tool.setuptools - Additional config -# ------------------------------- -[tool.setuptools] -# True means `setuptools` will attempt to include all relevant files in package_data automatically. -# This corresponds to `include_package_data=True` in setup.py. -include-package-data = true - -# We read the version from a file in 'verl/version/version' -[tool.setuptools.dynamic] -version = {file = "verl/version/version"} - -# If you need to mimic `package_dir={'': '.'}`: -[tool.setuptools.package-dir] -"" = "." - -# If you need to include specific non-Python data (like YAML files or version file): -# This is the rough equivalent of package_data={'': ['version/*'], 'verl': ['trainer/config/*.yaml']} -[tool.setuptools.package-data] -verl = [ - "version/*", - "trainer/config/*.yaml", - "trainer/config/*/*.yaml", -] +[tool.setuptools.packages.find] +where = ["src"] diff --git a/recipe/README.md b/recipe/README.md deleted file mode 100644 index 29fb40384..000000000 --- a/recipe/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# Recipe -The examples under `recipes/` are representative extensions to verl for specific end-to-end RL training recipes. -The help the community reproduce experiments, verl team provides a snapshot of the codebase when each recipe is initially PR'ed to verl main. You can find them via [github branches](https://github.com/volcengine/verl/branches/all?query=recipe) - -# Awesome work using verl - -- [Logic-RL](https://github.com/Unakar/Logic-RL): a reproduction of DeepSeek R1 Zero on 2K Tiny Logic Puzzle Dataset. ![GitHub Repo stars](https://img.shields.io/github/stars/Unakar/Logic-RL) -- [Seed-Coder](https://github.com/ByteDance-Seed/Seed-Coder): RL training of Seed-Coder boosts performance on competitive programming ![GitHub Repo stars](https://img.shields.io/github/stars/ByteDance-Seed/Seed-Coder) -- [all-hands/openhands-lm-32b-v0.1](https://www.all-hands.dev/blog/introducing-openhands-lm-32b----a-strong-open-coding-agent-model): A strong, open coding agent model, trained with [multi-turn fine-tuning](https://github.com/volcengine/verl/pull/195) -- [s3](https://github.com/pat-jj/s3) **Efficient Yet Effective** Search Agent Training via RL ![GitHub Repo stars](https://img.shields.io/github/stars/pat-jj/s3) -- [Rec-R1](https://arxiv.org/pdf/2503.24289): Bridging Generative Large Language Models and Recommendation Systems via Reinforcement Learning -- [Explore RL Data Scaling](https://arxiv.org/abs/2503.22230): Exploring Data Scaling Trends and Effects in Reinforcement Learning from Human Feedback -- [FIRE](https://arxiv.org/abs/2410.21236): Flaming-hot initiation with regular execution sampling for large language models -- [DQO](https://arxiv.org/abs/2410.09302): Enhancing multi-Step reasoning abilities of language models through direct Q-function optimization -- [ProRL](https://arxiv.org/abs/2505.24864): Prolonged Reinforcement Learning Expands Reasoning Boundaries in Large Language Models -- [cognition-engineering](https://github.com/gair-nlp/cognition-engineering): Test time scaling drives cognition engineering. ![GitHub Repo stars](https://img.shields.io/github/stars/gair-nlp/cognition-engineering) -- [Trust Region Preference Approximation](https://github.com/XueruiSu/Trust-Region-Preference-Approximation): A simple and stable **reinforcement learning algorithm** for LLM reasoning. ![GitHub Repo stars](https://img.shields.io/github/stars/XueruiSu/Trust-Region-Preference-Approximation) -- [AdaRFT](https://github.com/uscnlp-lime/verl): Efficient Reinforcement Finetuning via **Adaptive Curriculum Learning** ![GitHub Repo stars](https://img.shields.io/github/stars/uscnlp-lime/verl) -- [critic-rl](https://github.com/HKUNLP/critic-rl): LLM critics for code generation ![GitHub Repo stars](https://img.shields.io/github/stars/HKUNLP/critic-rl) -- [self-rewarding-reasoning-LLM](https://arxiv.org/pdf/2502.19613): self-rewarding and correction with **generative reward models** ![GitHub Repo stars](https://img.shields.io/github/stars/RLHFlow/Self-rewarding-reasoning-LLM) -- [DeepEnlighten](https://github.com/DolbyUUU/DeepEnlighten): Reproduce R1 with **social reasoning** tasks and analyze key findings ![GitHub Repo stars](https://img.shields.io/github/stars/DolbyUUU/DeepEnlighten) -- [MetaSpatial](https://github.com/PzySeere/MetaSpatial): Reinforcing **3D Spatial Reasoning** in **VLMs** for the **Metaverse** ![GitHub Repo stars](https://img.shields.io/github/stars/PzySeere/MetaSpatial) -- [PURE](https://github.com/CJReinforce/PURE): **Credit assignment** is the key to successful reinforcement fine-tuning using **process reward model** ![GitHub Repo stars](https://img.shields.io/github/stars/CJReinforce/PURE) -- [cognitive-behaviors](https://github.com/kanishkg/cognitive-behaviors): Cognitive Behaviors that Enable Self-Improving Reasoners, or, Four Habits of Highly Effective STaRs ![GitHub Repo stars](https://img.shields.io/github/stars/kanishkg/cognitive-behaviors) -- [deepscaler](https://github.com/agentica-project/rllm/tree/deepscaler): iterative context scaling with GRPO ![GitHub Repo stars](https://img.shields.io/github/stars/agentica-project/deepscaler) -- [DAPO](https://dapo-sia.github.io/): the fully open source SOTA RL algorithm that beats DeepSeek-R1-zero-32B ![GitHub Repo stars](https://img.shields.io/github/stars/volcengine/verl) diff --git a/recipe/char_count/README.md b/recipe/char_count/README.md deleted file mode 100644 index 18f902d15..000000000 --- a/recipe/char_count/README.md +++ /dev/null @@ -1,41 +0,0 @@ -# Char Count -## Introduction -Char count is a simple NLP task. We create it for beginners to grasp the idea of RLVR. The task can be trained using a tiny model (e.g., https://huggingface.co/HuggingFaceTB/SmolLM2-135M) on a consumer GPU with only 8GB. - -## Problem formulation -The prompt is: "How many {char} are there in {word}?". In order for LLM to better answer this question, we create SFT dataset with intermediate steps. For example, - -```text -Question: How many n are there in n-i-n-e? -Answer: -n = n -i != n -n = n -e != n -\boxed{2} -``` - -Note that -- We add a dash between each individual char to make the task easier because each individual char will be tokenized to the same token by most tokenizer. -- In the SFT dataset, we create a CoT by listing all the individual chars and whether it equals to the target. In the end, it outputs the final answer inside the box. -- The task can be verified. -- The word is not always meaningful. Each char is sampled uniformly from a to z. We make the total length and the answer uniformly distributed within a range. - -## Scripts -To create the dataset, run -```bash -python3 create_dataset.py -``` -We create a train set and a val set. Both of them are used of SFT and RL. You can specify the total number of data, min/max length and data path. - -To run the SFT -```bash -bash train_sft.sh -``` -We train SFT for 3 epochs. After 3 epochs, the validation score is around 0.12. - -To run GRPO -```bash -bash train_grpo.sh -``` -We train GRPO for 2 epochs. After 2 epochs, the validation score is around 0.36. diff --git a/recipe/char_count/create_dataset.py b/recipe/char_count/create_dataset.py deleted file mode 100644 index 47571e023..000000000 --- a/recipe/char_count/create_dataset.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Task description: -Given a random word and a random char, count the number of occurrence of char in the word. - -Create CoT dataset that split the word into separate char. Then list the char and count the occurrence. - -The word set comes from shakespeare -""" - -import os.path -import random - -prompt_template = "How many {} are there in word {}?" - - -def generate_random_char(): - return chr(97 + random.randint(0, 25)) - - -def create_prompt_response(min_length=3, max_length=5): - # randomly generate a length - word_length = random.randint(min_length, max_length) - # randomly generate a target count number. This makes the target number - target_count_number = random.randint(1, word_length) - - char_lst = [] - # generate the word - # step 1: generate the target word - target_char = generate_random_char() - - for _ in range(target_count_number): - char_lst.append(target_char) - - # step 2: generate other words - for _ in range(word_length - target_count_number): - while True: - char = generate_random_char() - if char != target_char: - char_lst.append(char) - break - - # step 3: random permute char_lst - random.shuffle(char_lst) - - word = "-".join(char_lst) - - prompt = prompt_template.format(target_char, word) - final_answer = [] - - # cot - number = 0 - for i, char in enumerate(char_lst): - cot = f"{char}" - if char != target_char: - cot += " != " - else: - cot += " = " - number += 1 - cot += f"{target_char}." - - final_answer.append(cot) - - conclusion = f"\\boxed{{{number}}} {target_char} in {word}." - - final_answer.append(conclusion) - - final_answer = "\n".join(final_answer) - - return prompt, final_answer - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--total_number", type=int, default=10000) - parser.add_argument("--min_length", type=int, default=5) - parser.add_argument("--max_length", type=int, default=20) - parser.add_argument("--data_path", type=str, default="~/data/char_count") - - args = vars(parser.parse_args()) - - total_number = args["total_number"] - min_length = args["min_length"] - max_length = args["max_length"] - data_path = args["data_path"] - data_path = os.path.expanduser(data_path) - - full_output = [] - for _ in range(total_number): - output = create_prompt_response(min_length=min_length, max_length=max_length) - full_output.append(output) - - # random reorder - random.shuffle(full_output) - - # split for train and test - train_split_len = int(0.9 * len(full_output)) - train_outputs = full_output[:train_split_len] - test_output = full_output[train_split_len:] - - sft_train_dataset = {"prompt": [], "response": []} - - for o in train_outputs: - sft_train_dataset["prompt"].append(o[0]) - sft_train_dataset["response"].append(o[1]) - - sft_test_dataset = {"prompt": [], "response": []} - - for o in test_output: - sft_test_dataset["prompt"].append(o[0]) - sft_test_dataset["response"].append(o[1]) - - import pandas as pd - - sft_train_dataset = pd.DataFrame(data=sft_train_dataset) - sft_test_dataset = pd.DataFrame(data=sft_test_dataset) - - folder = os.path.join(data_path, "sft") - - os.makedirs(folder, exist_ok=True) - - sft_train_dataset.to_parquet(os.path.join(folder, "train.parquet")) - sft_test_dataset.to_parquet(os.path.join(folder, "test.parquet")) - - # build RL dataset - rl_train_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []} - - rl_test_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []} - - from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed - - for o in train_outputs: - prompt = o[0] - response = o[1] - prompt_with_template = [ - { - "role": "user", - "content": prompt, - } - ] - - rl_train_dataset["prompt"].append(prompt_with_template) - rl_train_dataset["data_source"].append("char_count") - rl_train_dataset["ability"].append("other") - rl_train_dataset["reward_model"].append( - {"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))} - ) - rl_train_dataset["extra_info"].append({"response": response}) - - for o in test_output: - prompt = o[0] - response = o[1] - prompt_with_template = [ - { - "role": "user", - "content": prompt, - } - ] - - rl_test_dataset["prompt"].append(prompt_with_template) - rl_test_dataset["data_source"].append("char_count") - rl_test_dataset["ability"].append("other") - rl_test_dataset["reward_model"].append( - {"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))} - ) - rl_test_dataset["extra_info"].append({"response": response}) - - rl_train_dataset = pd.DataFrame(data=rl_train_dataset) - rl_test_dataset = pd.DataFrame(data=rl_test_dataset) - - folder = os.path.join(data_path, "rl") - - os.makedirs(folder, exist_ok=True) - - rl_train_dataset.to_parquet(os.path.join(folder, "train.parquet")) - rl_test_dataset.to_parquet(os.path.join(folder, "test.parquet")) diff --git a/recipe/char_count/reward_function.py b/recipe/char_count/reward_function.py deleted file mode 100644 index 9bdffe2a5..000000000 --- a/recipe/char_count/reward_function.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Reward function -""" - -from verl.utils.reward_score import math - - -def char_count_reward_function(data_source, solution_str, ground_truth, extra_info=None): - try: - last_boxed_string = math.last_boxed_only_string(solution_str) - if last_boxed_string is None: - return 0 - solution = math.remove_boxed(last_boxed_string) - if solution == ground_truth: - return 1 - else: - return 0 - except Exception: - print(ground_truth, solution_str) - return 0 diff --git a/recipe/char_count/train_grpo.sh b/recipe/char_count/train_grpo.sh deleted file mode 100644 index 5de85422f..000000000 --- a/recipe/char_count/train_grpo.sh +++ /dev/null @@ -1,43 +0,0 @@ -set -x - - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/char_count/rl/train.parquet \ - data.val_files=$HOME/data/char_count/rl/test.parquet \ - data.train_batch_size=128 \ - data.max_prompt_length=128 \ - data.max_response_length=128 \ - data.filter_overlong_prompts=False \ - data.truncation='error' \ - actor_rollout_ref.model.path=./models/sft/global_step_105 \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=16 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5000 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.actor.kl_loss_coef=0.0 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ - actor_rollout_ref.rollout.n=8 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","tensorboard"]' \ - trainer.project_name='verl_example' \ - trainer.experiment_name='smol135m_grpo' \ - trainer.val_before_train=True \ - trainer.n_gpus_per_node=1 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=2 \ - custom_reward_function.path=recipe/char_count/reward_function.py \ - custom_reward_function.name=char_count_reward_function diff --git a/recipe/char_count/train_sft.sh b/recipe/char_count/train_sft.sh deleted file mode 100644 index 56f5cec53..000000000 --- a/recipe/char_count/train_sft.sh +++ /dev/null @@ -1,21 +0,0 @@ -set -x - -nproc_per_node=1 -save_path=./models/sft - -torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/char_count/sft/train.parquet \ - data.val_files=$HOME/data/char_count/sft/test.parquet \ - data.prompt_key=prompt \ - data.response_key=response \ - data.micro_batch_size_per_gpu=8 \ - data.max_length=256 \ - data.train_batch_size=256 \ - use_remove_padding=True \ - model.partial_pretrain=HuggingFaceTB/SmolLM2-135M-Instruct \ - trainer.default_local_dir=$save_path \ - trainer.project_name=char_count-sft \ - trainer.experiment_name=char_count-sft-SmolLM2-135M-Instruct \ - trainer.total_epochs=3 \ - trainer.logger=console \ No newline at end of file diff --git a/recipe/dapo/README.md b/recipe/dapo/README.md deleted file mode 100644 index 75b80f1aa..000000000 --- a/recipe/dapo/README.md +++ /dev/null @@ -1,192 +0,0 @@ -# Recipe: Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) - -> Open-Source Algorithm Implementation & Expriement Running: [Yuxuan Tong](https://tongyx361.github.io/), [Guangming Sheng](https://hk.linkedin.com/in/guangming-sheng-b50640211) - -> [!IMPORTANT] -> -> **🔥 News!!!** -> -> - [2025/04] We reproduced the results of two versions of DAPO ([Full](./run_dapo_qwen2.5_32b.sh) & [w/o Dynamic Sampling](./run_dapo_wo_ds_qwen2.5_32b.sh)), achieving 52% and 50% on AIME 2024 respectively, based on [the latest codebase on `recipe/dapo`](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo). Please check the details in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n). -> - [2025/03] We published the training record of [an early version of DAPO (w/o Token-level PG Loss & Dynamic Sampling)](./run_dapo_early_qwen2.5_32b.sh), achieving 44% on AIME 2024, in [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n). - -🏠 [Homepage](https://dapo-sia.github.io/) | 📝 [Paper@arXiv](https://arxiv.org/abs/2503.14476) | 🤗 [Datasets&Models@HF](https://huggingface.co/collections/BytedTsinghua-SIA/dapo-67d7f1517ee33c8aed059da0) | 🐱 [Code@GitHub](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) | 🐱 [Repo@GitHub](https://github.com/BytedTsinghua-SIA/DAPO) - -> We propose the **D**ecoupled Clip and Dynamic s**A**mpling **P**olicy **O**ptimization (DAPO) algorithm. By making our work publicly available, we provide the broader research community and society with practical access to scalable reinforcement learning, enabling all to benefit from these advancements. Our system is based on the awesome [verl](https://github.com/volcengine/verl) framework. Thanks for their great work! Applying DAPO training to Qwen2.5-32B base model proves to outperform the previous state-of-the-art DeepSeek-R1-Zero-Qwen-32B on AIME 2024, achieving **50%** accuracy with **50%** less training steps. -> -> ![dapo-main-result](https://dapo-sia.github.io/static/images/score.png) - -## Quickstart - -1. Prepare the datasets **on the Ray cluster**: - -```bash -bash prepare_dapo_data.sh # This downloads the datasets to ${HOME}/verl/data by default -``` - -2. Submit the job to the Ray cluster **from any machine**: - -```bash -cd verl # Repo root -export RAY_ADDRESS="http://${RAY_IP:-localhost}:8265" # The Ray cluster address to connect to -export WORKING_DIR="${PWD}" # The local directory to package to the Ray cluster -# Set the runtime environment like env vars and pip packages for the Ray cluster in yaml -export RUNTIME_ENV="./recipe/dapo/runtime_env.yaml" # This sets environment variables for the Ray cluster -bash recipe/dapo/run_dapo_qwen2.5_32b.sh # or other scripts -``` - -## Reproduction Runs - -| Setup | AIME 2024 Acc. | Hardware | Image | Commit | Environment Variables | Training Script | Training Record | -| -------------------------------------------- | -------------- | --------- | -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | -| DAPO | 52% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | -| DAPO w/o Dynamic Sampling | 50% | 16x8xH800 | `hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_wo_ds_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | -| DAPO w/o Token-level Loss & Dynamic Sampling | 44% | 16x8xH20 | `hiyouga/verl:ngc-th2.5.1-cu120-vllm0.7.4-hotfix` | [`4f80e4`](https://github.com/volcengine/verl/tree/4f80e465c2ec79ab9c3c30ec74b9745de61d0490) | [runtime_env.yaml](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/runtime_env.yaml) | [run_dapo_early_qwen2.5_32b.sh](https://github.com/volcengine/verl/blob/4f80e465c2ec79ab9c3c30ec74b9745de61d0490/recipe/dapo/run_dapo_early_qwen2.5_32b.sh) | [W&B](https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/workspace?nw=wmb4qxfht0n) | - -> [!IMPORTANT] -> -> **📢 Call for Contribution!** -> -> Welcome to submit your reproduction runs and setups! - -## Configuration - -### Separated Clip Epsilons (-> Clip-Higher) - -An example configuration: - -```yaml -actor_rollout_ref: - actor: - clip_ratio_low: 0.2 - clip_ratio_high: 0.28 -``` - -`clip_ratio_low` and `clip_ratio_high` specify the $\varepsilon_{\text {low }}$ and $\varepsilon_{\text {high }}$ in the DAPO objective. - -Core relevant code: - -```python -pg_losses1 = -advantages * ratio -pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) -pg_losses = torch.maximum(pg_losses1, pg_losses2) -``` - -### Dynamic Sampling (with Group Filtering) - -An example configuration: - -```yaml -data: - gen_batch_size: 1536 - train_batch_size: 512 -algorithm: - filter_groups: - enable: True - metric: acc # score / seq_reward / seq_final_reward / ... - max_num_gen_batches: 10 # Non-positive values mean no upper limit -``` - -Setting `filter_groups.enable` to `True` will filter out groups whose outputs' `metric` are all the same, e.g., for `acc`, groups whose outputs' accuracies are all 1 or 0. - -The trainer will repeat sampling with `gen_batch_size` until there are enough qualified groups for `train_batch_size` or reaching the upper limit specified by `max_num_gen_batches`. - -Core relevant code: - -```python -prompt_bsz = self.config.data.train_batch_size -if num_prompt_in_batch < prompt_bsz: - print(f'{num_prompt_in_batch=} < {prompt_bsz=}') - num_gen_batches += 1 - max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches - if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: - print(f'{num_gen_batches=} < {max_num_gen_batches=}. Keep generating...') - continue - else: - raise ValueError( - f'{num_gen_batches=} >= {max_num_gen_batches=}. Generated too many. Please check your data.' - ) -else: - # Align the batch - traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n - batch = batch[:traj_bsz] -``` - -### Flexible Loss Aggregation Mode (-> Token-level Loss) - -An example configuration: - -```yaml -actor_rollout_ref: - actor: - loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" - # NOTE: "token-mean" is the default behavior -``` - -Setting `loss_agg_mode` to `token-mean` will mean the (policy gradient) loss across all the tokens in all the sequences in a mini-batch. - -Core relevant code: - -```python -if loss_agg_mode == "token-mean": - loss = verl_F.masked_mean(loss_mat, loss_mask) -elif loss_agg_mode == "seq-mean-token-sum": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum - loss = torch.mean(seq_losses) # seq-mean -elif loss_agg_mode == "seq-mean-token-mean": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean - loss = torch.mean(seq_losses) # seq-mean -else: - raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") -``` - -### Overlong Reward Shaping - -An example configuration: - -```yaml -data: - max_response_length: 20480 # 16384 + 4096 -reward_model: - overlong_buffer: - enable: True - len: 4096 - penalty_factor: 1.0 -``` - -Setting `overlong_buffer.enable` to `True` will penalize the outputs whose lengths are overlong but still within the hard context limit. - -Specifically, the penalty increases linearly from `0` to `overlong_buffer.penalty_factor` when the length of the output exceeds the `max_response_length` by `0` to `overlong_buffer.len` tokens. - -Core relevant code: - -```python -if self.overlong_buffer_cfg.enable: - overlong_buffer_len = self.overlong_buffer_cfg.len - expected_len = self.max_resp_len - overlong_buffer_len - exceed_len = valid_response_length - expected_len - overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor - overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) - reward += overlong_reward -``` - -## FAQ - -### Where is the "Overlong Filtering" in the paper? - -Most experiments in the paper, including the best-performant one, are run without Overlong Filtering because it's somehow overlapping with Overlong Reward Shaping in terms of properly learning from the longest outputs. So we don't implement it here. - -### What's the difference between [the `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) and the [`recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo)? - -[The `recipe/dapo` branch](https://github.com/volcengine/verl/tree/recipe/dapo/recipe/dapo) is for **as-is reproduction** and thus won't be updated with new features. - -[The `recipe/dapo` directory in the `main` branch](https://github.com/volcengine/verl/tree/main/recipe/dapo) works as an example of how to extend the latest `verl` to implement an algorithm recipe, which will be maintained with new features. - -### Why can't I produce similar results after modifications? - -RL infrastructures nowadays still have inherent unrobustness, on which we are still working hard to improve. - -We strongly recommend to only modify one thing at a time. - -We also list some known problems here: - -1. Enabling CUDA graph (`enforce_eager=False`) might cause model performance degradation, whose cause is still under investigation. diff --git a/recipe/dapo/config/dapo_fsdp_config.yaml b/recipe/dapo/config/dapo_fsdp_config.yaml deleted file mode 100644 index 47141447e..000000000 --- a/recipe/dapo/config/dapo_fsdp_config.yaml +++ /dev/null @@ -1,26 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -data: - gen_batch_size: ${data.train_batch_size} - -reward_model: - reward_manager: dapo - overlong_buffer: - enable: False # We try to avoid forgetting to set enable - len: 0 - penalty_factor: 0.0 - log: False - -algorithm: - filter_groups: - _target_: verl.trainer.config.FilterGroupsConfig - enable: False # We try to avoid forgetting to set enable - metric: null # acc / score / seq_reward / seq_final_reward / ... - max_num_gen_batches: 0 # Non-positive values mean no upper limit - diff --git a/recipe/dapo/config/dapo_megatron_config.yaml b/recipe/dapo/config/dapo_megatron_config.yaml deleted file mode 100644 index 5b83fab85..000000000 --- a/recipe/dapo/config/dapo_megatron_config.yaml +++ /dev/null @@ -1,25 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_megatron_trainer - - _self_ - -data: - gen_batch_size: ${data.train_batch_size} - -reward_model: - reward_manager: dapo - overlong_buffer: - enable: False # We try to avoid forgetting to set enable - len: 0 - penalty_factor: 0.0 - log: False - -algorithm: - filter_groups: - _target_: verl.trainer.config.FilterGroupsConfig - enable: False # We try to avoid forgetting to set enable - metric: null # acc / score / seq_reward / seq_final_reward / ... - max_num_gen_batches: 0 # Non-positive values mean no upper limit \ No newline at end of file diff --git a/recipe/dapo/config/dapo_trainer.yaml b/recipe/dapo/config/dapo_trainer.yaml deleted file mode 100644 index 47ac00fd6..000000000 --- a/recipe/dapo/config/dapo_trainer.yaml +++ /dev/null @@ -1,28 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -data: - gen_batch_size: ${data.train_batch_size} - -reward_model: - reward_manager: dapo - overlong_buffer: - enable: False # We try to avoid forgetting to set enable - len: 0 - penalty_factor: 0.0 - log: False - -algorithm: - filter_groups: - _target_: verl.trainer.config.FilterGroupsConfig - enable: False # We try to avoid forgetting to set enable - metric: null # acc / score / seq_reward / seq_final_reward / ... - max_num_gen_batches: 0 # Non-positive values mean no upper limit - -trainer: - project_name: verl-dapo diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py deleted file mode 100644 index faf23423c..000000000 --- a/recipe/dapo/dapo_ray_trainer.py +++ /dev/null @@ -1,375 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -FSDP PPO Trainer with Ray-based single controller. -This trainer supports model-agonistic model initialization with huggingface -""" - -import uuid -from collections import defaultdict -from copy import deepcopy -from pprint import pprint - -import numpy as np -import torch -from tqdm import tqdm - -from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss -from verl.trainer.ppo.metric_utils import ( - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, - reduce_metrics, -) -from verl.trainer.ppo.ray_trainer import ( - AdvantageEstimator, - RayPPOTrainer, - apply_kl_penalty, - compute_advantage, - compute_response_mask, -) -from verl.utils.profiler import marked_timer - - -class RayDAPOTrainer(RayPPOTrainer): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - def fit(self): - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC - to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - - from verl.utils.tracking import Tracking - - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - - self.global_steps = 0 - - # load checkpoint before doing anything - self._load_checkpoint() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - assert val_metrics, f"{val_metrics=}" - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") - - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - - timing_raw = defaultdict(float) - batch = None - num_prompt_in_batch = 0 - num_gen_batches = 0 - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - - do_profile = ( - self.global_steps in self.config.trainer.profile_steps - if self.config.trainer.profile_steps is not None - else False - ) - with marked_timer("start_profile", timing_raw): - if do_profile: - self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) - if self.use_reference_policy: - self.ref_policy_wg.start_profile() - if self.use_critic: - self.critic_wg.start_profile() - if self.use_rm: - self.rm_wg.start_profile() - - new_batch: DataProto = DataProto.from_single_dict(batch_dict) - num_gen_batches += 1 - # pop those keys for generation - if "multi_modal_data" in new_batch.non_tensor_batch.keys(): - gen_batch = new_batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"], - ) - else: - gen_batch = new_batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], - ) - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - - is_last_step = self.global_steps >= self.total_training_steps - - with marked_timer("step", timing_raw): - # generate a batch - with marked_timer("gen", timing_raw, "red"): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - timing_raw.update(gen_batch_output.meta_info["timing"]) - gen_batch_output.meta_info.pop("timing", None) - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with marked_timer("gen_max", timing_raw, "red"): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - - new_batch = new_batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(new_batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - new_batch.batch["reward_baselines"] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - - new_batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object - ) - # repeat to align with repeated responses in rollout - new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - new_batch = new_batch.union(gen_batch_output) - - with marked_timer("reward", timing_raw, "yellow"): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(new_batch) - new_batch = new_batch.union(reward_tensor) - - # we combine with rule-based rm - reward_extra_infos_dict: dict[str, list] - try: - reward_result = self.reward_fn(new_batch, return_dict=True) - reward_tensor = reward_result["reward_tensor"] - reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) - except Exception as e: - print(f"Error in reward_fn: {e}") - reward_tensor = self.reward_fn(new_batch) - reward_extra_infos_dict = {} - - new_batch.batch["token_level_scores"] = reward_tensor - - if reward_extra_infos_dict: - new_batch.non_tensor_batch.update( - {k: np.array(v) for k, v in reward_extra_infos_dict.items()} - ) - - # compute rewards. apply_kl_penalty if available - if self.config.algorithm.use_kl_in_reward: - new_batch, kl_metrics = apply_kl_penalty( - new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty - ) - metrics.update( - kl_metrics - ) # TODO: This will be cleared if we use multiple genenration batches - else: - new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] - - if not self.config.algorithm.filter_groups.enable: - batch = new_batch - else: # NOTE: When prompts after filtering is less than train batch size, - # we skip to the next generation batch - metric_name = self.config.algorithm.filter_groups.metric - if metric_name == "seq_final_reward": - # Turn to numpy for easier filtering - new_batch.non_tensor_batch["seq_final_reward"] = ( - new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() - ) - elif metric_name == "seq_reward": - new_batch.non_tensor_batch["seq_reward"] = ( - new_batch.batch["token_level_scores"].sum(dim=-1).numpy() - ) - - # Collect the sequence reward for each trajectory - prompt_uid2metric_vals = defaultdict(list) - for uid, metric_val in zip( - new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True - ): - prompt_uid2metric_vals[uid].append(metric_val) - - prompt_uid2metric_std = {} - for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): - prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) - - kept_prompt_uids = [ - uid - for uid, std in prompt_uid2metric_std.items() - if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 - ] - num_prompt_in_batch += len(kept_prompt_uids) - - kept_traj_idxs = [] - for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): - if traj_from_prompt_uid in kept_prompt_uids: - kept_traj_idxs.append(idx) - - new_batch = new_batch[kept_traj_idxs] - batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) - - prompt_bsz = self.config.data.train_batch_size - if num_prompt_in_batch < prompt_bsz: - print(f"{num_prompt_in_batch=} < {prompt_bsz=}") - max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches - if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: - print(f"{num_gen_batches=}. Keep generating...") - progress_bar.update(1) - continue - else: - raise ValueError( - f"{num_gen_batches=} >= {max_num_gen_batches=}." - + " Generated too many. Please check if your data are too difficult." - + " You could also try set max_num_gen_batches=0 to enable endless trials." - ) - else: - # Align the batch - traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n - batch = batch[:traj_bsz] - - # === Updating === - - batch.batch["response_mask"] = compute_response_mask(batch) - - # Balance the number of valid tokens across DP ranks. - # NOTE: This usually changes the order of data in the `batch`, - # which won't affect the advantage calculation (since it's based on uid), - # but might affect the loss calculation (due to the change of mini-batching). - # TODO: Decouple the DP balancing and mini-batching. - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - - # recompute old_log_probs - with marked_timer("old_log_prob", timing_raw, "blue"): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch["entropys"] - response_masks = batch.batch["response_mask"] - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} - metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) - - if self.use_reference_policy: - # compute reference log_prob - with marked_timer("ref", timing_raw, "olive"): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # compute values - if self.use_critic: - with marked_timer("values", timing_raw, "cyan"): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with marked_timer("adv", timing_raw, "brown"): - # compute advantages, executed on the driver process - norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) - batch = compute_advantage( - batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - ) - - # update critic - if self.use_critic: - with marked_timer("update_critic", timing_raw, "pink"): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with marked_timer("update_actor", timing_raw, "red"): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # validate - if ( - self.val_reward_fn is not None - and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) - ): - with marked_timer("testing", timing_raw, "green"): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 - ): - with marked_timer("save_checkpoint", timing_raw, "green"): - self._save_checkpoint() - - with marked_timer("stop_profile", timing_raw): - if do_profile: - self.actor_rollout_wg.stop_profile() - if self.use_reference_policy: - self.ref_policy_wg.stop_profile() - if self.use_critic: - self.critic_wg.stop_profile() - if self.use_rm: - self.rm_wg.stop_profile() - - # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - # TODO: implement actual tflpo and theoretical tflpo - n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) - timing_raw = defaultdict(float) # clear timing - - metrics["train/num_gen_batches"] = num_gen_batches - batch = None - num_prompt_in_batch = 0 - num_gen_batches = 0 - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - if is_last_step: - pprint(f"Final validation metrics: {last_val_metrics}") - progress_bar.close() - return - - progress_bar.update(1) - self.global_steps += 1 diff --git a/recipe/dapo/main_dapo.py b/recipe/dapo/main_dapo.py deleted file mode 100644 index 268591b8d..000000000 --- a/recipe/dapo/main_dapo.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. -""" - -import os -import socket - -import hydra -import ray -from omegaconf import OmegaConf - -from verl.trainer.ppo.reward import load_reward_manager -from verl.utils.device import is_cuda_available - -from .dapo_ray_trainer import RayDAPOTrainer - - -@hydra.main(config_path="config", config_name="dapo_trainer", version_base=None) -def main(config): - run_ppo(config) - - -def run_ppo(config) -> None: - if not ray.is_initialized(): - # this is for local ray cluster - ray.init( - runtime_env={ - "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} - }, - num_cpus=config.ray_init.num_cpus, - ) - - if ( - is_cuda_available - and OmegaConf.select(config.trainer, "profile_steps") is not None - and len(OmegaConf.select(config.trainer, "profile_steps")) > 0 - ): - nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) - runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() - else: - runner = TaskRunner.remote() - ray.get(runner.run.remote(config)) - - -@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head -class TaskRunner: - def run(self, config): - # print initial config - from pprint import pprint - - from omegaconf import OmegaConf - - from verl.utils.fs import copy_to_local - - print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") - - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - OmegaConf.resolve(config) - - # download the checkpoint from hdfs - local_path = copy_to_local(config.actor_rollout_ref.model.path) - - # instantiate tokenizer - from verl.utils import hf_processor, hf_tokenizer - - tokenizer = hf_tokenizer(local_path) - processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none - - # define worker classes - if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: - assert config.critic.strategy in {"fsdp", "fsdp2"} - from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker - - ray_worker_group_cls = RayWorkerGroup - - elif config.actor_rollout_ref.actor.strategy == "megatron": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker - - ray_worker_group_cls = NVMegatronRayWorkerGroup - - else: - raise NotImplementedError - - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role - - role_worker_mapping = { - Role.ActorRollout: ray.remote(ActorRolloutRefWorker), - Role.Critic: ray.remote(CriticWorker), - } - - global_pool_id = "global_pool" - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - Role.Critic: global_pool_id, - } - - # we should adopt a multi-source reward function here - # - for rule-based rm, we directly call a reward score - # - for model-based rm, we call a model - # - for code related prompt, we send to a sandbox if there are test cases - # - finally, we combine all the rewards together - # - The reward type depends on the tag of the data - if config.reward_model.enable: - if config.reward_model.strategy in {"fsdp", "fsdp2"}: - from verl.workers.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == "megatron": - from verl.workers.megatron_workers import RewardModelWorker - else: - raise NotImplementedError - role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) - mapping[Role.RewardModel] = global_pool_id - - # reference model - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: - role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) - mapping[Role.RefPolicy] = global_pool_id - - reward_fn = load_reward_manager( - config, - tokenizer, - 0, - max_resp_len=config.data.max_response_length, - overlong_buffer_cfg=config.reward_model.overlong_buffer, - ) - - # Note that we always use function-based RM for validation - val_reward_fn = load_reward_manager( - config, - tokenizer, - 1, - max_resp_len=config.data.max_response_length, - overlong_buffer_cfg=config.reward_model.overlong_buffer, - ) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - trainer = RayDAPOTrainer( - config=config, - tokenizer=tokenizer, - processor=processor, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn, - device_name=config.trainer.device, - ) - trainer.init_workers() - trainer.fit() - - -if __name__ == "__main__": - main() diff --git a/recipe/dapo/prepare_dapo_data.sh b/recipe/dapo/prepare_dapo_data.sh deleted file mode 100644 index b5dbb25a7..000000000 --- a/recipe/dapo/prepare_dapo_data.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash -set -uxo pipefail - -export VERL_HOME=${VERL_HOME:-"${HOME}/verl"} -export TRAIN_FILE=${TRAIN_FILE:-"${VERL_HOME}/data/dapo-math-17k.parquet"} -export TEST_FILE=${TEST_FILE:-"${VERL_HOME}/data/aime-2024.parquet"} -export OVERWRITE=${OVERWRITE:-0} - -mkdir -p "${VERL_HOME}/data" - -if [ ! -f "${TRAIN_FILE}" ] || [ "${OVERWRITE}" -eq 1 ]; then - wget -O "${TRAIN_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k/resolve/main/data/dapo-math-17k.parquet?download=true" -fi - -if [ ! -f "${TEST_FILE}" ] || [ "${OVERWRITE}" -eq 1 ]; then - wget -O "${TEST_FILE}" "https://huggingface.co/datasets/BytedTsinghua-SIA/AIME-2024/resolve/main/data/aime-2024.parquet?download=true" -fi diff --git a/recipe/dapo/run_dapo_early_qwen2.5_32b.sh b/recipe/dapo/run_dapo_early_qwen2.5_32b.sh deleted file mode 100644 index 81bc2cb12..000000000 --- a/recipe/dapo/run_dapo_early_qwen2.5_32b.sh +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -project_name='DAPO' -exp_name='DAPO-Early-Qwen2.5-32B' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 20)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -# An early version for DAPO -loss_agg_mode="seq-mean-token-mean" - -enable_filter_groups=False -gen_prompt_bsz=512 # NOTE: no filtering here -train_prompt_bsz=512 -train_prompt_mini_bsz=32 -n_resp_per_prompt=16 - -# Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-16} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -val_top_p=0.7 - - -# Performance Related Parameter -sp_size=8 -use_dynamic_bsz=True -actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) -infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) -offload=True -gen_tp=4 - -ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ - --working-dir "${WORKING_DIR}" \ - -- python3 -m recipe.dapo.main_dapo \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.gen_batch_size=${gen_prompt_bsz} \ - data.train_batch_size=${train_prompt_bsz} \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k="${top_k}" \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - reward_model.reward_manager=dapo \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=True \ - trainer.test_freq=5 \ - trainer.save_freq=5 \ - trainer.total_epochs=1 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto diff --git a/recipe/dapo/run_dapo_qwen2.5_32b.sh b/recipe/dapo/run_dapo_qwen2.5_32b.sh deleted file mode 100644 index feb783a7c..000000000 --- a/recipe/dapo/run_dapo_qwen2.5_32b.sh +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -project_name='DAPO' -exp_name='DAPO-Qwen2.5-32B' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 20)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -enable_filter_groups=True -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=512 -gen_prompt_bsz=$((train_prompt_bsz * 3)) -n_resp_per_prompt=16 -train_prompt_mini_bsz=32 - -# Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-16} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -val_top_p=0.7 - -# Performance Related Parameter -sp_size=8 -use_dynamic_bsz=True -actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) -infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) -offload=True -gen_tp=4 - -ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ - --working-dir "${WORKING_DIR}" \ - -- python3 -m recipe.dapo.main_dapo \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.gen_batch_size=${gen_prompt_bsz} \ - data.train_batch_size=${train_prompt_bsz} \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k="${top_k}" \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - reward_model.reward_manager=dapo \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=True \ - trainer.test_freq=5 \ - trainer.save_freq=5 \ - trainer.total_epochs=1 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto diff --git a/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh b/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh deleted file mode 100644 index b0491aedf..000000000 --- a/recipe/dapo/run_dapo_wo_ds_qwen2.5_32b.sh +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env bash -set -euxo pipefail -# DAPO (w/o Dynamic Sampling) - -project_name='DAPO-verl' -exp_name='DAPO-wo-DS-Qwen2.5-32B' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 20)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -enable_filter_groups=False -train_prompt_bsz=512 -n_resp_per_prompt=16 -train_prompt_mini_bsz=32 - -# Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-16} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-32B"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -val_top_p=0.7 - -# Performance Related Parameter -sp_size=8 -use_dynamic_bsz=True -actor_ppo_max_token_len=$((max_prompt_length + max_response_length)) -infer_ppo_max_token_len=$((max_prompt_length + max_response_length)) -offload=True -gen_tp=4 - -ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ - --working-dir "${WORKING_DIR}" \ - -- python3 -m recipe.dapo.main_dapo \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k="${top_k}" \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - reward_model.reward_manager=dapo \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=True \ - trainer.test_freq=5 \ - trainer.save_freq=5 \ - trainer.total_epochs=1 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto diff --git a/recipe/dapo/runtime_env.yaml b/recipe/dapo/runtime_env.yaml deleted file mode 100644 index 13f4b2ba2..000000000 --- a/recipe/dapo/runtime_env.yaml +++ /dev/null @@ -1,5 +0,0 @@ -working_dir: ./ -excludes: ["/.git/"] -env_vars: - TORCH_NCCL_AVOID_RECORD_STREAMS: "1" - VLLM_USE_V1: "1" diff --git a/recipe/dapo/test_dapo_7b.sh b/recipe/dapo/test_dapo_7b.sh deleted file mode 100644 index 2bb94963d..000000000 --- a/recipe/dapo/test_dapo_7b.sh +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -project_name='DAPO' -exp_name='DAPO-Qwen2.5-7B-Math-Test' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 2)) -enable_overlong_buffer=True -overlong_buffer_len=512 -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -enable_filter_groups=True -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=512 -gen_prompt_bsz=$((train_prompt_bsz * 3)) -train_prompt_mini_bsz=32 -n_resp_per_prompt=16 - -# Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-4} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout - -# Mathematically equivalent -use_dynamic_bsz=True -infer_micro_batch_size=null -train_micro_batch_size=null -offload=False - -ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ - --working-dir "${WORKING_DIR}" \ - -- python3 -m recipe.dapo.main_dapo \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.gen_batch_size=${gen_prompt_bsz} \ - data.train_batch_size=${train_prompt_bsz} \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k="${top_k}" \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - reward_model.reward_manager=dapo \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=True \ - trainer.test_freq=2 \ - trainer.save_freq=2 \ - trainer.total_epochs=1 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=disable diff --git a/recipe/dapo/test_dapo_7b_math.sh b/recipe/dapo/test_dapo_7b_math.sh deleted file mode 100644 index 9574f7722..000000000 --- a/recipe/dapo/test_dapo_7b_math.sh +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -project_name='DAPO' -exp_name='DAPO-Qwen2.5-7b-MATH-0527a1' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -train_prompt_bsz=512 -n_resp_per_prompt=16 -train_prompt_mini_bsz=32 - -# Ray -# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -# WORKING_DIR=${WORKING_DIR:-"${PWD}"} -# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-8} -NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -val_top_p=0.7 - -# Performance Related Parameter -sp_size=4 -use_dynamic_bsz=True -actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) -infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) -offload=True -gen_tp=4 -fsdp_size=32 - -# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361 - -python3 -m verl.trainer.main_ppo \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.model.use_remove_padding=True \ - +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ - reward_model.reward_manager=dapo \ - +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ - +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=True \ - trainer.test_freq=10 \ - trainer.save_freq=10 \ - trainer.total_epochs=10 \ - trainer.total_training_steps=200 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto \ - trainer.log_val_generations=10 diff --git a/recipe/dapo/test_dapo_7b_math_lora.sh b/recipe/dapo/test_dapo_7b_math_lora.sh deleted file mode 100644 index d68e5d625..000000000 --- a/recipe/dapo/test_dapo_7b_math_lora.sh +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -project_name='DAPO' -exp_name='DAPO-Qwen2.5-7b-MATH-0527a1' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -train_prompt_bsz=512 -n_resp_per_prompt=16 -train_prompt_mini_bsz=32 - -# Ray -# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -# WORKING_DIR=${WORKING_DIR:-"${PWD}"} -# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-8} -NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -val_top_p=0.7 - -# Performance Related Parameter -sp_size=4 -use_dynamic_bsz=True -actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) -infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) -offload=True -gen_tp=4 -fsdp_size=32 - -# remember to set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for this model - -python3 -m verl.trainer.main_ppo \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.model.use_remove_padding=True \ - +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.model.lora_rank=8 \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ - reward_model.reward_manager=dapo \ - +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ - +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=True \ - trainer.test_freq=10 \ - trainer.save_freq=10 \ - trainer.total_epochs=10 \ - trainer.total_training_steps=200 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto \ - trainer.log_val_generations=10 diff --git a/recipe/dapo/test_dapo_7b_math_megatron.sh b/recipe/dapo/test_dapo_7b_math_megatron.sh deleted file mode 100644 index 4c16cd7d4..000000000 --- a/recipe/dapo/test_dapo_7b_math_megatron.sh +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -project_name='DAPO' -exp_name='DAPO-Qwen2.5-7b-MATH-megatron-0519a1' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -train_prompt_bsz=512 -n_resp_per_prompt=16 -train_prompt_mini_bsz=32 - -# Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-4} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -val_top_p=0.7 - -# Performance Related Parameter -use_dynamic_bsz=True -actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) -infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) -offload=True -gen_tp=4 -train_tp=4 -train_pp=2 - -# TODO: support dynamic_bsz for megatron -# actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ -# actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ -# actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ -# actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ -# actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ -# actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - -python3 -m verl.trainer.main_ppo \ - --config-path=config \ - --config-name='ppo_megatron_trainer.yaml' \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.megatron.param_offload=${offload} \ - actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ - actor_rollout_ref.actor.megatron.grad_offload=${offload} \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.optim.clip_grad=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.ref.megatron.param_offload=${offload} \ - reward_model.reward_manager=dapo \ - +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ - +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=16 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=False \ - trainer.test_freq=10 \ - trainer.save_freq=10 \ - trainer.total_epochs=10 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto \ - trainer.log_val_generations=10 diff --git a/recipe/dapo/test_dapo_dspk_671b_megatron.sh b/recipe/dapo/test_dapo_dspk_671b_megatron.sh deleted file mode 100644 index c6988d114..000000000 --- a/recipe/dapo/test_dapo_dspk_671b_megatron.sh +++ /dev/null @@ -1,143 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -# 0. download the config -# only need to download the configuration_deepseek.py and config.json -# remove the `quantization_config` in the `config.json` -# set `num_nextn_predict_layers=0` to disable MTP, which is not currently supported -huggingface-cli download deepseek-ai/DeepSeek-V3-0324 configuration_deepseek.py config.json - -project_name='DAPO' -exp_name='DAPO-DeepSeek-671b-megatron' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 4)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=0.1 - -loss_agg_mode="token-mean" - -train_prompt_bsz=512 # must be > n_gpus. need to fix -n_resp_per_prompt=2 -train_prompt_mini_bsz=16 # mini_bsz * n >= micro_bsz * pp * dp - -NNODES=${NNODES:-64} - -# 1. download the dist_ckpt format model from https://huggingface.co/BearBiscuit05/dpsk-v3-671B-BF16-dist_ckpt/tree/main -# change the MODEL_PATH and MCORE_MODEL_PATH to your own path -# Paths -MODEL_PATH="" -MCORE_MODEL_PATH="" -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -aime24_test_path=${RAY_DATA_HOME}/data/aime-2024.parquet -# TEST_FILE="['$math500_test_path', '$aime24_test_path']" - -TEST_FILE="['$aime24_test_path']" - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -val_top_p=0.7 - -# Performance Related Parameter -use_dynamic_bsz=True -actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) -infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) -offload=True -gen_tp=32 -train_tp=1 -train_ep=32 -train_pp=16 - -python3 -m verl.trainer.main_ppo \ - --config-path=config \ - --config-name='ppo_megatron_trainer.yaml' \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.megatron.param_offload=${offload} \ - actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ - actor_rollout_ref.actor.megatron.grad_offload=${offload} \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.actor.megatron.expert_model_parallel_size=${train_ep} \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ - +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=3 \ - +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=2 \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.optim.clip_grad=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.ref.megatron.expert_model_parallel_size=${train_ep} \ - actor_rollout_ref.ref.megatron.param_offload=${offload} \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=${MCORE_MODEL_PATH} \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ - reward_model.reward_manager=dapo \ - +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ - +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=False \ - trainer.test_freq=5 \ - trainer.save_freq=5 \ - trainer.total_epochs=10 \ - trainer.total_training_steps=10 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto \ - trainer.log_val_generations=10 diff --git a/recipe/dapo/test_dapo_qwen3_30b_math.sh b/recipe/dapo/test_dapo_qwen3_30b_math.sh deleted file mode 100644 index 741e0d6d0..000000000 --- a/recipe/dapo/test_dapo_qwen3_30b_math.sh +++ /dev/null @@ -1,126 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -project_name='DAPO' -exp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0527a1' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=True -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -train_prompt_bsz=512 -n_resp_per_prompt=16 -train_prompt_mini_bsz=32 - -# Ray -# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -# WORKING_DIR=${WORKING_DIR:-"${PWD}"} -# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-8} -NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"} -CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} -TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} -TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -val_top_p=0.7 - -# Performance Related Parameter -sp_size=4 -use_dynamic_bsz=True -actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) -infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) -offload=True -gen_tp=4 -fsdp_size=32 - -python3 -m verl.trainer.main_ppo \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ - reward_model.reward_manager=dapo \ - +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ - +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ - +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=True \ - trainer.test_freq=10 \ - trainer.save_freq=10 \ - trainer.total_epochs=10 \ - trainer.total_training_steps=300 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=auto \ - trainer.log_val_generations=10 diff --git a/recipe/entropy/32b_clip_cov.sh b/recipe/entropy/32b_clip_cov.sh deleted file mode 100644 index 65cbe2e14..000000000 --- a/recipe/entropy/32b_clip_cov.sh +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -export WANDB_API_KEY=YOUR_WANDB_API_KEY -# export VLLM_USE_V1=1 - -project_name='Qwen2.5-32B' -exp_name='clipcov' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=1 -clip_ratio_high=1 -clip_cov_ratio=0.0002 -clip_cov_lb=1.0 -clip_cov_ub=5.0 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 2)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" -loss_mode="clip_cov" -enable_filter_groups=True -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=256 -gen_prompt_bsz=$((train_prompt_bsz * 3)) -train_prompt_mini_bsz=32 -n_resp_per_prompt=8 -max_token=20480 - -# Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-4} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} -CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} -TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} -TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -ppo_kl_coef=1 -kl_cov_ratio=0.02 - -# Mathematically equivalent -use_dynamic_bsz=True -infer_micro_batch_size=null -train_micro_batch_size=null -offload=False - -HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.filter_overlong_prompts=False \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.gen_batch_size=${gen_prompt_bsz} \ - data.train_batch_size=${train_prompt_bsz} \ - data.return_raw_chat=True \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ - actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \ - actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \ - actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.mode=sync \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.weight_decay=0 \ - actor_rollout_ref.actor.optim.warmup_style=constant \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.clip_cov_ratio=${clip_cov_ratio} \ - actor_rollout_ref.actor.clip_cov_lb=${clip_cov_lb} \ - actor_rollout_ref.actor.clip_cov_ub=${clip_cov_ub} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k="${top_k}" \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=False \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - reward_model.reward_manager=dapo \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=False \ - trainer.test_freq=4 \ - trainer.save_freq=32 \ - trainer.total_epochs=1000 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=disable diff --git a/recipe/entropy/32b_kl_cov.sh b/recipe/entropy/32b_kl_cov.sh deleted file mode 100644 index b0ba4519f..000000000 --- a/recipe/entropy/32b_kl_cov.sh +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -export WANDB_API_KEY=YOUR_WANDB_API_KEY -# export VLLM_USE_V1=1 - -project_name='Qwen2.5-32B' -exp_name='klcov' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.2 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 2)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" -loss_mode="kl_cov" -enable_filter_groups=True -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=256 -gen_prompt_bsz=$((train_prompt_bsz * 3)) -train_prompt_mini_bsz=32 -n_resp_per_prompt=8 -max_token=20480 - -# Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-4} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} -CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} -TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} -TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -ppo_kl_coef=1 -kl_cov_ratio=0.0002 - -# Mathematically equivalent -use_dynamic_bsz=True -infer_micro_batch_size=null -train_micro_batch_size=null -offload=False - -HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.filter_overlong_prompts=False \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.gen_batch_size=${gen_prompt_bsz} \ - data.train_batch_size=${train_prompt_bsz} \ - data.return_raw_chat=True \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.loss_mode=${loss_mode} \ - actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ - actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \ - actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.mode=sync \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.weight_decay=0 \ - actor_rollout_ref.actor.optim.warmup_style=constant \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k="${top_k}" \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=False \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - reward_model.reward_manager=dapo \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=False \ - trainer.test_freq=4 \ - trainer.save_freq=32 \ - trainer.total_epochs=1000 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=disable diff --git a/recipe/entropy/32b_kl_cov_mininbsz.sh b/recipe/entropy/32b_kl_cov_mininbsz.sh deleted file mode 100644 index 15d191838..000000000 --- a/recipe/entropy/32b_kl_cov_mininbsz.sh +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -export WANDB_API_KEY=YOUR_WANDB_API_KEY -# export VLLM_USE_V1=1 - -project_name='Qwen2.5-32B' -exp_name='klcov' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.2 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 2)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" -loss_mode="kl_cov" -enable_filter_groups=True -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=256 -gen_prompt_bsz=$((train_prompt_bsz * 3)) -train_prompt_mini_bsz=16 -n_resp_per_prompt=8 -max_token=20480 - -# Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-4} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} -CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} -TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} -TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -ppo_kl_coef=1 -kl_cov_ratio=0.0002 - -# Mathematically equivalent -use_dynamic_bsz=True -infer_micro_batch_size=null -train_micro_batch_size=null -offload=False - -HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.filter_overlong_prompts=False \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.gen_batch_size=${gen_prompt_bsz} \ - data.train_batch_size=${train_prompt_bsz} \ - data.return_raw_chat=True \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ - actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \ - actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.mode=sync \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.weight_decay=0 \ - actor_rollout_ref.actor.optim.warmup_style=constant \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k="${top_k}" \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=False \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - reward_model.reward_manager=dapo \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=False \ - trainer.test_freq=4 \ - trainer.save_freq=32 \ - trainer.total_epochs=1000 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=disable diff --git a/recipe/entropy/7b_clip_cov.sh b/recipe/entropy/7b_clip_cov.sh deleted file mode 100644 index 7a68f37df..000000000 --- a/recipe/entropy/7b_clip_cov.sh +++ /dev/null @@ -1,145 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -export WANDB_API_KEY=YOUR_WANDB_API_KEY -# export VLLM_USE_V1=1 - -project_name='Qwen2.5-7B' -exp_name='clipcov' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=1 -clip_ratio_high=1 -clip_cov_ratio=0.0002 -clip_cov_lb=1.0 -clip_cov_ub=5.0 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 2)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" -loss_mode="clip_cov" -enable_filter_groups=True -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=256 -gen_prompt_bsz=$((train_prompt_bsz * 3)) -train_prompt_mini_bsz=32 -n_resp_per_prompt=8 -max_token=30720 - -# Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-4} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} -CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} -TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} -TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -ppo_kl_coef=1 -kl_cov_ratio=0.2 - -# Mathematically equivalent -use_dynamic_bsz=True -infer_micro_batch_size=null -train_micro_batch_size=null -offload=False - -HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.filter_overlong_prompts=False \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.gen_batch_size=${gen_prompt_bsz} \ - data.train_batch_size=${train_prompt_bsz} \ - data.return_raw_chat=True \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ - actor_rollout_ref.actor.policy_loss.clip_cov_ratio=${clip_cov_ratio} \ - actor_rollout_ref.actor.policy_loss.clip_cov_lb=${clip_cov_lb} \ - actor_rollout_ref.actor.policy_loss.clip_cov_ub=${clip_cov_ub} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.mode=sync \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.weight_decay=0 \ - actor_rollout_ref.actor.optim.warmup_style=constant \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k="${top_k}" \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=False \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - reward_model.reward_manager=dapo \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=False \ - trainer.test_freq=4 \ - trainer.save_freq=32 \ - trainer.total_epochs=1000 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=disable diff --git a/recipe/entropy/7b_kl_cov.sh b/recipe/entropy/7b_kl_cov.sh deleted file mode 100644 index 5dd1f8870..000000000 --- a/recipe/entropy/7b_kl_cov.sh +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -export WANDB_API_KEY=YOUR_WANDB_API_KEY -# export VLLM_USE_V1=1 - -project_name='Qwen2.5-7B' -exp_name='klcov' - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.2 - -max_prompt_length=$((1024 * 2)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 2)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" -loss_mode="kl_cov" -enable_filter_groups=True -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=256 -gen_prompt_bsz=$((train_prompt_bsz * 3)) -train_prompt_mini_bsz=32 -n_resp_per_prompt=8 -max_token=30720 - -# Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-4} -# Paths -RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} -MODEL_PATH=${MODEL_PATH:-"/YOUR_MODELPATH"} -CKPTS_DIR=${CKPTS_DIR:-"/YOUR_CKPTS_PATH"} -TRAIN_FILE=${TRAIN_FILE:-"/YOUR_TRAIN_FILE_PATH"} -TEST_FILE=${TEST_FILE:-["/YOUR_TRAIN_FILE_PATH"]} - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout -ppo_kl_coef=1 -kl_cov_ratio=0.002 - -# Mathematically equivalent -use_dynamic_bsz=True -infer_micro_batch_size=null -train_micro_batch_size=null -offload=False - -HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \ - data.train_files="${TRAIN_FILE}" \ - data.val_files="${TEST_FILE}" \ - data.prompt_key=prompt \ - data.truncation='left' \ - data.filter_overlong_prompts=False \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.gen_batch_size=${gen_prompt_bsz} \ - data.train_batch_size=${train_prompt_bsz} \ - data.return_raw_chat=True \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ - actor_rollout_ref.actor.policy_loss.kl_cov_ratio=${kl_cov_ratio} \ - actor_rollout_ref.actor.policy_loss.ppo_kl_coef=${ppo_kl_coef} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.mode=sync \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.weight_decay=0 \ - actor_rollout_ref.actor.optim.warmup_style=constant \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${max_token} \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k="${top_k}" \ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.do_sample=False \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - reward_model.reward_manager=dapo \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger='["console","wandb"]' \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes="${NNODES}" \ - trainer.val_before_train=False \ - trainer.test_freq=4 \ - trainer.save_freq=32 \ - trainer.total_epochs=1000 \ - trainer.default_local_dir="${CKPTS_DIR}" \ - trainer.resume_mode=disable diff --git a/recipe/entropy/README.md b/recipe/entropy/README.md deleted file mode 100644 index 5238cec84..000000000 --- a/recipe/entropy/README.md +++ /dev/null @@ -1,110 +0,0 @@ -
- -# The Entropy Mechanism of Reinforcement Learning for Large Language Model Reasoning. - -[![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/pdf/2505.22617) [![Github](https://img.shields.io/badge/PRIME-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/PRIME-RL/Entropy-Mechanism-of-RL) [![alphaXiv](https://img.shields.io/badge/discussion-A42C25?style=for-the-badge&logo=arxiv&logoColor=white&color=blue -)](https://www.alphaxiv.org/abs/2505.22617) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/stingning/status/1928088554166505667) [![Twitter](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/charlesfornlp/status/1928089451080585283) [![Twitter-ak](https://img.shields.io/badge/Twitter-%23000000.svg?style=for-the-badge&logo=twitter&logoColor=white)](https://x.com/_akhaliq/status/1928077929105268861) - - - - -
- - -# 🎉News - -- **[2025/05/29]** 🎉 Ranked **#1** of the day on [Huggingface Daily Papers](https://huggingface.co/papers?date=2025-05-29). -- **[2025/05/29]** Released our Paper on arXiv. See [here](https://arxiv.org/pdf/2505.22617). We provide insights into the entropy mechanism of RL for LLMs and propose two simple yet effective strategies to alleviate the entropy collapse. - - - -# ✨Getting started - -After preparing the training data, for training Qwen2.5-7B on a single node, taking the KL-Cov approach as an example, you can simply run: - -``` -cd verl -conda activate your_env -bash recipe/dapo/7b_kl_cov.sh -``` - -While for training Qwen2.5-32B on multi nodes, you can run the following commands: - -``` -cd verl -conda activate your_env -bash recipe/dapo/32b_kl_cov.sh -``` - -# 📖Introduction - -
- issue -
- -This paper addresses the entropy collapse issue in scaling reinforcement learning (RL) for large language models (LLMs), where policy entropy drops sharply during training, leading to overconfidence and performance saturation. We empirically establish a relationship between entropy ($H$) and performance ($R$): $R=−aexp(H)+b$, showing performance is bottlenecked by entropy exhaustion. - -
- issue -
- -Theoretically, we find entropy changes are driven by the covariance between action probability and logit updates, which correlates with advantage in Policy Gradient methods. High-probability, high-advantage actions reduce entropy, while rare, high-advantage actions increase it. Empirically, the covariance term remains positive, explaining entropy’s monotonic decline. To mitigate this, we propose ​​Clip-Cov​​ and ​​KL-Cov​​, which restrict updates for high-covariance tokens. These methods effectively prevent entropy collapse, and improve performance. - -# 📃Evaluation - -
- issue -
- - -Our method is able to maintain a considerably higher level of entropy throughout training. For example, when the baseline's entropy reaches a plateau and can no longer be consumed, the KL-Cov method still sustains an entropy level over 10 times higher. Meanwhile, the response length of the policy model steadily increases, and its performance on the test set consistently surpasses that of the baseline. This indicates that our model is able to explore more freely during training, learning better policy through RL. -| **Method** | **AIME24** | **AIME25** | **AMC** | **MATH-500** | **OMNI-MATH** | **OlympiadBench** | **Minerva** | **Avg.** | -| ----------------- | ---------: | ---------: | -------: | -----------: | ------------: | ----------------: | ----------: | -------: | -| *Qwen2.5-7B* | | | | | | | | | -| GRPO | 21.2 | 9.6 | 58.7 | 78.8 | 27.9 | 40.7 | 36.7 | 38.6 | -| w. Clip-higher | 18.1 | 11.5 | 56.6 | 79.2 | 29.8 | 43.3 | 40.4 | 38.8 | -| w. **`CLIP-Cov`** | 22.1 | **15.8** | 58.2 | 80.4 | **30.5** | **44.1** | **41.1** | 40.4 | -| w. **`KL-Cov`** | **22.6** | 12.9 | **61.4** | **80.8** | 29.1 | 42.6 | 38.2 | **40.6** | -| *Qwen2.5-32B* | | | | | | | | | -| GRPO | 21.8 | 16.2 | 69.7 | 84.2 | 35.2 | 43.6 | 45.5 | 45.8 | -| w. Clip-higher | 35.6 | 22.3 | 69.5 | 77.2 | 35.1 | 42.5 | 43.0 | 47.2 | -| w. **`CLIP-Cov`** | 32.3 | 22.7 | 67.2 | **87.0** | **42.0** | **57.2** | 46.0 | 50.3 | -| w. **`KL-Cov`** | **36.8** | **30.8** | **74.5** | 84.6 | 39.1 | 49.0 | **46.3** | **52.2** | - -Our two approaches both achieve non-trivial improvements across all benchmarks. Compared to GRPO, our method outperforms it by 2.0% on average for the 7B model and by 6.4% for the 32B model. Moreover, we observe that our method yields more substantial gains on the larger Qwen2.5-32B. Specifically, our method achieves improvements of 15.0% and 14.6% compared to GRPO on the most challenging benchmarks, AIME24 and AIME25, respectively. - - -# 🎈Citation -If you find this paper or repo helpful, please cite us. - -```bibtex -@article{cui2025entropy, - title={The Entropy Mechanism of Reinforcement Learning for Reasoning Language Models}, - author={Cui, Ganqu and Zhang, Yuchen and Chen, Jiacheng and Yuan, Lifan and Wang, Zhi and Zuo, Yuxin and Li, Haozhan and Fan, Yuchen and Chen, Huayu and Chen, Weize and others}, - journal={arXiv preprint arXiv:2505.22617}, - year={2025} -} -``` -# 🌻Acknowledgement -We implement our reinforcement learning algorithm extending from [verl](https://github.com/volcengine/verl). We utilize [vLLM](https://github.com/vllm-project/vllm) for inference. Our models are trained primarily on [Qwen2.5 family](https://github.com/QwenLM/Qwen2.5). Our training data is built from [DAPO-MATH](https://huggingface.co/datasets/BytedTsinghua-SIA/DAPO-Math-17k). Thanks for their great contributions! - -# 📬 Contact - -For questions, discussion, or collaboration opportunities, feel free to contact: -- Ganqu Cui: cuiganqu@pjlab.org.cn -- Yuchen Zhang: yuchen.zhang2003@gmail.com -- Jiacheng Chen: jackchan9345@gmail.com -- Ning Ding: ningding.cs@gmail.com - diff --git a/recipe/entropy/config/entropy_trainer.yaml b/recipe/entropy/config/entropy_trainer.yaml deleted file mode 100644 index 969c72946..000000000 --- a/recipe/entropy/config/entropy_trainer.yaml +++ /dev/null @@ -1,39 +0,0 @@ -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -data: - gen_batch_size: ${data.train_batch_size} - -reward_model: - reward_kwargs: - overlong_buffer_cfg: ${reward_model.overlong_buffer} - reward_manager: dapo - overlong_buffer: - enable: False - len: 0 - penalty_factor: 0.0 - log: False - -algorithm: - filter_groups: - enable: False # We try to avoid forgetting to set enable - metric: null # acc / score / seq_reward / seq_final_reward / ... - max_num_gen_batches: 0 # Non-positive values mean no upper limit - -trainer: - project_name: verl-entropy - -actor_rollout_ref: - actor: - policy_loss: - loss_mode: "vanilla" # /clip-cov / kl-cov from https://arxiv.org/abs/2505. - clip_cov_ratio: 0.0002 # for clip-cov loss - clip_cov_lb: 1.0 # for clip-cov loss - clip_cov_ub: 5.0 # for clip-cov loss - kl_cov_ratio: 0.0002 # for kl-cov loss - ppo_kl_coef: 0.1 # for kl-cov loss \ No newline at end of file diff --git a/recipe/entropy/entropy_ray_trainer.py b/recipe/entropy/entropy_ray_trainer.py deleted file mode 100644 index 0b0b04318..000000000 --- a/recipe/entropy/entropy_ray_trainer.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -FSDP PPO Trainer with Ray-based single controller. -This trainer supports model-agonistic model initialization with huggingface -""" - -import uuid -from collections import defaultdict -from copy import deepcopy -from pprint import pprint - -import numpy as np -import torch -from tqdm import tqdm - -from verl import DataProto -from verl.trainer.ppo.metric_utils import ( - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, - reduce_metrics, -) -from verl.trainer.ppo.ray_trainer import ( - AdvantageEstimator, - RayPPOTrainer, - apply_kl_penalty, - compute_advantage, - compute_response_mask, -) -from verl.utils.profiler import simple_timer - - -class RayEntropyTrainer(RayPPOTrainer): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - def fit(self): - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC - to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - - from verl.utils.tracking import Tracking - - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - - self.global_steps = 0 - - # load checkpoint before doing anything - self._load_checkpoint() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - assert val_metrics, f"{val_metrics=}" - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") - - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - - timing_raw = defaultdict(float) - batch = None - num_prompt_in_batch = 0 - num_gen_batches = 0 - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - - new_batch: DataProto = DataProto.from_single_dict(batch_dict) - num_gen_batches += 1 - # pop those keys for generation - if "multi_modal_inputs" in new_batch.non_tensor_batch.keys(): - gen_batch = new_batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], - ) - else: - gen_batch = new_batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], - ) - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - - is_last_step = self.global_steps >= self.total_training_steps - - with simple_timer("step", timing_raw): - # generate a batch - # with simple_timer("gen", timing_raw): - # gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - with simple_timer("gen", timing_raw): - if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - else: - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with simple_timer("gen_max", timing_raw): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - - new_batch = new_batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(new_batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - new_batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - new_batch.batch["reward_baselines"] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - - new_batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object - ) - # repeat to align with repeated responses in rollout - new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - new_batch = new_batch.union(gen_batch_output) - - with simple_timer("reward", timing_raw): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(new_batch) - new_batch = new_batch.union(reward_tensor) - - # we combine with rule-based rm - reward_extra_infos_dict: dict[str, list] - try: - reward_result = self.reward_fn(new_batch, return_dict=True) - reward_tensor = reward_result["reward_tensor"] - reward_extra_infos_dict = reward_result["reward_extra_info"] - except Exception as e: - print(f"Error in reward_fn: {e}") - reward_tensor = self.reward_fn(new_batch) - reward_extra_infos_dict = {} - - new_batch.batch["token_level_scores"] = reward_tensor - - print(f"{list(reward_extra_infos_dict.keys())=}") - if reward_extra_infos_dict: - new_batch.non_tensor_batch.update( - {k: np.array(v) for k, v in reward_extra_infos_dict.items()} - ) - - # compute rewards. apply_kl_penalty if available - if self.config.algorithm.use_kl_in_reward: - new_batch, kl_metrics = apply_kl_penalty( - new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty - ) - metrics.update( - kl_metrics - ) # TODO: This will be cleared if we use multiple genenration batches - else: - new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"] - - if not self.config.algorithm.filter_groups.enable: - batch = new_batch - else: # NOTE: When prompts after filtering is less than train batch size, - # we skip to the next generation batch - metric_name = self.config.algorithm.filter_groups.metric - if metric_name == "seq_final_reward": - # Turn to numpy for easier filtering - new_batch.non_tensor_batch["seq_final_reward"] = ( - new_batch.batch["token_level_rewards"].sum(dim=-1).numpy() - ) - elif metric_name == "seq_reward": - new_batch.non_tensor_batch["seq_reward"] = ( - new_batch.batch["token_level_scores"].sum(dim=-1).numpy() - ) - - # Collect the sequence reward for each trajectory - prompt_uid2metric_vals = defaultdict(list) - for uid, metric_val in zip( - new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True - ): - prompt_uid2metric_vals[uid].append(metric_val) - - prompt_uid2metric_std = {} - for prompt_uid, metric_vals in prompt_uid2metric_vals.items(): - prompt_uid2metric_std[prompt_uid] = np.std(metric_vals) - - kept_prompt_uids = [ - uid - for uid, std in prompt_uid2metric_std.items() - if std > 0 or len(prompt_uid2metric_vals[uid]) == 1 - ] - num_prompt_in_batch += len(kept_prompt_uids) - - kept_traj_idxs = [] - for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]): - if traj_from_prompt_uid in kept_prompt_uids: - kept_traj_idxs.append(idx) - - new_batch = new_batch[kept_traj_idxs] - batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) - - prompt_bsz = self.config.data.train_batch_size - if num_prompt_in_batch < prompt_bsz: - print(f"{num_prompt_in_batch=} < {prompt_bsz=}") - max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches - if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: - print(f"{num_gen_batches=}. Keep generating...") - continue - else: - raise ValueError( - f"{num_gen_batches=} >= {max_num_gen_batches=}." - + " Generated too many. Please check if your data are too difficult." - + " You could also try set max_num_gen_batches=0 to enable endless trials." - ) - else: - # Align the batch - traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n - print( - f"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. " - f"Collecting finished." - ) - batch = batch[:traj_bsz] - - # === Updating === - - batch.batch["response_mask"] = compute_response_mask(batch) - - # balance the number of valid tokens on each dp rank. - # Note that this breaks the order of data inside the batch. - # Please take care when you implement group based adv computation such as GRPO and rloo - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - - # recompute old_log_probs - with simple_timer("old_log_prob", timing_raw): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - batch = batch.union(old_log_prob) - - if self.use_reference_policy: - # compute reference log_prob - with simple_timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # compute values - if self.use_critic: - with simple_timer("values", timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with simple_timer("adv", timing_raw): - # compute advantages, executed on the driver process - norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) - batch = compute_advantage( - batch, - adv_estimator=self.config.algorithm.adv_estimator, - gamma=self.config.algorithm.gamma, - lam=self.config.algorithm.lam, - num_repeat=self.config.actor_rollout_ref.rollout.n, - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, - ) - - # update critic - if self.use_critic: - with simple_timer("update_critic", timing_raw): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with simple_timer("update_actor", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # validate - if ( - self.val_reward_fn is not None - and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) - ): - with simple_timer("testing", timing_raw): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 - ): - with simple_timer("save_checkpoint", timing_raw): - self._save_checkpoint() - - # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - # TODO: implement actual tflpo and theoretical tflpo - n_gpus = self.resource_pool_manager.get_n_gpus() - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) - timing_raw = defaultdict(float) # clear timing - - metrics["train/num_gen_batches"] = num_gen_batches - batch = None - num_prompt_in_batch = 0 - num_gen_batches = 0 - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - if is_last_step: - pprint(f"Final validation metrics: {last_val_metrics}") - progress_bar.close() - return - - progress_bar.update(1) - self.global_steps += 1 diff --git a/recipe/entropy/main_entropy.py b/recipe/entropy/main_entropy.py deleted file mode 100644 index a8bb0cb6a..000000000 --- a/recipe/entropy/main_entropy.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. -""" - -import hydra -import ray - -from .entropy_ray_trainer import RayEntropyTrainer -from .reward import load_reward_manager - - -@hydra.main(config_path="config", config_name="entropy_trainer", version_base=None) -def main(config): - run_ppo(config) - - -def run_ppo(config) -> None: - if not ray.is_initialized(): - # this is for local ray cluster - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "WARN", - "WANDB_API_KEY": "YOUR_WANDB_API_KEY", - } - }, - num_cpus=config.ray_init.num_cpus, - ) - - runner = TaskRunner.remote() - ray.get(runner.run.remote(config)) - - -def merge_dict(a: dict, b: dict) -> dict: - """Return a new dict that has `a` updated with `b` (b wins on conflicts). - - Example:: - - >>> d1 = {"x": 1, "y": 2} - >>> d2 = {"y": 20, "z": 3} - >>> new_dict = merge_dict(d1, d2) - >>> print(new_dict) # {'x': 1, 'y': 20, 'z': 3} - >>> print(d1) # {"x": 1, "y": 2} (unchanged) - >>> print(d2) # {"y": 20, "z": 3} (unchanged) - """ - return a | b - - -@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head -class TaskRunner: - def run(self, config): - # print initial config - from pprint import pprint - - from omegaconf import OmegaConf - - from verl.utils.fs import copy_to_local - - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - OmegaConf.resolve(config) - - # download the checkpoint from hdfs - local_path = copy_to_local(config.actor_rollout_ref.model.path) - print(f"{config.actor_rollout_ref.model.path}") - # instantiate tokenizer - from verl.utils import hf_processor, hf_tokenizer - - trust_remote_code = config.data.get("trust_remote_code", False) - tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none - - # define worker classes - if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: - assert config.critic.strategy in {"fsdp", "fsdp2"} - from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker - - actor_rollout_cls = ( - AsyncActorRolloutRefWorker - if config.actor_rollout_ref.rollout.mode == "async" - else ActorRolloutRefWorker - ) - ray_worker_group_cls = RayWorkerGroup - - elif config.actor_rollout_ref.actor.strategy == "megatron": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker - - actor_rollout_cls = ActorRolloutRefWorker - ray_worker_group_cls = NVMegatronRayWorkerGroup - - else: - raise NotImplementedError - - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role - - role_worker_mapping = { - Role.ActorRollout: ray.remote(actor_rollout_cls), - Role.Critic: ray.remote(CriticWorker), - } - - global_pool_id = "global_pool" - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - Role.Critic: global_pool_id, - } - - # we should adopt a multi-source reward function here - # - for rule-based rm, we directly call a reward score - # - for model-based rm, we call a model - # - for code related prompt, we send to a sandbox if there are test cases - # - finally, we combine all the rewards together - # - The reward type depends on the tag of the data - if config.reward_model.enable: - if config.reward_model.strategy in {"fsdp", "fsdp2"}: - from verl.workers.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == "megatron": - from verl.workers.megatron_workers import RewardModelWorker - else: - raise NotImplementedError - role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) - mapping[Role.RewardModel] = global_pool_id - - # use reference model - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: - role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) - mapping[Role.RefPolicy] = global_pool_id - - reward_kwargs = { - "max_resp_len": config.data.max_response_length, - "overlong_buffer_cfg": config.reward_model.overlong_buffer, - } - cfg_reward_kwargs = config.reward_model.get("reward_kwargs", {}) - reward_fn = load_reward_manager( - config, tokenizer, num_examine=0, **OmegaConf.merge(OmegaConf.create(reward_kwargs), cfg_reward_kwargs) - ) - val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1, **reward_kwargs) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - from verl.utils.dataset.rl_dataset import collate_fn - - train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) - val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) - train_sampler = create_rl_sampler(config.data, train_dataset) - trainer = RayEntropyTrainer( - config=config, - tokenizer=tokenizer, - processor=processor, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn, - train_dataset=train_dataset, - val_dataset=val_dataset, - collate_fn=collate_fn, - train_sampler=train_sampler, - ) - trainer.init_workers() - trainer.fit() - - -def create_rl_dataset(data_paths, data_config, tokenizer, processor): - """Create a dataset. - - Arguments: - data_config: The data config. - tokenizer (Tokenizer): The tokenizer. - processor (Processor): The processor. - - Returns: - dataset (Dataset): The dataset. - """ - from torch.utils.data import Dataset - - from verl.utils.dataset.rl_dataset import RLHFDataset - - if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: - from verl.utils.import_utils import load_extern_type - - dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) - if not issubclass(dataset_cls, Dataset): - raise TypeError( - f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' " - f"must inherit from torch.utils.data.Dataset" - ) - else: - dataset_cls = RLHFDataset - print(f"Using dataset class: {dataset_cls.__name__}") - - dataset = dataset_cls( - data_files=data_paths, - tokenizer=tokenizer, - processor=processor, - config=data_config, - ) - - return dataset - - -def create_rl_sampler(data_config, dataset): - """Create a sampler for the dataset. - - Arguments: - data_config: The data config. - dataset (Dataset): The dataset. - - Returns: - sampler (Sampler): The sampler. - """ - import torch - from torch.utils.data import RandomSampler, SequentialSampler - - # use sampler for better ckpt resume - if data_config.shuffle: - train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed(data_config.get("seed", 1)) - sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) - else: - sampler = SequentialSampler(data_source=dataset) - - return sampler - - -if __name__ == "__main__": - main() diff --git a/recipe/entropy/reward.py b/recipe/entropy/reward.py deleted file mode 100644 index 36b8b65a4..000000000 --- a/recipe/entropy/reward.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2025 Individual Contributor: Thibaut Barroyer -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -from functools import partial - -import ray - -from verl import DataProto -from verl.trainer.ppo.reward import compute_reward, get_custom_reward_fn - -from .reward_score import _default_compute_score - - -def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): - """ - Load and initialize a reward manager based on the configuration. - - Args: - config: PPO trainer configuration object containing reward_model fields. - tokenizer: Tokenizer object used for processing text. - num_examine: Number of samples to examine. - **reward_kwargs: Additional keyword arguments for the reward manager. - - Returns: - An instance of the specified reward manager class. - """ - from verl.workers.reward_manager import get_reward_manager_cls - - # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`: - # naive: NaiveRewardManager - # prime: PrimeRewardManager - # batch: BatchRewardManager - # dapo: DAPORewardManager - # Note(haibin.lin): For custom reward managers, please make sure they are imported and - # registered via `verl.workers.reward_manager.register` - # By default reward_manager is set to naive (NaiveRewardManager) - reward_manager_name = config.reward_model.get("reward_manager", "naive") - reward_manager_cls = get_reward_manager_cls(reward_manager_name) - - # Try to get a custom reward function based on the configuration - compute_score = get_custom_reward_fn(config) - final_compute_score = compute_score - - if compute_score is None: - sandbox_config = config.reward_model.get("sandbox_fusion") - sandbox_url = sandbox_config.get("url") if sandbox_config else None - if sandbox_url: - sandbox_manager = multiprocessing.Manager() - # Create a semaphore to control concurrent access to the sandbox - _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) - final_compute_score = partial( - _default_compute_score, sandbox_fusion_url=sandbox_url, concurrent_semaphore=_concurrent_semaphore - ) - else: - final_compute_score = _default_compute_score - - # Instantiate and return the reward manager with the specified parameters - return reward_manager_cls( - tokenizer=tokenizer, - num_examine=num_examine, - compute_score=final_compute_score, - reward_fn_key=config.data.reward_fn_key, - **reward_kwargs, - ) - - -@ray.remote(num_cpus=1) -def compute_reward_async(data: DataProto, config, tokenizer): - """ - Load the reward manager and compute the reward for a batch of data. - This is meant to be run in a separate Ray worker. - """ - reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) - return compute_reward(data, reward_fn) diff --git a/recipe/entropy/reward_score/__init__.py b/recipe/entropy/reward_score/__init__.py deleted file mode 100644 index 7224bf3c3..000000000 --- a/recipe/entropy/reward_score/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# from . import gsm8k, math, prime_math, prime_code - -import traceback - -from . import entropy_math - - -def _default_compute_score( - data_source, solution_str, ground_truth, extra_info=None, sandbox_fusion_url=None, concurrent_semaphore=None -): - try: - res = entropy_math.compute_score(solution_str, str(ground_truth)) - # print(f"data_source: {data_source}") - # raise NotImplementedError(f"Reward function is not implemented for {data_source=}") - - if isinstance(res, dict): - return res - elif isinstance(res, int | float | bool): - return float(res) - else: - return float(res[0]) - except Exception as e: - print(f"[ERROR] Error in process_completion for task : {str(e)}") - traceback.print_exc() # 打印完整堆栈 - raise # 重新抛出异常以便上层捕获 diff --git a/recipe/entropy/reward_score/entropy_math/__init__.py b/recipe/entropy/reward_score/entropy_math/__init__.py deleted file mode 100644 index 1b2ba647d..000000000 --- a/recipe/entropy/reward_score/entropy_math/__init__.py +++ /dev/null @@ -1,1062 +0,0 @@ -# Copyright 2024 PRIME team and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except Exception in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Provides a math answer grading function with high recall. -Based on HF math_verify, verl, open reasoner zero, etc. -""" - -import os -import re -import signal -from itertools import islice, zip_longest -from math import isclose -from typing import Optional - -import sympy -from latex2sympy2_extended import latex2sympy -from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify -from pylatexenc import latex2text -from sympy import N, simplify -from sympy.parsing import sympy_parser -from sympy.parsing.latex import parse_latex -from sympy.parsing.sympy_parser import parse_expr - -""" -This code is adapted from: Dr. GRPO (https://github.com/sail-sg/understand-r1-zero/blob/main/understand_r1_zero/math_grader.py). -""" - - -def timeout_ours(timeout_seconds: int = 8): - if os.name == "posix": - import signal - - def decorator(func): - def handler(signum, frame): - raise TimeoutError("Operation timed out!") - - def wrapper(*args, **kwargs): - old_handler = signal.getsignal(signal.SIGALRM) - signal.signal(signal.SIGALRM, handler) - signal.alarm(timeout_seconds) - - try: - return func(*args, **kwargs) - finally: - signal.alarm(0) - signal.signal(signal.SIGALRM, old_handler) - - return wrapper - - return decorator - else: - raise NotImplementedError(f"Unsupported OS: {os.name}") - - -# Dan Hendrycks' code -def mathd_normalize_answer(answer: Optional[str]) -> Optional[str]: - if answer is None: - return None - answer = answer.strip() - try: - # Remove enclosing `\text{}`. - m = re.search("^\\\\text\{(?P.+?)\}$", answer) - if m is not None: - answer = m.group("text").strip() - return _strip_string(answer) - except Exception: - return answer - - -# units mainly from MathQA -unit_texts = [ - "east", - "degree", - "mph", - "kmph", - "ft", - "m sqaure", - " m east", - "sq m", - "deg", - "mile", - "q .", - "monkey", - "prime", - "ratio", - "profit of rs", - "rd", - "o", - "gm", - "p . m", - "lb", - "tile", - "per", - "dm", - "lt", - "gain", - "ab", - "way", - "west", - "a .", - "b .", - "c .", - "d .", - "e .", - "f .", - "g .", - "h .", - "t", - "a", - "h", - "no change", - "men", - "soldier", - "pie", - "bc", - "excess", - "st", - "inches", - "noon", - "percent", - "by", - "gal", - "kmh", - "c", - "acre", - "rise", - "a . m", - "th", - "π r 2", - "sq", - "mark", - "l", - "toy", - "coin", - "sq . m", - "gallon", - "° f", - "profit", - "minw", - "yr", - "women", - "feet", - "am", - "pm", - "hr", - "cu cm", - "square", - "v â € ™", - "are", - "rupee", - "rounds", - "cubic", - "cc", - "mtr", - "s", - "ohm", - "number", - "kmph", - "day", - "hour", - "minute", - "min", - "second", - "man", - "woman", - "sec", - "cube", - "mt", - "sq inch", - "mp", - "∏ cm ³", - "hectare", - "more", - "sec", - "unit", - "cu . m", - "cm 2", - "rs .", - "rs", - "kg", - "g", - "month", - "km", - "m", - "cm", - "mm", - "apple", - "liter", - "loss", - "yard", - "pure", - "year", - "increase", - "decrease", - "d", - "less", - "Surface", - "litre", - "pi sq m", - "s .", - "metre", - "meter", - "inch", -] - -unit_texts.extend([t + "s" for t in unit_texts]) - - -def _strip_string(string): - def _fix_fracs(string): - substrs = string.split("\\frac") - new_str = substrs[0] - if len(substrs) > 1: - substrs = substrs[1:] - for substr in substrs: - new_str += "\\frac" - if substr[0] == "{": - new_str += substr - else: - try: - assert len(substr) >= 2 - except Exception: - return string - a = substr[0] - b = substr[1] - if b != "{": - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr - else: - new_str += "{" + a + "}{" + b + "}" - else: - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr - else: - new_str += "{" + a + "}" + b - string = new_str - return string - - def _fix_a_slash_b(string): - if len(string.split("/")) != 2: - return string - a = string.split("/")[0] - b = string.split("/")[1] - try: - a = int(a) - b = int(b) - assert string == "{}/{}".format(a, b) - new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" - return new_string - except Exception: - return string - - def _remove_right_units(string): - # "\\text{ " only ever occurs (at least in the val set) when describing units - if "\\text{ " in string: - splits = string.split("\\text{ ") - assert len(splits) == 2 - return splits[0] - else: - return string - - def _fix_sqrt(string): - if "\\sqrt" not in string: - return string - splits = string.split("\\sqrt") - new_string = splits[0] - for split in splits[1:]: - if split[0] != "{": - a = split[0] - new_substr = "\\sqrt{" + a + "}" + split[1:] - else: - new_substr = "\\sqrt" + split - new_string += new_substr - return new_string - - # linebreaks - string = string.replace("\n", "") - # print(string) - - # remove inverse spaces - string = string.replace("\\!", "") - # print(string) - - # replace \\ with \ - string = string.replace("\\\\", "\\") - # print(string) - - # matrix - string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) - string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) - string = string.replace("bmatrix", "pmatrix") - - # replace tfrac and dfrac with frac - string = string.replace("tfrac", "frac") - string = string.replace("dfrac", "frac") - string = string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge") - # print(string) - - # remove \left and \right - string = string.replace("\\left", "") - string = string.replace("\\right", "") - # print(string) - - # Remove unit: miles, dollars if after is not none - _string = re.sub(r"\\text{.*?}$", "", string).strip() - if _string != "" and _string != string: - # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) - string = _string - - # Remove unit: texts - for _ in range(2): - for unit_text in unit_texts: - # use regex, the prefix should be either the start of the string or a non-alphanumeric character - # the suffix should be either the end of the string or a non-alphanumeric character - _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) - if _string != "": - string = _string - - # Remove circ (degrees) - string = string.replace("^{\\circ}", "") - string = string.replace("^\\circ", "") - - # remove dollar signs - string = string.replace("\\$", "") - - # remove units (on the right) - string = _remove_right_units(string) - - # remove percentage - string = string.replace("\\%", "") - string = string.replace("\%", "") - - # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string - string = string.replace(" .", " 0.") - string = string.replace("{.", "{0.") - # if empty, return empty string - if len(string) == 0: - return string - if string[0] == ".": - string = "0" + string - - # to consider: get rid of e.g. "k = " or "q = " at beginning - if len(string.split("=")) == 2: - if len(string.split("=")[0]) <= 2: - string = string.split("=")[1] - - # fix sqrt3 --> sqrt{3} - string = _fix_sqrt(string) - - # remove spaces - string = string.replace(" ", "") - - # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). - # Also does a/b --> \\frac{a}{b} - string = _fix_fracs(string) - - # manually change 0.5 --> \frac{1}{2} - if string == "0.5": - string = "\\frac{1}{2}" - - # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y - string = _fix_a_slash_b(string) - - return string - - -SUBSTITUTIONS = [ - ("an ", ""), - ("a ", ""), - (".$", "$"), - ("\\$", ""), - (r"\ ", ""), - (" ", ""), - ("mbox", "text"), - (",\\text{and}", ","), - ("\\text{and}", ","), - ("\\text{m}", "\\text{}"), -] - - -REMOVED_EXPRESSIONS = [ - "square", - "ways", - "integers", - "dollars", - "mph", - "inches", - "ft", - "hours", - "km", - "units", - "\\ldots", - "sue", - "points", - "feet", - "minutes", - "digits", - "cents", - "degrees", - "cm", - "gm", - "pounds", - "meters", - "meals", - "edges", - "students", - "childrentickets", - "multiples", - "\\text{s}", - "\\text{.}", - "\\text{\ns}", - "\\text{}^2", - "\\text{}^3", - "\\text{\n}", - "\\text{}", - r"\mathrm{th}", - r"^\circ", - r"^{\circ}", - r"\;", - r",\!", - "{,}", - '"', - "\\dots", -] - - -def normalize_final_answer(final_answer: str) -> str: - """ - Normalize a final answer to a quantitative reasoning question. - This code comes from https://arxiv.org/pdf/2206.14858.pdf, page18. - """ - # final_answer = final_answer.split("=")[-1] - - for before, after in SUBSTITUTIONS: - final_answer = final_answer.replace(before, after) - for expr in REMOVED_EXPRESSIONS: - final_answer = final_answer.replace(expr, "") - - # Extract answer that is in LaTeX math, is bold, - # is surrounded by a box, etc. - final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) - final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) - final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) - - # Normalize shorthand TeX: - # \fracab -> \frac{a}{b} - # \frac{abc}{bef} -> \frac{abc}{bef} - # \fracabc -> \frac{a}{b}c - # \sqrta -> \sqrt{a} - # \sqrtab -> sqrt{a}b - final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) - final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) - final_answer = final_answer.replace("$", "") - - # Normalize 100,000 -> 100000 - if final_answer.replace(",", "").isdigit(): - final_answer = final_answer.replace(",", "") - - return final_answer - - -def repeatness(s: str): - def ranks(seq): - index = {v: i for i, v in enumerate(sorted(set(seq)))} - return [index[v] for v in seq] - - def suffixArray(s): - line = ranks(s) - n, k, ans, sa = len(s), 1, line, [0] * len(s) - while k < n - 1: - line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1))) - ans, k = line, k << 1 - for i, k in enumerate(ans): - sa[k] = i - return ans, sa - - def lcp(arr, suffixArr, inv_suff): - n, ans, k = len(arr), [0] * len(arr), 0 - - for i in range(n): - if inv_suff[i] == n - 1: - k = 0 - continue - - j = suffixArr[inv_suff[i] + 1] - while i + k < n and j + k < n and arr[i + k] == arr[j + k]: - k += 1 - - ans[inv_suff[i]] = k - if k > 0: - k -= 1 - - return ans - - arr = [ord(i) for i in s] - n = len(arr) - if n <= 1: - return 0 - c, sa = suffixArray(arr) - cnt = sum(lcp(arr, sa, c)) - - return (cnt * 2 / (n * (n + 1))) > 0.2 - - -class timeout: - def __init__(self, seconds=1, error_message="Timeout"): - self.seconds = seconds - self.error_message = error_message - - def handle_timeout(self, signum, frame): - raise TimeoutError(self.error_message) - - def __enter__(self): - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.alarm(self.seconds) - - def __exit__(self, type, value, traceback): - signal.alarm(0) - - -def latex_eval(latex): - sym = parse_latex(latex) - val = sym.evalf() - return sym, val - - -def numeric_equal(prediction: float, reference: float): - # Note that relative tolerance has significant impact - # on the result of the synthesized GSM-Hard dataset - # if reference.is_integer(): - # return isclose(reference, round(prediction), abs_tol=1e-4) - # else: - # prediction = round(prediction, len(str(reference).split(".")[-1])) - return isclose(reference, prediction, rel_tol=1e-4) - - -@timeout_ours(timeout_seconds=5) -def symbolic_equal(a, b): - def _parse(s): - for f in [parse_latex, parse_expr, latex2sympy]: - try: - return f(s.replace("\\\\", "\\")) - except Exception: - try: - return f(s) - except Exception: - pass - return s - - a = _parse(a) - b = _parse(b) - - # direct equal - try: - if str(a) == str(b) or a == b: - return True - except Exception: - pass - - # simplify equal - try: - if a.equals(b) or simplify(a - b) == 0: - return True - except Exception: - pass - - # equation equal - try: - if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): - return True - except Exception: - pass - - try: - if numeric_equal(float(N(a)), float(N(b))): - return True - except Exception: - pass - - # matrix - try: - # if a and b are matrix - if a.shape == b.shape: - _a = a.applyfunc(lambda x: round(x, 3)) - _b = b.applyfunc(lambda x: round(x, 3)) - if _a.equals(_b): - return True - except Exception: - pass - - return False - - -def _is_latex_equal(str1, str2): - try: - sym1, val1 = latex_eval(str1) - sym2, val2 = latex_eval(str2) - if sym1 == sym2 or val1 == val2: - return True - else: - raise ValueError - except Exception: # noqa - try: - norm1, norm2 = normalize_final_answer(str1), normalize_final_answer(str2) - sym1, val1 = latex_eval(norm1) - sym2, val2 = latex_eval(norm2) - if sym1 == sym2 or val1 == val2: - return True - except Exception: # noqa - return norm1 == norm2 - return False - - -def is_latex_equal(given_answer: str, ground_truth: str) -> bool: - try: - with timeout(1): - try: - if (len(given_answer) > 128 and repeatness(given_answer)) or ( - len(ground_truth) > 128 and repeatness(ground_truth) - ): - return False - # First conduct normalized string matching. - ground_truth_normalized = _normalize(ground_truth) - given_normalized = _normalize(given_answer) - if ground_truth_normalized is None: - return False - if ground_truth_normalized == given_normalized: - return True - - # Next call math verify. - given_answer.replace("\n", "") - ground_truth.replace("\n", "") - if "$" not in given_answer: - given_answer = f"${given_answer}$" - if "$" not in ground_truth: - ground_truth = f"${ground_truth}$" - return verify( - parse( - ground_truth, - extraction_config=( - LatexExtractionConfig(boxed_match_priority=0), - ExprExtractionConfig(), - ), - fallback_mode="no_fallback", - extraction_mode=["first_match"], - parsing_timeout=1, - ), - parse( - given_answer, - extraction_config=( - LatexExtractionConfig(boxed_match_priority=0), - ExprExtractionConfig(), - ), - fallback_mode="no_fallback", - extraction_mode=["first_match"], - parsing_timeout=1, - ), - timeout_seconds=1, - ) - # or symbolic_equal(ground_truth, given_answer) - except Exception: - return False - except TimeoutError: - return False - - -def is_value_equal(given_answer: str, ground_truth: str) -> bool: - assert ground_truth is not None - ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth) - given_answer_normalized_mathd = mathd_normalize_answer(given_answer) - - str_equal = ground_truth_normalized_mathd == given_answer_normalized_mathd - try: - number_equal = float(ground_truth_normalized_mathd) == float(given_answer_normalized_mathd) - return str_equal or number_equal - except Exception: - return str_equal - - -# sympy might hang -- we don't care about trying to be lenient in these cases -BAD_SUBSTRINGS = ["^{", "^("] -BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] -TUPLE_CHARS = "()[]" - - -def _sympy_parse(expr: str): - """Parses an expression with sympy.""" - py_expr = expr.replace("^", "**") - return sympy_parser.parse_expr( - py_expr, - transformations=(sympy_parser.standard_transformations + (sympy_parser.implicit_multiplication_application,)), - ) - - -def _parse_latex(expr: str) -> str: - """Attempts to parse latex to an expression sympy can read.""" - expr = expr.replace("\\tfrac", "\\frac") - expr = expr.replace("\\dfrac", "\\frac") - expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. - expr = latex2text.LatexNodes2Text().latex_to_text(expr) - - # Replace the specific characters that this parser uses. - expr = expr.replace("√", "sqrt") - expr = expr.replace("π", "pi") - expr = expr.replace("∞", "inf") - expr = expr.replace("∪", "U") - expr = expr.replace("·", "*") - expr = expr.replace("×", "*") - - return expr.strip() - - -def _is_float(num: str) -> bool: - try: - float(num) - return True - except ValueError: - return False - - -def _is_int(x: float) -> bool: - try: - return abs(x - int(round(x))) <= 1e-7 - except Exception: - return False - - -def _is_frac(expr: str) -> bool: - return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) - - -def _str_is_int(x: str) -> bool: - try: - x = _strip_properly_formatted_commas(x) - x = float(x) - return abs(x - int(round(x))) <= 1e-7 - except Exception: - return False - - -def _str_to_int(x: str) -> bool: - x = x.replace(",", "") - x = float(x) - return int(x) - - -def _inject_implicit_mixed_number(step: str): - """ - Automatically make a mixed number evalable - e.g. 7 3/4 => 7+3/4 - """ - p1 = re.compile("([0-9]) +([0-9])") - step = p1.sub("\\1+\\2", step) ## implicit mults - return step - - -def _strip_properly_formatted_commas(expr: str): - # We want to be careful because we don't want to strip tuple commas - p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") - while True: - next_expr = p1.sub("\\1\\3\\4", expr) - if next_expr == expr: - break - expr = next_expr - return next_expr - - -def _normalize(expr: str) -> str: - """Normalize answer expressions.""" - if expr is None: - return None - - # Remove enclosing `\text{}`. - m = re.search("^\\\\text\{(?P.+?)\}$", expr) - if m is not None: - expr = m.group("text") - - expr = expr.replace("\\%", "%") - expr = expr.replace("\\$", "$") - expr = expr.replace("$", "") - expr = expr.replace("%", "") - expr = expr.replace(" or ", " , ") - expr = expr.replace(" and ", " , ") - - expr = expr.replace("million", "*10^6") - expr = expr.replace("billion", "*10^9") - expr = expr.replace("trillion", "*10^12") - - for unit in [ - "degree", - "cm", - "centimeter", - "meter", - "mile", - "second", - "minute", - "hour", - "day", - "week", - "month", - "year", - "foot", - "feet", - "inch", - "yard", - ]: - expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) - expr = re.sub("\^ *\\\\circ", "", expr) - - if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": - expr = expr[1:-1] - - expr = re.sub(",\\\\! *", "", expr) - if _is_float(expr) and _is_int(float(expr)): - expr = str(int(round(float(expr)))) - if "\\" in expr: - try: - expr = _parse_latex(expr) - except Exception: - pass - - # edge case with mixed numbers and negative signs - expr = re.sub("- *", "-", expr) - - expr = _inject_implicit_mixed_number(expr) - expr = expr.replace(" ", "") - - # if we somehow still have latex braces here, just drop them - expr = expr.replace("{", "") - expr = expr.replace("}", "") - - # don't be case sensitive for text answers - expr = expr.lower() - - if _str_is_int(expr): - expr = str(_str_to_int(expr)) - - return expr - - -def count_unknown_letters_in_expr(expr: str): - expr = expr.replace("sqrt", "") - expr = expr.replace("frac", "") - letters_in_expr = set([x for x in expr if x.isalpha()]) - return len(letters_in_expr) - - -def should_allow_eval(expr: str): - # we don't want to try parsing unknown text or functions of more than two variables - if count_unknown_letters_in_expr(expr) > 2: - return False - - for bad_string in BAD_SUBSTRINGS: - if bad_string in expr: - return False - - for bad_regex in BAD_REGEXES: - if re.search(bad_regex, expr) is not None: - return False - - return True - - -@timeout_ours(timeout_seconds=5) -def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): - are_equal = False - try: - expr = f"({ground_truth_normalized})-({given_normalized})" - if should_allow_eval(expr): - sympy_diff = _sympy_parse(expr) - simplified = sympy.simplify(sympy_diff) - if simplified == 0: - are_equal = True - except Exception: - pass - return are_equal - - -def split_tuple(expr: str): - """ - Split the elements in a tuple/interval, while handling well-formatted commas in large numbers - """ - expr = _strip_properly_formatted_commas(expr) - if len(expr) == 0: - return [] - if ( - len(expr) > 2 - and expr[0] in TUPLE_CHARS - and expr[-1] in TUPLE_CHARS - and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) - ): - elems = [elem.strip() for elem in expr[1:-1].split(",")] - else: - elems = [expr] - return elems - - -def last_boxed_only_string(string): - idx = string.rfind("\\boxed") - if idx < 0: - idx = string.rfind("\\fbox") - if idx < 0: - return None - - i = idx - right_brace_idx = None - num_left_braces_open = 0 - while i < len(string): - if string[i] == "{": - num_left_braces_open += 1 - if string[i] == "}": - num_left_braces_open -= 1 - if num_left_braces_open == 0: - right_brace_idx = i - break - i += 1 - if right_brace_idx is None: - retval = None - else: - retval = string[idx : right_brace_idx + 1] - - return retval - - -def remove_boxed(s): - left = "\\boxed{" - try: - assert s[: len(left)] == left - assert s[-1] == "}" - return s[len(left) : -1] - except Exception: - return None - - -def extract_boxed_answer(solution: str) -> str: - """Extract the answer from inside a LaTeX \\boxed{} command""" - solution = last_boxed_only_string(solution) - solution = remove_boxed(solution) - return solution - - -def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool: - ground_truth_normalized = _normalize(ground_truth) - given_normalized = _normalize(given_answer) - - if ground_truth_normalized is None: - return False - - if ground_truth_normalized == given_normalized: - return True - - if len(given_normalized) == 0: - return False - - ground_truth_elems = split_tuple(ground_truth_normalized) - given_elems = split_tuple(given_normalized) - - if len(ground_truth_elems) > 1 and ( - ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1] - ): - is_correct = False - elif len(ground_truth_elems) != len(given_elems): - is_correct = False - else: - for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems, strict=True): - if _is_frac(ground_truth_elem) and _is_frac(given_elem): - # if fractions aren't reduced, then shouldn't be marked as correct - # so, we don't want to allow sympy.simplify in this case - is_correct = ground_truth_elem == given_elem - elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): - # if the ground truth answer is an integer, we require the given answer to be a strict match - # (no sympy.simplify) - is_correct = False - else: - is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) - if not is_correct: - break - - return is_correct - - -def grade_answer_mathd(given_answer: str, ground_truth: str) -> bool: - ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth) - given_answer_normalized_mathd = mathd_normalize_answer(given_answer) - - # be at least as lenient as mathd - if ground_truth_normalized_mathd == given_answer_normalized_mathd: - return True - return False - - -def extract_answer(passage: str) -> str: - if "\\boxed" in passage: - return extract_boxed_answer(passage) - return None - - -def grade(model_answer: str, gt_answer: str, fast: bool = True): - if "\\boxed" in gt_answer: - gt_answer = extract_answer(gt_answer) - correct = grade_answer_mathd(model_answer, gt_answer) or grade_answer_sympy(model_answer, gt_answer) - if not fast: - # This mode further uses math_verify to recall originally false positives. - # Will be a bit slower, and sensitive to bad inputs. - correct = correct or is_latex_equal( - model_answer, - gt_answer, - ) - return correct - - -def compute_score(model_response, gt_answer, fast=False): - model_answer = extract_answer(model_response) - if model_answer is None: - return { - "score": 0.0, - "format_score": 0.0, - "acc": False, - "extracted_gt": gt_answer, - # "extracted_pred": None, - } - # return 0.0, 0.0 # Cannot even parse anything. - is_correct = False - if isinstance(gt_answer, float) or isinstance(gt_answer, int): - gt_answer = str(gt_answer) - if isinstance(gt_answer, str): - is_correct = grade(model_answer, gt_answer, fast) - elif isinstance(gt_answer, list): - is_correct = False - for gt in gt_answer: - is_correct |= grade(model_answer, gt, fast) - if is_correct: - return { - "score": 1.0, - "format_score": 1.0, - "acc": True, - "extracted_gt": gt_answer, - # "extracted_pred": None, - } - else: - return { - "score": 0.0, - "format_score": 1.0, - "acc": False, - "extracted_gt": gt_answer, - # "extracted_pred": None, - } diff --git a/recipe/entropy/reward_score/entropy_math/grader.py b/recipe/entropy/reward_score/entropy_math/grader.py deleted file mode 100644 index 02507e359..000000000 --- a/recipe/entropy/reward_score/entropy_math/grader.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright (c) Microsoft Corporation. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE - -# Copyright (c) 2023 OpenAI -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -# Copyright (c) 2021 Dan Hendrycks -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -# Copyright 2024 PRIME team and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: -- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py -- https://github.com/microsoft/ProphetNet/tree/master/CRITIC -- https://github.com/openai/prm800k -""" - -import contextlib -import math -import re -from math import isclose - -# sympy related -from sympy import N, simplify -from sympy.parsing.latex import parse_latex -from sympy.parsing.sympy_parser import parse_expr - -# verl related -from verl.utils.py_functional import timeout_limit - - -def is_digit(s): - try: - if "{,}" in str(s): - num = float(str(s).replace("{,}", "")) - return True, num - - num = float(str(s).replace(",", "")) - return True, num - except ValueError: - return False, None - - -def normalize(answer, pi) -> str: - # checking if answer is $ and removing $ in that case to compare - if isinstance(answer, str) and bool(re.match(r"\$\d+(\.\d+)?", answer)): - return answer[1:] - - # checking if answer is % or \\% and removing % - if isinstance(answer, str) and ( - bool(re.match(r"^\d+(\.\d+)?%$", answer)) or bool(re.match(r"^\d+(\.\d+)?\\%$", answer)) - ): - return answer.replace("\\%", "").replace("%", "") - - # handle base - answer = handle_base(answer) - - # handle pi - answer = handle_pi(answer, pi) - - return answer - - -def handle_base(x) -> str: - if isinstance(x, str) and "_" in x: - # Due to base - x = x.split("_")[0] - x = float(x) - return int(x) - return x - - -def handle_pi(string, pi): - if isinstance(string, str) and "\pi" in string: - # Find the first occurrence of "\pi" - idx = string.find("\pi") - - # Iterate over the string and find all occurrences of "\pi" with a valid previous character - while idx != -1: - if idx > 0 and string[idx - 1].isdigit(): - # Replace "\pi" with "*math.pi" if the previous character is a digit - string = string[:idx] + f"*{pi}" + string[idx + 3 :] - else: - # Replace "\pi" with "1*math.pi" if the previous character is not a digit - string = string[:idx] + f"1*{pi}" + string[idx + 3 :] - - # Find the next occurrence of "\pi" - idx = string.find("\pi", idx + 1) - - # Evaluate the expression using eval() function - with contextlib.suppress(Exception): - string = eval(string) - - return string - - -def math_equal( - prediction: bool | float | str, - reference: float | str, - include_percentage: bool = True, - tolerance: float = 1e-4, - timeout: float = 10.0, - pi: float = math.pi, -) -> bool: - """ - Exact match of math if and only if: - 1. numerical equal: both can convert to float and are equal - 2. symbolic equal: both can convert to sympy expression and are equal - """ - - prediction = normalize(prediction, pi) - reference = normalize(reference, pi) - - if isinstance(prediction, str) and len(prediction) > 1000: # handling weird corner-cases - prediction = prediction[:1000] - - # 0. string comparison - if isinstance(prediction, str) and isinstance(reference, str): - if prediction.strip().lower() == reference.strip().lower(): - return True - if prediction.replace(" ", "") == reference.replace(" ", ""): - return True - - try: # 1. numerical equal - if is_digit(prediction)[0] and is_digit(reference)[0]: - prediction = is_digit(prediction)[1] - reference = is_digit(reference)[1] - # number questions - gt_result = [reference / 100, reference, reference * 100] if include_percentage else [reference] - for item in gt_result: - try: - if isclose(item, prediction, rel_tol=tolerance): - return True - except Exception: - continue - return False - except Exception: - pass - - if not prediction and prediction not in [0, False]: - return False - - # 2. symbolic equal - reference = str(reference).strip() - prediction = str(prediction).strip() - - ## deal with [], (), {} - prediction = format_intervals(prediction) - - pred_str, ref_str = prediction, reference - if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or ( - prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[") - ): - pred_str = pred_str.strip("[]()") - ref_str = ref_str.strip("[]()") - for s in ["{", "}", "(", ")"]: - ref_str = ref_str.replace(s, "") - pred_str = pred_str.replace(s, "") - if pred_str == ref_str: - return True - - ## [a, b] vs. [c, d], return a==c and b==d - if ( - prediction - and reference - and prediction[0] in "([" - and prediction[-1] in ")]" - and prediction[0] == reference[0] - and prediction[-1] == reference[-1] - ): - pred_parts = prediction[1:-1].split(",") - ref_parts = reference[1:-1].split(",") - if len(pred_parts) == len(ref_parts) and all( - [ - math_equal(pred_pt, ref_pt, include_percentage, tolerance) - for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True) - ] - ): - return True - - if "," in prediction and "," in reference: - pred_parts = [item.strip() for item in prediction.split(",")] - ref_parts = [item.strip() for item in reference.split(",")] - - if len(pred_parts) == len(ref_parts): - return bool( - all( - [ - math_equal(pred_parts[i], ref_parts[i], include_percentage, tolerance) - for i in range(len(pred_parts)) - ] - ) - ) - - # if we have point == tuple of values - if prediction.startswith("Point") and reference[0] == "(" and reference[-1] == ")": - pred_parts = prediction[prediction.find("(") + 1 : -1].split(",") - ref_parts = reference[1:-1].split(",") - if len(pred_parts) == len(ref_parts) and all( - [ - math_equal(pred_pt, ref_pt, include_percentage, tolerance) - for pred_pt, ref_pt in zip(pred_parts, ref_parts, strict=True) - ] - ): - return True - - # if reference is a matrix - if "\begin{pmatrix}" in reference and prediction.startswith("Matrix"): - try: - pred_matrix = parse_expr(prediction) - ref_matrix_items = reference.split()[1:-1:2] - if len(pred_matrix) == len(ref_matrix_items) and all( - [ - math_equal(pred, ref, include_percentage, tolerance) - for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True) - ] - ): - return True - except Exception: - pass - elif "\begin{pmatrix}" in reference and prediction.startswith("[") and prediction.endswith("]"): - if isinstance(eval(prediction), list): - try: - pred_matrix = eval(prediction) - # ref_matrix_items = reference.split()[1:-1:2] - ref_matrix_items = ( - reference.lstrip("\\begin{pmatrix}") # noqa: B005 - .lstrip("\begin{pmatrix}") - .rstrip("\\end{pmatrix}") - .rstrip("\end{pmatrix}") - ) # noqa: B005 - ref_matrix_items = ref_matrix_items.split("\\") - ref_matrix_items = [row.split("&") if "&" in row else row for row in ref_matrix_items] - if len(pred_matrix) == len(ref_matrix_items) and all( - [ - math_equal(pred, ref, include_percentage, tolerance) - for ref, pred in zip(ref_matrix_items, pred_matrix, strict=True) - ] - ): - return True - except Exception: - pass - - return symbolic_equal(prediction, reference, tolerance, timeout) - - -def symbolic_equal(a, b, tolerance, timeout=10.0): - def _parse(s): - for f in [parse_expr, parse_latex]: - try: - with timeout_limit(seconds=timeout): - return f(s) - except TimeoutError: - print(f"Parsing timed out for {s}") - continue - except Exception: - continue - return s - - a = _parse(a) - b = _parse(b) - - try: - with timeout_limit(seconds=timeout): - if simplify(a - b) == 0: - return True - except TimeoutError: - print(f"Simplification timed out for {a} - {b}") - pass - except Exception: - pass - - try: - with timeout_limit(seconds=timeout): - if isclose(N(a), N(b), rel_tol=tolerance): - return True - except TimeoutError: - print(f"Numerical evaluation timed out for {a}, {b}") - pass - except Exception: - pass - return False - - -def format_intervals(prediction): - patterns = { - "Interval(": r"^Interval\((.*)\)$", - "Interval.Ropen(": r"^Interval\.Ropen\((.*)\)$", - "Interval.Lopen(": r"^Interval\.Lopen\((.*)\)$", - "Interval.open(": r"^Interval\.open\((.*)\)$", - } - - for key, pattern in patterns.items(): - match = re.match(pattern, prediction) - if match: - inner_content = match.group(1) - - if key == "Interval(": # Intarval(a, b) == [a, b] - return f"[{inner_content}]" - elif key == "Interval.Ropen(": # Intarval.Ropen(a, b) == [a, b) - return f"[{inner_content})" - elif key == "Interval.Lopen(": # Intarval.Lopen(a, b) == (a, b] - return f"({inner_content}]" - elif key == "Interval.open(": # Intarval.open(a, b) == (a, b) - return f"({inner_content})" - - return prediction diff --git a/recipe/entropy/reward_score/entropy_math/math_normalize.py b/recipe/entropy/reward_score/entropy_math/math_normalize.py deleted file mode 100644 index 74d94cc41..000000000 --- a/recipe/entropy/reward_score/entropy_math/math_normalize.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright 2024 PRIME team and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright (c) 2021 Dan Hendrycks -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -""" -This logic is largely copied from the Hendrycks' MATH release (math_equivalence). - -From: https://github.com/openai/prm800k/blob/main/prm800k/grading/math_normalize.py -""" - -import re -from typing import Optional - - -def normalize_answer(answer: Optional[str]) -> Optional[str]: - if answer is None: - return None - answer = answer.strip() - try: - # Remove enclosing `\text{}`. - m = re.search("^\\\\text\{(?P.+?)\}$", answer) - if m is not None: - answer = m.group("text").strip() - return _strip_string(answer) - except: # noqa: E722 - return answer - - -def _fix_fracs(string): - substrs = string.split("\\frac") - new_str = substrs[0] - if len(substrs) > 1: - substrs = substrs[1:] - for substr in substrs: - new_str += "\\frac" - if substr[0] == "{": - new_str += substr - else: - try: - assert len(substr) >= 2 - except: # noqa: E722 - return string - a = substr[0] - b = substr[1] - if b != "{": - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr - else: - new_str += "{" + a + "}{" + b + "}" - else: - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr - else: - new_str += "{" + a + "}" + b - string = new_str - return string - - -def _fix_a_slash_b(string): - if len(string.split("/")) != 2: - return string - a = string.split("/")[0] - b = string.split("/")[1] - try: - a = int(a) - b = int(b) - assert string == "{}/{}".format(a, b) - new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" - return new_string - except: # noqa: E722 - return string - - -def _remove_right_units(string): - # "\\text{ " only ever occurs (at least in the val set) when describing units - if "\\text{ " in string: - splits = string.split("\\text{ ") - assert len(splits) == 2 - return splits[0] - else: - return string - - -def _fix_sqrt(string): - if "\\sqrt" not in string: - return string - splits = string.split("\\sqrt") - new_string = splits[0] - for split in splits[1:]: - if split[0] != "{": - a = split[0] - new_substr = "\\sqrt{" + a + "}" + split[1:] - else: - new_substr = "\\sqrt" + split - new_string += new_substr - return new_string - - -def _strip_string(string): - # linebreaks - string = string.replace("\n", "") - - # remove inverse spaces - string = string.replace("\\!", "") - - # replace \\ with \ - string = string.replace("\\\\", "\\") - - # replace tfrac and dfrac with frac - string = string.replace("tfrac", "frac") - string = string.replace("dfrac", "frac") - - # remove \left and \right - string = string.replace("\\left", "") - string = string.replace("\\right", "") - - # Remove circ (degrees) - string = string.replace("^{\\circ}", "") - string = string.replace("^\\circ", "") - - # remove dollar signs - string = string.replace("\\$", "") - - # remove units (on the right) - string = _remove_right_units(string) - - # remove percentage - string = string.replace("\\%", "") - string = string.replace("\%", "") - - # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string - string = string.replace(" .", " 0.") - string = string.replace("{.", "{0.") - # if empty, return empty string - if len(string) == 0: - return string - if string[0] == ".": - string = "0" + string - - # to consider: get rid of e.g. "k = " or "q = " at beginning - if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: - string = string.split("=")[1] - - # fix sqrt3 --> sqrt{3} - string = _fix_sqrt(string) - - # remove spaces - string = string.replace(" ", "") - - # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). - # Also does a/b --> \\frac{a}{b} - string = _fix_fracs(string) - - # manually change 0.5 --> \frac{1}{2} - if string == "0.5": - string = "\\frac{1}{2}" - - # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y - string = _fix_a_slash_b(string) - - return string diff --git a/recipe/genrm_remote/README.md b/recipe/genrm_remote/README.md deleted file mode 100644 index 1a800fd88..000000000 --- a/recipe/genrm_remote/README.md +++ /dev/null @@ -1,39 +0,0 @@ -# Generative Reward Model - -## Scripts - -### Step 1: Launch a vLLM Server (Optional) - -Deploy the pretrained GenRM model using vLLM. Skip this step if you want to use an external api service. - -```bash -vllm serve verl-team/GenRM-CI-Test-1.5B --served-model-name genrm-demo -``` - -### Step 2: Perform RL using GenRM - -```bash -bash recipe/api-genrm/run_genrm_remote.sh -``` - -The implementation works by passing a customized reward function (see `reward_function.py`) - -For convenience, we run both the RL training and server on the same machine. To use an external server, configure the `BASE_URL` and `API_KEY` in `reward_function.py` first. - -## Advanced: Customizing Your GenRM - -You can use sglang server with data parallel for faster inference: - -```bash -CUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4 -``` - -Note that you should modify the `BASE_URL` in `reward_function.py` to match your SGLang Server address. - -You can also create your own customized GenRM by implementing a custom reward function. Here are some tips for customizing your own GenRM based on `reward_function.py`: - -- Design appropriate prompts for your GenRM -- Convert GenRM responses into RL rewards -- ... - -Since these aspects are highly flexible, we only provide a demo implementation. The actual design and implementation of GenRM is left to the user's discretion. diff --git a/recipe/genrm_remote/reward_function.py b/recipe/genrm_remote/reward_function.py deleted file mode 100644 index b2d3fbc2f..000000000 --- a/recipe/genrm_remote/reward_function.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from concurrent.futures import ThreadPoolExecutor -from time import sleep - -import requests - -from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed - -BASE_URL = "http://localhost:30000" -API_KEY = "EMPTY" -MAX_RETRIES = 3 -BASE_DELAY = 2 -MAX_WORKERS = 32 -MODEL_NAME = "genrm-demo" -GENRM_PROMPT_TEMPLATE = """ -The following is a math problem and an AI solution: - -[Math Problem] - -{problem} - -[AI Solution] - -{solution} - -Your task is to review and critique the solution step by step, and output whether the AI solution is correct. - -Please put your final answer (i.e., 'True' or 'False') in \\boxed{{}}. -""".strip() - - -def get_response(problem, solution_str, ground_truth): - prompt = GENRM_PROMPT_TEMPLATE.format(problem=problem, solution=solution_str) - messages = [{"role": "user", "content": prompt}] - for attempt in range(MAX_RETRIES): - try: - headers = {"Content-Type": "application/json"} - chat_url = f"{BASE_URL}/v1/chat/completions" - data = {"model": MODEL_NAME, "messages": messages} - output = requests.post(chat_url, headers=headers, json=data, timeout=30) - response = output.json()["choices"][0]["message"]["content"] - return response - except Exception as e: - if attempt < MAX_RETRIES - 1: - print("Exception: ", repr(e)) - delay = BASE_DELAY * (2**attempt) - print(f"Retrying in {delay} seconds...") - sleep(delay) - else: - print(f"Failed after {MAX_RETRIES} attempts. Error: {e}") - - raise ConnectionRefusedError(f"Failed to run the model for {prompt}!") - - -def compute_reward(response): - reward_score = 0.0 - try: - boxed_result = last_boxed_only_string(response) - if boxed_result is not None: - result = remove_boxed(boxed_result) - reward_score = float(result == "True") - except Exception as e: - print(e) - return reward_score - - -def compute_score(data_source, solution_str, ground_truth, extra_info): - split = extra_info["split"] - from verl.utils.reward_score import default_compute_score - - func_rm_score = default_compute_score(data_source, solution_str, ground_truth, extra_info) - - if split == "test": - return func_rm_score - else: - problem = extra_info["question"] - response = get_response(problem, solution_str, ground_truth) - if response is not None: - reward_score = compute_reward(response) - else: - reward_score = 0.0 - - return reward_score - - -def compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos): - with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: - futures = [] - for data_source, solution_str, ground_truth, extra_info in zip( - data_sources, solution_strs, ground_truths, extra_infos, strict=True - ): - future = executor.submit(compute_score, data_source, solution_str, ground_truth, extra_info) - futures.append(future) - - results = [future.result() for future in futures] - - return results diff --git a/recipe/genrm_remote/run_genrm_remote.sh b/recipe/genrm_remote/run_genrm_remote.sh deleted file mode 100644 index 6656dc8a7..000000000 --- a/recipe/genrm_remote/run_genrm_remote.sh +++ /dev/null @@ -1,45 +0,0 @@ -# vllm server -# CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve verl-team/GenRM-CI-Test-1.5B --served_model_name genrm-demo - -# sglang server -# CUDA_VISIBLE_DEVICES=0,1,2,3 python -m sglang_router.launch_server --model-path verl-team/GenRM-CI-Test-1.5B --dp-size 4 - -set -x - -CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=${HOME}/data/gsm8k/train.parquet \ - data.val_files=${HOME}/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.rollout.n=8 \ - algorithm.use_kl_in_reward=False \ - reward_model.reward_manager=batch \ - custom_reward_function.path=recipe/genrm_remote/reward_function.py \ - custom_reward_function.name=compute_score_batch \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_func_rm_example_gsm8k' \ - trainer.experiment_name='qwen2_5_3b_gen_rm' \ - trainer.n_gpus_per_node=4 \ - trainer.val_before_train=True \ - trainer.nnodes=1 \ - trainer.save_freq=20 \ - trainer.test_freq=5 \ - trainer.total_epochs=10 \ - trainer.resume_mode='disable' diff --git a/recipe/langgraph_agent/chat_model.py b/recipe/langgraph_agent/chat_model.py deleted file mode 100644 index f41f6ac37..000000000 --- a/recipe/langgraph_agent/chat_model.py +++ /dev/null @@ -1,357 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Ref: https://python.langchain.com/docs/how_to/custom_chat_model/ -""" - -import asyncio -import json -import logging -import os -import uuid -from typing import Any, Optional - -from langchain_core.language_models import BaseChatModel -from langchain_core.language_models.base import LanguageModelInput -from langchain_core.messages import ( - AIMessage, - BaseMessage, - convert_to_openai_messages, -) -from langchain_core.messages.tool import InvalidToolCall, ToolCall -from langchain_core.outputs import ChatGeneration, ChatResult -from langchain_core.runnables import Runnable, RunnableConfig -from langchain_core.tools import StructuredTool -from langchain_core.utils.function_calling import convert_to_openai_tool -from pydantic import Field - -from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, AsyncLLMServerManager -from verl.experimental.agent_loop.tool_parser import ToolParser - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class MaxTokenExceededError(Exception): - """Indicate that history chat messages + tool message exceeds LLM max_tokens.""" - - pass - - -class ChatModel(BaseChatModel): - model_name: str = Field(alias="model") - """The name of the model""" - - client: AsyncLLMServerManager - """AsyncLLM server manager""" - - tokenizer: Any - """Tokenizer for the model""" - - max_tokens: int - """Max tokens to generate""" - - tool_parser: str = "hermes" - """Tool parser for the model""" - - max_parallel_calls: int = 1 - """Max parallel tool calls""" - - temperature: float = 1.0 - """Temperature for sampling""" - - top_p: float = 1.0 - """Top p for sampling""" - - repetition_penalty: float = 1.0 - """Repetition penalty for sampling""" - - def bind_tools(self, tools, **kwargs) -> Runnable[LanguageModelInput, BaseMessage]: - """Bind tools to the model. - - Args: - tools: Sequence of tools to bind to the model. - - Returns: - A Runnable that returns a message. - """ - formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools] - - # used to remove system prompt prefix when encoding tool response - system_prompt = self.tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True) - kwargs["system_prompt"] = system_prompt - - return self.bind(tools=formatted_tools, **kwargs) - - def with_structured_output( - self, - schema: dict | type, - *, - include_raw: bool = False, - **kwargs: Any, - ) -> Runnable[LanguageModelInput, dict | BaseChatModel]: - """Ref: https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/""" - raise NotImplementedError - - def _generate( - self, - messages: list[BaseMessage], - stop: Optional[list[str]] = None, - **kwargs: Any, - ) -> ChatResult: - raise NotImplementedError - - async def _agenerate( - self, - messages: list[BaseMessage], - stop: Optional[list[str]] = None, - **kwargs: Any, - ) -> ChatResult: - """Asynchronously generate chat completion message. - - Args: - messages (list[BaseMessage]): List of list of messages. - stop (Optional[list[str]], optional): Stop words to use when generating. Model output is cut off at the - first occurrence of any of these substrings. Defaults to None. - - Returns: - ChatResult: Chat result. - """ - request_id, prompt_ids, response_mask = await self._preprocess(messages, **kwargs) - - sampling_params = { - "temperature": self.temperature, - "top_p": self.top_p, - "repetition_penalty": self.repetition_penalty, - } - if "sampling_params" in kwargs: - sampling_params.update(kwargs["sampling_params"]) - - response_ids = await self.client.generate( - request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params - ) - - message = await self._postprocess(request_id, prompt_ids, response_mask, response_ids, **kwargs) - generation = ChatGeneration(message=message) - return ChatResult(generations=[generation]) - - @property - def _llm_type(self) -> str: - """Get the type of language model used by this chat model.""" - return self.model_name - - async def _preprocess(self, messages: list[BaseMessage], **kwargs: Any) -> tuple[str, list[int], list[int]]: - """Preprocess messages for chat completion. - - To ensure strong consistency with policy model, AsyncLLM server generate response with token in token out - instead of messages list. - - But all agent frameworks use messages list to represent chat history. To mitigate the gap, we store trajectory - (prompt_ids, response_mask) in lastest AIMessage.response_metadata. - - 1. Encode ToolMessage to token ids. - 2. Retrieve trajectory (prompt_ids, response_mask) from lastest AIMessage.response_metadata. - 3. Append ToolMessage token ids to prompt_ids, and append 0 to response_mask. - - Ref: https://python.langchain.com/docs/concepts/chat_history/ - - Args: - messages (list[BaseMessage]): List of messages. - - Returns: - tuple[str, list[int], list[int]]: Request id, prompt ids, response mask. - """ - # messages: [system], human, ai, human|tool, ai, human|tool, ... - assert messages[-1].type in ["human", "tool"], ( - f"Last message must be human or tool, but got {messages[-1].type}" - ) - loop = asyncio.get_running_loop() - - # Case 1: initial chat completion: [system], human - if messages[-1].type == "human" and (len(messages) == 1 or messages[-2].type != "ai"): - prompt_ids = await loop.run_in_executor( - None, - lambda: self.tokenizer.apply_chat_template( - convert_to_openai_messages(messages), - tools=kwargs.get("tools"), - add_generation_prompt=True, - tokenize=True, - ), - ) - return str(uuid.uuid4()), prompt_ids, [] - - # Case 2: follow up chat completion with tool/human response: [system], human, ai, human|tool, ... - for i in range(len(messages) - 1, -1, -1): - if messages[i].type == "ai": - break - assert "prompt_ids" in messages[i].response_metadata, "Last message must have prompt_ids in response_metadata" - assert "response_mask" in messages[i].response_metadata, ( - "Last message must have response_mask in response_metadata" - ) - - # encode tool response - tool_responses = convert_to_openai_messages(messages[i + 1 :]) - tool_response_ids = await loop.run_in_executor( - None, - lambda messages=tool_responses: self.tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True - ), - ) - tool_response_ids = tool_response_ids[len(kwargs["system_prompt"]) :] - - # stop generation if response length exceeds max response length - if len(messages[i].response_metadata["response_mask"]) + len(tool_response_ids) >= self.max_tokens: - raise MaxTokenExceededError(f"Max response length {self.max_tokens} exceeded") - - # append tool response to prompt - request_id = messages[i].response_metadata.pop("request_id") - prompt_ids = messages[i].response_metadata.pop("prompt_ids") - response_mask = messages[i].response_metadata.pop("response_mask") - prompt_ids += tool_response_ids - response_mask += [0] * len(tool_response_ids) - - return request_id, prompt_ids, response_mask - - async def _postprocess( - self, request_id: str, prompt_ids: list[int], response_mask: list[int], response_ids: list[int], **kwargs: Any - ) -> AIMessage: - """Postprocess response_ids when chat completion is done. - - 1. Decode response_ids, parse tool calls to AIMessage. - 2. Append response_ids to prompt_ids, and append 1 to response_mask. - 3. Store trajectory (prompt_ids, response_mask) in AIMessage.response_metadata. - - Args: - request_id (str): Unique request id. - prompt_ids (list[int]): Input prompt token ids in this chat completion. - response_mask (list[int]): Response mask before this chat completion. - response_ids (list[int]): LLM generated token ids in this chat completion. - - Returns: - AIMessage: Postprocessed message. - """ - prompt_ids += response_ids - response_mask += [1] * len(response_ids) - - tool_parser = ToolParser.get_tool_parser(self.tool_parser, self.tokenizer) - content, function_calls = await tool_parser.extract_tool_calls(response_ids) - - tool_calls, invalid_tool_calls = [], [] - for function_call in function_calls: - try: - args = json.loads(function_call.arguments) - if not isinstance(args, dict): - raise json.JSONDecodeError(f"Invalid json tool arguments: {args}") - tool_call = ToolCall( - args=args, - name=function_call.name, - id=str(uuid.uuid4()), - ) - tool_calls.append(tool_call) - except json.JSONDecodeError as e: - logger.warning(f"Invalid json tool arguments: {e}") - tool_call = InvalidToolCall( - args=function_call.arguments, - name=function_call.name, - error=f"Invalid json tool arguments: {e}", - ) - invalid_tool_calls.append(tool_call) - - message = AIMessage( - content=content, - tool_calls=tool_calls[: self.max_parallel_calls], - invalid_tool_calls=invalid_tool_calls[: self.max_parallel_calls], - response_metadata={ - "request_id": request_id, - "prompt_ids": prompt_ids, - "response_mask": response_mask, - }, - ) - return message - - -class TruncateStructuredTool(StructuredTool): - """Structured tool with response truncation.""" - - tool_response_truncate_side: str - """truncate side of tool response: left, middle, right""" - - max_tool_response_length: int - """max length of tool response""" - - async def _arun( - self, - *args: Any, - config: RunnableConfig, - **kwargs: Any, - ) -> Any: - tool_response = await super()._arun(*args, config=config, **kwargs) - tool_response = str(tool_response) - - if len(tool_response) > self.max_tool_response_length: - if self.tool_response_truncate_side == "left": - tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)" - elif self.tool_response_truncate_side == "right": - tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :] - else: - length = self.max_tool_response_length // 2 - tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:] - - return tool_response - - -def convert_to_agent_output(messages: list[BaseMessage], response_length: int) -> AgentLoopOutput: - """Convert messages to AgentLoopOutput. - - Args: - messages (List[BaseMessage]): List of messages, last message must be assistant - with response_metadata containing `prompt_ids` and `response_mask`. - response_length (int): Max length of response. - - Returns: - AgentLoopOutput: agent loop output trajectory used for training. - """ - # skip last tool calls - for i in range(len(messages) - 1, -1, -1): - if messages[i].type != "tool": - break - last_message = messages[i] - assert last_message.type == "ai", f"Last message must be assistant, but got {last_message.type}" - assert "prompt_ids" in last_message.response_metadata, "Last message must have prompt_ids in response_metadata" - assert "response_mask" in last_message.response_metadata, ( - "Last message must have response_mask in response_metadata" - ) - - num_turns = 0 - for i in range(len(messages)): - if messages[i].type == "system": - continue - # parallel tool calls are in single turn - if i == 0 or messages[i].type != messages[i - 1].type: - num_turns += 1 - - prompt_ids = last_message.response_metadata["prompt_ids"] - response_mask = last_message.response_metadata["response_mask"] - - response_ids = prompt_ids[-len(response_mask) :] - prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)] - - output = AgentLoopOutput( - prompt_ids=prompt_ids, - response_ids=response_ids[:response_length], - response_mask=response_mask[:response_length], - num_turns=num_turns, - metrics={}, - ) - return output diff --git a/recipe/langgraph_agent/example/README.md b/recipe/langgraph_agent/example/README.md deleted file mode 100644 index 021e875bc..000000000 --- a/recipe/langgraph_agent/example/README.md +++ /dev/null @@ -1,111 +0,0 @@ -# MathExpression: LangGraph Agent Example - -MathExpression is a tiny example to demonstrate multi-turn rollout with [LangGraph ReactAgent](https://langchain-ai.github.io/langgraph/agents/overview/). - -### Define react agent with tool -Firstly, to force ReactAgent to evaluate math expression by tool, we define a special operand `@`: -```python -@tool(parse_docstring=True) -def calculate(a: int, b: int, operand: str) -> int: - """ - Compute the results using operand with two integers - - Args: - a: the first operand - b: the second operand - operand: '+' or '-' or '*' or '@' - """ - assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}" - if operand == "@": - return 3 * a - 2 * b - return eval(f"{a} {operand} {b}") -``` - -Without calling `calculate`, ReactAgent is impossible to evaluate math expression correctly. - -Then, we can equip ReactAgent with `calculate` tool: -```python -class MathExpressionReactAgentLoop(ReactAgentLoop): - @classmethod - def init_class(cls, config, tokenizer): - cls.tools = [calculate] - super().init_class(config, tokenizer) -``` - -We can define agent loop config in yaml file, which will be used by AgentLoopWorker to dynamic load custom AgentLoop class. -```yaml -- name: math_expression - _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop -``` - -### Prepare dataset -Now, let's prepare two small datasets for training and evaluation: -```bash -python recipe/langgraph_agent/example/create_dataset.py -``` - -Note that dataset should contain a column `agent_name` with `math_expression`, which is used by `AgentLoopWorker` to select the -agent loop class. -| prompt | reward_model | agent_name | -|--------------------------------------|------------------------------|-----------------| -| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression | -| [{'role': 'user', 'content': '...'}] | {'ground_truth': '-10', ...} | math_expression | - -Generated math expressions are like below, requiring model to call `calculate` multiple times to solve sub expressions. -``` -(2 @ (8 @ 8 @ 5 @ 5 @ 3) @ 6 @ (1 @ 4 @ 4 @ 4) @ 2) @ 6 -(4.6 @ (9.05 @ 4.0) @ 8.3 @ 1.21) @ 8.6 -9 @ 4 -((2 @ 2) @ (3 @ 3)) @ 4 -``` - -### Training -Hook all these up and start training: -```bash -bash recipe/langgraph_agent/example/run_qwen2.5_3b.sh 2>&1 | tee train.log -``` - -After total 39 steps, model should achieve 100% accuray on test dataset: -- val-aux/lighteval/MATH/reward: 1.0 -- val-aux/num_turns/mean: 9.0, average number of messages include assistant and tool turns. - -We can see model can solve math expression by calling `calculate` tool: -``` -system -You are Qwen, created by Alibaba Cloud. You are a helpful assistant. -user -We define a new math operator @, where you can only call an external tool to compute. Please put your final answer inside \boxed{} only in the last turn. Now answer the following questions: -Compute 5 @ (3 @ 6 @ 8) -assistant -Sure, let's break this down step by step. We need to compute the expression \(5 @ (3 @ 6 @ 8)\). First, we need to compute the inner expression \(3 @ 6 @ 8\). - -Let's compute \(3 @ 6\) first using the external tool. - -{"name": "calculate", "arguments": {"a": 3, "b": 6, "operand": "@"}} -user - --3 - -assistant -We have computed that \(3 @ 6 = -3\). Now we need to compute \(-3 @ 8\). - -Let's use the external tool to compute this. - -{"name": "calculate", "arguments": {"a": -3, "b": 8, "operand": "@"}} -user - --25 - -assistant -We have computed that \(-3 @ 8 = -25\). Now we can put this result into the original expression to get the final result of \(5 @ (3 @ 6 @ 8)\), which is \(5 @ (-25)\). - -Let's compute \(5 @ (-25)\) using the external tool. - -{"name": "calculate", "arguments": {"a": 5, "b": -25, "operand": "@"}} -user - -65 - -assistant -The final result of \(5 @ (3 @ 6 @ 8)\) is \(\boxed{65}\). -``` diff --git a/recipe/langgraph_agent/example/agent.yaml b/recipe/langgraph_agent/example/agent.yaml deleted file mode 100644 index cbd8fb9eb..000000000 --- a/recipe/langgraph_agent/example/agent.yaml +++ /dev/null @@ -1,2 +0,0 @@ -- name: math_expression - _target_: recipe.langgraph_agent.example.math_expression.MathExpressionReactAgentLoop diff --git a/recipe/langgraph_agent/example/create_dataset.py b/recipe/langgraph_agent/example/create_dataset.py deleted file mode 100644 index fb14e755d..000000000 --- a/recipe/langgraph_agent/example/create_dataset.py +++ /dev/null @@ -1,277 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Create dataset for calculator -""" - -import random - -import pandas as pd - - -def generate_math_expression(min_terms=2, max_terms=5, min_number=1, max_number=10, allow_decimals=False, max_depth=2): - """ - Generate a random mathematical expression with operators +, -, *, /, and parentheses. - - Args: - min_terms (int): Minimum number of terms in the expression. - max_terms (int): Maximum number of terms in the expression. - max_number (int): Maximum value for numbers in the expression. - allow_decimals (bool): Whether to allow decimal numbers. - max_depth (int): Maximum nesting depth for parentheses. - - Returns: - str: A valid mathematical expression as a string. - """ - - def generate_number(): - """Generate a random number (integer or float).""" - assert min_number < max_number - num = random.uniform(min_number, max_number) - if not allow_decimals: - num = int(num) - else: - num = round(num, random.randint(0, 2)) # Round to 0-2 decimal places - return str(num) - - def generate_term(depth=0): - """Generate a term (number or parenthesized expression).""" - if depth < max_depth and random.random() < 0.5: # 50% chance to add parentheses - expr = generate_expression(depth + 1) - return f"({expr})" - else: - return generate_number() - - def generate_expression(depth=0): - """Generate a full expression with multiple terms and operators.""" - num_terms = random.randint(min_terms, max_terms) - terms = [generate_term(depth) for _ in range(num_terms)] - - # Randomly select operators - operators = ["+", "-", "*", "/", "@"] - expr = terms[0] - - for i in range(1, num_terms): - # Bias towards + and - for readability - op = random.choices( - operators, - weights=[0, 0, 0, 0, 1], # + and - are 1.5x more likely than * and / - )[0] - expr += f" {op} " + terms[i] - - return expr - - return generate_expression() - - -def test(): - # Example 1: Basic integer expression - print(generate_math_expression()) - # Output: (3 + 7) * 2 - 5 - - # Example 2: Expression with decimals - print(generate_math_expression(allow_decimals=True)) - # Output: 4.5 / (2.1 + 3.7) - 1.2 - - # Example 3: More complex expression with higher depth - print(generate_math_expression(max_terms=6, max_depth=3)) - # Output: ((5 * 2) - (3 + 1)) / (7 - 2) + 4 - - # Example 4: Simplified expression - print(generate_math_expression(min_terms=2, max_terms=3, max_number=5)) - # Output: 4 - 2 * 3 - - -def calculate(expression: str) -> float: - """ - Evaluate a mathematical expression with +, -, *, /, @, and parentheses. - The @ operator is defined as: a @ b = 3a - 2b. - - Args: - expression (str): Input mathematical expression (e.g., "3@2+4"). - - Returns: - float: Result of the evaluated expression. - - Raises: - ValueError: For invalid expressions (e.g., mismatched parentheses, division by zero). - """ - - def tokenize(s: str) -> list: - """Convert the input string into tokens (numbers, operators, parentheses).""" - tokens = [] - i = 0 - while i < len(s): - if s[i].isdigit() or s[i] == ".": - # Parse number (integer or float) - j = i - while j < len(s) and (s[j].isdigit() or s[j] == "."): - j += 1 - tokens.append(s[i:j]) - i = j - elif s[i] in "+-*/@()": - # Operator or parenthesis - tokens.append(s[i]) - i += 1 - elif s[i].isspace(): - # Skip whitespace - i += 1 - else: - raise ValueError(f"Invalid character: {s[i]}") - return tokens - - def infix_to_postfix(tokens: list) -> list: - """Convert infix notation to postfix notation (Reverse Polish Notation).""" - output = [] - stack = [] - # Higher precedence for @ (between * and +) - precedence = {"@": 3, "*": 2, "/": 2, "+": 1, "-": 1} - - for token in tokens: - if token.isdigit() or "." in token: - output.append(token) - elif token == "(": - stack.append(token) - elif token == ")": - while stack and stack[-1] != "(": - output.append(stack.pop()) - if not stack or stack[-1] != "(": - raise ValueError("Mismatched parentheses") - stack.pop() # Discard '(' - else: # Operator - while stack and stack[-1] != "(" and precedence.get(stack[-1], 0) >= precedence.get(token, 0): - output.append(stack.pop()) - stack.append(token) - - # Pop remaining operators - while stack: - if stack[-1] in "()": - raise ValueError("Mismatched parentheses") - output.append(stack.pop()) - - return output - - def evaluate_postfix(postfix: list) -> float: - """Evaluate postfix expression using a stack.""" - stack = [] - for token in postfix: - if token.isdigit() or "." in token: - stack.append(float(token)) - else: - if len(stack) < 2: - raise ValueError("Invalid expression") - b = stack.pop() - a = stack.pop() - if token == "+": - res = a + b - elif token == "-": - res = a - b - elif token == "*": - res = a * b - elif token == "/": - if b == 0: - raise ValueError("Division by zero") - res = a / b - elif token == "@": - res = 3 * a - 2 * b # Custom @ operator implementation - else: - raise ValueError(f"Invalid operator: {token}") - stack.append(res) - - if len(stack) != 1: - raise ValueError("Invalid expression") - return stack[0] - - # Remove spaces and validate parentheses - expression = expression.replace(" ", "") - if expression.count("(") != expression.count(")"): - raise ValueError("Mismatched parentheses") - - tokens = tokenize(expression) - postfix = infix_to_postfix(tokens) - result = evaluate_postfix(postfix) - - # Convert integers to integer representation - if result.is_integer(): - return int(result) - return result - - -def generate_data(total_num_dataset, split): - rl_dataset = { - "prompt": [], - "data_source": [], - "ability": [], - "reward_model": [], - "extra_info": [], - "agent_name": [], - } - - for idx in range(total_num_dataset): - while True: - try: - expression: str = generate_math_expression( - min_terms=2, max_terms=3, min_number=1, max_number=10, allow_decimals=False, max_depth=1 - ) - - num_plus = expression.count("+") - num_minus = expression.count("-") - num_mul = expression.count("*") - num_star = expression.count("@") - - answer = str(calculate(expression)) - # answer = str(eval(expression)) - break - except Exception as e: - print(e) - continue - - num_tool_calls = num_plus + num_minus + num_mul + num_star - - prompt = ( - f"We define a new math operator @, where you can only call an external tool to compute. " - f"Please put your final answer inside \\boxed{{}} only in the last turn. Now answer the " - f"following questions:\nCompute {expression}" - ) - prompt_with_template = [ - { - "role": "user", - "content": prompt, - } - ] - - rl_dataset["prompt"].append(prompt_with_template) - rl_dataset["data_source"].append("lighteval/MATH") - rl_dataset["ability"].append("math") - rl_dataset["reward_model"].append({"style": "lighteval/MATH", "ground_truth": answer}) - rl_dataset["extra_info"].append( - {"index": idx, "expression": expression, "split": split, "expected_tool_calls": num_tool_calls} - ) - rl_dataset["agent_name"].append("math_expression") - - rl_dataset = pd.DataFrame(data=rl_dataset) - return rl_dataset - - -if __name__ == "__main__": - # print(calculate("3@2")) # Output: 5 (3*3 - 2*2) - # print(calculate("3@2+4")) # Output: 9 (5 + 4) - # print(calculate("3*(4@2)")) # Output: 24 (3 * 8) - # print(calculate("(5@3)*2")) # Output: 18 (9 * 2) - - train_dataset = generate_data(total_num_dataset=5000, split="train") - test_dataset = generate_data(total_num_dataset=500, split="test") - - train_dataset.to_parquet("train.parquet") - test_dataset.to_parquet("test.parquet") diff --git a/recipe/langgraph_agent/example/math_expression.py b/recipe/langgraph_agent/example/math_expression.py deleted file mode 100644 index 4532c8af3..000000000 --- a/recipe/langgraph_agent/example/math_expression.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from langchain_core.tools import tool - -from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop - - -@tool(parse_docstring=True) -def calculate(a: int, b: int, operand: str) -> int: - """ - Compute the results using operand with two integers - - Args: - a: the first operand - b: the second operand - operand: '+' or '-' or '*' or '@' - """ - assert operand in ["+", "-", "*", "@"], f"unknown operand {operand}" - if operand == "@": - return 3 * a - 2 * b - return eval(f"{a} {operand} {b}") - - -class MathExpressionReactAgentLoop(ReactAgentLoop): - @classmethod - def init_class(cls, config, tokenizer, **kwargs): - cls.tools = [calculate] - super().init_class(config, tokenizer) diff --git a/recipe/langgraph_agent/example/run_qwen2.5_3b.sh b/recipe/langgraph_agent/example/run_qwen2.5_3b.sh deleted file mode 100644 index 4a398bb6a..000000000 --- a/recipe/langgraph_agent/example/run_qwen2.5_3b.sh +++ /dev/null @@ -1,99 +0,0 @@ -set -x - -# ================= data/model/tool ================= -HDFS_ROOT=${HDFS_ROOT:-$PWD} -DATA_ROOT=${DATA_ROOT:-$PWD} - -model_path=$DATA_ROOT/model/Qwen2.5-3B-Instruct - -train_files=$DATA_ROOT/dataset/math_expression_tool/train.parquet -test_files=$DATA_ROOT/dataset/math_expression_tool/test.parquet - -# agent -agent_loop_config_path=recipe/langgraph_agent/example/agent.yaml - -# wandb -project_name=math_expression_tool -experiment_name=qwen2.5-3b -default_local_dir=$DATA_ROOT/checkpoint/$experiment_name - -# ================= algorithm ================= -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_turns=8 -max_prompt_length=1024 -max_response_length=2048 -actor_lr=1e-6 - -train_batch_size=128 -ppo_mini_batch_size=16 -n_resp_per_prompt=8 -n_resp_per_prompt_val=1 - -# ================= perfomance ================= -infer_tp=2 # vllm -train_sp=4 # train -offload=True - -actor_max_token_len_per_gpu=$(( (max_prompt_length + max_response_length) * 4 )) -log_prob_max_token_len_per_gpu=$(( actor_max_token_len_per_gpu * 2 )) - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=$adv_estimator \ - algorithm.use_kl_in_reward=$use_kl_in_reward \ - algorithm.kl_ctrl.kl_coef=$kl_coef \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.return_raw_chat=True \ - data.train_batch_size=$train_batch_size \ - data.max_prompt_length=$max_prompt_length \ - data.max_response_length=$max_response_length \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=$model_path \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.use_kl_loss=$use_kl_loss \ - actor_rollout_ref.actor.kl_loss_coef=$kl_loss_coef \ - actor_rollout_ref.actor.clip_ratio_low=$clip_ratio_low \ - actor_rollout_ref.actor.clip_ratio_high=$clip_ratio_high \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.optim.lr=$actor_lr \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=$ppo_mini_batch_size \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$actor_max_token_len_per_gpu \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=$train_sp \ - actor_rollout_ref.actor.fsdp_config.param_offload=$offload \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=$offload \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=$log_prob_max_token_len_per_gpu \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.mode=async \ - actor_rollout_ref.rollout.tensor_model_parallel_size=$infer_tp \ - actor_rollout_ref.rollout.multi_turn.max_user_turns=$max_turns \ - actor_rollout_ref.rollout.multi_turn.max_assistant_turns=$max_turns \ - actor_rollout_ref.rollout.multi_turn.format=hermes \ - actor_rollout_ref.rollout.agent.agent_loop_config_path=$agent_loop_config_path \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \ - actor_rollout_ref.rollout.n=$n_resp_per_prompt \ - actor_rollout_ref.rollout.val_kwargs.top_p=0.6 \ - actor_rollout_ref.rollout.val_kwargs.temperature=1.0 \ - actor_rollout_ref.rollout.val_kwargs.n=$n_resp_per_prompt_val \ - trainer.logger=['console','wandb'] \ - trainer.project_name=$project_name \ - trainer.experiment_name=$experiment_name \ - trainer.n_gpus_per_node=$ARNOLD_WORKER_GPU \ - trainer.val_before_train=True \ - trainer.log_val_generations=50 \ - trainer.nnodes=$ARNOLD_WORKER_NUM \ - trainer.save_freq=-1 \ - trainer.default_local_dir=$default_local_dir \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ diff --git a/recipe/langgraph_agent/react_agent_loop.py b/recipe/langgraph_agent/react_agent_loop.py deleted file mode 100644 index 578968a92..000000000 --- a/recipe/langgraph_agent/react_agent_loop.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -LangGraph React Agent Loop. - -This implementation is exact same as `ToolAgentLoop`. - -Ref: https://langchain-ai.github.io/langgraph/tutorials/workflows/ -""" - -from typing import Any, Literal - -from langchain_core.runnables import RunnableConfig -from langgraph.graph import END, MessagesState, StateGraph -from langgraph.prebuilt import ToolNode - -from recipe.langgraph_agent.chat_model import ( - ChatModel, - MaxTokenExceededError, - convert_to_agent_output, -) -from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput - - -async def call_model(state: MessagesState, config: RunnableConfig): - model = config["configurable"]["model"] - sampling_params = config["configurable"]["sampling_params"] - try: - message = await model.ainvoke(state["messages"], sampling_params=sampling_params) - return {"messages": [message]} - except MaxTokenExceededError: - # last message is ToolMessage - return {"messages": []} - - -def should_continue(state: MessagesState, config: RunnableConfig) -> Literal["tools", END]: - max_assistant_turns = config["configurable"]["max_assistant_turns"] - num_assistant_turns = 0 - for message in state["messages"]: - if message.type == "ai": - num_assistant_turns += 1 - - last_message = state["messages"][-1] - - # LLM call failed, e.g: max response length exceeded - if last_message.type == "tool": - return END - - # max assistant turns exceeded - if max_assistant_turns and num_assistant_turns >= max_assistant_turns: - return END - - # no tool calls - if not last_message.tool_calls: - return END - - return "tools" - - -class ReactAgentLoop(AgentLoopBase): - @classmethod - def init_class(cls, config, tokenizer, **kwargs): - if cls._class_initialized: - return - cls._class_initialized = True - print("Performing class-level ReactAgentLoop initialization") - - # build graph - cls.graph = cls.build_graph() - - @classmethod - def build_graph(cls) -> StateGraph: - workflow = StateGraph(MessagesState) - - workflow.add_node("agent", call_model) - workflow.add_node("tools", ToolNode(cls.tools)) - workflow.set_entry_point("agent") - workflow.add_conditional_edges( - "agent", - should_continue, - { - "tools": "tools", - END: END, - }, - ) - - workflow.add_edge("tools", "agent") - graph = workflow.compile() - return graph - - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: - model_path = self.config.actor_rollout_ref.model.path - model_name = "/".join(model_path.split("/")[-2:]) - - rollout = self.config.actor_rollout_ref.rollout - model = ChatModel( - model=model_name, - client=self.server_manager, - tokenizer=self.tokenizer, - max_tokens=rollout.response_length, - max_parallel_calls=rollout.multi_turn.max_parallel_calls, - tool_parser=rollout.multi_turn.format, - ) - - model = model.bind_tools(self.tools, tool_choice="any") - - config = { - "configurable": { - "model": model, - "sampling_params": sampling_params, - "max_user_turns": rollout.multi_turn.max_user_turns, - "max_assistant_turns": rollout.multi_turn.max_assistant_turns, - } - } - - # TODO: how to handle multiple trajectories in an graph invocation? - # Each graph node may has its own LLM calls and state, e.g: - # https://github.com/google-gemini/gemini-fullstack-langgraph-quickstart - state = await self.graph.ainvoke(input={"messages": messages}, config=config) - - output = convert_to_agent_output(state["messages"], rollout.response_length) - return output diff --git a/recipe/langgraph_agent/test_react_agent_loop.py b/recipe/langgraph_agent/test_react_agent_loop.py deleted file mode 100644 index 0cdc91959..000000000 --- a/recipe/langgraph_agent/test_react_agent_loop.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import os - -import numpy as np -import pytest -import ray -from langchain_core.tools import tool -from omegaconf import DictConfig - -from recipe.langgraph_agent.react_agent_loop import ReactAgentLoop -from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager -from verl.protocol import DataProto -from verl.utils import hf_tokenizer - - -@pytest.fixture -def init_config() -> DictConfig: - from hydra import compose, initialize_config_dir - - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): - config = compose(config_name="ppo_trainer") - model_path = "Qwen/Qwen2.5-1.5B-Instruct" - config.actor_rollout_ref.model.path = model_path - config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") - config.actor_rollout_ref.rollout.mode = "async" - config.actor_rollout_ref.rollout.prompt_length = 4096 - config.actor_rollout_ref.rollout.response_length = 4096 - config.actor_rollout_ref.rollout.n = 4 - config.actor_rollout_ref.rollout.agent.num_workers = 2 - - # test sleep/wake_up with fsdp offload - config.actor_rollout_ref.actor.fsdp_config.param_offload = True - config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True - - return config - - -@tool(parse_docstring=True) -def get_current_temperature(location: str, unit: str = "celsius"): - """Get current temperature at a location. - - Args: - location: The location to get the temperature for, in the format "City, State, Country". - unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) - - Returns: - the temperature, the location, and the unit in a dict - """ - print(f"[DEBUG] get_current_temperature: {location}, {unit}") - return { - "temperature": 26.1, - "location": location, - "unit": unit, - } - - -@tool(parse_docstring=True) -def get_temperature_date(location: str, date: str, unit: str = "celsius"): - """Get temperature at a location and date. - - Args: - location: The location to get the temperature for, in the format "City, State, Country". - date: The date to get the temperature for, in the format "Year-Month-Day". - unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) - - Returns: - the temperature, the location, the date and the unit in a dict - """ - print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}") - return { - "temperature": 25.9, - "location": location, - "date": date, - "unit": unit, - } - - -class TestReactAgentLoop(ReactAgentLoop): - @classmethod - def init_class(cls, config, tokenizer, **kwargs): - # TODO: find better way to configure tools - cls.tools = [get_current_temperature, get_temperature_date] - super().init_class(config, tokenizer, **kwargs) - - -def test_react_agent(init_config): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - } - ) - - # =========================== 1. Init rollout manager =========================== - agent_loop_config = [ - { - "_target_": "recipe.langgraph_agent.test_react_agent_loop.TestReactAgentLoop", - "name": "react_agent", - }, - ] - agent_loop_config_path = "/tmp/agent_loop_config.json" - with open(agent_loop_config_path, "w") as f: - json.dump(agent_loop_config, f) - - n = 2 - init_config.actor_rollout_ref.rollout.n = n - # init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path - init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2 - init_config.actor_rollout_ref.rollout.agent.agent_loop_config_path = agent_loop_config_path - agent_loop_manager = init_agent_loop_manager(init_config) - - # =========================== 2. Generate sequences =========================== - raw_prompts = [ - [ - {"role": "user", "content": "How are you?"}, - ], - [ - {"role": "user", "content": "What's the temperature in Los Angeles now?"}, - ], - [ - {"role": "user", "content": "What's the temperature in New York now?"}, - ], - [ - { - "role": "system", - "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" - "Current Date: 2024-09-30", - }, - {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, - ], - ] - batch = DataProto( - non_tensor_batch={ - "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), - "agent_name": np.array(["react_agent"] * len(raw_prompts)), - }, - ) - batch = batch.repeat(n) - result = agent_loop_manager.generate_sequences(prompts=batch) - assert len(result) == len(raw_prompts) * n - - # Check turns - num_turns = result.non_tensor_batch["__num_turns__"] - print(f"num_turns: {num_turns}") - for i in range(len(num_turns)): - if i // n == 0: - # [user, assistant] - assert num_turns[i] == 2 - else: - # [user, assistant, tool, assistant] - assert num_turns[i] == 4 - - # Check response_mask - tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) - responses = result.batch["responses"] - response_mask = result.batch["response_mask"] - attention_mask = result.batch["attention_mask"] - assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" - response_length = response_mask.size(1) - - for i in range(len(responses)): - # response with tool response - valid_tokens = responses[i][attention_mask[i][-response_length:].bool()] - response_with_obs = tokenizer.decode(valid_tokens) - - # response without tool response - valid_tokens = responses[i][response_mask[i].bool()] - response_without_obs = tokenizer.decode(valid_tokens) - - assert "" not in response_without_obs, ( - f"found in response: {response_without_obs}" - ) - assert "" not in response_without_obs, ( - f"found in response: {response_without_obs}" - ) - print("=========================") - print(response_with_obs) - print("---") - print(response_without_obs) - - print("Test passed!") - ray.shutdown() diff --git a/recipe/minicpmo/rl_dataset.py b/recipe/minicpmo/rl_dataset.py deleted file mode 100644 index 5ce15fb12..000000000 --- a/recipe/minicpmo/rl_dataset.py +++ /dev/null @@ -1,553 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import logging -import math -import os -import re -from typing import Optional - -import datasets -import torch -from omegaconf import DictConfig, ListConfig -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms -from transformers import PreTrainedTokenizer, ProcessorMixin - -import verl.utils.torch_functional as verl_F -from verl.utils.dataset.vision_utils import process_image -from verl.utils.model import compute_position_id_with_mask - -logger = logging.getLogger(__name__) - - -def build_transform(): - IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN - IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD - return transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - ] - ) - - -def build_image_bound(input_ids, tokenizer, new_schema=True, logger=None): - if new_schema: - start_cond = (input_ids == tokenizer.im_start_id) | (input_ids == tokenizer.slice_start_id) - end_cond = (input_ids == tokenizer.im_end_id) | (input_ids == tokenizer.slice_end_id) - else: - start_cond = input_ids == tokenizer.im_start_id - end_cond = input_ids == tokenizer.im_end_id - image_start_tokens = torch.where(start_cond)[0] - image_start_tokens += 1 - image_end_tokens = torch.where(end_cond)[0] - if len(image_start_tokens) != len(image_end_tokens): - logger.error("image start token != image end tokens") - raise Exception("image start token != image end tokens") - if len(image_start_tokens) > 0: - image_bound = torch.hstack([image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]) - else: - image_bound = [] - return image_bound - - -def preprocess( - images_dict, - conversations, - tokenizer, - transform, - query_nums=64, - slice_config=None, - llm_type=None, - patch_size=14, - batch_vision=False, - max_length=2048, - truncation="error", - logger=None, -): - """ - single(multi) image(s) preprocess, the image(s) will be placed at the top of the conversation - """ - conversations = copy.deepcopy(conversations) - assert conversations[0]["role"] == "user", "the first role must be user" - - if slice_config is not None: - assert isinstance(slice_config, dict) - assert "patch_size" in slice_config - assert "max_slice_nums" in slice_config - assert "scale_resolution" in slice_config - default_image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end - new_schema = False - use_image_id = False - if llm_type == "qwen": - new_schema = True - use_image_id = True - image_placeholder_dict = {} - images = [] - image_id_cnt = 0 - for img_name, image in images_dict.items(): - if slice_config: - source_image, patches, best_grid = slice_image( - image, - slice_config["max_slice_nums"], - slice_config["scale_resolution"], - slice_config["patch_size"], - ) - images.append(source_image) - image_placeholder = default_image_placeholder - if len(patches) > 0: - for i in range(len(patches)): - for j in range(len(patches[0])): - images.append(patches[i][j]) - if use_image_id: - image_placeholder = ( - f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder - ) - image_id_cnt += 1 - image_placeholder += get_grid_placeholder(tokenizer, best_grid, query_nums, new_schema=new_schema) - image_placeholder_dict[img_name] = image_placeholder - else: - images.append(image) - if use_image_id: - image_placeholder = f"{tokenizer.im_id_start}{image_id_cnt}{tokenizer.im_id_end}" + image_placeholder - image_id_cnt += 1 - else: - image_placeholder = default_image_placeholder - image_placeholder_dict[img_name] = image_placeholder - - images = [transform(i) for i in images] - - if len(images_dict) == 1 and "" in images_dict: - if "" in conversations[0]["content"]: - conversations[0]["content"] = conversations[0]["content"].replace("", image_placeholder) - else: - conversations[0]["content"] = image_placeholder + "\n" + conversations[0]["content"] - else: - pattern = r"" - new_conversations = [] - for conversation in conversations: - content = conversation["content"] - parts = re.split(f"({pattern})", content) - for i, part in enumerate(parts): - if not part.strip(): - continue - if re.match(pattern, part): - if part in image_placeholder_dict: - parts[i] = image_placeholder_dict[part] - else: - raise Exception(f"not found {part} in image dict") - conversation["content"] = "\n".join(parts) - new_conversations.append(conversation) - conversations = new_conversations - - # TODO change role in conversation for different llm - prompt_with_chat_template = tokenizer.apply_chat_template(conversations, add_generation_prompt=True, tokenize=False) - - input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( - prompt=prompt_with_chat_template, - tokenizer=tokenizer, - max_length=max_length, - pad_token_id=tokenizer.pad_token_id, - left_pad=True, - truncation=truncation, - ) - position_ids = compute_position_id_with_mask(attention_mask) - image_bound = build_image_bound(input_ids[0], tokenizer, new_schema, logger) - - input_dict = { - "input_ids": input_ids[0], - "attention_mask": attention_mask[0], - "position_ids": position_ids[0], - "image_bound": image_bound, - } - - if batch_vision: - tgt_sizes = [] - reshape_images = [] - for image in images: - H, W = image.shape[1:] - reshape_image = reshape_by_patch(image, patch_size) - reshape_images.append(reshape_image) - tgt_sizes.append([H // patch_size, W // patch_size]) - if tgt_sizes: - tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32) - - input_dict["pixel_values"] = reshape_images - input_dict["tgt_sizes"] = tgt_sizes - - else: - input_dict["pixel_values"] = images - input_dict["tgt_sizes"] = [] - - return input_dict - - -def slice_image(image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False): - original_size = image.size - original_width, original_height = original_size - log_ratio = math.log(original_width / original_height) - ratio = original_width * original_height / (scale_resolution * scale_resolution) - multiple = min(math.ceil(ratio), max_slice_nums) - - source_image = None - best_grid = None - patches = [] - - if multiple <= 1 or never_split: - # dont need to slice, upsample - best_size = find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True) - source_image = image.resize(best_size, Image.Resampling.BICUBIC) - else: - candidate_split_grids_nums = [] - for i in [multiple - 1, multiple, multiple + 1]: - if i == 1 or i > max_slice_nums: - continue - candidate_split_grids_nums.append(i) - - # source image, down-sampling and ensure divided by patch_size - best_resize = find_best_resize(original_size, scale_resolution, patch_size) - source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) - candidate_grids = [] - - # find best grid - for split_grids_nums in candidate_split_grids_nums: - m = 1 - while m <= split_grids_nums: - if split_grids_nums % m == 0: - candidate_grids.append([m, split_grids_nums // m]) - m += 1 - - best_grid = [1, 1] - min_error = float("inf") - for grid in candidate_grids: - error = abs(log_ratio - math.log(grid[0] / grid[1])) - if error < min_error: - best_grid = grid - min_error = error - - refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, allow_upscale=True) - - refine_image = image.resize(refine_size, Image.Resampling.BICUBIC) - patches = split_to_patches(refine_image, best_grid) - - return source_image, patches, best_grid - - -def ensure_divide(length, patch_size): - return max(round(length / patch_size) * patch_size, patch_size) - - -def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False): - width, height = original_size - if (width * height > scale_resolution * scale_resolution) or allow_upscale: - r = width / height - height = int(scale_resolution / math.sqrt(r)) - width = int(height * r) - best_width = ensure_divide(width, patch_size) - best_height = ensure_divide(height, patch_size) - return (best_width, best_height) - - -def get_refine_size(original_size, grid, scale_resolution, patch_size, allow_upscale=False): - width, height = original_size - grid_x, grid_y = grid - - refine_width = ensure_divide(width, grid_x) - refine_height = ensure_divide(height, grid_y) - - grid_width = refine_width / grid_x - grid_height = refine_height / grid_y - - best_grid_size = find_best_resize( - (grid_width, grid_height), - scale_resolution, - patch_size, - allow_upscale=allow_upscale, - ) - - refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) - - return refine_size - - -def split_to_patches(image, grid): - patches = [] - width, height = image.size - grid_x = int(width / grid[0]) - grid_y = int(height / grid[1]) - - for i in range(0, height, grid_y): - images = [] - for j in range(0, width, grid_x): - box = (j, i, j + grid_x, i + grid_y) - patch = image.crop(box) - images.append(patch) - patches.append(images) - - return patches - - -def get_grid_placeholder(tokenizer, grid, query_num, new_schema=False): - if new_schema: - image_placeholder = tokenizer.slice_start + tokenizer.unk_token * query_num + tokenizer.slice_end - else: - image_placeholder = tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end - - cols = grid[0] - rows = grid[1] - slices = [] - for i in range(rows): - lines = [] - for j in range(cols): - lines.append(image_placeholder) - slices.append("".join(lines)) - if new_schema: - slice_placeholder = "\n".join(slices) - else: - slice_placeholder = tokenizer.slice_start + "\n".join(slices) + tokenizer.slice_end - return slice_placeholder - - -def reshape_by_patch(image_tensor, patch_size): - """ - :param image_tensor: shape [3, H, W] - :param patch_size: - :return: [3, patch_size, HW/patch_size] - """ - patches = torch.nn.functional.unfold(image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size)) - - patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1) - patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1) - return patches - - -def init_minicpmo_config(processor, config): - """Initialize MiniCPM-o specific configuration""" - minicpmo_config = { - "transform": build_transform(), - "patch_size": config.get("patch_size", 14), - "query_nums": config.get("query_nums", 64), - "slice_config": config.get( - "slice_config", {"max_slice_nums": 9, "patch_size": config.get("patch_size", 14), "scale_resolution": 448} - ), - "llm_type": config.get("llm_type", "qwen"), - "batch_vision": config.get("batch_vision", True), - } - return minicpmo_config - - -def process_minicpmo_data( - row_dict, messages, tokenizer, minicpmo_config, image_key, max_prompt_length, truncation, logger -): - """Process data for MiniCPM-o model""" - if len(row_dict[image_key]) == 1: - multi_modal_data = {} - image = process_image(row_dict.pop(image_key)[0]) - multi_modal_data["image"] = [image] - images_dict = {"": image} - else: - raise NotImplementedError - - model_inputs = preprocess( - images_dict, - messages, - tokenizer, - minicpmo_config["transform"], - query_nums=minicpmo_config["query_nums"], - slice_config=minicpmo_config["slice_config"], - llm_type=minicpmo_config["llm_type"], - patch_size=minicpmo_config["patch_size"], - batch_vision=minicpmo_config["batch_vision"], - max_length=max_prompt_length, - truncation=truncation, - logger=logger, - ) - - raw_prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - raw_prompt = raw_prompt.replace("", "(./)") - - return model_inputs, multi_modal_data, raw_prompt - - -class RLHFDataset(Dataset): - """ - Load and preprocess RLHF data from Parquet files. - - - Caches files locally. - - Reads into a HuggingFace Dataset and tokenizes prompts. - - Optionally handles images/videos via a ProcessorMixin. - - Filters prompts over a max length. - - Supports resuming from checkpoints. - - Args: - data_files (str or list): Path(s) to Parquet file(s). - tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs. - config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc. - processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos. - """ - - def __init__( - self, - data_files: str | list[str], - tokenizer: PreTrainedTokenizer, - config: DictConfig, - processor: Optional[ProcessorMixin] = None, - ): - if not isinstance(data_files, list | ListConfig): - data_files = [data_files] - - self.data_files = copy.deepcopy(data_files) - self.original_data_files = copy.deepcopy(data_files) # use for resume - self.tokenizer = tokenizer - self.processor = processor - self.config = config - - self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) - self.prompt_key = config.get("prompt_key", "prompt") - self.image_key = config.get("image_key", "images") - self.video_key = config.get("video_key", "videos") - self.max_prompt_length = config.get("max_prompt_length", 1024) - self.return_raw_chat = config.get("return_raw_chat", False) - self.return_full_prompt = config.get("return_full_prompt", False) - self.truncation = config.get("truncation", "error") - self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) - - self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) - self.num_workers = min(self.num_workers, os.cpu_count()) - self.use_shm = config.get("use_shm", False) - self.chat_template_func = config.get("chat_template_func", None) - self.need_tools_kwargs = config.get("need_tools_kwargs", False) - self.filter_prompts = config.get("filter_prompts", True) - self.serialize_dataset = False - self.minicpmo_config = init_minicpmo_config(self.processor, config) - self._download() - self._read_files_and_tokenize() - - def _download(self, use_origin_parquet=False): - from verl.utils.fs import copy_to_local - - data_files = self.data_files if not use_origin_parquet else self.original_data_files - for i, parquet_file in enumerate(data_files): - self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm) - - def _read_files_and_tokenize(self): - dataframes = [] - for parquet_file in self.data_files: - # read parquet files and cache - dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] - dataframes.append(dataframe) - self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) - - print(f"dataset len: {len(self.dataframe)}") - - def resume_dataset_state(self): - self.serialize_dataset = not hasattr(self, "original_data_files") - # resume dataframe if not it's serialized in data.pt - if not self.serialize_dataset: - self._download(use_origin_parquet=True) # download and resume from original parquet files - self._read_files_and_tokenize() - else: - print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") - - def __len__(self): - return len(self.dataframe) - - def _build_messages(self, example: dict): - return example.pop(self.prompt_key) - - def __getitem__(self, item): - """ - Note that we also return the raw_input_ids so that it can be combined with other chat template - """ - row_dict: dict = self.dataframe[item] - messages = self._build_messages(row_dict) - model_inputs = {} - - if self.processor is not None: - model_inputs, multi_modal_data, raw_prompt = process_minicpmo_data( - row_dict, - messages, - self.tokenizer, - self.minicpmo_config, - self.image_key, - self.max_prompt_length, - self.truncation, - logger, - ) - input_ids = model_inputs.pop("input_ids") - attention_mask = model_inputs.pop("attention_mask") - position_ids = model_inputs.pop("position_ids") - - # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature - row_dict["multi_modal_data"] = multi_modal_data - row_dict["multi_modal_inputs"] = dict(model_inputs) - else: - raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) - model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False) - input_ids = model_inputs.pop("input_ids") - attention_mask = model_inputs.pop("attention_mask") - position_ids = compute_position_id_with_mask(attention_mask) - - row_dict["input_ids"] = input_ids - row_dict["attention_mask"] = attention_mask - row_dict["position_ids"] = position_ids - - raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False) - if len(raw_prompt_ids) > self.max_prompt_length: - if self.truncation == "left": - raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :] - elif self.truncation == "right": - raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] - elif self.truncation == "middle": - left_half = self.max_prompt_length // 2 - right_half = self.max_prompt_length - left_half - raw_prompt_ids = raw_prompt_ids[:left_half] + raw_prompt_ids[-right_half:] - elif self.truncation == "error": - raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.") - - row_dict["raw_prompt_ids"] = raw_prompt_ids - # encode prompts without chat template - if self.return_raw_chat: - row_dict["raw_prompt"] = messages - - # get prompts with chat template - if self.return_full_prompt: - row_dict["full_prompts"] = raw_prompt # array of strings - - # add index for each prompt - index = row_dict.get("extra_info", {}).get("index", 0) - tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {}) - interaction_kwargs = row_dict.get("extra_info", {}).get("interaction_kwargs", {}) - need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs) - if need_tools_kwargs and not tools_kwargs: - logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"]) - row_dict["index"] = index - row_dict["tools_kwargs"] = tools_kwargs - row_dict["interaction_kwargs"] = interaction_kwargs - return row_dict - - def __getstate__(self): - if not self.serialize_dataset: - state = self.__dict__.copy() - - if "dataframe" in state: - del state["dataframe"] - return state - - return self.__dict__.copy() diff --git a/recipe/prime/config/prime_trainer.yaml b/recipe/prime/config/prime_trainer.yaml deleted file mode 100644 index 23a3c4403..000000000 --- a/recipe/prime/config/prime_trainer.yaml +++ /dev/null @@ -1,76 +0,0 @@ -# the prime config will override default ppo_trainer.yaml - -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -data: - filter_accuracy: True - accuracy_lower_bound: 0.2 - accuracy_upper_bound: 0.8 - oversample_factor: 4.0 # Sample more responses than the batch size. prompts satisfying the filter will be prioritized. - filter_truncate: True - truncation: right - -actor_rollout_ref: - hybrid_engine: True - model: - use_remove_padding: True - rollout: - # number of responses (i.e. num sample times) - n: 4 - actor: - entropy_coeff: 0.001 - -reward_model: - enable: True - strategy: fsdp - model: - ref_path: ${reward_model.model.path} - use_remove_padding: True - use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} - fused_kernel_options: - impl_backend: torch # triton, torch - tokenizer_path: ${actor_rollout_ref.model.path} - enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing} - ref_type: freeze - fsdp_config: - min_num_params: 0 - param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload} -# grad_offload: ${actor_rollout_ref.actor.fsdp_config.grad_offload} - optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload} - update: before # ``before`` for double-forward, ``after`` for single-forward - optim: - lr: 1e-6 - lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null - warmup_style: constant - total_training_steps: -1 # must be overridden by program - weight_decay: 0. - grad_clip: 10.0 - beta_train: 0.05 - loss_type: ce # currently only supports ce loss - prime_granularity: token - prime_norm: batch_norm # batch_norm or none. if set to none, the normalizer is beta_train - mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - reward_manager: prime - -algorithm: - adv_estimator: rloo - # now supports rloo. it treats different source of reward separately. - kl_ctrl: - type: fixed - kl_coef: 0.000 - reward_gt_coef: 5 - reward_dpo_coef: 5 - -trainer: - project_name: prime - experiment_name: examples - val_before_train: False - balance_batch: False diff --git a/recipe/prime/main_prime.py b/recipe/prime/main_prime.py deleted file mode 100644 index 6bf7f5e45..000000000 --- a/recipe/prime/main_prime.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2024 PRIME team and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. -""" - -import hydra -import ray - -from .prime_ray_trainer import RayPRIMETrainer - - -@hydra.main(config_path="config", config_name="prime_trainer", version_base=None) -def main(config): - run_prime(config) - - -def run_prime(config, compute_score=None): - if not ray.is_initialized(): - # this is for local ray cluster - ray.init( - runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}, - num_cpus=config.ray_init.num_cpus, - ) - - ray.get(main_task.remote(config, compute_score)) - - -@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head -def main_task(config, compute_score=None): - # print initial config - from pprint import pprint - - from omegaconf import OmegaConf - - from verl.utils.fs import copy_local_path_from_hdfs - - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - OmegaConf.resolve(config) - - # download the checkpoint from hdfs - local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) - - # instantiate tokenizer - from verl.utils import hf_tokenizer - - tokenizer = hf_tokenizer(local_path) - - # define worker classes - if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: - assert config.critic.strategy in {"fsdp", "fsdp2"} - from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ActorRolloutRefWorker - - ray_worker_group_cls = RayWorkerGroup - - elif config.actor_rollout_ref.actor.strategy == "megatron": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker - - ray_worker_group_cls = NVMegatronRayWorkerGroup - - else: - raise NotImplementedError - - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role - - role_worker_mapping = { - Role.ActorRollout: ray.remote(ActorRolloutRefWorker), - } - - global_pool_id = "global_pool" - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - } - - # use reference model - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: - role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) - mapping[Role.RefPolicy] = global_pool_id - - if config.reward_model.enable: - from .prime_fsdp_workers import PRIMERewardModelWorker - - role_worker_mapping[Role.RewardModel] = ray.remote(PRIMERewardModelWorker) - mapping[Role.RewardModel] = global_pool_id - - reward_manager_name = config.reward_model.get("reward_manager", "naive") - if reward_manager_name == "naive": - from verl.workers.reward_manager import NaiveRewardManager - - reward_manager_cls = NaiveRewardManager - elif reward_manager_name == "prime": - from verl.workers.reward_manager import PrimeRewardManager - - reward_manager_cls = PrimeRewardManager - else: - raise NotImplementedError - reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=0, compute_score=compute_score) - - # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls(tokenizer=tokenizer, num_examine=1, compute_score=compute_score) - - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - trainer = RayPRIMETrainer( - config=config, - tokenizer=tokenizer, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn, - device_name=config.trainer.device, - ) - trainer.init_workers() - trainer.fit() - - -if __name__ == "__main__": - main() diff --git a/recipe/prime/prime_core_algos.py b/recipe/prime/prime_core_algos.py deleted file mode 100644 index 825671216..000000000 --- a/recipe/prime/prime_core_algos.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2024 PRIME team and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -import verl -import verl.utils.torch_functional as verl_F - - -def compute_rloo_advantage_return(data: verl.DataProto, response_mask: torch.Tensor, n_samples, config): - # calculate rloo reward on different reward sources, and sum again - def masked_rloo(reward_tensor_original, mask_tensor): - reward_tensor = reward_tensor_original.clone() - reward_tensor[~mask_tensor] = 0 - for start_pos in range(0, reward_tensor.shape[0], n_samples): - cur_rewards_mean = torch.cat( - [ - reward_tensor[pos : pos + 1][mask_tensor[pos : pos + 1]].mean(dim=0, keepdim=True) - for pos in range(start_pos, start_pos + n_samples) - ], - dim=0, - ) - cur_rewards_sum = cur_rewards_mean.sum() - cur_reward_baseline = cur_rewards_sum / (n_samples - 1) - reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] = ( - reward_tensor[start_pos : start_pos + n_samples][mask_tensor[start_pos : start_pos + n_samples]] - * (n_samples / (n_samples - 1)) - - cur_reward_baseline - ) - - return reward_tensor - - reward_tensors = [] - - with torch.no_grad(): - if "rm_scores" in data.batch.keys() and config.algorithm.reward_dpo_coef != 0.0: - reward_tensor = data.batch["rm_scores"] - reward_mask = response_mask.bool() - - reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_dpo_coef) - - if "acc" in data.batch.keys() and config.algorithm.reward_gt_coef != 0.0: - reward_tensor = torch.zeros_like(response_mask, dtype=torch.float32) - reward_mask = torch.zeros_like(response_mask, dtype=torch.bool) - - prompt_ids = data.batch["prompts"] - prompt_length = prompt_ids.shape[-1] - valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(-1) - - reward_mask[ - torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), - valid_response_length - 1, - ] = True - reward_tensor[ - torch.arange(0, valid_response_length.shape[0], dtype=torch.long, device=valid_response_length.device), - valid_response_length - 1, - ] = data.batch["acc"] - - reward_tensors.append(masked_rloo(reward_tensor, reward_mask) * config.algorithm.reward_gt_coef) - - final_reward_tensor = sum(reward_tensors) - - returns = (final_reward_tensor * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) - - advantages = returns.clone() - advantages = verl_F.masked_whiten(advantages, response_mask) - - return advantages, returns - - -def compute_ce_dpo_loss_rm(token_level_scores, acc, response_mask, beta): - cur_scores = ((token_level_scores * response_mask).sum(dim=1) * beta).sigmoid() - cur_dpo_loss = torch.nn.functional.binary_cross_entropy(cur_scores, acc) - return cur_dpo_loss - - -def compute_detach_dpo_loss_rm(token_level_scores, acc, Q_bc, acc_bc, response_mask, beta, bon_mode="none"): - # we always assume that the BoN size equals n_samples - # mode1: use acc as rm - # mode2: use Q as rm - cur_Q = (token_level_scores * response_mask).sum(dim=1) * beta - other_Q = torch.zeros_like(cur_Q) - for i in range(token_level_scores.shape[0]): - Q_chosen = Q_bc[i][acc_bc[i] < acc[i]] if acc[i] > 0 else Q_bc[i][acc_bc[i] > acc[i]] - if len(Q_chosen) > 0: - other_Q[i] = Q_chosen.mean() * beta - else: - other_Q[i] = 0 - dpo_loss = -torch.log(torch.sigmoid((cur_Q - other_Q) * ((acc > 0).float() * 2 - 1))) - if bon_mode == "none": - dpo_loss = dpo_loss.mean() - else: - weight = torch.zeros_like(dpo_loss) - n_samples = acc_bc.shape[1] - if bon_mode == "bon_rm": - for i in range(token_level_scores.shape[0]): - weight[i] = n_samples * torch.pow((Q_bc[i] * beta <= cur_Q[i]).float().mean(), n_samples - 1) - elif bon_mode == "bon_acc": - for i in range(token_level_scores.shape[0]): - weight[i] = n_samples * torch.pow((acc_bc[i] <= acc[i]).float().mean(), n_samples - 1) - else: - raise NotImplementedError - dpo_loss = (dpo_loss * weight).sum() - - return dpo_loss - - -def compute_dpo_accuracy(token_level_scores, acc, response_mask, n_samples): - dpo_acc = [] - for start_id in range(0, token_level_scores.shape[0], n_samples): - cur_scores = ( - token_level_scores[start_id : start_id + n_samples] * response_mask[start_id : start_id + n_samples] - ).sum(dim=1) - - def get_upper_triangle(tensor_x): - diff_matrix = tensor_x.unsqueeze(1) - tensor_x.unsqueeze(0) - upper_tri_indices = torch.triu(torch.ones_like(diff_matrix).bool(), diagonal=1) - return diff_matrix[upper_tri_indices] - - cur_acc_diff = get_upper_triangle(acc[start_id : start_id + n_samples]) # in range [-1,1] - cur_score_diff = get_upper_triangle(cur_scores) # in R - cur_score_prediction = (cur_score_diff > 0).float() # in [0,1] - if cur_acc_diff.abs().sum() == 0: - cur_acc = torch.zeros_like(cur_score_prediction[0]) + 0.5 - else: - cur_acc = ( - ((cur_score_diff > 0) == (cur_acc_diff > 0)).float() * cur_acc_diff.abs() - ).sum() / cur_acc_diff.abs().sum() - - dpo_acc.append(cur_acc.unsqueeze(0)) - - return torch.cat(dpo_acc, dim=0).mean() - - -def compute_dpo_abs_accuracy(token_level_scores, acc, response_mask, n_samples): - return (torch.sign((token_level_scores * response_mask).sum(dim=-1)) == torch.sign(acc * 2 - 1)).float().mean() diff --git a/recipe/prime/prime_dp_rm.py b/recipe/prime/prime_dp_rm.py deleted file mode 100644 index d15d772f0..000000000 --- a/recipe/prime/prime_dp_rm.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright 2024 PRIME team and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Implement a multiprocess PPOCritic -""" - -import itertools - -import torch -import torch.distributed -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input -from torch import nn, optim -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -import verl.utils.torch_functional as verl_F -from verl import DataProto -from verl.utils.device import get_device_name -from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches -from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs - -from .prime_core_algos import compute_ce_dpo_loss_rm, compute_detach_dpo_loss_rm - -__all__ = ["DataParallelPRIMERewardModel"] - - -class DataParallelPRIMERewardModel: - def __init__(self, config, reward_module: nn.Module, ref_module: nn.Module, reward_optimizer: optim.Optimizer): - self.config = config - self.reward_module = reward_module - self.ref_module = ref_module - self.reward_optimizer = reward_optimizer - self.use_remove_padding = self.config.model.get("use_remove_padding", False) - print(f"Reward model use_remove_padding={self.use_remove_padding}") - self.use_fused_kernels = self.config.model.get("use_fused_kernels", False) - print(f"Reward model use_fused_kernels={self.use_fused_kernels}") - - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) - - def _forward_micro_batch(self, micro_batch, prompt_length): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - - num_actions = micro_batch["input_ids"].shape[-1] - prompt_length - max_positions = micro_batch["attention_mask"][:, prompt_length:].sum(-1) - - if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # for compute the log_prob - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) - - # pad and slice the inputs if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size - ) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, None, self.ulysses_sequence_parallel_size - ) - - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) - output = self.reward_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - use_cache=False, - return_dict=self.use_fused_kernels, - ) - - if self.use_fused_kernels: - rm_log_labels = output.log_probs.squeeze(0) # (total_nnz,) - rm_log_labels = rm_log_labels.to(torch.float32) - - else: - rm_output_logits = output.logits.squeeze(0) - rm_log_labels = verl_F.logprobs_from_logits( - logits=rm_output_logits, - labels=input_ids_rmpad_rolled, - ) - - if self.ulysses_sequence_parallel_size > 1: - rm_log_labels = gather_outputs_and_unpad( - rm_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - rm_log_labels = pad_input( - hidden_states=rm_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen - ).squeeze(-1)[:, -num_actions - 1 : -1] - - else: - output = self.reward_module( - input_ids=micro_batch["input_ids"], - attention_mask=micro_batch["attention_mask"], - position_ids=micro_batch["position_ids"], - use_cache=False, - return_dict=self.use_fused_kernels, - ) - - if self.use_fused_kernels: - rm_log_labels = output.log_probs[:, :-1] # (bsz, seq_length) - rm_log_labels = rm_log_labels.to(torch.float32) - - else: - rm_output_logits = output.logits - rm_log_prob = torch.nn.functional.log_softmax( - rm_output_logits[:, :-1, :], dim=-1 - ) # (batch_size, seq_length, vocab_size) - rm_log_labels = rm_log_prob.gather(dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1)).squeeze( - -1 - ) # (batch, seq_length) - - if self.ref_module is not None: - # do not have to pad again - with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): - if self.ulysses_sequence_parallel_size > 1 and self.use_remove_padding: - ref_output = self.ref_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - use_cache=False, - ) - - if self.use_fused_kernels: - ref_log_labels = ref_output.log_probs.squeeze(0) # (total_nnz,) - ref_log_labels = ref_log_labels.to(torch.float32) - - else: - ref_output_logits = ref_output.logits.squeeze(0) - ref_log_labels = verl_F.logprobs_from_logits( - logits=ref_output_logits, labels=input_ids_rmpad_rolled - ) - - ref_log_labels = gather_outputs_and_unpad( - ref_log_labels, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - ref_log_labels = pad_input( - hidden_states=ref_log_labels.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen - ).squeeze(-1)[:, -num_actions - 1 : -1] - else: - ref_output = self.ref_module( - input_ids=micro_batch["input_ids"], - attention_mask=micro_batch["attention_mask"], - position_ids=micro_batch["position_ids"], - use_cache=False, - ) - - if self.use_fused_kernels: - ref_log_labels = ref_output.log_probs[:, :-1] # (batch_size, seq_length) - ref_log_labels = ref_log_labels.to(torch.float32) - - else: - ref_output_logits = ref_output.logits - ref_log_prob = torch.nn.functional.log_softmax( - ref_output_logits[:, :-1, :], dim=-1 - ) # (batch_size, seq_length, vocab_size) - ref_log_labels = ref_log_prob.gather( - dim=-1, index=micro_batch["input_ids"][:, 1:].unsqueeze(-1) - ).squeeze(-1) # (batch, seq_length) - - else: - ref_log_labels = micro_batch["old_log_probs"] - - ref_log_labels.to(rm_log_labels.dtype) - q = rm_log_labels[:, -num_actions:] - ref_log_labels[:, -num_actions:] # this is actually diff of q - - # trim unnecessary logprobs here - for i in range(micro_batch["input_ids"].shape[0]): - q[i, max_positions[i] :] = 0 - - # reward computation does not need gradient. only q needs - with torch.no_grad(): - # generalized estimation of r should go before the reward filling. r means process reward for policy - # model, or the advantage of reward model. - lam = self.config.get("lambda", 0.0) - beta = self.config.model.get("beta_train", 0.05) - if lam == 0.0: - r = q * beta - else: - # reward coefficient takes no effect here - acc = micro_batch["acc"] - q_ = q * beta - r = torch.zeros_like(q) - lastgaelam = 0 - # change the last token and mask out all paddings to make this process easier if we rely on - # outcome reward to calculate V - for i in range(q.shape[0]): - if self.config.prime_use_gt: - q_[i, max_positions[i] - 1] = acc[i] - q_[i, : max_positions[i] - 1].sum() - q_[i, max_positions[i] :] = 0 - - for t in reversed(range(num_actions)): - delta = q_[:, t] - lastgaelam = delta + lam * lastgaelam - r[:, t] = lastgaelam - - token_level_score = torch.zeros_like(q) - - if self.config.prime_granularity == "token": - for i in range(micro_batch["input_ids"].shape[0]): - token_level_score[i, : max_positions[i] - 1] = r[i, : max_positions[i] - 1] - elif self.config.prime_granularity == "whole": - for i in range(micro_batch["input_ids"].shape[0]): - token_level_score[i, max_positions[i] - 1] = r[i, : max_positions[i]] - else: - raise NotImplementedError - - return token_level_score, q - - def _optimizer_step(self): - assert self.config.model.optim.grad_clip is not None - - if isinstance(self.reward_module, FSDP): - grad_norm = self.reward_module.clip_grad_norm_(self.config.model.optim.grad_clip) - else: - grad_norm = torch.nn.utils.clip_grad_norm_( - self.reward_module.parameters(), max_norm=self.config.model.optim.grad_clip - ) - self.reward_optimizer.step() - return grad_norm - - def prime_norm(self, token_level_scores): - if self.config.prime_norm == "batch_norm": - reverse_cumsum = torch.cumsum(token_level_scores.flip(dims=[1]), dim=-1).flip(dims=[1]) - token_level_scores = token_level_scores / (reverse_cumsum.abs().max() + 1e-6) - return token_level_scores - - def compute_rm_score(self, data: DataProto): - self.reward_module.eval() - self.ref_module.eval() - micro_batch_size = data.meta_info["micro_batch_size"] - select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "acc"] - batch = data.select(batch_keys=select_keys).batch - use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - prompt_length = data.batch["input_ids"].shape[-1] - data.batch["responses"].shape[-1] - - if use_dynamic_bsz: - # split using dynamic bsz - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) - else: - micro_batches = batch.split(micro_batch_size) - - rm_scores_lst = [] - q_lst = [] - for micro_batch in micro_batches: - with torch.no_grad(): - rm_score, q = self._forward_micro_batch(micro_batch, prompt_length) - rm_scores_lst.append(rm_score) - q_lst.append(q) - rm_scores = torch.concat(rm_scores_lst, dim=0) - q = torch.concat(q_lst, dim=0) - - rm_scores = self.prime_norm(rm_scores) - - if use_dynamic_bsz: - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == rm_scores.size(0), f"{len(indices)} vs. {rm_scores.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - rm_scores = rm_scores[revert_indices] - - return ( - rm_scores, - q.detach(), - { - "reward_model/reward": rm_scores.sum(dim=-1).mean().item(), - "reward_model/raw_reward": q.sum(dim=-1).mean().item(), - }, - ) - - def update_rm(self, data: DataProto): - # make sure we are in training mode - self.reward_module.train() - metrics = {} - - beta = self.config.model.get("beta_train", 0.05) - - select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "acc", "prompts"] - - for key in ["Q_bc", "acc_bc"]: - if key in data.batch.keys(): - select_keys.append(key) - - batch = data.select(batch_keys=select_keys).batch - # Split to make minibatch iterator for updating the actor - # See PPO paper for details. https://arxiv.org/abs/1707.06347 - dataloader = batch.split(self.config.mini_batch_size) - - rm_scores_lst = [] - q_lst = [] - - for batch_idx, data in enumerate(dataloader): - # split batch into micro_batches - mini_batch = data - if self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) - else: - micro_batches = mini_batch.split(self.config.micro_batch_size_per_gpu) - self.gradient_accumulation = self.config.mini_batch_size // self.config.micro_batch_size_per_gpu - - self.reward_optimizer.zero_grad() - - for data in micro_batches: - data = data.to(get_device_name()) - attention_mask = data["attention_mask"] - acc = data["acc"] - - prompt_ids = data["prompts"] - prompt_length = prompt_ids.shape[-1] - - response_mask = attention_mask[:, prompt_length:] - - rm_score, q = self._forward_micro_batch(data, prompt_length) - - rm_scores_lst.append(rm_score) - q_lst.append(q.detach()) - - if self.config.model.loss_type == "ce": - dpo_loss = compute_ce_dpo_loss_rm(q, acc, response_mask=response_mask, beta=beta) - elif self.config.model.loss_type == "dpo": - # the implementation of dpo is actually detached, which means we have to know the average - # value of w/l reward before the update. - dpo_loss = compute_detach_dpo_loss_rm( - q, acc, Q_bc=data["Q_bc"], acc_bc=data["acc_bc"], response_mask=response_mask, beta=beta - ) - elif self.config.model.loss_type == "bon_acc": - # change the original distribution of each sample to BoN distribution, then update reward model - dpo_loss = compute_detach_dpo_loss_rm( - q, - acc, - Q_bc=data["Q_bc"], - acc_bc=data["acc_bc"], - response_mask=response_mask, - beta=beta, - bon_mode="bon_acc", - ) - elif self.config.model.loss_type == "bon_rm": - dpo_loss = compute_detach_dpo_loss_rm( - q, - acc, - Q_bc=data["Q_bc"], - acc_bc=data["acc_bc"], - response_mask=response_mask, - beta=beta, - bon_mode="bon_rm", - ) - else: - raise NotImplementedError - - data = {"reward_model/dpo_loss": dpo_loss.detach().item()} - - if self.config.use_dynamic_bsz: - # relative to the dynamic bsz - loss = dpo_loss * (len(data) / self.config.ppo_mini_batch_size) - else: - loss = dpo_loss / self.gradient_accumulation - - loss.backward() - - append_to_dict(metrics, data) - - grad_norm = self._optimizer_step() - data = {"reward_model/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) - self.reward_optimizer.zero_grad() - - rm_scores = torch.cat(rm_scores_lst, dim=0) - q = torch.concat(q_lst, dim=0) - - rm_scores = self.prime_norm(rm_scores) - - metrics.update( - { - "reward_model/reward": rm_scores.sum(dim=-1).mean().item(), - "reward_model/raw_reward": q.sum(dim=-1).mean().item(), - } - ) - - return rm_scores, metrics diff --git a/recipe/prime/prime_fsdp_workers.py b/recipe/prime/prime_fsdp_workers.py deleted file mode 100644 index e35340464..000000000 --- a/recipe/prime/prime_fsdp_workers.py +++ /dev/null @@ -1,381 +0,0 @@ -# Copyright 2024 PRIME team and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import os -import warnings - -import torch -import torch.distributed -from torch.distributed.device_mesh import init_device_mesh - -from verl import DataProto -from verl.models.transformers.monkey_patch import apply_monkey_patch -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register -from verl.utils import hf_tokenizer -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.device import get_device_id, get_device_name, get_nccl_backend -from verl.utils.flops_counter import FlopsCounter -from verl.utils.fs import copy_local_path_from_hdfs -from verl.utils.fsdp_utils import ( - get_fsdp_wrap_policy, - get_init_weight_context_manager, - init_fn, - load_fsdp_model_to_gpu, - load_fsdp_optimizer, - offload_fsdp_model_to_cpu, - offload_fsdp_optimizer, -) -from verl.utils.import_utils import import_external_libs -from verl.utils.profiler import log_gpu_memory_usage -from verl.workers.fsdp_workers import create_device_mesh, get_sharding_strategy -from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager - -from .prime_core_algos import compute_dpo_abs_accuracy, compute_dpo_accuracy - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class PRIMERewardModelWorker(Worker): - def __init__(self, config): - super().__init__() - import torch.distributed - - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend=get_nccl_backend()) - self.config = config - - # build device mesh for Ulysses Sequence Parallel - world_size = torch.distributed.get_world_size() - - fsdp_size = self.config.model.fsdp_config.fsdp_size - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) - - self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) - dp = world_size // self.ulysses_sequence_parallel_size - if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh( - get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] - ) - - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - - # set FSDP offload params - self._is_offload_param = self.config.model.fsdp_config.param_offload - self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload - - # normalize config - self.config.mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size - if self.config.micro_batch_size is not None: - self.config.micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size - self.config.micro_batch_size_per_gpu = self.config.micro_batch_size - assert self.config.mini_batch_size % self.config.micro_batch_size_per_gpu == 0 - - def _build_reward_ref_model_optimizer(self, config): - # the following line is necessary - from torch import optim - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import MixedPrecision - - from verl.utils.model import print_model_size - from verl.utils.torch_dtypes import PrecisionType - - local_path = copy_local_path_from_hdfs(config.model.path) - - tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) - self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) - - from omegaconf import OmegaConf - - override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) - override_config_kwargs = { - "bos_token_id": self.tokenizer.bos_token_id, - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - } - override_config_kwargs.update(override_config) - if self.rank == 0: - print(f"Reward model overriding config {override_config_kwargs}") - - torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") - torch_dtype = PrecisionType.to_dtype(torch_dtype) - - from transformers import AutoConfig, AutoModelForCausalLM - - trust_remote_code = False - reward_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) - reward_model_config.num_labels = 1 - - init_context = get_init_weight_context_manager(use_meta_tensor=not reward_model_config.tie_word_embeddings) - with init_context(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - reward_model_config.classifier_dropout = 0.0 - reward_model_config.hidden_dropout = "0" - reward_module = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=reward_model_config, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, - ) - - fused_kernel_options = config.model.get("fused_kernel_options", None) - fused_kernels_backend = ( - fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None - ) - - apply_monkey_patch( - model=reward_module, - ulysses_sp_size=self.ulysses_sequence_parallel_size, - use_remove_padding=config.model.get("use_remove_padding", False), - use_fused_kernels=config.model.get("use_fused_kernels", False), - fused_kernels_backend=fused_kernels_backend, - ) - - # some parameters may not in torch_dtype - reward_module.to(torch_dtype) - - if config.model.get("enable_gradient_checkpointing", False): - reward_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - if self.rank == 0: - print_model_size(reward_module) - - self.reward_model_config = reward_model_config - - fsdp_config = self.config.model.fsdp_config - mixed_precision_config = fsdp_config.get("mixed_precision", None) - if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) - reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) - buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) - else: - param_dtype = torch.bfloat16 - reduce_dtype = torch.float32 - buffer_dtype = torch.float32 - - mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - - auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config.wrap_policy) - - log_gpu_memory_usage("Before reward model FSDP", logger=None) - - fsdp_mesh = self.device_mesh - sharding_strategy = get_sharding_strategy(fsdp_mesh) - - with init_context(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - reward_model_config.classifier_dropout = 0.0 - reward_model_config.hidden_dropout = "0" - ref_module = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=copy_local_path_from_hdfs(config.model.ref_path), - torch_dtype=torch_dtype, - config=reward_model_config, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, - ) - - # some parameters may not in torch_dtype - ref_module.to(torch_dtype) - - reward_module = FSDP( - reward_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=get_device_id(), - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - sync_module_states=True, - forward_prefetch=False, - device_mesh=self.device_mesh, - cpu_offload=None, - ) - - log_gpu_memory_usage("After reward FSDP", logger=None) - - ref_module = FSDP( - ref_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=get_device_id(), - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - sync_module_states=True, - forward_prefetch=False, - device_mesh=self.device_mesh, - cpu_offload=None, - ) - - reward_optimizer = optim.AdamW( - reward_module.parameters(), - lr=config.model.optim.lr, - betas=config.model.optim.get("betas", (0.9, 0.999)), - weight_decay=config.model.optim.get("weight_decay", 1e-2), - ) - - total_steps = config.model.optim.get("total_training_steps", 0) - num_warmup_steps = int(config.model.optim.get("lr_warmup_steps", -1)) - if num_warmup_steps < 0: - num_warmup_steps_ratio = config.model.optim.get("lr_warmup_steps_ratio", 0.0) - num_warmup_steps = int(num_warmup_steps_ratio * total_steps) - - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") - - from verl.utils.torch_functional import get_constant_schedule_with_warmup - - reward_lr_scheduler = get_constant_schedule_with_warmup( - optimizer=reward_optimizer, num_warmup_steps=num_warmup_steps - ) - - return reward_module, ref_module, reward_optimizer, reward_lr_scheduler - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get("external_lib", None)) - - from .prime_dp_rm import DataParallelPRIMERewardModel - - self.reward_module, self.ref_module, self.reward_optimizer, self.reward_lr_scheduler = ( - self._build_reward_ref_model_optimizer(config=self.config) - ) - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.reward_module) - offload_fsdp_model_to_cpu(self.ref_module) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.reward_optimizer) - - self.rm = DataParallelPRIMERewardModel( - config=self.config, - reward_module=self.reward_module, - ref_module=self.ref_module, - reward_optimizer=self.reward_optimizer, - ) - - self.flops_counter = FlopsCounter(self.reward_model_config) - self.checkpoint_manager = FSDPCheckpointManager( - model=self.reward_module, - optimizer=self.reward_optimizer, - lr_scheduler=self.reward_lr_scheduler, - tokenizer=self.tokenizer, - ) - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_rm_score(self, data: DataProto): - data = data.to(get_device_name()) - - if self._is_offload_param: - load_fsdp_model_to_gpu(self.reward_module) - load_fsdp_model_to_gpu(self.ref_module) - micro_batch_size = self.config.micro_batch_size_per_gpu - data.meta_info["micro_batch_size"] = micro_batch_size - data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz - # perform forward computation - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) - rm_scores, q, metrics = self.rm.compute_rm_score(data=data) - - prompt_length = data.batch["prompts"].shape[-1] - response_mask = data.batch["attention_mask"][:, prompt_length:] - acc = data.batch["acc"] - - dpo_acc = compute_dpo_accuracy(rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"]) - dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"]) - - metrics["reward_model/dpo_acc"] = dpo_acc.detach().item() - metrics["reward_model/dpo_acc_abs"] = dpo_acc_abs.detach().item() - - output = DataProto.from_dict(tensors={"rm_scores": rm_scores, "q": q}, meta_info={"metrics": metrics}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) - - output = output.to("cpu") - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.reward_module) - offload_fsdp_model_to_cpu(self.ref_module) - return output - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def update_rm(self, data: DataProto): - data = data.to(get_device_name()) - if self._is_offload_param: - load_fsdp_model_to_gpu(self.ref_module) - load_fsdp_model_to_gpu(self.reward_module) - if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=get_device_id()) - - # perform forward computation - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) - - rm_scores, metrics = self.rm.update_rm(data=data) - - self.reward_lr_scheduler.step() - lr = self.reward_lr_scheduler.get_last_lr()[0] - metrics["rm/lr"] = lr - - prompt_length = data.batch["prompts"].shape[-1] - response_mask = data.batch["attention_mask"][:, prompt_length:] - acc = data.batch["acc"] - - dpo_acc_before = compute_dpo_accuracy( - rm_scores, acc, response_mask=response_mask, n_samples=data.meta_info["n"] - ) - dpo_acc_abs = compute_dpo_abs_accuracy(rm_scores, acc, response_mask, n_samples=data.meta_info["n"]) - - metrics["reward_model/dpo_acc_before"] = dpo_acc_before.detach().item() - metrics["reward_model/dpo_acc_abs_before"] = dpo_acc_abs.detach().item() - - output = DataProto.from_dict(tensors={"rm_scores": rm_scores}, meta_info={"metrics": metrics}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.reward_module) - offload_fsdp_model_to_cpu(self.ref_module) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.reward_optimizer) - output = output.to("cpu") - return output - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): - import torch - - if self._is_offload_param: - load_fsdp_model_to_gpu(self.reward_module) - - self.checkpoint_manager.save_checkpoint( - local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep - ) - - torch.distributed.barrier() - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.reward_module) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, local_path, del_local_after_load=True): - import torch - - if self._is_offload_param: - load_fsdp_model_to_gpu(self.reward_module) - - self.checkpoint_manager.load_checkpoint(local_path=local_path, del_local_after_load=del_local_after_load) - - torch.distributed.barrier() - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.reward_module) diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py deleted file mode 100644 index a5ad96431..000000000 --- a/recipe/prime/prime_ray_trainer.py +++ /dev/null @@ -1,575 +0,0 @@ -# Copyright 2024 PRIME team and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -FSDP PPO Trainer with Ray-based single controller. -This trainer supports model-agonistic model initialization with huggingface -""" - -import os -import statistics -import uuid -from copy import deepcopy -from pprint import pprint - -import numpy as np -import torch -from omegaconf import OmegaConf, open_dict - -from verl import DataProto -from verl.single_controller.ray import RayWorkerGroup -from verl.trainer.ppo.core_algos import agg_loss -from verl.trainer.ppo.metric_utils import _compute_response_info -from verl.trainer.ppo.ray_trainer import RayPPOTrainer, ResourcePoolManager, Role, WorkerType -from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path -from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn -from verl.utils.metric import reduce_metrics -from verl.utils.profiler.performance import simple_timer - -from . import prime_core_algos - - -def compute_advantage(data: DataProto, adv_estimator, config): - if adv_estimator == "rloo": - responses = data.batch["responses"] - response_length = responses.size(-1) - attention_mask = data.batch["attention_mask"] - response_mask = attention_mask[:, -response_length:] - advantages, returns = prime_core_algos.compute_rloo_advantage_return( - data, response_mask, config.actor_rollout_ref.rollout.n, config - ) - data.batch["advantages"] = advantages - data.batch["returns"] = returns - else: - raise NotImplementedError - return data - - -def compute_data_metrics(batch, use_critic=True): - advantages = batch.batch["advantages"] - returns = batch.batch["returns"] - - max_response_length = batch.batch["responses"].shape[-1] - - prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool() - response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool() - - max_prompt_length = prompt_mask.size(-1) - - response_info = _compute_response_info(batch) - prompt_length = response_info["prompt_length"] - response_length = response_info["response_length"] - - valid_adv = torch.masked_select(advantages, response_mask) - valid_returns = torch.masked_select(returns, response_mask) - - if use_critic: - values = batch.batch["values"] - valid_values = torch.masked_select(values, response_mask) - return_diff_var = torch.var(valid_returns - valid_values) - return_var = torch.var(valid_returns) - - metrics = { - # adv - "critic/advantages/mean": torch.mean(valid_adv).detach().item(), - "critic/advantages/max": torch.max(valid_adv).detach().item(), - "critic/advantages/min": torch.min(valid_adv).detach().item(), - # returns - "critic/returns/mean": torch.mean(valid_returns).detach().item(), - "critic/returns/max": torch.max(valid_returns).detach().item(), - "critic/returns/min": torch.min(valid_returns).detach().item(), - **( - { - # values - "critic/values/mean": torch.mean(valid_values).detach().item(), - "critic/values/max": torch.max(valid_values).detach().item(), - "critic/values/min": torch.min(valid_values).detach().item(), - # vf explained var - "critic/vf_explained_var": (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), - } - if use_critic - else {} - ), - # response length - "response_length/mean": torch.mean(response_length).detach().item(), - "response_length/max": torch.max(response_length).detach().item(), - "response_length/min": torch.min(response_length).detach().item(), - "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) - .detach() - .item(), - # prompt length - "prompt_length/mean": torch.mean(prompt_length).detach().item(), - "prompt_length/max": torch.max(prompt_length).detach().item(), - "prompt_length/min": torch.min(prompt_length).detach().item(), - "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), - } - return metrics - - -def compute_response_mask(data: DataProto): - responses = data.batch["responses"] - response_length = responses.size(1) - attention_mask = data.batch["attention_mask"] - return attention_mask[:, -response_length:] - - -def compute_timing_metrics(batch, timing_raw): - response_info = _compute_response_info(batch) - num_prompt_tokens = torch.sum(response_info["prompt_length"]).item() - num_response_tokens = torch.sum(response_info["response_length"]).item() - num_overall_tokens = num_prompt_tokens + num_response_tokens - - num_tokens_of_section = { - "gen": num_response_tokens, - **{name: num_overall_tokens for name in ["ref", "values", "adv", "update_critic", "update_actor"]}, - } - - return { - **{f"timing_s/{name}": value for name, value in timing_raw.items()}, - **{ - f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] - for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) - }, - } - - -class RayPRIMETrainer(RayPPOTrainer): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - # TODO: support each role have individual ray_worker_group_cls, - # i.e., support different backend of different role - def __init__( - self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, - reward_fn=None, - val_reward_fn=None, - device_name="cuda", - ): - # assert get_torch_device().is_available(), 'cuda must be available on driver' - - super().__init__( - config, - tokenizer, - role_worker_mapping, - resource_pool_manager, - ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn, - device_name=device_name, - ) - - self.use_critic = False - - def _validate_config(self): - super()._validate_config() - # TODO: Additional config checks can be added here - - def _create_dataloader(self, *args, **kwargs): - from torch.utils.data import DataLoader, RandomSampler, SequentialSampler - - # TODO: we have to make sure the batch size is divisible by the dp size - self.train_dataset = RLHFDataset( - data_files=self.config.data.train_files, tokenizer=self.tokenizer, config=self.config.data - ) - # use sampler for better ckpt resume - if self.config.data.shuffle: - train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed(self.config.data.get("seed", 1)) - sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator) - else: - sampler = SequentialSampler(data_source=self.train_dataset) - - self.train_dataloader = DataLoader( - dataset=self.train_dataset, - batch_size=int(self.config.data.train_batch_size * self.config.data.oversample_factor), - drop_last=True, - collate_fn=collate_fn, - sampler=sampler, - ) - - self.val_dataset = RLHFDataset( - data_files=self.config.data.val_files, tokenizer=self.tokenizer, config=self.config.data - ) - self.val_dataloader = DataLoader( - dataset=self.val_dataset, - batch_size=len(self.val_dataset), - shuffle=True, - drop_last=True, - collate_fn=collate_fn, - ) - - assert len(self.train_dataloader) >= 1 - assert len(self.val_dataloader) >= 1 - - print(f"Size of train dataloader: {len(self.train_dataloader)}") - print(f"Size of val dataloader: {len(self.val_dataloader)}") - - # inject total_training_steps to actor/critic optim_config. This is hacky. - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs - - if self.config.trainer.total_training_steps is not None: - total_training_steps = self.config.trainer.total_training_steps - - self.total_training_steps = total_training_steps - print(f"Total training steps: {self.total_training_steps}") - - OmegaConf.set_struct(self.config, True) - with open_dict(self.config): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps - self.config.critic.optim.total_training_steps = total_training_steps - - def _save_checkpoint(self): - # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join( - self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" - ) - print(f"local_global_step_folder: {local_global_step_folder}") - actor_local_path = os.path.join(local_global_step_folder, "actor") - - actor_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") - ) - self.actor_rollout_wg.save_checkpoint( - actor_local_path, - actor_remote_path, - self.global_steps, - ) - - if self.use_rm: - reward_local_path = os.path.join(local_global_step_folder, "reward") - reward_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "reward") - ) - self.rm_wg.save_checkpoint( - reward_local_path, - reward_remote_path, - self.global_steps, - ) - - # save dataloader - dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") - import dill - - torch.save(self.train_dataloader, dataloader_local_path, pickle_module=dill) - - # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join( - self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" - ) - with open(local_latest_checkpointed_iteration, "w") as f: - f.write(str(self.global_steps)) - - def _load_checkpoint(self): - if self.config.trainer.resume_mode == "disable": - return 0 - - # load from hdfs - if self.config.trainer.default_hdfs_dir is not None: - NotImplementedError("load from hdfs is not implemented yet") - else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path - if not os.path.isabs(checkpoint_folder): - working_dir = os.getcwd() - checkpoint_folder = os.path.join(working_dir, checkpoint_folder) - global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest - - # find global_step_folder - if self.config.trainer.resume_mode == "auto": - if global_step_folder is None: - print("Training from scratch") - return 0 - else: - if self.config.trainer.resume_mode == "resume_path": - assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert "global_step_" in self.config.trainer.resume_from_path, ( - "resume ckpt must specify the global_steps" - ) - global_step_folder = self.config.trainer.resume_from_path - if not os.path.isabs(global_step_folder): - working_dir = os.getcwd() - global_step_folder = os.path.join(working_dir, global_step_folder) - print(f"Load from checkpoint folder: {global_step_folder}") - # set global step - self.global_steps = int(global_step_folder.split("global_step_")[-1]) - - print(f"Setting global step to {self.global_steps}") - print(f"Resuming from {global_step_folder}") - - actor_path = os.path.join(global_step_folder, "actor") - reward_path = os.path.join(global_step_folder, "reward") - # load actor - self.actor_rollout_wg.load_checkpoint( - actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - # load rm - if self.use_rm: - self.rm_wg.load_checkpoint(reward_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load) - - # load dataloader, - # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, "data.pt") - self.train_dataloader = torch.load(dataloader_local_path) - if isinstance(self.train_dataloader.dataset, RLHFDataset): - self.train_dataloader.dataset.resume_dataset_state() - - def fit(self): - """ - The training loop of PPO. - The driver process only need to call the compute functions of the worker group through RPC to - construct the PPO dataflow. The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - - from verl.utils.tracking import Tracking - - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - - self.global_steps = 0 - - # load checkpoint before doing anything - self._load_checkpoint() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - assert val_metrics, f"{val_metrics=}" - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - # we start from step 1 - self.global_steps += 1 - - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - timing_raw = {} - - batch: DataProto = DataProto.from_single_dict(batch_dict) - - # pop those keys for generation - gen_batch = batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - - with simple_timer("step", timing_raw): - # generate a batch - with simple_timer("gen", timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - timing_raw.update(gen_batch_output.meta_info["timing"]) - gen_batch_output.meta_info.pop("timing", None) - - if self.config.algorithm.adv_estimator == "remax": - with simple_timer("gen_max", timing_raw): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - - batch = batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - batch.batch["reward_baselines"] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object - ) - # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - batch = batch.union(gen_batch_output) - - # Balance the number of valid tokens across DP ranks. - # NOTE: This usually changes the order of data in the `batch`, - # which won't affect the advantage calculation (since it's based on uid), - # but might affect the loss calculation (due to the change of mini-batching). - # TODO: Decouple the DP balancing and mini-batching. - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - - # verify - with simple_timer("verify", timing_raw): - scores = self.reward_fn.verify(batch) - metrics["acc"] = statistics.mean(scores) - - # filter the batch. 1/oversample_factor samples will be kept. - # If there is a filter, prompts passing it will be prioritized. - - batch = self.filter_and_downsample(scores, batch) - batch.meta_info["n"] = self.config.actor_rollout_ref.rollout.n - n_samples = self.config.actor_rollout_ref.rollout.n - - # recompute old_log_probs - with simple_timer("old_log_prob", timing_raw): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch["entropys"] - response_masks = compute_response_mask(batch) - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} - metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) - - if self.use_reference_policy: - # compute reference log_prob - with simple_timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - with simple_timer("adv", timing_raw): - if self.use_rm: - update_style = self.config.reward_model.model.get("update", "none") - if update_style == "none": # only run forward - reward_output = self.rm_wg.compute_rm_score(batch) - elif update_style == "after": # update and directly return the reward - reward_output = self.rm_wg.update_rm(batch) - elif update_style == "before": # update reward model, and then run forward - reward_output = self.rm_wg.update_rm(batch) - if "metrics" in reward_output.meta_info.keys(): - reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"]) - metrics.update(reward_output_metrics) - - reward_output = self.rm_wg.compute_rm_score(batch) - elif ( - update_style == "reverse" - ): # run forward to calculate statistics, then update reward model - reward_output = self.rm_wg.compute_rm_score(batch) - # broadcast q and acc tensor to each result - bc_td = DataProto.from_dict( - tensors={ - "Q_bc": reward_output.batch["q"] - .sum(dim=-1) - .view(-1, n_samples) - .unsqueeze(1) - .expand(-1, n_samples, -1) - .reshape(-1, n_samples), - "acc_bc": batch.batch["acc"] - .view(-1, n_samples) - .unsqueeze(1) - .expand(-1, n_samples, -1) - .reshape(-1, n_samples), - } - ) - batch = batch.union(bc_td) - reward_output = self.rm_wg.update_rm(batch) - else: - raise NotImplementedError - batch = batch.union(reward_output) - if "metrics" in reward_output.meta_info.keys(): - reward_output_metrics = reduce_metrics(reward_output.meta_info["metrics"]) - metrics.update(reward_output_metrics) - - # compute advantages, executed on the driver process - batch = compute_advantage( - batch, adv_estimator=self.config.algorithm.adv_estimator, config=self.config - ) - - # update actor - with simple_timer("update_actor", timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # validate - if ( - self.val_reward_fn is not None - and self.config.trainer.test_freq > 0 - and self.global_steps % self.config.trainer.test_freq == 0 - ): - with simple_timer("testing", timing_raw): - val_metrics: dict = self._validate() - metrics.update(val_metrics) - - if self.config.trainer.save_freq > 0 and self.global_steps % self.config.trainer.save_freq == 0: - with simple_timer("save_checkpoint", timing_raw): - self._save_checkpoint() - - # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - self.global_steps += 1 - - if self.global_steps >= self.total_training_steps: - # perform validation after training - if self.val_reward_fn is not None: - val_metrics = self._validate() - pprint(f"Final validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if ( - self.config.trainer.save_freq > 0 - and (self.global_steps - 1) % self.config.trainer.save_freq != 0 - ): - with simple_timer("save_checkpoint", timing_raw): - self._save_checkpoint() - return - - def filter_and_downsample(self, scores, batch: DataProto): - """ - downsample the batch according to oversample_factor - samples passing the filters will be prioritized - """ - n_samples = int(self.config.actor_rollout_ref.rollout.n) - reward_matrix = torch.tensor(scores).reshape(-1, n_samples) - - filter_mask = torch.ones((reward_matrix.shape[0]), dtype=torch.bool) - - if self.config.data.filter_accuracy: - acc_tensor = torch.mean(reward_matrix, dim=-1) - filter_mask[ - (acc_tensor > self.config.data.accuracy_upper_bound) - | (acc_tensor < self.config.data.accuracy_lower_bound) - ] = False - - if self.config.data.filter_truncate: - length_matrix = ( - batch.batch["attention_mask"][:, -batch.batch["responses"].shape[-1] :] - .sum(dim=-1) - .reshape(-1, n_samples) - ) - length_tensor = torch.max(length_matrix, dim=-1)[0] - filter_mask[length_tensor >= self.config.data.max_response_length - 1] = False - - reorder_index = torch.argsort(filter_mask, descending=True) - reorder_index = (reorder_index.unsqueeze(-1) * n_samples + torch.arange(0, n_samples).unsqueeze(0)).view(-1) - batch.reorder( - reorder_index[: int(len(batch) // self.config.data.oversample_factor)] - ) # this operation is inplace - - return batch diff --git a/recipe/prime/run_prime_qwen.sh b/recipe/prime/run_prime_qwen.sh deleted file mode 100644 index 145f31b7b..000000000 --- a/recipe/prime/run_prime_qwen.sh +++ /dev/null @@ -1,64 +0,0 @@ -set -x - - -gsm8k_train_path=$HOME/data/gsm8k/train.parquet -gsm8k_test_path=$HOME/data/gsm8k/test.parquet - -# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data -math_train_path=$HOME/data/math/train.parquet -math_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path', '$math_train_path']" -test_files="['$gsm8k_test_path', '$math_test_path']" - -model_path=PRIME-RL/Eurus-2-7B-SFT -# model_path=Qwen/Qwen2.5-0.5B-Instruct - -python3 -m recipe.prime.main_prime \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=64 \ - data.val_batch_size=6312 \ - data.max_prompt_length=1024 \ - data.max_response_length=3072 \ - data.filter_overlong_prompts=True \ - data.filter_accuracy=True \ - data.accuracy_lower_bound=0.2 \ - data.accuracy_upper_bound=0.8 \ - data.oversample_factor=4 \ - actor_rollout_ref.model.path=$model_path \ - actor_rollout_ref.actor.optim.lr=5e-7 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.n=4 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - algorithm.adv_estimator=rloo \ - algorithm.use_kl_in_reward=True \ - algorithm.kl_penalty=kl \ - algorithm.kl_ctrl.kl_coef=0.001 \ - reward_model.model.path=$model_path \ - reward_model.micro_batch_size_per_gpu=1 \ - reward_model.model.update=before \ - reward_model.model.beta_train=0.05 \ - reward_model.model.optim.lr=1e-6 \ - reward_model.model.optim.grad_clip=10.0 \ - reward_model.model.input_tokenizer=null \ - reward_model.mini_batch_size=64 \ - trainer.val_before_train=False \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='prime_example' \ - trainer.experiment_name='Eurus-2-7B-SFT-gsm8k' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=64 \ - trainer.test_freq=64 \ - trainer.total_epochs=15 $@ diff --git a/recipe/prime/run_prime_qwen_code.sh b/recipe/prime/run_prime_qwen_code.sh deleted file mode 100644 index e179c0858..000000000 --- a/recipe/prime/run_prime_qwen_code.sh +++ /dev/null @@ -1,61 +0,0 @@ -set -x - - -# download from https://huggingface.co/datasets/PRIME-RL/Eurus-2-RL-Data -code_train_path=$HOME/data/code/train.parquet -code_test_path=$HOME/data/code/test.parquet - -train_files="['$code_train_path']" -test_files="['$code_test_path']" - -model_path=PRIME-RL/Eurus-2-7B-SFT -# model_path=Qwen/Qwen2.5-0.5B-Instruct - -python3 -m recipe.prime.main_prime \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=64 \ - data.val_batch_size=6312 \ - data.max_prompt_length=1024 \ - data.max_response_length=3072 \ - data.filter_overlong_prompts=True \ - data.filter_accuracy=True \ - data.accuracy_lower_bound=0.2 \ - data.accuracy_upper_bound=0.8 \ - data.oversample_factor=4 \ - actor_rollout_ref.model.path=$model_path \ - actor_rollout_ref.actor.optim.lr=5e-7 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.n=4 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - algorithm.adv_estimator=rloo \ - algorithm.use_kl_in_reward=True \ - algorithm.kl_penalty=kl \ - algorithm.kl_ctrl.kl_coef=0.001 \ - reward_model.model.path=$model_path \ - reward_model.micro_batch_size_per_gpu=1 \ - reward_model.model.update=before \ - reward_model.model.beta_train=0.05 \ - reward_model.model.optim.lr=1e-6 \ - reward_model.model.optim.grad_clip=10.0 \ - reward_model.model.input_tokenizer=null \ - reward_model.mini_batch_size=64 \ - trainer.val_before_train=False \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='prime_example' \ - trainer.experiment_name='Eurus-2-7B-SFT-code' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=64 \ - trainer.test_freq=64 \ - trainer.total_epochs=15 $@ diff --git a/recipe/r1/README.md b/recipe/r1/README.md deleted file mode 100644 index ddd23bcc3..000000000 --- a/recipe/r1/README.md +++ /dev/null @@ -1,26 +0,0 @@ -# DeepSeek R1 Reproduction - -This recipe is under development, if you are interested, checkout the TODO list and join this project! https://github.com/volcengine/verl/issues/708 - -## Reproducing Evaluation - -Eval Results of DS-R1-Distill-Qwen2.5-1.5B (k=8) - -Dataset | Test Results | Reported --- | -- | -- -GPQA Diamond | 35.3 | 33.8 -LiveCodeBench | 16.9 | 16.9 -AIME 2024 | 30.4 | 28.9 -CNMO 2024 (en) | 45.1 | - -CNMO 2024 (zh) | 41.0 | - - ---- - -Eval Results (DS-R1) - -Dataset | Test Results (k=1) | Test Results (k=4) | Reported --- | -- | -- | -- -GPQA Diamond | 67.7 | 69.6 | 71.5 -LiveCodeBench | 64.7 | 63.1 | 65.9 -AIME 2024 | 86.7 | 79.2 | 79.8 -CNMO 2024 | 75.0 | 78.5 | 78.8 diff --git a/recipe/r1/__init__.py b/recipe/r1/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/recipe/r1/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/recipe/r1/config/evaluation.yaml b/recipe/r1/config/evaluation.yaml deleted file mode 100644 index 1bd9f4e93..000000000 --- a/recipe/r1/config/evaluation.yaml +++ /dev/null @@ -1,13 +0,0 @@ -data: - path: /tmp/math_Qwen2-7B-Instruct.parquet - prompt_key: prompt - response_key: responses - data_source_key: data_source - reward_model_key: reward_model - -custom_reward_function: - path: null - name: compute_score - -ray_init: - num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. \ No newline at end of file diff --git a/recipe/r1/data_process.py b/recipe/r1/data_process.py deleted file mode 100644 index fb41c8143..000000000 --- a/recipe/r1/data_process.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Preprocess the dataset to parquet format -""" - -import argparse -import os -from functools import partial - -from datasets import concatenate_datasets, load_dataset - -from verl.utils.hdfs_io import copy, makedirs - - -def example_map_fn(example, idx, process_fn, data_source, ability, split): - question, solution = process_fn(example) - data = { - "data_source": data_source, - "prompt": [{"role": "user", "content": question}], - "ability": ability, - "reward_model": {"style": "rule", "ground_truth": solution}, - "extra_info": {"split": split, "index": idx}, - } - return data - - -def build_aime2024_dataset(): - def process_aime2024(example): - return example["Problem"], str(example["Answer"]) - - data_source = "Maxwell-Jia/AIME_2024" - print(f"Loading the {data_source} dataset from huggingface...", flush=True) - dataset = load_dataset(data_source, split="train") - map_fn = partial( - example_map_fn, process_fn=process_aime2024, data_source=data_source, ability="English", split="test" - ) - dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) - return dataset - - -def build_gpqa_dimond_dataset(): - import random - - GPQA_QUERY_TEMPLATE = ( - "Answer the following multiple choice question. The last line of your response should be of the following " - "format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before " - "answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}" - ) - - def process_gpqa_diamond(example): - choices = [example["Incorrect Answer 1"], example["Incorrect Answer 2"], example["Incorrect Answer 3"]] - random.shuffle(choices) - gold_index = random.randint(0, 3) - choices.insert(gold_index, example["Correct Answer"]) - query_prompt = GPQA_QUERY_TEMPLATE.format( - A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=example["Question"] - ) - gold_choice = "ABCD"[gold_index] - return query_prompt, gold_choice - - data_source = "Idavidrein/gpqa" - print(f"Loading the {data_source} dataset from huggingface...", flush=True) - - dataset = load_dataset(data_source, "gpqa_diamond", split="train") - map_fn = partial( - example_map_fn, process_fn=process_gpqa_diamond, data_source=data_source, ability="Math", split="test" - ) - dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names) - return dataset - - -def build_cnmo2024_dataset(): - def process_cnmo2024(example): - return example["question"], example["answer"] - - data_source = "opencompass/LiveMathBench" - print(f"Loading the {data_source} dataset from huggingface...", flush=True) - - dataset_en = load_dataset(data_source, "v202412_CNMO_en", split="test") - map_fn_en = partial( - example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_en", ability="Math", split="test" - ) - dataset_en = dataset_en.map(map_fn_en, with_indices=True, remove_columns=dataset_en.column_names) - - dataset_zh = load_dataset(data_source, "v202412_CNMO_cn", split="test") - map_fn_zh = partial( - example_map_fn, process_fn=process_cnmo2024, data_source="opencompass/cnmo2024_zh", ability="Math", split="test" - ) - dataset_zh = dataset_zh.map(map_fn_zh, with_indices=True, remove_columns=dataset_zh.column_names) - - dataset = concatenate_datasets([dataset_en, dataset_zh]) - return dataset - - -def build_livecodebench_dataset(): - import base64 - import json - import pickle - import zlib - - def process_livecodebench(example): - # Construct Query Prompt - # From https://github.com/LiveCodeBench/LiveCodeBench/blob/998c52d394b836f15fff3b9a29866191108ff81b/lcb_runner/prompts/code_generation.py#L140 - query_prompt = ( - f"You will be given a question (problem specification) and will generate a correct Python program " - f"that matches the specification and passes all tests.\n\nQuestion: {example['question_content']}\n\n" - ) - if example["starter_code"]: - query_prompt += ( - f"You will use the following starter code to write the solution to the problem and enclose your " - f"code within delimiters.\n```python\n{example['starter_code']}\n```" - ) - else: - query_prompt += ( - "Read the inputs from stdin solve the problem and write the answer to stdout (do not directly test " - "on the sample inputs). Enclose your code within delimiters as follows. Ensure that when the python " - "program runs, it reads the inputs, runs the algorithm and writes output to STDOUT." - "```python\n# YOUR CODE HERE\n```" - ) - - # Construct test cases - public_test_cases = json.loads(example["public_test_cases"]) - try: - private_test_cases = json.loads(example["private_test_cases"]) - except Exception as e: - print(f"Error loading private test cases: {e}") - private_test_cases = json.loads( - pickle.loads(zlib.decompress(base64.b64decode(example["private_test_cases"].encode("utf-8")))) - ) - full_test_cases = public_test_cases + private_test_cases - - metadata = json.loads(example["metadata"]) - test_cases = { - "inputs": [t["input"] for t in full_test_cases], - "outputs": [t["output"] for t in full_test_cases], - "fn_name": metadata.get("func_name", None), - } - text_cases_compressed = base64.b64encode(zlib.compress(pickle.dumps(json.dumps(test_cases)))).decode("utf-8") - return query_prompt, text_cases_compressed - - data_source = "livecodebench/code_generation_lite" - print(f"Loading the {data_source} dataset from huggingface...", flush=True) - dataset = load_dataset(data_source, split="test") - # R1 Evaluation use LiveCodeBench 24.08-25.01 - dataset = dataset.filter(lambda line: "2024-08-00T00:00:00" <= line["contest_date"] < "2025-01-00T00:00:00") - map_fn = partial( - example_map_fn, process_fn=process_livecodebench, data_source=data_source, ability="Code", split="test" - ) - - dataset = dataset.map(map_fn, with_indices=True, remove_columns=dataset.column_names, num_proc=8) - return dataset - - -TASK2DATA = { - "aime2024": build_aime2024_dataset, - "gpqa_diamond": build_gpqa_dimond_dataset, - "cnmo2024": build_cnmo2024_dataset, - "livecodebench": build_livecodebench_dataset, -} -SUPPORTED_TASKS = TASK2DATA.keys() - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--local_dir", default="~/data/r1") - parser.add_argument("--hdfs_dir", default=None) - parser.add_argument("--tasks", default="all") - - args = parser.parse_args() - - if args.tasks.lower() == "all": - args.tasks = SUPPORTED_TASKS - else: - args.tasks = [task.strip() for task in args.tasks.split(",") if task.strip()] - for task in args.tasks: - if task not in SUPPORTED_TASKS: - raise NotImplementedError(f"{task} has not been supported.") - - datasets = [] - for task in args.tasks: - datasets.append(TASK2DATA[task]()) - test_dataset = concatenate_datasets(datasets) - - local_dir = args.local_dir - hdfs_dir = args.hdfs_dir - - test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) - - if hdfs_dir is not None: - makedirs(hdfs_dir) - - copy(src=local_dir, dst=hdfs_dir) diff --git a/recipe/r1/main_eval.py b/recipe/r1/main_eval.py deleted file mode 100644 index b9c03791b..000000000 --- a/recipe/r1/main_eval.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Offline evaluate the performance of a generated file using reward model and ground truth verifier. -The input is a parquet file that contains N generated sequences and (optional) the ground truth. - -""" - -from collections import defaultdict - -import hydra -import numpy as np -import pandas as pd -import ray -from tqdm import tqdm - -from verl.trainer.ppo.reward import get_custom_reward_fn -from verl.utils.fs import copy_to_local - - -@ray.remote -def process_item(config, data_source, response_lst, reward_data): - reward_fn = get_custom_reward_fn(config) - ground_truth = reward_data["ground_truth"] - score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst] - return data_source, np.mean(score_lst) - - -@hydra.main(config_path="config", config_name="evaluation", version_base=None) -def main(config): - local_path = copy_to_local(config.data.path) - dataset = pd.read_parquet(local_path) - responses = dataset[config.data.response_key] - data_sources = dataset[config.data.data_source_key] - reward_model_data = dataset[config.data.reward_model_key] - - total = len(dataset) - - # Initialize Ray - if not ray.is_initialized(): - ray.init(num_cpus=config.ray_init.num_cpus) - - # evaluate test_score based on data source - data_source_reward = defaultdict(list) - - # Create remote tasks - remote_tasks = [ - process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total) - ] - - # Process results as they come in - with tqdm(total=total) as pbar: - while len(remote_tasks) > 0: - # Use ray.wait to get completed tasks - done_ids, remote_tasks = ray.wait(remote_tasks) - for result_id in done_ids: - data_source, score = ray.get(result_id) - data_source_reward[data_source].append(score) - pbar.update(1) - - metric_dict = {} - for data_source, rewards in data_source_reward.items(): - metric_dict[f"test_score/{data_source}"] = np.mean(rewards) - - print(metric_dict) - - -if __name__ == "__main__": - main() diff --git a/recipe/r1/reward_score.py b/recipe/r1/reward_score.py deleted file mode 100644 index 2010665aa..000000000 --- a/recipe/r1/reward_score.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def reward_func(data_source, solution_str, ground_truth, extra_info=None): - if data_source in ["Maxwell-Jia/AIME_2024", "opencompass/cnmo2024_en", "opencompass/cnmo2024_zh"]: - from recipe.r1.tasks import math - - return math.compute_score(solution_str, ground_truth) - elif data_source == "Idavidrein/gpqa": - from recipe.r1.tasks import gpqa - - return gpqa.compute_score(solution_str, ground_truth) - elif data_source in ["livecodebench/code_generation_lite", "livecodebench/code_generation"]: - from recipe.r1.tasks import livecodebench - - return livecodebench.compute_score(solution_str, ground_truth) - else: - raise NotImplementedError diff --git a/recipe/r1/run_r1_distill_qwen.sh b/recipe/r1/run_r1_distill_qwen.sh deleted file mode 100644 index a1aa9edcc..000000000 --- a/recipe/r1/run_r1_distill_qwen.sh +++ /dev/null @@ -1,33 +0,0 @@ -MODEL_PATH=Qwen/DeepSeek-R1-Distill-Qwen-1.5B -DATA_PATH=/workspace/datasets/r1_bench - -# Eval Data Process -python3 -m recipe.r1.data_process \ - --local_dir $DATA_PATH \ - --tasks all - -# Generation -python3 -m verl.trainer.main_generation \ - trainer.nnodes=1 \ - trainer.n_gpus_per_node=8 \ - data.path=$DATA_PATH/test.parquet \ - data.prompt_key=prompt \ - data.batch_size=1024 \ - data.n_samples=8 \ - data.output_path=$DATA_PATH/test-output-8.parquet \ - model.path=$MODEL_PATH \ - rollout.temperature=0.6 \ - rollout.top_p=0.95 \ - rollout.prompt_length=1024 \ - rollout.response_length=32768 \ - rollout.tensor_model_parallel_size=1 \ - rollout.gpu_memory_utilization=0.9 \ - rollout.max_num_batched_tokens=65536 - -# Evaluation -python3 -m recipe.r1.main_eval \ - data.path=$DATA_PATH/test-output-8.parquet \ - data.prompt_key=prompt \ - data.response_key=responses \ - custom_reward_function.path=recipe/r1/reward_score.py \ - custom_reward_function.name=reward_func diff --git a/recipe/r1/tasks/__init__.py b/recipe/r1/tasks/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/recipe/r1/tasks/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/recipe/r1/tasks/gpqa.py b/recipe/r1/tasks/gpqa.py deleted file mode 100644 index 65b37e916..000000000 --- a/recipe/r1/tasks/gpqa.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re - -# Extraction Template from https://github.com/openai/simple-evals/blob/90e3e821cabba2aeb6be651dcb662b253df04225/common.py#L25 -ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?" - - -def compute_score(solution_str, ground_truth) -> float: - match = re.search(ANSWER_PATTERN_MULTICHOICE, solution_str) - extracted_answer = match.group(1) if match else None - score = 1.0 if extracted_answer == ground_truth else 0.0 - return score diff --git a/recipe/r1/tasks/livecodebench.py b/recipe/r1/tasks/livecodebench.py deleted file mode 100644 index f0cbab681..000000000 --- a/recipe/r1/tasks/livecodebench.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -import json -import multiprocessing -import pickle -import zlib - -# Reuse `run_test` for convenience -from verl.utils.reward_score.prime_code.testing_util import run_test - - -def _temp_run(in_outs, generation, debug, result, metadata_list, timeout): - res, metadata = run_test(in_outs, test=generation, debug=debug, timeout=timeout) - result.append(res) - metadata_list.append(metadata) - - -def check_correctness(in_outs, generation, timeout, debug=True): - """Check correctness of code generation with a global timeout. - The global timeout is to catch some extreme/rare cases not handled by the timeouts - inside `run_test`""" - - manager = multiprocessing.Manager() - result = manager.list() - metadata_list = manager.list() - p = multiprocessing.Process( - target=_temp_run, - args=(in_outs, generation, debug, result, metadata_list, timeout), - ) - p.start() - p.join(timeout=(timeout + 1) * len(in_outs["inputs"]) + 5) - if p.is_alive(): - p.kill() - if not result: - # consider that all tests failed - result = [[-1 for i in range(len(in_outs["inputs"]))]] - if debug: - print("global timeout") - return result[0], metadata_list[0] - - -def compute_score(completion, test_cases): - solution = completion.split("```python")[-1].split("```")[0] - - # extract test cases - try: - in_outs = json.loads(test_cases) - except Exception as e: - print(f"Error loading test cases: {e}") - in_outs = json.loads(pickle.loads(zlib.decompress(base64.b64decode(test_cases.encode("utf-8"))))) - - success = False - try: - res, metadata = check_correctness(in_outs=in_outs, generation=solution, timeout=6, debug=False) - success = all(map(lambda x: x is True, res)) - except Exception: - pass - - return success diff --git a/recipe/r1/tasks/math.py b/recipe/r1/tasks/math.py deleted file mode 100644 index 5ecde5494..000000000 --- a/recipe/r1/tasks/math.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import contextlib - -try: - from math_verify.metric import math_metric - from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig -except ImportError: - print("To use Math-Verify, please install it first by running `pip install math-verify`.") - - -def compute_score(model_output: str, ground_truth: str) -> bool: - verify_func = math_metric( - gold_extraction_target=(LatexExtractionConfig(),), - pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), - ) - ret_score = 0.0 - - # Wrap the ground truth in \boxed{} format for verification - ground_truth_boxed = "\\boxed{" + ground_truth + "}" - with contextlib.suppress(Exception): - ret_score, _ = verify_func([ground_truth_boxed], [model_output]) - - return ret_score diff --git a/recipe/retool/retool.py b/recipe/retool/retool.py deleted file mode 100644 index b4d6028ff..000000000 --- a/recipe/retool/retool.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import re -from typing import Any - -import datasets - -from verl.tools.base_tool import OpenAIFunctionToolSchema -from verl.tools.sandbox_fusion_tools import SandboxFusionTool -from verl.utils.dataset import RLHFDataset -from verl.utils.reward_score import math_dapo -from verl.utils.rollout_trace import rollout_trace_op - -logger = logging.getLogger(__name__) - - -class CustomSandboxFusionTool(SandboxFusionTool): - def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): - super().__init__(config, tool_schema) - self.code_pattern = re.compile(r"```python(.*?)```", re.DOTALL) - - @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - code = parameters["code"] - matches = self.code_pattern.findall(code) - if matches: - code = matches[0].strip() - - # NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script - lines = code.split("\n") - for i, line in reversed(list(enumerate(lines))): - if line == "": - continue - if not lines[i].startswith("print"): - lines[i] = f"print({line})" - break - code = "\n".join(lines) - - timeout = parameters.get("timeout", self.default_timeout) - language = parameters.get("language", self.default_language) - if not isinstance(code, str): - code = str(code) - - result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) - # sandbox has no score or metrics, use Nones - return result, None, None - - -answer_format = """\nThe answer format must be: \\boxed{'The final answer goes here.'}""" - - -class CustomRLHFDataset(RLHFDataset): - """Custom dataset class to process Maxwell-Jia/AIME_2024, yentinglin/aime_2025 datasets.""" - - def _read_files_and_tokenize(self): - dataframes = [] - for parquet_file in self.data_files: - # read parquet files and cache - dataframe = datasets.load_dataset(parquet_file)["train"] - data_source = "/".join(parquet_file.split("/")[-2:]) - if data_source in ["Maxwell-Jia/AIME_2024", "yentinglin/aime_2025"]: - dataframe = dataframe.map( - self.map_fn, fn_kwargs={"data_source": data_source}, remove_columns=dataframe.column_names - ) - else: - dataframe = dataframe.map(self.map_fn2, num_proc=16) - dataframes.append(dataframe) - self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) - - print(f"dataset len: {len(self.dataframe)}") - - def map_fn(self, row: dict, *, data_source: str = None): - if data_source == "Maxwell-Jia/AIME_2024": - problem, answer = row["Problem"], row["Answer"] - elif data_source == "yentinglin/aime_2025": - problem, answer = row["problem"], row["answer"] - - prompt = problem + answer_format - data = { - "data_source": data_source.split("/")[1].lower(), # aime_2024, aime_2025 - "prompt": [{"role": "user", "content": prompt}], - "ability": "MATH", - "reward_model": {"ground_truth": str(answer)}, - "agent_name": "tool_agent", - } - return data - - def map_fn2(self, row: dict): - content = row["prompt"][0]["content"] - row["prompt"][0]["content"] = content + answer_format - row["agent_name"] = "tool_agent" - return row - - -def compute_score(data_source, solution_str, ground_truth, extra_info): - # use \\boxed{...} answer - result = math_dapo.compute_score(solution_str, ground_truth, strict_box_verify=True) - - # encourage model to call tools - num_turns = extra_info["num_turns"] - if result["score"] < 0: - tool_call_reward = (num_turns - 2) / 2 * 0.1 - result["score"] = min(0, result["score"] + tool_call_reward) - - if result["pred"] is None: - result["pred"] = "" - - return result diff --git a/recipe/retool/retool_multi_turn_sft_preprocess.py b/recipe/retool/retool_multi_turn_sft_preprocess.py deleted file mode 100644 index 201ee6892..000000000 --- a/recipe/retool/retool_multi_turn_sft_preprocess.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Preprocess the Retool dataset to parquet format -""" - -import argparse -import os - -import datasets - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--local_dir", default="~/data/retool_multiturn") - parser.add_argument("--hdfs_dir", default=None) - parser.add_argument("--train_ratio", default=0.9, type=float) - parser.add_argument("--seed", default=42, type=int) - args = parser.parse_args() - - data_source = "swordfaith/ReTool-SFT-multi-turn" - dataset = datasets.load_dataset(data_source, "default") - - train_dataset = dataset["train"] - shuffled_train_dataset = train_dataset.shuffle(seed=args.seed) - split_idx = int(len(shuffled_train_dataset) * args.train_ratio) - train_dataset = shuffled_train_dataset.select(range(split_idx)) - test_dataset = shuffled_train_dataset.select(range(split_idx, len(shuffled_train_dataset))) - - # add a row to each data item that represents a unique id - def make_map_fn(split): - def process_fn(example, idx): - messages = example.pop("messages") - tools = example.pop("tools") - data = { - "data_source": data_source, - "messages": messages, - "tools": tools, - "enable_thinking": False, - "extra_info": { - "split": split, - "index": idx, - }, - } - return data - - return process_fn - - train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) - test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) - - # Create output directory - local_dir = os.path.expanduser(args.local_dir) - os.makedirs(local_dir, exist_ok=True) - - # Save to parquet files - local_dir = args.local_dir - hdfs_dir = args.hdfs_dir - - train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) - test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) - - # Handle HDFS if specified - if hdfs_dir is not None: - try: - from verl.utils.hdfs_io import copy, makedirs - - makedirs(hdfs_dir) - copy(src=local_dir, dst=hdfs_dir) - except ImportError: - print("Warning: HDFS support not available. Skipping HDFS copy.") - - # Print statistics - print(f"Train dataset size: {len(train_dataset)}") - print(f"Test dataset size: {len(test_dataset)}") - print(f"Data saved to {local_dir}") - - -if __name__ == "__main__": - main() diff --git a/recipe/retool/retool_sft_preprocess.py b/recipe/retool/retool_sft_preprocess.py deleted file mode 100644 index 0a46c1522..000000000 --- a/recipe/retool/retool_sft_preprocess.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Convert JoeYing/ReTool-SFT to standard multi-turn tool calling messages. -""" - -import json -import re -from typing import Any - -import datasets -from omegaconf import OmegaConf - -code_pattern = re.compile(r"```python(.*?)```", re.DOTALL) - - -def extract_code_message(content: str) -> tuple[dict[str, Any], str]: - start, stop = "", "" - i = content.find(start) - if i == -1: - return None, content - j = content.find(stop) - assert j > i - - code = content[i + len(start) : j] - matches = code_pattern.findall(code) - if matches: - code = matches[0].strip() - - message = { - "role": "assistant", - "content": content[:i].strip(), - "tool_calls": [ - { - "type": "function", - "function": { - "name": "code_interpreter", - "arguments": {"code": code}, - }, - }, - ], - } - return message, content[j + len(stop) :] - - -def extract_answer_message(content: str) -> tuple[dict[str, Any], str]: - start, stop = "", "" - i = content.find(start) - if i == -1: - return None, content - j = content.find(stop) - assert j > i - - answer = content[:i] + content[i + len(start) : j] - message = { - "role": "assistant", - "content": answer.strip(), - } - return message, content[j + len(stop) :] - - -def extract_interpreter_message(content: str) -> tuple[dict[str, Any], str]: - start, stop = "", "" - i = content.find(start) - if i == -1: - return None, content - j = content.find(stop) - assert j > i - - interpreter = content[i + len(start) : j] - message = { - "role": "tool", - "content": interpreter.strip(), - } - return message, content[j + len(stop) :] - - -def process(row: dict, *, tools: str): - messages = [] - - # extract problem - content = row["messages"][0]["content"] - start = "*user question:*" - i = content.find(start) - assert i != -1 - prompt = content[i + len(start) :].replace("", "").replace("", "").strip() - messages.append( - { - "role": "user", - "content": prompt, - } - ) - - # extract multi turns - content = row["messages"][1]["content"] - role = "assistant" - while len(content) > 0: - if role == "assistant": - message, content = extract_code_message(content) - if message is None: - message, content = extract_answer_message(content) - assert message is not None - messages.append(message) - role = "tool" - else: - message, content = extract_interpreter_message(content) - assert message is not None - messages.append(message) - role = "assistant" - - return {"messages": messages, "tools": tools} - - -if __name__ == "__main__": - tools_config_file = "recipe/retool/sandbox_fusion_tool_config.yaml" - tools_config = OmegaConf.load(tools_config_file) - tool_schema = OmegaConf.to_container(tools_config["tools"][0]["tool_schema"]) - tools = json.dumps([tool_schema]) - - data = datasets.load_dataset("JoeYing/ReTool-SFT")["train"] - data = data.map(process, fn_kwargs={"tools": tools}) - data.to_parquet("wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet") diff --git a/recipe/retool/run_qwen2-32b_sft.sh b/recipe/retool/run_qwen2-32b_sft.sh deleted file mode 100644 index 137698138..000000000 --- a/recipe/retool/run_qwen2-32b_sft.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash -set -x - -# set dist args -nproc_per_node=${ARNOLD_WORKER_GPU} -if [ ! -z "$SINGLE" ] && [ "$SINGLE" != "0" ]; then - echo "[single node alone] SINGLE=$SINGLE" - MASTER_NODE_ID=${ARNOLD_ID} - nnodes=1 - node_rank=0 -else - MASTER_NODE_ID=0 - nnodes=${ARNOLD_WORKER_NUM} - node_rank=${ARNOLD_ID} -fi -master_addr="METIS_WORKER_${MASTER_NODE_ID}_HOST" -master_addr=${!master_addr} -master_port="METIS_WORKER_${MASTER_NODE_ID}_PORT" -master_port=${!master_port} -ports=(`echo $master_port | tr ',' ' '`) -master_port=${ports[0]} -echo "[nproc_per_node: ${nproc_per_node}]" -echo "[nnodes: ${nnodes}]" -echo "[node_rank: ${node_rank}]" -echo "[master_addr: ${master_addr}]" -echo "[master_port: ${master_port}]" - -experiment_name=multiturn-sft-qwen-2.5-32b-instruct -HDFS_ROOT=${HDFS_ROOT:-$PWD} -DATA_ROOT=${DATA_ROOT:-$PWD} - -TRAIN_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet -EVAL_DATA=$DATA_ROOT/dataset/wuxibin/ReTool-SFT/data/train-00000-of-00001.parquet -MODEL_PATH=$HDFS_ROOT/model/Qwen2.5-32B-Instruct -SAVE_PATH=$DATA_ROOT/checkpoint/$experiment_name - -torchrun --nnodes=$ARNOLD_WORKER_NUM \ - --nproc_per_node=$ARNOLD_WORKER_GPU \ - --master-addr=$master_addr \ - --master-port=$master_port \ - --node-rank=$node_rank \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$TRAIN_DATA \ - data.val_files=$EVAL_DATA \ - data.max_length=16384 \ - data.train_batch_size=32 \ - data.multiturn.enable=true \ - data.multiturn.messages_key=messages \ - data.multiturn.tools_key=tools \ - data.micro_batch_size_per_gpu=4 \ - model.partial_pretrain=$MODEL_PATH \ - model.strategy=fsdp \ - trainer.default_local_dir=$SAVE_PATH \ - trainer.project_name=wuxibin-multiturn-sft \ - trainer.experiment_name=$experiment_name \ - trainer.logger='["console","wandb"]' \ - trainer.total_epochs=6 \ - ulysses_sequence_parallel_size=4 \ - use_remove_padding=true \ No newline at end of file diff --git a/recipe/retool/run_qwen2.5_32b_sp8.sh b/recipe/retool/run_qwen2.5_32b_sp8.sh deleted file mode 100644 index 4d6daa1dd..000000000 --- a/recipe/retool/run_qwen2.5_32b_sp8.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -set -x - -export PYTHONUNBUFFERED=1 -export RUST_BACKTRACE=1 -export HYDRA_FULL_ERROR=1 - -ulimit -n 65535 - -EXPERIMENT_NAME=retool-multiturn-sft-qwen2.5-32b-sp8 - -torchrun --nnodes=1 --nproc_per_node=8 \ - -m verl.trainer.fsdp_sft_trainer \ - data.max_length=16384 \ - data.train_batch_size=128 \ - data.micro_batch_size_per_gpu=4 \ - data.train_files=$HOME/data/retool_multi_turn_sft_preprocessed/train.parquet \ - data.val_files=$HOME/data/retool_multi_turn_sft_preprocessed/test.parquet \ - data.multiturn.enable=true \ - data.multiturn.messages_key=messages \ - data.multiturn.tools_key=tools \ - model.partial_pretrain=$HOME/models/Qwen/Qwen2.5-32B-Instruct \ - model.trust_remote_code=true \ - model.fsdp_config.cpu_offload=true \ - model.fsdp_config.offload_params=true \ - optim.lr=1e-6 \ - trainer.default_local_dir=$HOME/checkpoints/retool-multiturn-sft/$EXPERIMENT_NAME \ - trainer.project_name=retool-multiturn-sft \ - trainer.experiment_name=$EXPERIMENT_NAME \ - trainer.logger='["console","wandb"]' \ - trainer.total_epochs=12 $@ \ - ulysses_sequence_parallel_size=8 \ - use_remove_padding=true diff --git a/recipe/retool/run_qwen2.5_7b_sp4.sh b/recipe/retool/run_qwen2.5_7b_sp4.sh deleted file mode 100644 index 9265dbbac..000000000 --- a/recipe/retool/run_qwen2.5_7b_sp4.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -set -x - -export PYTHONUNBUFFERED=1 -export RUST_BACKTRACE=1 -export HYDRA_FULL_ERROR=1 - -ulimit -n 65535 - -EXPERIMENT_NAME=retool-multiturn-sft-qwen2.5-7b-sp4 - -torchrun --nnodes=1 --nproc_per_node=8 \ - -m verl.trainer.fsdp_sft_trainer \ - data.max_length=16384 \ - data.train_batch_size=128 \ - data.micro_batch_size_per_gpu=16 \ - data.train_files=$HOME/data/retool_multi_turn_sft_preprocessed/train.parquet \ - data.val_files=$HOME/data/retool_multi_turn_sft_preprocessed/test.parquet \ - data.multiturn.enable=true \ - data.multiturn.messages_key=messages \ - data.multiturn.tools_key=tools \ - model.partial_pretrain=$HOME/models/Qwen/Qwen2.5-7B-Instruct \ - model.trust_remote_code=true \ - model.fsdp_config.cpu_offload=false \ - model.fsdp_config.offload_params=false \ - optim.lr=1e-6 \ - trainer.default_local_dir=$HOME/checkpoints/retool-multiturn-sft/$EXPERIMENT_NAME \ - trainer.project_name=retool-multiturn-sft \ - trainer.experiment_name=$EXPERIMENT_NAME \ - trainer.logger='["console","wandb"]' \ - trainer.total_epochs=8 $@ \ - ulysses_sequence_parallel_size=4 \ - use_remove_padding=true diff --git a/recipe/retool/run_qwen3_4b_sp4.sh b/recipe/retool/run_qwen3_4b_sp4.sh deleted file mode 100644 index 23ec986e3..000000000 --- a/recipe/retool/run_qwen3_4b_sp4.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash -set -x - -export PYTHONUNBUFFERED=1 -export RUST_BACKTRACE=1 -export HYDRA_FULL_ERROR=1 - -ulimit -n 65535 - -EXPERIMENT_NAME=retool-multiturn-sft-qwen3-4b-sp4 - -torchrun --nnodes=1 --nproc_per_node=8 \ - -m verl.trainer.fsdp_sft_trainer \ - data.max_length=16384 \ - data.train_batch_size=128 \ - data.micro_batch_size_per_gpu=16 \ - data.train_files=$HOME/data/retool_multi_turn_sft_preprocessed/train.parquet \ - data.val_files=$HOME/data/retool_multi_turn_sft_preprocessed/test.parquet \ - data.multiturn.enable=true \ - data.multiturn.messages_key=messages \ - data.multiturn.tools_key=tools \ - model.partial_pretrain=$HOME/models/Qwen/Qwen3-4B \ - model.trust_remote_code=true \ - optim.lr=1e-6 \ - trainer.default_local_dir=$HOME/checkpoints/retool-multiturn-sft/$EXPERIMENT_NAME \ - trainer.project_name=retool-multiturn-sft \ - trainer.experiment_name=$EXPERIMENT_NAME \ - trainer.logger='["console","wandb"]' \ - trainer.total_epochs=12 $@ \ - ulysses_sequence_parallel_size=4 \ - use_remove_padding=true diff --git a/recipe/retool/sandbox_fusion_tool_config.yaml b/recipe/retool/sandbox_fusion_tool_config.yaml deleted file mode 100644 index 203457155..000000000 --- a/recipe/retool/sandbox_fusion_tool_config.yaml +++ /dev/null @@ -1,24 +0,0 @@ -tools: - - class_name: "recipe.retool.retool.CustomSandboxFusionTool" - config: - sandbox_fusion_url: "https://***.apigateway-cn-beijing.volceapi.com/run_code" - num_workers: 128 - enable_global_rate_limit: true - rate_limit: 128 - default_timeout: 30 - default_language: "python" - memory_limit_mb: 1024 - type: native - - tool_schema: - type: "function" - function: - name: "code_interpreter" - description: "A tool for executing code." - parameters: - type: "object" - properties: - code: - type: "string" - description: "The code to execute." - required: ["code"] diff --git a/recipe/spin/README.md b/recipe/spin/README.md deleted file mode 100644 index 0fc35ba7b..000000000 --- a/recipe/spin/README.md +++ /dev/null @@ -1,179 +0,0 @@ -# SPIN: Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models - -This repository hosts a `verl` recipe inspired by the paper **"Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models"** (SPIN). SPIN is a language model finetuning algorithm that enables iterative self-improvement through a self-play mechanism inspired by game theory. - -**Core Idea:** Models learn by playing against themselves, reducing reliance on external preference datasets or stronger teacher models: - -1. **Synthetic Data Generation:** The current model generates responses, creating its own training data from previous iterations. -2. **Two-Player Game Setup:** A game involving two players acted by a single LLM. -3. **Iterative Training:** The model progressively improves by refining its policy, with each iteration's model becoming the opponent for the next iteration. - -Paper Authors: [Zixiang Chen](https://github.com/uclaml/SPIN)\*, [Yihe Deng](https://github.com/uclaml/SPIN)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) - -[[Webpage](https://uclaml.github.io/SPIN/)] [[Huggingface](https://huggingface.co/papers/2401.01335)] [[Paper](https://arxiv.org/abs/2401.01335)] [[Original Implementation](https://github.com/uclaml/SPIN)] - -verl Implementation Authors: [Chendong Wang](https://cdwang96.github.io/), [Chenyang Zhao](https://github.com/zhaochenyang20) - ---- - -## Key Function (compute_online_dpo_loss) and Related works -SPIN (Chen et al., 2024) proposes an iterative self-play mechanism to fine-tune language models. In each iteration, SPIN's training objective, when using a logistic loss function, is equivalent to Direct Preference Optimization (DPO) loss (Rafailov et al., 2023). - -This `verl` recipe realizes SPIN's core concept by using DPO loss iteratively (Xu et al., 2023; Xiong et al., 2023; Snorkel AI, 2024). This means that in each iteration, we fine-tune the LLM using DPO loss for preference optimization. Notably, Xu et al. (2023) explored iterative preference optimization with pairwise cringe loss, while Xiong et al. (2023) discussed how to bridge theory and practice for RLHF under KL constraints using iterative training. The concept of iterative preference learning was also explored in online DPO (Guo et al., 2024), which focuses on direct alignment from online AI feedback. In online DPO, preference data is dynamically updated during training, allowing the model to learn from its own generated data. - -Specifically, we developed the **`compute_online_dpo_loss`** function and built this SPIN recipe on top of it. By incorporating online preference generation, this approach enables continuously refining language models without relying on fixed external preference datasets. - -**Reference Papers:** -* [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/abs/2401.01335) (Chen et al., 2024) -* [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) (Rafailov et al., 2023) -* [Somethings are more cringe than others: Preference optimization with the pairwise cringe loss](https://arxiv.org/abs/2312.16682) (Xu et al., 2023) -* [Iterative preference learning from human feedback: Bridging theory and practice for rlhf under kl-constraint](https://arxiv.org/abs/2312.11456) (Xiong et al., 2023) -* [Snorkel-Mistral-PairRM-DPO](https://huggingface.co/snorkelai/Snorkel-Mistral-PairRM-DPO) (Snorkel AI, 2024) -* [Direct language model alignment from online ai feedback](https://arxiv.org/abs/2402.04792) (Guo et al., 2024) - - -## Our Online DPO Implementation - -Our `compute_online_dpo_loss` function adapts `verl`'s existing PPO infrastructure (based on `verl` v0.3.0.post1) for this iterative online DPO. Key aspects of our implementation include: - -* **No Critic:** Unlike PPO, we omit the value function critic. -* **Dynamic Reference Model:** An explicit reference policy (`ref_policy_wg`) is used for DPO loss. This reference model's weights can be periodically updated from the actor (`ref_update_freq`), providing a dynamic baseline. -* **Online Preference Generation:** The `compute_onlineDPO_pref` function (in `core_algos.py`) dynamically creates chosen/rejected pairs based on a reward source (e.g., rule-based ranking for math problems). -* **DPO Loss Integration:** We replace PPO's policy loss with our `compute_online_dpo_loss` (in `core_algos.py`) within the actor update (`dp_actor.py`), directly optimizing the policy using the generated preferences. -* **Iterative Training Orchestration:** The `SpinTrainer` (in `spin_trainer.py`) manages the entire self-play loop: generation, preference labeling, optional reference model updates, and policy updates, enabling continuous self-improvement aligned with SPIN's principles. - ---- -## Algorithm - -This recipe implements an Online algorithm adapted to the `verl` Reinforcement Learning framework, which provides an alternative to PPO for fine-tuning language models. - -**Online Loop:** Instead of maximizing a scalar reward signal in PPO, this approach directly optimizes the policy model to align with preference data generated *online* during training: - -1. **Generation:** The current model generates multiple responses for each prompt in a batch. -2. **Preference Labeling:** A function evaluates these generated responses to determine which one is preferred (chosen) and which is dispreferred (rejected). This can be done using a reward function or implicit ranking based on specific rules. (In this recipe, we use rule-based ranking on the math problem). -3. **Update:** This preference tuple (`prompt`, `chosen_response`, `rejected_response`) is used to update the actor model using `compute_online_dpo_loss`, comparing against a reference model. - -**Connection with SPIN:** -Instead of only using a fixed target data distribution, the online generation loop in step 2 will dynamically change the target data distribution by using a certain Preference Labeling method (rule-based ranking on the math problem by selecting the better one in this recipe). This explores the direction mentioned in SPIN's paper Section 7 about "dynamically changing target data distribution" to potentially elevate LLM performance beyond the fixed human-annotated data ceiling. - ---- - -## Reproduce the Experiment (Example Setup) - -The following steps outline how to set up the environment and run the SPIN recipe, based on the provided test log using GSM8K and Qwen2.5-3B-Instruct. - -1. **Setup Environment (Example using Docker):** - ```bash - # Start a container with GPU access and shared memory - docker run -it --name spin_test --gpus all \ - --shm-size=32g \ - --ipc=host \ - -v /path/to/host/.cache:/root/.cache \ - -e HF_TOKEN= \ - lmsysorg/sglang:latest \ - /bin/bash - - # Inside the container or on your host machine: - # Ensure /tmp is writable - mkdir -p /tmp - chmod 1777 /tmp - - # Install Python 3.10 (if not present) and venv - sudo apt update - sudo apt install -y python3.10 python3.10-venv tmux - python3 -m ensurepip --upgrade - - # Create and activate a virtual environment - python3 -m venv ~/.python/spin_env - source ~/.python/spin_env/bin/activate - - # Install uv (fast package installer) - python3 -m pip install uv - ``` - -2. **Install verl and Dependencies:** - ```bash - # Clone the verl repository and checkout the spin branch - cd ~ - git clone git@github.com:volcengine/verl.git && cd verl - - # Install flash-attn (handle potential build issues) - python3 -m uv pip install wheel packaging - python3 -m uv pip install flash-attn --no-build-isolation --no-deps - - # Install verl with sglang extras - python3 -m uv pip install -e ".[sglang]" - ``` - *Note: If `flash-attn` installation fails, try the manual steps again or consult its documentation.* - -3. **Login & Download Data/Model:** - ```bash - # Login to Weights & Biases (optional, for logging) - export WANDB_API_KEY= - # wandb login - - # Download the GSM8K dataset - python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k # Adjusted path - - # Download the base model (Example: Qwen2.5-3B-Instruct) - huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen2.5-3B-Instruct - ``` - -4. **Configure:** - * Modify the configuration file (e.g., `config/spin_trainer.yaml` or the one specified in the run script) with correct paths to your downloaded model, data, desired hyperparameters (`dpo_beta`, learning rate, etc.), and distributed training settings (nodes, GPUs per node). - * Pay attention to `actor_rollout_ref.model_path`, `data` paths, `reward_model` config (if using one), and `trainer.ref_update_freq`. - -5. **Run Training:** - ```bash - # Set CUDA visible devices (adjust based on your hardware and config) - export CUDA_VISIBLE_DEVICES=0,1,2,3 - - # Launch the training script (e.g., test.sh or a custom script) - # Ensure test.sh points to the correct config and main script - bash recipe/spin/run_spin.sh - ``` - ---- - -## Configuration - -* The primary configuration is typically managed through a YAML file specified in the launch script (e.g., `config/spin_trainer.yaml`). -* Key configuration sections: - * `data`: Paths to training/validation prompt files, batch sizes, sequence lengths. - * `actor_rollout_ref`: Paths to the base model (used for actor and initial reference), FSDP settings, optimization parameters (learning rate, scheduler). - * `reward_model`: Configuration for the reward model used for online preference labeling (path, batch size, etc.). Can be omitted if using a simpler reward function. - * `algorithm`: DPO-specific hyperparameters like `dpo_beta`, `dpo_loss_type`. - * `trainer`: Distributed training settings (nodes, GPUs per node), logging (WandB), checkpointing frequency, and `ref_update_freq` (set > 0 to enable periodic reference model updates from the actor). - ---- - -## Key Files - -* `main_spin.py`: Main entry point using Hydra to load the config and launch the `SpinTrainer`. -* `spin_trainer.py`: Defines the `SpinTrainer` class, orchestrating the Online DPO training loop. -* `fsdp_workers.py`: Implements Ray workers (Actor, Reference) potentially using FSDP. -* `dp_actor.py`: Contains the actor class, including the DPO policy update logic. -* `core_algos.py`: Includes helper functions for `compute_online_dpo_loss` and `compute_onlineDPO_pref`. -* `config/spin_trainer.yaml` (or similar): Main Hydra configuration file for the recipe. -* `run_spin.sh` (or similar): Example bash script for launching a training run. -* `README.md`: This file. - ---- - -## Acknowledgement - -We sincerely thank the contribution and guidance from the `verl` community and advisors, including (adapted from SPPO): - -* [Zixiang Chen](https://sites.google.com/view/zxchen) -* [Yuhao Yang](https://github.com/yhyang201) -* [Yifan Zhang](https://github.com/yifanzhang-pro) -* [Yongan Xiang](https://github.com/BearBiscuit05) -* [Junrong Lin](https://github.com/ocss884) -* [Yuxuan Tong](https://github.com/tongyx361) -* [Guangming Shen](https://github.com/PeterSH6) -* [Biao He](https://www.linkedin.com/in/biao-he/) -* [Qingquan Song](https://qingquansong.github.io/) -* [Chenyang Zhao](https://zhaochenyang20.github.io/Chayenne/) -* [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) - ---- diff --git a/recipe/spin/config/spin_trainer.yaml b/recipe/spin/config/spin_trainer.yaml deleted file mode 100644 index ee105c421..000000000 --- a/recipe/spin/config/spin_trainer.yaml +++ /dev/null @@ -1,28 +0,0 @@ -# the sppo config will override default ppo_trainer.yaml - -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -actor_rollout_ref: - actor: - dpo_beta: 0.1 - optim: - lr_warmup_steps: 15 - rollout: - name: sglang - tensor_model_parallel_size: 2 - gpu_memory_utilization: 0.5 - val_kwargs: - n: 2 # 2 will trigger validation, 1 will bypass - -algorithm: - adv_estimator: null - -trainer: - log_val_generations: 0 - ref_update_freq: 1 \ No newline at end of file diff --git a/recipe/spin/core_algos.py b/recipe/spin/core_algos.py deleted file mode 100644 index c48027e54..000000000 --- a/recipe/spin/core_algos.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import numpy as np -import torch - - -class AdaptiveKLController: - """ - Adaptive KL controller described in the paper: - https://arxiv.org/pdf/1909.08593.pdf - """ - - def __init__(self, init_kl_coef, target_kl, horizon): - self.value = init_kl_coef - self.target = target_kl - self.horizon = horizon - - def update(self, current_kl, n_steps): - target = self.target - proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.horizon - self.value *= mult - - -class FixedKLController: - """Fixed KL controller.""" - - def __init__(self, kl_coef): - self.value = kl_coef - - def update(self, current_kl, n_steps): - pass - - -def get_kl_controller(kl_ctrl): - if kl_ctrl.type == "fixed": - return FixedKLController(kl_coef=kl_ctrl.kl_coef) - elif kl_ctrl.type == "adaptive": - assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" - return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) - else: - raise NotImplementedError - - -def compute_onlinedpo_pref( - token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, -) -> torch.Tensor: - """ - Computes preferences between pairs of sequences based on summed rewards - and returns a mask aligned with the interleaved batch. - - Assumes inputs are interleaved: [Resp1_Prompt0, Resp2_Prompt0, Resp1_Prompt1, Resp2_Prompt1, ...] - - Args: - token_level_rewards: Tensor of shape [batch_size * 2, seq_len] - response_mask: Tensor of shape [batch_size * 2, seq_len] - - Returns: - torch.Tensor: A boolean mask of shape [batch_size * 2], where True indicates - the corresponding entry is the chosen response for its pair. - Example: [True, False, False, True, ...] means for prompt 0, - response 1 was chosen; for prompt 1, response 2 was chosen. - """ - # print(f"---- [DEBUG] Inside compute_onlinedpo_pref ----") - if token_level_rewards.shape[0] % 2 != 0 or response_mask.shape[0] % 2 != 0: - raise ValueError( - f"Input tensor batch dimension must be even for pair comparison, got shapes: " - f"{token_level_rewards.shape}, {response_mask.shape}" - ) - if token_level_rewards.shape != response_mask.shape: - raise ValueError(f"Shape mismatch between rewards {token_level_rewards.shape} and mask {response_mask.shape}") - - # 1. Calculate Sequence Scores - scores = (token_level_rewards * response_mask).sum(dim=-1) - # print(f" Calculated sequence scores shape: {scores.shape}") # [batch_size * 2] - - # 2. Reshape scores to group pairs: [batch_size, 2] - try: - score_pairs = scores.view(-1, 2) - except RuntimeError as e: - print(f"ERROR reshaping scores (shape {scores.shape}) into pairs: {e}") - raise e - print(f" Reshaped score pairs shape: {score_pairs.shape}") # [batch_size, 2] - - # 3. Compare scores to find which index (0 or 1) is the winner within each pair - # winner_indices[i] = 0 if score_pairs[i, 0] >= score_pairs[i, 1] else 1 - winner_indices = torch.argmax(score_pairs, dim=1) # 0 if first is max, 1 if second is max - # Handle ties explicitly if argmax behavior isn't guaranteed (usually picks first max) - # Alternatively: winner_mask_original = score_pairs[:, 0] >= score_pairs[:, 1] - # print(f" Winner indices shape: {winner_indices.shape}") # [batch_size] - # print(f" Number where Response 2 (index 1) is preferred: {winner_indices.sum().item()}") # Counts number of 1s - - # 4. Create the final [batch_size * 2] mask - num_pairs = score_pairs.shape[0] - full_batch_size = num_pairs * 2 - # Create indices for the full batch [0, 1, 2, 3, ..., N*2-1] - # full_indices = torch.arange(full_batch_size, device=scores.device) - # Create indices corresponding to the winner within each pair's original index - # E.g., if winner_indices is [0, 1, 0], pair_indices is [0, 1, 2] - # winner_global_indices = (pair_indices * 2) + winner_indices -> [ (0*2)+0, (1*2)+1, (2*2)+0 ] -> [0, 3, 4] - pair_indices = torch.arange(num_pairs, device=scores.device) - winner_global_indices = (pair_indices * 2) + winner_indices - - # Create boolean mask - True at the winner's position - output_preference_mask = torch.zeros(full_batch_size, dtype=torch.bool, device=scores.device) - output_preference_mask[winner_global_indices] = True - - # print(f" Output preference mask shape: {output_preference_mask.shape}") # Should be [batch_size * 2] - # print(f" Output mask True count (Chosen): {output_preference_mask.sum().item()}") # Should be batch_size - # print(f" Output mask False count (Rejected): {(~output_preference_mask).sum().item()}") # Should be batch_size - # print(f"---- [DEBUG] Exiting compute_onlinedpo_pref ----") - - return output_preference_mask - - -def compute_online_dpo_loss( - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, - beta: float, - label_smoothing: float = 0.0, - loss_type: str = "sigmoid", - reference_free: bool = False, -) -> torch.Tensor: - import torch.nn.functional as F - - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - - if reference_free: - ref_logratios = torch.zeros_like(pi_logratios) - - logits = pi_logratios - ref_logratios - - if loss_type == "sigmoid": - losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing - elif loss_type == "ipo": - losses = (logits - 1 / (2 * beta)) ** 2 - else: - raise ValueError(f"Unsupported loss_type: {loss_type}. Choose 'sigmoid', 'ipo', or 'hinge'.") - - return losses.mean() - - -def get_batch_logps( - logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False -) -> torch.FloatTensor: - """ - Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (e.g., huggingface CausalLMOutputs `logits`). - Shape: (batch_size, sequence_length, vocab_size) - labels: Labels for computing the sequence log probabilities. Shape: (batch_size, sequence_length) - average_log_prob: If True, return the average log probability per sequence. Otherwise, return the sum. - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given sequences. - """ - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits and labels must have the same shape[:-1]") - - # Ensure labels are contiguous and on the same device as logits - labels = labels.contiguous().to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - - # Calculate per token log probability - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none") - per_token_logps = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - per_token_logps = per_token_logps.view( - shift_logits.size(0), shift_logits.size(1) - ) # Reshape back to (batch_size, seq_len-1) - - # Create a mask for the labels that are not -100 - loss_mask = shift_labels != -100 - - # Apply the mask to the per token log probabilities - masked_logps = per_token_logps * loss_mask - - # Calculate the sum or average log probability per sequence - sequence_logps = masked_logps.sum(dim=-1) - - if average_log_prob: - # Avoid division by zero for sequences with no valid tokens - num_valid_tokens = loss_mask.sum(dim=-1) - return sequence_logps / torch.clamp(num_valid_tokens, min=1) - else: - return sequence_logps diff --git a/recipe/spin/dp_actor.py b/recipe/spin/dp_actor.py deleted file mode 100644 index 35caa29c7..000000000 --- a/recipe/spin/dp_actor.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import itertools -import math -from collections import defaultdict - -import numpy as np -import torch - -from recipe.spin.core_algos import compute_online_dpo_loss, get_batch_logps -from verl import DataProto -from verl.utils.device import get_device_name -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches -from verl.workers.actor import DataParallelPPOActor - -__all__ = ["DataParallelPPOActor"] - - -class SPINDataParallelPPOActor(DataParallelPPOActor): - def compute_log_prob(self, data: DataProto) -> torch.Tensor: - """Compute the log probability of the responses given input_ids, attention_mask and position_ids - - Args: - data (DataProto): a DataProto containing keys - - ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the - concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. - - ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``responses``: tensor of shape [batch_size, response_length]. torch.int64. - - Returns: - torch.Tensor: the log_prob tensor - """ - # set to eval - self.actor_module.eval() - - micro_batch_size = data.meta_info["micro_batch_size"] - temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error - use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - - select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - - if has_multi_modal_inputs: - num_micro_batches = data.batch.batch_size[0] // micro_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif use_dynamic_bsz: - # split using dynamic bsz - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len) - else: - micro_batches = batch.split(micro_batch_size) - - log_probs_lst = [] - for micro_batch in micro_batches: - if isinstance(micro_batch, DataProto): - micro_batch = {**micro_batch.batch, **micro_batch.non_tensor_batch} - - with torch.no_grad(): - _, log_probs = self._forward_micro_batch(micro_batch, temperature=temperature) - log_probs_lst.append(log_probs) - log_probs = torch.concat(log_probs_lst, dim=0) - - if use_dynamic_bsz: - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - log_probs = log_probs[revert_indices] - - return log_probs - - def update_policy_dpo_with_ref(self, data: DataProto): - """ - Performs the DPO update step using pre-calculated reference log probs - from an external, periodically updated reference model. - """ - self.actor_module.train() # Ensure training mode - - # --- Retrieve necessary data --- - try: - # Expects batch prepared by fit_dpo loop, including reference log probs - batch_td = data.batch - chosen_labels = batch_td["chosen_labels"] - rejected_labels = batch_td["rejected_labels"] - # ... other needed tensors like chosen/rejected input_ids, attention_mask, position_ids ... - - # === Get PRE-CALCULATED reference log probs from input data === - reference_chosen_logps = batch_td["reference_chosen_logps"] # Should be sequence-level logps - reference_rejected_logps = batch_td["reference_rejected_logps"] # Should be sequence-level logps - # ============================================================ - - # Get DPO params from meta_info - # beta = data.meta_info.get('dpo_beta', 0.1) # Default beta - beta = self.config.get("dpo_beta", 0.1) # Default beta - loss_type = data.meta_info.get("dpo_loss_type", "sigmoid") - label_smoothing = data.meta_info.get("dpo_label_smoothing", 0.0) - # reference_free should now be False as we provide ref logps - reference_free = data.meta_info.get("reference_free", False) # Default False - - except KeyError as e: - print(f"ERROR: Missing required key for DPO update (in update_policy_dpo): {e}") - print(f"Available keys in data.batch: {list(batch_td.keys())}") # Debug print - return {} # Return empty metrics on error - except Exception as e_data: - print(f"ERROR accessing data for DPO update (in update_policy_dpo): {e_data}") - return {} - - # --- Micro-batching Setup --- - micro_batch_size = self.config.get("ppo_micro_batch_size_per_gpu") - if micro_batch_size is None: - # Fallback or default if not set, or raise error - micro_batch_size = 1 # Example fallback, adjust as needed - print(f"Warning: 'ppo_micro_batch_size_per_gpu' not set, defaulting to {micro_batch_size}") - # raise ValueError("Config 'ppo_micro_batch_size_per_gpu' must be set.") - - # Ensure chosen_input_ids exists before getting shape - if "chosen_input_ids" not in batch_td: - print("ERROR: 'chosen_input_ids' not found in batch_td for DPO update.") - return {} - bsz = batch_td["chosen_input_ids"].shape[0] - - if bsz == 0: - print("Warning: DPO batch size is 0 in update_policy_dpo. Skipping update.") - return {"actor/dpo_loss": 0.0, "actor/grad_norm": 0.0} # Return zero metrics if batch is empty - - num_micro_batches = math.ceil(bsz / micro_batch_size) - gradient_accumulation_steps = num_micro_batches - - # --- Metrics Accumulation --- - total_loss = 0.0 - accumulated_metrics = defaultdict(list) - metrics = {} # Final metrics dict - - # --- Zero Gradients --- - self.actor_optimizer.zero_grad(set_to_none=True) - - # --- Micro-batch Loop --- - for i in range(num_micro_batches): - start_idx = i * micro_batch_size - end_idx = min(start_idx + micro_batch_size, bsz) - if start_idx >= end_idx: - continue - - # Slice the full DPO batch into micro-batches - # Important: Slice ALL required tensors, including labels and inputs - micro_batch_chosen_labels = chosen_labels[start_idx:end_idx] - micro_batch_rejected_labels = rejected_labels[start_idx:end_idx] - micro_batch_chosen_inputs = { - "input_ids": batch_td["chosen_input_ids"][start_idx:end_idx], - "attention_mask": batch_td["chosen_attention_mask"][start_idx:end_idx], - } - if "chosen_position_ids" in batch_td: - micro_batch_chosen_inputs["position_ids"] = batch_td["chosen_position_ids"][start_idx:end_idx] - - micro_batch_rejected_inputs = { - "input_ids": batch_td["rejected_input_ids"][start_idx:end_idx], - "attention_mask": batch_td["rejected_attention_mask"][start_idx:end_idx], - } - if "rejected_position_ids" in batch_td: - micro_batch_rejected_inputs["position_ids"] = batch_td["rejected_position_ids"][start_idx:end_idx] - - # Determine autocast dtype - autocast_dtype = torch.bfloat16 # Or get dynamically from config/FSDP settings - # --- Autocast Forward Pass --- - with torch.autocast(device_type=get_device_name(), dtype=autocast_dtype): - # --- Step 1: Forward pass for CURRENT policy log probs (with grad) --- - policy_chosen_outputs = self.actor_module(**micro_batch_chosen_inputs, use_cache=False) - policy_rejected_outputs = self.actor_module(**micro_batch_rejected_inputs, use_cache=False) - - # --- Step 2: Calculate CURRENT policy log probs using get_batch_logps --- - policy_chosen_logps = get_batch_logps( - policy_chosen_outputs.logits, micro_batch_chosen_labels, average_log_prob=False - ) - policy_rejected_logps = get_batch_logps( - policy_rejected_outputs.logits, micro_batch_rejected_labels, average_log_prob=False - ) - - # --- Step 3: Retrieve PRE-CALCULATED reference log probs (NO grad needed) --- - # Slice the full batch reference logps for the current micro-batch - micro_ref_chosen_logps = reference_chosen_logps[start_idx:end_idx] - micro_ref_rejected_logps = reference_rejected_logps[start_idx:end_idx] - # --- The ActorAsRef calculation block is REMOVED --- - - # --- Step 4: Calculate DPO Logits and Loss --- - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = micro_ref_chosen_logps - micro_ref_rejected_logps # Uses pre-calculated values - logits = pi_logratios - ref_logratios # DPO logits - - loss = compute_online_dpo_loss( - policy_chosen_logps=policy_chosen_logps, # Has grad - policy_rejected_logps=policy_rejected_logps, # Has grad - reference_chosen_logps=micro_ref_chosen_logps, # No grad (from input) - reference_rejected_logps=micro_ref_rejected_logps, # No grad (from input) - beta=beta, - label_smoothing=label_smoothing, - loss_type=loss_type, - reference_free=reference_free, # Should be False now - ) - - # --- Scale loss for gradient accumulation --- - scaled_loss = loss / gradient_accumulation_steps - - # --- Accumulate Metrics --- - total_loss += loss.item() # Unscaled loss - accumulated_metrics["actor/dpo_loss_batch"].append(loss.item()) - accumulated_metrics["actor/dpo_logits_batch"].append(logits.mean().item()) - # Accumulate policy and reference log probs/ratios if needed for debugging - accumulated_metrics["actor/policy_chosen_logps_batch"].append(policy_chosen_logps.mean().item()) - accumulated_metrics["actor/policy_rejected_logps_batch"].append(policy_rejected_logps.mean().item()) - accumulated_metrics["actor/reference_chosen_logps_batch"].append(micro_ref_chosen_logps.mean().item()) - accumulated_metrics["actor/reference_rejected_logps_batch"].append( - micro_ref_rejected_logps.mean().item() - ) - - # --- Backward Pass (outside autocast) --- - # Check if loss requires grad before backward - if scaled_loss.requires_grad: - scaled_loss.backward() - else: - print(f"Warning: Scaled loss at micro-batch {i} does not require grad. Skipping backward.") - - # --- End Micro-batch Loop --- - - # --- Optimizer Step (after accumulating gradients for all micro-batches) --- - grad_norm = self._optimizer_step() - - # --- Populate Final Metrics --- - if num_micro_batches > 0 and bsz > 0: # Check if any processing happened - metrics["actor/dpo_loss"] = total_loss / num_micro_batches - metrics["actor/grad_norm"] = ( - grad_norm.item() if torch.is_tensor(grad_norm) and torch.isfinite(grad_norm) else float("inf") - ) - # Average other accumulated metrics - for key, val_list in accumulated_metrics.items(): - if val_list: - metrics[key.replace("_batch", "")] = np.mean(val_list) - - # Calculate accuracy / rewards / margins based on averaged logprobs if desired - if ( - "actor/policy_chosen_logps" in metrics - and "actor/policy_rejected_logps" in metrics - and "actor/reference_chosen_logps" in metrics - and "actor/reference_rejected_logps" in metrics - ): - policy_ratio_mean = metrics["actor/policy_chosen_logps"] - metrics["actor/policy_rejected_logps"] - ref_ratio_mean = metrics["actor/reference_chosen_logps"] - metrics["actor/reference_rejected_logps"] - logits_mean = policy_ratio_mean - ref_ratio_mean - metrics["actor/rewards_chosen"] = beta * ( - metrics["actor/policy_chosen_logps"] - metrics["actor/reference_chosen_logps"] - ) - metrics["actor/rewards_rejected"] = beta * ( - metrics["actor/policy_rejected_logps"] - metrics["actor/reference_rejected_logps"] - ) - metrics["actor/rewards_accuracies"] = float(logits_mean > 0) # Mean accuracy proxy - metrics["actor/rewards_margins"] = metrics["actor/rewards_chosen"] - metrics["actor/rewards_rejected"] - - else: # Handle case where no micro-batches were run (e.g., bsz=0) - metrics["actor/dpo_loss"] = 0.0 - metrics["actor/grad_norm"] = 0.0 - # Initialize other metrics to 0 or NaN as appropriate - for key in accumulated_metrics.keys(): - metrics[key.replace("_batch", "")] = 0.0 - metrics["actor/rewards_chosen"] = 0.0 - metrics["actor/rewards_rejected"] = 0.0 - metrics["actor/rewards_accuracies"] = 0.0 - metrics["actor/rewards_margins"] = 0.0 - - return metrics # Return aggregated metrics diff --git a/recipe/spin/fsdp_workers.py b/recipe/spin/fsdp_workers.py deleted file mode 100644 index bbbfa0ed0..000000000 --- a/recipe/spin/fsdp_workers.py +++ /dev/null @@ -1,599 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import os -import warnings - -import psutil -import torch -import torch.distributed -from codetiming import Timer -from omegaconf import open_dict -from torch.distributed.device_mesh import init_device_mesh - -import verl.utils.torch_functional as verl_F -from verl import DataProto -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register -from verl.utils import hf_tokenizer -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device -from verl.utils.flops_counter import FlopsCounter -from verl.utils.fs import copy_to_local -from verl.utils.fsdp_utils import ( - get_fsdp_wrap_policy, - get_init_weight_context_manager, - init_fn, - load_fsdp_model_to_gpu, - load_fsdp_optimizer, - offload_fsdp_model_to_cpu, - offload_fsdp_optimizer, -) -from verl.utils.import_utils import import_external_libs -from verl.utils.model import compute_position_id_with_mask -from verl.utils.profiler import log_gpu_memory_usage -from verl.workers.fsdp_workers import ActorRolloutRefWorker -from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) - - -def create_device_mesh(world_size, fsdp_size): - if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) - else: - device_mesh = init_device_mesh( - get_device_name(), mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] - ) - return device_mesh - - -def get_sharding_strategy(device_mesh): - from torch.distributed.fsdp import ShardingStrategy - - if device_mesh.ndim == 1: - sharding_strategy = ShardingStrategy.FULL_SHARD - elif device_mesh.ndim == 2: - sharding_strategy = ShardingStrategy.HYBRID_SHARD - else: - raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") - return sharding_strategy - - -class SPINRolloutRefWorker(ActorRolloutRefWorker): - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - from recipe.spin.dp_actor import SPINDataParallelPPOActor as DataParallelPPOActor - - # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get("external_lib", None)) - - from omegaconf import OmegaConf - - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) - - use_remove_padding = self.config.model.get("use_remove_padding", False) - use_fused_kernels = self.config.model.get("use_fused_kernels", False) - - if self._is_actor or self._is_rollout or self._is_ref: - # we need the model for actor and rollout - if self._is_actor or self._is_ref: - optim_config = self.config.actor.optim - fsdp_config = self.config.actor.fsdp_config - else: - optim_config = None - fsdp_config = OmegaConf.create() - self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( - self._build_model_optimizer( - model_path=self.config.model.path, - fsdp_config=fsdp_config, - optim_config=optim_config, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - use_fused_kernels=use_fused_kernels, - enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), - trust_remote_code=self.config.model.get("trust_remote_code", False), - use_liger=self.config.model.get("use_liger", False), - role="actor", - ) - ) - - # get the original unwrapped module - self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module - - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) - # load from checkpoint - if self._is_actor or self._is_ref: - OmegaConf.set_struct(self.config.actor, True) - with open_dict(self.config.actor): - self.config.actor.use_remove_padding = use_remove_padding - self.config.actor.use_fused_kernels = use_fused_kernels - self.actor = DataParallelPPOActor( - config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer - ) - - if self._is_rollout: - self.rollout, self.rollout_sharding_manager = self._build_rollout( - trust_remote_code=self.config.model.get("trust_remote_code", False) - ) - - if self._is_ref: - self.ref_module_fsdp = self._build_model_optimizer( - model_path=self.config.model.path, - fsdp_config=self.config.ref.fsdp_config, - optim_config=None, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - use_fused_kernels=use_fused_kernels, - trust_remote_code=self.config.model.get("trust_remote_code", False), - use_liger=self.config.model.get("use_liger", False), - role="ref", - )[0] - OmegaConf.set_struct(self.config.ref, True) - with open_dict(self.config.ref): - self.config.ref.use_remove_padding = use_remove_padding - self.config.ref.use_fused_kernels = use_fused_kernels - self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) - self.checkpoint_manager = FSDPCheckpointManager( - model=self.actor_module_fsdp, - optimizer=self.actor.actor_optimizer, - lr_scheduler=self.actor_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_config=self.config.actor.checkpoint, - ) - - if self._is_actor: - self.flops_counter = FlopsCounter(self.actor_model_config) - self.checkpoint_manager = FSDPCheckpointManager( - model=self.actor_module_fsdp, - optimizer=self.actor.actor_optimizer, - lr_scheduler=self.actor_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_config=self.config.actor.checkpoint, - ) - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_ref_log_prob(self, data: DataProto): - assert self._is_ref - - # Support all hardwares - data = data.to(get_device_id()) - - micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu - data.meta_info["micro_batch_size"] = micro_batch_size - data.meta_info["temperature"] = self.config.rollout.temperature - data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) - output = self.ref_policy.compute_log_prob(data=data) - output = DataProto.from_dict(tensors={"ref_log_prob": output}) - output = self.ulysses_sharding_manager.postprocess_data(output) - - output = output.to("cpu") - - # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes - # unshard the root FSDP module - if self.world_size > 1: - self.ref_policy.actor_module._handle.reshard(True) - - return output - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_log_prob(self, data: DataProto): - assert self._is_actor - if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - - # Support all hardwares - data = data.to(get_device_id()) - # we should always recompute old_log_probs when it is HybridEngine - data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz - data.meta_info["temperature"] = self.config.rollout.temperature - # perform recompute log_prob - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) - output = self.actor.compute_log_prob(data=data) - output = DataProto.from_dict( - tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature} - ) - output = self.ulysses_sharding_manager.postprocess_data(output) - - output = output.to("cpu") - - # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes - # unshard the root FSDP module - if self.world_size > 1: - self.actor.actor_module._handle.reshard(True) - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - - log_gpu_memory_usage("After compute_log_prob", logger=logger) - return output - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def update_actor_dpo(self, data: DataProto): - """ - Wrapper for actor update step. Handles FSDP state management. - Calls self.actor.update_policy which now contains DPO logic based - on pre-calculated log probabilities. - """ - # Support all hardwares - data = data.to(get_device_id()) - - assert self._is_actor # Make sure this worker has the actor role - if self.actor is None: - raise RuntimeError("Actor instance (self.actor) not initialized in worker.") - - # --- FSDP State Management --- - if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) - - log_gpu_memory_usage("Before update policy (DPO via PPO path)", logger=logger) - - # --- Ulysses Sharding (if used) --- - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) - - # --- Call the core update method (now containing DPO logic) --- - with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name - # Calls the modified update_policy method - metrics = self.actor.update_policy_dpo_with_ref(data=data) # <-- THIS CALLS THE MODIFIED FUNCTION - delta_time = timer.last - - # --- Add Performance Metrics --- - # MFU calculation might be less accurate/meaningful here for DPO - metrics["perf/approx_tokens_processed"] = torch.sum( - data.batch.get("attention_mask", torch.tensor(0)) - ).item() # Approx tokens - metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) - metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) - metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) - global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/actor"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size - - # --- LR Scheduler Step --- - lr = self.actor_lr_scheduler.get_last_lr()[0] - metrics["actor/lr"] = lr - self.actor_lr_scheduler.step() - - log_gpu_memory_usage("After update policy (DPO via PPO path)", logger=logger) - - # --- Prepare Output --- - output = DataProto(meta_info={"metrics": metrics}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) - output = output.to("cpu") - - # --- FSDP State Management (Offload) --- - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.actor_optimizer) - - return output - - -# TODO(sgm): we may need to extract it to dp_reward_model.py -class RewardModelWorker(Worker): - """ - Note that we only implement the reward model that is subclass of AutoModelForTokenClassification. - """ - - def __init__(self, config): - super().__init__() - import torch.distributed - - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend=get_nccl_backend()) - self.config = config - - # build device mesh for Ulysses Sequence Parallel - world_size = torch.distributed.get_world_size() - from torch.distributed.device_mesh import init_device_mesh - - fsdp_size = self.config.model.fsdp_config.fsdp_size - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) - - self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) - dp = world_size // self.ulysses_sequence_parallel_size - if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh( - get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] - ) - - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - - self.use_remove_padding = self.config.model.get("use_remove_padding", False) - - # normalize config - if self.config.micro_batch_size is not None: - self.config.micro_batch_size //= torch.distributed.get_world_size() - self.config.micro_batch_size_per_gpu = self.config.micro_batch_size - - def _build_model(self, config): - # the following line is necessary - from torch.distributed.fsdp import CPUOffload - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from transformers import AutoConfig, AutoModelForTokenClassification - - # download the checkpoint from hdfs - local_path = copy_to_local(config.model.path) - - if self.config.model.input_tokenizer is None: - self._do_switch_chat_template = False - else: - self._do_switch_chat_template = True - input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer) - self.input_tokenizer = hf_tokenizer( - input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) - ) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) - - trust_remote_code = config.model.get("trust_remote_code", False) - model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) - model_config.num_labels = 1 - - # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect - init_context = get_init_weight_context_manager( - use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh - ) - - with init_context(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - model_config.classifier_dropout = 0.0 - reward_module = AutoModelForTokenClassification.from_pretrained( - pretrained_model_name_or_path=local_path, - config=model_config, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, - ) - - if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1: - from verl.models.transformers.monkey_patch import apply_monkey_patch - - apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size) - - reward_module.to(torch.bfloat16) - - auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) - - fsdp_mesh = self.device_mesh - sharding_strategy = get_sharding_strategy(fsdp_mesh) - - reward_module = FSDP( - reward_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=get_device_id(), - sharding_strategy=sharding_strategy, # zero3 - sync_module_states=True, - cpu_offload=CPUOffload(offload_params=True), - forward_prefetch=False, - device_mesh=self.device_mesh, - ) - - return reward_module - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get("external_lib", None)) - self.reward_module = self._build_model(config=self.config) - - def _forward_micro_batch(self, micro_batch): - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input - - from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs - - with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - - if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # pad and slice the inputs if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size - ) - - # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.reward_module( - input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False - ) # prevent model thinks we are generating - reward_rmpad = output.logits - reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) - - # gather output if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - reward_rmpad = gather_outputs_and_unpad( - reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - - # pad it back - rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) - else: - output = self.reward_module( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False - ) - rm_score = output.logits # (batch_size, seq_len, 1) - rm_score = rm_score.squeeze(-1) - - # extract the result of the last valid token - eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) - rm_score = rm_score[torch.arange(batch_size), eos_mask_idx] - return rm_score - - def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): - batch_size = data.batch.batch_size[0] - # expand as token_level_reward - attention_mask = data.batch["attention_mask"] - position_ids = data.batch["position_ids"] - response_length = data.batch["responses"].shape[-1] - eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) - token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) - token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores - - # select the response part - token_level_scores = token_level_scores[:, -response_length:] - - return token_level_scores - - def _switch_chat_template(self, data: DataProto): - src_max_length = data.batch["attention_mask"].shape[-1] - - src_tokenizer = self.input_tokenizer - target_tokenizer = self.tokenizer - - rm_input_ids = [] - rm_attention_mask = [] - - for i in range(data.batch.batch_size[0]): - # extract raw prompt - if isinstance(data.non_tensor_batch["raw_prompt"][i], list): - chat: list = data.non_tensor_batch["raw_prompt"][i] - else: - chat: list = data.non_tensor_batch["raw_prompt"][i].tolist() - - # extract response - response_ids = data.batch["responses"][i] - response_length = response_ids.shape[-1] - valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() - valid_response_ids = response_ids[:valid_response_length] - - # decode - response = src_tokenizer.decode(valid_response_ids) - # remove bos and eos - response = response.replace(src_tokenizer.eos_token, "") - - chat.append({"role": "assistant", "content": response}) - - prompt_with_chat_template = target_tokenizer.apply_chat_template( - chat, add_generation_prompt=False, tokenize=False - ) - if self.rank == 0 and i == 0: - # for debugging purpose - print(f"Switch template. chat: {prompt_with_chat_template}") - - # the maximum length is actually determined by the reward model itself - max_length = self.config.get("max_length", src_max_length) - if max_length is None: - max_length = src_max_length - - model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) - input_ids, attention_mask = verl_F.postprocess_data( - input_ids=model_inputs["input_ids"], - attention_mask=model_inputs["attention_mask"], - max_length=max_length, - pad_token_id=target_tokenizer.pad_token_id, - left_pad=False, # right padding - truncation=self.config.get("truncation", "right"), - ) # truncate from the right - - rm_input_ids.append(input_ids) - rm_attention_mask.append(attention_mask) - - rm_input_ids = torch.cat(rm_input_ids, dim=0) - rm_attention_mask = torch.cat(rm_attention_mask, dim=0) - - rm_position_ids = compute_position_id_with_mask(rm_attention_mask) - - rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} - - return DataProto.from_dict(rm_inputs) - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def compute_rm_score(self, data: DataProto): - import itertools - - from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches - - # Support all hardwares - data = data.to(get_device_id()) - if self._do_switch_chat_template: - rm_data = self._switch_chat_template(data) - else: - rm_input_ids = data.batch["input_ids"] - rm_attention_mask = data.batch["attention_mask"] - rm_position_ids = data.batch["position_ids"] - rm_inputs = { - "input_ids": rm_input_ids, - "attention_mask": rm_attention_mask, - "position_ids": rm_position_ids, - } - rm_data = DataProto.from_dict(rm_inputs) - - # Support all hardwares - rm_data.batch = rm_data.batch.to(get_device_id()) - - # perform forward computation - with self.ulysses_sharding_manager: - rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data) - data = self.ulysses_sharding_manager.preprocess_data(data=data) - - use_dynamic_bsz = self.config.use_dynamic_bsz - if use_dynamic_bsz: - max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) - else: - micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) - output = [] - for micro_batch in micro_batches: - rm_score = self._forward_micro_batch(micro_batch) - output.append(rm_score) - scores = torch.cat(output, dim=0) # (batch_size) - - if use_dynamic_bsz: - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - scores = scores[revert_indices] - - token_level_scores = self._expand_to_token_level(data, scores) - # Note that this is only the scores, may not be the final rewards used to train RL - output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) - - # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes - # unshard the root FSDP module - self.reward_module._handle.reshard(True) - - output = output.to("cpu") - return output diff --git a/recipe/spin/main_spin.py b/recipe/spin/main_spin.py deleted file mode 100644 index 9a879ee77..000000000 --- a/recipe/spin/main_spin.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import hydra -import ray - -from recipe.spin.spin_trainer import RaySPINTrainer -from verl.trainer.ppo.reward import get_custom_reward_fn - - -@hydra.main(config_path="config", config_name="spin_trainer", version_base=None) -def main(config): - run_ppo(config) - - -def run_ppo(config) -> None: - # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices - # isolation, will solve in the future - os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") - if not ray.is_initialized(): - # this is for local ray cluster - ray.init( - runtime_env={ - "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} - } - ) - - runner = TaskRunner.remote() - ray.get(runner.run.remote(config)) - - -@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head -class TaskRunner: - def run(self, config): - # print initial config - from pprint import pprint - - from omegaconf import OmegaConf - - from verl.utils.fs import copy_to_local - - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - OmegaConf.resolve(config) - - # download the checkpoint from hdfs - local_path = copy_to_local(config.actor_rollout_ref.model.path) - - # instantiate tokenizer - from verl.utils import hf_processor, hf_tokenizer - - trust_remote_code = config.data.get("trust_remote_code", False) - tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none - - # define worker classes - if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: - assert config.critic.strategy in {"fsdp", "fsdp2"} - # from recipe.spin.fsdp_workers import ActorRolloutRefWorker - from recipe.spin.fsdp_workers import SPINRolloutRefWorker - from verl.single_controller.ray import RayWorkerGroup - - ray_worker_group_cls = RayWorkerGroup - - elif config.actor_rollout_ref.actor.strategy == "megatron": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - - ray_worker_group_cls = NVMegatronRayWorkerGroup - - else: - raise NotImplementedError - - from recipe.spin.spin_trainer import ResourcePoolManager, Role - - role_worker_mapping = { - # Role.ActorRollout: ray.remote(ActorRolloutRefWorker), - Role.ActorRollout: ray.remote(SPINRolloutRefWorker), - # Role.Critic: ray.remote(CriticWorker), - } - - global_pool_id = "global_pool" - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - # Role.Critic: global_pool_id, - } - - if config.reward_model.enable: - if config.reward_model.strategy in {"fsdp", "fsdp2"}: - from recipe.spin.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == "megatron": - from verl.workers.megatron_workers import RewardModelWorker - else: - raise NotImplementedError - role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) - mapping[Role.RewardModel] = global_pool_id - - # use reference model - # if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: - # role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) - role_worker_mapping[Role.RefPolicy] = ray.remote(SPINRolloutRefWorker) - mapping[Role.RefPolicy] = global_pool_id - - from verl.workers.reward_manager import get_reward_manager_cls - - # Note(haibin.lin): please make sure custom reward managers are imported and - # registered via `verl.workers.reward_manager.register` - reward_manager_name = config.reward_model.get("reward_manager", "naive") - reward_manager_cls = get_reward_manager_cls(reward_manager_name) - - compute_score = get_custom_reward_fn(config) - reward_kwargs = dict(config.reward_model.get("reward_kwargs", {})) - reward_fn = reward_manager_cls( - tokenizer=tokenizer, - num_examine=0, - compute_score=compute_score, - reward_fn_key=config.data.reward_fn_key, - **reward_kwargs, - ) - - # Note that we always use function-based RM for validation - val_reward_fn = reward_manager_cls( - tokenizer=tokenizer, num_examine=1, compute_score=compute_score, reward_fn_key=config.data.reward_fn_key - ) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - trainer = RaySPINTrainer( - config=config, - tokenizer=tokenizer, - processor=processor, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn, - device_name=config.trainer.device, - ) - trainer.init_workers() - trainer.fit_dpo() - - -if __name__ == "__main__": - main() diff --git a/recipe/spin/run_spin.sh b/recipe/spin/run_spin.sh deleted file mode 100644 index 798dedabe..000000000 --- a/recipe/spin/run_spin.sh +++ /dev/null @@ -1,29 +0,0 @@ -set -e -set -x -VISIBLE_DEVICES="4,5,6,7" -export HYDRA_FULL_ERROR=1 - -CUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} python3 -m recipe.spin.main_spin \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size=8 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=64 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=console \ - trainer.val_before_train=True \ - trainer.n_gpus_per_node=4 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=1 \ - +trainer.log_freq=1 \ - trainer.ref_update_freq=1 \ - trainer.total_epochs=1000 2>&1 | tee verl_demo.log \ No newline at end of file diff --git a/recipe/spin/spin_trainer.py b/recipe/spin/spin_trainer.py deleted file mode 100644 index fa435dbdd..000000000 --- a/recipe/spin/spin_trainer.py +++ /dev/null @@ -1,1458 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import traceback -import uuid -from collections import defaultdict -from contextlib import contextmanager -from dataclasses import dataclass, field -from enum import Enum -from pprint import pprint -from typing import Any, Optional - -import numpy as np -import ray -import torch -from codetiming import Timer -from omegaconf import OmegaConf, open_dict -from torch.utils.data import Dataset, Sampler -from torchdata.stateful_dataloader import StatefulDataLoader -from tqdm import tqdm - -from recipe.spin import core_algos -from verl import DataProto -from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.single_controller.base import Worker -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup -from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.ppo.metric_utils import ( - compute_throughout_metrics, - compute_timing_metrics, - process_validation_metrics, - reduce_metrics, -) -from verl.trainer.ppo.ray_trainer import Role -from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path -from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance -from verl.utils.torch_functional import masked_mean -from verl.utils.tracking import ValidationGenerationsLogger - -WorkerType = type[Worker] - - -class AdvantageEstimator(str, Enum): - """ - Using an enumeration class to avoid spelling errors in adv_estimator - """ - - GAE = "gae" - GRPO = "grpo" - REINFORCE_PLUS_PLUS = "reinforce_plus_plus" - REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" - REMAX = "remax" - RLOO = "rloo" - - -@dataclass -class ResourcePoolManager: - """ - Define a resource pool specification. Resource pool will be initialized first. - Mapping - """ - - resource_pool_spec: dict[str, list[int]] - mapping: dict[Role, str] - resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) - - def create_resource_pool(self): - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool - # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. - # For Megatron backend, we recommend using max_colocate_count>1 that can utilize different - # WorkerGroup for different models - resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name - ) - self.resource_pool_dict[resource_pool_name] = resource_pool - - self._check_resource_available() - - def get_resource_pool(self, role: Role) -> RayResourcePool: - """Get the resource pool of the worker_cls""" - return self.resource_pool_dict[self.mapping[role]] - - def get_n_gpus(self) -> int: - """Get the number of gpus in this cluster.""" - return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]) - - def _check_resource_available(self): - """Check if the resource pool can be satisfied in this ray cluster.""" - node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} - - # check total required gpus can be satisfied - total_available_gpus = sum(node_available_gpus.values()) - total_required_gpus = sum( - [n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes] - ) - if total_available_gpus < total_required_gpus: - raise ValueError( - f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" - ) - - # check each resource pool can be satisfied, O(#resource_pools * #nodes) - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes) - for node, available_gpus in node_available_gpus.items(): - if available_gpus >= num_gpus: - node_available_gpus[node] -= num_gpus - num_nodes -= 1 - if num_nodes == 0: - break - if num_nodes > 0: - raise ValueError( - f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes} cannot be satisfied in this " - f"ray cluster" - ) - - -def _compute_response_info(batch: DataProto) -> dict[str, Any]: - """Placeholder: Computes prompt and response lengths.""" - try: - # Assuming 'prompts' and 'responses' keys exist after generation/union - prompt_len = batch.batch["prompts"].shape[1] - resp_len = batch.batch["responses"].shape[1] - # This is simplified - real implementation might use attention masks - # to get actual lengths per sample. - batch_size = batch.batch.batch_size[0] - prompt_lengths_tensor = torch.full((batch_size,), prompt_len, dtype=torch.float32, device=batch.batch.device) - response_lengths_tensor = torch.full((batch_size,), resp_len, dtype=torch.float32, device=batch.batch.device) - - # Try getting actual lengths from attention mask if possible (more accurate) - if "response_mask" in batch.batch: - response_lengths_tensor = batch.batch["response_mask"].sum(dim=1).float() - # if "attention_mask" in batch.batch and "response_mask" in batch.batch: - # full_mask = batch.batch["attention_mask"] - # resp_mask = batch.batch["response_mask"] - # Infer prompt mask length based on where response mask starts or total length - # This logic depends heavily on how your masks are constructed. - # Example: prompt_lengths_tensor = full_mask.sum(dim=1).float() - response_lengths_tensor - # Fallback to using prompt shape if mask logic is complex: - prompt_lengths_tensor = torch.tensor( - [batch.batch["prompts"].shape[1]] * batch_size, dtype=torch.float32, device=batch.batch.device - ) - - return { - "prompt_length": prompt_lengths_tensor, - "response_length": response_lengths_tensor, - "max_response_length": resp_len, - "max_prompt_length": prompt_len, # Or from config if fixed padding - } - except KeyError as e: - print(f"Warning: Missing key in _compute_response_info: {e}. Returning defaults.") - # Return default/dummy values if keys are missing - b_size = batch.batch.batch_size[0] if batch.batch.batch_size else 1 - max_resp = batch.batch.get("responses").shape[1] if batch.batch.get("responses") is not None else 0 - max_prompt = batch.batch.get("prompts").shape[1] if batch.batch.get("prompts") is not None else 0 - return { - "prompt_length": torch.zeros(b_size), - "response_length": torch.zeros(b_size), - "max_response_length": max_resp, - "max_prompt_length": max_prompt, - } - - -# --- Modified Metric Function --- -def compute_dpo_data_metrics(batch: DataProto) -> dict[str, Any]: - """ - Computes and returns metrics relevant for the DPO-like process. - Assumes 'batch' contains results after generation and preference marking, - potentially including 'dpo_logits', 'preferences', 'chosen_logps', etc. - Removes PPO-specific advantage/return/critic metrics. - """ - print("---- [DEBUG] Computing DPO Data Metrics ----") - metrics = {} - try: - # --- Scores and Rewards (from reward_fn) --- - if "token_level_scores" in batch.batch and batch.batch["token_level_scores"] is not None: - sequence_score = batch.batch["token_level_scores"].sum(-1) - metrics.update( - { - "reward/score/mean": torch.mean(sequence_score).item(), - "reward/score/max": torch.max(sequence_score).item(), - "reward/score/min": torch.min(sequence_score).item(), - } - ) - else: - print("DEBUG compute_dpo_data_metrics: 'token_level_scores' not found.") - - if "token_level_rewards" in batch.batch and batch.batch["token_level_rewards"] is not None: - sequence_reward = batch.batch["token_level_rewards"].sum(-1) - metrics.update( - { - "reward/rewards/mean": torch.mean(sequence_reward).item(), - "reward/rewards/max": torch.max(sequence_reward).item(), - "reward/rewards/min": torch.min(sequence_reward).item(), - } - ) - else: - print("DEBUG compute_dpo_data_metrics: 'token_level_rewards' not found.") - - # --- DPO Specific Metrics (if stored previously) --- - if "dpo_logits" in batch.batch and batch.batch["dpo_logits"] is not None: - metrics["actor/dpo_logits"] = batch.batch["dpo_logits"].mean().item() - else: - print("DEBUG compute_dpo_data_metrics: 'dpo_logits' not found.") - - if "chosen_logps" in batch.batch and batch.batch["chosen_logps"] is not None: - metrics["actor/chosen_logps"] = batch.batch["chosen_logps"].mean().item() - else: - print("DEBUG compute_dpo_data_metrics: 'chosen_logps' not found.") - - if "rejected_logps" in batch.batch and batch.batch["rejected_logps"] is not None: - metrics["actor/rejected_logps"] = batch.batch["rejected_logps"].mean().item() - else: - print("DEBUG compute_dpo_data_metrics: 'rejected_logps' not found.") - - # Add metrics based on the 'preferences' mask if available - # if "preferences" in batch.batch and batch.batch["preferences"] is not None: - # prefs_mask = batch.batch["preferences"] # Shape [batch_size * n] - # Calculate accuracy based on RM scores (assuming higher score -> True in mask) - # Requires chosen/rejected scores to be available or recalculated - # This is complex here, better calculated in the main loop or update function - - # --- Length Metrics --- - response_info = _compute_response_info(batch) - prompt_length = response_info["prompt_length"] - response_length = response_info["response_length"] - max_response_length = response_info["max_response_length"] - max_prompt_length = response_info["max_prompt_length"] # Use calculated or from config - - metrics.update( - { - "response_length/mean": torch.mean(response_length).item(), - "response_length/max": torch.max(response_length).item(), - "response_length/min": torch.min(response_length).item(), - "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()).item(), - "prompt_length/mean": torch.mean(prompt_length).item(), - "prompt_length/max": torch.max(prompt_length).item(), - "prompt_length/min": torch.min(prompt_length).item(), - # Prompt clip ratio might need adjustment based on how max_prompt_length is defined - "prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).item(), - } - ) - - except KeyError as e: - print(f"ERROR in compute_dpo_data_metrics: Missing key {e}") - except Exception as e: - print(f"ERROR in compute_dpo_data_metrics: {e}") - traceback.print_exc() - - print(f"---- [DEBUG] Calculated DPO Data Metrics: {list(metrics.keys())} ----") - return metrics - - -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): - responses = data.batch["responses"] - response_length = responses.size(1) - token_level_scores = data.batch["token_level_scores"] - batch_size = data.batch.batch_size[0] - attention_mask = data.batch["attention_mask"] - response_mask = attention_mask[:, -response_length:] - - # compute kl between ref_policy and current policy - # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. - kld = core_algos.kl_penalty( - data.batch["old_log_probs"], data.batch["ref_log_prob"], kl_penalty=kl_penalty - ) # (batch_size, response_length) - kld = kld * response_mask - beta = kl_ctrl.value - - token_level_rewards = token_level_scores - beta * kld - - current_kl = masked_mean(kld, mask=response_mask, axis=-1) # average over sequence - current_kl = torch.mean(current_kl, dim=0).item() - - # according to https://github.com/huggingface/trl/blob/951ca1841f29114b969b57b26c7d3e80a39f75a0/trl/trainer/ppo_trainer.py#L837 - kl_ctrl.update(current_kl=current_kl, n_steps=batch_size) - data.batch["token_level_rewards"] = token_level_rewards - - metrics = {"actor/reward_kl_penalty": current_kl, "actor/reward_kl_penalty_coeff": beta} - - return data, metrics - - -def compute_response_mask(data: DataProto): - responses = data.batch["responses"] - response_length = responses.size(1) - attention_mask = data.batch["attention_mask"] - return attention_mask[:, -response_length:] - - -def compute_onlineDPO_pref(data: DataProto): - """ - Wrapper to compute DPO preference and add it to the DataProto batch. - Includes debugging prints. - """ - # print(f"\n---- [DEBUG] Entering compute_onlineDPO_pref ----") - # print(f" Input batch keys: {list(data.batch.keys())}") - - # Check inputs - rewards_tensor = data.batch.get("token_level_rewards") - mask_tensor = data.batch.get("response_mask") - - if rewards_tensor is None or mask_tensor is None: - print(" ERROR: Missing 'token_level_rewards' or 'response_mask' in input data!") - # Handle error case - maybe return original data or raise? - # Returning original data for now to potentially allow skipping - return data - - try: - preferences = core_algos.compute_onlinedpo_pref(token_level_rewards=rewards_tensor, response_mask=mask_tensor) - # Store the result - data.batch["preferences"] = preferences - - except AttributeError: - print("ERROR: Function 'compute_online_dpo_preference' not found in core_algos.py!") - # Assign dummy value or raise error - data.batch["preferences"] = None # Indicate failure - except Exception as e_pref: - print(f"ERROR during core_algos.compute_online_dpo_preference: {e_pref}") - import traceback - - traceback.print_exc() - data.batch["preferences"] = None # Indicate failure - - # print(f"---- [DEBUG] Exiting compute_onlineDPO_pref ----") - return data - - -@contextmanager -def _timer(name: str, timing_raw: dict[str, float]): - with Timer(name=name, logger=None) as timer: - yield - timing_raw[name] = timer.last - - -class RaySPINTrainer: - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - # TODO: support each role have individual ray_worker_group_cls, - # i.e., support different backend of different role - def __init__( - self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, - processor=None, - reward_fn=None, - val_reward_fn=None, - train_dataset: Optional[Dataset] = None, - val_dataset: Optional[Dataset] = None, - collate_fn=None, - train_sampler: Optional[Sampler] = None, - device_name="cuda", - ): - # assert get_torch_device().is_available(), 'cuda must be available on driver' - - self.tokenizer = tokenizer - self.processor = processor - self.config = config - self.reward_fn = reward_fn - self.val_reward_fn = val_reward_fn - - self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, "Currently, only support hybrid engine" - - if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" - - self.role_worker_mapping = role_worker_mapping - self.resource_pool_manager = resource_pool_manager - self.use_reference_policy = Role.RefPolicy in role_worker_mapping - self.use_rm = Role.RewardModel in role_worker_mapping - self.ray_worker_group_cls = ray_worker_group_cls - self.validation_generations_logger = ValidationGenerationsLogger() - self.async_rollout_mode = False - self.device_name = device_name - - # define in-reward KL control - # kl loss control currently not suppoorted - if config.algorithm.use_kl_in_reward: - self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) - - self.use_critic = False - self._validate_config() - self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) - - def _validate_config(self): - config = self.config - # number of GPUs total - n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes - - # 1. Check total batch size for data correctness - real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n - assert real_train_batch_size % n_gpus == 0, ( - f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." - ) - - # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" - # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". - def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): - settings = { - "actor_rollout_ref.actor": "micro_batch_size", - "critic": "micro_batch_size", - "reward_model": "micro_batch_size", - "actor_rollout_ref.ref": "log_prob_micro_batch_size", - "actor_rollout_ref.rollout": "log_prob_micro_batch_size", - } - - if name in settings: - param = settings[name] - param_per_gpu = f"{param}_per_gpu" - - if mbs is None and mbs_per_gpu is None: - raise ValueError( - f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'." - ) - - if mbs is not None and mbs_per_gpu is not None: - raise ValueError( - f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. " - f"Please remove '{name}.{param}' because only '*_{param_per_gpu}' is supported " - f"(the former is deprecated)." - ) - - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.actor.ppo_micro_batch_size, - config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, - "actor_rollout_ref.actor", - ) - - if self.use_reference_policy: - # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.ref.log_prob_micro_batch_size, - config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.ref", - ) - - # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.rollout.log_prob_micro_batch_size, - config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.rollout", - ) - - if self.use_critic and not config.critic.use_dynamic_bsz: - # Check for critic micro-batch size conflicts - check_mutually_exclusive( - config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic" - ) - - # Check for reward model micro-batch size conflicts - if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive( - config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" - ) - - # Actor - # check if train_batch_size is larger than ppo_mini_batch_size - # if NOT dynamic_bsz, we must ensure: - # ppo_mini_batch_size is divisible by ppo_micro_batch_size - # ppo_micro_batch_size * sequence_parallel_size >= n_gpus - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size - sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) - if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: - assert ( - config.actor_rollout_ref.actor.ppo_mini_batch_size - % config.actor_rollout_ref.actor.ppo_micro_batch_size - == 0 - ) - assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus - - assert config.actor_rollout_ref.actor.loss_agg_mode in [ - "token-mean", - "seq-mean-token-sum", - "seq-mean-token-mean", - ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" - - if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: - print("NOTICE: You have both enabled in-reward kl and kl loss.") - - # critic - if self.use_critic and not config.critic.use_dynamic_bsz: - assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size - sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) - if config.critic.ppo_micro_batch_size is not None: - assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 - assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus - - # Check if use_remove_padding is enabled when using sequence parallelism for fsdp - if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: - if ( - config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 - or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 - ): - assert config.actor_rollout_ref.model.use_remove_padding, ( - "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." - ) - - if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}: - if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: - assert config.critic.model.use_remove_padding, ( - "When using sequence parallelism for critic, you must enable `use_remove_padding`." - ) - - if config.data.get("val_batch_size", None) is not None: - print( - "WARNING: val_batch_size is deprecated. Validation datasets are sent to inference engines " - "as a whole batch, which will schedule the memory themselves." - ) - - # check eval config - if config.actor_rollout_ref.rollout.val_kwargs.do_sample: - assert config.actor_rollout_ref.rollout.temperature > 0, ( - "validation gen temperature should be greater than 0 when enabling do_sample" - ) - - print("[validate_config] All configuration checks passed successfully!") - - def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler): - """ - Creates the train and validation dataloaders. - """ - # TODO: we have to make sure the batch size is divisible by the dp size - from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler - - if train_dataset is None: - train_dataset = create_rl_dataset( - self.config.data.train_files, self.config.data, self.tokenizer, self.processor - ) - if val_dataset is None: - val_dataset = create_rl_dataset( - self.config.data.val_files, self.config.data, self.tokenizer, self.processor - ) - self.train_dataset, self.val_dataset = train_dataset, val_dataset - - if train_sampler is None: - train_sampler = create_rl_sampler(self.config.data, self.train_dataset) - if collate_fn is None: - from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn - - collate_fn = default_collate_fn - - self.train_dataloader = StatefulDataLoader( - dataset=self.train_dataset, - batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), - num_workers=self.config.data.get("dataloader_num_workers", 8), - drop_last=True, - collate_fn=collate_fn, - sampler=train_sampler, - ) - - val_batch_size = self.config.data.val_batch_size # Prefer config value if set - if val_batch_size is None: - val_batch_size = len(self.val_dataset) - - self.val_dataloader = StatefulDataLoader( - dataset=self.val_dataset, - batch_size=val_batch_size, - num_workers=self.config.data.get("dataloader_num_workers", 8), - shuffle=False, - drop_last=False, - collate_fn=collate_fn, - ) - - assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" - assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" - - print( - f"Size of train dataloader: {len(self.train_dataloader)}, " - f"Size of val dataloader: {len(self.val_dataloader)}" - ) - - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs - - if self.config.trainer.total_training_steps is not None: - total_training_steps = self.config.trainer.total_training_steps - - self.total_training_steps = total_training_steps - print(f"Total training steps: {self.total_training_steps}") - - try: - OmegaConf.set_struct(self.config, True) - with open_dict(self.config): - if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps - if OmegaConf.select(self.config, "critic.optim"): - self.config.critic.optim.total_training_steps = total_training_steps - except Exception as e: - print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") - - def _maybe_log_val_generations(self, inputs, outputs, scores): - """Log a table of validation samples to the configured logger (wandb or swanlab)""" - - generations_to_log = self.config.trainer.log_val_generations - - if generations_to_log == 0: - return - - import numpy as np - - # Create tuples of (input, output, score) and sort by input text - samples = list(zip(inputs, outputs, scores, strict=True)) - samples.sort(key=lambda x: x[0]) # Sort by input text - - # Use fixed random seed for deterministic shuffling - rng = np.random.RandomState(42) - rng.shuffle(samples) - - # Take first N samples after shuffling - samples = samples[:generations_to_log] - - # Log to each configured logger - self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) - - def _validate(self): - data_source_lst = [] - reward_extra_infos_dict: dict[str, list] = defaultdict(list) - - # Lists to collect samples for the table - sample_inputs = [] - sample_outputs = [] - sample_scores = [] - - for test_data in self.val_dataloader: - test_batch = DataProto.from_single_dict(test_data) - - # repeat test batch - test_batch = test_batch.repeat( - repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True - ) - - # we only do validation on rule-based rm - if self.config.reward_model.enable and test_batch[0].non_tensor_batch["reward_model"]["style"] == "model": - return {} - - # Store original inputs - input_ids = test_batch.batch["input_ids"] - # TODO: Can we keep special tokens except for padding tokens? - input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] - sample_inputs.extend(input_texts) - - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] - if "multi_modal_inputs" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"]) - if "raw_prompt" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("raw_prompt") - if "tools_kwargs" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("tools_kwargs") - test_gen_batch = test_batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) - - test_gen_batch.meta_info = { - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - "recompute_log_prob": False, - "do_sample": self.config.actor_rollout_ref.rollout.val_kwargs.do_sample, - "validate": True, - } - print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") - - # pad to be divisible by dp_size - test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) - if not self.async_rollout_mode: - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) - else: - test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) - - # unpad - test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) - print("validation generation end") - - # Store generated outputs - output_ids = test_output_gen_batch.batch["responses"] - output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids] - sample_outputs.extend(output_texts) - - test_batch = test_batch.union(test_output_gen_batch) - - # evaluate using reward_function - result = self.val_reward_fn(test_batch, return_dict=True) - reward_tensor = result["reward_tensor"] - scores = reward_tensor.sum(-1).cpu().tolist() - sample_scores.extend(scores) - - reward_extra_infos_dict["reward"].extend(scores) - if "reward_extra_info" in result: - for key, lst in result["reward_extra_info"].items(): - reward_extra_infos_dict[key].extend(lst) - - data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) - - self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) - - # dump generations - val_data_dir = self.config.trainer.get("validation_data_dir", None) - if val_data_dir: - self._dump_generations( - inputs=sample_inputs, - outputs=sample_outputs, - scores=sample_scores, - reward_extra_infos_dict=reward_extra_infos_dict, - dump_path=val_data_dir, - ) - - for key_info, lst in reward_extra_infos_dict.items(): - assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" - - data_sources = np.concatenate(data_source_lst, axis=0) - print(f"DEBUG: Data sources shape: {data_sources.shape}") # Added Print - print(f"DEBUG: reward_extra_infos_dict keys before processing: {reward_extra_infos_dict.keys()}") # Added Print - - data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) - print( - f"DEBUG: Output of process_validation_metrics (data_src2var2metric2val): {data_src2var2metric2val}" - ) # Added Print - metric_dict = {} - for data_source, var2metric2val in data_src2var2metric2val.items(): - core_var = "acc" if "acc" in var2metric2val else "reward" - for var_name, metric2val in var2metric2val.items(): - n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) - for metric_name, metric_val in metric2val.items(): - if ( - (var_name == core_var) - and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) - and (f"@{n_max}" in metric_name) - ): - metric_sec = "val-core" - else: - metric_sec = "val-aux" - pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" - metric_dict[pfx] = metric_val - - return metric_dict - - def init_workers(self): - """Init resource pool and worker group""" - self.resource_pool_manager.create_resource_pool() - - self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} - - # create actor and rollout - if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], - config=self.config.actor_rollout_ref, - role="actor_rollout", - ) - self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls - else: - raise NotImplementedError - - # create critic - if self.use_critic: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) - self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls - - # create reference policy if needed - if self.use_reference_policy: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) - ref_policy_cls = RayClassWithInitArgs( - self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role="ref" - ) - self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls - - # create a reward model if reward_fn is None - if self.use_rm: - # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) - self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls - - # initialize WorkerGroup - # NOTE: if you want to use a different resource pool for each role, which can support different - # parallel size, - # you should not use `create_colocated_worker_cls`. Instead, directly pass different resource pool to - # different worker groups. - # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. - all_wg = {} - self.wg_dicts = [] - wg_kwargs = {} # Setting up kwargs for RayWorkerGroup - if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: - wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout - - for resource_pool, class_dict in self.resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = self.ray_worker_group_cls( - resource_pool=resource_pool, - ray_cls_with_init=worker_dict_cls, - device_name=self.device_name, - **wg_kwargs, - ) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - # keep the referece of WorkerDict to support ray >= 2.31. Ref: https://github.com/ray-project/ray/pull/45699 - self.wg_dicts.append(wg_dict) - - if self.use_critic: - self.critic_wg = all_wg["critic"] - self.critic_wg.init_model() - - if self.use_reference_policy: - self.ref_policy_wg = all_wg["ref"] - self.ref_policy_wg.init_model() - - if self.use_rm: - self.rm_wg = all_wg["rm"] - self.rm_wg.init_model() - - # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg["actor_rollout"] - self.actor_rollout_wg.init_model() - - def _save_checkpoint(self): - # path: given_path + `/global_step_{global_steps}` + `/actor` - local_global_step_folder = os.path.join( - self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" - ) - - print(f"local_global_step_folder: {local_global_step_folder}") - actor_local_path = os.path.join(local_global_step_folder, "actor") - - actor_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") - ) - - remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) - if remove_previous_ckpt_in_save: - print( - "Warning: remove_previous_ckpt_in_save is deprecated, set max_actor_ckpt_to_keep=1 and " - "max_critic_ckpt_to_keep=1 instead" - ) - max_actor_ckpt_to_keep = ( - self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 - ) - max_critic_ckpt_to_keep = ( - self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 - ) - - self.actor_rollout_wg.save_checkpoint( - actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep - ) - - if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, "critic") - critic_remote_path = ( - None - if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") - ) - self.critic_wg.save_checkpoint( - critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep - ) - - # save dataloader - dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") - dataloader_state_dict = self.train_dataloader.state_dict() - torch.save(dataloader_state_dict, dataloader_local_path) - - # latest checkpointed iteration tracker (for atomic usage) - local_latest_checkpointed_iteration = os.path.join( - self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" - ) - with open(local_latest_checkpointed_iteration, "w") as f: - f.write(str(self.global_steps)) - - def _load_checkpoint(self): - if self.config.trainer.resume_mode == "disable": - return 0 - - # load from hdfs - if self.config.trainer.default_hdfs_dir is not None: - raise NotImplementedError("load from hdfs is not implemented yet") - else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path - if not os.path.isabs(checkpoint_folder): - working_dir = os.getcwd() - checkpoint_folder = os.path.join(working_dir, checkpoint_folder) - global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest - - # find global_step_folder - if self.config.trainer.resume_mode == "auto": - if global_step_folder is None: - print("Training from scratch") - return 0 - else: - if self.config.trainer.resume_mode == "resume_path": - assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" - assert "global_step_" in self.config.trainer.resume_from_path, ( - "resume ckpt must specify the global_steps" - ) - global_step_folder = self.config.trainer.resume_from_path - if not os.path.isabs(global_step_folder): - working_dir = os.getcwd() - global_step_folder = os.path.join(working_dir, global_step_folder) - print(f"Load from checkpoint folder: {global_step_folder}") - # set global step - self.global_steps = int(global_step_folder.split("global_step_")[-1]) - - print(f"Setting global step to {self.global_steps}") - print(f"Resuming from {global_step_folder}") - - actor_path = os.path.join(global_step_folder, "actor") - critic_path = os.path.join(global_step_folder, "critic") - # load actor - self.actor_rollout_wg.load_checkpoint( - actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - # load critic - if self.use_critic: - self.critic_wg.load_checkpoint( - critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load - ) - - # load dataloader, - # TODO: from remote not implemented yet - dataloader_local_path = os.path.join(global_step_folder, "data.pt") - if os.path.exists(dataloader_local_path): - dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) - self.train_dataloader.load_state_dict(dataloader_state_dict) - else: - print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") - - def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): - """Reorder the data on single controller such that each dp rank gets similar total tokens""" - attention_mask = batch.batch["attention_mask"] - batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) - world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) - # reorder based on index. The data will be automatically equally partitioned by dispatch function - global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) - batch.reorder(global_idx) - global_balance_stats = log_seqlen_unbalance( - seqlen_list=global_seqlen_lst, partitions=global_partition_lst, prefix=logging_prefix - ) - metrics.update(global_balance_stats) - - def fit_dpo(self): # Renamed for clarity as standard PPO loop - """ - The training loop of Online DPO using a periodically updated reference model. - The driver process calls worker groups for computation. - Advantage computation is replaced by DPO logic. - """ - import traceback # Ensure traceback is imported - - from omegaconf import OmegaConf - - from verl.utils.tracking import Tracking - - # Initialize logger - logger = None - try: - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True, throw_on_missing=False), - ) - except Exception as e: - print(f"Warning: Failed to initialize logger: {e}") - - self.global_steps = 0 - # Load checkpoint before doing anything - loaded_step = self._load_checkpoint() - self.global_steps = loaded_step + 1 if loaded_step is not None and loaded_step > 0 else 1 - print( - f"Starting Online DPO training from global step {self.global_steps}. " - f"Total steps: {self.total_training_steps}" - ) - print(f"Reference model update frequency: {self.config.trainer.get('ref_update_freq', 'Not Set')}") - - # Check if reference policy is configured correctly for this mode - if not self.use_reference_policy: - print( - "WARNING: 'use_reference_policy' is False. Periodic reference model update requires a " - "reference policy worker. DPO updates might fail or use incorrect logic." - ) - # Consider raising an error if strict adherence is required: - # raise ValueError("Periodic reference model update requires 'use_reference_policy' to be True " - # "and a configured reference worker.") - - # Perform validation before training - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - print("Running validation before Online DPO training...") - val_metrics = self._validate() - pprint(f"Initial validation metrics: {val_metrics}") - if logger and val_metrics: - logger.log(data=val_metrics, step=max(0, self.global_steps - 1)) - if self.config.trainer.get("val_only", False): - print("Validation only mode enabled. Exiting training.") - if logger and hasattr(logger, "finish"): - logger.finish() - return - - # Add tqdm progress bar - progress_bar = tqdm( - total=self.total_training_steps, - initial=self.global_steps, - desc="Online DPO Training Progress", - position=0, - leave=True, - ) - - last_val_metrics = None - should_stop = False - - for epoch in range(self.config.trainer.total_epochs): - if should_stop: - break - print(f"--- Starting Online DPO Epoch {epoch} ---") - try: - train_iterator = iter(self.train_dataloader) - except TypeError: - print("Warning: Dataloader is not iterable.") - train_iterator = self.train_dataloader # Fallback attempt - - for batch_idx, batch_dict in enumerate(train_iterator): - if self.global_steps > self.total_training_steps: - should_stop = True - break - - metrics = {} - timing_raw = {} - step_timer = Timer(logger=None) - ref_log_prob_computed = False # Flag to track if ref log probs were computed - - try: # Outer try-except for the whole step - step_timer.start() - with _timer("step", timing_raw): - batch: DataProto = DataProto.from_single_dict(batch_dict) - current_batch_size = batch.batch.batch_size[0] - print( - f"\n[Step {self.global_steps}, Batch {batch_idx}] Processing batch size: " - f"{current_batch_size}" - ) - - # --- Reference Model Update --- - ref_update_freq = self.config.trainer.get("ref_update_freq", -1) - if ( - self.use_reference_policy - and ref_update_freq > 0 - and self.global_steps % ref_update_freq == 0 - ): - print(f"\n[Step {self.global_steps}] Updating Reference Model Weights from Actor...") - try: - # --- This requires careful implementation with FSDP --- - # 1. Save actor state dict (potentially to CPU memory or disk) - # This needs to be done collectively across actor worker ranks. - # The checkpoint_manager might be adaptable, or use FSDP APIs directly. - # Example placeholder using a conceptual save/load mechanism: - actor_state_path = "/tmp/actor_state_mid" # Temporary path - self.actor_rollout_wg.save_checkpoint(actor_state_path) # Adapt save logic - - # 2. Load the state dict onto the reference model worker group - # This also needs collective loading on the ref worker ranks. - self.ref_policy_wg.load_checkpoint(actor_state_path, None, True) # Adapt load logic - - print(f"[Step {self.global_steps}] Reference Model Weights Updated.") - # Optionally remove the temporary state file - # os.remove(actor_state_path) # Needs rank-aware removal or shared storage - - except Exception as sync_e: - print(f"ERROR during reference model sync at step {self.global_steps}: {sync_e}") - traceback.print_exc() - - # Pop keys for generation - pop_batch_keys = ["input_ids", "attention_mask"] - if "position_ids" in batch.batch: - pop_batch_keys.append("position_ids") - pop_non_tensor_keys = ["raw_prompt_ids"] if "raw_prompt_ids" in batch.non_tensor_batch else [] - if "multi_modal_inputs" in batch.non_tensor_batch.keys(): - pop_non_tensor_keys.extend(["multi_modal_data", "multi_modal_inputs"]) - original_non_tensor_data = batch.non_tensor_batch - gen_batch = batch.pop( - batch_keys=pop_batch_keys, - non_tensor_batch_keys=pop_non_tensor_keys, - ) - gen_batch = gen_batch.repeat( - repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True - ) - # (Add Debug prints for gen_batch if needed) - - # Generate sequences (chosen/rejected pairs) - with _timer("gen", timing_raw): - try: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - # (Add Debug prints for gen_batch_output if needed) - except Exception as gen_e: - print(f"\n!!!!!!!! ERROR DURING GENERATION (Step {self.global_steps}) !!!!!!!!") - print(gen_e) - traceback.print_exc() - print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - step_timer.stop() - continue - - # Combine original prompts with generated sequences - batch.non_tensor_batch = original_non_tensor_data # Restore non-tensor data - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(current_batch_size)], dtype=object - ) - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - batch = batch.union(gen_batch_output) - # (Add Debug prints after union if needed) - - # Compute response mask (needed for ref logprob calc and DPO prep) - batch.batch["response_mask"] = compute_response_mask(batch) - - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - - # --- Compute Log Probs for the CURRENT policy (used for KL if enabled, or ActorAsRef - # fallback) --- - # Note: For pure DPO with external ref, this 'old_log_probs' might not be strictly needed - # unless used for other metrics or a fallback. Keep it for now. - with _timer("policy_log_prob", timing_raw): - policy_log_prob_output = self.actor_rollout_wg.compute_log_prob(batch) - batch = batch.union(policy_log_prob_output) # Adds 'old_log_probs' - # (Debug prints for old_log_probs) - - # --- Compute Log Probs using the EXTERNAL Reference Model --- - if self.use_reference_policy: - with _timer("ref_log_prob_dpo", timing_raw): - # print(f"---- [Step {self.global_steps}] DEBUG DPO: Calling compute_ref_log_prob ----") - try: - # 'batch' contains interleaved chosen/rejected sequences - ref_log_prob_output = self.ref_policy_wg.compute_ref_log_prob( - batch - ) # Returns DataProto with 'ref_log_prob' - batch = batch.union( - ref_log_prob_output - ) # Adds 'ref_log_prob' key [batch_size * n, seq_len] - ref_log_prob_computed = True # Mark success - # print(f"---- [Step {self.global_steps}] DEBUG DPO: ref_log_prob tensor shape: " - # f"{batch.batch['ref_log_prob'].shape} ----") - except Exception as ref_e: - print(f"ERROR computing reference log probs at step {self.global_steps}: {ref_e}") - traceback.print_exc() - batch.batch["ref_log_prob"] = None # Mark as failed - ref_log_prob_computed = False - else: - print( - "Warning: Skipping external reference log prob calculation as use_reference_policy " - "is False." - ) - # DPO update will likely fail unless ActorAsRef logic is re-enabled in dp_actor - - # --- Compute Rewards/Scores (used to determine preference) --- - with _timer("reward_calc", timing_raw): - # (Reward calculation logic using RM or reward_fn as before) - # ... Ensure this calculates 'token_level_rewards' or similar ... - if self.use_rm: - reward_tensor_rm = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor_rm) # Adds 'rm_scores' - - reward_extra_infos_dict = {} - try: - if self.reward_fn is None: - # print(f"---- [DEBUG Step {self.global_steps}] ERROR: self.reward_fn is None! " - # f"Using dummy rewards. ----") - # Use rm_scores if available, otherwise zeros - reward_tensor = batch.batch.get( - "rm_scores", torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32) - ) - else: - reward_result = self.reward_fn(batch, return_dict=True) - reward_tensor = reward_result["reward_tensor"] # Final combined reward - reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) - - except Exception: - # print(f'---- [DEBUG Step {self.global_steps}] Error in reward_fn call: {e}. ' - # f'Using dummy rewards. ----') - traceback.print_exc() - reward_tensor = torch.zeros_like(batch.batch["response_mask"], dtype=torch.float32) - reward_extra_infos_dict = {} - - # Use 'token_level_rewards' as the key for preference calculation - batch.batch["token_level_rewards"] = reward_tensor - if reward_extra_infos_dict: - batch.non_tensor_batch.update( - {k: np.array(v) for k, v in reward_extra_infos_dict.items()} - ) - - # --- Determine Preferences --- - # Uses 'token_level_rewards' to determine chosen/rejected based on score - batch = compute_onlineDPO_pref(batch) # Adds 'preferences' key - - # --- Prepare DPO Batch --- - dpo_update_batch_proto = None # Initialize - with _timer("prepare_dpo_batch", timing_raw): - try: - if "preferences" not in batch.batch or batch.batch["preferences"] is None: - raise ValueError("'preferences' key missing or None after compute_onlineDPO_pref.") - - # Check if reference log probs were computed successfully (if needed) - if self.use_reference_policy and not ref_log_prob_computed: - raise ValueError("Reference log probs required but failed to compute.") - - # Check required base keys - required_keys = ["input_ids", "attention_mask", "response_mask"] - for rk in required_keys: - if rk not in batch.batch or batch.batch[rk] is None: - raise KeyError(f"Required key '{rk}' missing from batch for DPO prep.") - - preferences_mask = batch.batch["preferences"] # Shape [batch_size * n] - not_preferences_mask = ~preferences_mask - - # Gather Chosen/Rejected Base Tensors - chosen_input_ids = batch.batch["input_ids"][preferences_mask] - chosen_attention_mask = batch.batch["attention_mask"][preferences_mask] - rejected_input_ids = batch.batch["input_ids"][not_preferences_mask] - rejected_attention_mask = batch.batch["attention_mask"][not_preferences_mask] - chosen_position_ids = ( - batch.batch.get("position_ids")[preferences_mask] - if "position_ids" in batch.batch - else None - ) - rejected_position_ids = ( - batch.batch.get("position_ids")[not_preferences_mask] - if "position_ids" in batch.batch - else None - ) - - # Create Labels - print("WARNING: Creating DPO labels using configured max_prompt_length...") - prompt_len = self.config.data.max_prompt_length - chosen_labels = chosen_input_ids.clone() - chosen_labels[:, :prompt_len] = -100 - rejected_labels = rejected_input_ids.clone() - rejected_labels[:, :prompt_len] = -100 - - # Calculate and Gather Reference Log Probs (Sequence Level) - if self.use_reference_policy: - ref_log_prob_tensor = batch.batch["ref_log_prob"] # Token level [bsz * n, seq_len] - response_mask_full = batch.batch[ - "response_mask" - ] # Response mask [bsz * n, seq_len] - ref_sequence_logps = (ref_log_prob_tensor * response_mask_full).sum( - dim=-1 - ) # Sequence level [bsz * n] - reference_chosen_logps = ref_sequence_logps[preferences_mask] - reference_rejected_logps = ref_sequence_logps[not_preferences_mask] - else: - # If not using external ref, DPO needs ActorAsRef logic in dp_actor - # We won't add the keys here, dp_actor will handle it (or fail if not modified) - print( - "Info: Not adding explicit reference logps to DPO batch " - "(use_reference_policy=False)." - ) - reference_chosen_logps = None # Explicitly None - reference_rejected_logps = None - - # Package Tensors - dpo_tensors = { - "chosen_input_ids": chosen_input_ids, - "chosen_attention_mask": chosen_attention_mask, - "chosen_labels": chosen_labels, - "rejected_input_ids": rejected_input_ids, - "rejected_attention_mask": rejected_attention_mask, - "rejected_labels": rejected_labels, - } - # Conditionally add reference logps if computed - if reference_chosen_logps is not None: - dpo_tensors["reference_chosen_logps"] = reference_chosen_logps - if reference_rejected_logps is not None: - dpo_tensors["reference_rejected_logps"] = reference_rejected_logps - # Add position ids if they exist - if chosen_position_ids is not None: - dpo_tensors["chosen_position_ids"] = chosen_position_ids - if rejected_position_ids is not None: - dpo_tensors["rejected_position_ids"] = rejected_position_ids - - # Prepare Meta Info - dpo_meta = { - "dpo_beta": OmegaConf.select(self.config.algorithm, "dpo_beta", default=0.1), - "dpo_loss_type": OmegaConf.select( - self.config.algorithm, "dpo_loss_type", default="sigmoid" - ), - "dpo_label_smoothing": OmegaConf.select( - self.config.algorithm, "dpo_label_smoothing", default=0.0 - ), - "use_reference_policy": self.use_reference_policy, - "reference_free": not self.use_reference_policy, # False if using external ref - "global_step": self.global_steps, - } - - dpo_update_batch_proto = DataProto.from_dict(tensors=dpo_tensors, meta_info=dpo_meta) - # print(f"---- [Step {self.global_steps}] DEBUG DPO: Prepared DPO Update Batch ----") - # print(f" Keys: {list(dpo_update_batch_proto.batch.keys())}") - # print(f" Meta Info: {dpo_meta}") - - except Exception as e_prep: - print(f"ERROR preparing DPO batch at step {self.global_steps}: {e_prep}") - traceback.print_exc() - dpo_update_batch_proto = None # Skip update on error - - # --- Actor Update Step --- - actor_output = None - if self.config.trainer.critic_warmup <= self.global_steps and dpo_update_batch_proto: - with _timer("update_actor", timing_raw): - # Pass the batch containing reference log probs (if computed) - # The modified update_actor_dpo expects them if reference_free=False - actor_output = self.actor_rollout_wg.update_actor_dpo(dpo_update_batch_proto) - if actor_output and "metrics" in actor_output.meta_info: - metrics.update(reduce_metrics(actor_output.meta_info["metrics"])) - elif dpo_update_batch_proto is None: - print( - f"Skipping actor update at step {self.global_steps} due to DPO batch preparation error." - ) - - # --- Validation and Saving --- - test_freq = OmegaConf.select(self.config.trainer, "test_freq", default=-1) - is_last_step = self.global_steps >= self.total_training_steps - if ( - self.val_reward_fn is not None - and test_freq > 0 - and (is_last_step or self.global_steps % test_freq == 0) - ): - print(f"\nRunning DPO validation at step {self.global_steps}...") - val_timing_raw = {} - with _timer("testing", val_timing_raw): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - if val_metrics: - metrics["time/validation_run"] = val_timing_raw.get("testing", 0) - metrics.update(val_metrics) - else: - print("Validation skipped or returned no metrics.") - - save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1) - if save_freq > 0 and (is_last_step or self.global_steps % save_freq == 0): - print(f"\nSaving DPO checkpoint at step {self.global_steps}...") - with _timer("save_checkpoint", timing_raw): - self._save_checkpoint() # Saves actor (and potentially critic if used elsewhere) - metrics["time/save_checkpoint"] = timing_raw.get("save_checkpoint", 0) - - # --- End main step timer context --- - - # --- Metrics calculation AFTER the 'step' timer block --- - metrics.update(compute_dpo_data_metrics(batch=batch)) # Use DPO-specific metrics - metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) - n_gpus = self.resource_pool_manager.get_n_gpus() - if "step" in timing_raw: - metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) - else: - print( - f"Warning: 'step' key missing from timing_raw at step {self.global_steps}. " - f"Skipping throughput." - ) - - step_timer.stop() - metrics["time/step"] = step_timer.last - - # Log metrics - log_freq = OmegaConf.select(self.config.trainer, "log_freq", default=1) - if logger and self.global_steps % log_freq == 0: - log_payload = metrics.copy() - # Add learning rate to log payload - if actor_output and "actor/lr" in metrics: - log_payload["actor/lr"] = metrics["actor/lr"] - - print(f"[Step {self.global_steps} DPO] Logging Step Payload Keys: {list(log_payload.keys())}") - try: - logger.log(data=log_payload, step=self.global_steps) - except Exception as e: - print(f"Logging failed at step {self.global_steps}: {e}") - - # Update progress bar - postfix_metrics = { - k: f"{v:.3f}" if isinstance(v, float) else v - for k, v in metrics.items() - if isinstance(v, int | float) - } - progress_bar.set_postfix(postfix_metrics) - - except Exception as step_e: - print(f"\n!!!!!!!! ERROR DURING DPO Step {self.global_steps} !!!!!!!!") - print(f"Caught Exception: {step_e}") - traceback.print_exc() - print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - step_timer.stop() - should_stop = True - break - - if is_last_step or should_stop: - print(f"Stopping DPO training at step {self.global_steps}.") - break - - self.global_steps += 1 - progress_bar.update(1) - - # End of epoch handling - if hasattr(self.train_dataloader, "reset"): - try: - self.train_dataloader.reset() - except Exception as e: - print(f"Warning: Failed to reset train dataloader state: {e}") - if should_stop: - break - - # --- Final cleanup and logging --- - progress_bar.close() - final_step = max(0, self.global_steps - 1) - print(f"Online DPO Training finished at step {final_step}.") - # Save final checkpoint - save_freq = OmegaConf.select(self.config.trainer, "save_freq", default=-1) - if not self.config.trainer.get("val_only", False) and (save_freq <= 0 or final_step % save_freq != 0): - print(f"Saving final DPO checkpoint at step {final_step}...") - self._save_checkpoint() - - # Final validation run - if self.val_reward_fn and last_val_metrics is None and not self.config.trainer.get("val_only", False): - print("Running final validation...") - last_val_metrics = self._validate() - if last_val_metrics and logger: - last_val_metrics["final_validation"] = True - try: - logger.log(data=last_val_metrics, step=final_step) - except Exception as e: - print(f"[Final Val Metrics Log Error]: {e}") - - pprint(f"Final validation metrics: {last_val_metrics}") - if logger and hasattr(logger, "finish"): - logger.finish() - print("Online DPO Training Run Complete.") diff --git a/recipe/sppo/README.md b/recipe/sppo/README.md deleted file mode 100644 index f87efa853..000000000 --- a/recipe/sppo/README.md +++ /dev/null @@ -1,50 +0,0 @@ -# SPPO: Self-Play Preference Optimization for Language Model Alignment - -This repository hosts the community implementation for the paper [Self-Play Preference Optimization for Language Model Alignment](https://arxiv.org/abs/2405.00675). SPPO can significantly enhance the performance of an LLM without strong external signals such as responses or preferences from GPT-4. It can outperform the model trained with iterative direct preference optimization (DPO), among other methods. SPPO is theoretically grounded, ensuring that the LLM can converge to the von Neumann winner (i.e., Nash equilibrium) under general, potentially intransitive preference, and empirically validated through extensive evaluations on multiple datasets. - -Paper Authors: [Yue Wu](https://yuewu.us/)\*, [Zhiqing Sun](https://www.cs.cmu.edu/~zhiqings/)\*, [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ)\*, [Kaixuan Ji](https://scholar.google.com/citations?user=FOoKDukAAAAJ), [Yiming Yang](https://www.cs.cmu.edu/~yiming/), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) - -verl Implementation Authors: [Yuhao Yang](https://github.com/yhyang201), [Chenyang Zhao](https://github.com/zhaochenyang20) - -[[Webpage](https://uclaml.github.io/SPPO/)] [[Huggingface](https://huggingface.co/papers/2405.00675)] [[Paper](https://arxiv.org/abs/2405.00675)][[Original Implementation](https://github.com/uclaml/SPPO)] - -## Reproduce the Experiment - -We evaluate the performance of SPPO on the MATH dataset. Starting from an initial score of 46.6 with Qwen2.5-7B-Instruct, we achieve a score of 65.6 after 20 epochs of training, placing our model approximately in the top 20 on the [MATH leaderboard](https://paperswithcode.com/sota/math-word-problem-solving-on-math). It's important to note that verl's internal evaluation metrics may not perfectly align with the official evaluation methodology for Qwen2.5-7B-Instruct. Therefore, for consistency and fair comparison, we report only the results based on verl's evaluation framework. - -``` -git clone git@github.com:volcengine/verl.git -cd verl -python3 -m uv pip install -e ".[sglang]" - -export WANDB_API_KEY= - -python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math -huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct - -export CUDA_VISIBLE_DEVICES=0,1,2,3 -bash recipe/sppo/run_qwen2.5-7b_rm.sh -``` - -Note that the installation would occasionally fail to install flash-attn. If this happens, you can install it manually by running: - -```bash -python3 -m uv pip install wheel -python3 -m uv pip install packaging -python3 -m uv pip install flash-attn --no-build-isolation --no-deps -``` - -## Acknowledgement - -We sincerely thank the contribution and guidance from: - -- [Yue Wu](https://yuewu.us/) -- [Chendong Wang](https://cdwang96.github.io/) -- [Yifan Zhang](https://github.com/yifanzhang-pro) -- [Yongan Xiang](https://github.com/BearBiscuit05) -- [Junrong Lin](https://github.com/ocss884) -- [Yuxuan Tong](https://github.com/tongyx361) -- [Guangming Shen](https://github.com/PeterSH6) -- [Biao He](https://www.linkedin.com/in/biao-he/) -- [Qingquan Song](https://qingquansong.github.io/) -- [Quanquan Gu](https://web.cs.ucla.edu/~qgu/) diff --git a/recipe/sppo/__init__.py b/recipe/sppo/__init__.py deleted file mode 100644 index bc88468e3..000000000 --- a/recipe/sppo/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/recipe/sppo/config/sppo_trainer.yaml b/recipe/sppo/config/sppo_trainer.yaml deleted file mode 100644 index f127e1840..000000000 --- a/recipe/sppo/config/sppo_trainer.yaml +++ /dev/null @@ -1,28 +0,0 @@ -# the sppo config will override default ppo_trainer.yaml - -hydra: - searchpath: - - file://verl/trainer/config - -defaults: - - ppo_trainer - - _self_ - -actor_rollout_ref: - actor: - sppo_eta: 1.0 - optim: - lr_warmup_steps: 15 - rollout: - name: sglang - tensor_model_parallel_size: 2 - gpu_memory_utilization: 0.5 - val_kwargs: - n: 2 # 2 will trigger validation, 1 will bypass - -algorithm: - adv_estimator: null - sppo_eta: 1.0 - -trainer: - log_val_generations: 0 \ No newline at end of file diff --git a/recipe/sppo/dp_actor.py b/recipe/sppo/dp_actor.py deleted file mode 100644 index df14c0b4e..000000000 --- a/recipe/sppo/dp_actor.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os - -import torch - -import verl.utils.torch_functional as verl_F -from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss, kl_penalty -from verl.utils.device import get_device_id -from verl.utils.profiler import GPUMemoryLogger -from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import rearrange_micro_batches -from verl.workers.actor.dp_actor import DataParallelPPOActor - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -def compute_sppo_loss( - old_log_prob: torch.Tensor, # (bs, seq_len) - log_prob: torch.Tensor, # (bs, seq_len) - rewards: torch.Tensor, # (bs,) - response_mask: torch.Tensor, # (bs, seq_len) - eta: float = 1.0, - loss_agg_mode: str = "token-mean", -): - """ - SPPO Loss computation. - """ - # Compute log-ratios over masked tokens - log_prob_sum = (log_prob * response_mask).sum(dim=1) # (bs,) - old_log_prob_sum = (old_log_prob * response_mask).sum(dim=1) # (bs,) - log_ratios = log_prob_sum - old_log_prob_sum # (bs,) - - scaled_rewards = eta * (rewards) - loss_vec = (log_ratios - scaled_rewards) ** 2 # (bs,) - - if loss_agg_mode == "token-mean": - sample_mask = response_mask.any(dim=1).float() # (bs,) - loss = verl_F.masked_mean(loss_vec, sample_mask) - - return loss, log_ratios, scaled_rewards - - -class DataParallelSPPOActor(DataParallelPPOActor): - @GPUMemoryLogger(role="dp actor", logger=logger) - def update_policy(self, data: DataProto): - # make sure we are in training mode - self.actor_module.train() - - temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error - multi_turn = data.meta_info.get("multi_turn", False) - - select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "seq_level_rewards"] - if multi_turn: - select_keys.append("loss_mask") - if self.config.use_kl_loss: - select_keys.append("ref_log_prob") - batch = data.select(batch_keys=select_keys).batch - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - - # Split to make minibatch iterator for updating the actor - # See PPO paper for details. https://arxiv.org/abs/1707.06347 - if has_multi_modal_inputs: - num_mini_batches = data.batch.batch_size[0] // self.config.ppo_mini_batch_size - non_tensor_select_keys = ["multi_modal_inputs"] - dataloader = data.select(select_keys, non_tensor_select_keys).chunk(num_mini_batches) - else: - dataloader = batch.split(self.config.ppo_mini_batch_size) - - metrics = {} - for epoch in range(self.config.ppo_epochs): - for batch_idx, data in enumerate(dataloader): - # split batch into micro_batches - mini_batch = data - if has_multi_modal_inputs: - self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - ) - num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys).chunk(num_micro_batches) - elif self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) - else: - self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - ) - # split batch into micro_batches - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - - self.actor_optimizer.zero_grad() - - for data in micro_batches: - # Support all hardwares - if isinstance(data, DataProto): - data = {**data.batch.to(get_device_id()), **data.non_tensor_batch} - else: - data = data.to(get_device_id()) # actor device is cpu when using offload - responses = data["responses"] - response_length = responses.size(1) - attention_mask = data["attention_mask"] - if multi_turn: - response_mask = data["loss_mask"][:, -response_length:] - else: - response_mask = attention_mask[:, -response_length:] - - old_log_prob = data["old_log_probs"] - rewards = data["seq_level_rewards"] - - entropy_coeff = self.config.entropy_coeff - loss_agg_mode = self.config.loss_agg_mode - eta = self.config.get("sppo_eta", 1.0) - - # all return: (bsz, response_length) - calculate_entropy = False - if entropy_coeff != 0: - calculate_entropy = True - entropy, log_prob = self._forward_micro_batch( - micro_batch=data, temperature=temperature, calculate_entropy=calculate_entropy - ) - - pg_loss, log_ratios, preference = compute_sppo_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - rewards=rewards, - response_mask=response_mask, - eta=eta, - loss_agg_mode=loss_agg_mode, - ) - - if entropy_coeff != 0: - entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - - # compute policy loss - policy_loss = pg_loss - entropy_loss * entropy_coeff - else: - policy_loss = pg_loss - - if self.config.use_kl_loss: - ref_log_prob = data["ref_log_prob"] - # compute kl loss - kld = kl_penalty( - logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type - ) - kl_loss = agg_loss( - loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode - ) - - policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - metrics["actor/kl_loss"] = kl_loss.detach().item() - metrics["actor/kl_coef"] = self.config.kl_loss_coef - - if self.config.use_dynamic_bsz: - # relative to the dynamic bsz - loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) - else: - loss = policy_loss / self.gradient_accumulation - loss.backward() - - data = { - "actor/loss": loss.detach().item(), - "actor/log_ratio_mean": log_ratios.mean().detach().item(), - "actor/preference_mean": preference.mean().detach().item(), - } - append_to_dict(metrics, data) - - grad_norm = self._optimizer_step() - data = {"actor/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, data) - self.actor_optimizer.zero_grad() - return metrics diff --git a/recipe/sppo/main_sppo.py b/recipe/sppo/main_sppo.py deleted file mode 100644 index d99f4f2dc..000000000 --- a/recipe/sppo/main_sppo.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. -""" - -import os - -import hydra -import ray - -from verl.trainer.ppo.reward import load_reward_manager - -from .sppo_ray_trainer import RaySPPOTrainer - - -@hydra.main(config_path="config", config_name="sppo_trainer", version_base=None) -def main(config): - run_ppo(config) - - -def run_ppo(config) -> None: - # TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices - # isolation, will solve in the future - os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "") - if not ray.is_initialized(): - # this is for local ray cluster - ray.init( - runtime_env={ - "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"} - }, - num_cpus=config.ray_init.num_cpus, - ) - - runner = TaskRunner.remote() - ray.get(runner.run.remote(config)) - - -@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head -class TaskRunner: - def run(self, config): - # print initial config - from pprint import pprint - - from omegaconf import OmegaConf - - from verl.utils.fs import copy_to_local - - pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - OmegaConf.resolve(config) - - # download the checkpoint from hdfs - local_path = copy_to_local(config.actor_rollout_ref.model.path) - - # instantiate tokenizer - from verl.utils import hf_processor, hf_tokenizer - - trust_remote_code = config.data.get("trust_remote_code", False) - tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - processor = hf_processor(local_path, use_fast=True) # used for multimodal LLM, could be none - - # define worker classes - if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: - assert config.critic.strategy in {"fsdp", "fsdp2"} - from verl.single_controller.ray import RayWorkerGroup - - from .sppo_worker import SPPOActorRolloutRefWorker # , CriticWorker - - actor_rollout_cls = SPPOActorRolloutRefWorker - ray_worker_group_cls = RayWorkerGroup - - elif config.actor_rollout_ref.actor.strategy == "megatron": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker - - actor_rollout_cls = ActorRolloutRefWorker - ray_worker_group_cls = NVMegatronRayWorkerGroup - - else: - raise NotImplementedError - - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role - - # sppo does not use critic - role_worker_mapping = { - Role.ActorRollout: ray.remote(actor_rollout_cls), - } - - global_pool_id = "global_pool" - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - } - - # we should adopt a multi-source reward function here - # - for rule-based rm, we directly call a reward score - # - for model-based rm, we call a model - # - for code related prompt, we send to a sandbox if there are test cases - # - finally, we combine all the rewards together - # - The reward type depends on the tag of the data - if config.reward_model.enable: - if config.reward_model.strategy in {"fsdp", "fsdp2"}: - from verl.workers.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == "megatron": - from verl.workers.megatron_workers import RewardModelWorker - else: - raise NotImplementedError - role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) - mapping[Role.RewardModel] = global_pool_id - - # use reference model - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: - role_worker_mapping[Role.RefPolicy] = ray.remote(SPPOActorRolloutRefWorker) - mapping[Role.RefPolicy] = global_pool_id - - reward_fn = load_reward_manager( - config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) - ) - val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - - trainer = RaySPPOTrainer( - config=config, - tokenizer=tokenizer, - processor=processor, - role_worker_mapping=role_worker_mapping, - resource_pool_manager=resource_pool_manager, - ray_worker_group_cls=ray_worker_group_cls, - reward_fn=reward_fn, - val_reward_fn=val_reward_fn, - device_name=config.trainer.device, - ) - trainer.init_workers() - trainer.fit() - - -if __name__ == "__main__": - main() diff --git a/recipe/sppo/run_qwen2.5-7b_rm.sh b/recipe/sppo/run_qwen2.5-7b_rm.sh deleted file mode 100755 index 1a4c02686..000000000 --- a/recipe/sppo/run_qwen2.5-7b_rm.sh +++ /dev/null @@ -1,56 +0,0 @@ -# Discliamer: the model used in the script is only for academic purpose. -set -x - -# Data preparation scripts are available in ``examples/data_preprocess``. -# Example usage: -# -# python3 examples/data_preprocess/math_dataset.py --local_dir ~/data/math -# python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k - -gsm8k_train_path=$HOME/data/math/train.parquet -gsm8k_test_path=$HOME/data/math/test.parquet - -train_files="['$gsm8k_train_path']" -test_files="['$gsm8k_test_path']" - -# prepare model ckpt -huggingface-cli download Qwen/Qwen2.5-7B-Instruct --local-dir $HOME/models/Qwen2.5-7B-Instruct & -# huggingface-cli download sfairXC/FsfairX-LLaMA3-RM-v0.1 --local-dir $HOME/models/FsfairX-LLaMA3-RM-v0.1 & -wait - -python3 -m recipe.sppo.main_sppo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path="$HOME/models/Qwen2.5-7B-Instruct" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='sppo-sglang' \ - trainer.val_before_train=True \ - trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \ - trainer.n_gpus_per_node=4 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=1 \ - trainer.total_epochs=1000 $@ - # Note that we set lr_warmup_steps = 15 in config/sppo_trainer.yaml - # The experiment will converge to 0.656 on MATH dataset after 20 epochs \ No newline at end of file diff --git a/recipe/sppo/sppo_ray_trainer.py b/recipe/sppo/sppo_ray_trainer.py deleted file mode 100644 index 15e2f9c40..000000000 --- a/recipe/sppo/sppo_ray_trainer.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -FSDP PPO Trainer with Ray-based single controller. -This trainer supports model-agonistic model initialization with huggingface -""" - -import uuid -from copy import deepcopy -from pprint import pprint -from typing import Optional - -import numpy as np -import ray -import torch -from torch.utils.data import Dataset, Sampler -from tqdm import tqdm - -from verl import DataProto -from verl.single_controller.ray import RayWorkerGroup -from verl.trainer.ppo import core_algos -from verl.trainer.ppo.core_algos import agg_loss -from verl.trainer.ppo.metric_utils import reduce_metrics -from verl.trainer.ppo.ray_trainer import ( - AdvantageEstimator, - RayPPOTrainer, - ResourcePoolManager, - Role, - WorkerType, - apply_kl_penalty, - compute_response_mask, -) -from verl.trainer.ppo.reward import compute_reward, compute_reward_async -from verl.utils.profiler.performance import simple_timer -from verl.utils.tracking import ValidationGenerationsLogger - - -def softmean(x: torch.Tensor, beta: float, dim: int = -1, keepdim: bool = False) -> torch.Tensor: - """ - Compute SoftMean_β(x) = (1/β) * log( (1/n) * Σ exp(β * x_i) ) - Falls back to arithmetic mean when β=0. - """ - if beta == 0.0: - return x.mean(dim=dim, keepdim=keepdim) - - # cast beta to tensor on same device/dtype - beta_t = x.new_tensor(beta) - # numerically-stable logsumexp(β x) - lse = torch.logsumexp(x * beta_t, dim=dim, keepdim=keepdim) - n = x.size(dim) - log_n = x.new_tensor(n).log() - - return (lse - log_n) / beta_t - - -def compute_advantage(data: DataProto, beta=1.0): - rewards = data.batch["token_level_rewards"].sum(axis=-1) # (bs, ) - s_mean = softmean(rewards, beta, keepdim=True) # (bs, ) - rewards = rewards - s_mean # (bs, ) - data.batch["seq_level_rewards"] = rewards # (bs, ) - return data - - -class RaySPPOTrainer(RayPPOTrainer): - """ - Note that this trainer runs on the driver process on a single CPU/GPU node. - """ - - # TODO: support each role have individual ray_worker_group_cls, - # i.e., support different backend of different role - def __init__( - self, - config, - tokenizer, - role_worker_mapping: dict[Role, WorkerType], - resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, - processor=None, - reward_fn=None, - val_reward_fn=None, - train_dataset: Optional[Dataset] = None, - val_dataset: Optional[Dataset] = None, - collate_fn=None, - train_sampler: Optional[Sampler] = None, - device_name="cuda", - ): - self.tokenizer = tokenizer - self.processor = processor - self.config = config - self.reward_fn = reward_fn - self.val_reward_fn = val_reward_fn - - self.hybrid_engine = config.actor_rollout_ref.hybrid_engine - assert self.hybrid_engine, "Currently, only support hybrid engine" - - if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" - - self.role_worker_mapping = role_worker_mapping - self.resource_pool_manager = resource_pool_manager - self.use_reference_policy = Role.RefPolicy in role_worker_mapping - self.use_rm = Role.RewardModel in role_worker_mapping - self.ray_worker_group_cls = ray_worker_group_cls - self.validation_generations_logger = ValidationGenerationsLogger() - self.device_name = device_name - - # define in-reward KL control - # kl loss control currently not supported - if config.algorithm.use_kl_in_reward: - self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) - - self.use_critic = False - - self._validate_config() - self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) - - def fit(self): - """ - The training loop of PPO. - The driver process only need to call the compute functions of the - worker group through RPC to construct the PPO dataflow. - The light-weight advantage computation is done on the driver process. - """ - from omegaconf import OmegaConf - - from verl.utils.tracking import Tracking - - logger = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True), - ) - - self.global_steps = 0 - - # load checkpoint before doing anything - self._load_checkpoint() - - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - pprint(f"Initial validation metrics: {val_metrics}") - logger.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - - # add tqdm - progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") - - # we start from step 1 - self.global_steps += 1 - last_val_metrics = None - - for epoch in range(self.config.trainer.total_epochs): - for batch_dict in self.train_dataloader: - metrics = {} - timing_raw = {} - batch: DataProto = DataProto.from_single_dict(batch_dict) - - # pop those keys for generation - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] - if "multi_modal_data" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("multi_modal_data") - if "raw_prompt" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("raw_prompt") - if "tools_kwargs" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("tools_kwargs") - gen_batch = batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - - is_last_step = self.global_steps >= self.total_training_steps - - with simple_timer("step", timing_raw): - # generate a batch - with simple_timer("gen", timing_raw): - if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - else: - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) - timing_raw.update(gen_batch_output.meta_info["timing"]) - gen_batch_output.meta_info.pop("timing", None) - - if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: - with simple_timer("gen_max", timing_raw): - gen_baseline_batch = deepcopy(gen_batch) - gen_baseline_batch.meta_info["do_sample"] = False - gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) - - batch = batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(batch) - reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - - batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) - - batch.batch["reward_baselines"] = reward_baseline_tensor - - del gen_baseline_batch, gen_baseline_output - - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object - ) - # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - batch = batch.union(gen_batch_output) - - batch.batch["response_mask"] = compute_response_mask(batch) - # Balance the number of valid tokens across DP ranks. - # NOTE: This usually changes the order of data in the `batch`, - # which won't affect the advantage calculation (since it's based on uid), - # but might affect the loss calculation (due to the change of mini-batching). - # TODO: Decouple the DP balancing and mini-batching. - if self.config.trainer.balance_batch: - self._balance_batch(batch, metrics=metrics) - - # compute global_valid tokens - batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - - with simple_timer("reward", timing_raw): - # compute reward model score - if self.use_rm: - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - if self.config.reward_model.launch_reward_fn_async: - future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer) - else: - reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) - - # recompute old_log_probs - with simple_timer("old_log_prob", timing_raw): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch["entropys"] - response_masks = batch.batch["response_mask"] - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} - metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) - - if self.use_reference_policy: - # compute reference log_prob - with simple_timer("ref", timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # compute values - if self.use_critic: - with simple_timer("values", timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with simple_timer("adv", timing_raw): - # we combine with rule-based rm - reward_extra_infos_dict: dict[str, list] - if self.config.reward_model.launch_reward_fn_async: - reward_tensor, reward_extra_infos_dict = ray.get(future_reward) - batch.batch["token_level_scores"] = reward_tensor - - if reward_extra_infos_dict: - batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) - - # compute rewards. apply_kl_penalty if available - if self.config.algorithm.use_kl_in_reward: - batch, kl_metrics = apply_kl_penalty( - batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty - ) - metrics.update(kl_metrics) - else: - batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - batch.batch["seq_level_rewards"] = batch.batch["token_level_scores"] - - beta = self.config.algorithm.sppo_eta - batch = compute_advantage(batch, beta=beta) - - # update critic - if self.use_critic: - with simple_timer("update_critic", timing_raw): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with simple_timer("update_actor", timing_raw): - batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) - metrics.update(actor_output_metrics) - - # Log rollout generations if enabled - rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) - if rollout_data_dir: - with simple_timer("dump_rollout_generations", timing_raw): - print(batch.batch.keys()) - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) - scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() - self._dump_generations( - inputs=inputs, - outputs=outputs, - scores=scores, - reward_extra_infos_dict=reward_extra_infos_dict, - dump_path=rollout_data_dir, - ) - - # validate - if ( - self.val_reward_fn is not None - and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) - ): - with simple_timer("testing", timing_raw): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - if self.config.trainer.save_freq > 0 and ( - is_last_step or self.global_steps % self.config.trainer.save_freq == 0 - ): - with simple_timer("save_checkpoint", timing_raw): - self._save_checkpoint() - - # training metrics - metrics.update( - { - "training/global_step": self.global_steps, - "training/epoch": epoch, - } - ) - - # TODO: make a canonical logger that supports various backend - logger.log(data=metrics, step=self.global_steps) - - if is_last_step: - pprint(f"Final validation metrics: {last_val_metrics}") - progress_bar.close() - return - - progress_bar.update(1) - self.global_steps += 1 diff --git a/recipe/sppo/sppo_worker.py b/recipe/sppo/sppo_worker.py deleted file mode 100644 index fbe3a6e48..000000000 --- a/recipe/sppo/sppo_worker.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os - -from omegaconf import open_dict - -from verl.single_controller.base.decorator import Dispatch, register -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.flops_counter import FlopsCounter -from verl.utils.fsdp_utils import offload_fsdp_model_to_cpu, offload_fsdp_optimizer -from verl.utils.import_utils import import_external_libs -from verl.utils.profiler import log_gpu_memory_usage -from verl.workers.fsdp_workers import ActorRolloutRefWorker - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) - - -class SPPOActorRolloutRefWorker(ActorRolloutRefWorker): - """ - This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy - or a hybrid engine based on the config.rollout - """ - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - from .dp_actor import DataParallelSPPOActor - - # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get("external_lib", None)) - - from omegaconf import OmegaConf - - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) - - use_remove_padding = self.config.model.get("use_remove_padding", False) - use_fused_kernels = self.config.model.get("use_fused_kernels", False) - - if self._is_actor or self._is_rollout: - # we need the model for actor and rollout - if self._is_actor: - optim_config = self.config.actor.optim - fsdp_config = self.config.actor.fsdp_config - else: - optim_config = None - fsdp_config = OmegaConf.create() - self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = ( - self._build_model_optimizer( - model_path=self.config.model.path, - fsdp_config=fsdp_config, - optim_config=optim_config, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - use_fused_kernels=use_fused_kernels, - enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), - trust_remote_code=self.config.model.get("trust_remote_code", False), - use_liger=self.config.model.get("use_liger", False), - role="actor", - ) - ) - - # get the original unwrapped module - self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage("After offload actor model during init", logger=logger) - - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) - # load from checkpoint - if self._is_actor: - OmegaConf.set_struct(self.config.actor, True) - with open_dict(self.config.actor): - self.config.actor.use_remove_padding = use_remove_padding - self.config.actor.use_fused_kernels = use_fused_kernels - self.actor = DataParallelSPPOActor( - config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer - ) - - if self._is_rollout: - self.rollout, self.rollout_sharding_manager = self._build_rollout( - trust_remote_code=self.config.model.get("trust_remote_code", False) - ) - - if self._is_ref: - self.ref_module_fsdp = self._build_model_optimizer( - model_path=self.config.model.path, - fsdp_config=self.config.ref.fsdp_config, - optim_config=None, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - use_fused_kernels=use_fused_kernels, - trust_remote_code=self.config.model.get("trust_remote_code", False), - use_liger=self.config.model.get("use_liger", False), - role="ref", - )[0] - OmegaConf.set_struct(self.config.ref, True) - with open_dict(self.config.ref): - self.config.ref.use_remove_padding = use_remove_padding - self.config.ref.use_fused_kernels = use_fused_kernels - self.ref_policy = DataParallelSPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) - - if self._is_actor: - self.flops_counter = FlopsCounter(self.actor_model_config) - self.checkpoint_manager = FSDPCheckpointManager( - model=self.actor_module_fsdp, - optimizer=self.actor.actor_optimizer, - lr_scheduler=self.actor_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_config=self.config.actor.checkpoint, - ) diff --git a/requirements-npu.txt b/requirements-npu.txt deleted file mode 100644 index 7d0386937..000000000 --- a/requirements-npu.txt +++ /dev/null @@ -1,21 +0,0 @@ -# requirements.txt records the full set of dependencies for development -accelerate -codetiming -datasets -dill -hydra-core -numpy<2.0.0 -pandas -peft -pyarrow>=15.0.0 -pybind11 -pylatexenc -tensordict>=0.8.0,<=0.9.1,!=0.9.0 -transformers==4.52.4 -ray==2.46.0 -wandb -mathruler -torchdata -einops -qwen_vl_utils -torchvision==0.20.1 diff --git a/requirements.txt b/requirements.txt index 458f251a9..8da830646 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,10 +15,9 @@ pybind11 pylatexenc pre-commit ray[default] -tensordict>=0.8.0,<=0.9.1,!=0.9.0 +tensordict torchdata transformers -# vllm==0.8.4 wandb packaging>=20.0 uvicorn diff --git a/requirements_sglang.txt b/requirements_sglang.txt deleted file mode 100644 index ce9e7d536..000000000 --- a/requirements_sglang.txt +++ /dev/null @@ -1,22 +0,0 @@ -# requirements.txt records the full set of dependencies for development -accelerate -codetiming -datasets -dill -flash-attn -hydra-core -numpy<2.0.0 -pandas -peft -pyarrow>=19.0.0 -pybind11 -pylatexenc -ray[default]>=2.10 -tensordict>=0.8.0,<=0.9.1,!=0.9.0 -torchdata -torchvision -transformers -wandb -sglang[all]==0.4.6.post5 -torch-memory-saver>=0.0.5 -huggingface_hub diff --git a/recipe/langgraph_agent/__init__.py b/scripts/tools/__init__.py similarity index 100% rename from recipe/langgraph_agent/__init__.py rename to scripts/tools/__init__.py diff --git a/scripts/tools/change_tokenizer_config.py b/scripts/tools/change_tokenizer_config.py deleted file mode 100644 index a39d653f3..000000000 --- a/scripts/tools/change_tokenizer_config.py +++ /dev/null @@ -1,91 +0,0 @@ -import json -import os -import shutil -import argparse -from transformers import AutoTokenizer - -def modify_tokenizer_config(input_path, output_path): - """Create a modified copy of the tokenizer config in a new directory""" - # Copy the entire model directory to the new location - print(f"Copying {input_path} to {output_path}") - shutil.copytree(input_path, output_path, dirs_exist_ok=True) - print(f"Copied {input_path} to {output_path}") - # Work with the config in the new directory - config_path = os.path.join(output_path, "tokenizer_config.json") - - # Read the original config - with open(config_path, 'r') as f: - config = json.load(f) - - # 1. Change the EOS token - config['eos_token'] = '<|im_end|>' - - # 2. Modify the chat template to change default system prompt - chat_template = config['chat_template'] - - # Replace the default system prompt - new_system_prompt = "You are a helpful assistant. To answer a query from the user, please first thinks through the question step-by-step inside ..., then provides the final response to user." - - # Update the system prompt in both conditions (with and without tools) - chat_template = chat_template.replace( - "You are a helpful assistant.", - new_system_prompt - ) - - # 3. Add tag after assistant generation prompt - chat_template = chat_template.replace( - "{{- '<|im_start|>assistant\\n' }}", - "{{- '<|im_start|>assistant\\n' }}" - ) - - config['chat_template'] = chat_template - - # Save the modified config in the new directory - with open(config_path, 'w') as f: - json.dump(config, f, indent=2) - - return output_path - -def test_chat_template(model_path): - """Test the modified chat template""" - tokenizer = AutoTokenizer.from_pretrained(model_path) - - # Example conversation - messages = [ - {"role": "user", "content": "What is 2+2?"} - ] - - # Apply chat template - prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - print("\nGenerated prompt:") - print(prompt) - -def parse_args(): - parser = argparse.ArgumentParser( - description='Modify tokenizer config with think tags and copy to new directory' - ) - parser.add_argument( - '--input', '-i', - required=True, - help='Input model directory path' - ) - parser.add_argument( - '--output', '-o', - help='Output directory path (default: input_path-think)', - ) - return parser.parse_args() - -if __name__ == "__main__": - args = parse_args() - if args.input.endswith('/'): - args.input = args.input[:-1] - # If output path not specified, create default - if args.output is None: - args.output = f"{args.input}-think" - - print(f"Modifying tokenizer config for {args.input} and saving to {args.output}") - new_path = modify_tokenizer_config(args.input, args.output) - print(f"Modified model saved to: {new_path}") - - # Test the modified chat template - test_chat_template(new_path) \ No newline at end of file diff --git a/scripts/tools/check_gpu.sh b/scripts/tools/check_gpu.sh deleted file mode 100755 index ceb03b95f..000000000 --- a/scripts/tools/check_gpu.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env bash - -# This script runs on a "single node" to check if there are "other users'" processes occupying the GPU - -MY_UID=$(id -u) -gpu_processes=$(nvidia-smi --query-compute-apps=pid --format=csv,noheader) - -# If there are no GPU processes, return 0 (indicating success) -if [ -z "$gpu_processes" ]; then - echo "No GPU processes found on node $(hostname)." - exit 0 -fi - -# If processes exist, check each one -echo "GPU processes found on node $(hostname):" -while IFS="" read -r pid; do - pid=$(echo "$pid" | xargs) - [ -z "$pid" ] && continue - - # Check process owner - user_of_pid=$(ps -o user= -p "$pid" 2>/dev/null | xargs) - # Process might have exited during query - if [ -z "$user_of_pid" ]; then - continue - fi - - # Get the user's UID - uid_of_pid=$(id -u "$user_of_pid" 2>/dev/null) - if [ -z "$uid_of_pid" ]; then - continue - fi - - # If UID doesn't match current script executor, it means someone else is using the GPU - if [ "$uid_of_pid" != "$MY_UID" ]; then - echo " PID=$pid, user=$user_of_pid (UID=$uid_of_pid) is using GPU!" - # Exit with non-zero status if someone else is using the GPU - exit 1 - fi -done <<< "$gpu_processes" - -# Exit normally if no other users were found after checking all processes -exit 0 \ No newline at end of file diff --git a/scripts/tools/clean_ckpt_training_state.sh b/scripts/tools/clean_ckpt_training_state.sh deleted file mode 100644 index 121acb946..000000000 --- a/scripts/tools/clean_ckpt_training_state.sh +++ /dev/null @@ -1,2 +0,0 @@ -# !/bin/bash - diff --git a/scripts/tools/converter_hf_to_mcore.py b/scripts/tools/converter_hf_to_mcore.py index 0183c1591..6e7cdf2b5 100644 --- a/scripts/tools/converter_hf_to_mcore.py +++ b/scripts/tools/converter_hf_to_mcore.py @@ -17,11 +17,21 @@ import os import warnings from contextlib import contextmanager +from importlib.metadata import version from typing import Any, Callable, ContextManager, Optional import numpy as np import torch import torch.distributed as dist + +try: + # NPU patch + import mindspeed.megatron_adaptor # noqa: F401 + from mindspeed.megatron_adaptor import repatch +except ImportError: + repatch = None + pass + from accelerate import init_empty_weights from megatron.core import dist_checkpointing from megatron.core import parallel_state as mpu @@ -29,6 +39,7 @@ from megatron.core.dist_checkpointing.serialization import StrictHandling from megatron.core.models.gpt.gpt_model import ModelType from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from packaging.version import Version from transformers import AutoConfig from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards @@ -50,6 +61,8 @@ def _init_args(): parser = argparse.ArgumentParser() parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model") parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model") + parser.add_argument("--pp_size", type=int, default=1, help="pipeline model parallel size") + parser.add_argument("--ep_size", type=int, default=1, help="expert model parallel size") parser.add_argument("--use_cpu_initialization", action="store_true", help="Whether to use cpu initialization") parser.add_argument("--test", action="store_true", help="Whether to test the conversion") parser.add_argument("--trust_remote_code", action="store_true", help="Whether to trust remote code") @@ -114,6 +127,8 @@ def convert_checkpoint_from_transformers_to_megatron( layer_start, layer_end = layer_start_end pp_rank = mpu.get_pipeline_model_parallel_rank() pp_size = mpu.get_pipeline_model_parallel_world_size() + ep_rank = mpu.get_expert_model_parallel_rank() + ep_size = mpu.get_expert_model_parallel_world_size() numel = 0 num_attention_heads = hf_config.num_attention_heads @@ -162,9 +177,19 @@ def convert_checkpoint_from_transformers_to_megatron( numel += safe_copy(hf_layer.mlp.gate.weight, layer.mlp.router.weight) for idx, hf_expert in enumerate(hf_layer.mlp.experts): + num_experts = len(hf_layer.mlp.experts) + num_local_experts = num_experts // ep_size + expert_idx_start = ep_rank * num_local_experts + expert_idx_end = (ep_rank + 1) * num_local_experts + if idx < expert_idx_start or idx >= expert_idx_end: + continue + local_expert_idx = idx - expert_idx_start + fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) - numel += safe_copy(fc1_weight, layer.mlp.experts.linear_fc1._parameters[f"weight{idx}"]) - numel += safe_copy(hf_expert.down_proj.weight, layer.mlp.experts.linear_fc2._parameters[f"weight{idx}"]) + numel += safe_copy(fc1_weight, layer.mlp.experts.linear_fc1._parameters[f"weight{local_expert_idx}"]) + numel += safe_copy( + hf_expert.down_proj.weight, layer.mlp.experts.linear_fc2._parameters[f"weight{local_expert_idx}"] + ) if has_share_expert: numel += safe_copy(hf_layer.mlp.shared_expert_gate.weight, layer.mlp.shared_experts.gate_weight) @@ -204,7 +229,11 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel head_dim = hidden_size // num_attention_heads # 1. vision model - hfvision = hfmodel.visual + if Version(version("transformers")) < Version("4.52.0"): + print("Using transformers < 4.52 API to load vision model") + hfvision = hfmodel.visual + else: + hfvision = hfmodel.model.visual mgvision = mgmodel.vision_model vision_hidden_size = mgvision.config.hidden_size vision_num_query_groups = mgvision.config.num_query_groups @@ -255,13 +284,18 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel copied_numel += safe_copy(hfprojector.mlp[2].weight, mgprojector.encoder.linear_fc2.weight) copied_numel += safe_copy(hfprojector.mlp[2].bias, mgprojector.encoder.linear_fc2.bias) n_params = sum([t.numel() for t in hfvision.state_dict().values()]) - assert n_params == copied_numel + assert n_params == copied_numel, f"n_params={n_params} != copied_numel={copied_numel}" # 3. llm [just Qwen2] - hfllm = hfmodel.model + if Version(version("transformers")) < Version("4.52.0"): + print("Using transformers < 4.52 API to load llm") + hfllm = hfmodel.model + else: + hfllm = hfmodel.model.language_model mgllm = mgmodel.language_model copied_numel = 0 copied_numel += safe_copy(hfllm.embed_tokens.weight, mgllm.embedding.word_embeddings.weight) - for mglayer, hflayer in zip(mgllm.decoder.layers, hfllm.layers, strict=True): + layermaps = zip(mgllm.decoder.layers, hfllm.layers, strict=True) + for mglayer, hflayer in layermaps: copied_numel += safe_copy(hflayer.input_layernorm.weight, mglayer.self_attention.linear_qkv.layer_norm_weight) q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size) @@ -289,7 +323,7 @@ def convert_checkpoint_from_transformers_to_megatron_qwen2_5_vl(hfmodel, mgmodel n_params = sum([t.numel() for t in hfllm.state_dict().values()]) - assert n_params == copied_numel + assert n_params == copied_numel, f"n_params={n_params} != copied_numel={copied_numel}" @torch.inference_mode() @@ -307,6 +341,9 @@ def convert_checkpoint_from_transformers_to_megatron_dpskv3( numel: int = 0 pp_rank = mpu.get_pipeline_model_parallel_rank() pp_size = mpu.get_pipeline_model_parallel_world_size() + ep_rank = mpu.get_expert_model_parallel_rank() + ep_size = mpu.get_expert_model_parallel_world_size() + if pp_rank == 0: numel += safe_copy(hf_model.model.embed_tokens.weight, model.embedding.word_embeddings.weight) @@ -353,11 +390,20 @@ def convert_checkpoint_from_transformers_to_megatron_dpskv3( ) if tfconfig.moe_grouped_gemm: for i, hf_expert in enumerate(hf_layer.mlp.experts): + num_experts = len(hf_layer.mlp.experts) + num_local_experts = num_experts // ep_size + expert_idx_start = ep_rank * num_local_experts + expert_idx_end = (ep_rank + 1) * num_local_experts + if i < expert_idx_start or i >= expert_idx_end: + continue + local_expert_idx = i - expert_idx_start + fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight]) - linear_fc1_weighti = getattr(layer.mlp.experts.linear_fc1, "weight" + str(i)) + linear_fc1_weighti = getattr(layer.mlp.experts.linear_fc1, "weight" + str(local_expert_idx)) numel += safe_copy(fc1_weight, linear_fc1_weighti) - linear_fc2_weighti = getattr(layer.mlp.experts.linear_fc2, "weight" + str(i)) - numel += safe_copy(hf_expert.down_proj.weight, linear_fc2_weighti) + linear_fc2_weighti = getattr(layer.mlp.experts.linear_fc2, "weight" + str(local_expert_idx)) + numel_w2 = safe_copy(hf_expert.down_proj.weight, linear_fc2_weighti) + numel += numel_w2 else: for i, hf_expert in enumerate(hf_layer.mlp.experts): expert = layer.mlp.experts.local_experts[i] @@ -371,7 +417,10 @@ def convert_checkpoint_from_transformers_to_megatron_dpskv3( numel += safe_copy(shared_fc1_weight, layer.mlp.shared_experts.linear_fc1.weight) numel += safe_copy(hf_layer.mlp.shared_experts.down_proj.weight, layer.mlp.shared_experts.linear_fc2.weight) print(f"{pp_rank=} {global_layer_idx=} {layer_idx=} {numel=} numel this layer={numel - numel_cur}") - assert numel - numel_cur == sum([i.numel() for i in hf_layer.state_dict().values()]), "numel mismatch" + numel_hf_one_layer = sum([i.numel() for i in hf_layer.state_dict().values()]) + if hasattr(layer.mlp, "router"): + numel_hf_one_layer -= numel_w2 * 3 * len(hf_layer.mlp.experts) // ep_size * (ep_size - 1) + assert numel - numel_cur == numel_hf_one_layer, "numel mismatch" if pp_rank == pp_size - 1: numel += safe_copy(hf_model.model.norm.weight, model.decoder.final_layernorm.weight) @@ -393,7 +442,9 @@ def support_distributed_convert(hf_config: AutoConfig) -> bool: return False -def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False): +def convert_hf_to_mcore( + hf_model_path, output_path, pp_size=1, ep_size=1, use_cpu_initialization=False, test=False, trust_remote_code=False +): os.makedirs(output_path, exist_ok=True) if len(os.listdir(output_path)) > 0 and not test: print(f"Output path {output_path} is not empty, skipping conversion") @@ -408,28 +459,35 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False torch.distributed.init_process_group("nccl") - rank = dist.get_rank() local_rank = os.getenv("LOCAL_RANK", 0) world_size = dist.get_world_size() get_torch_device().set_device(f"{get_device_name()}:{local_rank}") + if ep_size * pp_size != world_size: + pp_size = world_size + print(f"pp_size is set to {pp_size}") mpu.initialize_model_parallel( tensor_model_parallel_size=1, - pipeline_model_parallel_size=world_size, + pipeline_model_parallel_size=pp_size, virtual_pipeline_model_parallel_size=None, context_parallel_size=1, - expert_model_parallel_size=1, + expert_model_parallel_size=ep_size, ) model_parallel_cuda_manual_seed(0) # init hf config - hf_config = AutoConfig.from_pretrained(hf_model_path) + hf_config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=trust_remote_code) print(hf_config, flush=True) + if repatch: + if hf_config.architectures[0] == "DeepseekV3ForCausalLM": + config_repatch = dict(multi_head_latent_attention=True) + repatch(config_repatch) + if world_size > 1 and not support_distributed_convert(hf_config): raise NotImplementedError(f"distributed conversion is not supported for {hf_config.architectures} yet.") - pipeline_shards = get_dynamic_pipeline_shards(hf_config.num_hidden_layers, world_size) + pipeline_shards = get_dynamic_pipeline_shards(hf_config.num_hidden_layers, pp_size) print(f"Pipeline shards: {pipeline_shards}", flush=True) tfconfig = hf_to_mcore_config( @@ -483,11 +541,13 @@ def megatron_model_provider(pre_process, post_process): ) hf_state_dict = hf_model.state_dict() + pp_rank = mpu.get_pipeline_model_parallel_rank() + # distributed convert if world_size > 1 and support_distributed_convert(hf_config): pipeline_cumsum = np.cumsum(pipeline_shards) - layer_start = 0 if rank == 0 else pipeline_cumsum[rank - 1] - layer_end = pipeline_cumsum[rank] + layer_start = 0 if pp_rank == 0 else pipeline_cumsum[pp_rank - 1] + layer_end = pipeline_cumsum[pp_rank] if "DeepseekV3ForCausalLM" in hf_config.architectures: numel_partial: int = convert_checkpoint_from_transformers_to_megatron_dpskv3( hf_model, model[0].module, hf_config, tfconfig=tfconfig, layer_start_end=(layer_start, layer_end) @@ -540,5 +600,11 @@ def megatron_model_provider(pre_process, post_process): if __name__ == "__main__": args = _init_args() convert_hf_to_mcore( - args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test, args.trust_remote_code + args.hf_model_path, + args.output_path, + args.pp_size, + args.ep_size, + args.use_cpu_initialization, + args.test, + args.trust_remote_code, ) diff --git a/scripts/tools/diagnose.py b/scripts/tools/diagnose.py index 174b1f9b5..cb78f9e5c 100644 --- a/scripts/tools/diagnose.py +++ b/scripts/tools/diagnose.py @@ -236,7 +236,7 @@ def _get_gpu_info(): } ) return gpu_count, gpu_info - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): print("Failed to execute nvidia-smi command.") return 0, [] diff --git a/scripts/tools/download_guru.py b/scripts/tools/download_guru.py deleted file mode 100644 index 1db87ec88..000000000 --- a/scripts/tools/download_guru.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -import shutil - -from huggingface_hub import hf_hub_download, list_repo_files - -REPO_ID = "LLM360/guru-RL-92k" -REPO_TYPE = "dataset" -LOCAL_DATA_DIR = "./data" - -all_files = list_repo_files(REPO_ID, repo_type=REPO_TYPE) -split_to_local = {"train": "train", "online_eval": "online_eval", "offline_eval": "offline_eval"} - - -def download_files_from_split(split, local_dir): - parquet_files = [f for f in all_files if f.startswith(f"{split}/") and f.endswith(".parquet")] - print(f"Downloading {len(parquet_files)} files to {local_dir}") - os.makedirs(local_dir, exist_ok=True) - for filename in parquet_files: - print(f"Downloading {filename} to {local_dir}") - hf_hub_download( - repo_id=REPO_ID, - repo_type=REPO_TYPE, - filename=filename, - local_dir=local_dir, - local_dir_use_symlinks=False, - ) - # Remove .cache under each split folder - cache_dir = os.path.join(local_dir, ".cache") - if os.path.exists(cache_dir): - shutil.rmtree(cache_dir) - - -for split, local_dir in split_to_local.items(): - # download_files_from_split(split, os.path.join(LOCAL_DATA_DIR, local_dir)) - download_files_from_split(split, LOCAL_DATA_DIR) diff --git a/scripts/tools/generate_trainer_config.sh b/scripts/tools/generate_trainer_config.sh new file mode 100755 index 000000000..a40f555fd --- /dev/null +++ b/scripts/tools/generate_trainer_config.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +set -euox pipefail + + +# Define config specifications: "config_name:output_file:config_arg" +CONFIG_SPECS=( + "ppo_trainer:_generated_ppo_trainer.yaml:" + "ppo_megatron_trainer:_generated_ppo_megatron_trainer.yaml:--config-name=ppo_megatron_trainer.yaml" +) + +generate_config() { + local config_name="$1" + local output_file="$2" + local config_arg="$3" + + local target_cfg="verl/trainer/config/${output_file}" + local tmp_header=$(mktemp) + local tmp_cfg=$(mktemp) + + echo "# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh'" > "$tmp_header" + echo "# in which it invokes 'python3 scripts/print_cfg.py --cfg job ${config_arg}' to flatten the 'verl/trainer/config/${config_name}.yaml' config fields into a single file." >> "$tmp_header" + echo "# Do not modify this file directly." >> "$tmp_header" + echo "# The file is usually only for reference and never used." >> "$tmp_header" + echo "" >> "$tmp_header" + + python3 scripts/print_cfg.py --cfg job ${config_arg} > "$tmp_cfg" + + cat "$tmp_header" > "$target_cfg" + sed -n '/^actor_rollout_ref/,$p' "$tmp_cfg" >> "$target_cfg" + + rm "$tmp_cfg" "$tmp_header" + + echo "Generated: $target_cfg" +} + +for spec in "${CONFIG_SPECS[@]}"; do + IFS=':' read -r config_name output_file config_arg <<< "$spec" + generate_config "$config_name" "$output_file" "$config_arg" +done + +for spec in "${CONFIG_SPECS[@]}"; do + IFS=':' read -r config_name output_file config_arg <<< "$spec" + target_cfg="verl/trainer/config/${output_file}" + if ! git diff --exit-code -- "$target_cfg" >/dev/null; then + echo "✖ $target_cfg is out of date. Please regenerate via 'scripts/generate_trainer_config.sh' and commit the changes." + exit 1 + fi +done + +echo "All good" +exit 0 diff --git a/scripts/tools/init_random_model.py b/scripts/tools/init_random_model.py index 2804bc2a2..2bc3ffc1b 100644 --- a/scripts/tools/init_random_model.py +++ b/scripts/tools/init_random_model.py @@ -39,6 +39,11 @@ def _init_args(): parser.add_argument("--hf_model_path", type=str, required=True, help="The path for the huggingface model") parser.add_argument("--new_config_path", type=str, required=True, help="The path for the new config file") parser.add_argument("--output_path", type=str, required=True, help="The path for the output random model") + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="Whether to trust remote code when loading HF model. Disabled by default for security.", + ) args = parser.parse_args() return args @@ -69,9 +74,9 @@ def check_configs(original_config: dict[str, Any], new_config: dict[str, Any]) - ) -def init_random_model(hf_model_path, new_config_path, output_path): - config = AutoConfig.from_pretrained(hf_model_path) - tokenizer = AutoTokenizer.from_pretrained(hf_model_path) +def init_random_model(hf_model_path, new_config_path, output_path, trust_remote_code: bool = False): + config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=trust_remote_code) + tokenizer = AutoTokenizer.from_pretrained(hf_model_path, trust_remote_code=trust_remote_code) config_dict = PretrainedConfig.get_config_dict(hf_model_path)[0] print(config_dict) with open(new_config_path) as f: @@ -80,7 +85,12 @@ def init_random_model(hf_model_path, new_config_path, output_path): config_dict.update(new_config_dict) new_confg = config.from_dict(config_dict) print(f"new_config: {new_confg}") - model = AutoModelForCausalLM.from_config(new_confg) + if trust_remote_code: + model = AutoModelForCausalLM.from_pretrained( + hf_model_path, config=new_confg, trust_remote_code=trust_remote_code + ) + else: + model = AutoModelForCausalLM.from_config(new_confg) model.save_pretrained(output_path) tokenizer.save_pretrained(output_path) new_confg.save_pretrained(output_path) @@ -91,5 +101,8 @@ def init_random_model(hf_model_path, new_config_path, output_path): args = _init_args() check_output_path(args.output_path) init_random_model( - hf_model_path=args.hf_model_path, new_config_path=args.new_config_path, output_path=args.output_path + hf_model_path=args.hf_model_path, + new_config_path=args.new_config_path, + output_path=args.output_path, + trust_remote_code=args.trust_remote_code, ) diff --git a/scripts/tools/install_vllm_sglang_mcore.sh b/scripts/tools/install_vllm_sglang_mcore.sh index e80647694..4ac686764 100755 --- a/scripts/tools/install_vllm_sglang_mcore.sh +++ b/scripts/tools/install_vllm_sglang_mcore.sh @@ -7,36 +7,37 @@ export MAX_JOBS=32 echo "1. install inference frameworks and pytorch they need" if [ $USE_SGLANG -eq 1 ]; then - pip install "sglang[all]==0.4.6.post1" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir + pip install "sglang[all]==0.5.2" --no-cache-dir && pip install torch-memory-saver --no-cache-dir fi -pip install --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata +pip install --no-cache-dir "vllm==0.11.0" echo "2. install basic packages" pip install "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \ - "numpy<2.0.0" "pyarrow>=15.0.0" pandas \ + "numpy<2.0.0" "pyarrow>=15.0.0" pandas "tensordict>=0.8.0,<=0.10.0,!=0.9.0" torchdata \ ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \ - pytest py-spy pyext pre-commit ruff + pytest py-spy pre-commit ruff tensorboard + +echo "pyext is lack of maintainace and cannot work with python 3.12." +echo "if you need it for prime code rewarding, please install using patched fork:" +echo "pip install git+https://github.com/ShaohonChen/PyExt.git@py311support" pip install "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1" echo "3. install FlashAttention and FlashInfer" -# Install flash-attn-2.7.4.post1 (cxx11abi=False) -wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \ - pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl +# Install flash-attn-2.8.1 (cxx11abi=False) +wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu12torch2.8cxx11abiFALSE-cp312-cp312-linux_x86_64.whl && \ + pip install --no-cache-dir flash_attn-2.8.1+cu12torch2.8cxx11abiFALSE-cp312-cp312-linux_x86_64.whl -# Install flashinfer-0.2.2.post1+cu124 (cxx11abi=False) -# vllm-0.8.3 does not support flashinfer>=0.2.3 -# see https://github.com/vllm-project/vllm/pull/15777 -wget -nv https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \ - pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl +pip install --no-cache-dir flashinfer-python==0.3.1 if [ $USE_MEGATRON -eq 1 ]; then echo "4. install TransformerEngine and Megatron" echo "Notice that TransformerEngine installation can take very long time, please be patient" - NVTE_FRAMEWORK=pytorch pip3 install --no-deps git+https://github.com/NVIDIA/TransformerEngine.git@v2.2 - pip3 install --no-deps git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.0rc3 + pip install "onnxscript==0.3.1" + NVTE_FRAMEWORK=pytorch pip3 install --no-deps git+https://github.com/NVIDIA/TransformerEngine.git@v2.6 + pip3 install --no-deps git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.1 fi @@ -48,7 +49,7 @@ pip install opencv-fixer && \ if [ $USE_MEGATRON -eq 1 ]; then echo "6. Install cudnn python package (avoid being overridden)" - pip install nvidia-cudnn-cu12==9.8.0.87 + pip install nvidia-cudnn-cu12==9.10.2.21 fi echo "Successfully installed all packages" diff --git a/scripts/tools/legacy_model_merger.py b/scripts/tools/legacy_model_merger.py index 7049fc65d..a6da5072d 100644 --- a/scripts/tools/legacy_model_merger.py +++ b/scripts/tools/legacy_model_merger.py @@ -44,7 +44,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import Optional, Union import numpy as np import torch @@ -105,14 +105,33 @@ def __init__(self, config: ModelMergerConfig): ) self.hf_model_config_path = config.hf_model_path + # Auto-detect huggingface subdirectory if it exists + huggingface_subdir = os.path.join(self.hf_model_config_path, "huggingface") + if os.path.isdir(huggingface_subdir): + self.hf_model_config_path = huggingface_subdir + self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path) def get_transformers_auto_model_class(self): - if "ForTokenClassification" in self.model_config.architectures[0]: + # Handle case where architectures might be None or empty + if self.model_config.architectures is None or len(self.model_config.architectures) == 0: + # Try to infer from model_type if architectures is missing + model_type = getattr(self.model_config, 'model_type', '').lower() + if 'vision' in model_type or 'vl' in model_type: + return AutoModelForVision2Seq + elif 'causal' in model_type or 'gpt' in model_type or 'llama' in model_type or 'qwen' in model_type: + return AutoModelForCausalLM + else: + raise NotImplementedError( + f"Cannot determine model class: architectures is None and model_type '{model_type}' is not recognized" + ) + + architecture = self.model_config.architectures[0] + if "ForTokenClassification" in architecture: return AutoModelForTokenClassification - elif "ForCausalLM" in self.model_config.architectures[0]: + elif "ForCausalLM" in architecture: return AutoModelForCausalLM - elif "ForConditionalGeneration" in self.model_config.architectures[0]: + elif "ForConditionalGeneration" in architecture: return AutoModelForVision2Seq raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") @@ -207,7 +226,11 @@ def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): del model processor = hf_processor(self.hf_model_config_path) - tokenizer = hf_tokenizer(self.hf_model_config_path) + try: + tokenizer = hf_tokenizer(self.hf_model_config_path) + except Exception as e: + warnings.warn(f"Failed to create tokenizer: {e}. This may affect tokenizer saving", stacklevel=1) + tokenizer = None if processor is not None: print(f"Saving processor to {self.config.target_dir}") processor.save_pretrained(self.config.target_dir) @@ -235,7 +258,7 @@ def _get_world_size(self) -> int: if match: return int(match.group(1)) raise FileNotFoundError( - f"Could not determine world size. No file matching 'model_world_size_(\d+)_rank_0.pt' found in {self.config.local_dir}" + f"Could not determine world size. No file matching 'model_world_size_(\\d+)_rank_0.pt' found in {self.config.local_dir}" ) def _load_rank_zero_state_dict(self, world_size: int) -> dict: @@ -492,7 +515,7 @@ def _merge_across_tp( config: PretrainedConfig, tp_size: int, is_value_model: bool = False, - ) -> torch.Tensor | list[torch.Tensor]: + ) -> Union[torch.Tensor, list[torch.Tensor]]: if "linear_fc1.weight" in key: # if the tensor is gate and proj gate_lst = [] @@ -621,7 +644,7 @@ def _merge_state_dicts( state_dict[hf_name] = merged elif len(merged) == 3: # split qkv - for n, d in zip(["q", "k", "v"], merged, strict=False): + for n, d in zip(["q", "k", "v"], merged): state_dict[hf_name.replace("qkv", n)] = d elif len(merged) == 2: # split gate up diff --git a/scripts/tools/model_merger.py b/scripts/tools/model_merger.py deleted file mode 100644 index 3bd25cae2..000000000 --- a/scripts/tools/model_merger.py +++ /dev/null @@ -1,623 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends. - -To merge FSDP checkpoints: -```sh -python scripts/model_merger.py merge \ - --backend fsdp \ - --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - -To merge Megatron checkpoints: -```sh -python scripts/model_merger.py merge \ - --backend megatron \ - --tie-word-embedding \ - --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - -For more details, please refer to documentation: -https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model -""" - -import argparse -import os -import re -from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional - -import numpy as np -import torch -from accelerate import init_empty_weights -from safetensors.torch import load_file -from torch.distributed._tensor import Placement, Shard -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoModelForTokenClassification, - AutoModelForVision2Seq, - GenerationConfig, - PretrainedConfig, -) - -try: - # for torch 2.5+ - from torch.distributed.tensor import DTensor -except ImportError: - from torch.distributed._tensor import DTensor - -from tqdm import tqdm - -from verl.utils import hf_processor, hf_tokenizer - - -@dataclass -class ModelMergerConfig: - operation: str # 'merge' or 'test' - backend: str - local_dir: str - hf_model_config_path: str - target_dir: Optional[str] = "tmp" - hf_upload_path: Optional[str] = None - private: bool = False - test_hf_dir: Optional[str] = None - tie_word_embedding: bool = False - is_value_model: bool = False - hf_model_path: Optional[str] = None - hf_upload: bool = field(init=False) - - def __post_init__(self): - self.hf_upload = self.operation == "merge" and bool(self.hf_upload_path) - if self.operation == "test": - self.target_dir = None - self.hf_upload_path = None - self.private = False - - -class BaseModelMerger(ABC): - def __init__(self, config: ModelMergerConfig): - self.config = config - self.hf_model_config_path = config.hf_model_config_path - - if config.hf_model_path: - print("Warning: --hf_model_path is deprecated and will be removed in a future version. Currently verl will save huggingface model configuration files into checkpoint directories. Therefore, there is no need to provide --hf_model_path. ") - self.hf_model_config_path = config.hf_model_path - - self.model_config = AutoConfig.from_pretrained(self.hf_model_config_path) - - def get_transformers_auto_model_class(self): - if "ForTokenClassification" in self.model_config.architectures[0]: - return AutoModelForTokenClassification - elif "ForCausalLM" in self.model_config.architectures[0]: - return AutoModelForCausalLM - elif "ForConditionalGeneration" in self.model_config.architectures[0]: - return AutoModelForVision2Seq - - raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") - - def patch_model_generation_config(self, model): - """ - The generation_config created from model config may be different to the pretrained model, - this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 - - This function patch the generation_config created from model config to the pretrained model. - """ - if model.can_generate(): - try: - model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) - except OSError: - print(f"Warning: Generation config file not found in {self.hf_model_config_path}, using a generation config created from the model config.") - return model - - def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): - auto_model_class = self.get_transformers_auto_model_class() - with init_empty_weights(): - model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16) - model.to_empty(device="cpu") - model = self.patch_model_generation_config(model) - - print(f"Saving model to {self.config.target_dir}") - model.save_pretrained(self.config.target_dir, state_dict=state_dict) - del state_dict - del model - - processor = hf_processor(self.hf_model_config_path) - tokenizer = hf_tokenizer(self.hf_model_config_path) - if processor is not None: - print(f"Saving processor to {self.config.target_dir}") - processor.save_pretrained(self.config.target_dir) - if tokenizer is not None: - print(f"Saving tokenizer to {self.config.target_dir}") - tokenizer.save_pretrained(self.config.target_dir) - - def upload_to_huggingface(self): - from huggingface_hub import HfApi - - api = HfApi() - api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) - api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") - - @abstractmethod - def merge_and_save(self): - raise NotImplementedError("Subclasses should implement this method") - - -class FSDPModelMerger(BaseModelMerger): - def _get_world_size(self) -> int: - """Extracts the FSDP world_size from checkpoint filenames (e.g., 'model_world_size_8_rank_0.pt').""" - for filename in os.listdir(self.config.local_dir): - match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) - if match: - return int(match.group(1)) - raise FileNotFoundError(f"Could not determine world size. No file matching 'model_world_size_(\d+)_rank_0.pt' found in {self.config.local_dir}") - - def _load_rank_zero_state_dict(self, world_size: int) -> dict: - return torch.load(Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_0.pt", map_location="cpu", weights_only=False) - - def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: - """ - Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. - If no DTensor is found, infers a simple FSDP mesh based on world_size. - """ - pivot_key = sorted(list(state_dict.keys()))[0] - weight = state_dict[pivot_key] - - if isinstance(weight, DTensor): - # get sharding info - device_mesh = weight.device_mesh - mesh = device_mesh.mesh - mesh_dim_names = device_mesh.mesh_dim_names - else: - # for non-DTensor - mesh = np.array([world_size], dtype=np.int64) - mesh_dim_names = ("fsdp",) - - return mesh, mesh_dim_names - - def _calculate_shard_configuration(self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...]) -> tuple[int, tuple[int, ...]]: - """Calculates the total number of shards and the shape of the device mesh.""" - assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" - - if "tp" in mesh_dim_names: - # TODO: "tp" is not supported yet due to the above assert - total_shards = mesh.shape[-1] * mesh.shape[-2] - mesh_shape = (mesh.shape[-2], mesh.shape[-1]) - else: - total_shards = mesh.shape[-1] - mesh_shape = (mesh.shape[-1],) - - return total_shards, mesh_shape - - def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: - """Merges a list of tensors based on their DTensor placement""" - if placement.is_replicate(): - return tensors[0] - elif placement.is_partial(): - raise NotImplementedError("Partial placement is not supported yet") - elif placement.is_shard(): - return torch.cat(tensors, dim=placement.dim).contiguous() - - raise NotImplementedError(f"Unsupported placement: {placement}") - - def _load_and_merge_state_dicts(self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...]) -> dict[str, torch.Tensor]: - model_state_dict_lst = [None] * total_shards - - def process_one_shard(rank: int, model_state_dict_lst: list): - model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" - state_dict = torch.load(model_path, map_location="cpu", weights_only=False) - model_state_dict_lst[rank] = state_dict - return state_dict - - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] - for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): - future.result() - - # Merge state dicts from all shards - state_dict = {} - param_placements: dict[str, list] = {} - - for key in set(model_state_dict_lst[0].keys()): - state_dict[key] = [] - for model_state_shard in model_state_dict_lst: - # add tensor shard in order of rank to state_dict[key] - tensor = model_state_shard.pop(key) - if isinstance(tensor, DTensor): - state_dict[key].append(tensor._local_tensor.bfloat16()) - - placements = tuple(tensor.placements) - # replicated placement at dp dimension can be discarded - if mesh_dim_names[0] in ("dp", "ddp"): - placements = placements[1:] - - if key not in param_placements: - param_placements[key] = placements - else: - assert param_placements[key] == placements - else: - state_dict[key].append(tensor.bfloat16()) - - del model_state_dict_lst - - # Merge tensors - for key in sorted(state_dict): - if not isinstance(state_dict[key], list): - print(f"No need to merge key {key}") - continue - if key in param_placements: - # merge shards - placements: tuple[Shard] = param_placements[key] - if len(mesh_shape) == 1: - # 1-D list, FSDP without TP - assert len(placements) == 1 - shards = state_dict[key] - state_dict[key] = self._merge_by_placement(shards, placements[0]) - else: - # 2-D list, FSDP + TP - raise NotImplementedError("FSDP + TP is not supported yet") - else: - state_dict[key] = torch.cat(state_dict[key], dim=0) - - return state_dict - - def merge_and_save(self): - world_size = self._get_world_size() - rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) - - mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) - print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - - total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) - print(f"Processing model shards with {total_shards} {mesh_shape} in total") - - merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._test_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): - auto_model_class = self.get_transformers_auto_model_class() - - hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) - hf_state_dict = hf_model.state_dict() - del hf_model - - hf_model_keys = set(hf_state_dict.keys()) - collected_keys = set(state_dict.keys()) - - missing_keys = hf_model_keys - collected_keys - assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" - - extra_keys = collected_keys - hf_model_keys - assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" - - for key in hf_model_keys: - hf_shape = hf_state_dict[key].shape - collected_shape = state_dict[key].shape - assert hf_shape == collected_shape, f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" - - hf_dtype = hf_state_dict[key].dtype - collected_dtype = state_dict[key].dtype - assert hf_dtype == collected_dtype, f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" - - torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) - - print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") - - -class MegatronModelMerger(BaseModelMerger): - def __init__(self, config: ModelMergerConfig): - from verl.utils.megatron_utils import get_hf_config_and_tokenizer_checkpoint_path - - config.hf_model_config_path = get_hf_config_and_tokenizer_checkpoint_path(config.local_dir) - super().__init__(config) - - def _get_tp_pp_rank_from_sharded_dir(self, sharded_dir: str) -> tuple[int, int]: - match = re.match(r"mp_rank_(\d\d)_(\d\d\d)", sharded_dir) - assert match, f"Invalid sharded dir {sharded_dir}" - tp_rank = int(match.group(1)) - pp_rank = int(match.group(2)) - return tp_rank, pp_rank - - def _check_megatron_checkpoint_path(self, model_path: str) -> tuple[list[str], int, int]: - """ - Validates the Megatron checkpoint structure (presence of 'model.pt' in sharded directories). - Determines TP and PP sizes from directory names. - """ - tp_size = 0 - pp_size = 0 - sharded_dirs = sorted(os.listdir(model_path)) - for sharded_dir in sharded_dirs: - assert "model.pt" in os.listdir(Path(model_path) / sharded_dir), f"model.pt not found in {sharded_dir}" - tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) - tp_size = max(tp_size, tp_rank + 1) - pp_size = max(pp_size, pp_rank + 1) - return sharded_dirs, tp_size, pp_size - - def _merge_across_tp(self, key: str, tp_data: list[torch.Tensor], config: PretrainedConfig, tp_size: int, is_value_model: bool = False) -> torch.Tensor | list[torch.Tensor]: - if "linear_fc1.weight" in key: - # if the tensor is gate and proj - gate_lst = [] - up_lst = [] - for infer_param in tp_data: - gate, up = infer_param.chunk(2) - gate_lst.append(gate) - up_lst.append(up) - gate = torch.cat(gate_lst, dim=0) - up = torch.cat(up_lst, dim=0) - return [gate, up] - - elif "self_attention.linear_qkv." in key and "layer_norm" not in key: - # if the tensor is qkv, for each param on tp, split into q, k, v - # concat q, k, v separately. - q_lst = [] - k_lst = [] - v_lst = [] - assert config.num_attention_heads % config.num_key_value_heads == 0 - num_q_per_kv = config.num_attention_heads // config.num_key_value_heads - assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0 - kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2) - split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] - - for infer_param in tp_data: - num_query_groups_per_partition = config.num_key_value_heads // tp_size - for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [ - kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) - - q = torch.cat(q_lst, dim=0) - k = torch.cat(k_lst, dim=0) - v = torch.cat(v_lst, dim=0) - return [q, k, v] - - elif "layer_norm" in key or "layernorm" in key or "output_layer" in key and is_value_model: - return tp_data[0] - else: - dim = 0 - if "linear_fc2.weight" in key or "self_attention.linear_proj" in key: - dim = 1 - return torch.cat(tp_data, dim=dim) - - def _load_state_dicts(self, model_ckpt_path: str, sharded_dirs: list[str], tp_size: int, pp_size: int) -> list[list[dict]]: - model_state_dict_lst = [[None for _ in range(tp_size)] for _ in range(pp_size)] - - def _process_one_megatron_shard(sharded_dir: str): - model_file_path = Path(model_ckpt_path) / sharded_dir / "model.pt" - state_dict = torch.load(model_file_path, map_location="cpu", weights_only=False) - tp_rank, pp_rank = self._get_tp_pp_rank_from_sharded_dir(sharded_dir) - model_state_dict_lst[pp_rank][tp_rank] = state_dict - - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(_process_one_megatron_shard, sharded_dir) for sharded_dir in sharded_dirs] - for future in tqdm(futures, desc=f"Loading {len(sharded_dirs)} Megatron shards", total=len(sharded_dirs)): - future.result() - - return model_state_dict_lst - - def _merge_state_dicts(self, model_state_dict_lst: list[list[dict]], tp_size: int, pp_size: int) -> dict[str, torch.Tensor]: - state_dict = {} - vpp_size = len(model_state_dict_lst[0][0]) - layers_cum = 0 - - for vpp_rank in range(vpp_size): - for pp_rank in range(pp_size): - layers_handled = 0 - keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys() - for key in keys: - if "extra_state" in key: - continue - if self.config.tie_word_embedding and ("output_layer" in key): - print("skip lm_head and reward_head loading because of tie_word_embeddings") - continue - - new_key = key - if "decoder.layers." in key: - local_layer_no = int(key.split(".")[2]) - layers_handled = max(local_layer_no, layers_handled) - global_layer_no = local_layer_no + layers_cum - new_key_list = key.split(".") - new_key_list[2] = str(global_layer_no) - new_key = ".".join(new_key_list) - - tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)] - merged = self._merge_across_tp(new_key, tp_data, self.model_config, tp_size, self.config.is_value_model) - - if not isinstance(merged, list): - state_dict[new_key] = merged - elif len(merged) == 3: - # split qkv - for n, d in zip(["q", "k", "v"], merged): - state_dict[new_key.replace("linear_qkv", f"linear_{n}")] = d - elif len(merged) == 2: - # split gate up - state_dict[new_key.replace("linear_fc1", "gate_proj")] = merged[0] - state_dict[new_key.replace("linear_fc1", "up_proj")] = merged[1] - - layers_cum += layers_handled + 1 # zero based - - return state_dict - - def merge_and_save(self): - from verl.utils.megatron_utils import get_model_checkpoint_path - - model_ckpt_path = get_model_checkpoint_path(self.config.local_dir) - sharded_dirs, tp_size, pp_size = self._check_megatron_checkpoint_path(model_ckpt_path) - print(f"sharded_dirs: {sharded_dirs}, tp_size: {tp_size}, pp_size: {pp_size}, mp_size: {len(sharded_dirs)}") - - model_state_dict_lst = self._load_state_dicts(model_ckpt_path, sharded_dirs, tp_size, pp_size) - merged_state_dict = self._merge_state_dicts(model_state_dict_lst, tp_size, pp_size) - del model_state_dict_lst - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._test_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _test_state_dict(self, state_dict: dict[str, torch.Tensor]): - """ - Compares the merged Megatron state_dict against a reference safetensors model. - Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. - """ - ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") - - params_mapping = [ - # (megatron core gpt model name, vllm model name) - ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), - ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), - ("embedding.word_embeddings", "model.embed_tokens"), - ("self_attention.linear_qkv", "self_attn.qkv_proj"), - ("self_attention.linear_proj", "self_attn.o_proj"), - ("pre_mlp_layernorm", "post_attention_layernorm"), - ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), - ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), - ("mlp.linear_fc1", "mlp.gate_up_proj"), - ("mlp.linear_fc2", "mlp.down_proj"), - ("decoder.final_layernorm", "model.norm"), - ("output_layer", "lm_head"), - ("self_attention.linear_q", "self_attn.q_proj"), - ("self_attention.linear_k", "self_attn.k_proj"), - ("self_attention.linear_v", "self_attn.v_proj"), - ] - - for original_name, loaded_weight in state_dict.items(): - name = self._replace_name(original_name, params_mapping) - if not name or name.endswith(".bias") and name not in ref_state_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - if self.config.tie_word_embedding and "lm_head.weight" in name: - continue - if name not in ref_state_dict: - raise RuntimeError(f"key: {name} not exist in state_dict") - param = ref_state_dict[name] - assert loaded_weight.dtype == param.dtype - torch.testing.assert_close(loaded_weight, param, atol=1e-2, rtol=5e-2) - - def _replace_name(self, megatron_name: str, name_mapping: list[tuple[str, str]]) -> str: - for m_name, v_name in name_mapping: - if m_name not in megatron_name: - continue - if "layers" in megatron_name: # deal with decoder layers - megatron_name = megatron_name.replace("decoder", "model") - megatron_name_list = megatron_name.split(".") - if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: - param_name_list = megatron_name_list[:3] - param_name_list.append(v_name) - param_name = ".".join(param_name_list) - else: - param_name_list = megatron_name_list[:3] - weight_or_bias = megatron_name_list[-1] - param_name_list.append(v_name) - param_name_list.append(weight_or_bias) - param_name = ".".join(param_name_list) - return param_name - else: - param_name = megatron_name.replace(m_name, v_name) - return param_name - return None # Return None if no mapping found - - -def main(): - parser = argparse.ArgumentParser(description="verl model merger") - subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") - - base_op_parser = argparse.ArgumentParser(add_help=False) - base_op_parser.add_argument("--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model") - base_op_parser.add_argument("--local_dir", type=str, required=True, help="Path to the saved model checkpoints") - base_op_parser.add_argument("--hf_model_path", type=str, default=None, help="(Deprecated) Path to the original Hugging Face model for config.") - base_op_parser.add_argument("--tie-word-embedding", action="store_true", help="Whether to tie word embedding weights (currently only Megatron supported)") - base_op_parser.add_argument("--is-value-model", action="store_true", help="Whether the model is a value model (currently only Megatron supported)") - - merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") - merge_parser.add_argument("--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model") - merge_parser.add_argument("--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model") - merge_parser.add_argument("--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository") - - test_parser = subparsers.add_parser("test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model") - test_parser.add_argument("--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing") - - args = parser.parse_args() - - common_config_args = { - "operation": args.operation, - "backend": args.backend, - "tie_word_embedding": args.tie_word_embedding, - "is_value_model": args.is_value_model, - "local_dir": args.local_dir, - "hf_model_path": args.hf_model_path, - "hf_model_config_path": args.local_dir, - } - - if args.operation == "merge": - config = ModelMergerConfig( - **common_config_args, - target_dir=args.target_dir, - hf_upload_path=args.hf_upload_path, - private=args.private, - test_hf_dir=None, - ) - os.makedirs(config.target_dir, exist_ok=True) - elif args.operation == "test": - config = ModelMergerConfig( - **common_config_args, - test_hf_dir=args.test_hf_dir, - # the following args are not used by test operation - target_dir=None, - hf_upload_path=None, - private=False, - ) - else: - raise NotImplementedError(f"Unknown operation: {args.operation}") - - if config.backend == "fsdp": - merger = FSDPModelMerger(config) - elif config.backend == "megatron": - merger = MegatronModelMerger(config) - else: - raise NotImplementedError(f"Unknown backend: {config.backend}") - - merger.merge_and_save() - - -if __name__ == "__main__": - main() diff --git a/scripts/tools/print_cfg.py b/scripts/tools/print_cfg.py new file mode 100644 index 000000000..287756fb1 --- /dev/null +++ b/scripts/tools/print_cfg.py @@ -0,0 +1,35 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import hydra +except ImportError as e: + raise ImportError("Please install hydra-core via 'pip install hydra-core' and retry.") from e + + +@hydra.main(config_path="../verl/trainer/config", config_name="ppo_trainer", version_base=None) +def main(config): + """Main entry point for PPO training with Hydra configuration management. + + Args: + config_dict: Hydra configuration dictionary containing training parameters. + """ + print(config) + from verl.utils.config import omega_conf_to_dataclass + + profiler_config = omega_conf_to_dataclass(config.critic.profiler) + print(profiler_config) + + +if __name__ == "__main__": + main() diff --git a/scripts/tools/rollout_viewer.py b/scripts/tools/rollout_viewer.py new file mode 100644 index 000000000..eb0314edc --- /dev/null +++ b/scripts/tools/rollout_viewer.py @@ -0,0 +1,565 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import re +import traceback +from pathlib import Path +from typing import Annotated, Optional + +import aiofiles + +try: + import ujson as json +except ImportError: + import json +import typer +from rich.highlighter import ReprHighlighter +from rich.markdown import Markdown +from rich.table import Table +from rich.text import Text +from textual import on +from textual.app import App, ComposeResult +from textual.containers import Horizontal, Vertical, VerticalScroll +from textual.widgets import Input, ProgressBar, Select, SelectionList, Static + +INDEX_KEY = "__IDX" +FILE_SUFFIX = ".jsonl" + + +def check_textual_version(): + # check if textual version is equal to 0.52.1 + import textual + from packaging.version import Version + + if Version(textual.__version__) != Version("0.52.1"): + raise ImportError(f"Textual version {textual.__version__} is not supported, please pip install textual==0.52.1") + + +check_textual_version() + + +async def load_path(p: Path, data: dict, mask_strs: str, idx: int, pbar): + samples = [] + async with aiofiles.open(p, encoding="utf-8") as f: + async for line in f: + d = json.loads(line) + for k in d: + if isinstance(d[k], str): + if mask_strs: + d[k] = re.sub(rf"{mask_strs}", "*", d[k]) + else: + d[k] = json.dumps(d[k], ensure_ascii=False, indent=4) + + d[INDEX_KEY] = len(samples) + samples.append(d) + data[idx] = {"samples": samples} + + print(f"path {p} loaded") + pbar.advance(1) + + +async def load_dir(path: Path, data: dict[int, dict], pbar, mask_strs: str = ""): + paths = list(path.glob(f"*{FILE_SUFFIX}")) + paths = sorted(paths, key=lambda x: int(x.stem)) + + tasks = [load_path(p, data, mask_strs, i, pbar) for i, p in enumerate(paths)] + + await asyncio.gather(*tasks) + + +class Highlighter(ReprHighlighter): + highlights = ReprHighlighter.highlights + [ + r"(?P[][\<\>{}()\|()【】\[\]=`])", + r"\<\|(?P[\w\W]*?)\|\>", + ] + + +def center_word_with_equals_exactly(word: str, total_length: int, char: str = "=") -> str: + if len(word) > total_length: + return word + + padding = total_length - len(word) + left_pad = (padding) // 2 + right_pad = (padding + 1) // 2 + return char * left_pad + " " + word + " " + char * right_pad + + +def highlight_keyword(content: str, keyword: Optional[str]): + if not keyword: + return Text(content) + text = Text() + parts = content.split(keyword) + for i, part in enumerate(parts): + text.append(part, style=None) + if i < len(parts) - 1: + # text.append(keyword, style=Style(color="#d154d1", bgcolor="yellow", bold=True)) + text.append(keyword, style="on #8f51b5") + return text + + +help_doc = """ +⌨️ keybinds: + +- `f/esc`: find/cancel +- `tab/←/→`: change focus +- `j/k`: page down/up +- `g/G`: scroll home/end +- `n/N`: next sample/step +- `p/P`: previous sample/step +- `s`: switch display mode + - plain text + - rich table + +""" + + +class JsonLineViewer(App): + BINDINGS = [ + ("left", "focus_previous", "Focus Previous"), + ("right", "focus_next", "Focus Next"), + ("s", "swith_render", "switch render"), + # control + ("n", "next_sample", "Next Sample"), + ("N", "next_step", "Next Step"), + ("p", "previous_sample", "Previous Sample"), + ("P", "previous_step", "Previous Step"), + # search + ("f", "toggle_search", "find"), + ("enter", "next_search", "find next"), + ("escape", "cancel_search", "cancel find"), + # scroll + ("j", "page_down", "page down"), + ("k", "page_up", "page up"), + ("g", "page_home", "page home"), + ("G", "page_end", "page end"), + ] + + CSS = """ + + Select:focus > SelectCurrent { + border: tall #8f51b5; + } + Select.-expanded > SelectCurrent { + border: tall #8f51b5; + } + #select-container { + width: 15%; + height: 100%; + align: center top; + } + #search-container { + height: 10%; + align: center top; + } + #search-box { + width: 50%; + } + #reqid-box { + width: 50%; + } + """ + + def __init__(self, step_num: int, data: dict[int, dict], pbar): + super().__init__() + self.step_num = step_num + + self.data = data + self.render_table = False + self.selected_step_index = 0 + self.selected_sample_index = 0 + self.pbar = pbar + + self.matches = [] + self.current_match_index = 0 + + self.highlighter = Highlighter() + + first_samples = data[list(data.keys())[0]]["samples"] + # Prepare the initial field filter list (all keys from the first sample) + self.filter_fields = [(f, f, True) for f in first_samples[0].keys()] + + # Internal set used for fast membership checks when we add new fields on the fly. + # We keep it here so that when new columns appear in later steps (e.g. `request_id`), + # they can be added to the UI automatically without restarting the viewer. + self._field_set: set[str] = set(first_samples[0].keys()) + self.sample_num = len(first_samples) + + def compose(self) -> ComposeResult: + with Horizontal(id="search-container"): + yield Input(placeholder="find something...", id="search-box") + yield Input(placeholder="request id...", id="reqid-box") + with Vertical(id="search-container2"): + yield self.pbar + yield Static("", id="search-status") + + with Horizontal(): + with Vertical(id="select-container"): + yield Static("\n") + yield Static( + renderable=Markdown( + help_doc, + ), + markup=False, + ) + yield Static("\n") + yield Select( + id="step-select", + value=0, + prompt="select step", + options=[("step: 1", 0)], + allow_blank=False, + ) + yield Select( + id="sample-select", + value=0, + prompt="select sample", + options=[("sample: 1", 0)], + allow_blank=False, + ) + yield Select( + id="sample-sort", + value=0, + prompt="排序", + options=[ + ("sort", 0), + ("score asc", 1), + ("score desc", 2), + ], + allow_blank=False, + ) + + yield SelectionList[int](("Select ALL", 1, True), id="fields-select-all") + with VerticalScroll(id="scroll-view2"): + yield SelectionList[str](*self.filter_fields, id="fields-select") + with VerticalScroll(id="scroll-view"): + yield Static(id="content", markup=False) + + async def on_mount(self) -> None: + self.step_select = self.query_one("#step-select", Select) + self.sample_select = self.query_one("#sample-select", Select) + self.sample_sort = self.query_one("#sample-sort", Select) + self.content_display = self.query_one("#content", Static) + self.search_box = self.query_one("#search-box", Input) + self.reqid_box = self.query_one("#reqid-box", Input) + self.scroll_view = self.query_one("#scroll-view", VerticalScroll) + self.search_status = self.query_one("#search-status", Static) + self.fields_select = self.query_one("#fields-select", SelectionList) + self.fields_select.border_title = "field filter" + + if self.data: + self.step_select.set_options([(f"step: {i + 1}", i) for i in range(self.step_num)]) + self.sample_select.set_options([(f"sample: {i + 1}", i) for i in range(self.sample_num)]) + self.step_select.focus() + await self.update_content() + + def update_result_options(self, offset: int = 0, sort_desc: Optional[bool] = None): + options = [] + if isinstance(self.selected_step_index, int) and self.selected_step_index < len(self.data): + if self.sample_num is None or sort_desc is not None: + samples = self.data[self.selected_step_index].get("samples", []) + if not samples: + self.selected_sample_index = offset + return + if sort_desc is not None: + samples = sorted( + samples, + key=lambda x: x.get("score", x.get("score_1", 0)), + reverse=sort_desc, + ) + + options = [(f"sample: {r[INDEX_KEY] + 1}", r[INDEX_KEY]) for r in samples] + self.sample_select.set_options(options) + self.sample_num = len(samples) + + if sort_desc is not None and options: + self.selected_sample_index = options[0][1] + else: + self.selected_sample_index = offset + + async def update_content(self, search_keyword: Optional[str] = None): + content = "" + try: + samples = self.data[self.selected_step_index].get("samples", []) + content_dict_full = samples[self.selected_sample_index] + + # Dynamically track any NEW keys that appear and add them to the field filter. + self._update_fields_select(content_dict_full.keys()) + + # Apply field selection filter (only show selected fields) + content_dict = {k: v for k, v in content_dict_full.items() if k in self.fields_select.selected} + if self.render_table: + content = Table("key", "value", show_lines=True) + for k in content_dict: + v = content_dict[k] + v = f"{v}" + content.add_row( + k, + self.highlighter(highlight_keyword(v, search_keyword)), + ) + else: + text = Text() + for k in content_dict: + v = content_dict[k] + s = center_word_with_equals_exactly(k, 64) + f"\n{v}\n" + text.append(highlight_keyword(s, search_keyword)) + content = self.highlighter(text) + except KeyError: + content = f"Loading data asynchronously, progress: {len(self.data)}/{self.step_num} step" + + except Exception: + content = self.highlighter(traceback.format_exc()) + + self.content_display.update(content) + + # --------------------------------------------------------------------- + # Request-ID jump logic + # --------------------------------------------------------------------- + + @on(Input.Submitted, "#reqid-box") + async def on_reqid_submitted(self, event: Input.Submitted) -> None: + """Jump to the sample that has a matching `request_id`.""" + + req_id_raw = event.value.strip() + # Remove hyphens so search is tolerant to different id formats + req_id = req_id_raw.replace("-", "") + if not req_id: + return + + found = False + for step_idx, step_data in self.data.items(): + for sample in step_data.get("samples", []): + sample_id = str(sample.get("request_id", "")) + if sample_id.replace("-", "") == req_id: + # Update selected indices + self.selected_step_index = step_idx + self.step_select.value = step_idx + + # Ensure sample list is updated and select sample + self.update_result_options(offset=sample[INDEX_KEY]) + self.selected_sample_index = sample[INDEX_KEY] + self.sample_select.value = sample[INDEX_KEY] + + await self._clear_search() + await self.update_content() + + found = True + break + if found: + break + + if not found: + self.search_status.update(Text(f"request_id '{req_id_raw}' not found", style="bold red")) + else: + # Keep the typed id in the input box so users see what was searched. + pass + + # --------------------------------------------------------------------- + # Helper: add new fields to SelectionList on-the-fly + # --------------------------------------------------------------------- + + def _update_fields_select(self, keys): + """Add any unseen *keys* to the field-selection widget so they can be toggled. + + The viewer is often launched with only the first step loaded. Later steps may + introduce new columns (e.g. `request_id`). This helper ensures those fields + become visible without requiring a restart. + """ + # Ensure we have the widget (only after on_mount) + if not hasattr(self, "fields_select"): + return + + for k in keys: + if k not in self._field_set: + self._field_set.add(k) + try: + # By default, new fields are selected so they appear immediately. + self.fields_select.add_option(k, k, selected=True) + except Exception: + # Fallback for older textual versions where signature is different. + self.fields_select.add_option((k, k, True)) + + @on(Select.Changed, "#step-select") + async def step_changed(self, event): + self.selected_step_index = event.value + self.update_result_options() + await self.update_content() + + @on(Select.Changed, "#sample-select") + async def sample_changed(self, event): + self.selected_sample_index = event.value + await self._clear_search() + await self.update_content() + + @on(Select.Changed, "#sample-sort") + async def sort_changed(self, event): + v = event.value + self.update_result_options(sort_desc=None if v == 0 else False if v == 1 else True) + await self.update_content() + + @on(SelectionList.SelectedChanged, "#fields-select") + async def fields_changed(self, event): + await self.update_content() + + @on(SelectionList.SelectedChanged, "#fields-select-all") + async def fields_all_changed(self, event): + s = self.query_one("#fields-select-all", SelectionList) + if s.selected: + self.fields_select.select_all() + else: + self.fields_select.deselect_all() + + def action_focus_previous(self): + self.screen.focus_previous() + + def action_focus_next(self): + self.screen.focus_next() + + async def action_next_step(self) -> None: + self.selected_step_index += 1 + if self.selected_step_index >= self.step_num: + self.selected_step_index = 0 + self.step_select.value = self.selected_step_index + self.update_result_options() + await self.update_content() + + async def action_next_sample(self) -> None: + self.selected_sample_index += 1 + if not self.sample_num or self.selected_sample_index >= self.sample_num: + self.selected_sample_index = 0 + self.sample_select.value = self.selected_sample_index + await self._clear_search() + await self.update_content() + + async def action_previous_step(self) -> None: + self.selected_step_index -= 1 + if self.selected_step_index < 0: + self.selected_step_index = self.step_num - 1 + self.step_select.value = self.selected_step_index + self.update_result_options() + await self.update_content() + + async def action_previous_sample(self) -> None: + self.selected_sample_index -= 1 + if self.selected_sample_index < 0: + self.selected_sample_index = self.sample_num - 1 + self.sample_select.value = self.selected_sample_index + await self._clear_search() + await self.update_content() + + async def action_swith_render(self): + self.render_table = not self.render_table + await self.update_content() + + def action_toggle_search(self) -> None: + self.search_box.focus() + + async def action_cancel_search(self) -> None: + self.search_box.value = "" + await self._clear_search() + await self.update_content() + + async def _clear_search(self): + self.matches = [] + self.search_status.update("") + self.current_match_index = 0 + + @on(Input.Submitted, "#search-box") + async def on_search_submitted(self, event: Input.Submitted) -> None: + self.matches = [] + self.current_match_index = 0 + if event.value: + await self.update_content(event.value) + renderable = self.content_display.render() + if isinstance(renderable, Table): + return + + assert isinstance(renderable, Text) + console = self.content_display._console + lines = renderable.wrap(console, self.scroll_view.container_size.width) + line_idx_recorded = set() + for line_idx, line in enumerate(lines): + if line_idx in line_idx_recorded: + continue + if event.value in line: + self.matches.append( + { + "line": line_idx, + "word": event.value, + } + ) + line_idx_recorded.add(line_idx) + self.scroll_view.focus() + await self.action_next_search() + + async def action_next_search(self) -> None: + if not self.matches or self.current_match_index >= len(self.matches): + return + + target_line = self.matches[self.current_match_index]["line"] + self.scroll_view.scroll_to(x=0, y=target_line * 1, animate=False) + self.current_match_index = (self.current_match_index + 1) % len(self.matches) + self.search_status.update( + Text( + f"Find :{self.current_match_index + 1}/{len(self.matches)}", + style="bold on #8f51b5", + ) + ) + + def action_page_up(self): + self.scroll_view.scroll_page_up(animate=False) + + def action_page_down(self): + self.scroll_view.scroll_page_down(animate=False) + + def action_page_home(self): + self.scroll_view.scroll_home(animate=False) + + def action_page_end(self): + self.scroll_view.scroll_end(animate=False) + + +async def _run(path: Path, mask_str: str): + assert path.exists(), f"{path} not exist" + + paths = list(path.glob(f"*{FILE_SUFFIX}")) + paths = sorted(paths, key=lambda x: int(x.stem)) + + if not paths: + raise ValueError(f"no available reward dump files under f{path}") + + print(f"get jsonl file nums: {len(paths)}") + + pbar = ProgressBar(total=len(paths), name="data load progress") + data = {} + await load_path(paths[0], data, mask_str, 0, pbar) + app = JsonLineViewer(step_num=len(paths), data=data, pbar=pbar) + await asyncio.gather(load_dir(path, data, pbar, mask_str), app.run_async()) + + +app = typer.Typer() + + +@app.command(help="launch TUI APP") +def run( + rollout_data_dir: Path, + mask_str: Annotated[str, typer.Option(help="string that will be masked to *")] = r"<\|image_pad\|>|<\|imgpad\|>", +): + loop = asyncio.get_event_loop() + loop.run_until_complete(_run(rollout_data_dir, mask_str)) + + +if __name__ == "__main__": + app() diff --git a/scripts/tools/serve_llm_as_verifier.sh b/scripts/tools/serve_llm_as_verifier.sh deleted file mode 100644 index 15f6689e1..000000000 --- a/scripts/tools/serve_llm_as_verifier.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=server_llm_as_verifier -#SBATCH --partition=main -#SBATCH --nodes=1 -#SBATCH --ntasks=1 -#SBATCH --cpus-per-task=64 -#SBATCH --gres=gpu:8 -#SBATCH --time=720:00:00 -#SBATCH --output=slurm/serve_llm_as_verifier_%j.log -#SBATCH --error=slurm/serve_llm_as_verifier_%j.err - - -# (1) detect this node’s primary IP -NODE_IP=$(hostname -I | awk '{print $1}') -echo "Detected NODE_IP = $NODE_IP" - -# (2) export judge URL for downstream clients -export STEM_LLM_JUDGE_URL="http://${NODE_IP}:8000" -echo "STEM_LLM_JUDGE_URL=$STEM_LLM_JUDGE_URL" - -# (3) launch the vLLM server bound to that IP -vllm serve TIGER-Lab/general-verifier --host "$NODE_IP" --data-parallel-size 8 diff --git a/scripts/train/dapo_7b_math_fsdp2_4_4.sh b/scripts/train/dapo_7b_math_fsdp2_4_4.sh new file mode 100644 index 000000000..c4018a4d8 --- /dev/null +++ b/scripts/train/dapo_7b_math_fsdp2_4_4.sh @@ -0,0 +1,320 @@ +#!/bin/bash +#SBATCH --job-name=dapo-7b-math-fsdp2-4-4 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main + +# NOTE: added by Reasoning360. +export MATH_LLM_JUDGE_URL=http://azure-uk-hpc-H200-instance-033:8000 + +export CONDA_BIN_PATH=/lustrefs/users/varad.pimpalkhute/anaconda3/envs/sync-rl-v5/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800 +export OMPI_MCA_coll_hcoll_enable=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-4-4' + +export VERL_USE_THREAD_TIMEOUT=false + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl_old"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} + + +# TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} + +# Training Data Configuration +DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1" +train_file_list=() +test_file_list_impossible_questions=() + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + # "codegen__deduped_leetcode2k_2.4k.parquet" + # "codegen__deduped_livecodebench_599.parquet" + # "codegen__deduped_primeintellect_9.6k.parquet" + # "codegen__deduped_taco_11.1k.parquet" + # "ifbench__fixed_85.6k.parquet" + # "logic__arcagi1_297.parquet" + # "logic__arcagi2_653.parquet" + # "logic__barc_3.4k.parquet" + # "logic__graph_logical_dataset_1.4k.parquet" + # "logic__ordering_puzzle_dataset_2.9k.parquet" + # "logic__reasoning_gym_40.6k.parquet" + # "logic__synlogic_12.1k.parquet" + # "logic__zebra_puzzle_dataset_5.0k.parquet" + "math__combined_118.2k.part1.parquet" + # "math__combined_118.2k.part2.parquet" + # "omni_math_4.43k.parquet" + # "simulation__codeio_fixed_12.1k.parquet" + # "stem__nemotron_13.3k.parquet" + # "stem__web_31.7k.parquet" + # "table__hitab_7.4k.parquet" + # "table__multihier_2.9k.parquet" +) + +echo "Collecting training files from ${DATA_MIX_DIR}..." + +# Search for each dataset in all subdirectories +# NOTE: added by Reasoning360. Exclude impossible_questions from training data +for dataset in "${dataset_names[@]}"; do + for subdir in "131k_context_questions" "main_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + train_file_list+=("'$file_path'") + fi + done +done + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +unset IFS + +for dataset in "${dataset_names[@]}"; do + for subdir in "impossible_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + test_file_list_impossible_questions+=("'$file_path'") + fi + done +done + +test_file_list_impossible_questions+=("${RAY_DATA_HOME}/data/aime-2024.parquet") + +IFS=, +test_files_impossible_questions="[${test_file_list_impossible_questions[*]}]" +unset IFS + +echo "Test files for impossible questions: ${#test_file_list_impossible_questions[@]}" +echo "Total training files found: ${#train_files[@]}" + + +# =================== Ray node setup =================== + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES VERL_USE_THREAD_TIMEOUT=${VERL_USE_THREAD_TIMEOUT} \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES VERL_USE_THREAD_TIMEOUT=${VERL_USE_THREAD_TIMEOUT} \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 +# =================== Ray node setup end =================== + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=1 +sp_size=1 +fsdp_size=2 + +# Fully async specific parameters +NNODES=${NNODES:-1} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +n_gpus_rollout=4 +n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout)) + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*100))) +test_freq=10 +staleness_threshold=0.1 +trigger_parameter_sync_step=4 +require_batches=4 +partial_rollout=True + +"${CONDA_BIN_PATH}python3" -m reasoning360.recipe.fully_async_policy.fully_async_main \ + data.train_files="${train_files}" \ + data.val_files="${test_files_impossible_questions}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + reward_model.reward_manager=dapo \ + reward_model.launch_reward_fn_async=True \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.save_freq=-1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.nnodes="${NNODES}" \ + trainer.n_gpus_per_node="${n_gpus_training}" \ + rollout.nnodes="${NNODES}" \ + rollout.n_gpus_per_node="${n_gpus_rollout}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True \ No newline at end of file diff --git a/scripts/train/dapo_7b_math_fsdp2_8_8.sh b/scripts/train/dapo_7b_math_fsdp2_8_8.sh new file mode 100644 index 000000000..aa7bd0ab4 --- /dev/null +++ b/scripts/train/dapo_7b_math_fsdp2_8_8.sh @@ -0,0 +1,320 @@ +#!/bin/bash +#SBATCH --job-name=dapo-7b-math-fsdp2-8-8 +#SBATCH --nodes=2 +#SBATCH --ntasks=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main + +# NOTE: added by Reasoning360. +export MATH_LLM_JUDGE_URL=http://azure-uk-hpc-H200-instance-033:8000 + +export CONDA_BIN_PATH=/lustrefs/users/varad.pimpalkhute/anaconda3/envs/sync-rl-v5/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800 +export OMPI_MCA_coll_hcoll_enable=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-8-8' + +export VERL_USE_THREAD_TIMEOUT=false + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl_old"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} + + +# TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} + +# Training Data Configuration +DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1" +train_file_list=() +test_file_list_impossible_questions=() + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + # "codegen__deduped_leetcode2k_2.4k.parquet" + # "codegen__deduped_livecodebench_599.parquet" + # "codegen__deduped_primeintellect_9.6k.parquet" + # "codegen__deduped_taco_11.1k.parquet" + # "ifbench__fixed_85.6k.parquet" + # "logic__arcagi1_297.parquet" + # "logic__arcagi2_653.parquet" + # "logic__barc_3.4k.parquet" + # "logic__graph_logical_dataset_1.4k.parquet" + # "logic__ordering_puzzle_dataset_2.9k.parquet" + # "logic__reasoning_gym_40.6k.parquet" + # "logic__synlogic_12.1k.parquet" + # "logic__zebra_puzzle_dataset_5.0k.parquet" + "math__combined_118.2k.part1.parquet" + # "math__combined_118.2k.part2.parquet" + # "omni_math_4.43k.parquet" + # "simulation__codeio_fixed_12.1k.parquet" + # "stem__nemotron_13.3k.parquet" + # "stem__web_31.7k.parquet" + # "table__hitab_7.4k.parquet" + # "table__multihier_2.9k.parquet" +) + +echo "Collecting training files from ${DATA_MIX_DIR}..." + +# Search for each dataset in all subdirectories +# NOTE: added by Reasoning360. Exclude impossible_questions from training data +for dataset in "${dataset_names[@]}"; do + for subdir in "131k_context_questions" "main_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + train_file_list+=("'$file_path'") + fi + done +done + +for dataset in "${dataset_names[@]}"; do + for subdir in "impossible_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + test_file_list_impossible_questions+=("'$file_path'") + fi + done +done + +test_file_list_impossible_questions+=("'${RAY_DATA_HOME}/data/aime-2024.parquet'") + +echo "Test files for impossible questions: ${#test_file_list_impossible_questions[@]}" +echo "Total training files found: ${#train_file_list[@]}" + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +test_files="[${test_file_list_impossible_questions[*]}]" +unset IFS + +echo "Test files for impossible questions: ${test_files}" +echo "Training files: ${train_files}" + + +# =================== Ray node setup =================== + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES VERL_USE_THREAD_TIMEOUT=${VERL_USE_THREAD_TIMEOUT} \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES VERL_USE_THREAD_TIMEOUT=${VERL_USE_THREAD_TIMEOUT} \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 +# =================== Ray node setup end =================== + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=1 +sp_size=1 +fsdp_size=2 + +# Fully async specific parameters +NNODES=${NNODES:-2} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +n_gpus_rollout=4 +n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout)) + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*100))) +test_freq=10 +staleness_threshold=0.1 +trigger_parameter_sync_step=4 +require_batches=4 +partial_rollout=True + +"${CONDA_BIN_PATH}python3" -m recipe.fully_async_policy.fully_async_main \ + data.train_files="${train_files}" \ + data.val_files="${test_files}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + reward_model.reward_manager=dapo \ + reward_model.launch_reward_fn_async=True \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=False \ + trainer.save_freq=-1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.nnodes="${NNODES}" \ + trainer.n_gpus_per_node="${n_gpus_training}" \ + rollout.nnodes="${NNODES}" \ + rollout.n_gpus_per_node="${n_gpus_rollout}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True \ No newline at end of file diff --git a/scripts/train/example_multinode_rl_llama3.1_70b_distill_fsdp.sh b/scripts/train/example_multinode_rl_llama3.1_70b_distill_fsdp.sh deleted file mode 100644 index 4f7c2f2d4..000000000 --- a/scripts/train/example_multinode_rl_llama3.1_70b_distill_fsdp.sh +++ /dev/null @@ -1,268 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=example-multinode-rl-llama3.1-70b-distill-fsdp -#SBATCH --nodes=32 -#SBATCH --ntasks=32 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:8 -#SBATCH --cpus-per-task=128 -#SBATCH --mem=0 -#SBATCH --output=slurm/%x-%j.out -#SBATCH --error=slurm/%x-%j.err -#SBATCH --exclusive -#SBATCH --time=720:00:00 - - -# =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch -export STEM_LLM_JUDGE_URL="" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain - -# =================== Cluster Environment =================== -export NCCL_DEBUG=info -export NCCL_ALGO=NVLSTree -export NCCL_IBEXT_DISABLE=1 -export NCCL_NVLS_ENABLE=1 -export NCCL_IB_HCA=mlx5 -export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export CUDA_LAUNCH_BLOCKING=1 - -# Get the list of allocated nodes -nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) -echo "Nodes to check: ${nodes[@]}" - -# We'll track PIDs so we can wait on them and detect errors -declare -A pids -export head_node=${nodes[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) -port=6379 -address_head=$head_node_ip:$port - -export worker_num=$SLURM_NNODES -export HYDRA_FULL_ERROR=1 -export VLLM_USE_V1=0 - -# =================== Data Mixture =================== -SHARED_DATA_PATH=./data -TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ -TEST_DATA_DIR=${SHARED_DATA_PATH}/online_eval/ - -# Math (train) -math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet -# Math (test) -math_test_path=${TEST_DATA_DIR}/math__math_500.parquet -aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet -amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet - -# Code (train) -leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet -livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet -primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet -taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet -# Code (test) -humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet -mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_200.parquet -livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet - -# Logic (train) -arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet -arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet -barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet -graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet -ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet -zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet -# Logic (test) -ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_100.parquet -zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet -arcagi_test_path=${TEST_DATA_DIR}/logic__arcagi1_200.parquet - -# Simulation (train) -codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet -# Simulation (test) -codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet - -# Table (train) -hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet -multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet -# Table (test) -multihier_test_path=${TEST_DATA_DIR}/table__multihier_200.parquet -hitab_test_path=${TEST_DATA_DIR}/table__hitab_200.parquet - -# Stem (train) -webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet -# Stem (test) -supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet - -train_files="['${math_train_path}']" # Use math as example, add to more tasks as needed -test_files="['${math_test_path}','${aime_test_path}']" # Use math as example, add to more tasks as needed - -# =================== Model =================== -BASE_MODEL=deepseek-ai/DeepSeek-R1-Distill-Llama-70B - -# =================== Logging =================== -WANDB_PROJECT=Reasoning360 -WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} - -# If RESUME_CKPT_DIR is not empty, resume from the checkpoint -if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then - WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" -fi - - -# =================== Ray start =================== -# ray stop at all nodes -srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop - -sleep 10 -# Remove existing Ray cluster -srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster - -# Start Ray head node -srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ - ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & - -sleep 10 - -# Start Ray worker nodes -for ((i = 1; i < worker_num; i++)); do - node_i=${nodes[$i]} - echo "Starting WORKER $i at $node_i" - srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ - ${CONDA_BIN_PATH}ray start --address "$address_head" \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & -done -sleep 10 - - -# =================== RL Config =================== -# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.2 - -max_prompt_length=$((1024 * 4)) -max_response_length=$((1024 * 32)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -enable_filter_groups=False -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n -gen_prompt_bsz=$((train_prompt_bsz * 1)) -n_resp_per_prompt=16 -train_prompt_mini_bsz=32 # model grad update batchsize - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout - -# Training config -sp_size=4 -gen_tp=4 -gen_max_num_seqs=1024 -infer_micro_batch_size=null -train_micro_batch_size=null -use_dynamic_bsz=True -actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow -infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow -offload=True - -# =================== Start RL training =================== -"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ - --config-path=config \ - --config-name="dapo_fsdp_config.yaml" \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.prompt_key=prompt \ - data.truncation='right' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - data.gen_batch_size=${gen_prompt_bsz} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.actor.strategy="fsdp" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.optim.warmup_style=constant \ - actor_rollout_ref.actor.optim.min_lr_ratio=0. \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.model.path=$BASE_MODEL \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.rollout.multi_turn.enable=False \ - actor_rollout_ref.rollout.mode="sync" \ - +actor_rollout_ref.model.override_config.attention_dropout=0. \ - +actor_rollout_ref.model.override_config.embd_pdrop=0. \ - +actor_rollout_ref.model.override_config.resid_pdrop=0. \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - reward_model.reward_manager=async_multi_process \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ - trainer.val_before_train=True \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=$worker_num \ - trainer.save_freq=10 \ - trainer.test_freq=10 \ - trainer.total_epochs=5 \ - trainer.log_val_generations=50 \ - trainer.resume_mode=auto \ - trainer.max_actor_ckpt_to_keep=2 \ No newline at end of file diff --git a/scripts/train/example_multinode_rl_llama3.1_70b_distill_megatron.sh b/scripts/train/example_multinode_rl_llama3.1_70b_distill_megatron.sh deleted file mode 100644 index 6a33c683c..000000000 --- a/scripts/train/example_multinode_rl_llama3.1_70b_distill_megatron.sh +++ /dev/null @@ -1,290 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=example-multinode-rl-llama3.1-70b-distill-megatron -#SBATCH --nodes=32 -#SBATCH --ntasks=32 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:8 -#SBATCH --cpus-per-task=128 -#SBATCH --mem=0 -#SBATCH --output=slurm/%x-%j.out -#SBATCH --error=slurm/%x-%j.err -#SBATCH --exclusive -#SBATCH --time=720:00:00 - - -# =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="" -export STEM_LLM_JUDGE_URL="" - -# =================== Environment =================== -export NCCL_DEBUG=info -export NCCL_ALGO=NVLSTree -export NCCL_IBEXT_DISABLE=1 -export NCCL_NVLS_ENABLE=1 -export NCCL_IB_HCA=mlx5 -export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export CUDA_LAUNCH_BLOCKING=1 - - -# Get the list of allocated nodes -nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) -echo "Nodes to check: ${nodes[@]}" - -# We'll track PIDs so we can wait on them and detect errors -declare -A pids - -export head_node=${nodes[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) -port=6379 -address_head=$head_node_ip:$port - -export worker_num=$SLURM_NNODES -export HYDRA_FULL_ERROR=1 -export VLLM_USE_V1=0 -export RAY_record_ref_creation_sites=1 # NOTE(yonghao): DEBUG code -# export GLOO_SOCKET_IFNAME=ens10f0np0 - - -# =================== Data Mixture =================== -SHARED_DATA_PATH=./data -TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ -TEST_DATA_DIR=${SHARED_DATA_PATH}/online_eval/ - -# Math (train) -math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet -# Math (test) -math_test_path=${TEST_DATA_DIR}/math__math_500.parquet -aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet -amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet - -# Code (train) -leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet -livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet -primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet -taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet -# Code (test) -humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet -mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_200.parquet -livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet - -# Logic (train) -arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet -arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet -barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet -graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet -ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet -zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet -# Logic (test) -ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_100.parquet -zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet -arcagi_test_path=${TEST_DATA_DIR}/logic__arcagi1_200.parquet - -# Simulation (train) -codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet -# Simulation (test) -codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet - -# Table (train) -hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet -multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet -# Table (test) -multihier_test_path=${TEST_DATA_DIR}/table__multihier_200.parquet -hitab_test_path=${TEST_DATA_DIR}/table__hitab_200.parquet - -# Stem (train) -webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet -# Stem (test) -supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet - -train_files="['${math_train_path}']" # Use math as example, add to more tasks as needed -test_files="['${math_test_path}','${aime_test_path}']" # Use math as example, add to more tasks as needed - -# =================== Model =================== -BASE_MODEL="deepseek-ai/DeepSeek-R1-Distill-Llama-70B" - -# =================== Logging =================== -WANDB_PROJECT=Reasoning360 -WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} - -# If RESUME_CKPT_DIR is not empty, resume from the checkpoint -if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then - WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" -fi - - -# =================== Ray start =================== -# ray stop at all nodes -srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_MEGATRON_PATH}ray stop - -sleep 10 -# Remove existing Ray cluster -srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster - -# Start Ray head node -srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ - ${CONDA_BIN_MEGATRON_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & - -sleep 10 - -# Start Ray worker nodes -for ((i = 1; i < worker_num; i++)); do - node_i=${nodes[$i]} - echo "Starting WORKER $i at $node_i" - srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ - ${CONDA_BIN_MEGATRON_PATH}ray start --address "$address_head" \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & -done -sleep 10 - - -# =================== RL Config =================== -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.2 - -max_prompt_length=$((1024 * 4)) -max_response_length=$((1024 * 32)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -enable_filter_groups=False -filter_groups_metric=acc -max_num_gen_batches=16 -train_prompt_bsz=256 # grad accum bsz; real grad accum bsz: train_prompt_bsz * rollout.n -gen_prompt_bsz=$((train_prompt_bsz * 1)) # rollout bsz, i.e., the x-axis in RL plot -n_resp_per_prompt=16 -train_prompt_mini_bsz=8 - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout - -# Generation config -gen_tp=4 -gen_max_num_seqs=1024 - -# Megatron trainer config -train_tp=8 -train_pp=2 -sp_size=8 -offload=True - -# Batch size -use_dynamic_bsz=True -train_micro_batch_size=null -train_micro_batch_size_per_gpu_placeholder=1 # can't be null, as in ray_trainer.py ```minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``` -infer_micro_batch_size_per_gpu_placeholder=8 # can't be null, as in megatron_worker.py ```assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and `log_prob_micro_batch_size` should not be None at the same time."``` -# NOTE: this one is for per gpu, so it times sp_size (defined later) -# actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1 )) -actor_ppo_max_token_len=8192 -infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1 )) - - -# NOTE(yonghao): all other parts (weights, optimizer states) exists across stages (training, generation) -# while this one only lives during a training iteration. -grad_offload=True -#### - -# =================== Start RL training =================== -"${CONDA_BIN_MEGATRON_PATH}python" -m recipe.dapo.main_dapo \ - --config-path=config \ - --config-name="dapo_megatron_config.yaml" \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.prompt_key=prompt \ - data.truncation='right' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - data.gen_batch_size=${gen_prompt_bsz} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.actor.strategy="megatron" \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.lr_warmup_init=0.0 \ - actor_rollout_ref.actor.optim.lr=5e-7 \ - actor_rollout_ref.actor.optim.lr_decay_style=constant \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.optim.min_lr=0. \ - actor_rollout_ref.actor.optim.clip_grad=1.0 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_micro_batch_size_per_gpu_placeholder} \ - actor_rollout_ref.actor.megatron.param_offload=${offload} \ - actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ - actor_rollout_ref.actor.megatron.grad_offload=${grad_offload} \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.actor.megatron.context_parallel_size=${sp_size} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_micro_batch_size_per_gpu_placeholder} \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.ref.megatron.context_parallel_size=${sp_size} \ - actor_rollout_ref.ref.megatron.param_offload=${offload} \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.65 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_micro_batch_size_per_gpu_placeholder} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.model.path=$BASE_MODEL \ - +actor_rollout_ref.model.use_remove_padding=True \ - +actor_rollout_ref.model.override_config.attention_dropout=0. \ - +actor_rollout_ref.model.override_config.embd_pdrop=0. \ - +actor_rollout_ref.model.override_config.resid_pdrop=0. \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - reward_model.reward_manager=async_multi_process \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=$worker_num \ - trainer.save_freq=-1 \ - trainer.test_freq=10 \ - trainer.total_epochs=5 \ - trainer.log_val_generations=50 \ - trainer.resume_mode=auto \ - trainer.max_actor_ckpt_to_keep=2 \ No newline at end of file diff --git a/scripts/train/example_multinode_rl_qwen2.5_32b_base_fsdp.sh b/scripts/train/example_multinode_rl_qwen2.5_32b_base_fsdp.sh deleted file mode 100644 index e39ea8624..000000000 --- a/scripts/train/example_multinode_rl_qwen2.5_32b_base_fsdp.sh +++ /dev/null @@ -1,268 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=example-multinode-rl-qwen2.5-32b-base-fsdp -#SBATCH --nodes=8 -#SBATCH --ntasks=8 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:8 -#SBATCH --cpus-per-task=96 -#SBATCH --mem=0 -#SBATCH --output=slurm/%x-%j.out -#SBATCH --error=slurm/%x-%j.err -#SBATCH --exclusive -#SBATCH --time=720:00:00 - - -# =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch -export STEM_LLM_JUDGE_URL="" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain - -# =================== Cluster Environment =================== -export NCCL_DEBUG=info -export NCCL_ALGO=NVLSTree -export NCCL_IBEXT_DISABLE=1 -export NCCL_NVLS_ENABLE=1 -export NCCL_IB_HCA=mlx5 -export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export CUDA_LAUNCH_BLOCKING=1 - -# Get the list of allocated nodes -nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) -echo "Nodes to check: ${nodes[@]}" - -# We'll track PIDs so we can wait on them and detect errors -declare -A pids -export head_node=${nodes[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) -port=6379 -address_head=$head_node_ip:$port - -export worker_num=$SLURM_NNODES -export HYDRA_FULL_ERROR=1 -export VLLM_USE_V1=0 - -# =================== Data Mixture =================== -SHARED_DATA_PATH=./data -TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ -TEST_DATA_DIR=${SHARED_DATA_PATH}/online_eval/ - -# Math (train) -math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet -# Math (test) -math_test_path=${TEST_DATA_DIR}/math__math_500.parquet -aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet -amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet - -# Code (train) -leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet -livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet -primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet -taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet -# Code (test) -humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet -mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_200.parquet -livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet - -# Logic (train) -arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet -arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet -barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet -graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet -ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet -zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet -# Logic (test) -ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_100.parquet -zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet -arcagi_test_path=${TEST_DATA_DIR}/logic__arcagi1_200.parquet - -# Simulation (train) -codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet -# Simulation (test) -codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet - -# Table (train) -hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet -multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet -# Table (test) -multihier_test_path=${TEST_DATA_DIR}/table__multihier_200.parquet -hitab_test_path=${TEST_DATA_DIR}/table__hitab_200.parquet - -# Stem (train) -webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet -# Stem (test) -supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet - -train_files="['${math_train_path}']" # Use math as example, add to more tasks as needed -test_files="['${math_test_path}','${aime_test_path}']" # Use math as example, add to more tasks as needed - -# =================== Model =================== -BASE_MODEL=Qwen/Qwen2.5-32B - -# =================== Logging =================== -WANDB_PROJECT=Reasoning360 -WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} - -# If RESUME_CKPT_DIR is not empty, resume from the checkpoint -if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then - WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" -fi - - -# =================== Ray start =================== -# ray stop at all nodes -srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop - -sleep 10 -# Remove existing Ray cluster -srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster - -# Start Ray head node -srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ - ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & - -sleep 10 - -# Start Ray worker nodes -for ((i = 1; i < worker_num; i++)); do - node_i=${nodes[$i]} - echo "Starting WORKER $i at $node_i" - srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ - ${CONDA_BIN_PATH}ray start --address "$address_head" \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & -done -sleep 10 - - -# =================== RL Config =================== -# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.2 - -max_prompt_length=$((1024 * 4)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -enable_filter_groups=False -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n -gen_prompt_bsz=$((train_prompt_bsz * 1)) -n_resp_per_prompt=16 -train_prompt_mini_bsz=64 # model grad update batchsize - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout - -# Training config -sp_size=1 -gen_tp=2 -gen_max_num_seqs=1024 -infer_micro_batch_size=null -train_micro_batch_size=null -use_dynamic_bsz=True -actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward & backward but note memory overflow -infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up modelforward, but note memory overflow -offload=True - -# =================== Start RL training =================== -"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ - --config-path=config \ - --config-name="dapo_fsdp_config.yaml" \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.prompt_key=prompt \ - data.truncation='right' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - data.gen_batch_size=${gen_prompt_bsz} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.actor.strategy="fsdp" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.optim.warmup_style=constant \ - actor_rollout_ref.actor.optim.min_lr_ratio=0. \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.model.path=$BASE_MODEL \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.rollout.multi_turn.enable=False \ - actor_rollout_ref.rollout.mode="sync" \ - +actor_rollout_ref.model.override_config.attention_dropout=0. \ - +actor_rollout_ref.model.override_config.embd_pdrop=0. \ - +actor_rollout_ref.model.override_config.resid_pdrop=0. \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - reward_model.reward_manager=async_multi_process \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ - trainer.val_before_train=True \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=$worker_num \ - trainer.save_freq=10 \ - trainer.test_freq=10 \ - trainer.total_epochs=5 \ - trainer.log_val_generations=50 \ - trainer.resume_mode=auto \ - trainer.max_actor_ckpt_to_keep=2 \ No newline at end of file diff --git a/scripts/train/example_multinode_rl_qwen2.5_32b_base_megatron.sh b/scripts/train/example_multinode_rl_qwen2.5_32b_base_megatron.sh deleted file mode 100644 index bb88bb9a1..000000000 --- a/scripts/train/example_multinode_rl_qwen2.5_32b_base_megatron.sh +++ /dev/null @@ -1,287 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=example-multinode-rl-qwen2.5-32b-base-megatron -#SBATCH --nodes=8 -#SBATCH --ntasks=8 -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:8 -#SBATCH --cpus-per-task=96 -#SBATCH --mem=0 -#SBATCH --output=slurm/%x-%j.out -#SBATCH --error=slurm/%x-%j.err -#SBATCH --exclusive -#SBATCH --time=720:00:00 - - -# =================== Frequently Used Variables =================== -RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch -export STEM_LLM_JUDGE_URL="" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain - -# =================== Cluster Environment =================== -export NCCL_DEBUG=info -export NCCL_ALGO=NVLSTree -export NCCL_IBEXT_DISABLE=1 -export NCCL_NVLS_ENABLE=1 -export NCCL_IB_HCA=mlx5 -export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export CUDA_LAUNCH_BLOCKING=1 - -# Get the list of allocated nodes -nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) -echo "Nodes to check: ${nodes[@]}" - -# We'll track PIDs so we can wait on them and detect errors -declare -A pids -export head_node=${nodes[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) -port=6379 -address_head=$head_node_ip:$port - -export worker_num=$SLURM_NNODES -export HYDRA_FULL_ERROR=1 -export VLLM_USE_V1=0 - -# =================== Data Mixture =================== -SHARED_DATA_PATH=./data -TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ -TEST_DATA_DIR=${SHARED_DATA_PATH}/online_eval/ - -# Math (train) -math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet -# Math (test) -math_test_path=${TEST_DATA_DIR}/math__math_500.parquet -aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet -amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet - -# Code (train) -leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet -livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet -primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet -taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet -# Code (test) -humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet -mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_200.parquet -livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet - -# Logic (train) -arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet -arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet -barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet -graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet -ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet -zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet -# Logic (test) -ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_100.parquet -zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet -arcagi_test_path=${TEST_DATA_DIR}/logic__arcagi1_200.parquet - -# Simulation (train) -codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet -# Simulation (test) -codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet - -# Table (train) -hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet -multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet -# Table (test) -multihier_test_path=${TEST_DATA_DIR}/table__multihier_200.parquet -hitab_test_path=${TEST_DATA_DIR}/table__hitab_200.parquet - -# Stem (train) -webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet -# Stem (test) -supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet - -train_files="['${math_train_path}']" # Use math as example, add to more tasks as needed -test_files="['${math_test_path}','${aime_test_path}']" # Use math as example, add to more tasks as needed - -# =================== Model =================== -BASE_MODEL=Qwen/Qwen2.5-32B - -# =================== Logging =================== -WANDB_PROJECT=Reasoning360 -WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} - -# If RESUME_CKPT_DIR is not empty, resume from the checkpoint -if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then - WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" -fi - - -# =================== Ray start =================== -# ray stop at all nodes -srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_MEGATRON_PATH}ray stop - -sleep 10 -# Remove existing Ray cluster -srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster - -# Start Ray head node -srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ - ${CONDA_BIN_MEGATRON_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & - -sleep 10 - -# Start Ray worker nodes -for ((i = 1; i < worker_num; i++)); do - node_i=${nodes[$i]} - echo "Starting WORKER $i at $node_i" - srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ - ${CONDA_BIN_MEGATRON_PATH}ray start --address "$address_head" \ - --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & -done -sleep 10 - - -# =================== RL Config =================== -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.2 - -max_prompt_length=$((1024 * 4)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -enable_filter_groups=False -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=512 # grad accum bsz; real grad accum bsz: train_prompt_bsz * rollout.n -gen_prompt_bsz=$((train_prompt_bsz * 1)) # rollout bsz, i.e., the x-axis in RL plot -n_resp_per_prompt=16 -train_prompt_mini_bsz=64 - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout - -# Generation config -gen_tp=2 -gen_max_num_seqs=1024 - -# Megatron trainer config -train_tp=8 -train_pp=1 -sp_size=2 -offload=True - -# Batch size -use_dynamic_bsz=True -train_micro_batch_size=null -train_micro_batch_size_per_gpu_placeholder=1 # can't be null, as in ray_trainer.py ```minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``` -infer_micro_batch_size_per_gpu_placeholder=1 # can't be null, as in megatron_worker.py ```assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and `log_prob_micro_batch_size` should not be None at the same time."``` -# NOTE: this one is for per gpu, so it times sp_size (defined later) -# actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1 )) -actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2 )) -infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2 )) - - -# NOTE(yonghao): all other parts (weights, optimizer states) exists across stages (training, generation) -# while this one only lives during a training iteration. -grad_offload=True -#### - -# =================== Start RL training =================== -"${CONDA_BIN_MEGATRON_PATH}python" -m recipe.dapo.main_dapo \ - --config-path=config \ - --config-name="dapo_megatron_config.yaml" \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.prompt_key=prompt \ - data.truncation='right' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - data.gen_batch_size=${gen_prompt_bsz} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.actor.strategy="megatron" \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.lr_warmup_init=0.0 \ - actor_rollout_ref.actor.optim.lr=5e-7 \ - actor_rollout_ref.actor.optim.lr_decay_style=constant \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.optim.min_lr=0. \ - actor_rollout_ref.actor.optim.clip_grad=1.0 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_micro_batch_size_per_gpu_placeholder} \ - actor_rollout_ref.actor.megatron.param_offload=${offload} \ - actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \ - actor_rollout_ref.actor.megatron.grad_offload=${grad_offload} \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.actor.megatron.context_parallel_size=${sp_size} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_micro_batch_size_per_gpu_placeholder} \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ - actor_rollout_ref.ref.megatron.context_parallel_size=${sp_size} \ - actor_rollout_ref.ref.megatron.param_offload=${offload} \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_micro_batch_size_per_gpu_placeholder} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.model.path=$BASE_MODEL \ - +actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.rollout.multi_turn.enable=False \ - actor_rollout_ref.rollout.mode="sync" \ - +actor_rollout_ref.model.override_config.attention_dropout=0. \ - +actor_rollout_ref.model.override_config.embd_pdrop=0. \ - +actor_rollout_ref.model.override_config.resid_pdrop=0. \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - reward_model.reward_manager=async_multi_process \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ - trainer.log_val_generations=50 \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=$worker_num \ - trainer.save_freq=10 \ - trainer.test_freq=10 \ - trainer.total_epochs=5 \ - trainer.log_val_generations=50 \ - trainer.resume_mode=auto \ - trainer.max_actor_ckpt_to_keep=2 \ No newline at end of file diff --git a/scripts/train/example_singlenode_rl_qwen2.5_7b_base_fsdp.sh b/scripts/train/example_singlenode_rl_qwen2.5_7b_base_fsdp.sh deleted file mode 100644 index 52c0a0f60..000000000 --- a/scripts/train/example_singlenode_rl_qwen2.5_7b_base_fsdp.sh +++ /dev/null @@ -1,235 +0,0 @@ -#!/bin/bash - -# =================== User-Configurable Settings =================== -# --- Execution Environment --- -NUM_GPUS=8 # Set the number of GPUs to use on this node - -# --- Resuming & Logging --- -RESUME_CKPT_DIR_NAME="" # Fill in the W&B experiment name to resume from, otherwise leave empty to start from scratch -WANDB_PROJECT="Reasoning360" # Your wandb project name - -# --- External Services --- -export STEM_LLM_JUDGE_URL="" # Optional: Fill in the llm-as-judge hosted URL for 'STEM' domain evaluation - -# =================== Environment Setup =================== -export NCCL_DEBUG=info -export CUDA_DEVICE_MAX_CONNECTIONS=1 -# export CUDA_LAUNCH_BLOCKING=1 # Uncomment for easier debugging of CUDA errors - -export HYDRA_FULL_ERROR=1 -export VLLM_USE_V1=0 - -# =================== Data Mixture =================== -SHARED_DATA_PATH=./data -TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train/ -TEST_DATA_DIR=${SHARED_DATA_PATH}/online_eval/ - -# Math (train) -math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet -# Math (test) -math_test_path=${TEST_DATA_DIR}/math__math_500.parquet -aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet -amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet - -# Code (train) -leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet -livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet -primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet -taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet -# Code (test) -humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet -mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_200.parquet -livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet - -# Logic (train) -arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet -arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet -barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet -graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet -ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet -zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet -# Logic (test) -ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_100.parquet -zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet -arcagi_test_path=${TEST_DATA_DIR}/logic__arcagi1_200.parquet - -# Simulation (train) -codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet -# Simulation (test) -codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_200.parquet - -# Table (train) -hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet -multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet -# Table (test) -multihier_test_path=${TEST_DATA_DIR}/table__multihier_200.parquet -hitab_test_path=${TEST_DATA_DIR}/table__hitab_200.parquet - -# Stem (train) -webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet -# Stem (test) -supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet - -train_files="['${math_train_path}']" # Use math as example, add to more tasks as needed -test_files="['${math_test_path}','${aime_test_path}']" # Use math as example, add to more tasks as needed - -# =================== Model =================== -BASE_MODEL=Qwen/Qwen2.5-7B - -# =================== Logging =================== -# Generate a unique experiment name if not resuming -if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then - WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" -else - TIMESTAMP=$(date +%Y%m%d-%H%M%S) - WANDB_EXPERIMENT_NAME="single-node-${TIMESTAMP}-${BASE_MODEL##*/}" -fi - -# =================== Ray Start (Single Node) =================== -# Stop any previous Ray instances -${CONDA_BIN_PATH}ray stop -f - -# Start a new Ray cluster on the local machine -# The number of CPUs is often best left for Ray to determine automatically. -echo "Starting Ray on the local node with ${NUM_GPUS} GPUs..." -${CONDA_BIN_PATH}ray start --head --num-gpus ${NUM_GPUS} --include-dashboard=True --dashboard-port 8265 -sleep 5 - - -# =================== RL Config =================== -# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. - -adv_estimator=grpo - -use_kl_in_reward=False -kl_coef=0.0 -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.2 - -max_prompt_length=$((1024 * 4)) -max_response_length=$((1024 * 8)) -enable_overlong_buffer=False -overlong_buffer_len=$((1024 * 4)) -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -enable_filter_groups=False -filter_groups_metric=acc -max_num_gen_batches=10 -train_prompt_bsz=512 # on-policy model update batchsize: train_prompt_bsz * rollout.n -gen_prompt_bsz=$((train_prompt_bsz * 1)) -n_resp_per_prompt=16 -train_prompt_mini_bsz=64 # model grad update batchsize - -# Algorithm -temperature=1.0 -top_p=1.0 -top_k=-1 # 0 for HF rollout, -1 for vLLM rollout - -# Training config -# NOTE: sp_size and gen_tp are parallelism settings. -# sp_size: Sequence Parallelism size. -# gen_tp: Tensor Parallelism size for vLLM generation. -# For a 32B model on 8 GPUs, TP=2 is a reasonable starting point. Adjust if you have memory issues. -sp_size=1 -gen_tp=2 -gen_max_num_seqs=1024 -infer_micro_batch_size=null -train_micro_batch_size=null -use_dynamic_bsz=True -actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward & backward but note memory overflow -infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward, but note memory overflow -offload=True - -# =================== Start RL training =================== -# Ensure your python environment (e.g., conda) is activated before running this script. -echo "Starting training..." -python -m recipe.dapo.main_dapo \ - --config-path=config \ - --config-name="dapo_fsdp_config.yaml" \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.prompt_key=prompt \ - data.truncation='right' \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - data.train_batch_size=${train_prompt_bsz} \ - data.gen_batch_size=${gen_prompt_bsz} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - actor_rollout_ref.actor.clip_ratio_c=10.0 \ - actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ - actor_rollout_ref.actor.strategy="fsdp" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ - actor_rollout_ref.actor.optim.weight_decay=0.1 \ - actor_rollout_ref.actor.optim.warmup_style=constant \ - actor_rollout_ref.actor.optim.min_lr_ratio=0. \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.grad_clip=1.0 \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ - actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ - actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ - actor_rollout_ref.rollout.enable_chunked_prefill=True \ - actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ - actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ - actor_rollout_ref.rollout.temperature=${temperature} \ - actor_rollout_ref.rollout.top_p=${top_p} \ - actor_rollout_ref.rollout.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ - actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ - actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ - actor_rollout_ref.rollout.val_kwargs.n=1 \ - actor_rollout_ref.rollout.val_kwargs.do_sample=True \ - actor_rollout_ref.model.path=$BASE_MODEL \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.rollout.multi_turn.enable=False \ - actor_rollout_ref.rollout.mode="sync" \ - +actor_rollout_ref.model.override_config.attention_dropout=0. \ - +actor_rollout_ref.model.override_config.embd_pdrop=0. \ - +actor_rollout_ref.model.override_config.resid_pdrop=0. \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - reward_model.reward_manager=async_multi_process \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - trainer.logger=['console','wandb'] \ - trainer.project_name=${WANDB_PROJECT} \ - trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ - trainer.val_before_train=True \ - trainer.n_gpus_per_node=${NUM_GPUS} \ - trainer.nnodes=1 \ - trainer.save_freq=10 \ - trainer.test_freq=10 \ - trainer.total_epochs=10 \ - trainer.log_val_generations=50 \ - trainer.resume_mode=auto \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 7f56b8667..000000000 --- a/setup.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# setup.py is the fallback installation script when pyproject.toml does not work -import os -from pathlib import Path - -from setuptools import find_packages, setup - -version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) - -with open(os.path.join(version_folder, "verl/version/version")) as f: - __version__ = f.read().strip() - -install_requires = [ - "accelerate", - "codetiming", - "datasets", - "dill", - "hydra-core", - "numpy<2.0.0", - "pandas", - "peft", - "pyarrow>=19.0.0", - "pybind11", - "pylatexenc", - "ray[default]>=2.41.0", - "torchdata", - "tensordict>=0.8.0,<=0.9.1,!=0.9.0", - "transformers", - "wandb", - "packaging>=20.0", - # NOTE: added by Reasoning360 - "langdetect", - "immutabledict", - "nltk", - "polars" -] - -TEST_REQUIRES = ["pytest", "pre-commit", "py-spy", "pytest-asyncio"] -PRIME_REQUIRES = ["pyext"] -GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"] -GPU_REQUIRES = ["liger-kernel", "flash-attn", "nvitop",] # NOTE: nvitop is added by Reasoning360 -MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency -VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.1,!=0.9.0", "vllm>=0.7.3,<=0.8.5"] -SGLANG_REQUIRES = [ - "tensordict>=0.8.0,<=0.9.1,!=0.9.0", - "sglang[srt,openai]==0.4.6.post5", - "torch-memory-saver>=0.0.5", - "torch==2.6.0", -] -TRL_REQUIRES = ["trl<=0.9.6"] -MCORE_REQUIRES = ["mbridge"] - -extras_require = { - "test": TEST_REQUIRES, - "prime": PRIME_REQUIRES, - "geo": GEO_REQUIRES, - "gpu": GPU_REQUIRES, - "math": MATH_REQUIRES, - "vllm": VLLM_REQUIRES, - "sglang": SGLANG_REQUIRES, - "trl": TRL_REQUIRES, - "mcore": MCORE_REQUIRES, -} - - -this_directory = Path(__file__).parent -long_description = (this_directory / "README.md").read_text() - -setup( - name="verl", - version=__version__, - package_dir={"": "."}, - packages=find_packages(where="."), - url="https://github.com/volcengine/verl", - license="Apache 2.0", - author="Bytedance - Seed - MLSys", - author_email="zhangchi.usc1992@bytedance.com, gmsheng@connect.hku.hk", - description="verl: Volcano Engine Reinforcement Learning for LLM", - install_requires=install_requires, - extras_require=extras_require, - package_data={ - "": ["version/*"], - "verl": ["trainer/config/*.yaml"], - }, - include_package_data=True, - long_description=long_description, - long_description_content_type="text/markdown", -) diff --git a/src/reasoning360.egg-info/PKG-INFO b/src/reasoning360.egg-info/PKG-INFO new file mode 100644 index 000000000..52d5cadab --- /dev/null +++ b/src/reasoning360.egg-info/PKG-INFO @@ -0,0 +1,254 @@ +Metadata-Version: 2.4 +Name: reasoning360 +Version: 0.1.0 +Summary: Reasoning360 extension for Verl +Classifier: Programming Language :: Python :: 3 +Requires-Python: >=3.9 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: verl +Requires-Dist: hydra-core +Requires-Dist: langdetect +Requires-Dist: immutabledict +Requires-Dist: nltk +Requires-Dist: polars +Requires-Dist: nvitop +Dynamic: license-file + +# Reasoning360 + +

+ + Paper + + + Dataset + + + Model + + + Wandb Log + + +

+ + +This is the official repository of Reasoning360, a project dedicated to *open research on large-scale reasoning models*. The repository currently includes data processing and filtering tools, reinforcement learning (RL) training pipelines, and an evaluation suite. It's initialized from [veRL](https://github.com/volcengine/verl). + +## 🔥News ++ *08/20/2025*: The full [wandb logs](https://wandb.ai/mbzuai-llm/Guru/) for Guru-7B/32B training is public now. ++ Our paper to analyze and improve multi-domain RL for LLM reasoning with Guru data "[Revisiting Reinforcement Learning for LLM Reasoning from A Cross-Domain Perspective](https://arxiv.org/abs/2506.14965)" is out on arxiv. Also, we release the [model](https://huggingface.co/LLM360/guru-32B) and [data](https://huggingface.co/datasets/LLM360/guru-RL-92k). ++ The ready-to-train 92K Guru RL data across six domains is released under [LLM360 huggingface](https://huggingface.co/datasets/LLM360/guru_RL). + + +--- +## Table of Contents +- [Installation](#installation) +- [Data preparation](#data-preparation) +- [RL Training](#rl-training) + - [(1) Download data](#1-download-data) + - [(2) \[Optional\] Customize chat template](#2-optional-customize-chat-template) + - [(3) \[Optional\] SandboxFusion Code Execution](#3-optional-sandboxfusion-code-execution) + - [(4) Train](#4-train) +- [Evaluation](#evaluation) +- [Contributing](#contributing) + - [Add a new dataset for training (or evaluation)](#add-a-new-dataset-for-training-or-evaluation) + - [Pre-commit](#pre-commit) + - [Pull Request](#pull-request) + + +--- + +## Installation + +```bash +git clone git@github.com:LLM360/Reasoning360.git +cd Reasoning360 +``` + + +```bash +conda create -n Reasoning360 python=3.12 +conda activate Reasoning360 +conda install -c nvidia/label/cuda-12.4.0 cuda-toolkit cuda-nvcc +pip install uv # using uv to install packages is faster than pip +uv pip install torch==2.6.0 +uv pip install flash-attn==2.7.3 --no-build-isolation +uv pip install -e .[gpu,math,vllm,test] +``` + +Alternatively, you can refer to verl [installment guidance](https://verl.readthedocs.io/en/latest/index.html) for setup. + + +--- +## Data preparation +The full ready-to-train 92K Guru RL data is already released under [LLM360 huggingface](https://huggingface.co/datasets/LLM360/guru_RL)! If you would like to build (or experience) the data pipeline from scratch, we also provide detailed guidances for [data preparation](./data_preprocess/README.md) and [filtering by data difficulty levels](./model_filtering/README.md). + +Quick data check: +```python +import json +from datasets import load_dataset + +# Load dataset +train_data = load_dataset("LLM360/guru-RL-92k", split="train", streaming=True) + +print(f"Columns: {train_data.column_names}") +print(f"First item: {next(iter(train_data))}") +``` + +--- +## RL Training +### (1) Download data +Download the data and prepare them into `.parquet`, the expected default format in training script. We provide a simple script to download and organize Guru data `scripts/tools/download_guru.py`, with all dataset files for training, online & offline evaluation to local directories. +By defauly, training files will be put in `./data/train`. Online evaluation files will be put in `./data/online_eval`. Offline evaluation files will be put in `./data/offline_eval`. + +### (2) [Optional] Customize chat template +Run `tools/change_tokenizer_config.py` if you want to apply 'think'-aware chat template. Now only the 'Qwen' families are supported. +```python +python tools/change_tokenizer_config.py -i -o +``` + +### (3) [Optional] SandboxFusion Code Execution + +SandboxFusion provides secure code execution for training and evaluation. It supports both containerized SLURM deployment and local installation. + +#### Quick Setup + +**Option 1: SLURM Container (Recommended for production)** +```bash +# Download container +enroot import docker://varad0309/code_sandbox:server + +# Deploy with SLURM +sbatch scripts/sandbox/run_server.sbatch +``` + +**Option 2: Local Installation (Development only)** +```bash +git clone https://github.com/bytedance/SandboxFusion.git +cd SandboxFusion +poetry install +make run-online +``` + +#### Configuration + +Configure sandbox servers in your training script: + +```bash +# Single server +export SANDBOX_FUSION_SERVERS="fs-mbz-gpu-044" + +# Multiple servers (load balancing) +export SANDBOX_FUSION_SERVERS="fs-mbz-gpu-044,fs-mbz-gpu-045" +``` + +Or programmatically: +```python +from verl.utils.reward_score.coder1.sandboxfusion_exec import code_exec_sandboxfusion + +# Single server +success, output = code_exec_sandboxfusion( + code="print('Hello')", + sandbox_servers="fs-mbz-gpu-044" +) + +# Multiple servers +success, output = code_exec_sandboxfusion( + code="print('Hello')", + sandbox_servers=["fs-mbz-gpu-044", "fs-mbz-gpu-045"] +) +``` + +For detailed setup instructions, see [`verl/utils/reward_score/coder1/README.md`](verl/utils/reward_score/coder1/README.md). + + +### (4) Train +We provide the multi-node training slurm script using a `math3k` subset data for ablation, not the full data. Change the `SHARED_DATA_PATH` upon your data path. +```bash +sbatch scripts/train/example_multinode_rl_qwen2.5_32b_base_fsdp.sh +``` + +If you need to train on the full data or include STEM data in Guru, host the llm-as-verifier model first before launching the training. +```bash +sbatch scripts/tools/serve_llm_as_verifier.sh +``` +Then fill in the `export STEM_LLM_JUDGE_URL=""` by the llm-as-verifier server IP. It uses one GPU node to serve a 1.5B [general-verifier](https://huggingface.co/TIGER-Lab/general-verifier) now. + +--- +## Evaluation +We provide a evaluation suite of of 17 tasks supporting multi-node inference based on [verl](https://github.com/volcengine/verl). For quick start, run +```bash +sbatch scripts/offline_eval/example_multinode_eval_guru7b.sh +``` +Please refer to `scripts/offline_eval/README.md` if you would like to know and customize evaluation details. + +--- +## Contributing +### Add a new dataset for training (or evaluation) + +**Step1: Data preprocessing script** + +In preprocessing, we will process the data into a list of dictionaries, and then save it into a parquet file. + +1. Prompt preprocessing + + We need to process the raw question into a prompt ready to be fed to the LLM. An example is [[1](data_preprocess/math/dapo_or1_merge_deduped.py)]. + + Each data point is processed into a dict, and we need to specify the prompt within the data dict: + ``` + "prompt": [{ + "role": "user", + "content": prompt + }], + ``` + + Note that, when we use verl to train the model, it will turn into a prompt string with `apply_chat_template`. + + Note that: + - You will probably need to add some task-specific instruction in the `question`. E.g., for math, we concatenate the raw problem with `Please output the final answer within \\boxed{}.`, so that it's easy to extract the answer from model output. + - You don't need to instruct the model to "think step by step" or "wrap your thinking process in `` `<\think>`". This should be taken care by verl during training with `apply_chat_template`. To enable this, we have a [script](scripts/tools/change_tokenizer_config.py) to modify the chat template of a huggingface model (currently only tested on Qwen). + - Please add an instruction under the README of `data_preprocess` + +2. Reward function + + We need to specify the information regarding reward calculation for the new dataset. + + This typically includes three keys in the dict: `data_source`, `reward_model["ground_truth"]`, `extra_info`. + + In our training, we use [`default_compute_score`](verl/utils/reward_score/__init__.py#L17), which routes the reward computing to a specific reward function implementation based on `data_source`. `ground_truth` and `extra_info` will be passed as arguments. + +**Step2: Reward function** + +Please look at [`default_compute_score`](verl/utils/reward_score/__init__.py#L17). You can write your own reward function for the task, and import it here. It's highly recommended to add a timeout module to avoid the training being stuck by a corner case of reward function ([example](verl/utils/reward_score/zebra_puzzle.py)). + +**Step3: Training script** + +Verify the inclusion of a new dataset by actually training models with it. Please refer the template script in this repo. + +### Pre-commit + +We use pre-commit to enforce code formatting. Before committing, make sure you have run the pre-commit checks. +```bash +pre-commit install +pre-commit run --all-files +``` + +### Pull Request + +Please make a pull request including the data preprocessing script, reward function, and the training script. + + +## Citation +If you find the repo helpful, please cite: +``` +@misc{cheng2025revisiting, + title = {Revisiting Reinforcement Learning for LLM Reasoning from A Cross-Domain Perspective}, + author = {Zhoujun Cheng and Shibo Hao and Tianyang Liu and Fan Zhou and Yutao Xie and Feng Yao and Yuexin Bian and Yonghao Zhuang and Nilabjo Dey and Yuheng Zha and Yi Gu and Kun Zhou and Yuqi Wang and Yuan Li and Richard Fan and Jianshu She and Chengqian Gao and Abulhair Saparov and Haonan Li and Taylor W. Killian and Mikhail Yurochkin and Zhengzhong Liu and Eric P. Xing and Zhiting Hu}, + journal = {arXiv preprint arXiv:2506.14965}, + year = {2025}, + doi = {10.48550/arXiv.2506.14965}, + url = {https://arxiv.org/abs/2506.14965} +} +``` diff --git a/src/reasoning360.egg-info/SOURCES.txt b/src/reasoning360.egg-info/SOURCES.txt new file mode 100644 index 000000000..e0b653618 --- /dev/null +++ b/src/reasoning360.egg-info/SOURCES.txt @@ -0,0 +1,148 @@ +LICENSE +README.md +pyproject.toml +src/reasoning360/__init__.py +src/reasoning360.egg-info/PKG-INFO +src/reasoning360.egg-info/SOURCES.txt +src/reasoning360.egg-info/dependency_links.txt +src/reasoning360.egg-info/requires.txt +src/reasoning360.egg-info/top_level.txt +src/reasoning360/experimental/agent_loop/agent_loop.py +src/reasoning360/recipe/__init__.py +src/reasoning360/recipe/fully_async_policy/detach_utils.py +src/reasoning360/recipe/fully_async_policy/fsdp2_utils.py +src/reasoning360/recipe/fully_async_policy/fsdp_workers.py +src/reasoning360/recipe/fully_async_policy/fully_async_main.py +src/reasoning360/recipe/fully_async_policy/fully_async_rollouter.py +src/reasoning360/recipe/fully_async_policy/fully_async_trainer.py +src/reasoning360/recipe/fully_async_policy/megatron_utils.py +src/reasoning360/recipe/fully_async_policy/megatron_worker.py +src/reasoning360/recipe/fully_async_policy/message_queue.py +src/reasoning360/recipe/fully_async_policy/param_sync.py +src/reasoning360/recipe/fully_async_policy/ray_trainer.py +src/reasoning360/recipe/fully_async_policy/agent_loop/__init__.py +src/reasoning360/recipe/fully_async_policy/agent_loop/agent_loop copy.py +src/reasoning360/recipe/fully_async_policy/agent_loop/agent_loop.py +src/reasoning360/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py +src/reasoning360/recipe/fully_async_policy/agent_loop/partial_tool_agent_loop.py +src/reasoning360/recipe/fully_async_policy/vllm_rollout/__init__.py +src/reasoning360/recipe/fully_async_policy/vllm_rollout/vllm_async_server.py +src/reasoning360/trainer/__init__.py +src/reasoning360/trainer/main_generation.py +src/reasoning360/trainer/main_ppo.py +src/reasoning360/trainer/config/__init__.py +src/reasoning360/trainer/config/algorithm.py +src/reasoning360/trainer/config/config.py +src/reasoning360/trainer/ppo/metric_utils.py +src/reasoning360/trainer/ppo/ray_trainer.py +src/reasoning360/trainer/ppo/reward.py +src/reasoning360/utils/__init__.py +src/reasoning360/utils/dataset/rl_dataset.py +src/reasoning360/utils/dataset/sft_dataset.py +src/reasoning360/utils/reward_score/__init__.py +src/reasoning360/utils/reward_score/arcagi.py +src/reasoning360/utils/reward_score/codeio.py +src/reasoning360/utils/reward_score/deepmath.py +src/reasoning360/utils/reward_score/deepmath_test.py +src/reasoning360/utils/reward_score/geo3k.py +src/reasoning360/utils/reward_score/gpqa.py +src/reasoning360/utils/reward_score/graph_dataset.py +src/reasoning360/utils/reward_score/gsm8k.py +src/reasoning360/utils/reward_score/math.py +src/reasoning360/utils/reward_score/math_batch.py +src/reasoning360/utils/reward_score/math_dapo.py +src/reasoning360/utils/reward_score/math_verify.py +src/reasoning360/utils/reward_score/naive_dapo.py +src/reasoning360/utils/reward_score/nemotron_stem.py +src/reasoning360/utils/reward_score/nemotron_stem_test.py +src/reasoning360/utils/reward_score/puzzles_dataset.py +src/reasoning360/utils/reward_score/search_r1_like_qa_em.py +src/reasoning360/utils/reward_score/supergpqa.py +src/reasoning360/utils/reward_score/tablereason.py +src/reasoning360/utils/reward_score/zebra_puzzle.py +src/reasoning360/utils/reward_score/coder1/__init__.py +src/reasoning360/utils/reward_score/coder1/bwrap_exec.py +src/reasoning360/utils/reward_score/coder1/ces_exec.py +src/reasoning360/utils/reward_score/coder1/docker_exec.py +src/reasoning360/utils/reward_score/coder1/firejail_exec.py +src/reasoning360/utils/reward_score/coder1/kira_exec.py +src/reasoning360/utils/reward_score/coder1/sandboxfusion_exec.py +src/reasoning360/utils/reward_score/coder1/unsafe_local_exec.py +src/reasoning360/utils/reward_score/coder1/utils.py +src/reasoning360/utils/reward_score/cruxeval/__init__.py +src/reasoning360/utils/reward_score/cruxeval/cruxeval.py +src/reasoning360/utils/reward_score/cruxeval/utils.py +src/reasoning360/utils/reward_score/ifbench/__init__.py +src/reasoning360/utils/reward_score/ifbench/check_ifbench_data.py +src/reasoning360/utils/reward_score/ifbench/instructions.py +src/reasoning360/utils/reward_score/ifbench/instructions_registry.py +src/reasoning360/utils/reward_score/ifbench/instructions_util.py +src/reasoning360/utils/reward_score/ifbench/split_fixed_data.py +src/reasoning360/utils/reward_score/ifbench/test_ifbench.py +src/reasoning360/utils/reward_score/ifeval/__init__.py +src/reasoning360/utils/reward_score/ifeval/instructions.py +src/reasoning360/utils/reward_score/ifeval/instructions_registry.py +src/reasoning360/utils/reward_score/ifeval/instructions_util.py +src/reasoning360/utils/reward_score/livebench/__init__.py +src/reasoning360/utils/reward_score/livebench/util.py +src/reasoning360/utils/reward_score/livebench/data_analysis/cta/utils.py +src/reasoning360/utils/reward_score/livebench/data_analysis/tablejoin/utils.py +src/reasoning360/utils/reward_score/livebench/data_analysis/tablereformat/utils.py +src/reasoning360/utils/reward_score/livebench/reasoning/house_traversal/utils.py +src/reasoning360/utils/reward_score/livebench/reasoning/spatial/utils.py +src/reasoning360/utils/reward_score/livebench/reasoning/web_of_lies_v2/utils.py +src/reasoning360/utils/reward_score/livebench/reasoning/web_of_lies_v3/utils.py +src/reasoning360/utils/reward_score/livebench/reasoning/zebra_puzzle/utils.py +src/reasoning360/utils/reward_score/livebench/writing/connections/utils.py +src/reasoning360/utils/reward_score/livebench/writing/plot_unscrambling/utils.py +src/reasoning360/utils/reward_score/livebench/writing/typos/utils.py +src/reasoning360/utils/reward_score/math_llm_judge/__init__.py +src/reasoning360/utils/reward_score/math_llm_judge/grader.py +src/reasoning360/utils/reward_score/math_llm_judge/math_normalize.py +src/reasoning360/utils/reward_score/orz/__init__.py +src/reasoning360/utils/reward_score/orz/math_utils.py +src/reasoning360/utils/reward_score/orz/math_utils_sync.py +src/reasoning360/utils/reward_score/prime_code/__init__.py +src/reasoning360/utils/reward_score/prime_code/testing_util.py +src/reasoning360/utils/reward_score/prime_code/utils.py +src/reasoning360/utils/reward_score/prime_math/__init__.py +src/reasoning360/utils/reward_score/prime_math/grader.py +src/reasoning360/utils/reward_score/prime_math/math_normalize.py +src/reasoning360/utils/reward_score/reasoning_gym/__init__.py +src/reasoning360/utils/reward_score/sandbox_fusion/__init__.py +src/reasoning360/utils/reward_score/sandbox_fusion/utils.py +src/reasoning360/utils/reward_score/stem_llm_judge/__init__.py +src/reasoning360/utils/reward_score/synlogic/__init__.py +src/reasoning360/utils/reward_score/synlogic/arrow_maze_verifier.py +src/reasoning360/utils/reward_score/synlogic/boolean_expressions_verifier.py +src/reasoning360/utils/reward_score/synlogic/campsite_verifier.py +src/reasoning360/utils/reward_score/synlogic/data.py +src/reasoning360/utils/reward_score/synlogic/dyck_language_errors_verifier.py +src/reasoning360/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py +src/reasoning360/utils/reward_score/synlogic/dyck_language_verifier.py +src/reasoning360/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py +src/reasoning360/utils/reward_score/synlogic/goods_exchange_verifier.py +src/reasoning360/utils/reward_score/synlogic/math_path_verifier.py +src/reasoning360/utils/reward_score/synlogic/minesweeper_verifier.py +src/reasoning360/utils/reward_score/synlogic/norinori_verifier.py +src/reasoning360/utils/reward_score/synlogic/number_wall_verifier.py +src/reasoning360/utils/reward_score/synlogic/numbrix_verifier.py +src/reasoning360/utils/reward_score/synlogic/object_counting_verifier.py +src/reasoning360/utils/reward_score/synlogic/object_properties_verifier.py +src/reasoning360/utils/reward_score/synlogic/operation_verifier.py +src/reasoning360/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py +src/reasoning360/utils/reward_score/synlogic/space_reasoning_tree_verifier.py +src/reasoning360/utils/reward_score/synlogic/space_reasoning_verifier.py +src/reasoning360/utils/reward_score/synlogic/star_placement_puzzle_verifier.py +src/reasoning360/utils/reward_score/synlogic/synlogic.py +src/reasoning360/utils/reward_score/synlogic/time_sequence_verifier.py +src/reasoning360/utils/reward_score/synlogic/verifier.py +src/reasoning360/utils/reward_score/synlogic/web_of_lies_verifier.py +src/reasoning360/utils/reward_score/synlogic/word_sorting_mistake_verifier.py +src/reasoning360/utils/reward_score/synlogic/word_sorting_verifier.py +src/reasoning360/utils/reward_score/synlogic/wordscapes_verifier.py +src/reasoning360/workers/reward_manager/__init__.py +src/reasoning360/workers/reward_manager/async_mp.py +src/reasoning360/workers/reward_manager/dapo.py +src/reasoning360/workers/reward_manager/llm_judge.py +src/reasoning360/workers/reward_manager/naive_parallel.py \ No newline at end of file diff --git a/src/reasoning360.egg-info/dependency_links.txt b/src/reasoning360.egg-info/dependency_links.txt new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/reasoning360.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/reasoning360.egg-info/requires.txt b/src/reasoning360.egg-info/requires.txt new file mode 100644 index 000000000..30707e323 --- /dev/null +++ b/src/reasoning360.egg-info/requires.txt @@ -0,0 +1,7 @@ +verl +hydra-core +langdetect +immutabledict +nltk +polars +nvitop diff --git a/src/reasoning360.egg-info/top_level.txt b/src/reasoning360.egg-info/top_level.txt new file mode 100644 index 000000000..93dbbffeb --- /dev/null +++ b/src/reasoning360.egg-info/top_level.txt @@ -0,0 +1 @@ +reasoning360 diff --git a/verl/py.typed b/src/reasoning360/__init__.py similarity index 100% rename from verl/py.typed rename to src/reasoning360/__init__.py diff --git a/verl/utils/reward_score/cruxeval/cruxeval.py b/src/reasoning360/recipe/__init__.py similarity index 100% rename from verl/utils/reward_score/cruxeval/cruxeval.py rename to src/reasoning360/recipe/__init__.py diff --git a/src/reasoning360/recipe/fully_async_policy/README.md b/src/reasoning360/recipe/fully_async_policy/README.md new file mode 100644 index 000000000..72a135faa --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/README.md @@ -0,0 +1,525 @@ +# Recipe: Fully Async Policy Trainer + +**Author:** `https://github.com/meituan-search` + +Last updated: 10/18/2025. + +This document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter, +supporting asynchronous sample generation and training. +Under this system, we achieved a 2.35x-2.67x performance improvement when training the Qwen2.5-7B model with 128 GPUs, +without significantly affecting the results. + +## Introduction + +### Background + +The separated rollout and train architecture, compared to the colocate architecture, can allocate resources more +flexibly and design more flexible training logic, thereby addressing issues such as low GPU utilization and training +efficiency caused by long-tail problems. +The one_step_off_policy alleviates the problem of long rollout times and achieves some gains in training efficiency by +designing a separated architecture and performing asynchronous training between rollout and train for one round. +However, it forcibly uses data from one round of asynchronous training, which is not flexible enough and cannot +completely eliminate the impact of long-tail on training efficiency. +In other frameworks such as AReaL, Magistral, StreamRL, and AsyncFlow, asynchronous training and streaming training have +been implemented based on the separated architecture and have achieved gains. +We borrow from their methods and implemented them in VERL. The fully_async_policy supports asynchronous, streaming, and +partial +rollout training. +By reasonably setting parameters such as resource allocation and parameter synchronization frequency, fully_async_policy +can significantly improve training efficiency. + +> Magistral https://arxiv.org/abs/2506.10910 +> +> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language +> Reasoning https://arxiv.org/abs/2505.24298 +> +> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream +> Generation https://arxiv.org/abs/2504.15930 +> +> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663 +> + +### Core Contributions + +* **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to + specify the resources they occupy separately. +* **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples. +* **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to + multiple steps, making the asynchronous solution more flexible. +* **NCCL Parameter Synchronization**: Uses NCCL communication primitives for parameter communication between Rollouter + and Trainer. +* **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single + sample as the minimum transmission unit. +* **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it + supports training with samples generated by old parameters. +* **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter + synchronization, by adding `sleep() and resume()` logic, it + saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for + ongoing tasks to finish during parameter synchronization. + +Currently, the supported usage mode is megatron/fsdp+vllm. vllm must use the server mode based on AgentLoop. + +## Design + +The overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four +parts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer. + +![fully_async_policy_structure]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true) + +1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the + production speed controlled by freshness. +2. MessageQueue is used to temporarily store samples generated by Rollouter. +3. Trainer fetches samples from MessageQueue sample by sample. After fetching `require_batches*ppo_mini_batch_size` + samples, it will perform training. After training for async_training.trigger_parameter_sync_step rounds, it triggers + a parameter synchronization with Rollouter. +4. ParameterSynchronizer implements the NCCL synchronous parameter synchronization capability. + +The source of benefits compared to the base scheme lies in the fact that in the colocate case, using more resources for +rollout cannot solve the idleness caused by long-tail samples. +After we perform resource isolation, the time for rollout and train may be longer than before (because fewer resources +are used), +but the overlap in their time consumption reduces the end-to-end time consumption. + +![fully_async_policy_revenue]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true) + +## Usage + +### Parameter Description + +| super params | implication | +|-----------------------------------------------|------------------------------------------------------------------------------------------------| +| `trainer.nnodes` | Number of nodes for Trainer | +| `trainer.n_gpus_per_node` | Number of GPUs per node for Trainer | +| `rollout.nnodes` | Number of nodes for Rollouter | +| `rollout.n_gpus_per_node` | Number of GPUs per node for Rollouter | +| `data.train_batch_size` | In the fully async strategy, this value is not effective (default is 0) | +| `data.gen_batch_size` | In the fully async strategy, uses streaming sample production logic (default is 1) | +| `rollout.total_rollout_steps` | Total number of rollout samples | +| `rollout.test_freq` | How many times Rollouter updates parameters before performing a validation | +| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus | +| `async_training.require_batches` | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once | +| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization | +| `async_training.staleness_threshold` | Freshness control | +| `async_training.partial_rollout` | Whether to perform partial_rollout | +| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout | +| `async_training.compute_prox_log_prob` | Whether to compute log_prob using the training model's parameters during the training phase. | | + +**Further Explanation:** + +* `rollout.total_rollout_steps` + + Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step: + `rollout.total_rollout_steps = data.train_batch_size * step`. + +* `async_training.trigger_parameter_sync_step` + + In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches + `require_batches * ppo_mini_batch_size` samples) before a parameter synchronization with Rollouter. + Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process + `trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` samples. + To fairly compare speed with colocate, trigger_parameter_sync_step should be set to + `data.train_batch_size / (require_batches * ppo_mini_batch_size)`. + +* `async_training.staleness_threshold` + + In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used. + + * staleness_threshold=0, indicates synchronous training. + Rollouter will generate a fixed number of samples between two parameter updates, the sample count is: + $$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$ + * staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous + calls. + Rollouter will generate at most the following number of samples between two parameter updates: + $$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$ + + num_staleness_sample represents the number of stale samples generated in excess during the last rollout. + + Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower, + trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples. + When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy. + To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1. + +* `async_training.partial_rollout` + + partial_rollout only actually takes effect when staleness_threshold>0. + +* `async_training.use_rollout_log_probs` + + In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to + the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling, + old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm + correctness. In the fully + async strategy, we default to old_log_prob being calculated by rollout rather than by trainer. + +* `async_training.require_batches` + + In streaming training, require_batches should be set to 1, indicating that training is performed after producing + enough ppo_mini_batch_size samples. + In actual testing, we found that if fewer samples are issued at once, due to the order of data distribution, it can + cause training instability and longer response lengths. + Here, we additionally provide require_batches for streaming distribution and control the number of samples + participating in training at once. + +* `async_training.compute_prox_log_prob` (experimental) + + During the training process, we observed that metrics and response lengths may become unstable in the later + stages of training. To mitigate this issue, we can use + the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html) + technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using + the training engine, which requires enabling this switch. + Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d + (async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`. + +### Supported Modes + +1. on policy pipeline: + 1. **trigger_parameter_sync_step=1, staleness_threshold=0** + 2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for + training, and after training completes, Trainer and Rollouter perform a parameter synchronization; + 3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill + idle resources, causing some resource waste. + 4. As shown in figure a; + +2. stream off policy pipeline: + 1. **trigger_parameter_sync_step>1, staleness_threshold=0** + 2. Synchronous streaming training will be performed. Rollouter produces + `require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local + training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training + trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization; + 3. Compared to a, since more samples are generated at once, resource idleness will be lower. + 4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples, + train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter + update, rollout waits for training to complete. + 5. As shown in figure b; + +3. async stream pipeline with stale samples: + 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False** + 2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number + of samples generated may be less than this value depending on rollout speed). + 3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples + before parameter synchronization for immediate use by Trainer after synchronization. + When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete + and not add new tasks; + 4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the + first batch rollout to finish, but will have the time to wait for active tasks to finish. + 5. As shown in figure c; + +4. async stream pipeline with partial rollout: + 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True** + 2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will + interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be + generated after synchronization. This reduces the time to wait for active tasks to finish. + 3. As shown in figure d; + +![fully_async_policy_mode]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true) + +### Key Metrics + +| metrics | implication | +|------------------------------------------------|--------------------------------------------------------------------------------------------------------| +| `trainer/idle_ratio` | Trainer idle rate | +| `rollouter/idle_ratio` | Rollouter idle rate | +| `fully_async/count/stale_samples_processed` | Total number of old samples used in training | +| `fully_async/count/stale_trajectory_processed` | Total number of old trajectories used in training (one sample produces rollout.n trajectories) | +| `fully_async/partial/total_partial_num` | Number of partial samples processed by Trainer between two trigger_parameter_sync_step | +| `fully_async/partial/partial_ratio` | Ratio of partial samples processed by Trainer between two trigger_parameter_sync_step | +| `fully_async/partial/max_partial_span` | Maximum parameter span of partial samples processed by Trainer between two trigger_parameter_sync_step | + +### Parameter Tuning Recommendations + +* Resource Allocation and Adjustment: + * Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource + allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire + training process, + avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource + allocation can be adjusted based on the idle time of rollout and train during actual training, + which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and + trainer/idle_ratio is low, + Trainer resources should be increased and Rollouter resources should be reduced, and vice versa. + +* Key Parameters: + * staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It + is recommended to set it to less than 1. + * require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and + the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample + processing; + * trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent + parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in + low resource utilization. + The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy. + * rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small. + +* Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at + different levels, suitable for tasks in different scenarios. + * For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed + requirements, the on policy pipeline mode (Mode 1) can be tried. + * For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy + pipeline mode can be tried. That is, by + setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization + mechanism (staleness_threshold=0) (Mode 2). + * For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and + staleness, setting staleness_threshold> + 0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4). + +### Quick Start + +```shell +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*400))) +test_freq=10 +staleness_threshold=0 +trigger_parameter_sync_step=16 +partial_rollout=False + + +python -m recipe.fully_async_policy.fully_async_main \ + train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.partial_rollout="${partial_rollout}" +``` + +## Experiments + +### Asynchronous Training on 7B Model + +We used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under long candidates and multiple resources. +Using the `async stream pipeline with stale samples` strategy, we achieved about 2x performance improvement on 32 cards, +64 cards, and 128 cards without significantly affecting experimental results. + +* Machine: H20 +* Model: Qwen2.5-Math-7B +* Rollout length: max_response_length FSDP2: 28K tokens; +* Algorithm: DAPO +* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet +* Engine: vllm+FSDP2 +* rollout.n: 16 +* ppo_mini_batch_size: 32 +* test_freq: 20 + +* colocate sync: + * step: 400 + * train_batch_size: 512 + +* fully_async_policy + * total_rollout_steps: 512*400 + * require_batches: 4 + * trigger_parameter_sync_step: 4 + * staleness_threshold: 0.5 + * partial_rollout: True + +| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:---------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-------------------------------:| +| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 269.80 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313
last: 0.2448 | +| fully_async_policy | 16:16 | 294.77 | 21.26 | \ | 313.81 | 7h 58m
(1.72x) | 16h 21m
(1.70x) | 1d 0h 53m
(2.31x) | 1d 9h 26m
(2.66x) | max: 0.3302
last: 0.2333 | +| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365
last: 0.2333 | +| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m
(2.09x) | 10h 14m
(2.03x) | 16h 58m
(1.83x) | 21h 40m
(1.92x) | max: 0.3677
last: 0.3406 | +| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 | +| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m
(2.67x) | 6h 46m
(2.65x) | 10h 53m
(2.67x) | 17h 22m
(2.35x) | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg + +### 128-card 7B Asynchronous Mode Experiment + +We used Qwen2.5-Math-7B to verify the effects of various modes supported by fully async. +We can see that the benefit brought by streaming is approximately 1.6x, and after combining staleness and +partial_rollout, the benefit reaches 2.35x. + +| mode | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:-------------------------------------------------------------------------------------------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:| +| colocate sync | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 | +| `stream off policy pipeline`
(+fully async: trigger_parameter_sync_step= 4,
require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 | +| `async stream pipeline with stale samples`
(+staleness_threshold=0.5) | | | | | | | | | | +| `async stream pipeline with partial rollout`
(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg + +### 128-card Stale Ablation Experiment + +Under the `async stream pipeline with partial rollout` mode, we verified the impact of staleness settings on training +efficiency. +We found that the larger the staleness, the more obvious the final gains. +We also noticed that the times for staleness values of 0.3 and 0.5 are quite close, because as the training steps +increase, the response length changes significantly, causing training instability. +Further analysis and optimization are needed for this issue. + +| staleness_threshold | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 | +| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542
last: 0.2979 | +| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469
last: 0.2865 | +| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg + +### 128-card 7B require_batches Ablation Experiment + +In multiple tests, we found that the number of samples issued each time in streaming affects the response length during +training, which in turn affects training time. We verified the impact on results by modifying +`async_training.require_batches`. + +| require_batches | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | acc/mean@1 | +|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349
last: 0.326 | +| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351
last: 0.3406 | +| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521
last: 0.3521 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg + +### 30B Model Mode Experiment + +We achieved a 1.7x performance improvement with `async stream pipeline with staleness samples` strategy on the +Qwen3-30B-A3B-Base model compared to the colocate setup. It is worth noting that this is far from the upper limit of +performance gains achievable through asynchrony. Firstly, the comparative experiments used a maximum response length of +only 8k, which is much shorter than the 20k sequence length in previous experiments, resulting in a less pronounced +rollout tail effect. Secondly, we adopted a highly skewed resource allocation, with rollout using 96 GPUs and trainer +using 32 GPUs, which is not an optimal configuration. During the experiments, we observed that the current verl +implementation imposes certain constraints, such as requiring data to be evenly divisible by the number of GPUs, making +resource adjustment less flexible. Additionally, as asynchronous training and deployment accelerate, the performance gap +is gradually narrowing. Therefore, enabling more flexible resource allocation and dynamic resource adjustment in the +future will be our next focus. + +* Machine: H20 +* Model: Qwen3-30B-A3B-Base +* Rollout length: max_response_length : 8K tokens; +* Algorithm: GRPO +* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet +* Engine: vllm+Megatron +* rollout.n: 16 +* ppo_mini_batch_size: 128 +* test_freq: 20 + +* colocate sync: + * step:400 + * train_batch_size: 512 + +* fully_async_policy + * total_rollout_steps: 512*400 + * trigger_parameter_sync_step: 512/128 = 4 + * staleness_threshold: 0.5 + * partial_rollout: True + +| Training Mode | Resource Allocation | Step | Gen | Old Log Prob | Ref | Update Actor | Total Time 100 Step | Total Time 200 Step | Total Time 300 Step | Total Time 400 Step | Acc/Mean@1 | +|--------------------|---------------------|--------|--------|--------------|-------|--------------|---------------------|---------------------|---------------------|---------------------|-----------------------------| +| Colocate Sync | 128 | 497.89 | 348.05 | 28.73 | 20.86 | 86.27 | 13h 36m | 1d 3h 48m | 1d 19h 4m | 2d 11h 39m | max: 0.3500
last: 0.3208 | +| Fully Async Policy | 96:32 | 282.75 | 22.06 | \ | 50.05 | 206.63 | 6h 45m (2.01x) | 14h 48m (1.88x) | 1d 0h 9m (1.78x) | 1d 10h 41m (1.72x) | max: 0.3813
last: 0.3448 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-30B?nw=nwuserhouzg | | | + +## Multi-Turn Tool Calling + +Referencing **recipe/retool** and **ToolAgentLoop**, we implemented **AsyncPartialToolAgentLoop**, a multi-turn +tool-calling loop that supports partial_rollout for **fully_async_policy**. + +### Core Design + +`AsyncPartialToolAgentLoop` inherits from `ToolAgentLoop` and is adapted for the asynchronous training mode of +`fully_async_policy`. When `partial_rollout=True`, the Rollouter interrupts ongoing generation tasks before +synchronizing parameters with the Trainer. `AsyncPartialToolAgentLoop` is capable of: + +1. **Interrupting Tasks**: Responding to an interrupt signal to save the current state. Currently, interruptions occur + during the `GENERATING` process or after other states have completed. +2. **Resuming Tasks**: Resuming execution from the saved state after parameter synchronization is complete, rather than + starting over. + +### How to Use + +RL training with multi-turn tool calling in `fully_async_policy` is similar to `recipe/retool`. It is enabled by +specifying `multi_turn` configurations in the config file. + +1. **SFT Stage**: First, the model should undergo SFT to learn how to follow tool-calling format instructions. +2. **Multi-turn Configuration**: In the `fully_async_policy` training configuration, set the following parameters: + ```yaml + actor_rollout_ref: + rollout: + multi_turn: + enable: True # AsyncPartialToolAgentLoop will be used by default in fully_async_policy mode + # Other multi_turn related configurations + ``` +3. **Async Parameters**: To improve efficiency, enable `partial_rollout` and `staleness_threshold` when using multi-turn + tool calling: + ```yaml + async_training: + partial_rollout: True + staleness_threshold: 0.5 + # Other async parameters + ``` +4. **Example**: See `recipe/fully_async_policy/shell/dapo_7b_async_retool.sh`. + +### Experimental Results + +To validate the performance of `fully_async_policy` on multi-turn tool-calling tasks, we compared it with the standard +`colocate` synchronous mode. Key parameter settings are as follows. + +* **SFT Model**: Based on `Qwen2.5-7B-Instruct`, trained for 6 epochs on the `ReTool-SFT` dataset +* **RL Algorithm**: DAPO +* **Dataset**: + * Train: `DAPO-Math-17k` + * Test: `aime_2025` +* **Resource and Mode Comparison**: + * `colocate sync`: 32 H20 gpus + * `fully_async_policy`: 16 gpus for Trainer + 16 gpus for Rollouter +* **Key Configurations**: + 1. **Tool Calling Configuration**: + * `multi_turn.enable: True` + * `multi_turn.max_user_turns: 16` + * `multi_turn.max_assistant_turns: 16` + * `multi_turn.tool_config_path: recipe/retool/sandbox_fusion_tool_config.yaml` + 2. **`colocate sync` Configuration**: + * `ppo_mini_batch_size: 16` + * `train_batch_size: 64` + 3. **`fully_async_policy` Configuration**: + * `ppo_mini_batch_size: 16` + * `trigger_parameter_sync_step: 4` + * `require_batches: 1` + * `staleness_threshold: 1` + * `partial_rollout: True` + +| training mode | Resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | aime_2025
acc/mean@30 | +|:--------------------:|:---------------------:|:---------:|:---------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:-------------------------------:| +| colocate | 32 | 375.47 | 228.03 | 35.19 | 111.84 | 9h 46m | 22h 28m | start:0.1078
last:0.2056 | +| fully_async_policy | 16: 16 | 221.36 | 40.59 | \ | 179.58 | 6h 19m
(1.55x) | 14h 4m
(1.60x) | start:0.11
last:0.2044 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-multiturn-tool?nw=nwuserhouzg + +## Future Plans + +* GRPO experiments +* Megatron adaptation +* SGLang integration +* Transfer queue integration +* Asynchronous parameter synchronization +* AReaL asynchronous algorithm implementation +* TPPO algorithm implementation +* Multi-turn and Tool support \ No newline at end of file diff --git a/verl/single_controller/base/__init__.py b/src/reasoning360/recipe/fully_async_policy/agent_loop/__init__.py similarity index 61% rename from verl/single_controller/base/__init__.py rename to src/reasoning360/recipe/fully_async_policy/agent_loop/__init__.py index b24bd9942..ef46df0e5 100644 --- a/verl/single_controller/base/__init__.py +++ b/src/reasoning360/recipe/fully_async_policy/agent_loop/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .worker import Worker -from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup +from .agent_loop import FullyAsyncAgentLoopManager +from .partial_single_turn_agent_loop import PartialSingleTurnAgentLoop +from .partial_tool_agent_loop import AsyncPartialToolAgentLoop -__all__ = ["Worker", "WorkerGroup", "ClassWithInitArgs", "ResourcePool"] +_ = [PartialSingleTurnAgentLoop, AsyncPartialToolAgentLoop] +__all__ = [FullyAsyncAgentLoopManager] diff --git a/src/reasoning360/recipe/fully_async_policy/agent_loop/agent_loop.py b/src/reasoning360/recipe/fully_async_policy/agent_loop/agent_loop.py new file mode 100644 index 000000000..d4e258124 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/agent_loop/agent_loop.py @@ -0,0 +1,358 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os +from typing import Any, Optional, Sequence + +import hydra +import numpy as np +import ray +from omegaconf import DictConfig + +from reasoning360.recipe.fully_async_policy.vllm_rollout.vllm_async_server import FullyAsyncvLLMReplica +from verl.experimental.agent_loop.agent_loop import ( + AgentLoopManager, + AgentLoopOutput, + AgentLoopWorkerBase, + AsyncLLMServerManager, + _agent_loop_registry, + _DummyConfig, + get_trajectory_info, +) +from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config +from verl.protocol import DataProto +from verl.single_controller.ray import RayWorkerGroup +from verl.utils.rollout_trace import ( + rollout_trace_attr, + rollout_trace_op, +) + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class FullyAsyncLLMServerManager(AsyncLLMServerManager): + @rollout_trace_op + async def generate_for_partial( + self, + request_id, + *, + prompt_ids: list[int], + sampling_params: dict[str, Any], + image_data: Optional[list[Any]] = None, + ) -> tuple[list[Any], list[Any], Any] | tuple[Sequence[int], list[float], bool]: + """Generate tokens from prompt ids, used for async partial. + + Args: + request_id (str): request id for sticky session. + prompt_ids (List[int]): List of prompt token ids. + sampling_params (Dict[str, Any]): Sampling parameters for the chat completion. + + Returns: + output: A tuple representing the generation output. + - Element 0 (Sequence[int]): Generated response token IDs. + - Element 1 (list[float]): Log probabilities for the response token IDs. + - Element 2 (bool): A flag or status indicating cancellation. + """ + server = self._choose_server(request_id) + output = await server.generate_for_partial.remote( + request_id=request_id, + prompt_ids=prompt_ids, + sampling_params=sampling_params, + image_data=image_data, + ) + return output + + +@ray.remote +class FullyAsyncAgentLoopWorker(AgentLoopWorkerBase): + def __init__( + self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], reward_router_address: str = None + ): + self.server_manager = FullyAsyncLLMServerManager(config, server_handles) + super().__init__(config, server_handles, reward_router_address) + # A shared cancellation event for all agent loops running on this worker. + self.cancellation_event = asyncio.Event() + + async def generate_sequences_no_post( + self, batch: DataProto, partial_output_list: Optional[list[AgentLoopOutput]] + ) -> tuple[list[AgentLoopOutput], bool] | tuple[DataProto, bool]: + """Generate sequences from agent loop. + + Args: + batch (DataProto): Input batch. + partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result. + + Returns: + list[AgentLoopOutput]: List of agent loop outputs, one per sample in the batch. + """ + config = self.config.actor_rollout_ref.rollout + sampling_params = dict( + temperature=config.temperature, + top_p=config.top_p, + repetition_penalty=1.0, + logprobs=config.calculate_log_probs, + ) + + # override sampling params for validation + if batch.meta_info.get("validate", False): + sampling_params["top_p"] = config.val_kwargs.top_p + sampling_params["temperature"] = config.val_kwargs.temperature + + # by default, we assume it's a single turn agent + if "agent_name" not in batch.non_tensor_batch: + batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object) + + if "index" in batch.non_tensor_batch: + index = batch.non_tensor_batch["index"] + else: + index = np.arange(len(batch)) + + trajectory_info = await get_trajectory_info( + batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False) + ) + + if not partial_output_list: + partial_output_list = [None] * len(batch) + try: + tasks = [] + for i in range(len(batch)): + kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()} + kwargs["output"] = partial_output_list[i] + tasks.append( + asyncio.create_task(self._partial_run_agent_loop(sampling_params, trajectory_info[i], **kwargs)) + ) + outputs = await asyncio.gather(*tasks) + except Exception: + logger.exception("_partial_run_agent_loop failed") + raise + + is_cancel = any(output.extra_fields.get("is_cancel", False) for output in outputs) + if not is_cancel: + output = self._postprocess(outputs) + output = self._addition_process(output) + return output, is_cancel + return outputs, is_cancel + + def _addition_process(self, output: DataProto): + """collect metirics""" + metrics = output.meta_info.pop("metrics") # List[Dict[str, str]] + processing_times_list = [item["generate_sequences"] for item in metrics] + tool_calls_times_list = [item["tool_calls"] for item in metrics] + output.non_tensor_batch["processing_times"] = processing_times_list + output.non_tensor_batch["tool_calls_times"] = tool_calls_times_list + return output + + async def _partial_run_agent_loop( + self, + sampling_params: dict[str, Any], + trajectory: dict[str, Any], + *, + agent_name: str, + **kwargs, + ) -> AgentLoopOutput: + # Completed, return directly + if kwargs["output"] is not None and not kwargs["output"].extra_fields.get("is_cancel", False): + logger.info("In _partial_run_agent_loop, already completed, return derictly!") + return kwargs["output"] + try: + with rollout_trace_attr( + step=trajectory["step"], + sample_index=trajectory["sample_index"], + rollout_n=trajectory["rollout_n"], + validate=trajectory["validate"], + name="agent_loop", + ): + assert agent_name in _agent_loop_registry, ( + f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}" + ) + + agent_loop_config = _agent_loop_registry[agent_name] + agent_loop = hydra.utils.instantiate( + config=agent_loop_config, + trainer_config=_DummyConfig(config=self.config), + server_manager=self.server_manager, + tokenizer=self.tokenizer, + processor=self.processor, + ) + output: AgentLoopOutput = await agent_loop.run( + sampling_params, cancellation_event=self.cancellation_event, **kwargs + ) + + # Preserve important per-sample metadata (e.g., for metric aggregation). + # In async pipelines the final DataProto `non_tensor_batch` is reconstructed from + # `_InternalAgentLoopOutput.extra_fields`, so we must carry `data_source` through it. + if "data_source" in kwargs and "data_source" not in output.extra_fields: + output.extra_fields["data_source"] = kwargs["data_source"] + + if not output.extra_fields.get("is_cancel", False): + kwargs.pop("output", None) + output = await self._agent_loop_postprocess(output, **kwargs) + + return output + except Exception: + logger.exception("Agent_loop run failed") + raise + + async def cancel_agent_loops(self): + """Set the shared cancellation event to stop all agent loops.""" + self.cancellation_event.set() + + async def resume_agent_loops(self): + """Clear the shared cancellation event.""" + self.cancellation_event.clear() + + +class FullyAsyncAgentLoopManager(AgentLoopManager): + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None): + self.config = config + self.worker_group = worker_group + self.reward_model_manager = None + self.reward_router_address = None + self.agent_loop_workers_class = FullyAsyncAgentLoopWorker + self.rollout_replica_class = FullyAsyncvLLMReplica + + self.rm_wg = rm_wg + self.rollout_replicas = None + self.server_handles = None + self.server_addresses = None + self.agent_loop_workers = None + + def generate_sequences(self, prompts: DataProto) -> DataProto: + """Override to use cancellation-aware worker method.""" + self.wake_up() + if self.reward_model_manager: + self.reward_model_manager.wake_up() + + chunkes = prompts.chunk(len(self.agent_loop_workers)) + outputs_with_cancel = ray.get( + [ + worker.generate_sequences_no_post.remote(chunk, None) + for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) + ] + ) + outputs = [out[0] for out in outputs_with_cancel] + output = DataProto.concat(outputs) + + self.sleep() + if self.reward_model_manager: + self.reward_model_manager.sleep() + + # Metrics already in non_tensor_batch from _addition_process, add timing placeholder + output.meta_info = {"timing": {}, **outputs[0].meta_info} + return output + + @classmethod + async def create(cls, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None): + instance = cls(config, worker_group, rm_wg) + await instance._async_init() + return instance + + async def _async_init(self): + if self.config.reward_model.enable and self.config.reward_model.enable_resource_pool: + from verl.experimental.reward import RewardModelManager + + self.reward_model_manager = RewardModelManager(self.config.reward_model, self.rm_wg) + self.reward_router_address = self.reward_model_manager.get_router_address() + + await self._initialize_llm_servers_async() + self._init_agent_loop_workers() + + async def _initialize_llm_servers_async(self): + rollout_world_size = ( + self.config.actor_rollout_ref.rollout.tensor_model_parallel_size + * self.config.actor_rollout_ref.rollout.data_parallel_size + * self.config.actor_rollout_ref.rollout.pipeline_model_parallel_size + ) + world_size = ( + self.worker_group.world_size + if self.worker_group + else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes + ) + num_replicas = world_size // rollout_world_size + + rollout_config = self.config.actor_rollout_ref.rollout + model_config = self.config.actor_rollout_ref.model + self.rollout_replicas = [ + self.rollout_replica_class( + replica_rank=replica_rank, + config=rollout_config, + model_config=model_config, + gpus_per_node=self.config.trainer.n_gpus_per_node, + ) + for replica_rank in range(num_replicas) + ] + + if self.worker_group: + await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas]) + else: + await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas]) + + self.server_handles = [server._server_handle for server in self.rollout_replicas] + self.server_addresses = [server._server_address for server in self.rollout_replicas] + + print(f"AgentLoopManager: {self.server_addresses}") + # Update Prometheus configuration with server addresses + if rollout_config.prometheus.enable: + if rollout_config.disable_log_stats: + raise ValueError("PROMETHEUS needs disable_log_stats==False, but it is currently True.") + await asyncio.to_thread(update_prometheus_config, rollout_config.prometheus, self.server_addresses) + + async def generate_single_sample_async( + self, + sample: DataProto, + partial_output_list: Optional[list[AgentLoopOutput]], + ) -> tuple[list[AgentLoopOutput], bool] | tuple[DataProto, bool]: + """ + Asynchronously process a single sample + + Args: + sample: Single sample data + partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result. + + Returns: + list[AgentLoopOutput]: Processing results + """ + worker = self._select_best_worker() + output_future = worker.generate_sequences_no_post.remote(sample, partial_output_list) + return await asyncio.wrap_future(output_future.future()) + + def _select_best_worker(self): + """Select the best worker, simple round-robin load balancing""" + if not hasattr(self, "_worker_index"): + self._worker_index = 0 + + worker = self.agent_loop_workers[self._worker_index] + self._worker_index = (self._worker_index + 1) % len(self.agent_loop_workers) + return worker + + async def cancel(self): + worker_cancel_tasks = [worker.cancel_agent_loops.remote() for worker in self.agent_loop_workers] + rollout_cancel_tasks = [replica.cancel() for replica in self.rollout_replicas] + await asyncio.gather(*rollout_cancel_tasks, *worker_cancel_tasks) + + async def resume(self): + rollout_resume_tasks = [replica.resume() for replica in self.rollout_replicas] + worker_resume_tasks = [worker.resume_agent_loops.remote() for worker in self.agent_loop_workers] + await asyncio.gather(*rollout_resume_tasks, *worker_resume_tasks) + + async def wake_up(self): + await asyncio.gather(*[replica.wake_up() for replica in self.rollout_replicas]) + + async def sleep(self): + await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas]) + + async def clear_kv_cache(self): + await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas]) diff --git a/src/reasoning360/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py b/src/reasoning360/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py new file mode 100644 index 000000000..9dee10e9f --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py @@ -0,0 +1,115 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from verl.experimental.agent_loop import AgentLoopBase +from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register +from verl.utils.profiler import simple_timer + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@register("partial_single_turn_agent") +class PartialSingleTurnAgentLoop(AgentLoopBase): + """Naive agent loop thareciperecipet only do single turn chat completion.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length + self.response_length = self.config.actor_rollout_ref.rollout.response_length + self.apply_chat_template_kwargs = self.config.data.get("apply_chat_template_kwargs", {}) + + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: + output: Optional[AgentLoopOutput] = kwargs.get("output", None) + messages = list(kwargs["raw_prompt"]) + param_version = kwargs.get("param_version", 0) + + metrics = {} + request_id = uuid4().hex + image_data = (kwargs.get("multi_modal_data") or {}).get("image", None) + + param_version_start = param_version + param_version_end = param_version + + if not output: + # TODO(baiyan): it is supposed to use the correct processor, + # but I found the async training would hang if use_correct_processor=True. + # so we use the tokenizer to tokenize the prompt for now. + use_correct_processor = False + if self.processor is not None and use_correct_processor: + + def get_prompt_ids(): + raw_prompt = self.processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=False, + **self.apply_chat_template_kwargs, + ) + model_inputs = self.processor(text=[raw_prompt], images=image_data, return_tensors="pt") + return model_inputs.pop("input_ids").squeeze(0).tolist() + + prompt_ids = await self.loop.run_in_executor(None, get_prompt_ids) + else: + prompt_ids = await self.loop.run_in_executor( + None, + lambda: self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs + ), + ) + else: + if output.extra_fields.get("is_cancel", False): + # Resume the paused sample, + # add the result directly after prompt_ids, + # and reset generate_sequences metric + prompt_ids = output.prompt_ids + output.response_ids + metrics["generate_sequences"] = output.metrics.generate_sequences + param_version_start = output.extra_fields.get("param_version_start", param_version) + else: + # In the same batch of samples, + # some are canceled and some are not. + # The samples without partial rollout are returned directly. + return output + with simple_timer("generate_sequences", metrics): + response_ids, response_logprobs, is_cancel = await self.server_manager.generate_for_partial( + request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params, image_data=image_data + ) + if not output: + response_mask = [1] * len(response_ids) + else: + # Pause the sample to be resumed, add the output result to response_ids, and reset response_mask + prompt_ids = output.prompt_ids + response_logprobs = output.response_logprobs + response_logprobs + response_ids = output.response_ids + response_ids + response_mask = [1] * len(response_ids) + if len(response_ids) >= self.response_length: + is_cancel = False + + return AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[: self.response_length], + response_mask=response_mask[: self.response_length], + response_logprobs=response_logprobs[: self.response_length], + num_turns=2, + metrics=metrics, + extra_fields={ + "is_cancel": is_cancel, + "param_version_start": param_version_start, + "param_version_end": param_version_end, + }, + # multi_modal_data={"image": image_data} if image_data is not None else {}, + ) diff --git a/src/reasoning360/recipe/fully_async_policy/agent_loop/partial_tool_agent_loop.py b/src/reasoning360/recipe/fully_async_policy/agent_loop/partial_tool_agent_loop.py new file mode 100644 index 000000000..eb1040dfc --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/agent_loop/partial_tool_agent_loop.py @@ -0,0 +1,279 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import copy +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from verl.experimental.agent_loop.agent_loop import AgentLoopOutput, register +from verl.experimental.agent_loop.tool_agent_loop import AgentData, AgentState, ToolAgentLoop +from verl.utils.profiler import simple_timer + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@register("async_partial_tool_agent") +class AsyncPartialToolAgentLoop(ToolAgentLoop): + """ + Support for partial rollout with multiple tool invocations in Agent Loop + + """ + + def __init__(self, trainer_config, **kwargs): + super().__init__(trainer_config, **kwargs) + self.enable_partial_rollout = trainer_config.config.async_training.get("partial_rollout", False) + + # async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: + async def run( + self, sampling_params: dict[str, Any], *, cancellation_event: asyncio.Event = None, **kwargs + ) -> AgentLoopOutput: + """ + Main entrance, supports interruption/recovery + + Args: + sampling_params: Sampling parameters + cancellation_event: cancellationn sginal + **kwargs: Contains output (for recovery), raw_prompt, param_version, etc. + + Returns: + AgentLoopOutput: Include the is_cancel flag + """ + param_version = kwargs.get("param_version", 0) + agent_data = None + state = None + + # 1. check whether is the partial task + output: Optional[AgentLoopOutput] = kwargs.get("output", None) + if output and output.extra_fields.get("is_cancel", False): + agent_data, state = self._restore_from_output(output) + + logger.info(f"[PartialToolAgent] Resuming from {state.value}") + else: + if output and not output.extra_fields.get("is_cancel", False): + # Completed, return directly + return output + + agent_data = await self._init_agent_data(kwargs, param_version) + state = AgentState.PENDING + logger.info("[PartialToolAgent] Start from scratch") + # 2. run state machine + state = await self._run_state_machine(agent_data, state, sampling_params, cancellation_event) + + # 3. bulid output + if state == AgentState.TERMINATED: + return self._build_completed_output(agent_data, param_version) + else: + # build cancelled output + return self._build_cancelled_output(agent_data, state) + + async def _init_agent_data(self, kwargs: dict, param_version: int) -> AgentData: + messages = list(kwargs["raw_prompt"]) + image_data = copy.deepcopy(kwargs.get("multi_modal_data", {}).get("image", None)) + metrics = {} + request_id = uuid4().hex + tools_kwargs = kwargs.get("tools_kwargs", {}) + + # Initialize interaction if needed + interaction = None + interaction_kwargs = {} + if self.interaction_config_file: + interaction_kwargs = kwargs["extra_info"]["interaction_kwargs"] + if "name" not in interaction_kwargs: + raise ValueError("'name' key is required in interaction_kwargs") + interaction_name = interaction_kwargs["name"] + if interaction_name not in self.interaction_map: + raise ValueError( + f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: " + f"{list(self.interaction_map.keys())}" + ) + interaction = self.interaction_map[interaction_name] + await interaction.start_interaction(request_id, **interaction_kwargs) + # Create AgentData instance to encapsulate all state + agent_data = AgentData( + messages=messages, + image_data=image_data, + metrics=metrics, + request_id=request_id, + tools_kwargs=tools_kwargs, + interaction=interaction, + interaction_kwargs=interaction_kwargs, + ) + + # additional param version record + agent_data.extra_fields["param_version_start"] = param_version + agent_data.extra_fields["param_version_end"] = param_version + + return agent_data + + def _restore_from_output(self, output: AgentLoopOutput) -> tuple[AgentData, AgentState]: + """restore AgentState and AgentData from output""" + agent_data = output.extra_fields.get("agent_data", None) + agent_state = output.extra_fields.get("agent_state", None) + if agent_data is None or agent_state is None: + raise ValueError(f"Unexpected situation: agent_data is {agent_data}, agent_state is {agent_state}") + return agent_data, agent_state + + async def _run_state_machine( + self, + agent_data: AgentData, + state: AgentState, + sampling_params: dict[str, Any], + cancellation_event: asyncio.Event = None, + ) -> AgentState: + """ + State machine. + Currently, interruptions are only supported to occur in the GENERATING state or other states have ended. + """ + # State machine loop + while state != AgentState.TERMINATED: + if cancellation_event and cancellation_event.is_set(): + logger.info(f"[PartialToolAgent] Cancellation detected. Interrupted before/at state: {state.value}") + return state + if state == AgentState.PENDING: + state = await self._handle_pending_state(agent_data, sampling_params) + elif state == AgentState.GENERATING: + state = await self._handle_generating_state_partial(agent_data, sampling_params) + elif state == AgentState.PROCESSING_TOOLS: + state = await self._handle_processing_tools_state(agent_data) + elif state == AgentState.INTERACTING: + state = await self._handle_interacting_state(agent_data) + else: + logger.error(f"[PartialToolAgent] Invalid state: {state}") + return AgentState.TERMINATED + + return AgentState.TERMINATED + + async def _handle_generating_state_partial( + self, agent_data: AgentData, sampling_params: dict[str, Any], ignore_termination: bool = False + ) -> AgentState: + """ + Handle GENERATING state, support partial rollout + """ + add_messages: list[dict[str, Any]] = [] + + with simple_timer("generate_sequences", agent_data.metrics): + # partial interface + if self.enable_partial_rollout: + response_ids, log_probs, is_cancel = await self.server_manager.generate_for_partial( + request_id=agent_data.request_id, + prompt_ids=agent_data.prompt_ids, + sampling_params=sampling_params, + image_data=agent_data.image_data, + ) + + if is_cancel: + # Save the generated parts + agent_data.response_ids = response_ids + agent_data.prompt_ids += agent_data.response_ids + agent_data.response_mask += [1] * len(response_ids) + if log_probs: + agent_data.response_logprobs += log_probs + if not ignore_termination and len(agent_data.response_mask) >= self.response_length: + # If response_length has reached the limit, + # it is considered to have ended normally. + agent_data.assistant_turns += 1 + return AgentState.TERMINATED + return AgentState.GENERATING + else: + # original generate interface + output = await self.server_manager.generate( + request_id=agent_data.request_id, + prompt_ids=agent_data.prompt_ids, + sampling_params=sampling_params, + image_data=agent_data.image_data, + ) + response_ids = output.token_ids + log_probs = output.log_probs + + agent_data.assistant_turns += 1 + agent_data.response_ids = response_ids + agent_data.prompt_ids += agent_data.response_ids + agent_data.response_mask += [1] * len(agent_data.response_ids) + if log_probs: + agent_data.response_logprobs += log_probs + + if not ignore_termination and len(agent_data.response_mask) >= self.response_length: + return AgentState.TERMINATED + if self.max_assistant_turns and agent_data.assistant_turns >= self.max_assistant_turns: + return AgentState.TERMINATED + if self.max_user_turns and agent_data.user_turns >= self.max_user_turns: + return AgentState.TERMINATED + + # Extract tool calls + _, agent_data.tool_calls = await self.tool_parser.extract_tool_calls(agent_data.response_ids) + + # Handle interaction if needed + if self.interaction_config_file: + assistant_message = await self.loop.run_in_executor( + None, lambda: self.tokenizer.decode(agent_data.response_ids, skip_special_tokens=True) + ) + add_messages.append({"role": "assistant", "content": assistant_message}) + agent_data.messages.extend(add_messages) + + # Determine next state + if agent_data.tool_calls: + return AgentState.PROCESSING_TOOLS + elif self.interaction_config_file: + return AgentState.INTERACTING + else: + return AgentState.TERMINATED + + def _build_completed_output(self, agent_data: AgentData, param_version: int) -> AgentLoopOutput: + """build completed output""" + response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :] + prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)] + multi_modal_data = {"image": agent_data.image_data} if agent_data.image_data is not None else {} + output = AgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[: self.response_length], + response_mask=agent_data.response_mask[: self.response_length], + multi_modal_data=multi_modal_data, + response_logprobs=agent_data.response_logprobs[: self.response_length] + if agent_data.response_logprobs + else None, + num_turns=agent_data.user_turns + agent_data.assistant_turns + 1, + metrics=agent_data.metrics, + extra_fields={}, + ) + output.extra_fields.update( + { + "turn_scores": agent_data.turn_scores, + "tool_rewards": agent_data.tool_rewards, + "is_cancel": False, + "param_version_start": agent_data.extra_fields["param_version_start"], + "param_version_end": param_version, + } + ) + return output + + def _build_cancelled_output(self, agent_data: AgentData, state: AgentState) -> AgentLoopOutput: + """build cancelled output""" + return AgentLoopOutput( + prompt_ids=[], + response_ids=[], + response_mask=[], + multi_modal_data={}, + response_logprobs=None, + num_turns=0, + metrics=agent_data.metrics, + extra_fields={ + "is_cancel": True, + "agent_data": agent_data, + "agent_state": state, + }, + ) diff --git a/src/reasoning360/recipe/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml b/src/reasoning360/recipe/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml new file mode 100644 index 000000000..c88819864 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml @@ -0,0 +1,68 @@ +hydra: + searchpath: + - file://src/reasoning360/trainer/config + +defaults: + - ppo_megatron_trainer + - _self_ + +async_training: + + # Maximum samples staleness threshold + staleness_threshold: 0.1 + + # Frequency of parameter synchronization between rollouter and trainer, + # One step means trainer obtains a batch of required samples + trigger_parameter_sync_step: 4 + + # The number of ppo_mini_batches that the FullyAsyncTrainer obtains once + require_batches: 1 + + # When synchronizing parameters, whether to interrupt rollouter and perform partial rollout + partial_rollout: True + + # Whether to use rollout log probs for training + use_rollout_log_probs: True + + # compute_prox_log_prob + compute_prox_log_prob: False + +# Rollout config +rollout: + + # Number of nodes used in the rollout + nnodes: 1 + + # Number of GPUs per node + n_gpus_per_node: 8 + + # number of responses (i.e. num sample times). > 1 for grpo + n: 4 + + # total rollout samples # TODO rename to total_rollout_samples + total_rollout_steps: 100 + + # Number of epochs in training + total_epochs: 10 + + # Test frequency, how many times a parameter update triggers a validation + test_freq: 1 + +data: + # Number of samples generated, currently only support 1 + gen_batch_size: 1 + +actor_rollout_ref: + actor: + # Whether to use rollout log probs for training + use_rollout_log_probs: ${oc.select:async_training.use_rollout_log_probs, True} + rollout: + agent: + # Fully-async recipe uses partial rollouts by default; ensure the default + # agent loop name matches a registered agent loop implementation. + default_agent_loop: partial_single_turn_agent + +# Use Reasoning360 reward scoring inside verl's RewardLoopWorker/DAPORewardLoopManager +custom_reward_function: + path: src/reasoning360/utils/reward_score/__init__.py + name: default_compute_score diff --git a/src/reasoning360/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml b/src/reasoning360/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml new file mode 100644 index 000000000..fa9355d7e --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml @@ -0,0 +1,68 @@ +hydra: + searchpath: + - file://src/reasoning360/trainer/config + +defaults: + - ppo_trainer + - _self_ + +async_training: + + # Maximum samples staleness threshold + staleness_threshold: 0.1 + + # Frequency of parameter synchronization between rollouter and trainer, + # One step means trainer obtains a batch of required samples + trigger_parameter_sync_step: 4 + + # The number of ppo_mini_batches that the FullyAsyncTrainer obtains once + require_batches: 1 + + # When synchronizing parameters, whether to interrupt rollouter and perform partial rollout + partial_rollout: True + + # Whether to use rollout log probs for training + use_rollout_log_probs: True + + # compute_prox_log_prob + compute_prox_log_prob: False + +# Rollout config +rollout: + + # Number of nodes used in the rollout + nnodes: 1 + + # Number of GPUs per node + n_gpus_per_node: 8 + + # number of responses (i.e. num sample times). > 1 for grpo + n: 4 + + # total rollout samples # TODO rename to total_rollout_samples + total_rollout_steps: 100 + + # Number of epochs in training + total_epochs: 10 + + # Test frequency, how many times a parameter update triggers a validation + test_freq: 1 + +data: + # Number of samples generated, currently only support 1 + gen_batch_size: 1 + +actor_rollout_ref: + actor: + # Whether to use rollout log probs for training + use_rollout_log_probs: ${oc.select:async_training.use_rollout_log_probs, True} + rollout: + agent: + # Fully-async recipe uses partial rollouts by default; ensure the default + # agent loop name matches a registered agent loop implementation. + default_agent_loop: partial_single_turn_agent + +# Use Reasoning360 reward scoring inside verl's RewardLoopWorker/DAPORewardLoopManager +custom_reward_function: + path: src/reasoning360/utils/reward_score/__init__.py + name: default_compute_score diff --git a/src/reasoning360/recipe/fully_async_policy/detach_utils.py b/src/reasoning360/recipe/fully_async_policy/detach_utils.py new file mode 100644 index 000000000..6ca363771 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/detach_utils.py @@ -0,0 +1,359 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np +import torch + +from verl import DataProto +from verl.experimental.agent_loop.agent_loop import AgentLoopOutput +from reasoning360.trainer.ppo.ray_trainer import compute_response_mask + + +@dataclass +class RolloutSample: + """Enhanced rollout sample containing both original batch info and AgentLoopOutput""" + + # Original batch information + full_batch: Any + + # AgentLoopOutput from generation + agent_loop_output_list: list[AgentLoopOutput] + + # Metadata + sample_id: str + epoch: int + + # Processing metadata + processing_times: list[float] + tool_calls: list[float] + param_version: int + param_version_start: list[int] + param_version_end: list[int] + rollout_status: dict[str, Any] + + +@dataclass +class ValidateMetrics: + """Metrics for validation""" + + timing_raw: dict[str, Any] + metrics: Optional[dict[str, Any]] = None + global_steps: Optional[int] = None + param_version: Optional[int] = None + + +def prepare_single_generation_data(batch_dict, config) -> DataProto: + """ + Similar to the logic of ray_trainer._prepare_generate_batch, but for a single sample. + Separate the data used for generation from the original data. + + Returns: + tuple: (original_batch_dict, gen_data_for_single_sample) + """ + + full_batch = DataProto.from_single_dict(batch_dict) + + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + + full_batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + # Setting selected agent, that supports partial + if config.actor_rollout_ref.rollout.multi_turn.enable: + full_batch.non_tensor_batch["agent_name"] = np.array( + ["async_partial_tool_agent"] * len(full_batch), dtype=object + ) + else: + full_batch.non_tensor_batch["agent_name"] = np.array( + ["partial_single_turn_agent"] * len(full_batch), dtype=object + ) + + # Add global step count to generated data + full_batch = full_batch.repeat(repeat_times=config.actor_rollout_ref.rollout.n, interleave=True) + return full_batch + + +def assemble_batch_from_rollout_samples( + rollout_samples: list[RolloutSample], tokenizer, config, balance_batch=None +) -> DataProto: + """ + Assemble gen_batch_output from RolloutSample objects + Assembles batches from RolloutSample objects, similar to the _post_generate_batch logic in ray_trainer. + + Args: + rollout_samples: List of RolloutSample objects + tokenizer: Tokenizer instance + config: Configuration object containing trainer settings + balance_batch: Whether to balance the batch (simplified version) + + Returns: + DataProto: Assembled gen_batch_output + + Raises: + ValueError: If rollout_samples is empty + """ + start_time = time.time() + + if not rollout_samples: + raise ValueError("Empty rollout_samples provided for batch assembly") + + print(f"[BatchUtils] Assembling batch from {len(rollout_samples)} RolloutSample objects") + + rollout_samples_batch = [] + processing_times = [] + tool_calls = [] + rollout_status = rollout_samples[0].rollout_status + # Add a prefix to all rollout_status keys + rollout_status = {f"fully_async/{key}": value for key, value in rollout_status.items()} + + for rs in rollout_samples: + rollout_samples_batch.append(rs.full_batch) + final_batch = DataProto.concat(rollout_samples_batch) + + # Calculate response_mask (if not present) + if "response_mask" not in final_batch.batch.keys(): + final_batch.batch["response_mask"] = compute_response_mask(final_batch) + + if balance_batch: + balance_batch(final_batch, metrics={}) + + # Calculate the global valid token number + if "attention_mask" in final_batch.batch: + final_batch.meta_info["global_token_num"] = torch.sum(final_batch.batch["attention_mask"], dim=-1).tolist() + + processing_times = final_batch.non_tensor_batch["processing_times"] + tool_calls = final_batch.non_tensor_batch["tool_calls_times"] + # Collect statistics + + processing_time_stats = { + "processing_time/avg": np.mean(processing_times), + "processing_time/max": np.max(processing_times), + "processing_time/min": np.min(processing_times), + "processing_time/tp50": np.percentile(processing_times, 50), + "processing_time/tp99": np.percentile(processing_times, 99), + "processing_time/tp95": np.percentile(processing_times, 95), + } + tool_calls_stats = {} + if len(tool_calls) > 0: + tool_calls_stats = { + "timing_s/agent_loop/tool_calls/max": np.max(tool_calls), + "timing_s/agent_loop/tool_calls/min": np.min(tool_calls), + "timing_s/agent_loop/tool_calls/mean": np.mean(tool_calls), + } + processing_time_stats = {f"fully_async/{key}": value for key, value in processing_time_stats.items()} + + param_version_start = final_batch.non_tensor_batch["param_version_start"] + param_version_end = final_batch.non_tensor_batch["param_version_end"] + param_version_diff = [abs(a - b) for a, b in zip(param_version_end, param_version_start, strict=False)] + num_diff0 = param_version_diff.count(0) + partial_stats = { + "fully_async/partial/total_partial_num": len(param_version_diff) - num_diff0, + "fully_async/partial/partial_ratio": (len(param_version_diff) - num_diff0) / len(param_version_diff), + "fully_async/partial/max_partial_span": max(param_version_diff), + } + # add meta_info + param_versions = [rs.param_version for rs in rollout_samples] + trajectorys_param_versions = final_batch.non_tensor_batch["param_version_end"] + + final_batch.meta_info.update( + { + "rollout_param_versions": param_versions, + "param_version_diversity": len(set(param_versions)) if param_versions else 0, + "trajectory_param_versions": trajectorys_param_versions, + **processing_time_stats, + **rollout_status, + **partial_stats, + **tool_calls_stats, + } + ) + + print(f"[BatchUtils] Batch assembly completed in {time.time() - start_time:.2f}s") + + return final_batch + + +class MetricsAggregator: + """Metrics aggregator, used to combine metrics from multiple training steps""" + + def __init__(self, total_gpus: int): + # Store all values ​​for each metric + self.metric_values: dict[str, list[float]] = defaultdict(list) + # Store the number of samples at each step for weighted averaging + self.sample_counts: list[int] = [] + # Store the timestamp of each step for time-related calculations + self.timestamps: list[float] = [] + # Step Count + self.step_count = 0 + # total num gpus used + self.total_gpus = total_gpus + + # Metric aggregation rule configuration + self.aggregation_rules = self._init_aggregation_rules() + + def _init_aggregation_rules(self) -> dict[str, dict[str, list[str]]]: + """Initialize metrics aggregation rules""" + return { + # Time-Based metrics, can add metrics here + "time_sum": ["perf/time_per_step"], + "min": ["timing_s/agent_loop/tool_calls/min"], + "avg": ["timing_s/agent_loop/tool_calls/mean"], + "max": ["timing_s/agent_loop/tool_calls/max"], + "last": [ + "fully_async/count/total_generated_samples", + "fully_async/count/stale_samples_processed", + "fully_async/count/stale_trajectory_processed", + "fully_async/count/current_param_version", + "fully_async/count/dropped_stale_samples", + "training/global_step", # TODO change name to: total_step + ], + } + + def add_step_metrics(self, metrics: dict[str, Any], sample_count: int, timestamp: float = None): + """Adding a single-step metrics""" + if timestamp is None: + timestamp = time.time() + + self.sample_counts.append(sample_count) + self.timestamps.append(timestamp) + self.step_count += 1 + + # Store all metrics values + for key, value in metrics.items(): + if isinstance(value, int | float | np.number): + self.metric_values[key].append(float(value)) + elif isinstance(value, torch.Tensor): + self.metric_values[key].append(float(value.item())) + + def _get_aggregation_type(self, metric_name: str) -> str: + """Determine the aggregation type based on the metric name""" + for agg_type, metric_list in self.aggregation_rules.items(): + if metric_name in metric_list: + return agg_type + + metric_lower = metric_name.lower() + if any(keyword in metric_lower for keyword in ["timing_s/"]): + return "time_sum" + if any(keyword in metric_lower for keyword in ["mean", "avg", "average"]): + return "avg" + if any(keyword in metric_lower for keyword in ["max", "maximum"]): + return "max" + if any(keyword in metric_lower for keyword in ["min", "minimum"]): + return "min" + if any(keyword in metric_lower for keyword in ["sum", "total"]): + return "sum" + if any(keyword in metric_lower for keyword in ["weighted_avg"]): + return "weighted_avg" + + return "avg" + + def _aggregate_single_metric(self, metric_name: str, values: list[float]) -> float: + """Aggregating a single metric""" + if not values: + return 0.0 + + agg_type = self._get_aggregation_type(metric_name) + + if agg_type == "last": + return values[-1] + + elif agg_type == "weighted_avg": + # Weighted average + if len(values) != len(self.sample_counts): + # If the lengths do not match, use a simple average + return sum(values) / len(values) + + total_samples = sum(self.sample_counts) + if total_samples == 0: + return sum(values) / len(values) + + weighted_sum = sum(v * c for v, c in zip(values, self.sample_counts, strict=False)) + return weighted_sum / total_samples + + elif agg_type == "sum" or agg_type == "time_sum": + return sum(values) + + elif agg_type == "avg": + return sum(values) / len(values) + + elif agg_type == "max": + return max(values) + + elif agg_type == "min": + return min(values) + + else: + # Default average + return sum(values) / len(values) + + def get_aggregated_metrics(self) -> dict[str, Any]: + """aggregated metrics""" + t = time.time() + if self.step_count == 0: + return {} + + aggregated = {} + + # Aggregate all metrics + for metric_name, values in self.metric_values.items(): + aggregated[metric_name] = self._aggregate_single_metric(metric_name, values) + + # Aggregate special metrics + aggregated = self._special_metrics_aggergate(aggregated) + + print(f"aggregated metrics done. cost {time.time() - t}") + + return aggregated + + def _special_metrics_aggergate(self, aggregated: dict[str, Any]) -> dict[str, Any]: + """calculate special metrics""" + + # global_seqlen/minmax_diff + if "global_seqlen/minmax_diff" in aggregated.keys(): + aggregated["global_seqlen/minmax_diff"] = aggregated["global_seqlen/max"] - aggregated["global_seqlen/min"] + + # perf/throughput + REQUIRED_PERF_KEYS = {"perf/throughput", "perf/total_num_tokens", "perf/time_per_step"} + if REQUIRED_PERF_KEYS.issubset(aggregated): + aggregated["perf/throughput"] = aggregated["perf/total_num_tokens"] / ( + aggregated["perf/time_per_step"] * self.total_gpus + ) + + # trainer/idle_ratio + if "timing_s/gen" in aggregated.keys() and "timing_s/step" in aggregated.keys(): + aggregated["trainer/idle_ratio"] = aggregated["timing_s/gen"] / aggregated["timing_s/step"] + + return aggregated + + def reset(self): + """Reset Aggregator""" + self.metric_values.clear() + self.sample_counts.clear() + self.timestamps.clear() + self.step_count = 0 + + def get_current_stats(self) -> dict[str, Any]: + """Get statistics about the current aggregation state (for debugging)""" + return { + "step_count": self.step_count, + "metric_count": len(self.metric_values), + "total_samples": sum(self.sample_counts), + "metric_names": list(self.metric_values.keys()), + } diff --git a/src/reasoning360/recipe/fully_async_policy/fsdp2_utils.py b/src/reasoning360/recipe/fully_async_policy/fsdp2_utils.py new file mode 100644 index 000000000..1f1856596 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/fsdp2_utils.py @@ -0,0 +1,125 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.distributed as dist +from packaging import version +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec + +if version.parse(torch.__version__) < version.parse("2.6"): + raise RuntimeError("PyTorch 2.6 or higher is required to use fstp_utils.") + + +def fsdp2_sharded_save_to_cpu( + model: torch.nn.Module, +) -> tuple[dict[str, tuple[torch.Tensor, DTensorSpec]], DTensorSpec]: + """ + Sharded Save: Each process only saves the local DTensor shard from its own GPU to CPU memory. + + Args: + model: FSDP2-wrapped model whose parameters are of DTensor type. + + Returns: + cpu_sharded_state: Dictionary of CPU shards for the current process. + Key = parameter name, Value = (CPU shard tensor, original DTensorSpec) + global_spec: DTensorSpec of the first parameter (used to verify global rules during loading) + """ + cpu_sharded_state = {} + global_spec = None # Record global sharding rules (all parameters follow the same spec) + + for param_name, param in model.named_parameters(): + # Only process sharded parameters of DTensor type (core parameters of FSDP2) + if not isinstance(param, DTensor): + # Save non-sharded parameters (e.g., running_mean of BatchNorm) as local data + cpu_tensor = param.detach().cpu() + cpu_sharded_state[param_name] = (cpu_tensor, None) + continue + + # Record global sharding rules (take spec of the first DTensor to ensure consistency) + if global_spec is None: + global_spec = param._spec + assert hasattr(global_spec, "device_mesh"), "DTensorSpec must contain 'device_mesh' attribute" + assert hasattr(global_spec, "placements"), "DTensorSpec must contain 'placements' attribute" + + # 1. Extract local shard data from the current GPU (_local_tensor) + local_gpu_tensor = param._local_tensor # Local shard attribute defined in your DTensor class + # 2. Move to CPU memory and detach from computation graph + local_cpu_tensor = local_gpu_tensor.detach().cpu() + # 3. Save CPU shard + original DTensorSpec (ensure sharding rules remain unchanged) + cpu_sharded_state[param_name] = (local_cpu_tensor, param._spec) + + assert global_spec is not None, "No DTensor-type parameters found in the model. FSDP2 sharding may not be enabled." + return cpu_sharded_state, global_spec + + +def fsdp2_sharded_load_from_cpu( + model: torch.nn.Module, + cpu_sharded_state: dict[str, tuple[torch.Tensor, Optional[DTensorSpec]]], + target_spec: DTensorSpec, +) -> None: + """ + Sharded Load: Each process only loads the CPU shard it is responsible for to the GPU, + keeping sharding rules unchanged. + + Args: + model: FSDP2 model to be restored (must have the same structure as when saved) + cpu_sharded_state: Shard data read from CPU memory by the current process + (from fsdp2_sharded_save_to_cpu) + target_spec: Global DTensorSpec from saving (used to verify sharding rule consistency) + """ + # Verify device_mesh consistency (core: ensure loaded shards map to original GPUs) + current_device_mesh = None + for param in model.parameters(): + if isinstance(param, DTensor): + current_device_mesh = param._spec.device_mesh + break + assert current_device_mesh is not None, "DTensor parameters not initialized in the model to be loaded" + assert current_device_mesh == target_spec.device_mesh, ( + f"device_mesh mismatch during loading! Original: {target_spec.device_mesh}, Current: {current_device_mesh}" + ) + + for param_name, param in model.named_parameters(): + # Skip parameters not in the saved state (e.g., newly added parameters) + if param_name not in cpu_sharded_state: + continue + + # Extract CPU shard data and original Spec + local_cpu_tensor, saved_spec = cpu_sharded_state[param_name] + + # Handle different parameter types: DTensor sharded parameters vs. regular parameters + if isinstance(param, DTensor): + # 1. Verify sharding rule consistency (placements must match original Spec) + assert saved_spec is not None, f"DTensorSpec missing in saved state for parameter {param_name}" + assert saved_spec.placements == target_spec.placements, ( + f"Sharding strategy mismatch for parameter {param_name} (conflicts with global rules)!" + ) + + # 2. Move CPU shard data to the current GPU (device of param._local_tensor) + target_device = param._local_tensor.device + local_gpu_tensor = local_cpu_tensor.to(target_device) + + # 3. Restore to DTensor's local shard (directly copy to _local_tensor, keep spec unchanged) + param._local_tensor.copy_(local_gpu_tensor) + + else: + # Regular parameters: load directly to original device + target_device = param.device + param.data.copy_(local_cpu_tensor.to(target_device)) + + # Process synchronization: ensure all processes complete loading before proceeding + dist.barrier() diff --git a/src/reasoning360/recipe/fully_async_policy/fsdp_workers.py b/src/reasoning360/recipe/fully_async_policy/fsdp_workers.py new file mode 100644 index 000000000..db746dc97 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/fsdp_workers.py @@ -0,0 +1,161 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch +import torch.distributed +from omegaconf import DictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from .fsdp2_utils import fsdp2_sharded_load_from_cpu, fsdp2_sharded_save_to_cpu +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils.device import ( + get_device_name, + get_torch_device, +) +from verl.utils.fsdp_utils import ( + fsdp_version, + load_fsdp_model_to_gpu, + offload_fsdp_model_to_cpu, +) +from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + +__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"] + + +def get_inference_model(rollout): + """ + get models according to different types of inference_engine + Args: + rollout: rollout object + Returns: + model: model object + """ + inference_engine = rollout.inference_engine + if hasattr(inference_engine, "llm_engine"): + inference_model = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + elif hasattr(inference_engine, "worker"): + inference_model = inference_engine.worker.model_runner.model + else: + raise AttributeError( + f"Unsupported inference_engine type: {type(inference_engine)}. " + f"Expected LLM (with llm_engine attribute) or WorkerWrapperBase (with worker attribute)." + ) + return inference_model + + +class DetachNcclSync(AsyncActorRolloutRefWorker): + def _get_actor_params(self): + pass + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def sync_rollout_weights(self): + assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine + assert hasattr(self, "_weights_info") and self._weights_info is not None + + if self._is_actor and self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + params = self._get_actor_params() if self._is_actor else None + if self._is_rollout: + inference_model = get_inference_model(self.rollout) + + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader + + patch_vllm_moe_model_weight_loader(inference_model) + for key, shape, dtype in self._weights_info: + tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) + if self._is_actor: + assert key in params + origin_data = params[key] + if hasattr(origin_data, "full_tensor"): + origin_data = origin_data.full_tensor() + if torch.distributed.get_rank() == 0: + tensor.copy_(origin_data) + from ray.util.collective import collective + + collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") + if self._is_rollout: + inference_model.load_weights([(key, tensor)]) + + if self._is_actor and self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + get_torch_device().empty_cache() + + +class DetachActorWorker(DetachNcclSync): + def _get_actor_params(self): + assert self._is_actor + params = self.actor_module_fsdp.state_dict() + from verl.utils.model import convert_weight_keys + + params = convert_weight_keys( + params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + ) + return params + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_actor_weights_info(self): + assert self._is_actor + if hasattr(self, "_weights_info"): + return self._weights_info + if fsdp_version(self.actor_module_fsdp) == 1: + from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType + + FSDP.set_state_dict_type( + self.actor_module_fsdp, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), + ) + params = self._get_actor_params() + ret = [] + for key, tensor in params.items(): + ret.append((key, tensor.size(), tensor.dtype)) + self._weights_info = ret + return ret + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_model_to_cpu(self, n): + if not hasattr(self, "cpu_saved_models"): + self.cpu_saved_models = {} + self.cpu_saved_models[n] = fsdp2_sharded_save_to_cpu(self.actor_module_fsdp) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def restore_model_from_cpu(self, n): + if n in self.cpu_saved_models: + cpu_sharded_state, global_spec = self.cpu_saved_models[n] + fsdp2_sharded_load_from_cpu(self.actor_module_fsdp, cpu_sharded_state, global_spec) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def clear_cpu_model(self, n): + if n in self.cpu_saved_models: + del self.cpu_saved_models[n] + + +class DetachAsyncRolloutWorker(DetachNcclSync): + def __init__(self, config: DictConfig, role: str): + print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") + ActorRolloutRefWorker.__init__(self, config, role) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_actor_weights_info(self, weights_info): + assert self._is_rollout + self._weights_info = weights_info diff --git a/src/reasoning360/recipe/fully_async_policy/fully_async_main.py b/src/reasoning360/recipe/fully_async_policy/fully_async_main.py new file mode 100644 index 000000000..35be8dbc5 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/fully_async_main.py @@ -0,0 +1,301 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import socket +import threading +from pprint import pprint + +import hydra +import ray +from omegaconf import OmegaConf + +from .fully_async_rollouter import FullyAsyncRollouter +from .fully_async_trainer import FullyAsyncTrainer +from .message_queue import MessageQueue, MessageQueueClient +from reasoning360.trainer.ppo.ray_trainer import ResourcePoolManager +from verl.trainer.ppo.utils import Role +from verl.utils.fs import copy_to_local + + +def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager: + """ + Create resource pool manager + + Args: + config: Configuration object + roles: List of roles that need to create resource pools + + Returns: + ResourcePoolManager: Resource pool manager + """ + resource_pool_spec = {} + mapping = {} + + # Actor/Critic resource pool + if any(role in roles for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]): + assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0" + assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0" + + trainer_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes + resource_pool_spec["trainer_pool"] = trainer_pool + + # Map training-related roles to the same resource pool + for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]: + if role in roles: + mapping[role] = "trainer_pool" + + # Rollout resource pool + if Role.Rollout in roles: + assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0" + assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0" + + rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes + resource_pool_spec["rollout_pool"] = rollout_pool + mapping[Role.Rollout] = "rollout_pool" + + return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + +def create_role_worker_mapping(config): + """ + Create mapping from roles to worker classes + + Args: + config: Configuration object + + Returns: + dict: Mapping from roles to worker classes + """ + # Select worker class based on strategy + if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]: + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from .fsdp_workers import ( + CriticWorker, + DetachActorWorker, + DetachAsyncRolloutWorker, + ) + from verl.single_controller.ray import RayWorkerGroup + + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == "megatron": + assert config.critic.strategy == "megatron" + from .megatron_worker import CriticWorker, DetachActorWorker, DetachAsyncRolloutWorker + from verl.single_controller.ray import RayWorkerGroup + + ray_worker_group_cls = RayWorkerGroup + else: + raise NotImplementedError(f"Unsupported strategy: {config.actor_rollout_ref.actor.strategy}") + + role_worker_mapping = { + Role.Actor: ray.remote(DetachActorWorker), + Role.Rollout: ray.remote(DetachAsyncRolloutWorker), + Role.Critic: ray.remote(CriticWorker), + } + + if config.reward_model.enable: + if config.reward_model.strategy in ["fsdp", "fsdp2"]: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + + # Add reference policy (if KL loss or reward is required) + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker) + + return role_worker_mapping, ray_worker_group_cls + + +@ray.remote(num_cpus=1) +class FullyAsyncTaskRunner: + """ + Ray remote class for executing distributed PPO training tasks. + """ + + def __init__(self): + self.running = False + self.components = {} + self.shutdown_event = threading.Event() + + def run(self, config): + print("[ASYNC MAIN] Starting fully async PPO training...") + self._initialize_components(config) + self._run_training_loop() + + def _initialize_components(self, config) -> None: + print(f"[ASYNC MAIN] TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + print("[ASYNC MAIN] Initializing model and tokenizer...") + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + self.components["tokenizer"] = tokenizer + self.components["processor"] = processor + self.components["config"] = config + + print("[ASYNC MAIN] Creating worker mapping and resource pools...") + role_worker_mapping, ray_worker_group_cls = create_role_worker_mapping(config) + self.components["role_worker_mapping"] = role_worker_mapping + self.components["ray_worker_group_cls"] = ray_worker_group_cls + + print("[ASYNC MAIN] Creating FullyAsyncRollouter...") + self._create_rollouter(config) + + print("[ASYNC MAIN] Creating FullyAsyncTrainer...") + self._create_trainer(config) + + # sync total_train_steps between rollouter and trainer + total_train_steps = ray.get(self.components["rollouter"].get_total_train_steps.remote()) + print(f"total_train_steps {total_train_steps}") + ray.get(self.components["trainer"].set_total_train_steps.remote(total_train_steps)) + + # max_queue_size + max_queue_size = ray.get(self.components["rollouter"].get_max_queue_size.remote()) + print(f"[ASYNC MAIN] Creating MessageQueue... max_queue_size {max_queue_size}") + message_queue = MessageQueue.remote(config, max_queue_size) + message_queue_client = MessageQueueClient(message_queue) + self.components["message_queue"] = message_queue + self.components["message_queue_client"] = message_queue_client + + ray.get(self.components["rollouter"].set_message_queue_client.remote(self.components["message_queue_client"])) + ray.get(self.components["trainer"].set_message_queue_client.remote(self.components["message_queue_client"])) + + print("[ASYNC MAIN] Setting up parameter synchronization...") + from .param_sync import ParameterSynchronizer + + param_synchronizer = ParameterSynchronizer.remote( + config=config, + trainer=self.components["trainer"], + rollouter=self.components["rollouter"], + mq=self.components["message_queue_client"], + ) + ray.get(self.components["trainer"].set_parameter_synchronizer.remote(param_synchronizer)) + + # load checkpoint and sync parameter before doing anything + val_before_train = config.trainer.get("val_before_train", True) + # param_version resume from ckpt or default 0 + param_version = ray.get(self.components["trainer"].load_checkpoint.remote()) + ray.get(self.components["rollouter"].load_checkpoint.remote()) + ray.get(param_synchronizer.sync_weights.remote(version=param_version, validate=val_before_train)) + ray.get(param_synchronizer.wait_last_valid.remote()) + + self.components["param_synchronizer"] = param_synchronizer + print("[ASYNC MAIN] All components initialized successfully") + + def _create_rollouter(self, config) -> None: + rollouter = FullyAsyncRollouter.remote( + config=config, + tokenizer=self.components["tokenizer"], + role_worker_mapping={Role.Rollout: self.components["role_worker_mapping"][Role.Rollout]}, + resource_pool_manager=create_resource_pool_manager(config, roles=[Role.Rollout]), + ray_worker_group_cls=self.components["ray_worker_group_cls"], + processor=self.components["processor"], + device_name=config.trainer.device, + ) + + ray.get(rollouter.init_workers.remote()) + ray.get(rollouter.set_max_required_samples.remote()) + + self.components["rollouter"] = rollouter + print("[ASYNC MAIN] Rollouter created and initialized successfully") + + def _create_trainer(self, config) -> None: + trainer_role_mapping = { + role: worker_cls + for role, worker_cls in self.components["role_worker_mapping"].items() + if role != Role.Rollout + } + + trainer = FullyAsyncTrainer.remote( + config=config, + tokenizer=self.components["tokenizer"], + role_worker_mapping=trainer_role_mapping, + resource_pool_manager=create_resource_pool_manager(config, roles=list(trainer_role_mapping.keys())), + ray_worker_group_cls=self.components["ray_worker_group_cls"], + processor=self.components["processor"], + device_name=config.trainer.device, + ) + + ray.get(trainer.init_workers.remote()) + self.components["trainer"] = trainer + print("[ASYNC MAIN] FullyAsyncTrainer created and initialized successfully") + + def _run_training_loop(self): + self.running = True + + print("[ASYNC MAIN] Starting Rollouter and Trainer...") + rollouter_future = self.components["rollouter"].fit.remote() + trainer_future = self.components["trainer"].fit.remote() + + futures = [rollouter_future, trainer_future] + + try: + while futures: + # Use ray.wait to monitor all futures and return when any one is completed. + done_futures, remaining_futures = ray.wait(futures, num_returns=1, timeout=None) + + for future in done_futures: + try: + ray.get(future) + print("[ASYNC MAIN] One component completed successfully") + except Exception as e: + print(f"[ASYNC MAIN] Component failed with error: {e}") + for remaining_future in remaining_futures: + ray.cancel(remaining_future) + raise e + + futures = remaining_futures + + except Exception as e: + print(f"[ASYNC MAIN] Training failed: {e}") + for future in futures: + ray.cancel(future) + raise + finally: + asyncio.run(self.components["message_queue_client"].clear_queue()) + print("[ASYNC MAIN] Training completed or interrupted") + + +@hydra.main(config_path="config", config_name="fully_async_ppo_trainer", version_base=None) +def main(config): + from reasoning360.trainer.main_ppo import run_ppo + + # Ensure async training config exists + if not hasattr(config, "async_training"): + raise RuntimeError("must set async_training config") + from time import time + + start_time = time() + run_ppo(config, task_runner_class=FullyAsyncTaskRunner) + print(f"total time: {time() - start_time:.2f} seconds") + + +if __name__ == "__main__": + main() diff --git a/src/reasoning360/recipe/fully_async_policy/fully_async_rollouter.py b/src/reasoning360/recipe/fully_async_policy/fully_async_rollouter.py new file mode 100644 index 000000000..f9a06e81f --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/fully_async_rollouter.py @@ -0,0 +1,717 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import os +import time +from pprint import pformat + +import numpy as np +import ray +import torch +from ray import ObjectRef + +from .detach_utils import ( + RolloutSample, + ValidateMetrics, + prepare_single_generation_data, +) +from .message_queue import MessageQueueClient +from .ray_trainer import FullyAsyncRayPPOTrainer +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from reasoning360.trainer.ppo.ray_trainer import ResourcePoolManager +from reasoning360.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import Role, WorkerType +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path +from verl.utils.profiler import marked_timer +from verl.utils.tracking import ValidationGenerationsLogger + + +@ray.remote(num_cpus=10, max_concurrency=100) +class FullyAsyncRollouter(FullyAsyncRayPPOTrainer): + """ + Asynchronous sample generator, responsible for continuously generating training samples + and putting them into MessageQueue + Based on the mature implementation improvements of OneStepOffRayTrainer + """ + + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + device_name=None, + ): + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + self.val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + + assert not self.hybrid_engine + assert self.config.data.train_batch_size == 0, "train_batch_size must be zero" + assert self.config.data.gen_batch_size == 1, "gen_batch_size must be one" + assert self.config.async_training.staleness_threshold >= 0, "staleness_threshold must larger than 0" + assert self.config.async_training.trigger_parameter_sync_step >= 1, ( + "trigger_parameter_sync_step must larger than 1" + ) + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name if device_name else self.config.trainer.device + self.validation_generations_logger = ValidationGenerationsLogger( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + ) + + self.ref_in_actor = False + self.kl_ctrl_in_reward = False + self.use_critic = False + self.use_reference_policy = False + self.use_rm = False + + print("[FullyAsyncRollouter] Creating datasets...") + from reasoning360.trainer.main_ppo import create_rl_dataset, create_rl_sampler + from reasoning360.utils.dataset.rl_dataset import collate_fn + + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + train_sampler = create_rl_sampler(config.data, train_dataset) + + self._validate_config() + print(f"[FullyAsyncRollouter] Rollouter _create_dataloader...\n{train_dataset}\n{val_dataset}") + + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + # ==================== fully async config ==================== + + self.total_rollout_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + if self.config.rollout.total_rollout_steps is not None: + self.total_rollout_steps = min(self.config.rollout.total_rollout_steps, self.total_rollout_steps) + print(f"[FullyAsyncRollouter] Total rollout steps: {self.total_rollout_steps}") + self.total_train_steps = None + + # Rollouter parameter configuration + self.message_queue_client = None + + # Worker groups: rollout_wg is same to actor_rollout_wg + self.rollout_wg = None + self.actor_rollout_wg = None + self.async_rollout_manager = None + + # Config + self.staleness_threshold: float = config.async_training.get("staleness_threshold", 1) + # required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples. + self.require_batches = config.async_training.require_batches + self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches + self.max_required_samples = None + self.max_concurrent_samples = None + # queue size + self.max_queue_size = None + + # Statistics + self.current_param_version = 0 + self.total_generated_samples = 0 + self.staleness_samples = 0 + self.dropped_stale_samples = 0 + self.processed_sample_count = 0 + # we start from step 1 + self.global_steps = 1 + self.idle_start_time = None + self.version_start_time = None + + # Concurrency control + # Modified by self.pause() or self._should_pause_generation() + self.paused = False + self.running = True + self.monitor_loop_trigger = True + + # Add dataloader lock + self.dataloader_lock = asyncio.Lock() + + # Initialize async queues + self.pending_queue = asyncio.Queue(maxsize=128) + self.active_tasks = set() + self.cancel_queue = asyncio.Queue() + + def _init_async_objects(self): + # Initialize asyncio synchronization primitives. + # We let asyncio.Condition create the Lock internally to ensure they share the same Event Loop. + # This avoids 'ValueError: loop argument must agree with lock' which can occur in Ray environments + # where the lock's captured loop (get_running_loop) differs from Condition's default loop check. + # Explicitly passing the loop is deprecated/removed in Python 3.10+, so this reverse-initialization + # is the most robust workaround. + self.condition = asyncio.Condition() + self.lock = self.condition._lock + + async def set_message_queue_client(self, message_queue_client: MessageQueueClient): + """Set message queue client""" + async with self.lock: + self.message_queue_client = message_queue_client + + async def set_max_required_samples(self): + async with self.lock: + self.max_required_samples = int( + self.required_samples + * (self.staleness_threshold + 1) + * self.config.async_training.trigger_parameter_sync_step + ) + self.total_train_steps = int( + self.total_rollout_steps + / (self.required_samples * self.config.async_training.trigger_parameter_sync_step) + ) + + self.max_concurrent_samples = len(self.async_rollout_manager.server_handles) * 16 + self.max_concurrent_samples = min(self.max_concurrent_samples, self.max_required_samples) + self.max_queue_size = self.max_required_samples + + print( + f"[FullyAsyncRollouter] required_samples : {self.required_samples} " + f"max_required_samples: {self.max_required_samples} " + f"max_queue_size: {self.max_queue_size} " + f"total_train_steps: {self.total_train_steps} " + f"total_rollout_steps: {self.total_rollout_steps} " + f"max_concurrent_samples: {self.max_concurrent_samples} " + ) + + def get_rollout_wg(self): + """Get rollout worker group""" + return self.rollout_wg + + def get_max_queue_size(self): + return self.max_queue_size + + def get_total_train_steps(self): + return self.total_train_steps + + async def update_param_version(self, version: int, validate: bool = False, global_steps: int = 0): + """Update current parameter version""" + async with self.lock: + old_version = self.current_param_version + self.current_param_version = version + # every time param change, reset staleness_samples + self.staleness_samples = ( + len(self.active_tasks) + self.cancel_queue.qsize() + await self.message_queue_client.get_queue_size() + ) + timing_raw = {} + idle_ratio = None + if self.idle_start_time is not None and self.version_start_time is not None: + rollout_active_time = self.idle_start_time - self.version_start_time + rollout_version_time = time.time() - self.version_start_time + idle_ratio = 1 - rollout_active_time / rollout_version_time + timing_raw["rollouter/active_time"] = rollout_active_time + timing_raw["rollouter/version_time"] = rollout_version_time + timing_raw["rollouter/idle_ratio"] = idle_ratio + self.idle_start_time = None + print( + f"[FullyAsyncRollouter][Public][update_param_version] " + f"Parameter version updated from {old_version} to {version} " + f",reset staleness_samples to: {self.staleness_samples}" + f",idle_ratio: {idle_ratio}" + ) + val_metrics = None + if ( + self.val_reward_fn is not None + and self.config.rollout.test_freq > 0 + and self.current_param_version % self.config.rollout.test_freq == 0 + and self.current_param_version > 0 # don't test here in the initial parameter sync + ) or (validate and self.val_reward_fn is not None): + with marked_timer("rollouter/validate_time", timing_raw, color="green"): + val_metrics: dict = self._validate() + data = ValidateMetrics( + timing_raw=timing_raw, metrics=val_metrics, global_steps=global_steps, param_version=version + ) + await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data)) + + self.version_start_time = time.time() + + async def save_checkpoint(self, local_global_step_folder: str): + # WARNING!: Due to the asynchronous nature, there are some in-flight samples + # (pending/cancel/result queue and message queue). + # Therefore, directly saving the state of the dataloader will result in losing these + # samples when resuming training. + # TODO: Implement dataloader recovery without losing in-flight samples. + from verl.utils.fs import local_mkdir_safe + + # save dataloader + local_mkdir_safe(local_global_step_folder) + dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") + async with self.dataloader_lock: + dataloader_state_dict = self.train_dataloader.state_dict() + torch.save(dataloader_state_dict, dataloader_local_path) + print(f"[FullyAsyncRollouter] Saved dataloader checkpoint to {dataloader_local_path}") + + def load_checkpoint(self): + """Load checkpoint including dataloader state based on resume mode""" + + if self.config.trainer.resume_mode == "disable": + print("[FullyAsyncRollouter] Resume mode is disabled, starting from scratch") + return 0 + + # Determine checkpoint folder path + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("[FullyAsyncRollouter] Load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + + global_step_folder = find_latest_ckpt_path(checkpoint_folder) + + # Find and validate global_step_folder based on resume mode + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("[FullyAsyncRollouter] Training from scratch (no checkpoint found)") + return 0 + elif self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), ( + "[FullyAsyncRollouter] resume_from_path must be str type" + ) + assert "global_step_" in self.config.trainer.resume_from_path, ( + "[FullyAsyncRollouter] resume_from_path must specify the global_steps" + ) + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + else: + raise ValueError(f"[FullyAsyncRollouter] Unknown resume_mode: {self.config.trainer.resume_mode}") + + print(f"[FullyAsyncRollouter] Loading checkpoint from: {global_step_folder}") + + # Extract and set global step + trainer_global_steps = int(global_step_folder.split("global_step_")[-1]) + self.global_steps = ( + trainer_global_steps * self.required_samples * self.config.async_training.trigger_parameter_sync_step + 1 + ) + print(f"[FullyAsyncRollouter] Setting global_steps to {self.global_steps}") + + # Load dataloader state + dataloader_local_path = os.path.join(global_step_folder, "data.pt") + if os.path.exists(dataloader_local_path): + dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) + self.train_dataloader.load_state_dict(dataloader_state_dict) + print(f"[FullyAsyncRollouter] Loaded dataloader state from {dataloader_local_path}") + else: + print( + f"[FullyAsyncRollouter] Warning: No dataloader state found at {dataloader_local_path}, " + f"will start from scratch" + ) + + def _validate_config(self): + # Validate asynchronous training configuration + if not hasattr(self.config, "async_training"): + raise ValueError("[FullyAsyncRollouter] Missing async_training configuration") + assert self.config.actor_rollout_ref.rollout.calculate_log_probs, "must rollout calculate log_probs" + + async def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self._init_async_objects() + self._init_resource_pools() + self._create_worker_classes() + self._init_worker_groups() + self._init_models() + await self._init_async_rollout_manager() + + def _create_actor_rollout_classes(self): + # only create rollout + for role in [Role.Rollout]: + resource_pool = self.resource_pool_manager.get_resource_pool(role) + role_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[role], + config=self.config.actor_rollout_ref, + role=str(role), + ) + self.resource_pool_to_cls[resource_pool][str(role)] = role_cls + + def _init_models(self): + self.rollout_wg = self.all_wg[str(Role.Rollout)] + self.rollout_wg.init_model() + self.actor_rollout_wg = self.rollout_wg + + def _create_continuous_iterator(self): + """ + Create a continuous data iterator across epoch + """ + for epoch in range(self.config.rollout.total_epochs): + iterator = iter(self.train_dataloader) + for batch_dict in iterator: + yield epoch, batch_dict + + async def _init_async_rollout_manager(self): + # create async rollout manager and request scheduler + assert self.config.actor_rollout_ref.rollout.mode == "async" + from .agent_loop import FullyAsyncAgentLoopManager + + self.async_rollout_mode = True + self.async_rollout_manager = await FullyAsyncAgentLoopManager.create( + config=self.config, + worker_group=self.rollout_wg, + ) + + # Add samples to the pending_queue + async def _feed_samples(self): + continuous_iterator = self._create_continuous_iterator() + + for epoch, batch_dict in continuous_iterator: + # Similar to _prepare_generate_batch: Separate data + full_batch = prepare_single_generation_data(batch_dict, self.config) + + sample_id = f"sample_{epoch}_{self.global_steps}" + + rollout_sample = RolloutSample( + full_batch=full_batch, + agent_loop_output_list=[None] * self.config.actor_rollout_ref.rollout.n, + sample_id=sample_id, + epoch=epoch, + param_version=0, + param_version_start=[], + param_version_end=[], + processing_times=[], + tool_calls=[], + rollout_status={}, + ) + + await self.pending_queue.put(rollout_sample) + + # Check if have reached the last step + if self.global_steps >= self.total_rollout_steps: + print( + f"[FullyAsyncRollouter][Feed] " + f"Maximum count has been reached, stop adding new samples" + f"{self.global_steps} >= {self.total_rollout_steps}" + ) + break + + self.global_steps += 1 + + # End signal + await self.pending_queue.put("DONE") + print(f"[FullyAsyncRollouter][Feed] Sample addition is complete, {self.global_steps} samples have been added") + + async def _processor_worker(self): + """ + Streaming worker coroutines, a sample is submitted for processing without waiting for batches + """ + while True: + if self.paused or await self._should_pause_generation(): + print( + "[FullyAsyncRollouter][Processor] Received pause signal, waiting for remaining tasks to return..." + ) + async with self.lock: + self.paused = True + while self.active_tasks: + async with self.lock: + # After acquiring the lock, the number of active_tasks may change, need to be verified again + if self.active_tasks: + done_tasks, self.active_tasks = await asyncio.wait( + self.active_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + await task + + async with self.lock: + while self.paused: + self.idle_start_time = time.time() + await self.condition.wait() + continue + + simple_from_cancel_queue = False + if not self.cancel_queue.empty(): + rollout_sample = await self.cancel_queue.get() + simple_from_cancel_queue = True + else: + rollout_sample = await self.pending_queue.get() + self.staleness_samples += 1 + + if rollout_sample == "DONE": + print( + "[FullyAsyncRollouter][Processor] Received end signal, waiting for remaining tasks to complete..." + ) + while self.active_tasks: + async with self.lock: + if self.active_tasks: + done_tasks, self.active_tasks = await asyncio.wait( + self.active_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + await task + break + + # Check whether the number of concurrent tasks exceeds the limit + while len(self.active_tasks) >= self.max_concurrent_samples: + async with self.lock: + if self.active_tasks: + done_tasks, self.active_tasks = await asyncio.wait( + self.active_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + await task + + # Submit single sample processing + async with self.lock: + # After the pause is over, the lock is acquired and it is necessary + # to determine whether it is the pause phase, otherwise continue to wait + while self.paused: + await self.condition.wait() + task = asyncio.create_task( + self._process_single_sample_streaming(rollout_sample), + name=rollout_sample.sample_id, + ) + self.active_tasks.add(task) + + if simple_from_cancel_queue: + self.cancel_queue.task_done() + else: + self.pending_queue.task_done() + + async def _process_single_sample_streaming(self, rollout_sample: RolloutSample): + """Process a single sample streamingly""" + # Calling asynchronous generation methods + rollout_sample.full_batch.non_tensor_batch["param_version"] = [self.current_param_version] * len( + rollout_sample.full_batch + ) + ret, is_cancel = await self.async_rollout_manager.generate_single_sample_async( + rollout_sample.full_batch, rollout_sample.agent_loop_output_list + ) + if not is_cancel: + rollout_sample.full_batch = ret + rollout_sample.full_batch.non_tensor_batch["uid"] = np.array( + [f"uid_{rollout_sample.sample_id}"] * len(rollout_sample.full_batch), dtype=object + ) + rollout_sample.param_version = self.current_param_version + rollout_sample.rollout_status = await self.get_statistics() + rollout_sample.agent_loop_output_list = [] + + success = await self.message_queue_client.put_sample( + sample=ray.cloudpickle.dumps(rollout_sample), + param_version=rollout_sample.param_version, + ) + if success: + self.total_generated_samples += 1 + else: + self.dropped_stale_samples += 1 + else: + rollout_sample.agent_loop_output_list = ret + await self.cancel_queue.put(rollout_sample) + + self.processed_sample_count += 1 + + async def _streaming_generation_main(self): + """The main entry method for stream processing""" + + if self.async_rollout_manager is None: + await self._init_async_rollout_manager() + + # Start the streaming loop + print(f"[FullyAsyncRollouter] Start streaming mode, maximum concurrent samples: {self.max_concurrent_samples}") + + # Start sample feed coroutine, streaming process coroutine + self.feed_task = asyncio.create_task(self._feed_samples()) + self.processor_task = asyncio.create_task(self._processor_worker()) + + try: + # Wait for sample feed to complete + # Use asyncio.wait to monitor all tasks. If processor exits early, + # detect it instead of blocking on feed_task (it might be stuck on a full queue). + done, pending = await asyncio.wait( + [self.feed_task, self.processor_task], return_when=asyncio.FIRST_COMPLETED + ) + + for task in done: + if task.exception(): + raise task.exception() + + if self.feed_task not in done: + raise RuntimeError("Processor task exited prematurely") + + print("[FullyAsyncRollouter] Sample feed completed") + + # Wait for streaming to complete + await self.processor_task + print("[FullyAsyncRollouter] Streaming process completed") + + except Exception as e: + print(f"[FullyAsyncRollouter] Streaming process exception:{e}") + + finally: + if self.processor_task: + self.processor_task.cancel() + + await asyncio.gather(self.processor_task, return_exceptions=True) + + # Send a finish signal + await self.message_queue_client.put_sample( + sample=None, + param_version=self.current_param_version, + ) + + async with self.lock: + self.running = False + + async def fit(self): + """ + Start the async rollouter - entry point that sets up and runs async tasks + Main async fit method that coordinates all coroutines + """ + + print("[FullyAsyncRollouter] Starting FullyAsyncRollouter...") + + if self.message_queue_client is None: + raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.") + + # Set the running status flag + async with self.lock: + self.paused = False + self.running = True + + # Create the main asynchronous task + generation_task = asyncio.create_task(self._streaming_generation_main()) + monitor_task = asyncio.create_task(self._async_monitor_loop()) + + try: + # Run build and monitoring tasks concurrently + await asyncio.gather(generation_task, monitor_task, return_exceptions=True) + except Exception as e: + print(f"[FullyAsyncRollouter] Asynchronous task execution error: {e}") + finally: + if not generation_task.done(): + generation_task.cancel() + if not monitor_task.done(): + monitor_task.cancel() + + # Wait for the task to complete + await asyncio.gather(generation_task, monitor_task, return_exceptions=True) + + print("[FullyAsyncRollouter] Rollouter fit completed") + + async def _async_monitor_loop(self): + """ + Async coroutine for monitoring: + Function 1: Log information output + Function 2: Trigger rollout recovery + """ + last_stats_time = time.time() + stats_interval = 60.0 + check_interval = 10.0 + + while True: + async with self.lock: + if not self.running: + break + await asyncio.sleep(check_interval) + # Print statistics periodically + current_time = time.time() + if current_time - last_stats_time >= stats_interval: + stats = await self.get_statistics() + print(f"[FullyAsyncRollouter][MonitorLoop][Statistics] {pformat(stats)}") + last_stats_time = current_time + + # Trigger rollout recovery + if self.monitor_loop_trigger: + if not await self._should_pause_generation(): + async with self.lock: + self.paused = False + self.condition.notify_all() + + async def _should_pause_generation(self) -> bool: + """Determine whether the build should be paused""" + queue_stats = self.message_queue_client.get_statistics_sync() + queue_size = queue_stats["queue_size"] + + if queue_size >= self.max_queue_size: + if not self.paused: + print( + f"[FullyAsyncRollouter][ShouldPause] " + f"due to full queue: size={queue_size}, max={self.max_queue_size}" + ) + return True + + if self.staleness_samples >= self.max_required_samples: + if not self.paused: + print( + "[FullyAsyncRollouter][ShouldPause] " + f"due to " + f"staleness_samples {self.staleness_samples} >= max_required_samples {self.max_required_samples} " + ) + return True + + return False + + async def pause(self): + """pause rollout""" + print("[FullyAsyncRollouter][Public][Pause]") + async with self.lock: + self.paused = True + # Cancel all rollout tasks + if self.config.async_training.partial_rollout: + await self.async_rollout_manager.cancel() + if self.active_tasks: + await asyncio.gather(*self.active_tasks, return_exceptions=True) + self.active_tasks.clear() + print("[FullyAsyncRollouter][Public][Pause] All active tasks completed") + await self.async_rollout_manager.clear_kv_cache() + self.monitor_loop_trigger = False + + async def resume(self, dependency_ref: ObjectRef = None): + if dependency_ref is not None: + ray.get(dependency_ref) + print("[FullyAsyncRollouter][Public][Resume]") + async with self.lock: + if self.config.async_training.partial_rollout: + await self.async_rollout_manager.resume() + self.paused = False + self.monitor_loop_trigger = True + self.condition.notify_all() + + async def get_statistics(self) -> dict: + queue_stats = self.message_queue_client.get_statistics_sync() + + stats = { + # monitor stats + "monitor/active_tasks_size": len(self.active_tasks), + "monitor/queue/pending_queue_size": self.pending_queue.qsize(), + "monitor/queue/cancel_queue_size": self.cancel_queue.qsize(), + "monitor/queue/mq_queue_size": queue_stats["queue_size"], + # counting stats + "count/current_param_version": self.current_param_version, + "count/total_generated_samples": self.total_generated_samples, + "count/staleness_samples": self.staleness_samples, + "count/dropped_stale_samples": self.dropped_stale_samples, + # static stats + "static/max_required_samples": self.max_required_samples, + "static/required_samples": self.required_samples, + "static/staleness_threshold": self.staleness_threshold, + "static/max_queue_size": self.max_queue_size, + "static/max_concurrent_samples": self.max_concurrent_samples, + } + + return stats diff --git a/src/reasoning360/recipe/fully_async_policy/fully_async_trainer.py b/src/reasoning360/recipe/fully_async_policy/fully_async_trainer.py new file mode 100644 index 000000000..669450204 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/fully_async_trainer.py @@ -0,0 +1,505 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from datetime import datetime +from pprint import pprint +from typing import Any + +import ray +from omegaconf import OmegaConf +from tqdm import tqdm + +from .detach_utils import ( + MetricsAggregator, + ValidateMetrics, + assemble_batch_from_rollout_samples, +) +from .message_queue import MessageQueueClient +from .ray_trainer import FullyAsyncRayPPOTrainer +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.trainer.ppo import core_algos +from reasoning360.trainer.ppo.ray_trainer import ResourcePoolManager +from reasoning360.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model +from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi +from verl.utils.debug import marked_timer + + +@ray.remote(num_cpus=10) +class FullyAsyncTrainer(FullyAsyncRayPPOTrainer): + """ + A fully asynchronous PPO trainer that obtains samples from a MessageQueue for training. + Based on an improved implementation of OneStepOffRayTrainer + """ + + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + device_name=None, + ): + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + self.val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert not self.hybrid_engine + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = need_reference_policy(self.role_worker_mapping) + self.use_rm = need_reward_model(self.role_worker_mapping) + self.use_critic = need_critic(self.config) + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name if device_name else self.config.trainer.device + + # if ref_in_actor is True, the reference policy will be actor without lora applied + self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 + + # define in-reward KL control + # kl loss control currently not suppoorted + if self.config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) + + # ==================== fully async config ==================== + + self.message_queue_client = None + self.param_synchronizer = None + + # Statistics + # we start from step 1 + self.global_steps = 1 + self.local_trigger_step = 1 + self.processed_samples = 0 + self.stale_samples_processed = 0 + self.stale_trajectory_processed = 0 + self.current_param_version = 0 + self.total_train_steps = None + self.progress_bar = None + self.trigger_parameter_sync_step = config.async_training.trigger_parameter_sync_step + self.last_ckpt_version = 0 + + # required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples. + self.require_batches = config.async_training.require_batches + self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches + self.compute_prox_log_prob = self.config.async_training.compute_prox_log_prob + total_gpus = ( + config.trainer.nnodes * config.trainer.n_gpus_per_node + + config.rollout.nnodes * config.rollout.n_gpus_per_node + ) + self.metrics_aggregator = MetricsAggregator(total_gpus=total_gpus) + + def set_message_queue_client(self, message_queue_client: MessageQueueClient): + """Set message queue client""" + self.message_queue_client = message_queue_client + + def set_parameter_synchronizer(self, param_synchronizer): + """Set parameter synchronizer""" + self.param_synchronizer = param_synchronizer + + def set_total_train_steps(self, total_train_steps): + self.total_train_steps = total_train_steps + self.progress_bar = tqdm(total=self.total_train_steps, initial=0, desc="Training Progress") + + def get_actor_wg(self): + """Get actor worker group""" + return self.actor_wg + + def _get_samples_from_queue(self) -> tuple[None, None] | tuple[int, Any]: + """ + Get samples from message queue and compose gen_batch_output + Uses a loop to continuously collect samples until enough are gathered + + Returns: + tuple: (epoch, batch_dict, gen_batch_output) + """ + print( + f"[FullyAsyncTrainer] Requesting {self.required_samples} samples from queue", + flush=True, + ) + + # Collect samples using a simple loop calling get_sample + consumer_start = time.time() + queue_samples = [] + queue_len = 0 + while len(queue_samples) < self.required_samples: + # Get a single sample and wait until there is a sample or None is received + sample, queue_len = self.message_queue_client.get_sample_sync() + + if sample is None: + print( + f"[FullyAsyncTrainer] Detected termination signal (None), stopping sample collection. " + f"Collected {len(queue_samples)}/{self.required_samples} samples" + ) + break + + queue_samples.append(sample) + + if len(queue_samples) % 64 == 0: + print( + f"[FullyAsyncTrainer] Collected {len(queue_samples)}/{self.required_samples} samples. " + f"mq_len: {queue_len}" + ) + + consumer_end = time.time() + + if not queue_samples or len(queue_samples) < self.required_samples: + print("[FullyAsyncTrainer] not enough samples collected after loop") + return None, None + total_wait_time = consumer_end - consumer_start + + print( + f"[FullyAsyncTrainer] Loop collection completed: {len(queue_samples)}/{self.required_samples} samples, " + f"total wait time: {total_wait_time:.2f} seconds." + f"mq_len: {queue_len}" + ) + + queue_samples = [ray.cloudpickle.loads(x) for x in queue_samples] + # Assemble batch - now working directly with RolloutSample objects + if self.config.trainer.balance_batch: + batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, self._balance_batch) + else: + batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, None) + + batch.meta_info["fully_async/total_wait_time"] = total_wait_time + return 0, batch + + def _create_actor_rollout_classes(self): + # create actor + for role in [Role.Actor]: + resource_pool = self.resource_pool_manager.get_resource_pool(role) + role_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[role], + config=self.config.actor_rollout_ref, + role=str(role), + ) + self.resource_pool_to_cls[resource_pool][str(role)] = role_cls + + def _init_models(self): + if self.use_critic: + self.critic_wg = self.all_wg[str(Role.Critic)] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = self.all_wg[str(Role.RewardModel)] + self.rm_wg.init_model() + + self.actor_wg = self.all_wg[str(Role.Actor)] + self.actor_wg.init_model() + self.actor_rollout_wg = self.actor_wg # to be compatible with the functions that not be modified + + def _init_async_rollout_manager(self): + pass + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + print("[FullyAsyncTrainer] Starting FullyAsyncTrainer...") + if self.message_queue_client is None: + raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.") + if self.param_synchronizer is None: + raise ValueError("param_synchronizer client not set. Call set_parameter_synchronizer() first.") + + from verl.utils.tracking import Tracking + + self.logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.max_steps_duration = 0 + + # get validate data before training + self._log_validation_data() + + # Use queue mode, no need for traditional dataloader iterator + # Initialize to get the first batch of data + while True: + metrics = {} + timing_raw = {} + + with marked_timer("step", timing_raw): + with marked_timer("gen", timing_raw, color="red"): + epoch, batch = self._get_samples_from_queue() + if batch is None: + break + self._collect_metrics_from_samples(batch, metrics) + batch, reward_extra_infos_dict = self._process_batch_common( + batch, metrics, timing_raw, self.local_trigger_step if self.compute_prox_log_prob else None + ) + self._log_rollout(batch, reward_extra_infos_dict, timing_raw) + + self._collect_metrics(batch, 0, metrics, timing_raw) + self.metrics_aggregator.add_step_metrics( + metrics=metrics, sample_count=self.required_samples, timestamp=time.time() + ) + # Trigger parameter synchronization after training step + time_str = datetime.now().strftime("%H:%M:%S.%f")[:-3] + print( + f"[FullyAsyncTrainer] global_steps: {self.global_steps} " + f"local_trigger_step: {self.local_trigger_step} " + f"trigger_parameter_sync_step: {self.trigger_parameter_sync_step} " + f"{time_str}" + ) + self._trigger_parameter_sync_after_step(global_steps=self.global_steps) + self._log_validation_data() + self._check_save_checkpoint(timing_raw) + self.global_steps += 1 + + # final parameter sync and validate + # 1. waiting remaining validate task + ray.get(self.param_synchronizer.wait_last_valid.remote()) + self._log_validation_data() + # 2. perform addtional parameter_sync and validate if trainer already updated + if self.current_param_version % self.config.rollout.test_freq != 0 or self.local_trigger_step > 1: + self._trigger_parameter_sync_after_step(validate=True, global_steps=self.global_steps) + ray.get(self.param_synchronizer.wait_last_valid.remote()) + self._log_validation_data() + self.progress_bar.close() + + self._check_save_checkpoint(timing_raw) + + def _check_save_checkpoint(self, timing_raw): + if self.current_param_version == self.last_ckpt_version: + return + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. The current step number is a multiple of the save frequency. + # 3. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + self.current_param_version % self.config.trainer.save_freq == 0 or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + self.last_ckpt_version = self.current_param_version + + def _save_checkpoint(self): + # Warning: Currently, to align the training process and metrics of colocate, + # we use current_param_version instead of global step. + # This can be logically aligned with the original self.global_steps of colocate + # and is used for metrics and ckpt. which means that the parameter synchronization + # from trainer to rollouter will increase by 1 each time. + + # path: given_path + `/global_step_{global_steps}` + `/actor` + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.current_param_version}" + ) + + print(f"[FullyAsyncTrainer] local_global_step_folder: {local_global_step_folder}") + actor_local_path = os.path.join(local_global_step_folder, "actor") + + actor_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join( + self.config.trainer.default_hdfs_dir, f"global_step_{self.current_param_version}", "actor" + ) + ) + + remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) + if remove_previous_ckpt_in_save: + print( + "[FullyAsyncTrainer] Warning: remove_previous_ckpt_in_save is deprecated," + + " set max_actor_ckpt_to_keep=1 and max_critic_ckpt_to_keep=1 instead" + ) + max_actor_ckpt_to_keep = ( + self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + max_critic_ckpt_to_keep = ( + self.config.trainer.get("max_critic_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 + ) + + self.actor_rollout_wg.save_checkpoint( + actor_local_path, actor_remote_path, self.current_param_version, max_ckpt_to_keep=max_actor_ckpt_to_keep + ) + + if self.use_critic: + critic_local_path = os.path.join(local_global_step_folder, str(Role.Critic)) + critic_remote_path = ( + None + if self.config.trainer.default_hdfs_dir is None + else os.path.join( + self.config.trainer.default_hdfs_dir, f"global_step_{self.current_param_version}", str(Role.Critic) + ) + ) + self.critic_wg.save_checkpoint( + critic_local_path, + critic_remote_path, + self.current_param_version, + max_ckpt_to_keep=max_critic_ckpt_to_keep, + ) + ray.get(self.param_synchronizer.rollouter_save_checkpoint.remote(local_global_step_folder)) + # latest checkpointed iteration tracker (for atomic usage) + local_latest_checkpointed_iteration = os.path.join( + self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" + ) + with open(local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.current_param_version)) + + def load_checkpoint(self): + if self.config.trainer.resume_mode == "disable": + # NOTE: while there is no checkpoint to load, we still need to offload the model and optimizer to CPU + self.actor_rollout_wg.load_checkpoint(None) + return 0 + + # load from hdfs + if self.config.trainer.default_hdfs_dir is not None: + raise NotImplementedError("load from hdfs is not implemented yet") + else: + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + if not os.path.isabs(checkpoint_folder): + working_dir = os.getcwd() + checkpoint_folder = os.path.join(working_dir, checkpoint_folder) + global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest + + # find global_step_folder + if self.config.trainer.resume_mode == "auto": + if global_step_folder is None: + print("[FullyAsyncTrainer] Training from scratch") + self.actor_rollout_wg.load_checkpoint(None) + return 0 + else: + if self.config.trainer.resume_mode == "resume_path": + assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" + assert "global_step_" in self.config.trainer.resume_from_path, ( + "resume ckpt must specify the global_steps" + ) + global_step_folder = self.config.trainer.resume_from_path + if not os.path.isabs(global_step_folder): + working_dir = os.getcwd() + global_step_folder = os.path.join(working_dir, global_step_folder) + print(f"[FullyAsyncTrainer] Load from checkpoint folder: {global_step_folder}") + # set global step + self.current_param_version = int(global_step_folder.split("global_step_")[-1]) + self.global_steps = self.current_param_version * self.trigger_parameter_sync_step + 1 + self.last_ckpt_version = self.current_param_version + print( + f"[FullyAsyncTrainer] Setting global step to {self.global_steps}, " + f"current_param_version to {self.current_param_version}" + ) + print(f"[FullyAsyncTrainer] Resuming from {global_step_folder}") + + actor_path = os.path.join(global_step_folder, "actor") + critic_path = os.path.join(global_step_folder, str(Role.Critic)) + # load actor + self.actor_rollout_wg.load_checkpoint( + actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + # load critic + if self.use_critic: + self.critic_wg.load_checkpoint( + critic_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load + ) + return self.current_param_version + + def _collect_metrics_from_samples(self, batch, metrics): + """ + Collect metrics from samples + """ + if hasattr(batch, "meta_info") and batch.meta_info: + samples_param_versions = batch.meta_info["rollout_param_versions"] + stale_count = sum(1 for v in samples_param_versions if self.current_param_version - v >= 1) + self.stale_samples_processed += stale_count + trajectory_param_versions = batch.meta_info["trajectory_param_versions"] + stale_traj_count = sum(1 for v in trajectory_param_versions if self.current_param_version - v >= 1) + self.stale_trajectory_processed += stale_traj_count + metrics.update( + { + "fully_async/count/stale_samples_processed": self.stale_samples_processed, + "fully_async/count/stale_trajectory_processed": self.stale_trajectory_processed, + "fully_async/count/current_param_version": self.current_param_version, + } + ) + for key, value in batch.meta_info.items(): + if key.startswith("fully_async") or key.startswith("timing_s"): + metrics[key] = value + + def _trigger_parameter_sync_after_step(self, validate: bool = False, global_steps: int = None): + """ + Trigger parameter synchronization after training step + This ensures rollouter always uses the latest trained parameters + """ + if self.local_trigger_step < self.trigger_parameter_sync_step and not validate: + self.local_trigger_step += 1 + return + + self.current_param_version += 1 + self.local_trigger_step = 1 + self.logger.log( + data=self.metrics_aggregator.get_aggregated_metrics(), + step=self.current_param_version, + ) + self.progress_bar.update(1) + self.metrics_aggregator.reset() + timing_param_sync = {} + with marked_timer("timing_s/wait_last_valid", timing_param_sync): + ray.get(self.param_synchronizer.wait_last_valid.remote()) + with marked_timer("timing_s/param_sync", timing_param_sync): + ray.get( + self.param_synchronizer.sync_weights.remote( + self.current_param_version, validate=validate, global_steps=global_steps + ) + ) + self.logger.log(data=timing_param_sync, step=self.current_param_version) + + def _log_validation_data(self): + """ + Log validation data + """ + val_data = self.message_queue_client.get_validate_sync() + if not val_data: + return + + val_metrics: ValidateMetrics = ray.cloudpickle.loads(val_data) + if val_metrics.metrics: + self.logger.log(data=val_metrics.metrics, step=val_metrics.param_version) + pprint( + f"[FullyAsyncTrainer] parameter version: {val_metrics.param_version} " + f"Validation metrics: {val_metrics.metrics}" + ) + self.logger.log(data=val_metrics.timing_raw, step=val_metrics.param_version) diff --git a/src/reasoning360/recipe/fully_async_policy/megatron_utils.py b/src/reasoning360/recipe/fully_async_policy/megatron_utils.py new file mode 100644 index 000000000..9f5380f25 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/megatron_utils.py @@ -0,0 +1,99 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core.distributed import DistributedDataParallel as DDP + + +@torch.no_grad() +def copy_megatron_model_to_cpu(models): + """ + Copy Megatron model parameters to CPU memory (non-destructive copy). + Unlike offload_megatron_model_to_cpu which moves data, this function creates + independent copies on CPU while keeping GPU data intact. + + Args: + models: List of model chunks (DDP-wrapped or unwrapped) + + Returns: + dict: CPU state containing copied parameters and buffers + """ + cpu_state = {} + + for model_idx, model_chunk in enumerate(models): + if isinstance(model_chunk, DDP): + # Handle DDP-wrapped models + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + buffer_states = [] + + for buffers in model_chunk_all_buffers: + buffer_list = [] + for buffer in buffers: + buffer_state = {} + + # Copy parameter data to CPU + if buffer.param_data.storage().size() > 0: + buffer_state["param_data"] = buffer.param_data.data.cpu().clone().pin_memory() + + buffer_list.append(buffer_state) + buffer_states.append(buffer_list) + + cpu_state[f"model_chunk_{model_idx}"] = {"buffer_states": buffer_states, "is_ddp": True} + else: + # Handle non-DDP models (ref module) + model_state = {} + for name, param in model_chunk.named_parameters(): + param_state = {"data": param.data.cpu().clone().pin_memory()} + model_state[name] = param_state + + cpu_state[f"model_chunk_{model_idx}"] = {"model_state": model_state, "is_ddp": False} + + return cpu_state + + +@torch.no_grad() +def restore_megatron_model_from_cpu(models, cpu_state): + """ + Restore Megatron model parameters from CPU memory back to GPU. + + Args: + models: List of model chunks to restore to + cpu_state: CPU state dict returned from copy_megatron_model_to_cpu + """ + for model_idx, model_chunk in enumerate(models): + chunk_key = f"model_chunk_{model_idx}" + if chunk_key not in cpu_state: + continue + + chunk_state = cpu_state[chunk_key] + + if chunk_state["is_ddp"] and isinstance(model_chunk, DDP): + # Restore DDP buffers + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + buffer_states = chunk_state["buffer_states"] + + for buffers, buffer_list in zip(model_chunk_all_buffers, buffer_states, strict=False): + for buffer, buffer_state in zip(buffers, buffer_list, strict=False): + # Restore parameter data + if "param_data" in buffer_state: + buffer.param_data.data.copy_(buffer_state["param_data"].to(buffer.param_data.device)) + + elif not chunk_state["is_ddp"] and not isinstance(model_chunk, DDP): + # Restore non-DDP models + model_state = chunk_state["model_state"] + for name, param in model_chunk.named_parameters(): + if name in model_state: + param_state = model_state[name] + param.data.copy_(param_state["data"].to(param.device)) diff --git a/src/reasoning360/recipe/fully_async_policy/megatron_worker.py b/src/reasoning360/recipe/fully_async_policy/megatron_worker.py new file mode 100644 index 000000000..0ca85f6f9 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/megatron_worker.py @@ -0,0 +1,156 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# Copyright 2025 NVIDIA Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch +import torch.distributed +from omegaconf import DictConfig + +from .megatron_utils import copy_megatron_model_to_cpu, restore_megatron_model_from_cpu +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils.device import ( + get_device_name, + get_torch_device, +) +from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator +from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + +__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"] + + +def get_inference_model(rollout): + """ + get models according to different types of inference_engine + Args: + rollout: rollout object + Returns: + model: model object + """ + inference_engine = rollout.inference_engine + if hasattr(inference_engine, "llm_engine"): + inference_model = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + elif hasattr(inference_engine, "worker"): + inference_model = inference_engine.worker.model_runner.model + else: + raise AttributeError( + f"Unsupported inference_engine type: {type(inference_engine)}. " + f"Expected LLM (with llm_engine attribute) or WorkerWrapperBase (with worker attribute)." + ) + return inference_model + + +class DetachNcclSync(AsyncActorRolloutRefWorker): + def _get_actor_params(self): + pass + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def sync_rollout_weights(self): + assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine + assert hasattr(self, "_weights_info") and self._weights_info is not None + if self._is_actor and self._is_offload_param: + load_megatron_model_to_gpu(self.actor_module) + params_generator = self._get_actor_params_generator() if self._is_actor else None + if self._is_rollout: + inference_model = get_inference_model(self.rollout) + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader + + patch_vllm_moe_model_weight_loader(inference_model) + for key, shape, dtype in self._weights_info: + if self._is_actor: + weight_key, weight = next(params_generator) + assert key == weight_key + assert shape == weight.size() + assert dtype == weight.dtype + + tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) + if self._is_actor and torch.distributed.get_rank() == 0: + tensor.copy_(weight) + from ray.util.collective import collective + + collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") + if self._is_rollout: + inference_model.load_weights([(key, tensor)]) + if self._is_actor and self._is_offload_param: + offload_megatron_model_to_cpu(self.actor_module) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_model_to_cpu(self, n): + if not hasattr(self, "cpu_saved_models"): + self.cpu_saved_models = {} + self.cpu_saved_models[n] = copy_megatron_model_to_cpu(self.actor.actor_module) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def restore_model_from_cpu(self, n): + if n in self.cpu_saved_models: + restore_megatron_model_from_cpu(self.actor.actor_module, self.cpu_saved_models[n]) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def clear_cpu_model(self, n): + if n in self.cpu_saved_models: + del self.cpu_saved_models[n] + + +class DetachActorWorker(DetachNcclSync): + def _get_actor_params_generator(self): + assert self._is_actor + if self.bridge is not None: + generator = self.bridge.export_weights(self.actor.actor_module) + else: + generator = per_tensor_generator( + self.actor.actor_module, + self.actor_model_config, + self.weight_converter, + self.tf_config, + self.layer_name_mapping, + ) + + return generator + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_actor_weights_info(self): + assert self._is_actor + if hasattr(self, "_weights_info"): + return self._weights_info + if self._is_offload_param: + load_megatron_model_to_gpu(self.actor_module) + params_generator = self._get_actor_params_generator() + ret = [] + for key, tensor in params_generator: + ret.append((key, tensor.size(), tensor.dtype)) + + self._weights_info = ret + # Here, we only call this function at the beginning, + # and immediately afterwards we call sync_rollout_weights. + # So we no longer call offload in this. + return ret + + +class DetachAsyncRolloutWorker(DetachNcclSync): + def __init__(self, config: DictConfig, role: str): + print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") + ActorRolloutRefWorker.__init__(self, config, role) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_actor_weights_info(self, weights_info): + assert self._is_rollout + self._weights_info = weights_info diff --git a/src/reasoning360/recipe/fully_async_policy/message_queue.py b/src/reasoning360/recipe/fully_async_policy/message_queue.py new file mode 100644 index 000000000..85860c6f2 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/message_queue.py @@ -0,0 +1,265 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +from collections import deque +from typing import Any + +import ray +from omegaconf import DictConfig + +logger = logging.getLogger(__name__) + + +@ray.remote(num_cpus=2, max_concurrency=20) +class MessageQueue: + """ + Simplified Ray-based asynchronous message queue for communication between Rollouter and Trainer + """ + + def __init__(self, config: DictConfig, max_queue_size: int = 1000): + self.config = config + if max_queue_size is None: + raise ValueError(f"max_queue_size cannot be None, got: {max_queue_size}") + self.max_queue_size = int(max_queue_size) + self.queue = deque(maxlen=self.max_queue_size) + self.current_param_version = 0 + + self.val_queue = deque() + + try: + if hasattr(config, "async_training") and config.async_training is not None: + self.staleness_threshold = getattr(config.async_training, "staleness_threshold", 3) + else: + self.staleness_threshold = 3 + except (AttributeError, RecursionError): + self.staleness_threshold = 3 + + # Asyncio for message handling + self.running = True + + # async safe + self._lock = asyncio.Lock() + self._consumer_condition = asyncio.Condition(self._lock) + + # statistic message + self.total_produced = 0 + self.total_consumed = 0 + self.dropped_samples = 0 + + print( + f"[MessageQueue] initialized with max_queue_size={max_queue_size}," + f"staleness_threshold={self.staleness_threshold}" + ) + + async def put_sample(self, sample: Any, param_version: int) -> bool: + """ + Put a batch sample into the queue + + Args: + sample: Sample data + param_version: Parameter version number + + Returns: + bool: Whether the sample was successfully put into the queue + """ + async with self._lock: + # If queue is full, remove the oldest sample (rarely happens) + is_drop = False + if len(self.queue) >= self.max_queue_size: + self.queue.popleft() + self.dropped_samples += 1 + is_drop = True + logger.warning("Queue full, dropped sample") + self.queue.append(sample) + self.total_produced += 1 + + # Notify waiting consumers + self._consumer_condition.notify_all() + + if self.total_produced % 100 == 0: + print(f"MessageQueue stats: produced={self.total_produced}, queue_size={len(self.queue)}") + if is_drop: + return False + return True + + async def get_sample(self) -> Any | None: + """ + Get a single sample from the queue, wait until one is available + + Returns: + Any: Single sample data or None if queue is closed + """ + async with self._lock: + while len(self.queue) == 0 and self.running: + await self._consumer_condition.wait() + + # If queue is closed and empty, return None + if not self.running and len(self.queue) == 0: + return None + + # Get one sample + data = self.queue.popleft() + self.total_consumed += 1 + return data, len(self.queue) + + async def update_param_version(self, version: int): + """Update current parameter version""" + async with self._lock: + old_version = self.current_param_version + self.current_param_version = version + print(f"Parameter version updated from {old_version} to {version}") + + async def get_queue_size(self) -> int: + """Get current queue length""" + async with self._lock: + return len(self.queue) + + async def get_statistics(self) -> dict[str, Any]: + """Get queue statistics""" + async with self._lock: + return { + "queue_size": len(self.queue), + "total_produced": self.total_produced, + "total_consumed": self.total_consumed, + "dropped_samples": self.dropped_samples, + "current_param_version": self.current_param_version, + "staleness_threshold": self.staleness_threshold, + "max_queue_size": self.max_queue_size, + } + + async def clear_queue(self): + """Clear the queue""" + async with self._lock: + cleared_count = len(self.queue) + self.queue.clear() + logger.info(f"Cleared {cleared_count} samples from queue") + + async def shutdown(self): + """Shutdown the message queue""" + async with self._lock: + self.running = False + # Notify all waiting coroutines so they can exit + self._consumer_condition.notify_all() + logger.info("MessageQueue shutdown") + + async def get_memory_usage(self) -> dict: + """Get memory usage statistics""" + async with self._lock: + # Estimate memory usage of samples in queue + import sys + + total_size = 0 + sample_count = len(self.queue) + + if sample_count > 0: + # Estimate size of a single sample (simplified estimation) + sample = list(self.queue)[0] + try: + sample_size = sys.getsizeof(sample) + # Since we now store RolloutSample directly, estimate based on its components + if hasattr(sample, "original_batch_dict") and sample.original_batch_dict: + # Estimate batch data size + batch_data = sample.original_batch_dict.get("batch", {}) + sample_size += len(batch_data) * 1000 # Roughly estimate 1KB per batch entry + if hasattr(sample, "agent_loop_output"): + # Estimate AgentLoopOutput size + sample_size += 5000 # Roughly estimate 5KB for AgentLoopOutput + total_size = sample_size * sample_count + except Exception: + total_size = sample_count * 15000 # Roughly estimate 15KB per RolloutSample + + return { + "queue_samples": sample_count, + "estimated_memory_bytes": total_size, + "estimated_memory_mb": total_size / (1024 * 1024), + } + + async def put_validate(self, data): + async with self._lock: + self.val_queue.append(data) + + async def get_validate(self): + async with self._lock: + if self.val_queue: + return self.val_queue.popleft() + else: + return None + + +class MessageQueueClient: + """Asyncio-compatible MessageQueue client for communicating with MessageQueue Actor""" + + def __init__(self, queue_actor: Any): + self.queue_actor = queue_actor + + async def put_sample(self, sample: Any, param_version: int) -> bool: + """Put batch into queue (async)""" + future = self.queue_actor.put_sample.remote(sample, param_version) + return await asyncio.wrap_future(future.future()) + + async def put_validate(self, data: Any) -> bool: + future = self.queue_actor.put_validate.remote(data) + return await asyncio.wrap_future(future.future()) + + def get_validate_sync(self) -> Any | None: + return ray.get(self.queue_actor.get_validate.remote()) + + async def get_sample(self) -> Any | None: + """Get single sample from queue, wait until one is available (async)""" + future = self.queue_actor.get_sample.remote() + return await asyncio.wrap_future(future.future()) + + async def get_queue_size(self) -> int: + """Get queue size (async)""" + future = self.queue_actor.get_queue_size.remote() + return await asyncio.wrap_future(future.future()) + + async def get_statistics(self) -> dict[str, Any]: + """Get statistics (async)""" + future = self.queue_actor.get_statistics.remote() + return await asyncio.wrap_future(future.future()) + + async def clear_queue(self): + """Clear queue (async)""" + future = self.queue_actor.clear_queue.remote() + await asyncio.wrap_future(future.future()) + + async def shutdown(self): + """Shutdown queue (async)""" + future = self.queue_actor.shutdown.remote() + await asyncio.wrap_future(future.future()) + + async def get_memory_usage(self) -> dict: + """Get memory usage statistics (async)""" + future = self.queue_actor.get_memory_usage.remote() + return await asyncio.wrap_future(future.future()) + + # Synchronous version of the method (deprecated) + def put_sample_sync(self, sample: Any, param_version: int) -> bool: + """Put batch into queue (sync - deprecated, use put_sample instead)""" + return ray.get(self.queue_actor.put_sample.remote(sample, param_version)) + + def get_sample_sync(self) -> Any | None: + """Get single sample from queue (sync - deprecated, use get_sample instead)""" + return ray.get(self.queue_actor.get_sample.remote()) + + def get_statistics_sync(self) -> dict[str, Any]: + """Get statistics (sync - deprecated, use get_statistics instead)""" + return ray.get(self.queue_actor.get_statistics.remote()) + + def update_param_version_sync(self, version: int): + """Update parameter version (async)""" + return ray.get(self.queue_actor.update_param_version.remote(version)) diff --git a/src/reasoning360/recipe/fully_async_policy/param_sync.py b/src/reasoning360/recipe/fully_async_policy/param_sync.py new file mode 100644 index 000000000..d67ff67fd --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/param_sync.py @@ -0,0 +1,113 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +import ray +from ray.util.collective import collective + +from verl.utils.device import get_nccl_backend + +logger = logging.getLogger(__name__) + + +@ray.remote +class ParameterSynchronizer: + """ + Unified parameter synchronizer, responsible for synchronizing model parameters between actor and rollout + Based on the mature synchronization mode implementation of one_step_off_policy + Merges the functions of the original multiple synchronizer classes + """ + + def __init__(self, config, trainer, rollouter, mq): + self.config = config + self.trainer = trainer + self.rollouter = rollouter + self.mq_client = mq + self.actor_wg = ray.get(trainer.get_actor_wg.remote()) + self.rollout_wg = ray.get(rollouter.get_rollout_wg.remote()) + + # Basic attributes + self.weights_info = None + self.sync_group_initialized = False + self.sync_group_name = "actor_rollout" + self.wait_last_update = None + self.wait_last_resume = None + + # Statistics + self.current_version = 0 + + self._init_weights_info() + self._init_sync_group() + + def get_current_param_version(self) -> int: + """Get current parameter version number""" + return self.current_version + + def get_weights_info(self): + """Get weights info""" + return self.weights_info + + def _init_weights_info(self): + self.weights_info = self.actor_wg.get_actor_weights_info()[0] + self.rollout_wg.set_actor_weights_info(self.weights_info) + + def _init_sync_group(self): + print("[ParameterSynchronizer] Initializing parameter synchronization group...") + actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers + collective.create_collective_group( + actor_rollout_workers, + len(actor_rollout_workers), + list(range(0, len(actor_rollout_workers))), + backend=get_nccl_backend(), + group_name=self.sync_group_name, + ) + + def sync_weights(self, version, validate=False, global_steps=0): + """Sync weights between trainer and rollouter, and update parameter version""" + start_time = time.time() + + self.current_version = version + print(f"[ParameterSynchronizer] Starting weight synchronization (version {self.current_version})...") + + ray.get(self.rollouter.pause.remote()) + + print(f"[ParameterSynchronizer] rollout paused. cost {time.time() - start_time:.2f} seconds") + # Update MQ version + self.mq_client.update_param_version_sync(version) + + # sync weights + self.actor_wg.sync_rollout_weights() + ray.get(self.rollout_wg.sync_rollout_weights()) + end_time = time.time() + print(f"[ParameterSynchronizer] sync_weights success. cost {end_time - start_time:.2f} seconds") + + # Async Update rollout version & validation + self.wait_last_update = self.rollouter.update_param_version.remote(version, validate, global_steps) + self.wait_last_resume = self.rollouter.resume.remote(self.wait_last_update) + + def wait_last_valid(self): + print("[ParameterSynchronizer] Waiting last sync and validate...") + start_time = time.time() + if self.wait_last_update: + ray.get(self.wait_last_update) + if self.wait_last_resume: + ray.get(self.wait_last_resume) + print(f"[ParameterSynchronizer] Wait last validate cost: {time.time() - start_time:.2f} seconds") + + def rollouter_save_checkpoint(self, local_global_step_folder: str): + """Trigger rollout to save checkpoint(dataloader)""" + print(f"[ParameterSynchronizer] Triggering checkpoint save at {local_global_step_folder} ...") + return ray.get(self.rollouter.save_checkpoint.remote(local_global_step_folder)) diff --git a/src/reasoning360/recipe/fully_async_policy/ray_trainer.py b/src/reasoning360/recipe/fully_async_policy/ray_trainer.py new file mode 100644 index 000000000..4028996ea --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/ray_trainer.py @@ -0,0 +1,550 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import uuid +from copy import deepcopy +from pprint import pprint + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.single_controller.ray import RayClassWithInitArgs +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from reasoning360.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + compute_difficulty_histogram_metrics, +) +from reasoning360.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask +from reasoning360.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.trainer.ppo.utils import Role +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.metric import ( + reduce_metrics, +) +from verl.utils.rollout_skip import RolloutSkip + + +class FullyAsyncRayPPOTrainer(RayPPOTrainer): + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self._init_resource_pools() + self._create_worker_classes() + self._init_worker_groups() + self._init_models() + self._init_async_rollout_manager() + + def _init_resource_pools(self): + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + def _create_worker_classes(self): + self._create_actor_rollout_classes() + self._create_critic_class() + self._create_reference_policy_class() + self._create_reward_model_class() + + def _create_actor_rollout_classes(self): + raise NotImplementedError + + def _create_critic_class(self): + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cfg = omega_conf_to_dataclass(self.config.critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls + + def _create_reference_policy_class(self): + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role=str(Role.RefPolicy), + # profile_option=self.config.trainer.npu_profile.options, + ) + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls + + def _create_reward_model_class(self): + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls + + def _init_worker_groups(self): + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + self.all_wg = all_wg + + def _init_models(self): + if self.use_critic: + self.critic_wg = self.all_wg[str(Role.Critic)] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = self.all_wg[str(Role.RewardModel)] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = self.all_wg[str(Role.ActorRollout)] + self.actor_rollout_wg.init_model() + + def _init_async_rollout_manager(self): + pass + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + + batch, gen_batch = self._prepare_generate_batch(batch_dict) + + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + else: + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + if self.reward_fn is None: + raise ValueError("A reward_fn is required for REMAX advantage estimation.") + + with marked_timer("gen_max", timing_raw, color="purple"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + if not self.async_rollout_mode: + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + else: + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + batch = batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + batch = self._post_generate_batch(batch, gen_batch_output, metrics) + batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw) + self._log_rollout(batch, reward_extra_infos_dict, timing_raw) + + last_val_metrics = self._validate_metrics(is_last_step, last_val_metrics, metrics, timing_raw) + self._check_save_checkpoint(is_last_step, timing_raw) + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + self._collect_metrics(batch, epoch, metrics, timing_raw) + self._post_batch_processing(batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + def _prepare_generate_batch(self, batch_dict): + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) + + gen_batch = self._get_gen_batch(batch) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + return batch, gen_batch + + def _post_generate_batch(self, batch, gen_batch_output, metrics): + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + return batch + + def _process_batch_common(self, batch, metrics, timing_raw, local_trigger_step=None): + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + # NOTE: + # Do NOT pass `self.reward_fn` across processes when it may capture dynamically + # imported functions (e.g. from `custom_reward_function`), because Ray workers + # won't have the `custom_module` import available on unpickle. + # Instead, let the remote task load the reward manager locally from `config`. + future_reward = compute_reward_async.remote( + data=batch, + config=self.config, + tokenizer=self.tokenizer, + reward_fn=None, + ) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + with marked_timer("old_log_prob", timing_raw, color="blue"): + + def compute_old_log_prob(batch): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + actor_config = self.config.actor_rollout_ref.actor + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=actor_config.loss_agg_mode, + loss_scale_factor=actor_config.loss_scale_factor, + ) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + return batch + + async_training = self.config.get("async_training", None) + if async_training and async_training.use_rollout_log_probs: + # If local_triger_step == 1, load the training engine's parameters to the CPU + # and save a copy for subsequent MIS use. + # If local_trigger_step == 2, 3, ..., restore the parameters of version 1 to calculate the old_log_prob, + # then restore the parameters of the current version. + if local_trigger_step == 1: + self.actor_rollout_wg.save_model_to_cpu(1) + batch = compute_old_log_prob(batch) + elif local_trigger_step is not None: + self.actor_rollout_wg.save_model_to_cpu(local_trigger_step) + self.actor_rollout_wg.restore_model_from_cpu(1) + batch = compute_old_log_prob(batch) + self.actor_rollout_wg.restore_model_from_cpu(local_trigger_step) + self.actor_rollout_wg.clear_cpu_model(local_trigger_step) + else: + batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"] + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + + else: + batch = compute_old_log_prob(batch) + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer("ref", timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # Compute rollout correction weights centrally (once per batch) + # This corrects for off-policy issues (policy mismatch, model staleness, etc.) + # Also computes off-policy diagnostic metrics (KL, PPL, etc.) + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + + # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + return batch, reward_extra_infos_dict + + def _log_rollout(self, batch, reward_extra_infos_dict, timing_raw): + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch] + + if "request_id" in batch.non_tensor_batch: + reward_extra_infos_dict.setdefault( + "request_id", + batch.non_tensor_batch["request_id"].tolist(), + ) + + self._dump_generations( + inputs=inputs, + outputs=outputs, + gts=sample_gts, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=rollout_data_dir, + ) + + def _validate_metrics(self, is_last_step, last_val_metrics, metrics, timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + return last_val_metrics + + def _collect_metrics(self, batch, epoch, metrics, timing_raw): + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_difficulty_histogram_metrics(batch=batch, config=self.config)) + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + def _post_batch_processing(self, batch: DataProto): + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/recipe/prime/__init__.py b/src/reasoning360/recipe/fully_async_policy/vllm_rollout/__init__.py similarity index 91% rename from recipe/prime/__init__.py rename to src/reasoning360/recipe/fully_async_policy/vllm_rollout/__init__.py index 6b76ea65c..9cd3ed5b8 100644 --- a/recipe/prime/__init__.py +++ b/src/reasoning360/recipe/fully_async_policy/vllm_rollout/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 PRIME team and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/reasoning360/recipe/fully_async_policy/vllm_rollout/vllm_async_server.py b/src/reasoning360/recipe/fully_async_policy/vllm_rollout/vllm_async_server.py new file mode 100644 index 000000000..6a44fdc07 --- /dev/null +++ b/src/reasoning360/recipe/fully_async_policy/vllm_rollout/vllm_async_server.py @@ -0,0 +1,149 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +from typing import Any, Optional, Sequence + +import ray +from ray.actor import ActorHandle +from vllm import SamplingParams +from vllm.inputs import TokensPrompt +from vllm.outputs import RequestOutput + +from verl.workers.config import HFModelConfig, RolloutConfig +from verl.workers.rollout.replica import RolloutMode +from verl.workers.rollout.vllm_rollout.vllm_async_server import ( + _qwen2_5_vl_dedup_image_tokens, + vLLMHttpServerBase, + vLLMReplica, +) + +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) + + +@ray.remote(num_cpus=1) +class vLLMHttpServerForPartial(vLLMHttpServerBase): + def __init__( + self, + config: RolloutConfig, + model_config: HFModelConfig, + rollout_mode: RolloutMode, + workers: list[ActorHandle], + replica_rank: int, + node_rank: int, + gpus_per_node: int, + nnodes: int, + ): + super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes) + + # for cancel LLMServer + self.paused = False + self.lock = asyncio.Lock() + self.cancel_event: dict[str, asyncio.Event] = {} + self.req_output: dict[str, Optional[RequestOutput]] = {} + + async def _generate_step( + self, + prompt_ids: list[int], + sampling_params: dict[str, Any], + request_id: str, + image_data: Optional[list[Any]] = None, + ): + max_tokens = self.config.max_model_len - len(prompt_ids) + sampling_params["logprobs"] = 1 + sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0)) + sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params) + prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor) + prompt = TokensPrompt( + prompt_token_ids=prompt_ids, multi_modal_data={"image": image_data} if image_data else None + ) + generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id) + + # Get final response + async for output in generator: + self.req_output[request_id] = output + assert self.req_output[request_id] is not None + + async def generate_for_partial( + self, + prompt_ids: list[int], + sampling_params: dict[str, Any], + request_id: str, + image_data: Optional[list[Any]] = None, + ) -> tuple[list[Any], list[Any], bool] | tuple[Sequence[int], list[float], Any]: + async with self.lock: + if self.paused: + # After cancel, all tasks will return directly and wait for the next submission + return [], [], True + self.req_output[request_id]: Optional[RequestOutput] = None + self.cancel_event[request_id] = asyncio.Event() + cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait()) + generation_handle = asyncio.create_task( + self._generate_step(prompt_ids, sampling_params, request_id, image_data) + ) + + done, pend = await asyncio.wait([generation_handle, cancel_handle], return_when=asyncio.FIRST_COMPLETED) + + for task in done: + await task + + for task in pend: + task.cancel() + + async with self.lock: + if self.req_output[request_id] is None: + return [], [], True + token_ids = self.req_output[request_id].outputs[0].token_ids + log_probs: list[float] = [] + for i, x in enumerate(self.req_output[request_id].outputs[0].logprobs): + # In sampling_params, logprobs is set to 1, which should return 1, + # but in practice there are multiple. Take the log_prob corresponding to token_id + token_id = self.req_output[request_id].outputs[0].token_ids[i] + log_probs.append(x[token_id].logprob) + is_cancel = generation_handle not in done + self.cancel_event.pop(request_id, None) + self.req_output.pop(request_id, None) + return token_ids, log_probs, is_cancel + + async def cancel(self): + async with self.lock: + self.paused = True + for request_id in self.cancel_event: + self.cancel_event[request_id].set() + + async def resume(self): + async with self.lock: + self.paused = False + + +class FullyAsyncvLLMReplica(vLLMReplica): + def __init__( + self, + replica_rank: int, + config: RolloutConfig, + model_config: HFModelConfig, + gpus_per_node: int = 8, + is_reward_model: bool = False, + ): + super().__init__(replica_rank, config, model_config, gpus_per_node, is_reward_model) + self.server_class = vLLMHttpServerForPartial + + async def cancel(self): + """Cancel each rollout server.""" + await asyncio.gather(*[server.cancel.remote() for server in self.servers]) + + async def resume(self): + """Resume each rollout server.""" + await asyncio.gather(*[server.resume.remote() for server in self.servers]) diff --git a/src/reasoning360/trainer/__init__.py b/src/reasoning360/trainer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/verl/workers/reward_model/__init__.py b/src/reasoning360/trainer/config/__init__.py similarity index 79% rename from verl/workers/reward_model/__init__.py rename to src/reasoning360/trainer/config/__init__.py index db412bd24..402475c3f 100644 --- a/verl/workers/reward_model/__init__.py +++ b/src/reasoning360/trainer/config/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import BasePPORewardModel +from . import algorithm, config +from .algorithm import * # noqa: F401 +from .config import * # noqa: F401 -__all__ = ["BasePPORewardModel"] +__all__ = config.__all__ + algorithm.__all__ diff --git a/src/reasoning360/trainer/config/_generated_ppo_megatron_trainer.yaml b/src/reasoning360/trainer/config/_generated_ppo_megatron_trainer.yaml new file mode 100644 index 000000000..b40d462d4 --- /dev/null +++ b/src/reasoning360/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -0,0 +1,664 @@ +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job --config-name=ppo_megatron_trainer.yaml' to flatten the 'verl/trainer/config/ppo_megatron_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. + +actor_rollout_ref: + actor: + optim: + _target_: verl.workers.config.McoreOptimizerConfig + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + optimizer: adam + lr_warmup_init: 0.0 + lr_decay_steps: null + lr_decay_style: constant + min_lr: 0.0 + weight_decay_incr_style: constant + lr_wsd_decay_style: exponential + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: false + override_optimizer_config: {} + megatron: + _target_: verl.workers.config.McoreEngineConfig + param_offload: false + grad_offload: false + optimizer_offload: false + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: null + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: true + use_dist_checkpointing: false + dist_checkpointing_path: null + dist_checkpointing_prefix: '' + seed: 42 + override_ddp_config: {} + override_transformer_config: + recompute_granularity: null + recompute_modules: + - core_attn + recompute_method: null + recompute_num_layers: null + attention_backend: flash + override_mcore_model_config: {} + use_mbridge: false + vanilla_mbridge: true + use_remove_padding: true + forward_only: false + dtype: bfloat16 + _target_: verl.workers.config.McoreActorConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: megatron + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + freeze_vision_tower: false + policy_loss: + _target_: verl.workers.config.PolicyLossConfig + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + loss_scale_factor: null + entropy_coeff: 0 + calculate_entropy: false + use_kl_loss: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level1 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + step_start: 0 + step_end: null + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + data_loader_seed: 42 + load_weight: true + ref: + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: megatron + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level1 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + step_start: 0 + step_end: null + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + megatron: + _target_: verl.workers.config.McoreEngineConfig + param_offload: ${oc.select:actor_rollout_ref.actor.megatron.param_offload,False} + grad_offload: false + optimizer_offload: false + tensor_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.tensor_model_parallel_size,1} + expert_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.expert_model_parallel_size,1} + expert_tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.expert_tensor_parallel_size,null} + pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.pipeline_model_parallel_size,1} + virtual_pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size,null} + context_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.context_parallel_size,1} + sequence_parallel: true + use_distributed_optimizer: true + use_dist_checkpointing: false + dist_checkpointing_path: null + dist_checkpointing_prefix: '' + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + override_ddp_config: {} + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + override_mcore_model_config: {} + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} + forward_only: true + dtype: bfloat16 + _target_: verl.workers.config.McoreActorConfig + load_weight: true + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + mode: async + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + ignore_eos: false + enforce_eager: false + cudagraph_capture_sizes: null + free_cache_engine: true + tensor_model_parallel_size: 2 + data_parallel_size: 1 + expert_parallel_size: 1 + pipeline_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + enable_chunked_prefill: true + enable_prefix_caching: true + load_format: dummy + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + over_sample_rate: 0 + multi_stage_wake_up: false + engine_kwargs: + vllm: {} + sglang: {} + val_kwargs: + _target_: verl.workers.config.SamplingConfig + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false + multi_turn: + _target_: verl.workers.config.MultiTurnConfig + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + num_repeat_rollouts: null + calculate_log_probs: false + agent: + _target_: verl.workers.config.AgentLoopConfig + num_workers: 8 + default_agent_loop: single_turn_agent + agent_loop_config_path: null + custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig + path: null + name: null + update_weights_bucket_megabytes: 512 + trace: + _target_: verl.workers.config.TraceConfig + backend: null + token2text: false + max_samples_per_step_per_worker: null + skip_rollout: false + skip_dump_dir: /tmp/rollout_dump + skip_tokenizer_init: true + enable_rollout_routing_replay: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + prometheus: + _target_: verl.workers.config.PrometheusConfig + enable: false + port: 9090 + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + quantization: null + layer_name_map: + qkv_layer_name: qkv + gate_proj_layer_name: gate_up + model: + _target_: verl.workers.config.HFModelConfig + path: ~/models/deepseek-llm-7b-chat + hf_config_path: null + tokenizer_path: null + use_shm: false + trust_remote_code: false + custom_chat_template: null + external_lib: null + override_config: + model_config: {} + moe_config: + freeze_moe_router: false + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + lora_adapter_path: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + lora: + type: lora + rank: 0 + alpha: 32 + dropout: 0.0 + target_modules: + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 + exclude_modules: [] + dropout_position: pre + lora_A_init_method: xavier + lora_B_init_method: zero + a2a_experimental: false + dtype: null + adapter_path: null + freeze_vision_model: true + freeze_vision_projection: true + freeze_language_model: true + hybrid_engine: true + nccl_timeout: 600 +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, + null} + return_raw_input_ids: false + return_raw_chat: true + return_full_prompt: false + shuffle: true + seed: null + dataloader_num_workers: 8 + image_patch_size: 14 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null + apply_chat_template_kwargs: {} +reward_manager: + _target_: verl.trainer.config.config.RewardManagerConfig + source: register + name: ${oc.select:reward_model.reward_manager,naive} + module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager +critic: + optim: + _target_: verl.workers.config.McoreOptimizerConfig + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + optimizer: adam + lr_warmup_init: 0.0 + lr_decay_steps: null + lr_decay_style: constant + min_lr: 0.0 + weight_decay_incr_style: constant + lr_wsd_decay_style: exponential + lr_wsd_decay_steps: null + use_checkpoint_opt_param_scheduler: false + override_optimizer_config: {} + megatron: + _target_: verl.workers.config.McoreEngineConfig + param_offload: false + grad_offload: false + optimizer_offload: false + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: null + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: true + use_dist_checkpointing: false + dist_checkpointing_path: null + dist_checkpointing_prefix: '' + seed: 42 + override_ddp_config: {} + override_transformer_config: + recompute_granularity: null + recompute_modules: + - core_attn + recompute_method: null + recompute_num_layers: null + attention_backend: flash + override_mcore_model_config: {} + use_mbridge: false + vanilla_mbridge: true + use_remove_padding: true + forward_only: false + dtype: bfloat16 + _target_: verl.workers.config.McoreCriticConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: megatron + enable: null + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: + model_config: {} + moe_config: + freeze_moe_router: false + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + _target_: verl.trainer.config.BaseModelConfig + lora: + type: lora + rank: 0 + alpha: 32 + dropout: 0.0 + target_modules: + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 + exclude_modules: [] + dropout_position: pre + lora_A_init_method: xavier + lora_B_init_method: zero + a2a_experimental: false + dtype: null + adapter_path: null + freeze_vision_model: true + freeze_vision_projection: true + freeze_language_model: true + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level1 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + step_start: 0 + step_end: null + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + nccl_timeout: 600 + load_weight: true + data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null} +reward_model: + enable: false + enable_resource_pool: false + n_gpus_per_node: 0 + nnodes: 0 + strategy: megatron + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: false + override_config: {} + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + launch_reward_fn_async: false + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + nccl_timeout: 600 + megatron: + _target_: verl.workers.config.MegatronEngineConfig + param_offload: false + tensor_model_parallel_size: 1 + expert_model_parallel_size: 1 + expert_tensor_parallel_size: null + pipeline_model_parallel_size: 1 + virtual_pipeline_model_parallel_size: null + context_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: false + use_dist_checkpointing: false + dist_checkpointing_path: null + dist_checkpointing_prefix: '' + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} + dtype: bfloat16 + load_weight: true +algorithm: + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + rollout_rs_threshold_lower: null + rollout_token_veto_threshold: null + bypass_mode: false + use_policy_gradient: false + rollout_is_batch_normalize: false + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + reweight_method: pow + weight_pow: 2.0 +custom_reward_function: + path: null + name: compute_score +trainer: + balance_batch: true + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + del_local_ckpt_after_load: false + val_before_train: true + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda + rollout_data_dir: null + use_legacy_worker_impl: auto +global_profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: null + steps: null + profile_continuous_steps: false + save_path: outputs/profile + global_tool_config: + nsys: + discrete: false + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + torch_memory: + trace_alloc_max_entries: 100000 + stack_depth: 32 + context: all + stacks: all + kw_args: {} +transfer_queue: + enable: false +ray_kwargs: + ray_init: + num_cpus: null + timeline_json_file: null diff --git a/src/reasoning360/trainer/config/_generated_ppo_trainer.yaml b/src/reasoning360/trainer/config/_generated_ppo_trainer.yaml new file mode 100644 index 000000000..d37965dbc --- /dev/null +++ b/src/reasoning360/trainer/config/_generated_ppo_trainer.yaml @@ -0,0 +1,591 @@ +# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' +# in which it invokes 'python3 scripts/print_cfg.py --cfg job ' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file. +# Do not modify this file directly. +# The file is usually only for reference and never used. + +actor_rollout_ref: + actor: + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + warmup_style: null + override_optimizer_config: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + dtype: bfloat16 + _target_: verl.workers.config.FSDPActorConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + freeze_vision_tower: false + policy_loss: + _target_: verl.workers.config.PolicyLossConfig + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + loss_scale_factor: null + entropy_coeff: 0 + calculate_entropy: false + use_kl_loss: false + use_torch_compile: true + kl_loss_coef: 0.001 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level1 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + step_start: 0 + step_end: null + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} + ref: + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: ${actor_rollout_ref.actor.strategy} + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level1 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + step_start: 0 + step_end: null + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: true + strategy: fsdp + dtype: bfloat16 + _target_: verl.workers.config.FSDPActorConfig + ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + mode: async + temperature: 1.0 + top_k: -1 + top_p: 1 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + ignore_eos: false + enforce_eager: false + cudagraph_capture_sizes: null + free_cache_engine: true + tensor_model_parallel_size: 2 + data_parallel_size: 1 + expert_parallel_size: 1 + pipeline_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + enable_chunked_prefill: true + enable_prefix_caching: true + load_format: dummy + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: null + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 1 + over_sample_rate: 0 + multi_stage_wake_up: false + engine_kwargs: + vllm: {} + sglang: {} + val_kwargs: + _target_: verl.workers.config.SamplingConfig + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false + multi_turn: + _target_: verl.workers.config.MultiTurnConfig + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + num_repeat_rollouts: null + calculate_log_probs: false + agent: + _target_: verl.workers.config.AgentLoopConfig + num_workers: 8 + default_agent_loop: single_turn_agent + agent_loop_config_path: null + custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig + path: null + name: null + update_weights_bucket_megabytes: 512 + trace: + _target_: verl.workers.config.TraceConfig + backend: null + token2text: false + max_samples_per_step_per_worker: null + skip_rollout: false + skip_dump_dir: /tmp/rollout_dump + skip_tokenizer_init: true + enable_rollout_routing_replay: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + prometheus: + _target_: verl.workers.config.PrometheusConfig + enable: false + port: 9090 + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + layered_summon: false + model: + _target_: verl.workers.config.HFModelConfig + path: ~/models/deepseek-llm-7b-chat + hf_config_path: null + tokenizer_path: null + use_shm: false + trust_remote_code: false + custom_chat_template: null + external_lib: null + override_config: {} + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + lora_adapter_path: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + hybrid_engine: true + nccl_timeout: 600 +data: + tokenizer: null + use_shm: false + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + train_max_samples: -1 + val_max_samples: -1 + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: null + tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, + null} + return_raw_input_ids: false + return_raw_chat: true + return_full_prompt: false + shuffle: true + seed: null + dataloader_num_workers: 8 + image_patch_size: 14 + validation_shuffle: false + filter_overlong_prompts: false + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null + apply_chat_template_kwargs: {} +reward_manager: + _target_: verl.trainer.config.config.RewardManagerConfig + source: register + name: ${oc.select:reward_model.reward_manager,naive} + module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager +critic: + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + warmup_style: null + override_optimizer_config: null + model: + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + dtype: bfloat16 + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: {} + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + _target_: verl.workers.config.FSDPCriticModelCfg + use_shm: false + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + _target_: verl.workers.config.FSDPCriticConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + enable: null + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level1 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + step_start: 0 + step_end: null + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} + forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} + ulysses_sequence_parallel_size: 1 + grad_clip: 1.0 +reward_model: + enable: false + enable_resource_pool: false + n_gpus_per_node: 0 + nnodes: 0 + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: false + override_config: {} + use_shm: false + use_remove_padding: false + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + launch_reward_fn_async: false + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + ulysses_sequence_parallel_size: 1 +algorithm: + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + rollout_rs_threshold_lower: null + rollout_token_veto_threshold: null + bypass_mode: false + use_policy_gradient: false + rollout_is_batch_normalize: false + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + reweight_method: pow + weight_pow: 2.0 +custom_reward_function: + path: null + name: compute_score +trainer: + balance_batch: true + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: + - console + - wandb + log_val_generations: 0 + rollout_data_dir: null + validation_data_dir: null + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + val_before_train: true + val_only: false + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + del_local_ckpt_after_load: false + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda + use_legacy_worker_impl: auto +global_profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: null + steps: null + profile_continuous_steps: false + save_path: outputs/profile + global_tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: false + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + torch_memory: + trace_alloc_max_entries: 100000 + stack_depth: 32 + context: all + stacks: all + kw_args: {} +transfer_queue: + enable: false +ray_kwargs: + ray_init: + num_cpus: null + timeline_json_file: null diff --git a/src/reasoning360/trainer/config/actor/actor.yaml b/src/reasoning360/trainer/config/actor/actor.yaml new file mode 100644 index 000000000..f5f1d15ee --- /dev/null +++ b/src/reasoning360/trainer/config/actor/actor.yaml @@ -0,0 +1,242 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# Target class for this configuration +_target_: verl.workers.config.ActorConfig + +# Number of rollouts per update (mirrors actor rollout_n) +rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + +# the abstract actor configs +# fsdp, fsdp2 or megatron. must be set. +strategy: ??? + +# Split each sample into sub-batches of this size for PPO +ppo_mini_batch_size: 256 + +# [Deprecated] Global micro batch size +ppo_micro_batch_size: null + +# Local per-GPU micro batch size +ppo_micro_batch_size_per_gpu: null + +# Whether to automatically adjust batch size at runtime +# oc.select: the default val for ref.log_prob_use_dynamic_bsz +use_dynamic_bsz: false + +# Max tokens per GPU in one PPO batch; affects gradient accumulation +# Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length} +# oc.select: the default val for ref.log_prob_max_token_len_per_gpu +ppo_max_token_len_per_gpu: 16384 + +# PPO clip ratio +clip_ratio: 0.2 + +# Lower bound for asymmetric clipping (used in dual-clip PPO) +clip_ratio_low: 0.2 + +# Upper bound for asymmetric clipping (used in dual-clip PPO) +clip_ratio_high: 0.2 + +# Whether to freeze vision model, if set true, it will be freeze vision model +freeze_vision_tower: false + +# policy loss config +policy_loss: + + # # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.PolicyLossConfig + + # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617 + loss_mode: "vanilla" + + # Ratio of tokens to be clipped for clip-cov loss + clip_cov_ratio: 0.0002 + + # Lower bound for clip-cov loss + clip_cov_lb: 1.0 + + # Upper bound for clip-cov loss + clip_cov_ub: 5.0 + + # Ratio of tokens to be applied kl penalty for kl-cov loss + kl_cov_ratio: 0.0002 + + # KL divergence penalty coefficient + ppo_kl_coef: 0.1 + +# Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C +clip_ratio_c: 3.0 + +# Loss aggregation mode: "token-mean", "seq-mean-token-sum", "seq-mean-token-mean", or "seq-mean-token-sum-norm" +loss_agg_mode: token-mean + +# Scale factor for "seq-mean-token-sum-norm" loss aggregation mode. +# If null, uses response_length. Set to a constant to ensure consistent normalization. +loss_scale_factor: null + +# Entropy regularization coefficient in PPO loss +entropy_coeff: 0 + +# When true, the actor forward will request entropy from the model +calculate_entropy: false + +# Whether to use KL loss instead of KL reward penalty. True for GRPO +use_kl_loss: false + +# Whether to use torch.compile() +# oc.select: the default val for ref.use_torch_compile +use_torch_compile: true + +# KL loss coefficient when use_kl_loss is enabled. For GRPO +kl_loss_coef: 0.001 + +# Type of KL divergence loss. Options: "kl"(k1), "abs", "mse"(k2), "low_var_kl"(k3), "full" +kl_loss_type: low_var_kl + +# Number of PPO epochs per batch +ppo_epochs: 1 + +# Shuffle training data across PPO epochs +shuffle: false + +# checkpoint configs +checkpoint: + + # Target dataclass for this configuration + _target_: verl.trainer.config.CheckpointConfig + + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + + # For more flexibility, you can specify the contents to load from the checkpoint. + # .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg + load_contents: ${.save_contents} + + # Whether to save checkpoints asynchronously. Only effective for Megatron as of now. + async_save: False + +# optimizer configs +optim: + + # Learning rate + lr: 1e-6 + + # Warmup steps ratio (used if lr_warmup_steps is 0 or negative) + lr_warmup_steps_ratio: 0.0 + + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + + # Weight decay + weight_decay: 0.01 + + # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps: -1 + + +# Whether to use custom fused kernels (e.g., FlashAttention, fused MLP) +use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + +# profile the actor model in `update_policy` +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # profiler tool, default same as profiler.tool in global config + # choices: nsys, npu, torch + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on Actor + enable: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config which only related to the role + tool_config: + + # nsys tool config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + + # npu config + npu: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NPUToolConfig + + # Contents to profile, can be empty + # options: npu, cpu, memory, shapes, module, stack + contents: [] + + # Collection level, optional values: level_none, level0, level1, level2. + level: "level1" + + # Whether to automatically parse the data. + analysis: True + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # torch profiler config + torch: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + + # start profile mini-batch in training + # NOTICE: different with global steps config which refers to iteration + # This field only related with mini-batch + step_start: 0 + + # stop profile mini-batch in training + step_end: null + + # torch memory profiler config + torch_memory: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + + # Maximum number of memory allocation entries to track + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + + # Stack trace depth for memory allocations + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + +# Router replay configuration for MoE models +router_replay: + + # Target dataclass for this configuration + _target_: verl.workers.config.RouterReplayConfig + + # Router replay mode: disabled, R2, R3 + # - R2: Use R2 routing strategy (record mode) + # - R3: Use R3 routing strategy (record mode) + mode: disabled + + # File path to save recorded routing decisions + # Required when mode is 'record', 'R2', or 'R3' + record_file: null + + # File path to load recorded routing decisions for replay + # Required when mode is 'replay' + replay_file: null + diff --git a/src/reasoning360/trainer/config/actor/dp_actor.yaml b/src/reasoning360/trainer/config/actor/dp_actor.yaml new file mode 100644 index 000000000..742ea5488 --- /dev/null +++ b/src/reasoning360/trainer/config/actor/dp_actor.yaml @@ -0,0 +1,43 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# defaults specify the default config from each component +defaults: + + # fsdp optimizer config + - ../optim@optim: fsdp + + # fsdp engine config + - ../engine@fsdp_config: fsdp + + # dp actor config, inheriting from trainer/config/actor/actor.yaml + - actor + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# Target class for this configuration +_target_: verl.workers.config.FSDPActorConfig + +# TODO(haibin.lin): switch to fsdp2 +strategy: fsdp + +# Gradient clipping for actor updates, specific to the strategy. +grad_clip: 1.0 + +# Sequence parallelism size for Ulysses-style model parallelism +# oc.select: the default val for ref.ulysses_sequence_parallel_size +# [DEPRECATED] use fsdp_config.ulysses_sequence_parallel_size instead +ulysses_sequence_parallel_size: 1 + +# calculate entropy with chunking to reduce memory peak +entropy_from_logits_with_chunking: False + +# recompute entropy +entropy_checkpointing: False + +# Whether to remove padding tokens in inputs during training +use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} \ No newline at end of file diff --git a/src/reasoning360/trainer/config/actor/megatron_actor.yaml b/src/reasoning360/trainer/config/actor/megatron_actor.yaml new file mode 100644 index 000000000..a632fe438 --- /dev/null +++ b/src/reasoning360/trainer/config/actor/megatron_actor.yaml @@ -0,0 +1,20 @@ +# megatron actor config, inheriting from trainer/config/actor/actor.yaml +defaults: + # megatron optimizer config + - ../optim@optim: megatron + + # megatron engine config + - ../engine@megatron: megatron + + - actor + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +_target_: verl.workers.config.McoreActorConfig + +strategy: megatron + +data_loader_seed: 42 + +load_weight: True diff --git a/src/reasoning360/trainer/config/algorithm.py b/src/reasoning360/trainer/config/algorithm.py new file mode 100644 index 000000000..a7c86da02 --- /dev/null +++ b/src/reasoning360/trainer/config/algorithm.py @@ -0,0 +1,463 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from verl.base_config import BaseConfig + +__all__ = ["AlgoConfig", "FilterGroupsConfig", "KLControlConfig", "RolloutCorrectionConfig"] + + +@dataclass +class KLControlConfig(BaseConfig): + """Configuration for KL control. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + type (str): Type of KL control. Can be "fixed" or "adaptive". + kl_coef (float): Initial coefficient for KL penalty. + horizon (int): Horizon value for adaptive controller. + target_kl (float): Target KL divergence for adaptive controller. + """ + + type: str = "fixed" + kl_coef: float = 0.001 + horizon: int = 10000 + target_kl: float = 0.1 + + +@dataclass +class FilterGroupsConfig(BaseConfig): + """Configuration for filter groups (used in DAPO and Entropy). + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + enable (bool): Whether to enable filter groups. + metric (Optional[str]): Metric to use for filtering: "acc", "score", "seq_reward", "seq_final_reward", etc. + max_num_gen_batches (int): Non-positive values mean no upper limit. + """ + + enable: bool = False + metric: Optional[str] = None + max_num_gen_batches: int = 0 + + +@dataclass +class RolloutCorrectionConfig(BaseConfig): + """Configuration for Rollout Correction (addresses off-policy issues in RL training). + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Rollout Correction handles off-policiness from multiple sources: + 1. Policy mismatch: Rollout policy (e.g., vLLM BF16) vs Training policy (e.g., FSDP FP32) + 2. Model update staleness: Rollout data collected from older policy checkpoints + 3. General off-policy scenarios: Any distribution shift between data collection and training + + For more details, see: + "When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch" + https://richardli.xyz/rl-collapse + + This typed config replaces the old dict-based approach and provides: + - Type safety and validation + - Clear documentation of all parameters + - Named factory methods for common presets (TIS, MIS, etc.) + - Sensible defaults + + Args: + rollout_is (Optional[str]): IS weight aggregation level. + - None: No IS weights (metrics only) + - "token": Per-token IS weights (low variance, biased) + - "sequence": Per-sequence IS weights (unbiased, high variance) + Default: "sequence" + + rollout_is_threshold (float): Upper threshold for IS weight truncation/rejection. + Typical range: 1.5-5.0 for token level, 2.0-10.0 for sequence level. + Default: 2.0 + + rollout_rs (Optional[str]): Rejection sampling aggregation level. + - None: No rejection sampling + - "token": Reject individual tokens with outlier ratios + - "sequence": Reject entire sequences with outlier ratios + - "geometric": Geometric mean aggregation (threshold: 1.0002-1.001) + Default: None (use IS weights without rejection) + + rollout_rs_threshold (Optional[float]): Upper threshold for rejection sampling. + - If None and rollout_rs is enabled, uses rollout_is_threshold + - Tokens/sequences with ratio > threshold are masked out + Default: None (uses rollout_is_threshold when rollout_rs is enabled) + + rollout_rs_threshold_lower (Optional[float]): Lower threshold for rejection sampling. + - If None, uses reciprocal of upper threshold (1/upper) + - Tokens/sequences with ratio < threshold are masked out + Default: None (auto-computed as reciprocal) + + rollout_token_veto_threshold (Optional[float]): Per-token veto for catastrophic outliers. + - Checks unclamped per-token ratios before safety bounds + - If ANY token has ratio < threshold, entire sequence is rejected + - Independent of rollout_is and rollout_rs settings + - Typical values: 1e-4 to 1e-6 when enabled + Default: None (disabled) + + bypass_mode (bool): Operating mode - bypass or decoupled. + - True: Bypass mode - reuse rollout_log_prob as old_log_prob (2 policies) + - False: Decoupled mode - compute old_log_prob separately (3 policies) + Default: False (decoupled mode) + + use_policy_gradient (bool): Loss function type. + - Requires bypass_mode=True + - True: Policy gradient loss (no PPO clipping) + - False: PPO loss (with clipping) + Default: False (PPO loss) + + rollout_is_batch_normalize (bool): Apply batch normalization to IS weights. + - True: Normalize IS weights to have mean=1.0 within each batch + - False: Use raw (truncated) IS weights (standard) + - Reduces variance by ensuring average weight is 1.0 per batch + - Only affects IS weight values, not rejection sampling + Default: False (no batch normalization) + + Example: + # Create with defaults + config = RolloutCorrectionConfig() + + # Decoupled PPO mode presets (3 policies: π_rollout, π_old, π_θ) + # IS weights correct for gap between π_old and π_rollout + config = RolloutCorrectionConfig.decoupled_token_is() # Token-TIS + config = RolloutCorrectionConfig.decoupled_seq_is() # Seq-TIS + config = RolloutCorrectionConfig.decoupled_seq_is_rs() # Seq-MIS + config = RolloutCorrectionConfig.decoupled_geo_rs() # Geo-RS + config = RolloutCorrectionConfig.geo_rs_seq_tis() # Geo-RS-Seq-TIS + + # Bypass PPO mode (2 policies: π_rollout = π_old, π_θ) + # No IS correction needed since π_old = π_rollout + config = RolloutCorrectionConfig.ppo_is_bypass() # PPO with rollout as anchor + + # Bypass PG mode presets (2 policies, no PPO clipping) + # IS weights computed on-the-fly as π_θ / π_rollout + config = RolloutCorrectionConfig.pg_is() # Seq-TIS + PG + config = RolloutCorrectionConfig.pg_rs() # Geo-RS + PG + config = RolloutCorrectionConfig.pg_geo_rs_seq_tis() # Geo-RS-Seq-TIS + PG + + Reference: + Liu, Li, Fu, Wang, Liu, Shen (2025) + "When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch" + https://richardli.xyz/rl-collapse + """ + + rollout_is: Optional[str] = "sequence" + rollout_is_threshold: float = 2.0 + rollout_rs: Optional[str] = None + rollout_rs_threshold: Optional[float] = None + rollout_rs_threshold_lower: Optional[float] = None + rollout_token_veto_threshold: Optional[float] = None + bypass_mode: bool = False + use_policy_gradient: bool = False + rollout_is_batch_normalize: bool = False + + @classmethod + def decoupled_token_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Decoupled Mode with Token-level Importance Sampling. + + IS weight correction at token level in decoupled mode (three policies). + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for decoupled mode with token-level IS + """ + return cls(rollout_is="token", rollout_is_threshold=threshold, rollout_rs=None) + + @classmethod + def decoupled_seq_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Decoupled Mode with Sequence-level Importance Sampling. + + IS weight correction at sequence level in decoupled mode (three policies). + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for decoupled mode with sequence-level IS + """ + return cls(rollout_is="sequence", rollout_is_threshold=threshold, rollout_rs=None) + + @classmethod + def decoupled_seq_is_rs( + cls, + is_threshold: float = 2.0, + rs_threshold: float = 2.0, + rs_threshold_lower: Optional[float] = None, + ) -> "RolloutCorrectionConfig": + """Decoupled Mode with Sequence-level IS + Rejection Sampling. + + Sequence-level IS with sequence-level rejection sampling in decoupled mode. + Rejects entire sequences based on sequence-level IS weight. + + Args: + is_threshold (float): Upper threshold for IS weights. Default: 2.0 + rs_threshold (float): Upper threshold for rejection sampling. Default: 2.0 + rs_threshold_lower (Optional[float]): Lower threshold for rejection sampling. + If None, auto-computed as reciprocal of rs_threshold. Default: None + + Returns: + RolloutCorrectionConfig configured for decoupled mode with sequence IS + RS + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=is_threshold, + rollout_rs="sequence", + rollout_rs_threshold=rs_threshold, + rollout_rs_threshold_lower=rs_threshold_lower, + ) + + @classmethod + def decoupled_geo_rs( + cls, + rs_threshold: float = 1.001, + rs_threshold_lower: Optional[float] = None, + veto_threshold: float = 1e-4, + ) -> "RolloutCorrectionConfig": + """Decoupled Mode with Geometric Rejection Sampling. + + Uses geometric mean for rejection sampling at sequence level in decoupled mode, + with additional veto mechanism. Geometric mean is extremely sensitive to outliers, + requiring very tight thresholds close to 1.0. + + Args: + rs_threshold (float): Geometric RS threshold (upper). Default: 1.001 (±0.1%) + rs_threshold_lower (Optional[float]): Geometric RS threshold (lower). + If None, auto-computed as reciprocal of rs_threshold. Default: None + veto_threshold (float): Per-token veto threshold. Default: 1e-4 + + Returns: + RolloutCorrectionConfig configured for decoupled mode with geometric RS + veto + """ + return cls( + rollout_is=None, + rollout_rs="geometric", + rollout_rs_threshold=rs_threshold, + rollout_rs_threshold_lower=rs_threshold_lower, + rollout_token_veto_threshold=veto_threshold, + ) + + @classmethod + def ppo_is_bypass(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """PPO with IS Correction in Bypass Mode. + + Skips old_log_prob computation by reusing rollout_log_prob. + PPO clips against rollout policy instead of true old policy. + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for PPO_IS bypass mode + """ + return cls( + rollout_is="token", + rollout_is_threshold=threshold, + rollout_rs=None, + bypass_mode=True, + use_policy_gradient=False, + ) + + @classmethod + def pg_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Policy Gradient with IS Correction. + + Uses policy gradient loss with explicit IS correction. + No PPO clipping. + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for PG with IS + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=threshold, + rollout_rs=None, + bypass_mode=True, + use_policy_gradient=True, + ) + + @classmethod + def pg_rs( + cls, + rs_threshold: float = 1.001, + rs_threshold_lower: Optional[float] = None, + veto_threshold: float = 1e-4, + ) -> "RolloutCorrectionConfig": + """Policy Gradient with Rejection Sampling (Geo-RS). + + Policy gradient with geometric rejection sampling (no IS weights) in bypass mode. + Skips old_log_prob computation for faster execution. + + Solves the "Length Trap" problem where standard IS estimators penalize long sequences. + Suitable for reasoning models (CoT) and agents with long action sequences. + + Args: + rs_threshold (float): Geometric RS threshold (upper). Default: 1.001 (±0.1%) + rs_threshold_lower (Optional[float]): Geometric RS threshold (lower). + If None, auto-computed as reciprocal of rs_threshold. Default: None + veto_threshold (float): Per-token veto threshold. Default: 1e-4 + + Returns: + RolloutCorrectionConfig configured for PG with Geo-RS + """ + return cls( + rollout_is=None, + rollout_rs="geometric", + rollout_rs_threshold=rs_threshold, + rollout_rs_threshold_lower=rs_threshold_lower, + rollout_token_veto_threshold=veto_threshold, + bypass_mode=True, + use_policy_gradient=True, + ) + + @classmethod + def geo_rs_seq_tis( + cls, + is_threshold: float = 2.0, + rs_threshold: float = 1.001, + rs_threshold_lower: Optional[float] = None, + veto_threshold: Optional[float] = 1e-4, + ) -> "RolloutCorrectionConfig": + """Geometric RS with Sequence-level Truncated IS (Geo-RS-Seq-TIS). + + Combines the Geometric Filter (length-invariant validity check) with + Clipped Sequence Weight (debiasing). + + Suitable for reasoning models (CoT, o1-style) and agents that need to + think for many steps without collapsing. + + Args: + is_threshold (float): Upper threshold for sequence IS weights. Default: 2.0 + rs_threshold (float): Geometric RS threshold (upper). Default: 1.001 (±0.1%) + rs_threshold_lower (Optional[float]): Geometric RS threshold (lower). + If None, auto-computed as reciprocal of rs_threshold. Default: None + veto_threshold (Optional[float]): Per-token veto threshold. Default: 1e-4 + + Returns: + RolloutCorrectionConfig configured for Geo-RS-Seq-TIS + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=is_threshold, + rollout_rs="geometric", + rollout_rs_threshold=rs_threshold, + rollout_rs_threshold_lower=rs_threshold_lower, + rollout_token_veto_threshold=veto_threshold, + ) + + @classmethod + def pg_geo_rs_seq_tis( + cls, + is_threshold: float = 2.0, + rs_threshold: float = 1.001, + rs_threshold_lower: Optional[float] = None, + veto_threshold: Optional[float] = 1e-4, + ) -> "RolloutCorrectionConfig": + """Policy Gradient with Geo-RS-Seq-TIS (Bypass mode). + + Combines geometric rejection with sequence-level IS + in bypass mode with policy gradient loss (no PPO clipping). + + Suitable for reasoning models (CoT, o1-style) and agents when you want + bypass mode efficiency. + + Args: + is_threshold (float): Upper threshold for sequence IS weights. Default: 2.0 + rs_threshold (float): Geometric RS threshold (upper). Default: 1.001 (±0.1%) + rs_threshold_lower (Optional[float]): Geometric RS threshold (lower). + If None, auto-computed as reciprocal of rs_threshold. Default: None + veto_threshold (Optional[float]): Per-token veto threshold. Default: 1e-4 + + Returns: + RolloutCorrectionConfig configured for PG with Geo-RS-Seq-TIS + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=is_threshold, + rollout_rs="geometric", + rollout_rs_threshold=rs_threshold, + rollout_rs_threshold_lower=rs_threshold_lower, + rollout_token_veto_threshold=veto_threshold, + bypass_mode=True, + use_policy_gradient=True, + ) + + @classmethod + def disabled(cls) -> "RolloutCorrectionConfig": + """Disabled - Metrics Only Mode. + + Computes and logs off-policy metrics without applying correction. + + Returns: + RolloutCorrectionConfig with all correction disabled + """ + return cls(rollout_is=None, rollout_rs=None) + + +@dataclass +class AlgoConfig(BaseConfig): + """Configuration for the algorithm. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + gamma (float): Discount factor for future rewards. + lam (float): Trade-off between bias and variance in the GAE estimator. + adv_estimator (str): Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. + norm_adv_by_std_in_grpo (bool): Whether to normalize advantages by std (specific to GRPO). + use_kl_in_reward (bool): Whether to enable in-reward KL penalty. + kl_penalty (str): How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full". + kl_ctrl (KLControlConfig): KL control configuration. + use_pf_ppo (bool): Whether to enable preference feedback PPO. + pf_ppo (dict[str, Any]): Preference feedback PPO settings. + filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy + rollout_correction (Optional[RolloutCorrectionConfig]): Rollout Correction configuration. + Addresses off-policy issues from policy mismatch, model staleness, and general distribution shifts. + + Set to None to disable entirely. Use factory methods for common presets: + - RolloutCorrectionConfig.decoupled_token_is() - Decoupled mode with token-level IS + - RolloutCorrectionConfig.decoupled_seq_is() - Decoupled mode with sequence-level IS + - RolloutCorrectionConfig.decoupled_seq_is_rs() - Decoupled mode with sequence IS + RS + - RolloutCorrectionConfig.decoupled_geo_rs() - Decoupled mode with geometric RS + veto + - RolloutCorrectionConfig.ppo_is_bypass() - Bypass mode (skips old_log_prob) + - RolloutCorrectionConfig.pg_is() - Policy gradient with IS + - RolloutCorrectionConfig.pg_rs() - Policy gradient with RS + + For backward compatibility, you can still pass a dict, which will be converted to + RolloutCorrectionConfig automatically. + """ + + gamma: float = 1.0 + lam: float = 1.0 + adv_estimator: str = "gae" + norm_adv_by_std_in_grpo: bool = True + use_kl_in_reward: bool = False + kl_penalty: str = "kl" + kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig) + use_pf_ppo: bool = False + pf_ppo: dict[str, Any] = field(default_factory=dict) + filter_groups: Optional[FilterGroupsConfig] = None + # Rollout Correction: corrects off-policy issues (policy mismatch, model staleness, distribution shifts) + # Set to None to disable, use RolloutCorrectionConfig presets (e.g., .tis(), .mis()), or pass dict + rollout_correction: Optional[RolloutCorrectionConfig] = None diff --git a/src/reasoning360/trainer/config/algorithm/rollout_correction.yaml b/src/reasoning360/trainer/config/algorithm/rollout_correction.yaml new file mode 100644 index 000000000..7c958c5ee --- /dev/null +++ b/src/reasoning360/trainer/config/algorithm/rollout_correction.yaml @@ -0,0 +1,30 @@ +# Rollout Correction: corrects off-policy distribution shifts +# See documentation: docs/algo/rollout_corr.md +# Use presets: RolloutCorrectionConfig.decoupled_seq_is(), .pg_is(), etc. + +# IS aggregation level: null (disabled), "token" (per-token), "sequence" (per-sequence) +rollout_is: null + +# Upper threshold for IS weight truncation (typical: 2.0-5.0) +rollout_is_threshold: 2.0 + +# RS aggregation level: null (disabled), "token", "sequence", "geometric" +rollout_rs: null + +# Upper threshold for rejection sampling (null = use rollout_is_threshold) +rollout_rs_threshold: null + +# Lower threshold for rejection sampling (null = auto-compute as 1/upper) +rollout_rs_threshold_lower: null + +# Per-token veto threshold for catastrophic outliers (null = disabled) +rollout_token_veto_threshold: null + +# Operating mode: false = Decoupled (3 policies), true = Bypass (2 policies) +bypass_mode: false + +# Loss function: false = PPO with clipping, true = Policy gradient (no clipping) +use_policy_gradient: false + +# Batch normalize IS weights: false = raw weights, true = normalize to mean=1.0 +rollout_is_batch_normalize: false diff --git a/src/reasoning360/trainer/config/config.py b/src/reasoning360/trainer/config/config.py new file mode 100644 index 000000000..bd323d09d --- /dev/null +++ b/src/reasoning360/trainer/config/config.py @@ -0,0 +1,129 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from verl.base_config import BaseConfig + +__all__ = ["CheckpointConfig", "ProfileConfig", "BaseModelConfig"] + + +@dataclass +class CheckpointConfig(BaseConfig): + """Configuration for model checkpointing. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + save_contents (list[str]): What to include in saved checkpoints. + Options: 'model', 'optimizer', 'extra', 'hf_model'. + load_contents (list[str]): Contents to load from checkpoint. Defaults to same as save_contents. + async_save (bool): Whether to save checkpoints asynchronously. Only implemented for Megatron as of now. + """ + + save_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"]) + load_contents: list[str] = field(default_factory=lambda: ["model", "optimizer", "extra"]) + async_save: bool = False + + +@dataclass +class ProfileConfig(BaseConfig): + """Configuration for profiling. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + profile_ranks (Optional[list[int]]): List of ranks to profile. None means all ranks. + step_start (int): Starting step for profiling. + step_end (int): Ending step for profiling. + save_path (Optional[str]): Path to save profiling results. + """ + + profile_ranks: Optional[list[int]] = None + step_start: int = -1 + step_end: int = -1 + save_path: Optional[str] = None + + +@dataclass +class BaseModelConfig(BaseConfig): + """Base configuration for a model. + Contains core settings for loading and initializing a pretrained model checkpoint. + + Args: + path (str): Path to pretrained model weights. + tokenizer_path (Optional[str]): Tokenizer path (defaults to actor's model path if not set). + override_config (dict): Hugging Face config override. + external_lib (Optional[str]): External model implementation (optional). + trust_remote_code (bool): Whether to trust remote code from Hugging Face models. + lora (dict[str, Any]): LoRA configuration dictionary. + """ + + path: str = "~/models/deepseek-llm-7b-chat" + tokenizer_path: Optional[str] = None + override_config: dict[str, Any] = field(default_factory=dict) + external_lib: Optional[str] = None + trust_remote_code: bool = False + lora: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ModuleConfig(BaseConfig): + """Configuration for external Python module, which can be loaded, executed (and optionally, ``import``ed). + + Args: + path (str, optional): Path to the module file to load and execute. + name (str, optional): Name of the module to ``import``. Format: ``"import.path.to.module"``. + If ``None``, the module will be loaded with a hased name and + will not be added to ``sys.modules``, thus can not be ``import``ed as ``name``. + """ + + path: Optional[str] = None + name: Optional[str] = None + + +@dataclass +class RewardManagerConfig(BaseConfig): + """Configuration for reward manager. + + A reward manager defines the mechanism of computing rule-based reward and handling different reward sources. + + Args: + source (str): Source of the reward manager. Options: ``"register"``, ``"importlib"``. Default: ``"register"``. + name (str, optional): + - When ``source`` is ``"register"``, the name is used in `get_reward_manager_cls(name)``. + See ``verl/experimental/reward/reward_manager.py`` for options. Default: ``"naive"``. + - When ``source`` is ``"importlib"``, the name is used in ``getattr(module, name)``, + e.g., ``"DAPORewardManager"``. + module (ModuleConfig, optional): Optional configuration for the external module defining the reward manager, + """ + + source: str = "register" + name: str = "naive" + module: Optional[ModuleConfig] = field(default_factory=ModuleConfig) + + def __post_init__(self): + super().__post_init__() + if self.source == "register": + from verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY + + assert self.name in REWARD_MANAGER_REGISTRY, ( + f"Reward manager is not registered: {self.name=} ,{REWARD_MANAGER_REGISTRY.keys()=}" + ) + elif self.source == "importlib": + # NOTE: The existence is not checked since it depends on which machine the config is initialized on. + assert self.module is not None and self.module.path is not None, ( + "When source is importlib, module.path should be set." + ) diff --git a/src/reasoning360/trainer/config/critic/critic.yaml b/src/reasoning360/trainer/config/critic/critic.yaml new file mode 100644 index 000000000..f201a34b4 --- /dev/null +++ b/src/reasoning360/trainer/config/critic/critic.yaml @@ -0,0 +1,176 @@ +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +_target_: verl.workers.config.CriticConfig + +# Number of rollouts per update (mirrors actor rollout_n) +rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + +# fsdp or fsdp2 strategy used for critic model training +strategy: ??? + +# whether to enable the critic worker. +# by default it is only enabled if advantage estimator is gae +# set it to True manually if you always want to enable critic worker +enable: null + +# optimizer configs +optim: + + # Learning rate + lr: 1e-5 + + # Warmup steps ratio; total steps will be injected at runtime + lr_warmup_steps_ratio: 0.0 + + # Total training steps (must be overridden at runtime) + total_training_steps: -1 + + # Weight decay + weight_decay: 0.01 + + # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. + lr_warmup_steps: -1 + + +# model config for the critic +model: + + # Path to pretrained model weights + path: ~/models/deepseek-llm-7b-chat + + # Tokenizer path (defaults to actor's model path) + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + + # Hugging Face config override + override_config: {} + + # External model implementation (optional) + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + + # Whether to trust remote code from Hugging Face models + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + +# PPO mini-batch size per update +ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + +# [Deprecated] Global micro batch size +ppo_micro_batch_size: null + +# Local per-GPU micro batch size +ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + +# Whether to automatically adjust batch size at runtime +use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# Max tokens per GPU in one PPO batch (doubled for critic) +ppo_max_token_len_per_gpu: 32768 + +# Max token length per GPU in forward pass +forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + +# Number of PPO epochs per batch +ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + +# Shuffle training data across PPO epochs +shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + +# PPO value function clipping range +cliprange_value: 0.5 + +# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" +loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + +# checkpoint configs +checkpoint: + + # Target dataclass for this configuration + _target_: verl.trainer.config.CheckpointConfig + + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ['model', 'optimizer', 'extra'] + + # What to include when loading checkpoints + load_contents: ${.save_contents} + + # Whether to save checkpoints asynchronously. Only effective for Megatron as of now. + async_save: False + +# profile the critic model in `update_critic` +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # profiler tool, default same as profiler.tool in global config + # choices: nsys, npu, torch, torch_memory + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on Critic + enable: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config which only related to the role + tool_config: + + # nsys tool config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + + # npu config + npu: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NPUToolConfig + + # Contents to profile, can be empty + # options: npu, cpu, memory, shapes, module, stack + contents: [] + + # Collection level, optional values: level_none, level0, level1, level2. + level: "level1" + + # Whether to automatically parse the data. + analysis: True + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # torch profiler config + torch: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + + # start profile mini-batch in training + # NOTICE: different with global steps config which refers to iteration + # This field only related with mini-batch + step_start: 0 + + # stop profile mini-batch in training + step_end: null + + # torch memory profiler config + torch_memory: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + + # Maximum number of memory allocation entries to track + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + + # Stack trace depth for memory allocations + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + \ No newline at end of file diff --git a/verl/trainer/config/critic/dp_critic.yaml b/src/reasoning360/trainer/config/critic/dp_critic.yaml similarity index 53% rename from verl/trainer/config/critic/dp_critic.yaml rename to src/reasoning360/trainer/config/critic/dp_critic.yaml index 88efe143a..c040a3224 100644 --- a/verl/trainer/config/critic/dp_critic.yaml +++ b/src/reasoning360/trainer/config/critic/dp_critic.yaml @@ -7,66 +7,42 @@ # defaults specify the default config from each component defaults: + # fsdp optimizer config + - ../optim@optim: fsdp + + # fsdp engine config + - ../engine@model.fsdp_config: fsdp + # dp actor config, inheriting from trainer/config/critic/critic.yaml - critic # load the reference default config, then apply the fields in the current yaml - _self_ -strategy: fsdp - -# optimizer configs -optim: - - # Learning rate - lr: 1e-5 +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +_target_: verl.workers.config.FSDPCriticConfig - # Minimum LR ratio for cosine schedule - min_lr_ratio: null - - # LR warmup style: "constant" or "cosine" - warmup_style: constant +# distribution strategy. Options: fsdp (deprecating), fsdp2 +strategy: fsdp # model config for the critic model: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.FSDPCriticModelCfg + # Whether to use shared memory for loading the model use_shm: False + # Enable gradient checkpointing to save memory + enable_gradient_checkpointing: True + # Offload activations to CPU to reduce GPU memory usage enable_activation_offload: False # Use remove padding optimization (saves compute) use_remove_padding: False - # FSDP-specific config - fsdp_config: - - # Whether to offload model parameters to CPU - param_offload: False - - # Whether to offload optimizer state to CPU - optimizer_offload: False - - # Only for FSDP2: offload param/grad/optimizer during train - offload_policy: False - - # Only for FSDP2: Reshard after forward pass to reduce memory footprint - reshard_after_forward: True - - # Policy for wrapping layers with FSDP - wrap_policy: - - # Minimum number of parameters to trigger wrapping - min_num_params: 0 - - # Number of GPUs in each FSDP shard group; -1 means auto - fsdp_size: -1 - - # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather - # before the current forward computation. - forward_prefetch: False - # Set to positive value to enable LoRA (e.g., 32) lora_rank: 0 @@ -77,13 +53,14 @@ model: target_modules: all-linear # Forward-only batch size during inference (global) -forward_micro_batch_size: ${critic.ppo_micro_batch_size} +forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} # Forward-only batch size during inference (per GPU) -forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} +forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} # Sequence parallelism size for Ulysses-style model parallelism +# [DEPRECATED] use fsdp_config.ulysses_sequence_parallel_size instead ulysses_sequence_parallel_size: 1 # Gradient clipping for critic updates -grad_clip: 1.0 \ No newline at end of file +grad_clip: 1.0 diff --git a/src/reasoning360/trainer/config/critic/megatron_critic.yaml b/src/reasoning360/trainer/config/critic/megatron_critic.yaml new file mode 100644 index 000000000..61f2877f2 --- /dev/null +++ b/src/reasoning360/trainer/config/critic/megatron_critic.yaml @@ -0,0 +1,101 @@ +# defaults specify the default config from each component +defaults: + + # megatron optimizer config + - ../optim@optim: megatron + + # megatron engine config + - ../engine@megatron: megatron + + # dp actor config, inheriting from trainer/config/critic/critic.yaml + - critic + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +_target_: verl.workers.config.McoreCriticConfig + +strategy: megatron + +# seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron +nccl_timeout: 600 + +# model config for the critic +model: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.BaseModelConfig + + # override default empty mapping + override_config: + + model_config: {} + + moe_config: + + freeze_moe_router: False + + # LoRA (Low-Rank Adaptation) configuration for parameter-efficient fine-tuning + lora: + # LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora" + type: lora + + # LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA + rank: 0 # typical values: 8, 16, 32, 64 + + # Weighting factor for the low-rank projection. Defaults to 32 + alpha: 32 + + # Dropout rate for the low-rank projection. Defaults to 0.0 + dropout: 0.0 + + # A list of module names to apply LoRA to. + # For fused LoRA, Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. + # For canonical LoRA: ["linear_q", "linear_k", "linear_v", "linear_proj", "linear_fc1_up", "linear_fc1_gate", "linear_fc2"] + # - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections in self-attention + # - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention + # - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP + # - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP + # Target modules can also contain wildcards. For example, you can specify + # target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv on the first two layers + target_modules: + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 + + # A list of module names not to apply LoRa to. It will match all nn.Linear & nn.Linear-adjacent modules whose name + # does not match any string in exclude_modules. If used, will require target_modules to be empty list or null + exclude_modules: [] + + # Position for applying dropout, can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'pre' + dropout_position: pre + + # Initialization method for the low-rank matrix A. Defaults to "xavier". + lora_A_init_method: xavier + + # Initialization method for the low-rank matrix B. Defaults to "zero". + lora_B_init_method: zero + + # Enables the experimental All-to-All (A2A) communication strategy. Defaults to False + a2a_experimental: False + + # Parameter data type for LoRA weights. Default to null, which will use model's dtype. + dtype: null + + # Path to pre-trained LoRA adapter weights (null to train from scratch) + adapter_path: null + + # VLMLoRA additionally allows the user to specify whether the language or vision models should be frozen. + # For example, a common finetuning workload for multimodal models is to apply adapters to language model and fully + # finetune the vision model. + freeze_vision_model: True + freeze_vision_projection: True + freeze_language_model: True + +# Whether to load initial weights +load_weight: True + +# seed for data loader +data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null} diff --git a/src/reasoning360/trainer/config/engine/fsdp.yaml b/src/reasoning360/trainer/config/engine/fsdp.yaml new file mode 100644 index 000000000..667854de5 --- /dev/null +++ b/src/reasoning360/trainer/config/engine/fsdp.yaml @@ -0,0 +1,62 @@ +# Target class for this configuration +_target_: verl.workers.config.FSDPEngineConfig + +# policy for wrapping the model +wrap_policy: + + # Minimum number of parameters to trigger wrapping a layer with FSDP + min_num_params: 0 + +# Whether to offload model parameters to CPU (trades speed for memory) +# Note that this differs from the offload_policy in FSDP +param_offload: false + +# Whether to offload optimizer state to CPU +# Note that this differs from the offload_policy in FSDP +optimizer_offload: false + +# Only for FSDP2: offload param/grad/optimizer during train +offload_policy: false + +# Only for FSDP2: Reshard after forward pass to reduce memory footprint +reshard_after_forward: true + +# Number of GPUs in each FSDP shard group; -1 means auto +fsdp_size: -1 + +# Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather +# before the current forward computation. +forward_prefetch: False + +# model dtype of fsdp +model_dtype: fp32 + +# Whether to use original parameters in fsdp. Only avaiable in fsdp1 +use_orig_params: false + +# Random seed for reproducibility. +seed: 42 + +# Whether to enable full determinism for distributed training, only for debugging. +full_determinism: false + +# ulysses sequence parallel size +ulysses_sequence_parallel_size: 1 + +# Whether to use entropy_from_logits_with_chunking in fsdp. +entropy_from_logits_with_chunking: false + +# Whether to use torch compile in fsdp. +use_torch_compile: true + +# Whether to use entropy checkpointing in fsdp. +entropy_checkpointing: false + +# Whether to use forward only in fsdp. +forward_only: false + +# fsdp or fsdp2 +strategy: fsdp + +# Mixed precision training param dtype +dtype: bfloat16 # ["bfloat16", "float16"] diff --git a/src/reasoning360/trainer/config/engine/megatron.yaml b/src/reasoning360/trainer/config/engine/megatron.yaml new file mode 100644 index 000000000..84601f5a3 --- /dev/null +++ b/src/reasoning360/trainer/config/engine/megatron.yaml @@ -0,0 +1,90 @@ +# Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs +_target_: verl.workers.config.McoreEngineConfig + +# Whether to offload model parameters to CPU +param_offload: False + +# Whether to offload gradients to CPU +grad_offload: False + +# Whether to offload optimizer state to CPU +optimizer_offload: False + +# tensor model parallel size +tensor_model_parallel_size: 1 + +# expert model parallel size +expert_model_parallel_size: 1 + +# expert tensor parallel size (null to be same as TP) +expert_tensor_parallel_size: null + +# pipeline model parallel size +pipeline_model_parallel_size: 1 + +# virtual pipeline model parallel size +virtual_pipeline_model_parallel_size: null + +# context parallel size +context_parallel_size: 1 + +# sequence parallel +sequence_parallel: True + +# Whether to use distributed optimizer +use_distributed_optimizer: True + +# Whether to use distributed checkpointing +use_dist_checkpointing: False + +# distributed checkpointing path +dist_checkpointing_path: null + +# distributed checkpointing prefix, e.g. Nemo2 will append prefix 'module.' to the state dict keys +dist_checkpointing_prefix: '' + +# oc.select: default val for ref.megatron.seed +seed: 42 + +# Allow to override Distributed Data Parallel (DDP) config +override_ddp_config: {} + +# additional transformer config like: num_layers_in_first(/last)_pipeline_stage +# oc.select: default val for ref.megatron.override_transformer_config +override_transformer_config: + # Recompute configuration, same as in megatron.training.arguments + # default use minimal performance-interference recompute methods + # Recompute granualarity, choices: ["full", "selective"] + recompute_granularity: null + + # Recompute modules, multiple choices: ["core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe"] + # Please use correct module in matched model + recompute_modules: ["core_attn"] + + # 'uniform', 'block' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + recompute_method: null + + # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention + recompute_num_layers: null + + # Attention backend to use (flash,fused,unfused,local,auto). Defaults to auto in mcore, flash in verl + attention_backend: flash + +override_mcore_model_config: {} + +# oc.select: default val for ref.megatron.use_mbridge +use_mbridge: False + +# oc.select: default val for ref.megatron.vanilla_mbridge +vanilla_mbridge: True + +# whether to use thd format (sequence packing), if not, use bshd format, padding the input_ids to the longest sequence length +use_remove_padding: True + +# whether to use forward only +forward_only: False + +# Mixed precision training param dtype +dtype: bfloat16 # ["bfloat16", "float16"] diff --git a/verl/trainer/config/evaluation.yaml b/src/reasoning360/trainer/config/evaluation.yaml similarity index 59% rename from verl/trainer/config/evaluation.yaml rename to src/reasoning360/trainer/config/evaluation.yaml index efca03da4..6a88d77f1 100644 --- a/verl/trainer/config/evaluation.yaml +++ b/src/reasoning360/trainer/config/evaluation.yaml @@ -9,6 +9,7 @@ custom_reward_function: path: null name: compute_score -ray_init: - num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. +ray_kwargs: + ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. timeline_json_file: null diff --git a/verl/trainer/config/generation.yaml b/src/reasoning360/trainer/config/generation.yaml similarity index 77% rename from verl/trainer/config/generation.yaml rename to src/reasoning360/trainer/config/generation.yaml index c19cfed95..e542d6159 100644 --- a/verl/trainer/config/generation.yaml +++ b/src/reasoning360/trainer/config/generation.yaml @@ -4,16 +4,17 @@ trainer: device: cuda data: - path: ./data/test/simulation__cruxeval-o_800.parquet + path: ~/data/rlhf/math/test.parquet prompt_key: prompt n_samples: 5 - output_path: ./output/simulation__cruxeval-o_800.parquet + output_path: /opt/tiger/math_Qwen2-7B-Instruct.parquet batch_size: 128 model: - path: Qwen/Qwen2-7B-Instruct + path: ~/models/Qwen2-7B-Instruct external_lib: null rollout: + _target_: verl.workers.config.RolloutConfig name: vllm mode: sync # sync: LLM, async: AsyncLLM temperature: 1.0 @@ -27,8 +28,9 @@ rollout: ignore_eos: False enforce_eager: True free_cache_engine: True - load_format: dummy_dtensor + load_format: auto tensor_model_parallel_size: 1 + data_parallel_size: 1 max_num_batched_tokens: 8192 max_model_len: null max_num_seqs: 1024 @@ -50,6 +52,7 @@ actor: fsdp_size: -1 forward_prefetch: False # FSDP1 forward_prefetch configuration -ray_init: - num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. +ray_kwargs: + ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. timeline_json_file: null diff --git a/src/reasoning360/trainer/config/model/hf_model.yaml b/src/reasoning360/trainer/config/model/hf_model.yaml new file mode 100644 index 000000000..6d02b8eac --- /dev/null +++ b/src/reasoning360/trainer/config/model/hf_model.yaml @@ -0,0 +1,67 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +_target_: verl.workers.config.HFModelConfig + +# path to the huggingface model +path: ~/models/deepseek-llm-7b-chat + +# config to the huggingface config. In case it is not the same as path +hf_config_path: null + +# path to the huggingface tokenizer. In case it is not the same as path +tokenizer_path: null + +# whether to use shared memory for model loading +use_shm: False + +# whether to trust remote code. +trust_remote_code: False + +# custom chat template for the model +custom_chat_template: null + +# whether to use external libs for the model +external_lib: null + +# override hf config +override_config: {} + +# whether to enable gradient checkpointing. Only valid when we use hf model definition +enable_gradient_checkpointing: True + +# whether to enable activation offload. Only valid when we use hf model definition +enable_activation_offload: False + +# whether to use remove padding. Only valid when we use hf model definition +use_remove_padding: False + +# Set to positive value to enable LoRA (e.g., 32) +lora_rank: 0 + +# LoRA scaling factor +lora_alpha: 16 + +# Target modules for LoRA adaptation +target_modules: all-linear + +# Exclude modules from LoRA adaptation +exclude_modules: null + +# Path to pre-trained LoRA adapter to load for continued training +lora_adapter_path: null + +# whether to use liger. Only valid when we use hf model definition +use_liger: False + +# whether to use fused kernels. +use_fused_kernels: False + +# fused kernel options. +fused_kernel_options: + + # the implementation backend for fused kernels. + impl_backend: torch diff --git a/verl/trainer/config/npu_profile/npu_profile.yaml b/src/reasoning360/trainer/config/npu_profile/npu_profile.yaml similarity index 73% rename from verl/trainer/config/npu_profile/npu_profile.yaml rename to src/reasoning360/trainer/config/npu_profile/npu_profile.yaml index b61260375..52bb52d3f 100644 --- a/verl/trainer/config/npu_profile/npu_profile.yaml +++ b/src/reasoning360/trainer/config/npu_profile/npu_profile.yaml @@ -4,6 +4,11 @@ options: # Storage path of collected data. save_path: ./profiler_data + # The roles that will be profiled. Only takes effect in discrete mode. + # optional values: all, rollout_generate, actor_compute_log_prob, actor_update and ref_compute_log_prob. + # "all" means all roles will be profiled. + roles: ["all"] + # Collection level, optional values: level_none, level0, level1, level2. level: level1 diff --git a/src/reasoning360/trainer/config/optim/fsdp.yaml b/src/reasoning360/trainer/config/optim/fsdp.yaml new file mode 100644 index 000000000..a7dd99b1e --- /dev/null +++ b/src/reasoning360/trainer/config/optim/fsdp.yaml @@ -0,0 +1,50 @@ +# Target class for this configuration +_target_: verl.workers.config.FSDPOptimizerConfig + +# Optimizer class name (e.g., "AdamW", "AdamW8bit", "_AdamW", "Adam") +optimizer: AdamW + +# Module path to import optimizer +# Examples: "torch.optim", "torchao.optim", "bitsandbytes.optim" +optimizer_impl: torch.optim + +# Learning rate +lr: 1e-3 + +# LR warmup steps ratio +lr_warmup_steps_ratio: 0.0 + +# Total training steps +total_training_steps: -1 + +# Weight decay +weight_decay: 0.01 + +# LR warmup steps +lr_warmup_steps: -1 + +# Betas for Adam optimizer +betas: [0.9, 0.999] + +# Clip gradient +clip_grad: 1.0 + +# Minimum LR ratio for cosine schedule +min_lr_ratio: 0.0 + +# Number of cosine cycles in LR schedule +num_cycles: 0.5 + +# LR scheduler type: "constant" or "cosine" +lr_scheduler_type: constant + +# deprecated +warmup_style: null + +# Additional optimizer-specific keyword arguments +# Example for torchao with bf16 stochastic rounding: +# optimizer_impl: torchao.optim +# optimizer: _AdamW +# override_optimizer_config: +# bf16_stochastic_round: true +override_optimizer_config: null diff --git a/src/reasoning360/trainer/config/optim/megatron.yaml b/src/reasoning360/trainer/config/optim/megatron.yaml new file mode 100644 index 000000000..c3e49b7df --- /dev/null +++ b/src/reasoning360/trainer/config/optim/megatron.yaml @@ -0,0 +1,49 @@ +_target_: verl.workers.config.McoreOptimizerConfig + +# Learning rate +lr: 1e-3 + +# LR warmup steps ratio +lr_warmup_steps_ratio: 0.0 + +# Total training steps +total_training_steps: -1 + +# Weight decay +weight_decay: 0.01 + +# LR warmup steps +lr_warmup_steps: -1 + +# Betas for Adam optimizer +betas: [0.9, 0.999] + +# Clip gradient +clip_grad: 1.0 + +# optimizer type +optimizer: adam + +# initial learning rate for warmup, default to 0.0 +lr_warmup_init: 0.0 + +lr_decay_steps: null + +# select from constant/linear/cosine/inverse_square_root +lr_decay_style: constant + +# minimum learning rate, default to 0.0 +min_lr: 0.0 + +# select from constant/linear/cosine +weight_decay_incr_style: constant + +# select from constant/exponential/cosine +lr_wsd_decay_style: exponential + +lr_wsd_decay_steps: null + +# use checkpoint optimizer parameter scheduler +use_checkpoint_opt_param_scheduler: False + +override_optimizer_config: {} diff --git a/src/reasoning360/trainer/config/ppo_megatron_trainer.yaml b/src/reasoning360/trainer/config/ppo_megatron_trainer.yaml new file mode 100644 index 000000000..9d9959aea --- /dev/null +++ b/src/reasoning360/trainer/config/ppo_megatron_trainer.yaml @@ -0,0 +1,240 @@ +# specify the default per-component configs +defaults: + # @.: + # actor_rollout_ref.actor: trainer/config/actor/megatron_actor.yaml + - actor@actor_rollout_ref.actor: megatron_actor + # data: trainer/config/data/legacy_data.yaml + - data@data: legacy_data + # (Rule-based) Reward manager config. + - reward_manager@reward_manager + # load the reference default config, then apply the fields in the current yaml + # Reference model config. + # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. + - ref@actor_rollout_ref.ref: megatron_ref + # Rollout model config. + - rollout@actor_rollout_ref.rollout: rollout + # Model config. + - model@actor_rollout_ref.model: hf_model + # Critic model config. + - critic@critic: megatron_critic + # Reward model config. + - reward_model@reward_model: megatron_reward_model + # Rollout correction config. + - algorithm@algorithm.rollout_correction: rollout_correction + - _self_ + +actor_rollout_ref: + hybrid_engine: True + + nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron + + model: + override_config: + model_config: {} + moe_config: + freeze_moe_router: False + + use_fused_kernels: False # Whether to use custom fused kernels (PostProcessing, for memory efficiency) + + trust_remote_code: False + + # Whether to remove padding tokens in inputs during training + use_remove_padding: false + + # LoRA (Low-Rank Adaptation) configuration for parameter-efficient fine-tuning + lora: + # LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora" + type: lora + + # LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA + rank: 0 # typical values: 8, 16, 32, 64 + + # Weighting factor for the low-rank projection. Defaults to 32 + alpha: 32 + + # Dropout rate for the low-rank projection. Defaults to 0.0 + dropout: 0.0 + + # A list of module names to apply LoRA to. + # For fused LoRA, Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. + # For canonical LoRA: ["linear_q", "linear_k", "linear_v", "linear_proj", "linear_fc1_up", "linear_fc1_gate", "linear_fc2"] + # - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections in self-attention + # - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention + # - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP + # - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP + # Target modules can also contain wildcards. For example, you can specify + # target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv on the first two layers + target_modules: + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 + + # A list of module names not to apply LoRa to. It will match all nn.Linear & nn.Linear-adjacent modules whose name + # does not match any string in exclude_modules. If used, will require target_modules to be empty list or None + exclude_modules: [] + + # Position for applying dropout, can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'pre' + dropout_position: pre + + # Initialization method for the low-rank matrix A. Defaults to "xavier". + lora_A_init_method: xavier + + # Initialization method for the low-rank matrix B. Defaults to "zero". + lora_B_init_method: zero + + # Enables the experimental All-to-All (A2A) communication strategy. Defaults to False + a2a_experimental: False + + # Parameter data type for LoRA weights. Default to null, which will use model's dtype. + dtype: null + + # Path to pre-trained LoRA adapter weights (null to train from scratch) + adapter_path: null + + # VLMLoRA additionally allows the user to specify whether the language or vision models should be frozen. + # For example, a common finetuning workload for multimodal models is to apply adapters to language model and fully + # finetune the vision model. + freeze_vision_model: True + freeze_vision_projection: True + freeze_language_model: True + + rollout: + quantization: null + + layer_name_map: + qkv_layer_name: qkv + gate_proj_layer_name: gate_up + +custom_reward_function: + path: null + name: compute_score + +algorithm: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + norm_adv_by_std_in_grpo: True + use_kl_in_reward: False + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: False + pf_ppo: + reweight_method: pow # ["pow", "max_min", "max_random"] + weight_pow: 2.0 + +trainer: + balance_batch: True + total_epochs: 30 + total_training_steps: null + project_name: verl_examples + experiment_name: gsm8k + logger: ["console", "wandb"] + log_val_generations: 0 + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + esi_redundant_time: 0 + + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or disable or resume_path if resume_from_path is set + resume_from_path: null + del_local_ckpt_after_load: False + val_before_train: True + test_freq: -1 + critic_warmup: 0 + default_hdfs_dir: null + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + # The timeout for ray worker group to wait for the register center to be ready + ray_wait_register_center_timeout: 300 + device: cuda + # Directory for logging rollout data; no dump if null + rollout_data_dir: null + + # whether to use legacy worker implementation + # mode: "auto", "enable", or "disable" + use_legacy_worker_impl: auto + +global_profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: null # choose between nsys, npu, torch, torch_memory + steps: null # profile steps + profile_continuous_steps: False + save_path: "outputs/profile" # profiler saving path + # Specific tool configs, can use +profiler.tool_config.[tool].xxx to config + global_tool_config: + # nsys config + nsys: + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. + ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html + ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html + controller_nsight_options: + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. + worker_nsight_options: + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. + capture-range: "cudaProfilerApi" + + # Specify the desired behavior when a capture range ends. + # In verl we need the torch.cuda.profiler.start/stop pair to repeats n times. + # valid values are "repeat-shutdown:n" or null. + # For normal whole step profiling, n = len(profile_steps); + # but for discrete profiling, n = len(profile_steps) * Number(subtasks). + # Or you can just leave it null and the program will use n = len(profile_steps) * 6; + capture-range-end: null + + # Send signal to the target application's process group. We let the program to exit by itself. + kill: none + + # enable memory visualization for debugging memory usage + torch_memory: + # Maximum number of allocation entries to record + trace_alloc_max_entries: 100_000 + # The depth of the call stack to capture for each allocation + stack_depth: 32 + # 'alloc': records only allocation events || 'state': records memory state changes || 'all': records both. + context: "all" + # 'python': records Python stacks || 'cpp': records C++ stacks (available in some versions) || 'all': records both. + stacks: "all" + # devices, record_context etc. + kw_args: {} + +# configs for TransferQueue +transfer_queue: + # Whether to enable transfer queue + enable: False + +ray_kwargs: + ray_init: + num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. + timeline_json_file: null diff --git a/verl/trainer/config/ppo_trainer.yaml b/src/reasoning360/trainer/config/ppo_trainer.yaml similarity index 53% rename from verl/trainer/config/ppo_trainer.yaml rename to src/reasoning360/trainer/config/ppo_trainer.yaml index 925872739..c226d2d06 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/src/reasoning360/trainer/config/ppo_trainer.yaml @@ -11,12 +11,12 @@ defaults: # actor_rollout_ref.actor: trainer/config/actor/dp_actor.yaml - actor@actor_rollout_ref.actor: dp_actor - # trainer.npu_profile: trainer/config/npu_profile/npu_profile.yaml - - npu_profile@trainer.npu_profile: npu_profile - # data: trainer/config/data/legacy_data.yaml - data@data: legacy_data + # (Rule-based) Reward manager config. + - reward_manager@reward_manager + # Reference model config. # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. - ref@actor_rollout_ref.ref: dp_ref @@ -24,12 +24,18 @@ defaults: # Rollout model config. - rollout@actor_rollout_ref.rollout: rollout + # Model config. + - model@actor_rollout_ref.model: hf_model + # Critic model config. - critic@critic: dp_critic # Reward model config. - reward_model@reward_model: dp_reward_model + # Rollout correction config. + - algorithm@algorithm.rollout_correction: rollout_correction + # load the reference default config, then apply the fields in the current yaml # self config override anything above - _self_ @@ -40,90 +46,15 @@ actor_rollout_ref: # Whether it's a hybrid engine, currently only supports hybrid engine hybrid_engine: true - # common configs for the model - model: - - # Huggingface model path. This can be either local path or HDFS path. - path: ~/models/deepseek-llm-7b-chat - - # Custom chat template for the model. - custom_chat_template: null - - # Whether to use shared memory (SHM) for accelerating the loading of model weights - use_shm: false - - # Additional Python packages to register huggingface models/tokenizers. - external_lib: null - - # Used to override model's original configurations, mainly dropout - override_config: {} - - # Enable gradient checkpointing for actor - enable_gradient_checkpointing: true - - # Enable activation offloading for actor - enable_activation_offload: false - - # Whether to remove padding tokens in inputs during training - use_remove_padding: false - - # Set to positive value to enable LoRA (e.g., 32) - lora_rank: 0 - - # LoRA scaling factor - lora_alpha: 16 - - # Target modules to apply LoRA. Options: "all-linear" (not recommended for VLMs) or - # [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj] - target_modules: all-linear - - # Exclude modules from applying Lora. Similar usage to target_modules and Peft. - # Example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as currently vllm does not support ViT Lora. - exclude_modules: null - - # Whether to use Liger for linear layer fusion - use_liger: false - - # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP) - use_fused_kernels: false - - # Options for fused kernels. If use_fused_kernels is true, this will be used. - fused_kernel_options: - - # Implementation backend for fused kernels. Options: "triton" or "torch". - impl_backend: torch - - # Whether to enable loading a remote code model - trust_remote_code: false + # Timeout for operations executed against the process group + nccl_timeout: 600 # Rollout model config. rollout: - # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. - enable_chunked_prefill: True - - # Which loader to use for rollout model weights: dummy_dtensor, hf, megatron, etc. - # safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight - load_format: dummy_dtensor - # for huge model, layered summon can save memory (prevent OOM) but make it slower layered_summon: False - # profiler configs - profiler: - - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.utils.profiler.ProfilerConfig - - # True for each task has its own database, False for all tasks in one training step share one database. - discrete: False - - # Whether to profile all ranks. - all_ranks: False - - # The ranks that will be profiled. [] or [0,1,...] - ranks: [] - # custom reward function definition custom_reward_function: @@ -137,7 +68,7 @@ custom_reward_function: # config for the algorithm algorithm: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs _target_: verl.trainer.config.AlgoConfig # Discount factor for future rewards @@ -161,7 +92,7 @@ algorithm: # KL control configuration kl_ctrl: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs _target_: verl.trainer.config.KLControlConfig # KL control type: "fixed" or "adaptive" @@ -169,12 +100,6 @@ algorithm: # Initial coefficient for KL penalty kl_coef: 0.001 - horizon: 10000 - target_kl: 0.1 - filter_groups: # NOTE: added by Reasoning360 - enable: False # We try to avoid forgetting to set enable - metric: null # acc / score / seq_reward / seq_final_reward / ... - max_num_gen_batches: 0 # Non-positive values mean no upper limit # Horizon value for adaptive controller (if enabled) horizon: 10000 @@ -188,9 +113,6 @@ algorithm: # Preference feedback PPO settings pf_ppo: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.trainer.config.PFPPOConfig - # Method for reweighting samples: "pow", "max_min", or "max_random" reweight_method: pow @@ -209,49 +131,6 @@ trainer: # Total training steps (can be set explicitly or derived from epochs) total_training_steps: null - # The steps that will be profiled. null means no profiling. null or [1,2,5,...] - profile_steps: null - - # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. - ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html - ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html - controller_nsight_options: - - # Select the API(s) to be traced. - trace: "cuda,nvtx,cublas,ucx" - - # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". - cuda-memory-usage: "true" - - # CUDA graphs will be traced as a whole - cuda-graph-trace: "graph" - - # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. - worker_nsight_options: - - # Select the API(s) to be traced. - trace: "cuda,nvtx,cublas,ucx" - - # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". - cuda-memory-usage: "true" - - # CUDA graphs will be traced as a whole - cuda-graph-trace: "graph" - - # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. - capture-range: "cudaProfilerApi" - - # Specify the desired behavior when a capture range ends. - # In verl we need the orch.cuda.profiler.start/stop pair to repeats n times. - # valid values are "repeat-shutdown:n" or null. - # For normal whole step profiling, n = len(profile_steps); - # but for discrete profiling, n = len(profile_steps) * Number(subtasks). - # Or you can just leave it null and the program will use n = len(profile_steps) * 6; - capture-range-end: null - - # Send signal to the target application's process group. We let the program to exit by itself. - kill: none - # Project name for experiment tracking (e.g., wandb) project_name: verl_examples @@ -259,7 +138,7 @@ trainer: experiment_name: gsm8k # Logging backends to use: "console", "wandb", etc. - logger: [ 'console', 'wandb' ] + logger: ["console", "wandb"] # Number of generations to log during validation log_val_generations: 0 @@ -328,11 +207,114 @@ trainer: # Device to run training on (e.g., "cuda", "cpu") device: cuda -# configs related to ray initialization -ray_init: + # whether to use legacy worker implementation + # mode: "auto", "enable", or "disable" + use_legacy_worker_impl: auto + +# profiler configs +global_profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # Profiling tool: choose between nsys, npu, torch, torch_memory + tool: null + + # profile steps + steps: null + + # Whether to combine continuous steps into one database. + ## If True, worker.profiler.discrete must be False, [1,2] in one, [5] in another. + ## If False, [1] in one, [2] in another, [5] in another. + profile_continuous_steps: False + + # Path to save profiling contents + save_path: "outputs/profile" + + # Specific tool configs, can use +profiler.tool_config.[tool].xxx to config + global_tool_config: + + # nsys config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. + ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html + ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html + controller_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. + worker_nsight_options: + + # Select the API(s) to be traced. + trace: "cuda,nvtx,cublas,ucx" + + # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". + cuda-memory-usage: "true" + + # CUDA graphs will be traced as a whole + cuda-graph-trace: "graph" + + # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. + capture-range: "cudaProfilerApi" + + # Specify the desired behavior when a capture range ends. + # In verl we need the torch.cuda.profiler.start/stop pair to repeats n times. + # valid values are "repeat-shutdown:n" or null. + # For normal whole step profiling, n = len(profile_steps); + # but for discrete profiling, n = len(profile_steps) * Number(subtasks). + # Or you can just leave it null and the program will use n = len(profile_steps) * 6; + capture-range-end: null + + # Send signal to the target application's process group. We let the program to exit by itself. + kill: none + + # enable memory visualization for debugging memory usage + torch_memory: + + # Maximum number of allocation entries to record + trace_alloc_max_entries: 100_000 + + # The depth of the call stack to capture for each allocation + stack_depth: 32 + + # 'alloc': records only allocation events || 'state': records memory state changes || 'all': records both. + context: "all" + + # 'python': records Python stacks || 'cpp': records C++ stacks (available in some versions) || 'all': records both. + stacks: "all" + + # devices, record_context etc. + kw_args: {} + +# configs for TransferQueue +transfer_queue: + + # Whether to enable transfer queue + enable: False + +# configs related to ray +ray_kwargs: + + # configs related to ray initialization + ray_init: - # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM. - num_cpus: null + # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM. + num_cpus: null # Path to save Ray timeline JSON for performance profiling timeline_json_file: null diff --git a/verl/trainer/config/ref/dp_ref.yaml b/src/reasoning360/trainer/config/ref/dp_ref.yaml similarity index 53% rename from verl/trainer/config/ref/dp_ref.yaml rename to src/reasoning360/trainer/config/ref/dp_ref.yaml index 13b604718..64b7d2abb 100644 --- a/verl/trainer/config/ref/dp_ref.yaml +++ b/src/reasoning360/trainer/config/ref/dp_ref.yaml @@ -3,29 +3,21 @@ defaults: # dp ref config, inheriting from trainer/config/ref/ref.yaml - ref + + # fsdp engine config + - ../engine@fsdp_config: fsdp # load the reference default config, then apply the fields in the current yaml - _self_ -# config for FSDP strategy -fsdp_config: - - # whether to offload parameters in FSDP - param_offload: False - - # whether to perform reshard after model forward to save memory. - # only for fsdp2, [True, False, int between 1 and fsdp_size] - reshard_after_forward: True +# Target class for this configuration +_target_: verl.workers.config.FSDPActorConfig - # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather - # before the current forward computation. - forward_prefetch: False - - # the wrap policy for FSDP model - wrap_policy: +# fsdp config +fsdp_config: - # minimum number of params in a wrapped module - min_num_params: 0 + # ref model is forward only + forward_only: True # sequence parallel size # same as actor_rollout_ref.actor.ulysses_sequence_parallel_size if it exists, otherwise 1 diff --git a/src/reasoning360/trainer/config/ref/megatron_ref.yaml b/src/reasoning360/trainer/config/ref/megatron_ref.yaml new file mode 100644 index 000000000..ca1fbb3c0 --- /dev/null +++ b/src/reasoning360/trainer/config/ref/megatron_ref.yaml @@ -0,0 +1,30 @@ +# megatron ref config, inheriting from trainer/config/ref/ref.yaml +defaults: + - ref + + # megatron engine config + - ../engine@megatron: megatron + + # load the reference default config, then apply the fields in the current yaml + - _self_ + +_target_: verl.workers.config.McoreActorConfig + +strategy: megatron + +megatron: + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} + tensor_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.tensor_model_parallel_size,1} + pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.pipeline_model_parallel_size,1} + virtual_pipeline_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size,null} + context_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.context_parallel_size,1} + expert_model_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.expert_model_parallel_size,1} + expert_tensor_parallel_size: ${oc.select:actor_rollout_ref.actor.megatron.expert_tensor_parallel_size,null} + param_offload: ${oc.select:actor_rollout_ref.actor.megatron.param_offload,False} + forward_only: True + +load_weight: True diff --git a/src/reasoning360/trainer/config/ref/ref.yaml b/src/reasoning360/trainer/config/ref/ref.yaml new file mode 100644 index 000000000..ec566c25b --- /dev/null +++ b/src/reasoning360/trainer/config/ref/ref.yaml @@ -0,0 +1,121 @@ +# Number of rollouts per update (mirrors actor rollout_n) +rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + +# actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default +strategy: ${actor_rollout_ref.actor.strategy} + +# whether to enable torch.compile +# same as actor_rollout_ref.actor.use_torch_compile if it exists, otherwise 1 +use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + +# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] +# The batch size for one forward pass in the computation of log_prob. Global batch size. +log_prob_micro_batch_size: null + +# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. +log_prob_micro_batch_size_per_gpu: null + +# enable dynamic batch size (sequence packing) for log_prob computation +# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false +log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + +# the max token length per GPU +# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384 +log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + +# profile the ref model in `compute_log_prob` +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # choices: nsys, npu, torch, torch_memory + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on Ref + enable: False + + # Whether to profile all ranks. + all_ranks: False + + # The ranks that will be profiled. [] or [0,1,...] + ranks: [] + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config which only related to the role + tool_config: + + # nsys tool config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + + # npu config + npu: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NPUToolConfig + + # Contents to profile, can be empty + # options: npu, cpu, memory, shapes, module, stack + contents: [] + + # Collection level, optional values: level_none, level0, level1, level2. + level: "level1" + + # Whether to automatically parse the data. + analysis: True + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # torch profiler config + torch: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + + # start profile mini-batch in training + # NOTICE: different with global steps config which refers to iteration + # This field only related with mini-batch + step_start: 0 + + # stop profile mini-batch in training + step_end: null + + # torch memory profiler config + torch_memory: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + + # Maximum number of memory allocation entries to track + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + + # Stack trace depth for memory allocations + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + +# Router replay configuration for MoE models +router_replay: + + # Target dataclass for this configuration + _target_: verl.workers.config.RouterReplayConfig + + # Router replay mode: disabled, R2, R3 + # - R2: Use R2 routing strategy (record mode) + # - R3: Use R3 routing strategy (record mode) + mode: disabled + + # File path to save recorded routing decisions + # Required when mode is 'record', 'R2', or 'R3' + record_file: null + + # File path to load recorded routing decisions for replay + # Required when mode is 'replay' + replay_file: null \ No newline at end of file diff --git a/src/reasoning360/trainer/config/reward_manager.yaml b/src/reasoning360/trainer/config/reward_manager.yaml new file mode 100644 index 000000000..3e55a1daf --- /dev/null +++ b/src/reasoning360/trainer/config/reward_manager.yaml @@ -0,0 +1,8 @@ +# See `verl/trainer/config/config.py:RewardManagerConfig` for more details. +_target_: verl.trainer.config.config.RewardManagerConfig +source: register +name: ${oc.select:reward_model.reward_manager,naive} +module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager diff --git a/verl/trainer/config/reward_model/dp_reward_model.yaml b/src/reasoning360/trainer/config/reward_model/dp_reward_model.yaml similarity index 94% rename from verl/trainer/config/reward_model/dp_reward_model.yaml rename to src/reasoning360/trainer/config/reward_model/dp_reward_model.yaml index d9a837032..fff1f9f1f 100644 --- a/verl/trainer/config/reward_model/dp_reward_model.yaml +++ b/src/reasoning360/trainer/config/reward_model/dp_reward_model.yaml @@ -29,8 +29,12 @@ model: # FSDP-specific config fsdp_config: + # Target configuration dataclass + _target_: verl.workers.config.FSDPEngineConfig + # Policy for wrapping layers with FSDP wrap_policy: + # Minimum number of parameters to trigger wrapping min_num_params: 0 diff --git a/verl/trainer/config/reward_model/megatron_reward_model.yaml b/src/reasoning360/trainer/config/reward_model/megatron_reward_model.yaml similarity index 61% rename from verl/trainer/config/reward_model/megatron_reward_model.yaml rename to src/reasoning360/trainer/config/reward_model/megatron_reward_model.yaml index 2c5d35cd5..ea585075e 100644 --- a/verl/trainer/config/reward_model/megatron_reward_model.yaml +++ b/src/reasoning360/trainer/config/reward_model/megatron_reward_model.yaml @@ -15,6 +15,10 @@ nccl_timeout: 600 # Megatron parallelism & checkpointing config megatron: + + # Target configuration dataclass + _target_: verl.workers.config.MegatronEngineConfig + # Whether to offload model parameters to CPU param_offload: False @@ -24,7 +28,7 @@ megatron: # Number of GPUs in expert model parallel group expert_model_parallel_size: 1 - # Expert tensor parallel size + # Expert tensor parallel size (null to be same as TP) expert_tensor_parallel_size: null # Number of pipeline model parallel stages @@ -48,14 +52,25 @@ megatron: # Path for distributed checkpoints dist_checkpointing_path: null + # distributed checkpointing prefix, e.g. Nemo2 will append prefix 'module.' to the state dict keys + dist_checkpointing_prefix: '' + # RNG seed for megatron - seed: ${actor_rollout_ref.actor.megatron.seed} + seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} # Any overrides to transformer config - override_transformer_config: {} + override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} # Whether to use mbridge for faster comms - use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} + use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} + + # Whether to use mbridge instead of Megatron-Bridge + vanilla_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.vanilla_mbridge,True} + + # Whether to use thd format (sequence packing), if not, use bshd format, padding the input_ids to the longest sequence length + use_remove_padding: ${oc.select:actor_rollout_ref.actor.megatron.use_remove_padding,True} + + dtype: bfloat16 # Whether to load weights (default True) load_weight: True \ No newline at end of file diff --git a/verl/trainer/config/reward_model/reward_model.yaml b/src/reasoning360/trainer/config/reward_model/reward_model.yaml similarity index 68% rename from verl/trainer/config/reward_model/reward_model.yaml rename to src/reasoning360/trainer/config/reward_model/reward_model.yaml index 698343955..14abdbbc5 100644 --- a/verl/trainer/config/reward_model/reward_model.yaml +++ b/src/reasoning360/trainer/config/reward_model/reward_model.yaml @@ -6,14 +6,20 @@ # If False, the following parameters are not effective enable: False +# Whether to deploy the model to a separate resource pool. +# If true, n_gpus_per_node & nnodes will be used to determine the resource node. +enable_resource_pool: False +n_gpus_per_node: 0 +nnodes: 0 + # FSDP strategy: "fsdp" or "fsdp2" strategy: ??? # model config for reward scoring model: - # Input tokenizer. If the reward model’s chat template is inconsistent with the policy, - # we need to first decode to plaintext, then apply the rm’s chat_template. + # Input tokenizer. If the reward model's chat template is inconsistent with the policy, + # we need to first decode to plaintext, then apply the rm's chat_template. # Then score with RM. If chat_templates are consistent, it can be set to null. # set this to null if the chat template is identical input_tokenizer: ${actor_rollout_ref.model.path} @@ -28,6 +34,9 @@ model: # Whether to enable loading a remote code model, default to False trust_remote_code: False + # override hf config + override_config: {} + # [Deprecated] Global micro batch size # will be deprecated, use micro_batch_size_per_gpu micro_batch_size: null @@ -44,9 +53,8 @@ use_dynamic_bsz: ${critic.use_dynamic_bsz} # Maximum number of tokens per GPU in one forward pass forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} -# Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources. -# Default is naive. If all verification functions are multiprocessing-safe, -# the reward manager can be set to prime for parallel verification. +# Deprecated. Use `reward_manager.name` instead. See `verl/trainer/config/reward_manager.yaml` for details. +# Kept for backward compatibility. reward_manager: naive # Whether to launch custom reward function asynchronously during log_prob @@ -65,17 +73,27 @@ sandbox_fusion: # Max memory limit for each sandbox process in MB memory_limit_mb: 1024 -# profiler configs +# profile the reward model in `compute_reward` profiler: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs _target_: verl.utils.profiler.ProfilerConfig - # True for each task has its own database, False for all tasks in one training step share one database. - discrete: False + # profiler tool, default same as profiler.tool in global config + # choices: nsys, npu, torch + tool: ${oc.select:global_profiler.tool,null} + # whether enable profile on ref + enable: False + # Whether to profile all ranks. all_ranks: False # The ranks that will be profiled. [] or [0,1,...] - ranks: [] \ No newline at end of file + ranks: [] + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} \ No newline at end of file diff --git a/verl/trainer/config/rollout/rollout.yaml b/src/reasoning360/trainer/config/rollout/rollout.yaml similarity index 54% rename from verl/trainer/config/rollout/rollout.yaml rename to src/reasoning360/trainer/config/rollout/rollout.yaml index fc3af80d4..968d9e112 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/src/reasoning360/trainer/config/rollout/rollout.yaml @@ -1,8 +1,11 @@ +# Target class for this configuration +_target_: verl.workers.config.RolloutConfig + # actor_rollout_ref.rollout.name: hf/vllm/sglang. The default value will be removed in the future -name: vllm +name: ??? # sync: LLM, async: AsyncLLM -mode: sync +mode: async # Sampling temperature for rollout. temperature: 1.0 @@ -31,15 +34,30 @@ gpu_memory_utilization: 0.5 # Whether to ignore EOS and continue generating after EOS is hit. ignore_eos: False -# Whether to disable CUDA graph. Default True to allow cache freeing. -enforce_eager: True +# Whether to disable CUDA graph. Default False to best performance. +enforce_eager: False + +# batch size of cudagraph to capture. Require enforce_eager: False to use this option +# Since cudagraph in inference engine can not be offloaded during update policy, +# you can use smaller batch size to save memory used in cuda graph, eg: [1 ,2, 4, 8, 16, 32] +# supported engines: vllm +cudagraph_capture_sizes: null -# Whether to free engine KVCache after generation. Set enforce_eager=True when enabled. +# Whether to free engine KVCache after generation. free_cache_engine: True # TP size for rollout. Not effective for hf tensor_model_parallel_size: 2 +# DP size for rollout +data_parallel_size: 1 + +# EP size for rollout +expert_parallel_size: 1 + +# PP size for rollout. +pipeline_model_parallel_size: 1 + # max number of tokens in a batch max_num_batched_tokens: 8192 @@ -49,6 +67,16 @@ max_model_len: null # max length of sequences max_num_seqs: 1024 +# may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. +enable_chunked_prefill: True + +# Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. +enable_prefix_caching: True + +# Which loader to use for rollout model weights: dummy, hf, megatron, etc. +# safetensors (for huge model, and set use_shm=True); dummy: randomly init model weight +load_format: dummy + # [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size. log_prob_micro_batch_size: null @@ -73,30 +101,30 @@ do_sample: True # number of responses (i.e. num sample times). > 1 for grpo n: 1 -# Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache) +# The over_sample_rate parameter controls the early termination threshold for training rollouts, +# where the system will abort remaining requests when (1 - over_sample_rate) * total_requests completions are reached. +over_sample_rate: 0 + +# Whether to wake up inference engine in multi-stage for SGLang +# to reduce peak memory during training-rollout transition. +# This is only effective for SGLang rollout. multi_stage_wake_up: false -# Extra inference engine arguments (vllm, sglang). +# Extra inference engine arguments (vllm, sglang), please refer vllm/sglang official doc for detail engine_kwargs: - # for vllm - vllm: - - # Swap space (in GB) used by inference engine. null uses default (e.g., 4 GB). - swap_space: null - - # Whether to disable the preprocessor cache for multimodel models. - disable_mm_preprocessor_cache: False + # vllm engine config + vllm: {} - # for sglang - sglang: - - # The attention backend for sglang engine. Options: flashinfer, triton, flashmla, null for default. - attention_backend: null + # sglang engine config + sglang: {} # Sampling parameters used during validation. val_kwargs: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.SamplingConfig + # sampling parameters for validation # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. top_k: -1 @@ -116,6 +144,9 @@ val_kwargs: # Multi-turn interaction config for tools or chat. multi_turn: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.MultiTurnConfig + # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well enable: False @@ -140,9 +171,6 @@ multi_turn: # null for no interaction interaction_config_path: null - # null for default callback - completion_callback: null - # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. @@ -161,15 +189,25 @@ multi_turn: # Format of the multi-turn interaction. Options: hermes, llama3_json, ... format: hermes + # Number of repeat rollouts for each interaction + num_repeat_rollouts: null + # support logging rollout prob for debugging purpose +# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling calculate_log_probs: False # [Experimental] agent loop based rollout configs agent: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.AgentLoopConfig + # Number of agent loop workers num_workers: 8 + # default agent loop to use if `agent_name` not set in RL dataset + default_agent_loop: single_turn_agent + # custom agent loop config path, which should contain list of configs to intialize AgentLoop instances. # https://hydra.cc/docs/advanced/instantiate_objects/overview/ # @@ -185,6 +223,9 @@ agent: # custom async server configs custom_async_server: + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.CustomAsyncServerConfig + # Path to the custom async server implementation path: null @@ -207,9 +248,80 @@ update_weights_bucket_megabytes: 512 # trace rollout data trace: - + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.TraceConfig + # trace backend, support mlflow, weave backend: null # whether translate token id to text in output token2text: False + + # Maximum number of unique samples to trace per agent worker per training step. + # If null, all samples are traced. If set to N, each agent loop worker will randomly + # select N unique samples to trace (including all their rollouts for GRPO). + # Total traces per step = max_samples_per_step_per_worker * num_workers * n_rollouts_per_sample + max_samples_per_step_per_worker: null + +# When enabled (True), the trainer will attempt to load previously generated rollout data from the specified directory instead of computing new rollouts. +# If no cached data is found or loading fails, new rollouts will be generated and automatically saved. +# This feature is useful for debugging or when you want to reuse computation results across multiple runs. +skip_rollout: False + +# Specifies the filesystem path where rollout data should be cached when skip_rollout is enabled. +# Note: Giving path under /tmp/ray/session* is not recommended as these are temporary Ray cluster directories. +skip_dump_dir: /tmp/rollout_dump + +# Whether to skip tokenizer initialization for rollout engine +# When enabled (True), the rollout assume token in token out for generation +skip_tokenizer_init: True + +# Whether to enable rollout routing replay for MoE models +# When enabled (True), the rollout will record the routing decisions. +enable_rollout_routing_replay: False + + +# profile the rollout model in `generate_sequence` +profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # profiler tool, default same as profiler.tool in global config + # choices: nsys, npu, torch + tool: ${oc.select:global_profiler.tool,null} + + # whether enable profile on ref + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + + # Whether to profile all ranks. + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + + # The ranks that will be profiled. [] or [0,1,...] + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + + # profile results saving path + save_path: ${oc.select:global_profiler.save_path,null} + + # specific tool config + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + +# prometheus configuration for vllm/sglang server mode +prometheus: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.workers.config.PrometheusConfig + + # whether enable prometheus on server mode rollout + enable: false + + # Port number that Prometheus listens on, default is 9090 + port: 9090 + + # Path to Prometheus configuration file + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + + # Specify served_model_name to avoid displaying overly long model paths in Grafana + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + diff --git a/verl/trainer/config/sft_trainer.yaml b/src/reasoning360/trainer/config/sft_trainer.yaml similarity index 63% rename from verl/trainer/config/sft_trainer.yaml rename to src/reasoning360/trainer/config/sft_trainer.yaml index c3af1a48f..b2308e39e 100644 --- a/verl/trainer/config/sft_trainer.yaml +++ b/src/reasoning360/trainer/config/sft_trainer.yaml @@ -1,9 +1,15 @@ +defaults: + - optim: fsdp + - _self_ + data: train_batch_size: 256 micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu micro_batch_size_per_gpu: 4 # this is also val batch size train_files: ~/data/gsm8k/train.parquet val_files: ~/data/gsm8k/test.parquet + train_max_samples: -1 # set to -1 to use full dataset + val_max_samples: -1 # set to -1 to use full dataset # Single-turn settings prompt_key: question response_key: answer @@ -23,6 +29,7 @@ data: path: null name: null use_shm: False + apply_chat_template_kwargs: {} model: partial_pretrain: ~/models/gemma-1.1-7b-it use_shm: False @@ -44,7 +51,7 @@ optim: lr: 1e-5 betas: [0.9, 0.95] weight_decay: 0.01 - warmup_steps_ratio: 0.1 + lr_warmup_steps_ratio: 0.1 clip_grad: 1.0 lr_scheduler: cosine ulysses_sequence_parallel_size: 1 @@ -52,16 +59,33 @@ use_remove_padding: False trainer: default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} default_hdfs_dir: null - resume_path: null project_name: gsm8k-sft experiment_name: test total_epochs: 4 total_training_steps: null logger: [ 'console', 'wandb' ] seed: 1 - save_freq: -1 test_freq: -1 nnodes: 1 n_gpus_per_node: 8 - max_ckpt_to_keep: null # TODO + max_ckpt_to_keep: null # Maximum number of checkpoints to keep, set to null to keep all + + # Resume mode: "auto", "disable", or "resume_path" + # "auto": resume from last checkpoint if available + # "disable": start from scratch + # "resume_path": resume from a user-defined path + resume_mode: auto + + # Path to resume training from (used when resume_mode is "resume_path" or "auto") + resume_from_path: null + + # Checkpoint configuration + checkpoint: + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ["model", "optimizer", "extra"] + + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${trainer.checkpoint.save_contents} + device: cuda diff --git a/src/reasoning360/trainer/config/sft_trainer_engine.yaml b/src/reasoning360/trainer/config/sft_trainer_engine.yaml new file mode 100644 index 000000000..f11b3bf8f --- /dev/null +++ b/src/reasoning360/trainer/config/sft_trainer_engine.yaml @@ -0,0 +1,80 @@ +# Format checks enforced on CI: +# 1. Comments must appear above each field. +# 2. There must be a blank line between each field. +# 3. Inline comments (after a field on the same line) are not allowed. +# 4. Indentation level is respected for nested fields. + +# @.: + +defaults: + - model@model: hf_model + - engine@engine: fsdp + - optim@optim: fsdp + - _self_ + +data: + train_batch_size: 256 # global batch size + micro_batch_size_per_gpu: 4 # this is also val batch size + max_token_len_per_gpu: 8192 + use_dynamic_bsz: True + train_files: ~/data/gsm8k/train.parquet + val_files: null + train_max_samples: -1 # set to -1 to use full dataset + val_max_samples: -1 # set to -1 to use full dataset + # Multi-turn settings + messages_key: messages # Key for messages list in multi-turn mode + tools_key: tools # Key for tools list in multi-turn mode + enable_thinking_key: enable_thinking # Whether to enable thinking in multi-turn mode + pad_mode: no_padding + # for right padding + max_length: 1024 + truncation: error + balance_dp_token: False # to be implement + custom_cls: + path: null + name: null + use_shm: False + apply_chat_template_kwargs: {} + + # MultiTurnSFTDataset apply_chat_template to each turn separately and concat `input_ids` + # as a whole sequence, which may not equal to apply_chat_template to whole messages at once. + # For example, Qwen Thinking series models add tags to last turn, please check + # your tokenizer chat template settings. + # Set to True to ignore input_ids mismatch and use the concatenated input_ids as the final input_ids. + ignore_input_ids_mismatch: False + +# Checkpoint configuration +checkpoint: + _target_: verl.trainer.config.CheckpointConfig + # What to include in saved checkpoints + # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + save_contents: ["model", "optimizer", "extra"] + + # For more flexibility, you can specify the contents to load from the checkpoint. + load_contents: ${checkpoint.save_contents} + +trainer: + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + default_hdfs_dir: null + project_name: gsm8k-sft + experiment_name: test + total_epochs: 4 + total_training_steps: null + logger: [ 'console', 'wandb' ] + seed: 1 + save_freq: -1 + test_freq: -1 + max_ckpt_to_keep: null # Maximum number of checkpoints to keep, set to null to keep all + + # Resume mode: "auto", "disable", or "resume_path" + # "auto": resume from last checkpoint if available + # "disable": start from scratch + # "resume_path": resume from a user-defined path + resume_mode: auto + + # Path to resume training from (used when resume_mode is "resume_path" or "auto") + resume_from_path: null + device: cuda + + nnodes: 1 + n_gpus_per_node: 1 diff --git a/verl/trainer/main_generation.py b/src/reasoning360/trainer/main_generation.py similarity index 77% rename from verl/trainer/main_generation.py rename to src/reasoning360/trainer/main_generation.py index 394927146..943f020e3 100644 --- a/verl/trainer/main_generation.py +++ b/src/reasoning360/trainer/main_generation.py @@ -15,8 +15,8 @@ Generate responses given a dataset of prompts """ -import json import os +import json import hydra import numpy as np @@ -40,7 +40,7 @@ from verl.utils.model import compute_position_id_with_mask from verl.workers.fsdp_workers import ActorRolloutRefWorker - +# NOTE: added by Reasoning360 def merge_responses(responses): """Merge multiple response lists into one""" merged = [] @@ -49,6 +49,7 @@ def merge_responses(responses): return merged +# NOTE: added by Reasoning360 def extract_content(p): """Extract content from prompt (handle both string and list formats)""" if isinstance(p, str): @@ -60,7 +61,7 @@ def extract_content(p): return p[0].get("content", "") return str(p) - +# NOTE: added by Reasoning360 def merge_aime_responses(dataset, output_lst, prompt_key="prompt", response_key="responses"): """Merge responses for AIME dataset based on prompt content""" # Convert to pandas DataFrame if it's not already @@ -105,10 +106,13 @@ def main(config): def run_generation(config) -> None: if not ray.is_initialized(): # this is for local ray cluster - ray.init( - runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}, - num_cpus=config.ray_init.num_cpus, - ) + default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}} + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) ray.get(main_task.remote(config)) @@ -149,6 +153,10 @@ def main_task(config): assert config.data.n_samples >= 1, "n_samples should always >= 1" # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary) + # dataset = pd.read_parquet(config.data.path) + # chat_lst = dataset[config.data.prompt_key].tolist() + + # NOTE: modified by Reasoning360 is_polars_df = False if "livecodebench" in config.data.path: import polars as pl @@ -169,6 +177,8 @@ def main_task(config): chat_lst = chat_lst * config.data.n_samples ground_truth_lst = ground_truth_lst * config.data.n_samples + chat_lst = [chat.tolist() for chat in chat_lst] + tokenizer.padding_side = "left" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -182,18 +192,12 @@ def main_task(config): ) wg.init_model() - # NOTE: updated by Reasoning360. Sample n times together + # NOTE: the following is modified by Reasoning360. total_samples = len(chat_lst) # chat_lst is repeated config_batch_size = config.data.batch_size num_batch = -(-total_samples // config_batch_size) - output_lst = [] - # total_samples = len(dataset) - # config_batch_size = config.data.batch_size - # num_batch = -(-total_samples // config_batch_size) - # output_lst = [[] for _ in range(config.data.n_samples)] - for batch_idx in range(num_batch): print(f"[{batch_idx + 1}/{num_batch}] Start to process.") batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size] @@ -298,6 +302,63 @@ def main_task(config): with open(config.data.output_path.replace(".parquet", f"_{model_name}.json"), "w", encoding="utf-8") as f: json.dump(result_list, f, indent=2, ensure_ascii=False) + # total_samples = len(dataset) + # config_batch_size = config.data.batch_size + # apply_chat_template_kwargs = config.data.get("apply_chat_template_kwargs", {}) + # num_batch = -(-total_samples // config_batch_size) + # output_lst = [[] for _ in range(config.data.n_samples)] + + # for batch_idx in range(num_batch): + # print(f"[{batch_idx + 1}/{num_batch}] Start to process.") + # batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size] + # inputs = tokenizer.apply_chat_template( + # batch_chat_lst, + # add_generation_prompt=True, + # padding=True, + # truncation=True, + # max_length=config.rollout.prompt_length, + # return_tensors="pt", + # return_dict=True, + # tokenize=True, + # **apply_chat_template_kwargs, + # ) + # input_ids = inputs["input_ids"] + # attention_mask = inputs["attention_mask"] + # position_ids = compute_position_id_with_mask(attention_mask) + # batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} + + # data = DataProto.from_dict(batch_dict) + # data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size) + + # # START TO GENERATE FOR n_samples TIMES + # print(f"[{batch_idx + 1}/{num_batch}] Start to generate.") + # for n_sample in range(config.data.n_samples): + # output_padded = wg.generate_sequences(data_padded) + # output = unpad_dataproto(output_padded, pad_size=pad_size) + + # output_texts = [] + # for i in range(len(output)): + # data_item = output[i] + # prompt_length = data_item.batch["prompts"].shape[-1] + # valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + # valid_response_ids = data_item.batch["responses"][:valid_response_length] + # response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True) + # output_texts.append(response_str) + + # output_lst[n_sample].extend(output_texts) + + # # convert output_lst from (n_samples, n_data) to (n_data, n_sampels) + # output_lst = np.array(output_lst, dtype=object) + # output_lst = np.transpose(output_lst, axes=(1, 0)).tolist() + + # # add to the data frame + # dataset["responses"] = output_lst + + # # write to a new parquet + # output_dir = os.path.dirname(config.data.output_path) + # makedirs(output_dir, exist_ok=True) + # dataset.to_parquet(config.data.output_path) + if __name__ == "__main__": main() diff --git a/verl/trainer/main_ppo.py b/src/reasoning360/trainer/main_ppo.py similarity index 51% rename from verl/trainer/main_ppo.py rename to src/reasoning360/trainer/main_ppo.py index f2a1433d5..ab6604b53 100644 --- a/verl/trainer/main_ppo.py +++ b/src/reasoning360/trainer/main_ppo.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +Note that we don't combine the main with ray_trainer as ray_trainer is used by other mpain. """ import os @@ -23,11 +23,13 @@ from omegaconf import OmegaConf from verl.experimental.dataset.sampler import AbstractSampler -from verl.trainer.constants_ppo import PPO_RAY_RUNTIME_ENV -from verl.trainer.ppo.ray_trainer import RayPPOTrainer -from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.constants_ppo import get_ppo_ray_runtime_env +from reasoning360.trainer.ppo.ray_trainer import RayPPOTrainer +from reasoning360.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import need_critic, need_reference_policy +from verl.utils.config import validate_config from verl.utils.device import is_cuda_available -from verl.utils.import_utils import load_extern_type +from verl.utils.import_utils import load_extern_object @hydra.main(config_path="config", config_name="ppo_trainer", version_base=None) @@ -41,13 +43,14 @@ def main(config): # Define a function to run the PPO-like training process -def run_ppo(config) -> None: +def run_ppo(config, task_runner_class=None) -> None: """Initialize Ray cluster and run distributed PPO training process. Args: config: Training configuration object containing all necessary parameters for distributed PPO training including Ray initialization settings, model paths, and training hyperparameters. + task_runner_class: For recipe to change TaskRunner. """ # Check if Ray is not initialized if not ray.is_initialized(): @@ -55,79 +58,96 @@ def run_ppo(config) -> None: # Set environment variables in the runtime environment to control tokenizer parallelism, # NCCL debug level, VLLM logging level, and allow runtime LoRA updating # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration - ray.init( - runtime_env=PPO_RAY_RUNTIME_ENV, - num_cpus=config.ray_init.num_cpus, - ) + default_runtime_env = get_ppo_ray_runtime_env() + ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) + runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) + + if config.transfer_queue.enable: + # Add runtime environment variables for transfer queue + runtime_env_vars = runtime_env_kwargs.get("env_vars", {}) + runtime_env_vars["TRANSFER_QUEUE_ENABLE"] = "1" + runtime_env_kwargs["env_vars"] = runtime_env_vars + + runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) + ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) + print(f"ray init kwargs: {ray_init_kwargs}") + ray.init(**OmegaConf.to_container(ray_init_kwargs)) + + if task_runner_class is None: + task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head # Create a remote instance of the TaskRunner class, and # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete if ( is_cuda_available - and config.trainer.get("profile_steps") is not None - and len(config.trainer.get("profile_steps", [])) > 0 + and config.global_profiler.tool == "nsys" + and config.global_profiler.get("steps") is not None + and len(config.global_profiler.get("steps", [])) > 0 ): - nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options) - runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + from verl.utils.import_utils import is_nvtx_available + + assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" + nsight_options = OmegaConf.to_container( + config.global_profiler.global_tool_config.nsys.controller_nsight_options + ) + runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote() else: - runner = TaskRunner.remote() + runner = task_runner_class.remote() ray.get(runner.run.remote(config)) # [Optional] get the path of the timeline trace file from the configuration, default to None # This file is used for performance analysis - timeline_json_file = config.ray_init.get("timeline_json_file", None) + timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) if timeline_json_file: ray.timeline(filename=timeline_json_file) -@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: """Ray remote class for executing distributed PPO training tasks. This class encapsulates the main training logic and runs as a Ray remote actor to enable distributed execution across multiple nodes and GPUs. + + Attributes: + role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes + mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation """ - def run(self, config): - """Execute the main PPO training workflow. + def __init__(self): + self.role_worker_mapping = {} + self.mapping = {} - This method sets up the distributed training environment, initializes - workers, datasets, and reward functions, then starts the training process. + def add_actor_rollout_worker(self, config): + """Add actor rollout worker based on the actor strategy.""" + from verl.single_controller.ray import RayWorkerGroup + from verl.trainer.ppo.ray_trainer import Role - Args: - config: Training configuration object containing all parameters needed - for setting up and running the PPO training process. - """ - # Print the initial configuration. `resolve=True` will evaluate symbolic values. - from pprint import pprint - - from omegaconf import OmegaConf + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") - from verl.utils.fs import copy_to_local - - print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") - pprint(OmegaConf.to_container(config, resolve=True)) - OmegaConf.resolve(config) + # use new model engine implementation + if use_legacy_worker_impl == "disable": + from verl.workers.engine_workers import ActorRolloutRefWorker - # Download the checkpoint from HDFS to the local machine. - # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on - local_path = copy_to_local( - config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) - ) - - # Instantiate the tokenizer and processor. - from verl.utils import hf_processor, hf_tokenizer - - trust_remote_code = config.data.get("trust_remote_code", False) - tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - # Used for multimodal LLM, could be None - processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + actor_rollout_cls = ActorRolloutRefWorker + ray_worker_group_cls = RayWorkerGroup + # NOTE: In new model engine, ref policy and actor rollout are in same ActorRolloutRefWorker, + # while in legacy model engine, ref policy is in a separate ActorRolloutRefWorker. + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role = Role.ActorRolloutRef + else: + role = Role.ActorRollout + self.role_worker_mapping[role] = ray.remote(actor_rollout_cls) + self.mapping[role] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls + + if config.actor_rollout_ref.rollout.mode == "sync": + raise ValueError( + "Rollout mode 'sync' has been removed. Please set " + "`actor_rollout_ref.rollout.mode=async` to use the native server rollout." + ) - # Define worker classes based on the actor strategy. if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}: - assert config.critic.strategy in {"fsdp", "fsdp2"} - from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker actor_rollout_cls = ( AsyncActorRolloutRefWorker @@ -137,38 +157,131 @@ def run(self, config): ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == "megatron": - assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker actor_rollout_cls = ( AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker ) - ray_worker_group_cls = NVMegatronRayWorkerGroup + ray_worker_group_cls = RayWorkerGroup else: raise NotImplementedError - from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + self.role_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls) + self.mapping[Role.ActorRollout] = "global_pool" + return actor_rollout_cls, ray_worker_group_cls - # Map roles to their corresponding remote worker classes. - role_worker_mapping = { - Role.ActorRollout: ray.remote(actor_rollout_cls), - Role.Critic: ray.remote(CriticWorker), - } + def add_critic_worker(self, config): + """Add critic worker to role mapping.""" + if config.critic.strategy in {"fsdp", "fsdp2"}: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable"]: + from verl.workers.fsdp_workers import CriticWorker + elif use_legacy_worker_impl == "disable": + from verl.workers.engine_workers import CriticWorker + + print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + elif config.critic.strategy == "megatron": + from verl.workers.megatron_workers import CriticWorker + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import Role + + self.role_worker_mapping[Role.Critic] = ray.remote(CriticWorker) + self.mapping[Role.Critic] = "global_pool" + + def init_resource_pool_mgr(self, config): + """Initialize resource pool manager.""" - # Define the resource pool specification. - # Map roles to the resource pool. global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, } - mapping = { - Role.ActorRollout: global_pool_id, - Role.Critic: global_pool_id, - } + # TODO Here you can use the new registration method to support dynamic registration of roles + if config.reward_model.enable_resource_pool: + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + + from reasoning360.trainer.ppo.ray_trainer import ResourcePoolManager + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping) + return resource_pool_manager + + def add_reward_model_worker(self, config): + """Add reward model worker if enabled.""" + from verl.trainer.ppo.ray_trainer import Role + + if config.reward_model.enable: + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl in ["auto", "enable", "disable"]: + if config.reward_model.strategy in {"fsdp", "fsdp2"}: + from verl.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == "megatron": + from verl.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + # elif use_legacy_worker_impl == "disable": + # from verl.workers.engine_workers import RewardModelWorker + # + # print("Using new worker implementation") + else: + raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") + + self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + if config.reward_model.enable_resource_pool: + self.mapping[Role.RewardModel] = "reward_pool" + else: + self.mapping[Role.RewardModel] = "global_pool" + + def add_ref_policy_worker(self, config, ref_policy_cls): + """Add reference policy worker if KL loss or KL reward is used.""" + from verl.trainer.ppo.ray_trainer import Role + + # Ref policy has been fused into ActorRolloutRefWorker in new model engine, + # we don't need to add a separate ref policy worker group. + use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") + if use_legacy_worker_impl == "disable": + return + + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + self.role_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls) + self.mapping[Role.RefPolicy] = "global_pool" + + def run(self, config): + """Execute the main PPO training workflow. + + This method sets up the distributed training environment, initializes + workers, datasets, and reward functions, then starts the training process. + + Args: + config: Training configuration object containing all parameters needed + for setting up and running the PPO training process. + """ + # Print the initial configuration. `resolve=True` will evaluate symbolic values. + from pprint import pprint + + from omegaconf import OmegaConf + + from verl.utils.fs import copy_to_local + + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) + self.add_critic_worker(config) # We should adopt a multi-source reward function here: # - for rule-based rm, we directly call a reward score @@ -176,20 +289,31 @@ def run(self, config): # - for code related prompt, we send to a sandbox if there are test cases # finally, we combine all the rewards together # The reward type depends on the tag of the data - if config.reward_model.enable: - if config.reward_model.strategy in {"fsdp", "fsdp2"}: - from verl.workers.fsdp_workers import RewardModelWorker - elif config.reward_model.strategy == "megatron": - from verl.workers.megatron_workers import RewardModelWorker - else: - raise NotImplementedError - role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) - mapping[Role.RewardModel] = global_pool_id + self.add_reward_model_worker(config) # Add a reference policy worker if KL loss or KL reward is used. - if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: - role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker) - mapping[Role.RefPolicy] = global_pool_id + self.add_ref_policy_worker(config, actor_rollout_cls) + + # validate config + validate_config( + config=config, + use_reference_policy=need_reference_policy(self.role_worker_mapping), + use_critic=need_critic(config), + ) + + # Download the checkpoint from HDFS to the local machine. + # `use_shm` determines whether to use shared memory, which could lead to faster model loading if turned on + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + + # Instantiate the tokenizer and processor. + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) # Load the reward manager for training and validation. reward_fn = load_reward_manager( @@ -198,13 +322,28 @@ def run(self, config): val_reward_fn = load_reward_manager( config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) ) - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - from verl.utils.dataset.rl_dataset import collate_fn + resource_pool_manager = self.init_resource_pool_mgr(config) + + from reasoning360.utils.dataset.rl_dataset import collate_fn # Create training and validation datasets. - train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor, is_train=True) - val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor, is_train=False) + train_dataset = create_rl_dataset( + config.data.train_files, + config.data, + tokenizer, + processor, + is_train=True, + max_samples=config.data.get("train_max_samples", -1), + ) + val_dataset = create_rl_dataset( + config.data.val_files, + config.data, + tokenizer, + processor, + is_train=False, + max_samples=config.data.get("val_max_samples", -1), + ) train_sampler = create_rl_sampler(config.data, train_dataset) # Initialize the PPO trainer. @@ -212,7 +351,7 @@ def run(self, config): config=config, tokenizer=tokenizer, processor=processor, - role_worker_mapping=role_worker_mapping, + role_worker_mapping=self.role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, @@ -221,15 +360,15 @@ def run(self, config): val_dataset=val_dataset, collate_fn=collate_fn, train_sampler=train_sampler, - device_name=config.trainer.device, ) # Initialize the workers of the trainer. trainer.init_workers() + # Start the training process. trainer.fit() -def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True): +def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=True, max_samples: int = -1): """Create a dataset. Arguments: @@ -243,13 +382,13 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr """ from torch.utils.data import Dataset - from verl.utils.dataset.rl_dataset import RLHFDataset + from reasoning360.utils.dataset.rl_dataset import RLHFDataset # Check if a custom dataset class is specified in the data configuration # and if the path to the custom class is provided if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None: # Dynamically load the custom dataset class - dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) + dataset_cls = load_extern_object(data_config.custom_cls.path, data_config.custom_cls.name) # Verify that the custom dataset class inherits from torch.utils.data.Dataset if not issubclass(dataset_cls, Dataset): raise TypeError( @@ -262,7 +401,6 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr dataset_cls = DynamicGenDataset print("Using DynamicGenDataset for data generation.") - else: # Use the default RLHFDataset class if no custom class is specified dataset_cls = RLHFDataset @@ -274,6 +412,7 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr tokenizer=tokenizer, processor=processor, config=data_config, + max_samples=max_samples, ) return dataset @@ -290,10 +429,13 @@ def create_rl_sampler(data_config, dataset): sampler (Sampler): The sampler. """ import torch - from torch.utils.data import RandomSampler, SequentialSampler + from torch.utils.data import SequentialSampler + + # torch.utils.data.RandomSampler could not recover properly + from torchdata.stateful_dataloader.sampler import RandomSampler if data_config.sampler is not None and data_config.sampler.get("class_path", None) is not None: - curriculum_class = load_extern_type( + curriculum_class = load_extern_object( data_config.sampler.class_path, data_config.sampler.class_name, ) @@ -312,7 +454,9 @@ def create_rl_sampler(data_config, dataset): # If shuffling is enabled in the data configuration, create a random sampler. elif data_config.shuffle: train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed(data_config.get("seed", 1)) + seed = data_config.get("seed") + if seed is not None: + train_dataloader_generator.manual_seed(seed) sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator) else: # If shuffling is disabled, use a sequential sampler to iterate through the dataset in order. diff --git a/verl/trainer/ppo/metric_utils.py b/src/reasoning360/trainer/ppo/metric_utils.py similarity index 79% rename from verl/trainer/ppo/metric_utils.py rename to src/reasoning360/trainer/ppo/metric_utils.py index ed691654a..85dd968c2 100644 --- a/verl/trainer/ppo/metric_utils.py +++ b/src/reasoning360/trainer/ppo/metric_utils.py @@ -26,9 +26,11 @@ from verl import DataProto from verl.utils.import_utils import deprecated + # NOTE: added by Reasoning360. _scores_tables = {} # Global dictionary to store wandb tables + @deprecated("verl.utils.metric.reduce_metrics") def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: """ @@ -121,6 +123,20 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, prompt_length = response_info["prompt_length"] response_length = response_info["response_length"] + aborted_mask = (response_length == 0).bool() + non_aborted_mask = ~aborted_mask + + non_aborted_sequence_score = sequence_score[non_aborted_mask] + non_aborted_sequence_reward = sequence_reward[non_aborted_mask] + + score_mean = torch.mean(non_aborted_sequence_score).detach().item() + score_max = torch.max(non_aborted_sequence_score).detach().item() + score_min = torch.min(non_aborted_sequence_score).detach().item() + + reward_mean = torch.mean(non_aborted_sequence_reward).detach().item() + reward_max = torch.max(non_aborted_sequence_reward).detach().item() + reward_min = torch.min(non_aborted_sequence_reward).detach().item() + valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) @@ -133,19 +149,39 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, # NOTE: added by Reasoning360. Group response lengths and rewards by data source data_source_response_lengths = defaultdict(list) data_source_scores = defaultdict(list) - for i, data_source in enumerate(batch.non_tensor_batch['data_source']): + + # `data_source` is optional; metrics must never crash training if it's missing. + # data_sources = batch.non_tensor_batch.get("data_source", None) + data_sources = batch.non_tensor_batch["data_source"] + # if data_sources is not None: + for i, data_source in enumerate(data_sources): data_source_response_lengths[data_source].append(response_length[i].item()) data_source_scores[data_source].append(sequence_score[i].item()) + # Aborted samples and non-aborted response length statistics + # response_length_non_aborted/*: statistics computed on non-aborted samples only + aborted_ratio = torch.mean(aborted_mask.float()).detach().item() + + non_aborted_response_length = response_length[non_aborted_mask] + if non_aborted_response_length.numel() > 0: + non_aborted_response_length_mean = torch.mean(non_aborted_response_length).detach().item() + non_aborted_response_length_max = torch.max(non_aborted_response_length).detach().item() + non_aborted_response_length_min = torch.min(non_aborted_response_length).detach().item() + non_aborted_response_length_clip_ratio = ( + torch.mean(torch.eq(non_aborted_response_length, max_response_length).float()).detach().item() + ) + else: + raise ValueError("All samples are aborted, this should not happen.") + metrics = { # score - "critic/score/mean": torch.mean(sequence_score).detach().item(), - "critic/score/max": torch.max(sequence_score).detach().item(), - "critic/score/min": torch.min(sequence_score).detach().item(), + "critic/score/mean": score_mean, + "critic/score/max": score_max, + "critic/score/min": score_min, # reward - "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), - "critic/rewards/max": torch.max(sequence_reward).detach().item(), - "critic/rewards/min": torch.min(sequence_reward).detach().item(), + "critic/rewards/mean": reward_mean, + "critic/rewards/max": reward_max, + "critic/rewards/min": reward_min, # adv "critic/advantages/mean": torch.mean(valid_adv).detach().item(), "critic/advantages/max": torch.max(valid_adv).detach().item(), @@ -173,6 +209,15 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, "response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float()) .detach() .item(), + # response length (non-aborted only) + # These statistics exclude aborted samples to avoid skew from zeros + "response_length_non_aborted/mean": non_aborted_response_length_mean, + "response_length_non_aborted/max": non_aborted_response_length_max, + "response_length_non_aborted/min": non_aborted_response_length_min, + "response_length_non_aborted/clip_ratio": non_aborted_response_length_clip_ratio, + # aborted ratio + # Fraction of samples whose response length is zero + "response/aborted_ratio": aborted_ratio, # prompt length "prompt_length/mean": torch.mean(prompt_length).detach().item(), "prompt_length/max": torch.max(prompt_length).detach().item(), @@ -187,25 +232,37 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True) -> dict[str, metrics["num_turns/max"] = num_turns.max() metrics["num_turns/mean"] = num_turns.mean() - # Add data source specific response length metrics + if "tool_call_counts" in batch.non_tensor_batch: + tool_call_counts = batch.non_tensor_batch["tool_call_counts"] + metrics["tool_call_counts/min"] = tool_call_counts.min() + metrics["tool_call_counts/max"] = tool_call_counts.max() + metrics["tool_call_counts/mean"] = tool_call_counts.mean() + + # NOTE: added by Reasoning360. Add data source specific response length metrics for data_source, lengths in data_source_response_lengths.items(): lengths_tensor = torch.tensor(lengths) - metrics.update({ - f"response_length/{data_source}/mean": torch.mean(lengths_tensor).item(), - f"response_length/{data_source}/max": torch.max(lengths_tensor).item(), - f"response_length/{data_source}/min": torch.min(lengths_tensor).item(), - f"response_length/{data_source}/clip_ratio": torch.mean(torch.eq(lengths_tensor, max_response_length).float()).item(), - }) - - # Add data source specific reward metrics + metrics.update( + { + f"response_length/{data_source}/mean": torch.mean(lengths_tensor).item(), + f"response_length/{data_source}/max": torch.max(lengths_tensor).item(), + f"response_length/{data_source}/min": torch.min(lengths_tensor).item(), + f"response_length/{data_source}/clip_ratio": torch.mean( + torch.eq(lengths_tensor, max_response_length).float() + ).item(), + } + ) + + # NOTE: added by Reasoning360. Add data source specific reward metrics for data_source, scores in data_source_scores.items(): scores_tensor = torch.tensor(scores) - metrics.update({ - f"critic/scores/{data_source}/mean": torch.mean(scores_tensor).item(), - f"critic/scores/{data_source}/max": torch.max(scores_tensor).item(), - f"critic/scores/{data_source}/min": torch.min(scores_tensor).item(), - f"critic/scores/{data_source}/std": torch.std(scores_tensor).item(), - }) + metrics.update( + { + f"critic/scores/{data_source}/mean": torch.mean(scores_tensor).item(), + f"critic/scores/{data_source}/max": torch.max(scores_tensor).item(), + f"critic/scores/{data_source}/min": torch.min(scores_tensor).item(), + f"critic/scores/{data_source}/std": torch.std(scores_tensor).item(), + } + ) return metrics @@ -366,7 +423,7 @@ def calc_maj_val(data: list[dict[str, Any]], vote_key: str, val_key: str) -> flo def process_validation_metrics( - data_sources: list[str], sample_inputs: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 + data_sources: list[str], sample_uids: list[str], infos_dict: dict[str, list[Any]], seed: int = 42 ) -> dict[str, dict[str, dict[str, float]]]: """ Process validation metrics into a structured format with statistical analysis. @@ -375,9 +432,10 @@ def process_validation_metrics( various statistical measures including means, standard deviations, best/worst values, and majority voting results. It also performs bootstrap sampling to estimate statistics for different sample sizes. + Args: data_sources: List of data source identifiers for each sample. - sample_inputs: List of input prompts corresponding to each sample. + sample_uids: List of sample uids corresponding to each sample. infos_dict: Dictionary mapping variable names to lists of values for each sample. seed: Random seed for bootstrap sampling. Defaults to 42. @@ -403,23 +461,23 @@ def process_validation_metrics( Example: >>> data_sources = ["source1", "source1", "source2"] - >>> sample_inputs = ["prompt1", "prompt1", "prompt2"] + >>> sample_uids = ["uid1", "uid1", "uid2"] >>> infos_dict = {"score": [0.8, 0.9, 0.7], "pred": ["A", "A", "B"]} - >>> result = process_validation_metrics(data_sources, sample_inputs, infos_dict) + >>> result = process_validation_metrics(data_sources, sample_uids, infos_dict) >>> # result will contain statistics for each data source and variable """ # Group metrics by data source, prompt and variable - data_src2prompt2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + data_src2uid2var2vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) for sample_idx, data_source in enumerate(data_sources): - prompt = sample_inputs[sample_idx] - var2vals = data_src2prompt2var2vals[data_source][prompt] + uid = sample_uids[sample_idx] + var2vals = data_src2uid2var2vals[data_source][uid] for var_name, var_vals in infos_dict.items(): var2vals[var_name].append(var_vals[sample_idx]) # Calculate metrics for each group - data_src2prompt2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) - for data_source, prompt2var2vals in data_src2prompt2var2vals.items(): - for prompt, var2vals in prompt2var2vals.items(): + data_src2uid2var2metric = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + for data_source, uid2var2vals in data_src2uid2var2vals.items(): + for uid, var2vals in uid2var2vals.items(): for var_name, var_vals in var2vals.items(): if isinstance(var_vals[0], str): continue @@ -456,21 +514,21 @@ def process_validation_metrics( ) metric[f"maj@{n}/mean"], metric[f"maj@{n}/std"] = maj_n_mean, maj_n_std - data_src2prompt2var2metric[data_source][prompt][var_name] = metric + data_src2uid2var2metric[data_source][uid][var_name] = metric - # Aggregate metrics across prompts - data_src2var2metric2prompt_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) - for data_source, prompt2var2metric in data_src2prompt2var2metric.items(): - for prompt, var2metric in prompt2var2metric.items(): + # Aggregate metrics across uids + data_src2var2metric2uid_vals = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for data_source, uid2var2metric in data_src2uid2var2metric.items(): + for uid, var2metric in uid2var2metric.items(): for var_name, metric in var2metric.items(): for metric_name, metric_val in metric.items(): - data_src2var2metric2prompt_vals[data_source][var_name][metric_name].append(metric_val) + data_src2var2metric2uid_vals[data_source][var_name][metric_name].append(metric_val) data_src2var2metric2val = defaultdict(lambda: defaultdict(lambda: defaultdict(float))) - for data_source, var2metric2prompt_vals in data_src2var2metric2prompt_vals.items(): - for var_name, metric2prompt_vals in var2metric2prompt_vals.items(): - for metric_name, prompt_vals in metric2prompt_vals.items(): - data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(prompt_vals) + for data_source, var2metric2uid_vals in data_src2var2metric2uid_vals.items(): + for var_name, metric2uid_vals in var2metric2uid_vals.items(): + for metric_name, uid_vals in metric2uid_vals.items(): + data_src2var2metric2val[data_source][var_name][metric_name] = np.mean(uid_vals) return data_src2var2metric2val diff --git a/verl/trainer/ppo/ray_trainer.py b/src/reasoning360/trainer/ppo/ray_trainer.py similarity index 66% rename from verl/trainer/ppo/ray_trainer.py rename to src/reasoning360/trainer/ppo/ray_trainer.py index 84b6753f6..b122d3e1c 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/src/reasoning360/trainer/ppo/ray_trainer.py @@ -24,7 +24,6 @@ from collections import defaultdict from copy import deepcopy from dataclasses import dataclass, field -from enum import Enum from pprint import pprint from typing import Optional @@ -39,45 +38,29 @@ from verl import DataProto from verl.experimental.dataset.sampler import AbstractCurriculumSampler from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto -from verl.single_controller.base import Worker from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.config import AlgoConfig from verl.trainer.ppo import core_algos from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss -from verl.trainer.ppo.metric_utils import ( +from reasoning360.trainer.ppo.metric_utils import ( compute_data_metrics, - compute_difficulty_histogram_metrics, # NOTE: added by Reasoning360 compute_throughout_metrics, + compute_difficulty_histogram_metrics, # NOTE: added by Reasoning360 compute_timing_metrics, process_validation_metrics, ) -from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from reasoning360.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass from verl.utils.debug import marked_timer -from verl.utils.metric import ( - reduce_metrics, -) -from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance +from verl.utils.metric import reduce_metrics +from verl.utils.rollout_skip import RolloutSkip +from verl.utils.seqlen_balancing import calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger -WorkerType = type[Worker] - - -class Role(Enum): - """ - To create more roles dynamically, you can subclass Role and add new members - """ - - Actor = 0 - Rollout = 1 - ActorRollout = 2 - Critic = 3 - RefPolicy = 4 - RewardModel = 5 - ActorRolloutRef = 6 - @dataclass class ResourcePoolManager: @@ -99,11 +82,11 @@ def create_resource_pool(self): """ for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): # max_colocate_count means the number of WorkerGroups (i.e. processes) in each RayResourcePool - # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. + # For FSDP backend, using max_colocate_count=3: actor_critic_ref, rollout, reward model (optional) # For Megatron backend, we recommend using max_colocate_count>1 # that can utilize different WorkerGroup for differnt models resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name + process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=3, name_prefix=resource_pool_name ) self.resource_pool_dict[resource_pool_name] = resource_pool @@ -119,7 +102,7 @@ def get_n_gpus(self) -> int: def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" - node_available_resources = ray.state.available_resources_per_node() + node_available_resources = ray._private.state.available_resources_per_node() node_available_gpus = { node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) for node, node_info in node_available_resources.items() @@ -135,21 +118,6 @@ def _check_resource_available(self): f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}" ) - # check each resource pool can be satisfied, O(#resource_pools * #nodes) - for resource_pool_name, process_on_nodes in self.resource_pool_spec.items(): - num_gpus, num_nodes = process_on_nodes[0], len(process_on_nodes) - for node, available_gpus in node_available_gpus.items(): - if available_gpus >= num_gpus: - node_available_gpus[node] -= num_gpus - num_nodes -= 1 - if num_nodes == 0: - break - if num_nodes > 0: - raise ValueError( - f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}" - + "cannot be satisfied in this ray cluster" - ) - def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): """Apply KL penalty to the token-level rewards. @@ -161,7 +129,6 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, data (DataProto): The data containing batched model outputs and inputs. kl_ctrl (core_algos.AdaptiveKLController): Controller for adaptive KL penalty. kl_penalty (str, optional): Type of KL penalty to apply. Defaults to "kl". - multi_turn (bool, optional): Whether the data is from a multi-turn conversation. Defaults to False. Returns: tuple: A tuple containing: @@ -257,12 +224,13 @@ def compute_advantage( if config.get("use_pf_ppo", False): data = core_algos.compute_pf_ppo_reweight_data( data, - config.pf_ppo.reweight_method, - config.pf_ppo.weight_pow, + config.pf_ppo.get("reweight_method"), + config.pf_ppo.get("weight_pow"), ) elif adv_estimator == AdvantageEstimator.GRPO: # Initialize the mask for GRPO calculation grpo_calculation_mask = data.batch["response_mask"] + # Call compute_grpo_outcome_advantage with parameters matching its definition advantages, returns = core_algos.compute_grpo_outcome_advantage( token_level_rewards=data.batch["token_level_rewards"], @@ -297,7 +265,7 @@ class RayPPOTrainer: This trainer orchestrates distributed PPO training across multiple nodes and GPUs, managing actor rollouts, critic training, and reward computation with Ray backend. - Supports various model architectures including FSDP, Megatron, and vLLM integration. + Supports various model architectures including FSDP, Megatron, vLLM, and SGLang integration. """ # TODO: support each role have individual ray_worker_group_cls, @@ -308,7 +276,7 @@ def __init__( tokenizer, role_worker_mapping: dict[Role, WorkerType], resource_pool_manager: ResourcePoolManager, - ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup, processor=None, reward_fn=None, val_reward_fn=None, @@ -316,7 +284,7 @@ def __init__( val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, - device_name="cuda", + device_name=None, ): """ Initialize distributed PPO trainer with Ray backend. @@ -335,7 +303,7 @@ def __init__( val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. collate_fn: Function to collate data samples into batches. train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. - device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda". + device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. """ # Store the tokenizer for text processing @@ -349,231 +317,64 @@ def __init__( assert self.hybrid_engine, "Currently, only support hybrid engine" if self.hybrid_engine: - assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}" + assert Role.ActorRollout in role_worker_mapping or Role.ActorRolloutRef in role_worker_mapping, ( + f"{role_worker_mapping.keys()=}" + ) self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager - self.use_reference_policy = Role.RefPolicy in role_worker_mapping - self.use_rm = Role.RewardModel in role_worker_mapping + self.use_reference_policy = need_reference_policy(self.role_worker_mapping) + self.use_rm = need_reward_model(self.role_worker_mapping) + self.use_critic = need_critic(self.config) self.ray_worker_group_cls = ray_worker_group_cls - self.device_name = device_name - self.validation_generations_logger = ValidationGenerationsLogger() + self.device_name = device_name if device_name else self.config.trainer.device + self.validation_generations_logger = ValidationGenerationsLogger( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + ) # if ref_in_actor is True, the reference policy will be actor without lora applied - self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 + self.ref_in_actor = ( + config.actor_rollout_ref.model.get("lora_rank", 0) > 0 + or config.actor_rollout_ref.model.get("lora_adapter_path") is not None + ) # define in-reward KL control # kl loss control currently not suppoorted if self.config.algorithm.use_kl_in_reward: self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) - if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE: - self.use_critic = True - elif self.config.algorithm.adv_estimator in [ - AdvantageEstimator.GRPO, - AdvantageEstimator.GRPO_PASSK, - AdvantageEstimator.REINFORCE_PLUS_PLUS, - AdvantageEstimator.REMAX, - AdvantageEstimator.RLOO, - AdvantageEstimator.OPO, - AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE, - AdvantageEstimator.GPG, - ]: - self.use_critic = False - else: - raise NotImplementedError - - self._validate_config() self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) - def _validate_config(self): - config = self.config - # number of GPUs total - n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes - if config.actor_rollout_ref.actor.strategy == "megatron": - model_parallel_size = ( - config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size - * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size - ) - assert ( - n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0 - ), ( - f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times " - f"context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})" - ) - megatron_dp = n_gpus // ( - model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size - ) - minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu - else: - minimal_bsz = n_gpus - - # 1. Check total batch size for data correctness - real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n - assert real_train_batch_size % minimal_bsz == 0, ( - f"real_train_batch_size ({real_train_batch_size}) must be divisible by minimal possible batch size " - f"({minimal_bsz})" - ) - - # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" - # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". - def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): - """Validate mutually exclusive micro batch size configuration options. - - Ensures that users don't set both deprecated micro_batch_size and - the new micro_batch_size_per_gpu parameters simultaneously. - - Args: - mbs: Deprecated micro batch size parameter value. - mbs_per_gpu: New micro batch size per GPU parameter value. - name (str): Configuration section name for error messages. - - Raises: - ValueError: If both parameters are set or neither is set. - """ - settings = { - "actor_rollout_ref.actor": "micro_batch_size", - "critic": "micro_batch_size", - "reward_model": "micro_batch_size", - "actor_rollout_ref.ref": "log_prob_micro_batch_size", - "actor_rollout_ref.rollout": "log_prob_micro_batch_size", - } - - if name in settings: - param = settings[name] - param_per_gpu = f"{param}_per_gpu" - - if mbs is None and mbs_per_gpu is None: - raise ValueError( - f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'." - ) - - if mbs is not None and mbs_per_gpu is not None: - raise ValueError( - f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove " - f"'{name}.{param}' because only '*_{param_per_gpu}' is supported (the former is deprecated)." - ) - - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.actor.ppo_micro_batch_size, - config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, - "actor_rollout_ref.actor", - ) - - if self.use_reference_policy: - # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.ref.log_prob_micro_batch_size, - config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.ref", - ) - - # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu - check_mutually_exclusive( - config.actor_rollout_ref.rollout.log_prob_micro_batch_size, - config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, - "actor_rollout_ref.rollout", - ) - - if self.use_critic and not config.critic.use_dynamic_bsz: - # Check for critic micro-batch size conflicts - check_mutually_exclusive( - config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic" - ) - - # Check for reward model micro-batch size conflicts - if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: - check_mutually_exclusive( - config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model" - ) - - # Actor - # check if train_batch_size is larger than ppo_mini_batch_size - # if NOT dynamic_bsz, we must ensure: - # ppo_mini_batch_size is divisible by ppo_micro_batch_size - # ppo_micro_batch_size * sequence_parallel_size >= n_gpus - if not config.actor_rollout_ref.actor.use_dynamic_bsz: - assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size - sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) - if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: - assert ( - config.actor_rollout_ref.actor.ppo_mini_batch_size - % config.actor_rollout_ref.actor.ppo_micro_batch_size - == 0 - ) - assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus - - assert config.actor_rollout_ref.actor.loss_agg_mode in [ - "token-mean", - "seq-mean-token-sum", - "seq-mean-token-mean", - "seq-mean-token-sum-norm", - ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}" - - if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss: - print("NOTICE: You have both enabled in-reward kl and kl loss.") - - # critic - if self.use_critic and not config.critic.use_dynamic_bsz: - assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size - sp_size = config.critic.get("ulysses_sequence_parallel_size", 1) - if config.critic.ppo_micro_batch_size is not None: - assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 - assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus - - # Check if use_remove_padding is enabled when using sequence parallelism for fsdp - if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"} and ( - config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 - or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1 - ): - assert config.actor_rollout_ref.model.use_remove_padding, ( - "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`." - ) - - if self.use_critic and config.critic.strategy in {"fsdp", "fsdp2"}: - if config.critic.get("ulysses_sequence_parallel_size", 1) > 1: - assert config.critic.model.use_remove_padding, ( - "When using sequence parallelism for critic, you must enable `use_remove_padding`." - ) - - if config.data.get("val_batch_size", None) is not None: - print( - "WARNING: val_batch_size is deprecated." - + " Validation datasets are sent to inference engines as a whole batch," - + " which will schedule the memory themselves." - ) - - # check eval config - if config.actor_rollout_ref.rollout.val_kwargs.do_sample: - assert config.actor_rollout_ref.rollout.temperature > 0, ( - "validation gen temperature should be greater than 0 when enabling do_sample" - ) - - print("[validate_config] All configuration checks passed successfully!") - def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): """ Creates the train and validation dataloaders. """ # TODO: we have to make sure the batch size is divisible by the dp size - from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + from reasoning360.trainer.main_ppo import create_rl_dataset, create_rl_sampler if train_dataset is None: train_dataset = create_rl_dataset( - self.config.data.train_files, self.config.data, self.tokenizer, self.processor + self.config.data.train_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("train_max_samples", -1), ) if val_dataset is None: val_dataset = create_rl_dataset( - self.config.data.val_files, self.config.data, self.tokenizer, self.processor + self.config.data.val_files, + self.config.data, + self.tokenizer, + self.processor, + max_samples=self.config.data.get("val_max_samples", -1), ) self.train_dataset, self.val_dataset = train_dataset, val_dataset if train_sampler is None: train_sampler = create_rl_sampler(self.config.data, self.train_dataset) if collate_fn is None: - from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn + from reasoning360.utils.dataset.rl_dataset import collate_fn as default_collate_fn collate_fn = default_collate_fn @@ -627,7 +428,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl except Exception as e: print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") - def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path): + def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): """Dump rollout/validation samples as JSONL.""" os.makedirs(dump_path, exist_ok=True) filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") @@ -636,6 +437,7 @@ def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, du base_data = { "input": inputs, "output": outputs, + "gts": gts, "score": scores, "step": [self.global_steps] * n, } @@ -654,6 +456,38 @@ def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, du print(f"Dumped generations to {filename}") + def _log_rollout_data( + self, batch: DataProto, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str + ): + """Log rollout data to disk. + Args: + batch (DataProto): The batch containing rollout data + reward_extra_infos_dict (dict): Additional reward information to log + timing_raw (dict): Timing information for profiling + rollout_data_dir (str): Directory path to save the rollout data + """ + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch] + + reward_extra_infos_to_dump = reward_extra_infos_dict.copy() + if "request_id" in batch.non_tensor_batch: + reward_extra_infos_dict.setdefault( + "request_id", + batch.non_tensor_batch["request_id"].tolist(), + ) + + self._dump_generations( + inputs=inputs, + outputs=outputs, + gts=sample_gts, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_to_dump, + dump_path=rollout_data_dir, + ) + def _maybe_log_val_generations(self, inputs, outputs, scores): """Log a table of validation samples to the configured logger (wandb or swanlab)""" @@ -678,21 +512,45 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): # Log to each configured logger self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + def _get_gen_batch(self, batch: DataProto) -> DataProto: + reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() + + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), + ) + + # For agent loop, we need reward model keys to compute score. + if self.async_rollout_mode: + gen_batch.non_tensor_batch.update(batch.non_tensor_batch) + + return gen_batch + def _validate(self): data_source_lst = [] - dataset_lst = [] reward_extra_infos_dict: dict[str, list] = defaultdict(list) + # NOTE: added by Reasoning360. + dataset_lst = [] + # Lists to collect samples for the table sample_inputs = [] sample_outputs = [] + sample_gts = [] sample_scores = [] sample_turns = [] + sample_uids = [] for test_data in self.val_dataloader: test_batch = DataProto.from_single_dict(test_data) - # NOTE: print statements in this loop added by Reasoning360 are temporarily disabled - # print(f"Shape of test_batch: {test_batch.batch['input_ids'].shape}") + + if "uid" not in test_batch.non_tensor_batch: + test_batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object + ) # repeat test batch test_batch = test_batch.repeat( @@ -708,24 +566,14 @@ def _validate(self): # TODO: Can we keep special tokens except for padding tokens? input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] sample_inputs.extend(input_texts) + sample_uids.extend(test_batch.non_tensor_batch["uid"]) - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] - if "multi_modal_data" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("multi_modal_data") - if "raw_prompt" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("raw_prompt") - if "tools_kwargs" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("tools_kwargs") - if "interaction_kwargs" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("interaction_kwargs") - if "agent_name" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("agent_name") - test_gen_batch = test_batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) + ground_truths = [ + item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch + ] + sample_gts.extend(ground_truths) + test_gen_batch = self._get_gen_batch(test_batch) test_gen_batch.meta_info = { "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, @@ -762,21 +610,19 @@ def _validate(self): test_batch.meta_info["validate"] = True # evaluate using reward_function + if self.val_reward_fn is None: + raise ValueError("val_reward_fn must be provided for validation.") result = self.val_reward_fn(test_batch, return_dict=True) reward_tensor = result["reward_tensor"] scores = reward_tensor.sum(-1).cpu().tolist() - # print(f"Shape of reward_tensor: {reward_tensor.shape}") - sample_scores.extend(scores) reward_extra_infos_dict["reward"].extend(scores) - print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") if "reward_extra_info" in result: for key, lst in result["reward_extra_info"].items(): reward_extra_infos_dict[key].extend(lst) - print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") - - # NOTE: Added by Reasoning360: Collect dataset information. TODO: maybe replicated usage with the data_source_lst and can be removed? + + # NOTE: added by Reasoning360. Collect dataset information. TODO: maybe replicated usage with the data_source_lst and can be removed? datasets = [] for i in range(reward_tensor.shape[0]): dataset = "unknown" @@ -801,6 +647,7 @@ def _validate(self): self._dump_generations( inputs=sample_inputs, outputs=sample_outputs, + gts=sample_gts, scores=sample_scores, reward_extra_infos_dict=reward_extra_infos_dict, dump_path=val_data_dir, @@ -811,17 +658,16 @@ def _validate(self): # NOTE: Added by Reasoning360: Calculate the mean reward for each data source and dataset data_sources = np.concatenate(data_source_lst, axis=0) + data_sources = np.concatenate(data_source_lst, axis=0) - datasets = np.concatenate(dataset_lst, axis=0) # Concatenate datasets - - data_src2var2metric2val = process_validation_metrics(data_sources, sample_inputs, reward_extra_infos_dict) + data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) metric_dict = {} for data_source, var2metric2val in data_src2var2metric2val.items(): core_var = "acc" if "acc" in var2metric2val else "reward" for var_name, metric2val in var2metric2val.items(): n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) for metric_name, metric_val in metric2val.items(): - # NOTE: Added by Reasoning360: Add std to the metric name. + # NOTE: added by Reasoning360. Add std metrics if ( (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best", "std"]) @@ -832,8 +678,8 @@ def _validate(self): metric_sec = "val-aux" pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" metric_dict[pfx] = metric_val - - # NOTE: Added by Reasoning360: Calculate the mean reward for each data source and dataset + + # NOTE: added by Reasoning360. Calculate the mean reward for each data source and dataset data_source_dataset_reward = {} for i in range(len(sample_scores)): data_source = data_sources[i] @@ -842,17 +688,17 @@ def _validate(self): if key not in data_source_dataset_reward: data_source_dataset_reward[key] = [] data_source_dataset_reward[key].append(sample_scores[i]) - + if len(sample_turns) > 0: sample_turns = np.concatenate(sample_turns) metric_dict["val-aux/num_turns/min"] = sample_turns.min() metric_dict["val-aux/num_turns/max"] = sample_turns.max() metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() - # Record the mean reward for each data source and dataset + # NOTE: added by Reasoning360. Record the mean reward for each data source and dataset for (data_source, dataset), rewards in data_source_dataset_reward.items(): metric_dict[f"val/test_score/{data_source}/{dataset}"] = np.mean(rewards) - + return metric_dict def init_workers(self): @@ -867,41 +713,41 @@ def init_workers(self): self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} # create actor and rollout + actor_role = Role.ActorRolloutRef if Role.ActorRolloutRef in self.role_worker_mapping else Role.ActorRollout if self.hybrid_engine: - resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout) + resource_pool = self.resource_pool_manager.get_resource_pool(actor_role) actor_rollout_cls = RayClassWithInitArgs( - cls=self.role_worker_mapping[Role.ActorRollout], + cls=self.role_worker_mapping[actor_role], config=self.config.actor_rollout_ref, - role="actor_rollout", - profile_option=self.config.trainer.npu_profile.options, + role=str(actor_role), ) - self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + self.resource_pool_to_cls[resource_pool][str(actor_role)] = actor_rollout_cls else: - raise NotImplementedError() + raise NotImplementedError # create critic if self.use_critic: resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) - critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) - self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + critic_cfg = omega_conf_to_dataclass(self.config.critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls # create reference policy if needed - if self.use_reference_policy: + if self.use_reference_policy and Role.RefPolicy in self.role_worker_mapping: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) ref_policy_cls = RayClassWithInitArgs( self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, - role="ref", - profile_option=self.config.trainer.npu_profile.options, + role=str(Role.RefPolicy), ) - self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls # create a reward model if reward_fn is None if self.use_rm: # we create a RM here resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) - self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls # initialize WorkerGroup # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, @@ -912,40 +758,50 @@ def init_workers(self): wg_kwargs = {} # Setting up kwargs for RayWorkerGroup if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout - if OmegaConf.select(self.config.trainer, "profile_steps") is not None: - wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "profile_steps") - assert OmegaConf.select(self.config.trainer, "worker_nsight_options") is not None, ( - "worker_nsight_options must be set when profile_steps is set" - ) - wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( - OmegaConf.select(self.config.trainer, "worker_nsight_options") - ) + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = self.ray_worker_group_cls( resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, - device_name=self.device_name, **wg_kwargs, ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) if self.use_critic: - self.critic_wg = all_wg["critic"] + self.critic_wg = all_wg[str(Role.Critic)] self.critic_wg.init_model() if self.use_reference_policy and not self.ref_in_actor: - self.ref_policy_wg = all_wg["ref"] - self.ref_policy_wg.init_model() + if str(Role.RefPolicy) in all_wg: + self.ref_policy_wg = all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + else: + # Model engine: ActorRolloutRefWorker + assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}" + self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)] + self.rm_wg = None + # initalization of rm_wg will be deprecated in the future if self.use_rm: - self.rm_wg = all_wg["rm"] + self.rm_wg = all_wg[str(Role.RewardModel)] self.rm_wg.init_model() # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg["actor_rollout"] + self.actor_rollout_wg = all_wg[str(actor_role)] self.actor_rollout_wg.init_model() # create async rollout manager and request scheduler @@ -954,9 +810,15 @@ def init_workers(self): from verl.experimental.agent_loop import AgentLoopManager self.async_rollout_mode = True + if self.config.reward_model.enable and self.config.reward_model.enable_resource_pool: + rm_resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + else: + rm_resource_pool = None + self.async_rollout_manager = AgentLoopManager( config=self.config, worker_group=self.actor_rollout_wg, + rm_resource_pool=rm_resource_pool, ) def _save_checkpoint(self): @@ -994,11 +856,13 @@ def _save_checkpoint(self): ) if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, "critic") + critic_local_path = os.path.join(local_global_step_folder, str(Role.Critic)) critic_remote_path = ( None if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + else os.path.join( + self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", str(Role.Critic) + ) ) self.critic_wg.save_checkpoint( critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep @@ -1011,6 +875,15 @@ def _save_checkpoint(self): torch.save(dataloader_state_dict, dataloader_local_path) # latest checkpointed iteration tracker (for atomic usage) + if ( + hasattr(self.config.actor_rollout_ref.actor.checkpoint, "async_save") + and self.config.actor_rollout_ref.actor.checkpoint.async_save + ) or ( + "async_save" in self.config.actor_rollout_ref.actor.checkpoint + and self.config.actor_rollout_ref.actor.checkpoint["async_save"] + ): + print("skip write latest_checkpointed_iteration.txt when async_save is True") + return local_latest_checkpointed_iteration = os.path.join( self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" ) @@ -1025,7 +898,7 @@ def _load_checkpoint(self): if self.config.trainer.default_hdfs_dir is not None: raise NotImplementedError("load from hdfs is not implemented yet") else: - checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path + checkpoint_folder = self.config.trainer.default_local_dir # TODO: check path if not os.path.isabs(checkpoint_folder): working_dir = os.getcwd() checkpoint_folder = os.path.join(working_dir, checkpoint_folder) @@ -1054,7 +927,7 @@ def _load_checkpoint(self): print(f"Resuming from {global_step_folder}") actor_path = os.path.join(global_step_folder, "actor") - critic_path = os.path.join(global_step_folder, "critic") + critic_path = os.path.join(global_step_folder, str(Role.Critic)) # load actor self.actor_rollout_wg.load_checkpoint( actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load @@ -1079,11 +952,11 @@ def _start_profiling(self, do_profile: bool) -> None: if do_profile: self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) if self.use_reference_policy: - self.ref_policy_wg.start_profile() + self.ref_policy_wg.start_profile(profile_step=self.global_steps) if self.use_critic: - self.critic_wg.start_profile() + self.critic_wg.start_profile(profile_step=self.global_steps) if self.use_rm: - self.rm_wg.start_profile() + self.rm_wg.start_profile(profile_step=self.global_steps) def _stop_profiling(self, do_profile: bool) -> None: """Stop profiling for all worker groups if profiling is enabled.""" @@ -1096,15 +969,35 @@ def _stop_profiling(self, do_profile: bool) -> None: if self.use_rm: self.rm_wg.stop_profile() - def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen"): + def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False): """Reorder the data on single controller such that each dp rank gets similar total tokens""" attention_mask = batch.batch["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = batch.batch["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) + workload_lst = calculate_workload(global_seqlen_lst) world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) + if keep_minibatch: + # Decouple the DP balancing and mini-batching. + minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size") + minibatch_num = len(workload_lst) // minibatch_size + global_partition_lst = [[] for _ in range(world_size)] + for i in range(minibatch_num): + rearrange_minibatch_lst = get_seqlen_balanced_partitions( + workload_lst[i * minibatch_size : (i + 1) * minibatch_size], + k_partitions=world_size, + equal_size=True, + ) + for j, part in enumerate(rearrange_minibatch_lst): + global_partition_lst[j].extend([x + minibatch_size * i for x in part]) + else: + global_partition_lst = get_seqlen_balanced_partitions( + workload_lst, k_partitions=world_size, equal_size=True + ) + # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. + for idx, partition in enumerate(global_partition_lst): + partition.sort(key=lambda x: (workload_lst[x], x)) + ordered_partition = partition[::2] + partition[1::2][::-1] + global_partition_lst[idx] = ordered_partition # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) batch.reorder(global_idx) @@ -1136,6 +1029,8 @@ def fit(self): # load checkpoint before doing anything self._load_checkpoint() + current_epoch = self.global_steps // len(self.train_dataloader) + # perform validation before training # currently, we only support validation using the reward_function. if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): @@ -1146,6 +1041,10 @@ def fit(self): if self.config.trainer.get("val_only", False): return + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + # add tqdm progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") @@ -1154,59 +1053,59 @@ def fit(self): last_val_metrics = None self.max_steps_duration = 0 - for epoch in range(self.config.trainer.total_epochs): + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + for epoch in range(current_epoch, self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) metrics = {} timing_raw = {} - do_profile = ( - self.global_steps in self.config.trainer.profile_steps - if self.config.trainer.profile_steps is not None - else False - ) with marked_timer("start_profile", timing_raw): - self._start_profiling(do_profile) - + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) batch: DataProto = DataProto.from_single_dict(batch_dict) + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature - # pop those keys for generation - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] - if "multi_modal_data" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("multi_modal_data") - if "raw_prompt" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("raw_prompt") - if "tools_kwargs" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("tools_kwargs") - if "interaction_kwargs" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("interaction_kwargs") - if "index" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("index") - if "agent_name" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("agent_name") - - gen_batch = batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + # add uid to batch + batch.non_tensor_batch["uid"] = np.array( + [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object ) + gen_batch = self._get_gen_batch(batch) + # pass global_steps to trace gen_batch.meta_info["global_steps"] = self.global_steps - gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + gen_batch_output = gen_batch.repeat( + repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True + ) is_last_step = self.global_steps >= self.total_training_steps - with marked_timer("step", timing_raw): # generate a batch with marked_timer("gen", timing_raw, color="red"): if not self.async_rollout_mode: - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) else: - gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) + timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + if self.reward_fn is None: + raise ValueError("A reward_fn is required for REMAX advantage estimation.") + with marked_timer("gen_max", timing_raw, color="purple"): gen_baseline_batch = deepcopy(gen_batch) gen_baseline_batch.meta_info["do_sample"] = False @@ -1215,18 +1114,22 @@ def fit(self): else: gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) batch = batch.union(gen_baseline_output) - reward_baseline_tensor = self.reward_fn(batch) + # compute reward model score on batch + rm_scores = None + if self.use_rm and "rm_scores" not in batch.batch.keys(): + rm_scores = self.rm_wg.compute_rm_score(batch) + batch = batch.union(rm_scores) + reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) - batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + keys_to_pop = set(gen_baseline_output.batch.keys()) + if rm_scores is not None: + keys_to_pop.update(rm_scores.batch.keys()) + batch.pop(batch_keys=list(keys_to_pop)) batch.batch["reward_baselines"] = reward_baseline_tensor - del gen_baseline_batch, gen_baseline_output - - batch.non_tensor_batch["uid"] = np.array( - [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object - ) + del rm_scores, gen_baseline_batch, gen_baseline_output # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) @@ -1237,7 +1140,6 @@ def fit(self): # NOTE: This usually changes the order of data in the `batch`, # which won't affect the advantage calculation (since it's based on uid), # but might affect the loss calculation (due to the change of mini-batching). - # TODO: Decouple the DP balancing and mini-batching. if self.config.trainer.balance_batch: self._balance_batch(batch, metrics=metrics) @@ -1246,54 +1148,58 @@ def fit(self): with marked_timer("reward", timing_raw, color="yellow"): # compute reward model score - if self.use_rm: + if self.use_rm and "rm_scores" not in batch.batch.keys(): reward_tensor = self.rm_wg.compute_rm_score(batch) batch = batch.union(reward_tensor) if self.config.reward_model.launch_reward_fn_async: - future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer) + future_reward = compute_reward_async.remote( + data=batch, config=self.config, tokenizer=self.tokenizer + ) else: reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) - # recompute old_log_probs - with marked_timer("old_log_prob", timing_raw, color="blue"): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch["entropys"] - response_masks = batch.batch["response_mask"] - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} - metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) - - if "rollout_log_probs" in batch.batch.keys(): - # TODO: we may want to add diff of probs too. - rollout_old_log_probs = batch.batch["rollout_log_probs"] - actor_old_log_probs = batch.batch["old_log_probs"] - attention_mask = batch.batch["attention_mask"] - responses = batch.batch["responses"] - response_length = responses.size(1) - response_mask = attention_mask[:, -response_length:] - - rollout_probs = torch.exp(rollout_old_log_probs) - actor_probs = torch.exp(actor_old_log_probs) - rollout_probs_diff = torch.abs(rollout_probs - actor_probs) - rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool()) - rollout_probs_diff_max = torch.max(rollout_probs_diff) - rollout_probs_diff_mean = torch.mean(rollout_probs_diff) - rollout_probs_diff_std = torch.std(rollout_probs_diff) - metrics.update( - { - "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), - "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), - "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), - } + # Operating Mode Selection: + # - Bypass mode: Sets old_log_probs = rollout_log_probs (2 policies: π_rollout, π_θ) + # - Decoupled mode: Recomputes old_log_probs as proximal anchor (3 policies: π_rollout, π_old, π_θ) + # Note: π_old computed once per data batch, serves as stable reference during mini-batch updates + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False) + if bypass_recomputing_logprobs: # Use `rollout_log_probs` + from verl.trainer.ppo.rollout_corr_helper import apply_rollout_correction + + apply_rollout_correction( + batch=batch, + rollout_corr_config=rollout_corr_config, + policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss, + ) + else: # Recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + actor_config = self.config.actor_rollout_ref.actor + entropy_agg = agg_loss( + loss_mat=entropys, + loss_mask=response_masks, + loss_agg_mode=actor_config.loss_agg_mode, + loss_scale_factor=actor_config.loss_scale_factor, ) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' if self.use_reference_policy: # compute reference log_prob - with marked_timer("ref", timing_raw, color="olive"): + with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): if not self.ref_in_actor: ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) else: @@ -1325,8 +1231,22 @@ def fit(self): else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - # compute advantages, executed on the driver process + # Compute rollout correction: IS weights, rejection sampling, and metrics + # Only runs in decoupled mode (computes once per batch using stable π_old) + # In bypass mode, this is skipped - actor computes metrics from evolving π_θ vs π_rollout + if ( + rollout_corr_config is not None + and "rollout_log_probs" in batch.batch + and not bypass_recomputing_logprobs # Only in decoupled mode + ): + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + # Compute IS weights, apply rejection sampling, compute metrics + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) + # compute advantages, executed on the driver process norm_adv_by_std_in_grpo = self.config.algorithm.get( "norm_adv_by_std_in_grpo", True ) # GRPO adv normalization factor @@ -1352,7 +1272,10 @@ def fit(self): if self.config.trainer.critic_warmup <= self.global_steps: # update actor with marked_timer("update_actor", timing_raw, color="red"): - batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + rollout_config = self.config.actor_rollout_ref.rollout + batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable + # TODO: Make "temperature" single source of truth from generation. + batch.meta_info["temperature"] = rollout_config.temperature actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) @@ -1360,54 +1283,53 @@ def fit(self): # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: - with marked_timer("dump_rollout_generations", timing_raw, color="green"): - inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) - outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) - scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() - self._dump_generations( - inputs=inputs, - outputs=outputs, - scores=scores, - reward_extra_infos_dict=reward_extra_infos_dict, - dump_path=rollout_data_dir, - ) - - # validate - if ( - self.val_reward_fn is not None - and self.config.trainer.test_freq > 0 - and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) - ): - with marked_timer("testing", timing_raw, color="green"): - val_metrics: dict = self._validate() - if is_last_step: - last_val_metrics = val_metrics - metrics.update(val_metrics) - - # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. - esi_close_to_expiration = should_save_ckpt_esi( - max_steps_duration=self.max_steps_duration, - redundant_time=self.config.trainer.esi_redundant_time, - ) - # Check if the conditions for saving a checkpoint are met. - # The conditions include a mandatory condition (1) and - # one of the following optional conditions (2/3/4): - # 1. The save frequency is set to a positive value. - # 2. It's the last training step. - # 3. The current step number is a multiple of the save frequency. - # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. - if self.config.trainer.save_freq > 0 and ( - is_last_step - or self.global_steps % self.config.trainer.save_freq == 0 - or esi_close_to_expiration - ): - if esi_close_to_expiration: - print("Force saving checkpoint: ESI instance expiration approaching.") - with marked_timer("save_checkpoint", timing_raw, color="green"): - self._save_checkpoint() + self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) + + # validate + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() with marked_timer("stop_profile", timing_raw): - self._stop_profiling(do_profile) + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile steps_duration = timing_raw["step"] self.max_steps_duration = max(self.max_steps_duration, steps_duration) @@ -1420,12 +1342,14 @@ def fit(self): } ) # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + # NOTE: added by Reasoning360 metrics.update(compute_difficulty_histogram_metrics(batch=batch, config=self.config)) + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: implement actual tflpo and theoretical tflpo n_gpus = self.resource_pool_manager.get_n_gpus() metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation # this is experimental and may be changed/removed in the future in favor of a general-purpose one if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): @@ -1436,7 +1360,18 @@ def fit(self): progress_bar.update(1) self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + if is_last_step: + if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): + self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) pprint(f"Final validation metrics: {last_val_metrics}") progress_bar.close() return diff --git a/src/reasoning360/trainer/ppo/reward.py b/src/reasoning360/trainer/ppo/reward.py new file mode 100644 index 000000000..29e818957 --- /dev/null +++ b/src/reasoning360/trainer/ppo/reward.py @@ -0,0 +1,239 @@ +# Copyright 2025 Individual Contributor: Thibaut Barroyer +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import importlib.util +import inspect +import multiprocessing +import os +import sys +import warnings +from functools import partial +from typing import TYPE_CHECKING, Any, Optional, cast + +import ray +import torch + +from reasoning360.utils.reward_score import default_compute_score +from verl.utils.transferqueue_utils import tqbridge + +if TYPE_CHECKING: + from omegaconf import DictConfig + + from verl import DataProto + from verl.experimental.reward.reward_loop.base import RewardLoopManagerBase + from verl.trainer.config.config import ModuleConfig, RewardManagerConfig + from verl.workers.reward_manager.abstract import AbstractRewardManager, RawRewardFn +else: + try: + from verl.experimental.reward.reward_loop.base import RewardLoopManagerBase + except ImportError: + RewardLoopManagerBase = None # type: ignore[assignment,misc] + + +def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs): + """Calls `raw_fn` by merging `extra_kwargs` into call-time `kwargs`, with `extra_kwargs` taking precedence. + + This function is used to merge additional keyword arguments with the original function's arguments. + """ + merged_kwargs = {**kwargs, **extra_kwargs} + return raw_fn(*args, **merged_kwargs) + + +async def _call_with_kwargs_async(raw_fn, extra_kwargs, *args, **kwargs): + """Calls `raw_fn` by merging `extra_kwargs` into call-time `kwargs`, with `extra_kwargs` taking precedence. + + This function is used to merge additional keyword arguments with the original function's arguments. + """ + merged_kwargs = {**kwargs, **extra_kwargs} + return await raw_fn(*args, **merged_kwargs) + + +def get_custom_reward_fn(config: DictConfig) -> Optional[RawRewardFn]: + """Load and return a custom reward function from external file. + + Dynamically imports a reward function from a specified file path and wraps + it with additional keyword arguments from the configuration. + + Args: + config (dict): Configuration dictionary containing custom_reward_function + settings with 'path', 'name', and 'reward_kwargs' fields. + + Returns: + callable or None: Wrapped reward function with merged kwargs, or None + if no custom reward function is configured. + + Raises: + FileNotFoundError: If the specified reward function file doesn't exist. + RuntimeError: If there's an error loading the module from file. + AttributeError: If the specified function name isn't found in the module. + """ + + reward_fn_config = config.get("custom_reward_function") or {} + file_path = reward_fn_config.get("path") + if not file_path: + return None + + function_name = reward_fn_config.get("name") + assert function_name is not None + + module = sys.modules.get("custom_module", None) + if module is None: + if not os.path.exists(file_path): + raise FileNotFoundError(f"Reward function file '{file_path}' not found.") + + spec = importlib.util.spec_from_file_location("custom_module", file_path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + try: + sys.modules["custom_module"] = module + assert spec.loader is not None + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e + + if not hasattr(module, function_name): + raise AttributeError(f"Reward function '{function_name}' not found in '{module.__file__}'.") + + print(f"using customized reward function '{function_name}' from '{module.__file__}'") + raw_fn = getattr(module, function_name) + + reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {})) + + if not inspect.iscoroutinefunction(raw_fn): + return partial(_call_with_kwargs, raw_fn, reward_kwargs) + else: + return partial(_call_with_kwargs_async, raw_fn, reward_kwargs) + + +def load_reward_manager( + config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any +) -> AbstractRewardManager: + """ + Load and initialize a reward manager based on the configuration. + + Args: + config: PPO trainer configuration object containing reward_model fields. + tokenizer: Tokenizer object used for processing text. + num_examine: Number of samples to examine. + **reward_kwargs: Additional keyword arguments for the reward manager. + + Returns: + An instance of the specified reward manager class. + """ + + # Try to get a custom reward function based on the configuration + # user defined reward manager can be registered in custom_reward_fn + compute_score = get_custom_reward_fn(config) + final_compute_score = compute_score + + reward_manager_cfg: RewardManagerConfig = config.reward_manager + reward_manager_cls: type[AbstractRewardManager] + if reward_manager_cfg.source == "register": + # Use verl's registry to avoid name-collisions between verl and local Reasoning360 + # reward manager implementations (both may register e.g. "dapo"). + from verl.workers.reward_manager import get_reward_manager_cls + + reward_manager_cls = get_reward_manager_cls(reward_manager_cfg.name) + elif reward_manager_cfg.source == "importlib": + from verl.utils.import_utils import load_extern_object + + module_cfg: ModuleConfig | None = reward_manager_cfg.module + assert module_cfg is not None and module_cfg.path is not None, ( + f"Module path is required when {reward_manager_cfg.source=}, but got {module_cfg=}" + ) + reward_manager_cls_name = reward_manager_cfg.name + reward_manager_cls = cast( + type[AbstractRewardManager], + load_extern_object(module_path=module_cfg.path, object_name=reward_manager_cls_name), + ) + + if compute_score is None: + sandbox_config = config.reward_model.get("sandbox_fusion") + sandbox_url = sandbox_config.get("url") if sandbox_config else None + memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024) if sandbox_config else 1024 + if sandbox_url: + sandbox_manager = multiprocessing.Manager() + # Create a semaphore to control concurrent access to the sandbox + _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) + final_compute_score = partial( + default_compute_score, + sandbox_fusion_url=sandbox_url, + concurrent_semaphore=_concurrent_semaphore, + memory_limit_mb=memory_limit_mb, + ) + else: + final_compute_score = default_compute_score + + # Instantiate and return the reward manager with the specified parameters + # RewardLoopManagerBase subclasses (like RateLimitedRewardLoopManager) don't accept num_examine + # while AbstractRewardManager subclasses (like NaiveRewardManager) do + if RewardLoopManagerBase is not None and issubclass(reward_manager_cls, RewardLoopManagerBase): + # RewardLoopManagerBase-based managers use a different signature + return reward_manager_cls( + config=config, + tokenizer=tokenizer, + compute_score=final_compute_score, + **reward_kwargs, + ) + else: + # Traditional AbstractRewardManager-based managers + return reward_manager_cls( + tokenizer=tokenizer, + num_examine=num_examine, + compute_score=final_compute_score, + reward_fn_key=config.data.reward_fn_key, + **reward_kwargs, + ) + + +@tqbridge(put_data=False) +def compute_reward(data: DataProto, reward_fn: AbstractRewardManager) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Compute reward for a batch of data. + Args: + data: DataProto object containing the input data. + reward_fn: Reward function to compute the reward. + Returns: + Tuple of reward tensor and extra info dictionary. + """ + try: + reward_result = reward_fn(data, return_dict=True) + reward_tensor = reward_result["reward_tensor"] + reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) + except Exception as e: + print(f"Error in reward_fn: {e}") + reward_tensor = reward_fn(data) + reward_extra_infos_dict = {} + + return reward_tensor, reward_extra_infos_dict + + +@ray.remote(num_cpus=1) +def compute_reward_async(data: DataProto, config=None, tokenizer=None, reward_fn=None): + """ + Load the reward manager and compute the reward for a batch of data. + This is meant to be run in a separate Ray worker. + """ + if reward_fn is None: + assert config is not None and tokenizer is not None, ( + "config and tokenizer must not be None when reward_fn is None" + ) + + warnings.warn("using config and tokenizer with compute_reward_async is deprecated", stacklevel=2) + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + + return compute_reward(data, reward_fn) diff --git a/src/reasoning360/utils/__init__.py b/src/reasoning360/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/reasoning360/utils/dataset/rl_dataset.py b/src/reasoning360/utils/dataset/rl_dataset.py new file mode 100644 index 000000000..aa86ccd75 --- /dev/null +++ b/src/reasoning360/utils/dataset/rl_dataset.py @@ -0,0 +1,489 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The file is temporarily reverted by Reasoning360 to use `dataframe` rather than `Dataset`, +# to support heterogeneous keys of multi-domain data + +import copy +import logging +import os +import re +import traceback +from collections import defaultdict +from typing import Optional + +import datasets +import pandas as pd +import numpy as np +import torch +from omegaconf import DictConfig, ListConfig +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer, ProcessorMixin + +import verl.utils.torch_functional as verl_F +from verl.utils.model import compute_position_id_with_mask + +logger = logging.getLogger(__name__) + + +def collate_fn(data_list: list[dict]) -> dict: + """ + Collate a batch of sample dicts into batched tensors and arrays. + + Args: + data_list: List of dicts mapping feature names to torch.Tensor or other values. + + Returns: + Dict where tensor entries are stacked into a torch.Tensor of shape + (batch_size, \\*dims) and non-tensor entries are converted to + np.ndarray of dtype object with shape (batch_size,). + """ + tensors = defaultdict(list) + non_tensors = defaultdict(list) + + for data in data_list: + for key, val in data.items(): + if isinstance(val, torch.Tensor): + tensors[key].append(val) + else: + non_tensors[key].append(val) + + for key, val in tensors.items(): + tensors[key] = torch.stack(val, dim=0) + + for key, val in non_tensors.items(): + non_tensors[key] = np.fromiter(val, dtype=object, count=len(val)) + + return {**tensors, **non_tensors} + + +class RLHFDataset(Dataset): + """ + Load and preprocess RLHF data from Parquet files. + + - Caches files locally. + - Reads into a HuggingFace Dataset and tokenizes prompts. + - Optionally handles images/videos via a ProcessorMixin. + - Filters prompts over a max length. + - Supports resuming from checkpoints. + + Args: + data_files (str or list): Path(s) to Parquet file(s). + tokenizer (PreTrainedTokenizer): For the tokenization of text to token IDs. + config (DictConfig): Options like cache_dir, prompt_key, max_prompt_length, truncation, etc. + processor (ProcessorMixin, optional): Multimodal preprocessor for images/videos. + """ + + def __init__( + self, + data_files: str | list[str], + tokenizer: PreTrainedTokenizer, + config: DictConfig, + processor: Optional[ProcessorMixin] = None, + max_samples: int = -1, + ): + if not isinstance(data_files, list | ListConfig): + data_files = [data_files] + + self.data_files = copy.deepcopy(data_files) + self.original_data_files = copy.deepcopy(data_files) # use for resume + self.tokenizer = tokenizer + self.processor = processor + self.max_samples = max_samples + self.config = config + + self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf")) + self.prompt_key = config.get("prompt_key", "prompt") + self.image_key = config.get("image_key", "images") + self.video_key = config.get("video_key", "videos") + self.image_patch_size = config.get("image_patch_size", 14) + self.max_prompt_length = config.get("max_prompt_length", 1024) + self.return_raw_chat = config.get("return_raw_chat", False) + self.return_full_prompt = config.get("return_full_prompt", False) + self.truncation = config.get("truncation", "error") + self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) + self.apply_chat_template_kwargs = config.get("apply_chat_template_kwargs", {}) + + self.tool_config_path = config.get("tool_config_path", None) + self.tool_schemas = None + if self.tool_config_path: + try: + from verl.tools.utils.tool_registry import initialize_tools_from_config + + tool_list = initialize_tools_from_config(self.tool_config_path) + # match ToolAgentLoop behaviour: model_dump to plain dicts + self.tool_schemas = [ + tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list + ] + except Exception as e: + logger.warning("Failed to initialize tools from %s: %s", self.tool_config_path, e) + self.tool_schemas = None + + self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) + self.num_workers = min(self.num_workers, os.cpu_count()) if self.num_workers is not None else None + self.use_shm = config.get("use_shm", False) + self.chat_template_func = config.get("chat_template_func", None) + self.need_tools_kwargs = config.get("need_tools_kwargs", False) + self.filter_prompts = config.get("filter_prompts", True) + self.serialize_dataset = False + self.return_multi_modal_inputs = config.get("return_multi_modal_inputs", True) + self.shuffle = config.get("shuffle", False) + self.seed = config.get("seed") + + self._download() + self._read_files_and_tokenize() + + def _download(self, use_origin_parquet=False): + from verl.utils.fs import copy_to_local + + data_files = self.data_files if not use_origin_parquet else self.original_data_files + for i, parquet_file in enumerate(data_files): + self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir, use_shm=self.use_shm) + + def _read_files_and_tokenize(self): + dataframes = [] + for parquet_file in self.data_files: + # read parquet files and cache + # dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"] + try: + dataframe = pd.read_parquet(parquet_file) + except Exception: + # if pandas fails (most likely due to nested columns), use polars + # NOTE: added by Reasoning360 + import polars as pl + dataframe = pl.read_parquet(parquet_file).to_pandas() + dataframes.append(dataframe) + + # NOTE: added by Reasoning360 + # self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes) + self.dataframe = pd.concat(dataframes) + + print(f"dataset len: {len(self.dataframe)}") + + # Safely check if apply_chat_template exists in dataframe + # NOTE: added by Reasoning360 + if "apply_chat_template" not in self.dataframe: + print("Warning: apply_chat_template column not found in dataframe. Defaulting to True.") + self.dataframe["apply_chat_template"] = [True] * len(self.dataframe) + + self.dataframe = self.maybe_filter_out_long_prompts(self.dataframe) + + def maybe_filter_out_long_prompts(self, dataframe: datasets.Dataset = None): + # filter out too long prompts + if self.filter_overlong_prompts: + tokenizer = self.tokenizer + processor = self.processor + prompt_key = self.prompt_key + image_key = self.image_key + video_key = self.video_key + + if processor is not None: + from verl.utils.dataset.vision_utils import process_image, process_video + + def doc2len(doc) -> int: + try: + messages = self._build_messages(doc) + # pass tool schemas if available so the processor can format prompts + apply_kwargs = dict(**self.apply_chat_template_kwargs) + if self.tool_schemas is not None: + apply_kwargs["tools"] = self.tool_schemas + + raw_prompt = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False, **apply_kwargs + ) + if image_key in doc and doc[image_key]: + images = [ + process_image(image, image_patch_size=self.image_patch_size) for image in doc[image_key] + ] + else: + images = None + + if video_key in doc and doc[video_key]: + videos, video_metadata = zip( + *[ + process_video( + video, image_patch_size=self.image_patch_size, return_video_metadata=True + ) + for video in doc[video_key] + ], + strict=True, + ) + videos = list(videos) + video_metadata = list(video_metadata) + videos_kwargs = {"video_metadata": video_metadata, "do_sample_frames": False} + else: + videos = None + videos_kwargs = {} + + return len( + processor(text=[raw_prompt], images=images, videos=videos, videos_kwargs=videos_kwargs)[ + "input_ids" + ][0] + ) + except Exception: + print("Error processing one of the samples, skipping...") + traceback.print_exc() + return self.max_prompt_length + 1 + + else: + + def doc2len(doc) -> int: + try: + apply_kwargs = dict(**self.apply_chat_template_kwargs) + if self.tool_schemas is not None: + apply_kwargs["tools"] = self.tool_schemas + + return len( + tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True, **apply_kwargs) + ) + except Exception: + print("Error processing one of the samples, skipping...") + traceback.print_exc() + return self.max_prompt_length + 1 + + # NOTE: added by Reasoning360 + # Handle both pandas DataFrame and datasets.Dataset + if isinstance(dataframe, pd.DataFrame): + # For pandas DataFrame, use apply with axis=1 to process each row + print(f"Filtering pandas DataFrame with {len(dataframe)} rows...") + from tqdm import tqdm + tqdm.pandas(desc=f"Filtering prompts longer than {self.max_prompt_length} tokens") + keep_mask = dataframe.progress_apply( + lambda row: doc2len(row.to_dict()) <= self.max_prompt_length, + axis=1 + ) + dataframe = dataframe[keep_mask].reset_index(drop=True) + else: + # For datasets.Dataset, use the standard filter method + dataframe = dataframe.filter( + lambda doc: doc2len(doc) <= self.max_prompt_length, + num_proc=self.num_workers, + desc=f"Filtering prompts longer than {self.max_prompt_length} tokens", + ) + + print(f"filter dataset len: {len(dataframe)}") + return dataframe + + def resume_dataset_state(self): + self.serialize_dataset = not hasattr(self, "original_data_files") + # resume dataframe if not it's serialized in data.pt + if not self.serialize_dataset: + self._download(use_origin_parquet=True) # download and resume from original parquet files + self._read_files_and_tokenize() + else: + print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") + + def __len__(self): + return len(self.dataframe) + + def _build_messages(self, example: dict): + messages: list = example.pop(self.prompt_key) + + if self.image_key in example or self.video_key in example: + for message in messages: + content = message["content"] + content_list = [] + segments = re.split("(|
')[-1].strip() + # with open("solution_str_Qwen3-4B.txt_maze", "a") as f: + # f.write("data_source: " + data_source + '\n') + # f.write("solution_str: " + solution_str + '\n') + # f.write("form_solution: " + form_solution + '\n') + # f.write('-'*32 + '\n') + data = Data.from_json_str(extra_info["game_data_str"]) + verifier = verifier_classes[data_source.replace("synlogic_", "")]() + res = verifier.verify(data, form_solution) + if res: + res = 1.0 + else: + res = 0.0 + else: raise NotImplementedError(f"Reward function is not implemented for {data_source=}") diff --git a/verl/utils/reward_score/arcagi.py b/src/reasoning360/utils/reward_score/arcagi.py similarity index 100% rename from verl/utils/reward_score/arcagi.py rename to src/reasoning360/utils/reward_score/arcagi.py diff --git a/verl/utils/reward_score/codeio.py b/src/reasoning360/utils/reward_score/codeio.py similarity index 100% rename from verl/utils/reward_score/codeio.py rename to src/reasoning360/utils/reward_score/codeio.py diff --git a/verl/utils/reward_score/coder1/README.md b/src/reasoning360/utils/reward_score/coder1/README.md similarity index 100% rename from verl/utils/reward_score/coder1/README.md rename to src/reasoning360/utils/reward_score/coder1/README.md diff --git a/verl/utils/reward_score/coder1/__init__.py b/src/reasoning360/utils/reward_score/coder1/__init__.py similarity index 100% rename from verl/utils/reward_score/coder1/__init__.py rename to src/reasoning360/utils/reward_score/coder1/__init__.py diff --git a/verl/utils/reward_score/coder1/bwrap_exec.py b/src/reasoning360/utils/reward_score/coder1/bwrap_exec.py similarity index 100% rename from verl/utils/reward_score/coder1/bwrap_exec.py rename to src/reasoning360/utils/reward_score/coder1/bwrap_exec.py diff --git a/verl/utils/reward_score/coder1/ces_exec.py b/src/reasoning360/utils/reward_score/coder1/ces_exec.py similarity index 100% rename from verl/utils/reward_score/coder1/ces_exec.py rename to src/reasoning360/utils/reward_score/coder1/ces_exec.py diff --git a/verl/utils/reward_score/coder1/docker_exec.py b/src/reasoning360/utils/reward_score/coder1/docker_exec.py similarity index 100% rename from verl/utils/reward_score/coder1/docker_exec.py rename to src/reasoning360/utils/reward_score/coder1/docker_exec.py diff --git a/verl/utils/reward_score/coder1/firejail_exec.py b/src/reasoning360/utils/reward_score/coder1/firejail_exec.py similarity index 100% rename from verl/utils/reward_score/coder1/firejail_exec.py rename to src/reasoning360/utils/reward_score/coder1/firejail_exec.py diff --git a/verl/utils/reward_score/coder1/kira_exec.py b/src/reasoning360/utils/reward_score/coder1/kira_exec.py similarity index 100% rename from verl/utils/reward_score/coder1/kira_exec.py rename to src/reasoning360/utils/reward_score/coder1/kira_exec.py diff --git a/verl/utils/reward_score/coder1/sandboxfusion_exec.py b/src/reasoning360/utils/reward_score/coder1/sandboxfusion_exec.py similarity index 100% rename from verl/utils/reward_score/coder1/sandboxfusion_exec.py rename to src/reasoning360/utils/reward_score/coder1/sandboxfusion_exec.py diff --git a/verl/utils/reward_score/coder1/unsafe_local_exec.py b/src/reasoning360/utils/reward_score/coder1/unsafe_local_exec.py similarity index 100% rename from verl/utils/reward_score/coder1/unsafe_local_exec.py rename to src/reasoning360/utils/reward_score/coder1/unsafe_local_exec.py diff --git a/verl/utils/reward_score/coder1/utils.py b/src/reasoning360/utils/reward_score/coder1/utils.py similarity index 100% rename from verl/utils/reward_score/coder1/utils.py rename to src/reasoning360/utils/reward_score/coder1/utils.py diff --git a/verl/utils/reward_score/cruxeval/__init__.py b/src/reasoning360/utils/reward_score/cruxeval/__init__.py similarity index 100% rename from verl/utils/reward_score/cruxeval/__init__.py rename to src/reasoning360/utils/reward_score/cruxeval/__init__.py diff --git a/src/reasoning360/utils/reward_score/cruxeval/cruxeval.py b/src/reasoning360/utils/reward_score/cruxeval/cruxeval.py new file mode 100644 index 000000000..e69de29bb diff --git a/verl/utils/reward_score/cruxeval/utils.py b/src/reasoning360/utils/reward_score/cruxeval/utils.py similarity index 100% rename from verl/utils/reward_score/cruxeval/utils.py rename to src/reasoning360/utils/reward_score/cruxeval/utils.py diff --git a/src/reasoning360/utils/reward_score/deepmath.py b/src/reasoning360/utils/reward_score/deepmath.py new file mode 100644 index 000000000..3b6a5f32a --- /dev/null +++ b/src/reasoning360/utils/reward_score/deepmath.py @@ -0,0 +1,225 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Optional, Union + + +def compute_score(solution_str: str, ground_truth: str, extra_info: Optional[dict] = None) -> float: + """Compute the score for DeepMath dataset solutions. + + Args: + solution_str: The model's solution/answer + ground_truth: The correct answer from the dataset + extra_info: Optional additional information (e.g., difficulty, topic) + + Returns: + float: 1.0 if correct, 0.0 otherwise + """ + try: + # Extract answer from solution if it's in boxed format + extracted_answer = extract_boxed_answer(solution_str) + if extracted_answer is None: + # Try to extract from common answer patterns + extracted_answer = extract_answer_patterns(solution_str) + + if extracted_answer is None: + # Use the full solution string as last resort + extracted_answer = solution_str.strip() + + # Normalize both answers for comparison + normalized_solution = normalize_math_answer(extracted_answer) + normalized_ground_truth = normalize_math_answer(ground_truth) + + # Check if answers are equivalent + if is_equivalent(normalized_solution, normalized_ground_truth): + return 1.0 + + # Additional check for numerical equivalence + if is_numerically_equivalent(normalized_solution, normalized_ground_truth): + return 1.0 + + return 0.0 + except Exception as e: + print(f"Error in DeepMath scoring: {e}") + return 0.0 + + +def extract_boxed_answer(text: str) -> Optional[str]: + """Extract answer from \\boxed{...} format.""" + # Look for the last boxed expression + pattern = r"\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}" + matches = re.findall(pattern, text) + if matches: + return matches[-1] + + # Also check for \boxed without braces + pattern2 = r"\\boxed\s+([^\s]+)" + matches2 = re.findall(pattern2, text) + if matches2: + return matches2[-1] + + return None + + +def extract_answer_patterns(text: str) -> Optional[str]: + """Extract answer from common answer patterns.""" + patterns = [ + r"(?:final answer|answer)[\s:]*(?:is)?[\s:]*([^\n.]+)", + r"(?:evaluates to|equals to|is equal to)[\s:]*([^\n.]+)", + r"therefore[\s,]+([^\n.]+)", + r"thus[\s,]+([^\n.]+)", + r"hence[\s,]+([^\n.]+)", + r"=\s*([^\n]+)$", # Last equals sign + r"(?:limit|integral|sum|product)[\s\w]*(?:evaluates to|is|equals)[\s:]*([^\n.]+)", + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + # Clean the extracted answer + answer = matches[-1].strip() + # Remove trailing punctuation but keep mathematical symbols + answer = answer.rstrip('.,;:') + return answer + + # Try to find any number at the end of the text + number_pattern = r"(?:^|\s)([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?|\d+/\d+)(?:\s*$|\s*[.,;]?\s*$)" + matches = re.findall(number_pattern, text) + if matches: + return matches[-1].strip() + + return None + + +def normalize_math_answer(answer: str) -> str: + """Normalize mathematical expressions for comparison.""" + # Remove whitespace + answer = answer.strip() + answer = re.sub(r'\s+', '', answer) + + # Remove dollar signs + answer = answer.replace('$', '') + + # Normalize LaTeX commands + answer = answer.replace('\\left', '') + answer = answer.replace('\\right', '') + answer = answer.replace('\\Big', '') + answer = answer.replace('\\big', '') + answer = answer.replace('\\cdot', '*') + answer = answer.replace('\\times', '*') + answer = answer.replace('\\div', '/') + + # Handle fractions + answer = normalize_fractions(answer) + + # Remove trailing punctuation + answer = answer.rstrip('.,;:') + + return answer + + +def normalize_fractions(text: str) -> str: + """Normalize fraction representations.""" + # Convert \frac{a}{b} to a/b for simple cases + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + + def frac_replacer(match): + num, den = match.groups() + # For simple numeric fractions, compute the value + try: + num_val = float(eval(num)) + den_val = float(eval(den)) + if den_val != 0: + result = num_val / den_val + # Return as integer if it's a whole number + if result == int(result): + return str(int(result)) + return str(result) + except: + pass + return f"({num})/({den})" + + text = re.sub(frac_pattern, frac_replacer, text) + + # Also handle tfrac and dfrac + text = text.replace('\\tfrac', '\\frac') + text = text.replace('\\dfrac', '\\frac') + + return text + + +def is_equivalent(answer1: str, answer2: str) -> bool: + """Check if two normalized answers are equivalent.""" + # Direct string comparison + if answer1 == answer2: + return True + + # Case-insensitive comparison for text answers + if answer1.lower() == answer2.lower(): + return True + + # Check common mathematical equivalences + equivalences = [ + ('infinity', '\\infty'), + ('inf', '\\infty'), + ('undefined', 'dne'), + ('doesnotexist', 'dne'), + ('none', 'dne'), + ] + + a1_lower = answer1.lower() + a2_lower = answer2.lower() + + for eq1, eq2 in equivalences: + if (eq1 in a1_lower and eq2 in a2_lower) or (eq2 in a1_lower and eq1 in a2_lower): + return True + + return False + + +def is_numerically_equivalent(answer1: str, answer2: str, tolerance: float = 1e-9) -> bool: + """Check if two answers are numerically equivalent.""" + try: + # Try to evaluate as numerical expressions + val1 = evaluate_expression(answer1) + val2 = evaluate_expression(answer2) + + if val1 is not None and val2 is not None: + return abs(val1 - val2) < tolerance + except: + pass + + return False + + +def evaluate_expression(expr: str) -> Optional[float]: + """Safely evaluate a mathematical expression.""" + try: + # Remove common LaTeX commands that might remain + expr = expr.replace('\\pi', '3.141592653589793') + expr = expr.replace('\\e', '2.718281828459045') + expr = expr.replace('^', '**') + + # Only allow safe operations + allowed_names = { + 'abs': abs, + 'min': min, + 'max': max, + } + + # Evaluate the expression safely + result = eval(expr, {"__builtins__": {}}, allowed_names) + return float(result) + except: + return None \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/deepmath_test.py b/src/reasoning360/utils/reward_score/deepmath_test.py new file mode 100644 index 000000000..7fde18db5 --- /dev/null +++ b/src/reasoning360/utils/reward_score/deepmath_test.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +""" +Test script for DeepMath integration +""" + +import sys +sys.path.append('/mnt/weka/home/jianshu.she/IFM/Reasoning360') + +from verl.utils.reward_score import default_compute_score + +def test_deepmath_scoring(): + """Test DeepMath scoring functionality""" + + print("Testing DeepMath scoring integration...") + + # Test cases + test_cases = [ + { + "solution": "\\boxed{0}", + "ground_truth": "0", + "expected": 1.0, + "description": "Exact match with boxed answer" + }, + { + "solution": "The limit evaluates to 0", + "ground_truth": "0", + "expected": 1.0, + "description": "Text extraction" + }, + { + "solution": "\\boxed{\\frac{2}{3}}", + "ground_truth": "2/3", + "expected": 1.0, + "description": "Fraction equivalence" + }, + { + "solution": "\\boxed{42}", + "ground_truth": "24", + "expected": 0.0, + "description": "Wrong answer" + }, + { + "solution": "The answer is \\infty", + "ground_truth": "infinity", + "expected": 1.0, + "description": "Infinity equivalence" + } + ] + + print("\nRunning test cases:") + print("=" * 60) + + all_passed = True + for i, test in enumerate(test_cases, 1): + try: + # Test with different data source identifiers + for data_source in ["deepmath", "DeepMath", "zwhe99/DeepMath-103K"]: + score = default_compute_score( + data_source=data_source, + solution_str=test["solution"], + ground_truth=test["ground_truth"] + ) + + passed = abs(score - test["expected"]) < 0.001 + + if not passed: + print(f"❌ Test {i} FAILED ({data_source}): {test['description']}") + print(f" Expected: {test['expected']}, Got: {score}") + all_passed = False + break + else: + print(f"✅ Test {i} PASSED: {test['description']}") + + except Exception as e: + print(f"❌ Test {i} ERROR: {test['description']}") + print(f" Error: {e}") + all_passed = False + + print("=" * 60) + if all_passed: + print("✅ All tests passed!") + else: + print("❌ Some tests failed") + + return all_passed + + +if __name__ == "__main__": + success = test_deepmath_scoring() + sys.exit(0 if success else 1) \ No newline at end of file diff --git a/verl/utils/reward_score/geo3k.py b/src/reasoning360/utils/reward_score/geo3k.py similarity index 100% rename from verl/utils/reward_score/geo3k.py rename to src/reasoning360/utils/reward_score/geo3k.py diff --git a/verl/utils/reward_score/gpqa.py b/src/reasoning360/utils/reward_score/gpqa.py similarity index 100% rename from verl/utils/reward_score/gpqa.py rename to src/reasoning360/utils/reward_score/gpqa.py diff --git a/verl/utils/reward_score/graph_dataset.py b/src/reasoning360/utils/reward_score/graph_dataset.py similarity index 100% rename from verl/utils/reward_score/graph_dataset.py rename to src/reasoning360/utils/reward_score/graph_dataset.py diff --git a/verl/utils/reward_score/gsm8k.py b/src/reasoning360/utils/reward_score/gsm8k.py similarity index 100% rename from verl/utils/reward_score/gsm8k.py rename to src/reasoning360/utils/reward_score/gsm8k.py diff --git a/verl/utils/reward_score/ifbench/__init__.py b/src/reasoning360/utils/reward_score/ifbench/__init__.py similarity index 100% rename from verl/utils/reward_score/ifbench/__init__.py rename to src/reasoning360/utils/reward_score/ifbench/__init__.py diff --git a/verl/utils/reward_score/ifbench/check_ifbench_data.py b/src/reasoning360/utils/reward_score/ifbench/check_ifbench_data.py similarity index 100% rename from verl/utils/reward_score/ifbench/check_ifbench_data.py rename to src/reasoning360/utils/reward_score/ifbench/check_ifbench_data.py diff --git a/verl/utils/reward_score/ifbench/instructions.py b/src/reasoning360/utils/reward_score/ifbench/instructions.py similarity index 100% rename from verl/utils/reward_score/ifbench/instructions.py rename to src/reasoning360/utils/reward_score/ifbench/instructions.py diff --git a/verl/utils/reward_score/ifbench/instructions_registry.py b/src/reasoning360/utils/reward_score/ifbench/instructions_registry.py similarity index 100% rename from verl/utils/reward_score/ifbench/instructions_registry.py rename to src/reasoning360/utils/reward_score/ifbench/instructions_registry.py diff --git a/verl/utils/reward_score/ifbench/instructions_util.py b/src/reasoning360/utils/reward_score/ifbench/instructions_util.py similarity index 100% rename from verl/utils/reward_score/ifbench/instructions_util.py rename to src/reasoning360/utils/reward_score/ifbench/instructions_util.py diff --git a/verl/utils/reward_score/ifbench/split_fixed_data.py b/src/reasoning360/utils/reward_score/ifbench/split_fixed_data.py similarity index 100% rename from verl/utils/reward_score/ifbench/split_fixed_data.py rename to src/reasoning360/utils/reward_score/ifbench/split_fixed_data.py diff --git a/verl/utils/reward_score/ifbench/test_ifbench.py b/src/reasoning360/utils/reward_score/ifbench/test_ifbench.py similarity index 100% rename from verl/utils/reward_score/ifbench/test_ifbench.py rename to src/reasoning360/utils/reward_score/ifbench/test_ifbench.py diff --git a/verl/utils/reward_score/ifeval/__init__.py b/src/reasoning360/utils/reward_score/ifeval/__init__.py similarity index 100% rename from verl/utils/reward_score/ifeval/__init__.py rename to src/reasoning360/utils/reward_score/ifeval/__init__.py diff --git a/verl/utils/reward_score/ifeval/instructions.py b/src/reasoning360/utils/reward_score/ifeval/instructions.py similarity index 100% rename from verl/utils/reward_score/ifeval/instructions.py rename to src/reasoning360/utils/reward_score/ifeval/instructions.py diff --git a/verl/utils/reward_score/ifeval/instructions_registry.py b/src/reasoning360/utils/reward_score/ifeval/instructions_registry.py similarity index 100% rename from verl/utils/reward_score/ifeval/instructions_registry.py rename to src/reasoning360/utils/reward_score/ifeval/instructions_registry.py diff --git a/verl/utils/reward_score/ifeval/instructions_util.py b/src/reasoning360/utils/reward_score/ifeval/instructions_util.py similarity index 100% rename from verl/utils/reward_score/ifeval/instructions_util.py rename to src/reasoning360/utils/reward_score/ifeval/instructions_util.py diff --git a/verl/utils/reward_score/livebench/__init__.py b/src/reasoning360/utils/reward_score/livebench/__init__.py similarity index 100% rename from verl/utils/reward_score/livebench/__init__.py rename to src/reasoning360/utils/reward_score/livebench/__init__.py diff --git a/verl/utils/reward_score/livebench/data_analysis/cta/utils.py b/src/reasoning360/utils/reward_score/livebench/data_analysis/cta/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/data_analysis/cta/utils.py rename to src/reasoning360/utils/reward_score/livebench/data_analysis/cta/utils.py diff --git a/verl/utils/reward_score/livebench/data_analysis/tablejoin/utils.py b/src/reasoning360/utils/reward_score/livebench/data_analysis/tablejoin/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/data_analysis/tablejoin/utils.py rename to src/reasoning360/utils/reward_score/livebench/data_analysis/tablejoin/utils.py diff --git a/verl/utils/reward_score/livebench/data_analysis/tablereformat/utils.py b/src/reasoning360/utils/reward_score/livebench/data_analysis/tablereformat/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/data_analysis/tablereformat/utils.py rename to src/reasoning360/utils/reward_score/livebench/data_analysis/tablereformat/utils.py diff --git a/verl/utils/reward_score/livebench/reasoning/house_traversal/utils.py b/src/reasoning360/utils/reward_score/livebench/reasoning/house_traversal/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/reasoning/house_traversal/utils.py rename to src/reasoning360/utils/reward_score/livebench/reasoning/house_traversal/utils.py diff --git a/verl/utils/reward_score/livebench/reasoning/spatial/utils.py b/src/reasoning360/utils/reward_score/livebench/reasoning/spatial/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/reasoning/spatial/utils.py rename to src/reasoning360/utils/reward_score/livebench/reasoning/spatial/utils.py diff --git a/verl/utils/reward_score/livebench/reasoning/web_of_lies_v2/utils.py b/src/reasoning360/utils/reward_score/livebench/reasoning/web_of_lies_v2/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/reasoning/web_of_lies_v2/utils.py rename to src/reasoning360/utils/reward_score/livebench/reasoning/web_of_lies_v2/utils.py diff --git a/verl/utils/reward_score/livebench/reasoning/web_of_lies_v3/utils.py b/src/reasoning360/utils/reward_score/livebench/reasoning/web_of_lies_v3/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/reasoning/web_of_lies_v3/utils.py rename to src/reasoning360/utils/reward_score/livebench/reasoning/web_of_lies_v3/utils.py diff --git a/verl/utils/reward_score/livebench/reasoning/zebra_puzzle/utils.py b/src/reasoning360/utils/reward_score/livebench/reasoning/zebra_puzzle/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/reasoning/zebra_puzzle/utils.py rename to src/reasoning360/utils/reward_score/livebench/reasoning/zebra_puzzle/utils.py diff --git a/verl/utils/reward_score/livebench/util.py b/src/reasoning360/utils/reward_score/livebench/util.py similarity index 100% rename from verl/utils/reward_score/livebench/util.py rename to src/reasoning360/utils/reward_score/livebench/util.py diff --git a/verl/utils/reward_score/livebench/writing/connections/utils.py b/src/reasoning360/utils/reward_score/livebench/writing/connections/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/writing/connections/utils.py rename to src/reasoning360/utils/reward_score/livebench/writing/connections/utils.py diff --git a/verl/utils/reward_score/livebench/writing/plot_unscrambling/utils.py b/src/reasoning360/utils/reward_score/livebench/writing/plot_unscrambling/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/writing/plot_unscrambling/utils.py rename to src/reasoning360/utils/reward_score/livebench/writing/plot_unscrambling/utils.py diff --git a/verl/utils/reward_score/livebench/writing/typos/utils.py b/src/reasoning360/utils/reward_score/livebench/writing/typos/utils.py similarity index 100% rename from verl/utils/reward_score/livebench/writing/typos/utils.py rename to src/reasoning360/utils/reward_score/livebench/writing/typos/utils.py diff --git a/verl/utils/reward_score/math.py b/src/reasoning360/utils/reward_score/math.py similarity index 100% rename from verl/utils/reward_score/math.py rename to src/reasoning360/utils/reward_score/math.py diff --git a/verl/utils/reward_score/math_batch.py b/src/reasoning360/utils/reward_score/math_batch.py similarity index 100% rename from verl/utils/reward_score/math_batch.py rename to src/reasoning360/utils/reward_score/math_batch.py diff --git a/verl/utils/reward_score/math_dapo.py b/src/reasoning360/utils/reward_score/math_dapo.py similarity index 100% rename from verl/utils/reward_score/math_dapo.py rename to src/reasoning360/utils/reward_score/math_dapo.py diff --git a/verl/utils/reward_score/math_llm_judge/__init__.py b/src/reasoning360/utils/reward_score/math_llm_judge/__init__.py similarity index 98% rename from verl/utils/reward_score/math_llm_judge/__init__.py rename to src/reasoning360/utils/reward_score/math_llm_judge/__init__.py index 7d2b20cae..41d08c01c 100644 --- a/verl/utils/reward_score/math_llm_judge/__init__.py +++ b/src/reasoning360/utils/reward_score/math_llm_judge/__init__.py @@ -398,14 +398,15 @@ def llm_check_answer(model_output: str, ground_truth: str, question: str) -> boo # use llm to check if the answer is correct # url = "http://176.56.200.81:30000/v1/chat/completions" - url = os.getenv("MATH_LLM_JUDGE_URL") - if not url: + url_base = os.getenv("MATH_LLM_JUDGE_URL") + if not url_base: raise ValueError("MATH_LLM_JUDGE_URL is not set") + url = url_base.rstrip("/") + "/v1/chat/completions" prompt = input_template.format(QUESTION=question, STUDENT_ANSWER=model_output, REFERENCE_ANSWER=ground_truth) data = { - "model": "Qwen/Qwen2.5-32B-Instruct", + "model": "openai/gpt-oss-120b", "messages": [{"role": "user", "content": prompt}], } response = requests.post(url, json=data) @@ -423,7 +424,7 @@ def llm_check_answer(model_output: str, ground_truth: str, question: str) -> boo def compute_score(model_output: str, ground_truth: str, extra_info: dict) -> bool: - question = extra_info["question"] + question = extra_info["original_question"] model_output = str(model_output) ground_truth = str(ground_truth) @@ -447,5 +448,4 @@ def compute_score(model_output: str, if is_matched and not is_correct: # use llm to check if the answer is correct is_correct = llm_check_answer(extracted_model_output, ground_truth, question) - return is_correct, 1, extracted_model_output diff --git a/verl/utils/reward_score/math_llm_judge/grader.py b/src/reasoning360/utils/reward_score/math_llm_judge/grader.py similarity index 100% rename from verl/utils/reward_score/math_llm_judge/grader.py rename to src/reasoning360/utils/reward_score/math_llm_judge/grader.py diff --git a/verl/utils/reward_score/math_llm_judge/math_normalize.py b/src/reasoning360/utils/reward_score/math_llm_judge/math_normalize.py similarity index 100% rename from verl/utils/reward_score/math_llm_judge/math_normalize.py rename to src/reasoning360/utils/reward_score/math_llm_judge/math_normalize.py diff --git a/verl/utils/reward_score/math_verify.py b/src/reasoning360/utils/reward_score/math_verify.py similarity index 100% rename from verl/utils/reward_score/math_verify.py rename to src/reasoning360/utils/reward_score/math_verify.py diff --git a/verl/utils/reward_score/naive_dapo.py b/src/reasoning360/utils/reward_score/naive_dapo.py similarity index 96% rename from verl/utils/reward_score/naive_dapo.py rename to src/reasoning360/utils/reward_score/naive_dapo.py index d26a1dd72..5dcc11440 100644 --- a/verl/utils/reward_score/naive_dapo.py +++ b/src/reasoning360/utils/reward_score/naive_dapo.py @@ -21,6 +21,7 @@ from pylatexenc import latex2text from sympy.parsing import sympy_parser import os +import threading from .prime_math import math_normalize from .prime_math.grader import math_equal @@ -158,6 +159,12 @@ def handler(signum, frame): raise TimeoutError("Operation timed out!") def wrapper(*args, **kwargs): + # Python signals can only be installed/handled in the main thread. + # Reward scoring may run inside ThreadPoolExecutor (e.g., async reward loop), + # so avoid SIGALRM in worker threads instead of crashing. + if threading.current_thread() is not threading.main_thread(): + return func(*args, **kwargs) + old_handler = signal.getsignal(signal.SIGALRM) signal.signal(signal.SIGALRM, handler) signal.alarm(timeout_seconds) @@ -184,12 +191,13 @@ def _sympy_parse(expr: str): ) +# @timeout(timeout_seconds=5) def _parse_latex(expr: str) -> str: """Attempts to parse latex to an expression sympy can read.""" expr = expr.replace("\\tfrac", "\\frac") expr = expr.replace("\\dfrac", "\\frac") expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. - expr = latex2text.LatexNodes2Text().latex_to_text(expr) + # expr = latex2text.LatexNodes2Text().latex_to_text(expr) # Replace the specific characters that this parser uses. expr = expr.replace("√", "sqrt") diff --git a/src/reasoning360/utils/reward_score/nemotron_stem.py b/src/reasoning360/utils/reward_score/nemotron_stem.py new file mode 100644 index 000000000..f1f982f4d --- /dev/null +++ b/src/reasoning360/utils/reward_score/nemotron_stem.py @@ -0,0 +1,90 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re + + +def extract_solution(solution_str, method='strict'): + """ + Extract the final answer choice from an LLM's response to a multiple-choice nemotron_stem question. + + Args: + solution_str (str): The full text response from the LLM + method (str): 'strict' for exact format matching, 'flexible' for more lenient matching + + Returns: + str: The extracted answer choice (A, B, C, or D) or None if not found + """ + assert method in ['strict', 'flexible'] + + if method == 'strict': + # First try to find answer in boxed format + boxed_match = re.search(r"\\boxed\{([A-D])\}", solution_str) + if boxed_match: + return boxed_match.group(1) + + # Then try standard "Answer:" format + answer_match = re.search(r"(?i)Answer[ \t]*:[ \t]*\$?([A-D])\$?", solution_str) + if answer_match: + return answer_match.group(1) + + # Try to find single letter answers at the end + end_match = re.search(r"\b([A-D])\b(?!.*\b[A-D]\b)", solution_str) + if end_match: + return end_match.group(1) + + return None + + elif method == 'flexible': + # Look for answers in parentheses + answer = re.findall(r"\(([A-D])\)", solution_str) + if answer: + return answer[-1] # Return the last found answer + + # Look for boxed answers + boxed_answer = re.findall(r"\\boxed\{([A-D])\}", solution_str) + if boxed_answer: + return boxed_answer[-1] + + # Look for any A, B, C, D pattern + general_answer = re.findall(r"\b([A-D])\b", solution_str) + if general_answer: + return general_answer[-1] + + return None + + +def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1., extra_info=None): + """The scoring function for nemotron_stem dataset. + + Args: + solution_str: the solution text + ground_truth: the ground truth answer (A, B, C, or D) + method: the method to extract the solution, choices are 'strict' and 'flexible' + format_score: the score for the format when answer is extractable but wrong + score: the score for the correct answer + extra_info: additional information (not used in this implementation) + + Returns: + dict: A dictionary containing 'score' and 'acc' keys + """ + answer = extract_solution(solution_str=solution_str, method=method) + if answer is None: + return {'score': 0, 'acc': 0} + else: + if answer == ground_truth: + return {'score': score, 'acc': 1.} + else: + return {'score': format_score, 'acc': 0.} \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/nemotron_stem_test.py b/src/reasoning360/utils/reward_score/nemotron_stem_test.py new file mode 100644 index 000000000..96c815f5d --- /dev/null +++ b/src/reasoning360/utils/reward_score/nemotron_stem_test.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +import pandas as pd +from verl.utils.reward_score import default_compute_score +from verl.utils.reward_score.nemotron_stem import compute_score, extract_solution + + +def test_extract_solution(): + """Test the extract_solution function with various response formats.""" + print("Testing extract_solution function...") + + test_cases = [ + ("The answer is A", "A"), + ("Answer: B", "B"), + ("\\boxed{C}", "C"), + ("After careful analysis, the answer is D.", "D"), + ("I think (C) is correct", "C"), + ("No clear answer", None), + ("The final answer is \\boxed{A}.", "A"), + ("Answer: C", "C"), + ] + + for response, expected in test_cases: + result = extract_solution(response, method='strict') + print(f"Input: '{response}' -> Expected: {expected}, Got: {result}") + assert result == expected, f"Failed for '{response}': expected {expected}, got {result}" + + print("extract_solution tests passed!\n") + + +def test_compute_score(): + """Test the compute_score function.""" + print("Testing compute_score function...") + + # Test correct answer + result = compute_score("Answer: A", "A") + print(f"Correct answer test: {result}") + assert result == {'score': 1.0, 'acc': 1.0} + + # Test incorrect answer + result = compute_score("Answer: B", "A") + print(f"Incorrect answer test: {result}") + assert result == {'score': 0.0, 'acc': 0.0} + + # Test no extractable answer + result = compute_score("I don't know", "A") + print(f"No answer test: {result}") + assert result == {'score': 0, 'acc': 0} + + print("compute_score tests passed!\n") + + +def test_default_compute_score(): + """Test the default_compute_score function with nemotron_stem data source.""" + print("Testing default_compute_score with nemotron_stem...") + + # Test with stem_nemotron data source + result = default_compute_score("stem_nemotron", "Answer: C", "C") + print(f"stem_nemotron correct: {result}") + assert result == {'score': 1.0, 'acc': 1.0} + + result = default_compute_score("stem_nemotron", "Answer: A", "C") + print(f"stem_nemotron incorrect: {result}") + assert result == {'score': 0.0, 'acc': 0.0} + + # Test with nemotron_stem data source + result = default_compute_score("nemotron_stem", "\\boxed{B}", "B") + print(f"nemotron_stem correct: {result}") + assert result == {'score': 1.0, 'acc': 1.0} + + print("default_compute_score tests passed!\n") + + +def test_real_data(): + """Test with real nemotron_stem data.""" + print("Testing with real nemotron_stem data...") + + try: + # Load a sample of the test data + df = pd.read_parquet('/mnt/sharefs/users/jianshu.she/nemotron_stem/test_data_final.parquet') + sample = df.head(5) + + print(f"Testing with {len(sample)} samples from real data...") + + for idx, row in sample.iterrows(): + data_source = row['data_source'] + response = row['response'] + ground_truth = row['reward_model']['ground_truth'] + + # Test with our implementation + result = default_compute_score(data_source, response, ground_truth) + print(f"Sample {idx}: response='{response}', ground_truth='{ground_truth}', score={result}") + + print("Real data test completed!\n") + + except Exception as e: + print(f"Could not test with real data: {e}") + + +if __name__ == "__main__": + test_extract_solution() + test_compute_score() + test_default_compute_score() + test_real_data() + print("All tests passed!") \ No newline at end of file diff --git a/verl/utils/reward_score/orz/__init__.py b/src/reasoning360/utils/reward_score/orz/__init__.py similarity index 100% rename from verl/utils/reward_score/orz/__init__.py rename to src/reasoning360/utils/reward_score/orz/__init__.py diff --git a/verl/utils/reward_score/orz/math_utils.py b/src/reasoning360/utils/reward_score/orz/math_utils.py similarity index 100% rename from verl/utils/reward_score/orz/math_utils.py rename to src/reasoning360/utils/reward_score/orz/math_utils.py diff --git a/verl/utils/reward_score/orz/math_utils_sync.py b/src/reasoning360/utils/reward_score/orz/math_utils_sync.py similarity index 100% rename from verl/utils/reward_score/orz/math_utils_sync.py rename to src/reasoning360/utils/reward_score/orz/math_utils_sync.py diff --git a/verl/utils/reward_score/prime_code/README.md b/src/reasoning360/utils/reward_score/prime_code/README.md similarity index 100% rename from verl/utils/reward_score/prime_code/README.md rename to src/reasoning360/utils/reward_score/prime_code/README.md diff --git a/verl/utils/reward_score/prime_code/__init__.py b/src/reasoning360/utils/reward_score/prime_code/__init__.py similarity index 100% rename from verl/utils/reward_score/prime_code/__init__.py rename to src/reasoning360/utils/reward_score/prime_code/__init__.py diff --git a/verl/utils/reward_score/prime_code/testing_util.py b/src/reasoning360/utils/reward_score/prime_code/testing_util.py similarity index 100% rename from verl/utils/reward_score/prime_code/testing_util.py rename to src/reasoning360/utils/reward_score/prime_code/testing_util.py diff --git a/verl/utils/reward_score/prime_code/utils.py b/src/reasoning360/utils/reward_score/prime_code/utils.py similarity index 100% rename from verl/utils/reward_score/prime_code/utils.py rename to src/reasoning360/utils/reward_score/prime_code/utils.py diff --git a/verl/utils/reward_score/prime_math/__init__.py b/src/reasoning360/utils/reward_score/prime_math/__init__.py similarity index 100% rename from verl/utils/reward_score/prime_math/__init__.py rename to src/reasoning360/utils/reward_score/prime_math/__init__.py diff --git a/verl/utils/reward_score/prime_math/grader.py b/src/reasoning360/utils/reward_score/prime_math/grader.py similarity index 100% rename from verl/utils/reward_score/prime_math/grader.py rename to src/reasoning360/utils/reward_score/prime_math/grader.py diff --git a/verl/utils/reward_score/prime_math/math_normalize.py b/src/reasoning360/utils/reward_score/prime_math/math_normalize.py similarity index 100% rename from verl/utils/reward_score/prime_math/math_normalize.py rename to src/reasoning360/utils/reward_score/prime_math/math_normalize.py diff --git a/verl/utils/reward_score/puzzles_dataset.py b/src/reasoning360/utils/reward_score/puzzles_dataset.py similarity index 100% rename from verl/utils/reward_score/puzzles_dataset.py rename to src/reasoning360/utils/reward_score/puzzles_dataset.py diff --git a/verl/utils/reward_score/reasoning_gym/__init__.py b/src/reasoning360/utils/reward_score/reasoning_gym/__init__.py similarity index 100% rename from verl/utils/reward_score/reasoning_gym/__init__.py rename to src/reasoning360/utils/reward_score/reasoning_gym/__init__.py diff --git a/verl/utils/reward_score/sandbox_fusion/__init__.py b/src/reasoning360/utils/reward_score/sandbox_fusion/__init__.py similarity index 100% rename from verl/utils/reward_score/sandbox_fusion/__init__.py rename to src/reasoning360/utils/reward_score/sandbox_fusion/__init__.py diff --git a/verl/utils/reward_score/sandbox_fusion/utils.py b/src/reasoning360/utils/reward_score/sandbox_fusion/utils.py similarity index 100% rename from verl/utils/reward_score/sandbox_fusion/utils.py rename to src/reasoning360/utils/reward_score/sandbox_fusion/utils.py diff --git a/verl/utils/reward_score/search_r1_like_qa_em.py b/src/reasoning360/utils/reward_score/search_r1_like_qa_em.py similarity index 100% rename from verl/utils/reward_score/search_r1_like_qa_em.py rename to src/reasoning360/utils/reward_score/search_r1_like_qa_em.py diff --git a/verl/utils/reward_score/stem_llm_judge/__init__.py b/src/reasoning360/utils/reward_score/stem_llm_judge/__init__.py similarity index 100% rename from verl/utils/reward_score/stem_llm_judge/__init__.py rename to src/reasoning360/utils/reward_score/stem_llm_judge/__init__.py diff --git a/verl/utils/reward_score/supergpqa.py b/src/reasoning360/utils/reward_score/supergpqa.py similarity index 100% rename from verl/utils/reward_score/supergpqa.py rename to src/reasoning360/utils/reward_score/supergpqa.py diff --git a/src/reasoning360/utils/reward_score/synlogic/__init__.py b/src/reasoning360/utils/reward_score/synlogic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/reasoning360/utils/reward_score/synlogic/arrow_maze_verifier.py b/src/reasoning360/utils/reward_score/synlogic/arrow_maze_verifier.py new file mode 100644 index 000000000..14f5977f8 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/arrow_maze_verifier.py @@ -0,0 +1,354 @@ +import json +from typing import List, Dict, Tuple +from .verifier import Verifier +from .data import Data +import re + +class ArrowMazeVerifier(Verifier): + """ + 箭头迷宫游戏验证器 + + 验证条件: + 1. 判断answer grid的大小是否和question grid一致 + 2. 判断answer grid中数字格子是否和question grid中数字格子一致 + 3. 判断question grid空格("X")在answer grid中是否被箭头填满 + 4. 判断箭头符号是否合法: + 上(↑)、下(↓)、左(←)、右(→)或对角线方向(↖、↗、↘、↙) + 5. 判断answer grid中非空格("X")和非数字的部分,即预填的箭头,是否和question grid一致 + 6. 迷宫有个隐藏的条件是所有箭头都能被射线箭头串覆盖到 + 7. 每个数字起点发出的射线箭头串总长度等于该数字 + """ + + # 定义合法的箭头符号 + VALID_ARROWS = {"↑", "↓", "←", "→", "↖", "↗", "↘", "↙"} + + # 定义箭头符号和其对应的方向 + ARROWS_DIRECTIONS = { + "↑": (-1, 0), # 上 + "↓": (1, 0), # 下 + "←": (0, -1), # 左 + "→": (0, 1), # 右 + "↖": (-1, -1), # 左上 + "↗": (-1, 1), # 右上 + "↘": (1, 1), # 右下 + "↙": (1, -1) # 左下 + } + + def verify(self, data: Data, test_solution_str: str) -> bool: + + """ + 验证箭头迷宫的答案是否正确 + + @param data: 游戏数据 + @param test_solution_str: 测试答案字符串 (JSON格式的二维数组) + @return: 答案是否正确 + """ + test_answer_str = self.extract_answer(test_solution_str) + if not test_answer_str: + # print("答案为空,验证失败") + return False + + try: + # 解析测试答案 + test_answer = json.loads(test_answer_str) + + # 获取原始迷宫 + question_grid = data.metadata["maze"] + + # 检查答案是否符合要求 + if not self._verify_grid_size(test_answer, question_grid): + # print("答案网格大小与题目不匹配") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案网格大小与题目不匹配" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_number_positions(test_answer, question_grid): + # print("答案中数字位置或值与题目不匹配") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中数字位置或值与题目不匹配" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_all_blanks_filled(test_answer, question_grid): + # print("答案中有空格未被填满") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中有空格未被填满" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_arrow_symbols(test_answer): + # print("答案中包含非法箭头符号") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中包含非法箭头符号" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_prefilled_arrows(test_answer, question_grid): + # print("答案中预填箭头与题目不一致") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中预填箭头与题目不一致" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_arrow_rays(test_answer): + # print("答案中存在未被射线覆盖的箭头") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中存在未被射线覆盖的箭头" + '\n') + f.write('-'*32 + '\n') + return False + + if not self._verify_number_rays(test_answer): + # print("答案中数字的射线箭头串总数不符合要求") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("答案中数字的射线箭头串总数不符合要求" + '\n') + f.write('-'*32 + '\n') + return False + + # 所有验证都通过 + # print("验证通过!") + return True + + except Exception as e: + # print(f"验证过程中出错: {e}") + with open("solution_str_Qwen3-4B.txt_maze_verify", "a") as f: + f.write("test_solution_str: " + test_solution_str + '\n') + f.write("test_answer_str: " + test_answer_str + '\n') + f.write("question_grid: " + str(data.metadata["maze"]) + '\n') + f.write("验证过程中出错" + str(e) + '\n') + f.write('-'*32 + '\n') + return False + + def _verify_grid_size(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证答案网格大小是否与题目一致 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 网格大小是否一致 + """ + if len(test_answer) != len(question_grid): + return False + + for i in range(len(test_answer)): + if len(test_answer[i]) != len(question_grid[i]): + return False + + return True + + def _verify_number_positions(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证答案中数字位置和值是否与题目一致 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 数字位置和值是否一致 + """ + for i in range(len(question_grid)): + for j in range(len(question_grid[i])): + if question_grid[i][j].isdigit(): + if test_answer[i][j] != question_grid[i][j]: + return False + return True + + def _verify_all_blanks_filled(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证所有空格是否都被填满 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 所有空格是否被填满 + """ + for i in range(len(question_grid)): + for j in range(len(question_grid[i])): + if question_grid[i][j] == "X" and test_answer[i][j] == "X": + return False + return True + + def _verify_arrow_symbols(self, test_answer: List[List[str]]) -> bool: + """ + 验证箭头符号是否合法 + + @param test_answer: 测试答案网格 + @return: 箭头符号是否合法 + """ + for i in range(len(test_answer)): + for j in range(len(test_answer[i])): + cell = test_answer[i][j] + if not cell.isdigit() and cell != "X" and cell not in self.VALID_ARROWS: + return False + return True + + def _verify_prefilled_arrows(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证预填的箭头是否与题目一致 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 预填箭头是否一致 + """ + for i in range(len(question_grid)): + for j in range(len(question_grid[i])): + cell = question_grid[i][j] + if not cell.isdigit() and cell != "X": + if test_answer[i][j] != cell: + return False + return True + + def _verify_arrow_rays(self, test_answer: List[List[str]]) -> bool: + """ + 验证所有箭头是否都能被射线箭头串覆盖到 + + @param test_answer: 测试答案网格 + @return: 所有箭头是否都能被射线覆盖 + """ + n = len(test_answer) + m = len(test_answer[0]) if n > 0 else 0 + + # 创建覆盖标记数组 + covered = [[False for _ in range(m)] for _ in range(n)] + + # 标记数字位置为已覆盖 + for i in range(n): + for j in range(m): + if test_answer[i][j].isdigit(): + covered[i][j] = True + + # 从每个数字出发,沿各个方向延伸射线,标记覆盖到的箭头 + for i in range(n): + for j in range(m): + if test_answer[i][j].isdigit(): + # 检查所有方向 + for arrow_symbol, (di, dj) in self.ARROWS_DIRECTIONS.items(): + ni, nj = i + di, j + dj + # 沿该方向延伸,直到边界或非匹配箭头 + while 0 <= ni < n and 0 <= nj < m and test_answer[ni][nj] == arrow_symbol: + covered[ni][nj] = True + ni += di + nj += dj + + # 检查所有箭头是否都被覆盖 + for i in range(n): + for j in range(m): + if test_answer[i][j] in self.VALID_ARROWS and not covered[i][j]: + return False + + return True + + def _verify_number_rays(self, test_answer: List[List[str]]) -> bool: + """ + 验证每个数字起点发出的射线箭头串总长度是否等于该数字 + + @param test_answer: 测试答案网格 + @return: 每个数字的射线箭头串是否符合要求 + """ + n = len(test_answer) + m = len(test_answer[0]) if n > 0 else 0 + + for i in range(n): + for j in range(m): + if test_answer[i][j].isdigit(): + number = int(test_answer[i][j]) + arrow_count = self._count_arrow_rays(test_answer, i, j) + if arrow_count != number: + return False + + return True + + def _count_arrow_rays(self, grid: List[List[str]], i: int, j: int) -> int: + """ + 计算从数字出发的所有射线箭头串中箭头总数 + + @param grid: 网格 + @param i: 数字行索引 + @param j: 数字列索引 + @return: 箭头总数 + """ + n = len(grid) + m = len(grid[0]) if n > 0 else 0 + count = 0 + + # 检查所有方向 + for arrow_symbol, (di, dj) in self.ARROWS_DIRECTIONS.items(): + ni, nj = i + di, j + dj + ray_length = 0 + + # 沿该方向计数连续的相同箭头 + while 0 <= ni < n and 0 <= nj < m and grid[ni][nj] == arrow_symbol: + ray_length += 1 + ni += di + nj += dj + + count += ray_length + + return count + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 (JSON格式的二维数组) + """ + if not test_solution: + return "" + # 尝试匹配Python代码块 + import re + code_block_patterns = [ + r'```python\s*\n(.*?\[.*?\].*?)\n```', # 标准Python代码块 + r'```\s*\n(.*?\[.*?\].*?)\n```', # 无语言标记的代码块 + r'```(.*?\[.*?\].*?)```' # 无换行的代码块 + ] + + for pattern in code_block_patterns: + matches = re.findall(pattern, test_solution, re.DOTALL) + if matches: + # 获取最后一个匹配项 + code_block = matches[-1].strip() + try: + # 尝试解析为Python列表 + grid = eval(code_block) + # 验证格式是否为二维数组 + if isinstance(grid, list) and all(isinstance(row, list) for row in grid): + return json.dumps(grid) + except Exception as e: + # print(f"解析代码块失败: {e}") + continue + + # 如果没有找到有效的代码块,尝试直接寻找列表 + list_pattern = r'\[\s*\[.*?\]\s*\]' + matches = re.findall(list_pattern, test_solution, re.DOTALL) + if matches: + try: + # 尝试解析为Python列表 + grid = eval(matches[-1]) + # 验证格式是否为二维数组 + if isinstance(grid, list) and all(isinstance(row, list) for row in grid): + return json.dumps(grid) + except Exception as e: + pass + # print(f"解析列表失败: {e}") + + # 如果上述方法都失败,返回空字符串 + return "" \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/boolean_expressions_verifier.py b/src/reasoning360/utils/reward_score/synlogic/boolean_expressions_verifier.py new file mode 100644 index 000000000..8edab88f2 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/boolean_expressions_verifier.py @@ -0,0 +1,54 @@ +import re +from .data import Data +from .verifier import Verifier + +class BooleanExpressionsVerifier(Verifier): + """ + 验证器用于布尔表达式游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + test_answer = self.extract_answer(test_answer) + if test_answer is None: + return False + # 提取所有字母(a-z和A-Z) + test_answer_letters = re.findall(r'[a-zA-Z]', test_answer) + ground_truth_letters = re.findall(r'[a-zA-Z]', data.answer) + test_answer_letters = self.lower(test_answer_letters) + ground_truth_letters = self.lower(ground_truth_letters) + # 转换为集合进行比较 + test_set = set(test_answer_letters) + ground_truth_set = set(ground_truth_letters) + + return test_set == ground_truth_set + except Exception as e: + print("NOTE!!! parse error!!!! (BooleanExpressions)", e) + return False + + def lower(self, answer_list): + return [answer.lower() for answer in answer_list] + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/campsite_verifier.py b/src/reasoning360/utils/reward_score/synlogic/campsite_verifier.py new file mode 100644 index 000000000..4aee1a5b2 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/campsite_verifier.py @@ -0,0 +1,183 @@ +from .data import Data +from .verifier import Verifier +import re +import ast +from typing import List, Set, Tuple, Dict + + +class CampsiteVerifier(Verifier): + """ + Verifier for Campsite game + """ + def verify(self, data: Data, test_solution: str): + try: + test_answer = self.extract_answer(test_solution) + original_grid = data.metadata["grid"] + row_constraints = data.metadata["row_constraints"] + col_constraints = data.metadata["col_constraints"] + n = data.metadata["n"] + m = data.metadata["m"] + + if not test_answer: + return False + + if len(test_answer) != n or any(len(row) != m for row in test_answer): + return False + + if not self._check_trees_unchanged(original_grid, test_answer): + return False + + if not self._check_row_constraints(test_answer, row_constraints): + return False + + if not self._check_col_constraints(test_answer, col_constraints): + return False + + if not self._check_tents_not_adjacent(test_answer): + return False + + if not self._check_tent_tree_matching(test_answer): + return False + + return True + + except Exception as e: + print(f"Verification error (Campsite): {e}") + return False + + def _extract_grid(self, test_answer: str) -> List[List[str]]: + """从回答中提取网格""" + grid_pattern = r'\[\s*\[.*?\]\s*\]' + match = re.search(grid_pattern, test_answer, re.DOTALL) + if match: + try: + grid_str = match.group(0) + return ast.literal_eval(grid_str) + except: + pass + + return None + + def _check_trees_unchanged(self, original_grid: List[List[str]], test_answer: List[List[str]]) -> bool: + """检查树木位置是否保持不变""" + for i in range(len(original_grid)): + for j in range(len(original_grid[0])): + if original_grid[i][j] == 'T' and test_answer[i][j] != 'T': + return False + if original_grid[i][j] != 'T' and test_answer[i][j] == 'T': + return False + return True + + def _check_row_constraints(self, grid: List[List[str]], row_constraints: List[int]) -> bool: + """检查行约束条件""" + for i in range(len(grid)): + tent_count = sum(1 for cell in grid[i] if cell == 'C') + if tent_count != row_constraints[i]: + return False + return True + + def _check_col_constraints(self, grid: List[List[str]], col_constraints: List[int]) -> bool: + """检查列约束条件""" + for j in range(len(grid[0])): + tent_count = sum(1 for i in range(len(grid)) if grid[i][j] == 'C') + if tent_count != col_constraints[j]: + return False + return True + + def _check_tents_not_adjacent(self, grid: List[List[str]]) -> bool: + """检查帐篷之间是否相邻(包括对角线)""" + n = len(grid) + m = len(grid[0]) if n > 0 else 0 + + for i in range(n): + for j in range(m): + if grid[i][j] == 'C': + # 检查周围8个方向是否有其他帐篷 + for di in [-1, 0, 1]: + for dj in [-1, 0, 1]: + if di == 0 and dj == 0: + continue + ni, nj = i + di, j + dj + if 0 <= ni < n and 0 <= nj < m and grid[ni][nj] == 'C': + return False + + return True + + def _check_tent_tree_matching(self, grid: List[List[str]]) -> bool: + """ + 检查帐篷与树木的一一匹配关系: + 1. 每个帐篷必须与一棵树正交相邻 + 2. 每棵树只能与一个帐篷匹配 + 3. 每个帐篷只能与一棵树匹配 + 4. 帐篷和树的数量必须相等 + """ + n = len(grid) + m = len(grid[0]) if n > 0 else 0 + + tents = [] + trees = [] + for i in range(n): + for j in range(m): + if grid[i][j] == 'C': + tents.append((i, j)) + elif grid[i][j] == 'T': + trees.append((i, j)) + + if len(tents) != len(trees): + return False + + tent_to_trees = {} + tree_to_tents = {} + + for tent_i, tent_j in tents: + tent_to_trees[(tent_i, tent_j)] = [] + for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + tree_i, tree_j = tent_i + di, tent_j + dj + if 0 <= tree_i < n and 0 <= tree_j < m and grid[tree_i][tree_j] == 'T': + tent_to_trees[(tent_i, tent_j)].append((tree_i, tree_j)) + + for tree_i, tree_j in trees: + tree_to_tents[(tree_i, tree_j)] = [] + for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + tent_i, tent_j = tree_i + di, tree_j + dj + if 0 <= tent_i < n and 0 <= tent_j < m and grid[tent_i][tent_j] == 'C': + tree_to_tents[(tree_i, tree_j)].append((tent_i, tent_j)) + + for tent in tents: + if not tent_to_trees[tent]: + return False + + tent_matched = {} + tree_matched = {} + + def dfs(tent): + for tree in tent_to_trees[tent]: + if tree in visited: + continue + visited.add(tree) + + if tree not in tree_matched or dfs(tree_matched[tree]): + tent_matched[tent] = tree + tree_matched[tree] = tent + return True + return False + + for tent in tents: + visited = set() + if tent not in tent_matched: + if not dfs(tent): + return False + + return len(tent_matched) == len(tents) and len(tree_matched) == len(trees) + + def extract_answer(self, test_solution: str): + """从模型回答中提取解决方案""" + grid_pattern = r'\[\s*\[.*?\]\s*\]' + match = re.search(grid_pattern, test_solution, re.DOTALL) + if match: + try: + grid_str = match.group(0) + return ast.literal_eval(grid_str) + except: + pass + return "" \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/data.py b/src/reasoning360/utils/reward_score/synlogic/data.py new file mode 100644 index 000000000..8f7d2998d --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/data.py @@ -0,0 +1,51 @@ +import json + +class Data: + """ + Data class for game/corpus + @param question: question of the game/corpus + @param answer: answer of the game/corpus + @param difficulty: difficulty of the game/corpus, from 1 to 10 + """ + def __init__(self, question: str, answer: str, difficulty: int = 1, metadata: dict = None, **kwargs): + self.question = question + self.answer = answer + self.difficulty = difficulty + self.metadata = metadata + self.gpt_response = "" + + def to_json(self): + return { + "question": self.question, + "answer": self.answer, + "difficulty": self.difficulty, + "metadata": self.metadata, + "gpt_response": self.gpt_response + } + + def to_json_str(self): + return json.dumps(self.to_json(), ensure_ascii=False) + + @classmethod + def from_json_str(cls, json_str): + json_data = json.loads(json_str) + return cls(**json_data) + + @classmethod + def from_json_dict(cls, json_dict): + instance = cls(**json_dict) + if 'gpt_response' in json_dict: + instance.gpt_response = json_dict['gpt_response'] + return instance + + @classmethod + def from_jsonl_file(cls, file_path): + data_list = [] + with open(file_path, "r") as f: + for line in f: + json_data = json.loads(line) + instance = cls(**json_data) + if 'gpt_response' in json_data: + instance.gpt_response = json_data['gpt_response'] + data_list.append(instance) + return data_list \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/dyck_language_errors_verifier.py b/src/reasoning360/utils/reward_score/synlogic/dyck_language_errors_verifier.py new file mode 100644 index 000000000..3bfbdac76 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/dyck_language_errors_verifier.py @@ -0,0 +1,91 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + + +class DyckLanguageErrorsVerifier(Verifier): + """ + 验证器用于检查括号闭合错误识别游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution=test_answer) + # 获取正确答案 + if data.metadata["is_valid"]: + correct_answer = "-1" # 合法序列对应-1 + else: + correct_answer = str(data.metadata["first_error_pos"]) + + # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'") + + # 清理和标准化答案 + test_answer = test_answer.strip() + + # 检查-1答案(合法序列) + if correct_answer == "-1": + # 如果正确答案是-1(合法序列),只接受-1作为回答 + if test_answer == "-1": + is_correct = True + else: + is_correct = False + else: + # 正确答案是位置数字,需要验证模型回答也是相同数字 + try: + is_correct = (int(test_answer) == int(correct_answer)) + except (ValueError, TypeError): + # 如果模型回答不是有效数字,验证失败 + is_correct = False + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + + except Exception as e: + print(f"Verification error (DyckLanguageErrors): {e}") + return False + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 + """ + answer_str = test_solution + if answer_str is None: + import re + # 清理回答文本 + solution = test_solution.strip() if test_solution else "" + + # 提取所有数字(包括负数) + numbers = re.findall(r'-?\d+', solution) + if numbers: + # 优先返回"-1"(如果存在) + if "-1" in numbers: + return "-1" + # 否则返回找到的第一个非负整数 + for num in numbers: + if num.isdigit() and int(num) >= 0: + return num + # 如果只有负数,返回第一个 + return numbers[0] + + # 检查是否表示合法 + + + # 默认返回空字符串 + return "" + elif any(keyword in answer_str.lower() for keyword in ["合法", "valid", "correct"]): + return "-1" + else: + return answer_str \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py b/src/reasoning360/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py new file mode 100644 index 000000000..03f2b95f9 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py @@ -0,0 +1,130 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + + +class DyckLanguageReasoningErrorsVerifier(Verifier): + """ + Dyck语言推理错误识别验证器 + """ + def verify(self, data: Data, test_answer: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的答案字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution=test_answer) + # 获取元数据中的正确答案 + correct_indices = data.metadata["error_indices"] + # 格式化为正确的答案字符串格式 + expected_answer = self._format_answer(correct_indices) + + # print(f"验证: 模型答案='{test_answer}', 正确答案='{expected_answer}'") + + # 检查不明确的答案 + if "不确定" in test_answer or "不知道" in test_answer or "unclear" in test_answer.lower(): + # print("验证结果: 错误") + return False + + # 清理模型答案,允许一定的格式变化 + cleaned_test_answer = self._standardize_answer(test_answer) + + if not correct_indices and (cleaned_test_answer == "" or cleaned_test_answer.lower() in ["无问题", "no", "无错误", "no error", "no errors", "no mistakes", "all correct"]): + # 如果没有错误,且模型回答是空字符串或表示无问题,则正确 + is_correct = True + else: + # 将两个答案转换为数字集合进行比较 + test_error_indices = self._extract_error_indices(cleaned_test_answer) + expected_error_indices = set(correct_indices) + + # 检查两个集合是否相同 + is_correct = test_error_indices == expected_error_indices + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + + except Exception as e: + print(f"Verification error (DyckLanguageReasoningErrors): {e}") + return False + + def _standardize_answer(self, answer: str) -> str: + """ + 标准化答案字符串 + + @param answer: 原始答案字符串 + @return: 标准化后的答案字符串 + """ + # 如果答案为空或仅包含空白字符 + if not answer or answer.strip() == "": + return "" + + # 如果答案表示没有错误 + if answer.lower() in ["无问题", "no", "无错误", "no error", "no errors", "no mistakes", "all correct"]: + return "" + + # 替换中文逗号为英文逗号 + answer = answer.replace(',', ',') + # 移除所有非数字和逗号的字符 + answer = re.sub(r'[^0-9,]', '', answer) + + return answer + + def _extract_error_indices(self, answer: str) -> set: + """ + 从答案字符串中提取错误索引集合 + + @param answer: 答案字符串 + @return: 错误索引集合 + """ + if not answer or answer.strip() == "": + return set() + + # 先替换中文逗号为英文逗号 + answer = answer.replace(',', ',') + + # 分割逗号分隔的数字字符串,并转换为整数集合 + try: + return set(int(idx) for idx in answer.split(',') if idx and idx.strip().isdigit()) + except ValueError: + return set() + + def _format_answer(self, error_indices: list) -> str: + """ + 格式化答案字符串 + + @param error_indices: 错误步骤的索引列表 + @return: 格式化的答案字符串 + """ + if not error_indices: + return "" + + # 按照数字大小排序 + sorted_indices = sorted(error_indices) + + # 转换为字符串,使用英文逗号分隔 + return ",".join(map(str, sorted_indices)) + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取错误索引 + + @param test_solution: 模型的完整回答 + """ + answer_str = test_solution + return answer_str + +if __name__ == "__main__": + test_response = '''\nOkay, so I'm trying to figure out if there are any errors in the reasoning process provided for checking the proper closure of brackets in the sequence: {}[][][]{}{}([]). Let me go through each thought step by step and see where things might have gone wrong.\n\nStarting with Thought 2: the initial stack is empty, which makes sense because we haven't processed any characters yet. So that seems correct.\n\nThought 3: The first character is '{', so it's pushed onto the stack. The stack becomes '{'. Hmm, but in the thought, it's written as '{[', which is incorrect. The stack after the first character should only have '{'. So this thought is wrong because it adds an extra '[' that wasn't there yet. Wait, no, actually, the input sequence starts with '{', so maybe there's a mistake in the thought. Let me recheck. Oh, looking again, no, the input sequence is {}[][][]{}{}([)]}, so the first two characters are '{' and '}'. So after processing '{', the stack should be '{', not '{['. So Thought 3 is wrong because it mistakenly pushes '[' instead of just '{'.\n\nMoving to Thought 4: Next character is '}' which correctly matches the top '{' in the stack, so the stack becomes empty. This part is correct.\n\nThought 5: The next character is '[', so stack becomes '['. Correct.\n\nThought 6: Next is ']', which pops the '[', so stack is empty. Correct.\n\nThought 7: Next is '[', stack becomes '['. Correct.\n\nThought 8: Next is ']', so stack is empty. Correct.\n\nThought 9: Next is '[', stack becomes '['. Correct.\n\nThought 10: Next is ']' which pops the '[', so stack should be empty again. However, in the thought, it says the stack becomes ']', which is incorrect. So Thought 10 is wrong because it added a ']' instead of popping it off. Wait no, the character is ']', which should pop the top '[', so the stack should go from '[' to empty. But the thought says the stack becomes ']' which is wrong. So Thought 10 is incorrect.\n\nThought 11: Next is '{', stack is now '{'. Correct.\n\nThought 12: Next is '}', popping the '{', stack is empty. Correct.\n\nThought 13: Next is '{', stack becomes '{'. Correct.\n\nThought 14: Next is '}', stack is empty. Correct.\n\nThought 15: Next is '(', stack becomes '('. Correct.\n\nThought 16: Next is '[', so stack becomes '(['. Correct.\n\nThought 17: Next is ']', which pops the '[', so stack becomes '('. Correct.\n\nThought 18: Next is ')', which pops the '(', so stack is empty. Correct.\n\nThought 19: After processing all characters, stack is empty. So the sequence is balanced.\n\nWait, but the input sequence actually ends with ']})', right? The sequence given is {}[][][]{}{}([]). So after processing '(' and '[', we have '(', then '['. Then when we process ']', it pops the '[', stack becomes '('. Then next is ')', which pops '(', stack is empty. So everything is correct up to that point.\n\nBut in the earlier steps, the issue was in Thought 10. Because when processing the fifth character, which is '[', followed by ']', the stack should go back to empty. So Thought 10 says the stack becomes ']', which is incorrect. It should go back to empty.\n\nWait, let me recount the steps:\n\nThe input is {}[][][]{}{}([]). So let's break it down:\n\n1. '{' : stack becomes '{'\n2. '}' : stack empty\n3. '[' : stack becomes '['\n4. ']' : stack empty\n5. '[' : '['\n6. ']' : stack empty\n7. '[' : '['\n8. ']' : stack empty\n9. '{' : '{'\n10. '}' : stack empty\n11. '{' : '{'\n12. '}' : stack empty\n13. '(' : '('\n14. '[' : '(['\n15. ']' : '('\n16. ')' : stack empty.\n\nWait, but in the original thought process, after step 10, the character is ']' which pops the '[', making the stack empty, but the thought says stack is ']'. So that's definitely incorrect.\n\nAlso, in step 3, the thought says after '{' the stack is '{[', which is incorrect because the next character is '}', not another '['. So step 3 is wrong because it added an extra '[' instead of just '{' and then processing '}' in step 4 correctly.\n\nSo the errors are in Thought 3 and Thought 10.\n\n\nThe errors are in Thought 3 and Thought 10 because Thought 3 incorrectly adds an extra '[' instead of just '{' and Thought 10 mistakenly leaves the stack as ']' instead of empty after popping.\n\n[3,20]''' + metadata = {"trace_id": "77db72eb-a9db-46cd-96ea-5a49eba78792", "dyck_sequence": "{}[][][]{}{}([])", "thoughts": ["Thought 1: 我们应该逐个处理输入并跟踪栈的配置。", "Thought 2: 栈: 空", "Thought 3: { ; 栈: {[", "Thought 4: } ; 栈: 空", "Thought 5: [ ; 栈: [", "Thought 6: ] ; 栈: 空", "Thought 7: [ ; 栈: [", "Thought 8: ] ; 栈: 空", "Thought 9: [ ; 栈: [", "Thought 10: ] ; 栈: ]", "Thought 11: { ; 栈: {", "Thought 12: } ; 栈: 空", "Thought 13: { ; 栈: {", "Thought 14: } ; 栈: 空", "Thought 15: ( ; 栈: (", "Thought 16: [ ; 栈: ([", "Thought 17: ] ; 栈: (", "Thought 18: ) ; 栈: 空", "Thought 19: 现在,我们已经到达结尾。最终栈是空的。"], "error_indices": [3, 10], "n_types": 3, "total_length": 15, "n_errors": 2} + test_data = Data(question="", answer="", metadata=metadata) + test_verifier = DyckLanguageReasoningErrorsVerifier() + extracted_answer = test_verifier.extract_answer(test_response) + print(extracted_answer) + print(test_verifier.verify(data=test_data, test_answer=test_response)) \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/dyck_language_verifier.py b/src/reasoning360/utils/reward_score/synlogic/dyck_language_verifier.py new file mode 100644 index 000000000..e986f66c5 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/dyck_language_verifier.py @@ -0,0 +1,82 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + + +class DyckLanguageVerifier(Verifier): + """ + 验证器用于检查Dyck Language游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str) -> bool: + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + # 获取元数据中的完整序列 + full_sequence = data.metadata["full_sequence"] + + # print(f"验证: 模型答案='{test_answer}', 完整序列='{full_sequence}'") + + # 从模型回答中提取答案 + extracted_answer = self.extract_answer(test_answer) + + # 检查答案是否完全匹配 + is_correct = (extracted_answer == full_sequence) + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + + except Exception as e: + print(f"Verification error (DyckLanguage): {e}") + return False + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取括号序列答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 + """ + if not test_solution: + return "" + + # print(f"原始回答:\n{test_solution}") + + def clean_text(text: str) -> str: + """清理文本,处理转义字符和空白字符""" + # 移除所有空白字符(包括换行符、制表符等) + text = ''.join(text.split()) + + # 处理转义序列 + text = text.replace('\\n', '') + text = text.replace('\\t', '') + text = text.replace('\\r', '') + text = text.replace('\\\\', '\\') + + # 如果文本被引号包围,且引号不是括号序列的一部分,则移除外层引号 + if len(text) >= 2: + if text.startswith('"') and text.endswith('"'): + text = text[1:-1] + elif text.startswith("'") and text.endswith("'"): + text = text[1:-1] + + return text + + return clean_text(test_solution) + +if __name__ == "__main__": + test_response = '''填写后的完整序列应为“([])({})([()])”。\n\n检查一下长度是否正确:\n\n原序列长度为11字符,补充3个字符,总长度14。\n\n这样,整个序列是合法的。\n
\n\n([])({})([()])''' + metadata = {"trace_id": "38aeede4-d5d7-4863-91d2-df1fd99f491b", "full_sequence": "([])({})([()])", "question_sequence": "([])({})([(", "n_types": 3, "total_length": 14, "fill_length": 3, "nesting_depth": 0} + test_data = Data(question="", answer="", metadata=metadata) + test_verifier = DyckLanguageVerifier() + extracted_answer = test_verifier.extract_answer(test_response) + print(extracted_answer) + print(test_verifier.verify(data=test_data, test_answer=test_response)) \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py b/src/reasoning360/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py new file mode 100644 index 000000000..b4c8bdb25 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py @@ -0,0 +1,126 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + +class BuggyTableVerifier(Verifier): + """ + Verifier for the BuggyTable game. + Checks if the submitted answer matches the expected answer. + """ + def extract_answer(self, answer: str) -> str: + """ + Public method to extract and normalize an answer string from LLM output. + Delegates to the private _extract_answer method. + + @param answer: The answer string to normalize + @return: The normalized answer string + """ + return self._extract_answer(answer) + + def verify(self, data: Data, test_answer: str) -> bool: + """ + Verify whether the test answer is consistent with the expected answer + for the buggy table query. + + @param data: Data object containing the expected answer + @param test_answer: The answer provided by the LLM to verify + @return: bool indicating whether the answer is correct + """ + # Extract the expected answer from the Data object + expected_answer = data.answer if data and hasattr(data, 'answer') else "" + + # For empty strings, compare directly + if not expected_answer and not test_answer: + return True + + # Extract and normalize both answers + normalized_expected = self._extract_answer(expected_answer) + normalized_test = self._extract_answer(test_answer) + + # Direct comparison of normalized answers + return normalized_expected == normalized_test + + def _is_raw_numeric_answer(self, value: str) -> bool: + """ + Check if a string represents a plain numeric answer without additional context. + This is used to validate raw input format. + + @param value: The string to check + @return: True if the string is a simple numeric value + """ + # Remove whitespace + value = value.strip() + + # Simple pattern match for a number (optionally with sign and decimal point) + import re + return bool(re.match(r'^-?\d+(\.\d+)?$', value)) + + def _raw_has_exactly_two_decimals(self, value: str) -> bool: + """ + Check if a raw numeric string has exactly 2 decimal places. + This is used to validate the format of the raw answer. + + @param value: The string to check + @return: True if the string has exactly 2 decimal places + """ + # Remove whitespace + value = value.strip() + + # Split on decimal point + parts = value.replace('-', '', 1).split('.') + + # Check if there is exactly one decimal point and two digits after it + return len(parts) == 2 and len(parts[1]) == 2 + + def _is_numeric(self, value: str) -> bool: + """ + Check if a string represents a valid number (including negative numbers and decimals). + + @param value: The string to check + @return: True if the string represents a valid number + """ + # Remove negative sign if present + value = value.strip() + if value.startswith('-'): + value = value[1:] + # Check if remaining string is a valid decimal number + return value.replace('.', '', 1).isdigit() + + def _has_exactly_two_decimals(self, value: str) -> bool: + """ + Check if a number string has exactly 2 decimal places. + + @param value: The number string to check + @return: True if the number has exactly 2 decimal places + """ + # Remove negative sign if present + value = value.strip() + if value.startswith('-'): + value = value[1:] + + # Split into whole and decimal parts + parts = value.split('.') + if len(parts) != 2: + return False + + # Check if decimal part has exactly 2 digits + return len(parts[1]) == 2 + + def _extract_answer(self, answer: str) -> str: + """ + Extract and normalize an answer string from LLM output. + Only finds values with exactly two decimal places. + + @param answer: The answer string to normalize + @return: The normalized answer string + """ + # Convert to string and normalize + normalized = str(answer).strip() if answer is not None else "" + + # Try to find numbers with exactly two decimal places + exact_matches = re.findall(r'-?\d+\.\d{2}\b', normalized) + if exact_matches: + return exact_matches[-1] # Return the last match with exactly two decimals + + # If no exact two-decimal match found, return the original string + return normalized \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/goods_exchange_verifier.py b/src/reasoning360/utils/reward_score/synlogic/goods_exchange_verifier.py new file mode 100644 index 000000000..922e92689 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/goods_exchange_verifier.py @@ -0,0 +1,218 @@ +import re +from .data import Data +from .verifier import Verifier + +class GoodsExchangeVerifier(Verifier): + """ + 验证器用于检查物品交换游戏的答案是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution) + # 获取元数据中的正确答案 + correct_answer = data.metadata["owns_after"] + + # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'") + + # 解析模型答案 + model_ownership = self._parse_answer(test_answer) + # 解析正确答案 + correct_ownership = self._parse_answer(correct_answer) + + # 比较两个答案是否完全一致 + is_correct = self._compare_answers(model_ownership, correct_ownership) + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + # # 打印详细的不匹配信息 + # self._print_difference(model_ownership, correct_ownership) + + return is_correct + + except Exception as e: + print(f"Verification error (GoodsExchange): {e}") + return False + + def _parse_answer(self, answer_str): + """ + 解析答案字符串为物品归属字典 + + @param answer_str: 答案字符串,格式为"(('人1','物品1'),('人2','物品2'),...)"或"(人1,物品1),(人2,物品2),..." + @return: 归属关系字典 {人: 物品} + """ + if not answer_str: + return {} + + result = {} + try: + # 预处理:只处理最外层的空格,保留内部结构 + answer_str = answer_str.strip() + + # 尝试使用 eval 解析 Python tuple 格式 + pairs = eval(answer_str) + if isinstance(pairs, tuple): + for pair in pairs: + if isinstance(pair, tuple) and len(pair) == 2: + person, item = pair + # 处理每个值中的空格:移除两端空格 + result[person.strip()] = item.strip() + return result + except Exception as e: + # 如果 eval 失败,记录错误并尝试解析旧格式 + print(f"NOTE!!! parse error!!!! (GoodsExchange 1): {e}") + + # 移除最外层的括号(如果有) + if answer_str.startswith('('): + answer_str = answer_str[1:] + if answer_str.endswith(')'): + answer_str = answer_str[:-1] + + # 更健壮的手动解析逻辑 + person_item_pairs = [] + current_pair = "" + bracket_count = 0 + + # 更智能地分割答案字符串 + for char in answer_str: + if char == '(': + bracket_count += 1 + current_pair += char + elif char == ')': + bracket_count -= 1 + current_pair += char + if bracket_count == 0: + person_item_pairs.append(current_pair) + current_pair = "" + elif char == ',' and bracket_count == 0: + # 跳过顶层逗号 + continue + else: + current_pair += char + + # 处理每一对 + for pair in person_item_pairs: + pair = pair.strip() + # 移除括号 + if pair.startswith('('): + pair = pair[1:] + if pair.endswith(')'): + pair = pair[:-1] + + # 拆分人和物品 + try: + # 使用更健壮的分割方法 + parts = [] + quote_count = 0 + current = "" + + for char in pair: + if char in "\"'" and (len(current) == 0 or current[-1] != '\\'): + quote_count = 1 - quote_count + + if char == ',' and quote_count == 0: + parts.append(current.strip()) + current = "" + else: + current += char + + if current: + parts.append(current.strip()) + + if len(parts) >= 2: + person = parts[0].strip().strip("'\"") + item = parts[1].strip().strip("'\"") + result[person] = item + except Exception as e: + print(f"NOTE!!! parse error!!!! (GoodsExchange 2): {e}") + + return result + + def _compare_answers(self, model_ownership, correct_ownership): + """ + 比较两个归属关系字典是否相同 + + @param model_ownership: 模型回答的归属关系 + @param correct_ownership: 正确的归属关系 + @return: 是否完全一致 + """ + # 检查人数是否相同 + if len(model_ownership) != len(correct_ownership): + return False + + # 创建小写人名到原始人名的映射 + model_lower_to_original = {person.lower(): person for person in model_ownership} + + # 检查每个人的物品是否一致 + for person in correct_ownership: + # 如果模型答案中没有这个人(不区分大小写) + if person.lower() not in model_lower_to_original: + return False + + # 获取模型答案中对应的原始人名 + model_person = model_lower_to_original[person.lower()] + + # 如果人的物品不匹配(不区分大小写) + if model_ownership[model_person].lower() != correct_ownership[person].lower(): + return False + + return True + + def _print_difference(self, model_ownership, correct_ownership): + """ + 打印两个归属关系之间的差异 + + @param model_ownership: 模型回答的归属关系 + @param correct_ownership: 正确的归属关系 + """ + print("\n差异详情:") + + # 创建小写人名到原始人名的映射 + model_lower_to_original = {person.lower(): person for person in model_ownership} + correct_lower_to_original = {person.lower(): person for person in correct_ownership} + + # 检查正确答案中的每个人 + for person in correct_ownership: + person_lower = person.lower() + if person_lower not in model_lower_to_original: + # print(f" - 模型答案中缺少: {person}") + pass + else: + model_person = model_lower_to_original[person_lower] + # if model_ownership[model_person].lower() != correct_ownership[person].lower(): + # print(f" - {person}: 模型答案={model_ownership[model_person]}, 正确答案={correct_ownership[person]}") + + # 检查模型答案中的额外人员 + # for person in model_ownership: + # if person.lower() not in correct_lower_to_original: + # print(f" - 模型答案中多余: {person}") + + def extract_answer(self, text): + """从文本中提取答案。 + + Args: + text (str): 输入文本 + + Returns: + str: 提取的答案,格式为 "(('人1','物品1'),('人2','物品2'),...)" + """ + if not text: + return "" + + # 尝试从 Python markdown 代码块中提取 + code_block_pattern = r'```python\s*\n(.*?)\n```' + code_blocks = re.findall(code_block_pattern, text, re.DOTALL) + if code_blocks: + # 使用最后一个代码块 + last_block = code_blocks[-1].strip() + if last_block.startswith("(") and last_block.endswith(")"): + return last_block + return "" \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/math_path_verifier.py b/src/reasoning360/utils/reward_score/synlogic/math_path_verifier.py new file mode 100644 index 000000000..d0df1c4a8 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/math_path_verifier.py @@ -0,0 +1,100 @@ +import re +import json +import numpy as np +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + + +class MathPathVerifier(Verifier): + """ + 验证器用于检查math_path填充游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的运算表达式 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution=test_answer) + except Exception as e: + print(f"NOTE!!! parse error!!!! (MathPath): {e}") + return False + + try: + # 解析元数据 + metadata = data.metadata + ref_expr = metadata["ref_expr"] + query_expr = metadata["query_expr"] + + # 验证数字是否被篡改,数字是否在0-9之间。 + test_tmp = test_answer.replace(' ', '').strip() + query_tmp = query_expr.replace(' ', '').strip() + ref_tmp = ref_expr.replace(' ', '').strip() + query_nums = [x for x in query_tmp if '0'<=x<='9' or x=='?'] + test_nums = [x for x in test_tmp if '0'<=x<='9'] + if len(query_nums)!=len(test_nums): + # print(f"所填数字数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + else: + for ind, x in enumerate(query_nums): + if x=='?': + continue + if x!=test_nums[ind]: + # print(f"表达式数字被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + + query_symbols = [x for x in query_tmp if x in ['+', '-', '*', '/', '%']] + test_symbols = [x for x in test_tmp if x in ['+', '-', '*', '/', '%']] + if len(query_symbols)!=len(test_symbols): + # print(f"表达式运算符号数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + else: + for ind, x in enumerate(query_symbols): + if x!=test_symbols[ind]: + # print(f"表达式运算符号被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + + # 验证回答中的等式是否成立 + try: + tmp = test_tmp.replace('=', '==') + if not eval(tmp): + # print(f"等式不成立!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + except: + # print(f"运算表达式错误!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + + + # 所有检查都通过 + # print("验证结果: 正确") + return True + + except Exception as e: + print(f"Verification error (MathPath): {e}") + return False + + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取答案(字符表达式) + + @param test_solution: 模型的完整回答 + @return: 提取的矩阵答案字符串 + """ + if not test_solution: + return "" + # 尝试提取Python代码块中的矩阵 + code_block_pattern = r'\[\[(.*?)\]\]' + code_matches = re.findall(code_block_pattern, test_solution) + + if code_matches: + # 使用最后一个匹配内容 + operation_expression = code_matches[-1].strip() + return operation_expression + + # 如果所有方法都失败,返回空字符串 + return "" + \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/minesweeper_verifier.py b/src/reasoning360/utils/reward_score/synlogic/minesweeper_verifier.py new file mode 100644 index 000000000..1b2791403 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/minesweeper_verifier.py @@ -0,0 +1,61 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +from typing import List, Tuple + + +class MinesweeperVerifier(Verifier): + """ + Verifier for Minesweeper puzzle + 扫雷游戏验证器 + """ + def verify(self, data: Data, test_solution: str, **kwargs): + try: + # 从解答中提取地雷坐标 + predicted_mines = self.extract_answer(test_solution) + + # 从metadata中获取确定性地雷坐标 + expected_mines = data.metadata["current_mines"] + + # 验证提取的坐标是否正确 + if set(tuple(mine) for mine in predicted_mines) == set(tuple(mine) for mine in expected_mines): + return True + + return False + + except Exception as e: + # 如果验证过程中发生任何错误,返回False + print(f"Verification error (Minesweeper): {e}") + return False + + def extract_answer(self, response: str) -> List[Tuple[int, int]]: + """从模型的响应中提取地雷坐标 + Extract mine coordinates from the model's response""" + patterns = [ + r'\[\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)(?:\s*,\s*\(\s*\d+\s*,\s*\d+\s*\))*\s*\]', # [(0,1),(2,3)] + r'\[\s*\[\s*(\d+)\s*,\s*(\d+)\s*\](?:\s*,\s*\[\s*\d+\s*,\s*\d+\s*\])*\s*\]', # [[0,1],[2,3]] + r'\(\s*(\d+)\s*,\s*(\d+)\s*\)(?:\s*,\s*\(\s*\d+\s*,\s*\d+\s*\))*', # (0,1),(2,3) + ] + + for pattern in patterns: + coords = [] + for match in re.finditer(pattern, response): + try: + # 提取所有坐标对 + coord_pattern = r'(?:\(|\[)\s*(\d+)\s*,\s*(\d+)\s*(?:\)|\])' + for coord_match in re.finditer(coord_pattern, match.group(0)): + i, j = int(coord_match.group(1)), int(coord_match.group(2)) + coords.append((i, j)) + except Exception: + continue + + if coords: + return coords + + # 如果没有找到坐标,尝试查找可能是坐标的任何数字 + number_pairs = re.findall(r'(\d+)[^\d]+(\d+)', response) + if number_pairs: + return [(int(i), int(j)) for i, j in number_pairs] + + return [] \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/norinori_verifier.py b/src/reasoning360/utils/reward_score/synlogic/norinori_verifier.py new file mode 100644 index 000000000..95cc98ed1 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/norinori_verifier.py @@ -0,0 +1,188 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +from collections import defaultdict + +class NorinoriVerifier(Verifier): + """ + Norinori 游戏的验证器 + 检查提交的答案是否符合 Norinori 游戏规则 + """ + + def __init__(self): + super().__init__() + + def verify(self, data: Data, test_solution: str): + """ + 验证 Norinori 游戏的答案 + + 参数: + data -- 游戏数据,包含区域网格等信息 + test_solution -- 用户提交的答案,应为多米诺坐标列表 + + 返回: + bool -- 答案是否正确 + """ + try: + # 从游戏数据中获取区域网格 + region_grid = data.metadata["region_grid"] + n = len(region_grid) + + # 解析答案 + dominoes = self._parse_answer(test_solution) + if dominoes is None: + return False + + # 检查多米诺形状 + if not self._check_domino_shapes(dominoes): + return False + + # 创建覆盖网格 + covered = [[False for _ in range(n)] for _ in range(n)] + for domino in dominoes: + for i, j in domino: + # 转换为0-indexed + i -= 1 + j -= 1 + if i < 0 or i >= n or j < 0 or j >= n: + return False # 坐标超出范围 + if covered[i][j]: + return False # 格子被多次覆盖 + covered[i][j] = True + + # 检查多米诺之间是否相邻 + if not self._check_domino_adjacency(dominoes, n): + return False + + # 检查每个区域是否恰好有两个格子被覆盖 + region_coverage = defaultdict(int) + for i in range(n): + for j in range(n): + if covered[i][j] and region_grid[i][j] != "X": + region_coverage[region_grid[i][j]] += 1 + + for region, count in region_coverage.items(): + if count != 2: + return False + + # 检查所有阴影格子是否被覆盖 + for i in range(n): + for j in range(n): + if region_grid[i][j] == "X" and not covered[i][j]: + return False + + return True + except Exception as e: + print(f"Verification error (Norinori): {e}") + return False + + def _parse_answer(self, test_solution: str): + """ + 解析答案字符串,提取多米诺坐标 + + 参数: + test_solution -- 答案字符串 + + 返回: + list -- 多米诺坐标列表,如果格式不正确则返回None + """ + try: + # 使用正则表达式提取坐标对 + pattern = r'\[\((\d+),\s*(\d+)\),\s*\((\d+),\s*(\d+)\)\]' + matches = re.findall(pattern, test_solution) + + if not matches: + # 尝试另一种可能的格式 + pattern = r'\(\s*(\d+)\s*,\s*(\d+)\s*\)\s*,\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)' + matches = re.findall(pattern, test_solution) + + dominoes = [] + for match in matches: + i1, j1, i2, j2 = map(int, match) + dominoes.append([(i1, j1), (i2, j2)]) + + return dominoes + except Exception as e: + print(f"NOTE!!! parse error!!!! (Norinori): {e}") + return None + + def _check_domino_shapes(self, dominoes): + """ + 检查所有多米诺是否都是1×2或2×1的形状 + + 参数: + dominoes -- 多米诺坐标列表 + + 返回: + bool -- 是否所有多米诺都符合形状要求 + """ + for domino in dominoes: + if len(domino) != 2: + return False + + (i1, j1), (i2, j2) = domino + + # 检查是否为1×2或2×1 + if not ((i1 == i2 and abs(j1 - j2) == 1) or + (j1 == j2 and abs(i1 - i2) == 1)): + return False + + return True + + def _check_domino_adjacency(self, dominoes, n): + """ + 检查多米诺之间是否相邻 + + 参数: + dominoes -- 多米诺坐标列表 + n -- 网格大小 + + 返回: + bool -- 是否所有多米诺都不相邻 + """ + # 创建一个网格来标记每个多米诺的位置 + grid = [[-1 for _ in range(n+2)] for _ in range(n+2)] # 加2是为了处理边界 + + for idx, domino in enumerate(dominoes): + for i, j in domino: + # 转换为0-indexed并考虑边界 + grid[i][j] = idx + + # 检查每个多米诺是否与其他多米诺相邻 + for idx, domino in enumerate(dominoes): + for i, j in domino: + for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + ni, nj = i + di, j + dj + if 1 <= ni <= n and 1 <= nj <= n: # 检查是否在网格内 + if grid[ni][nj] != -1 and grid[ni][nj] != idx: + return False # 发现相邻的多米诺 + + return True + + def extract_answer(self, test_solution: str, strict=False): + """ + 从回答中提取答案 + + 参数: + test_solution -- 用户的回答 + strict -- 是否严格模式 + + 返回: + str -- 提取的答案 + """ + # 尝试找到答案部分 + answer_patterns = [ + r'\[\s*\[\s*\(\s*\d+\s*,\s*\d+\s*\)\s*,\s*\(\s*\d+\s*,\s*\d+\s*\)\s*\]', # 寻找格式如 [[(1,2), (1,3)], ...] 的答案 + r'答案是\s*(.*?)\s*$', # 中文格式 + r'answer is\s*(.*?)\s*$', # 英文格式 + r'solution is\s*(.*?)\s*$' # 另一种英文格式 + ] + + for pattern in answer_patterns: + matches = re.findall(pattern, test_solution, re.IGNORECASE | re.DOTALL) + if matches: + # 返回最后一个匹配项,通常是最终答案 + return matches[-1] + + # 如果没有找到明确的答案格式,返回整个解答 + return test_solution \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/number_wall_verifier.py b/src/reasoning360/utils/reward_score/synlogic/number_wall_verifier.py new file mode 100644 index 000000000..541c0720a --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/number_wall_verifier.py @@ -0,0 +1,226 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +from collections import deque + +class NumberWallVerifier(Verifier): + """ + Verifier for Number Wall puzzle + 数字墙拼图验证器 + """ + def verify(self, data: Data, test_solution: str, **kwargs): + try: + # 提取答案网格 + solution_grid = self.extract_answer(test_solution) + if not solution_grid: + # print("Failed to extract solution grid") + return False + + # 提取元数据 + original_grid = data.metadata["grid"] + n = data.metadata["n"] + + # 检查网格尺寸 + if len(solution_grid) != n: + # print(f"Solution grid has incorrect number of rows: {len(solution_grid)} != {n}") + return False + + for row in solution_grid: + if len(row) != n: + # print(f"Solution grid has incorrect number of columns: {len(row)} != {n}") + return False + + # 检查每个单元格只包含数字、"X"或"A" + for cell in row: + if not (isinstance(cell, int) or cell in ["X", "A"]): + # print(f"Invalid cell content: {cell}") + return False + + # 检查原始数字是否保留 + if not self._check_original_numbers(original_grid, solution_grid): + # print("Original numbers not preserved") + return False + + # 检查墙壁布局是否有效(没有2×2或更大的连续墙块) + if not self._check_wall_layout(solution_grid): + # print("Invalid wall layout (2x2 or larger continuous wall blocks found)") + return False + + # 检查岛屿划分是否有效 + if not self._check_islands(solution_grid): + # print("Invalid island division") + return False + + # 检查是否有斜线边 + if not self._check_diagonal_borders(solution_grid): + # print("Invalid solution: islands have diagonal borders") + return False + + return True + + except Exception as e: + # 如果验证过程中发生任何错误,返回False + print(f"Verification error (NumberWall): {e}") + return False + + def _check_original_numbers(self, original_grid, solution_grid): + """检查原始数字是否在解决方案中保留""" + for i in range(len(original_grid)): + for j in range(len(original_grid[i])): + if isinstance(original_grid[i][j], int): + if original_grid[i][j] != solution_grid[i][j]: + # print(f"Original number at ({i},{j}) changed: {original_grid[i][j]} -> {solution_grid[i][j]}") + return False + return True + + def _check_wall_layout(self, grid): + """检查墙壁布局是否有效(没有2×2或更大的连续墙块)""" + n = len(grid) + for i in range(n - 1): + for j in range(n - 1): + if (grid[i][j] == "A" and grid[i][j+1] == "A" and + grid[i+1][j] == "A" and grid[i+1][j+1] == "A"): + # print(f"Found 2x2 wall block at ({i},{j})") + return False + return True + + def _check_islands(self, grid): + """检查岛屿划分是否有效""" + n = len(grid) + visited = set() + + for i in range(n): + for j in range(n): + if (i, j) not in visited and grid[i][j] != "A": + # 发现一个新岛屿 + island_cells = [] + island_number = None + queue = deque([(i, j)]) + visited.add((i, j)) + + while queue: + r, c = queue.popleft() + island_cells.append((r, c)) + + if isinstance(grid[r][c], int): + if island_number is not None: + # 岛屿有多个数字 + # print(f"Island contains multiple numbers: {island_number} and {grid[r][c]}") + return False + island_number = grid[r][c] + + for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + nr, nc = r + dr, c + dc + if (0 <= nr < n and 0 <= nc < n and + (nr, nc) not in visited and + grid[nr][nc] != "A"): + queue.append((nr, nc)) + visited.add((nr, nc)) + + if island_number is None: + # 岛屿没有数字 + # print(f"Island at ({i},{j}) has no number") + return False + + if len(island_cells) != island_number: + # 岛屿大小与数字不匹配 + # print(f"Island size ({len(island_cells)}) doesn't match number ({island_number})") + return False + + return True + + def _check_diagonal_borders(self, grid): + """检查是否有斜线边(对角相邻的不同岛屿)""" + n = len(grid) + + # 标记所有岛屿 + island_map = {} # 映射格子坐标到岛屿ID + island_id = 0 + visited = set() + + for i in range(n): + for j in range(n): + if grid[i][j] != "A" and (i, j) not in visited: + # 发现一个新岛屿 + queue = deque([(i, j)]) + visited.add((i, j)) + + while queue: + r, c = queue.popleft() + island_map[(r, c)] = island_id + + for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + nr, nc = r + dr, c + dc + if (0 <= nr < n and 0 <= nc < n and + grid[nr][nc] != "A" and (nr, nc) not in visited): + queue.append((nr, nc)) + visited.add((nr, nc)) + + island_id += 1 + + # 检查斜线边 + for i in range(n - 1): + for j in range(n - 1): + # 检查2x2方格中的对角格子 + if (grid[i][j] != "A" and grid[i+1][j+1] != "A" and + grid[i][j+1] == "A" and grid[i+1][j] == "A"): + # 对角格子属于不同岛屿,存在斜线边 + if island_map.get((i, j)) != island_map.get((i+1, j+1)): + # print(f"Found diagonal border at ({i},{j}) and ({i+1},{j+1})") + return False + + # 检查另一对对角格子 + if (grid[i][j+1] != "A" and grid[i+1][j] != "A" and + grid[i][j] == "A" and grid[i+1][j+1] == "A"): + # 对角格子属于不同岛屿,存在斜线边 + if island_map.get((i, j+1)) != island_map.get((i+1, j)): + # print(f"Found diagonal border at ({i},{j+1}) and ({i+1},{j})") + return False + + return True + + + def extract_answer(self, response: str): + """从模型的响应中提取答案网格""" + # 在响应中寻找网格表示 + # 修改正则表达式以匹配字符串形式的数字 + grid_pattern = r'\[\s*\[(?:\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*,\s*)*\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*\]\s*(?:,\s*\[(?:\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*,\s*)*\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*\]\s*)*\]' + matches = re.findall(grid_pattern, response) + + if matches: + # 尝试解析最后一个匹配项 + grid_str = matches[-1] + + try: + # 尝试清理字符串,替换可能导致问题的字符 + cleaned_grid_str = grid_str.replace('\n', '').replace('\r', '').strip() + grid = json.loads(cleaned_grid_str) + + # 将字符串数字转换为整数 + for i in range(len(grid)): + for j in range(len(grid[i])): + if isinstance(grid[i][j], str) and grid[i][j].isdigit(): + grid[i][j] = int(grid[i][j]) + + return grid + except json.JSONDecodeError as e: + # 尝试使用 ast.literal_eval 作为备选方案 + try: + import ast + grid = ast.literal_eval(cleaned_grid_str) + + # 将字符串数字转换为整数 + for i in range(len(grid)): + for j in range(len(grid[i])): + if isinstance(grid[i][j], str) and grid[i][j].isdigit(): + grid[i][j] = int(grid[i][j]) + + return grid + except Exception as e2: + print(f"NOTE!!! parse error!!!! (NumberWall): {e2}") + else: + # print("No grid pattern found in the response") + pass + + return None \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/numbrix_verifier.py b/src/reasoning360/utils/reward_score/synlogic/numbrix_verifier.py new file mode 100644 index 000000000..f142bc7b7 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/numbrix_verifier.py @@ -0,0 +1,103 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import ast +import numpy as np + +class NumbrixVerifier(Verifier): + """ + Numbrix 游戏的验证器 + 验证提交的解答是否符合 Numbrix 游戏规则 + """ + def verify(self, data: Data, test_solution: str): + try: + # 提取答案网格 + test_grid = self.extract_answer(test_solution) + if not test_grid: + return False + + # 获取原始谜题和网格大小 + original_grid = data.metadata["grid"] + n = len(original_grid) + n_squared = n * n + + # 检查网格大小是否正确 + if len(test_grid) != n or any(len(row) != n for row in test_grid): + return False + + # 检查是否包含所有数字 1 到 n² + flattened_grid = [cell for row in test_grid for cell in row] + if sorted(flattened_grid) != list(range(1, n_squared + 1)): + return False + + # 检查是否保留了原始提示数字 + for i in range(n): + for j in range(n): + if original_grid[i][j] != "X" and test_grid[i][j] != original_grid[i][j]: + return False + + # 检查连续数字是否正交相邻 + for num in range(1, n_squared): + # 找到当前数字的位置 + current_pos = None + next_pos = None + for i in range(n): + for j in range(n): + if test_grid[i][j] == num: + current_pos = (i, j) + elif test_grid[i][j] == num + 1: + next_pos = (i, j) + + if current_pos is None or next_pos is None: + return False + + # 检查是否正交相邻(曼哈顿距离为1) + i1, j1 = current_pos + i2, j2 = next_pos + manhattan_distance = abs(i1 - i2) + abs(j1 - j2) + if manhattan_distance != 1: + return False + + return True + except Exception as e: + print(f"Verification error (Numbrix): {e}") + return False + + def extract_answer(self, test_solution: str, strict=False): + """从模型回答中提取网格""" + try: + import ast + import re + # 尝试找到 Python 列表格式的答案 + # 寻找形如 [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 的模式 + pattern = r'\[\s*\[\s*\d+.*?\]\s*\]' + matches = re.finditer(pattern, test_solution, re.DOTALL) + match = None + + # 获取最后一个匹配项 + for m in matches: + match = m + if not match: + return None + + # 提取匹配的文本并尝试解析为 Python 对象 + grid_text = match.group(0) + + # 清理文本,确保它是有效的 Python 列表 + # 移除可能导致解析错误的字符 + grid_text = grid_text.replace("'", "").replace('"', "") + + # 解析为 Python 对象 + grid = ast.literal_eval(grid_text) + + # 确保是二维列表且所有元素都是整数 + if not isinstance(grid, list) or not all(isinstance(row, list) for row in grid): + return None + + if not all(isinstance(cell, int) for row in grid for cell in row): + return None + + return grid + except Exception as e: + print(f"NOTE!!! parse error!!!! (Numbrix): {e}") + return None \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/object_counting_verifier.py b/src/reasoning360/utils/reward_score/synlogic/object_counting_verifier.py new file mode 100644 index 000000000..e63e4fd97 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/object_counting_verifier.py @@ -0,0 +1,45 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + + +class ObjectCountingVerifier(Verifier): + """ + 验证器用于物品计数游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = int(data.answer) + parsed_answer = self.extract_answer(test_answer) + with open("solution_str_OC.txt", "a") as f: + f.write("data.answer: " + data.answer + '\n') + f.write("test_answer: " + test_answer + '\n') + f.write("parsed_answer" + parsed_answer + '\n') + f.write('-'*32 + '\n') + + if parsed_answer is None: + return False + return int(parsed_answer) == ground_truth + + except Exception as e: + print(f"NOTE!!! parse error!!!! (ObjectCounting): {e}") + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\Box{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return answer_str + \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/object_properties_verifier.py b/src/reasoning360/utils/reward_score/synlogic/object_properties_verifier.py new file mode 100644 index 000000000..f8ad06ccd --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/object_properties_verifier.py @@ -0,0 +1,40 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + + +class ObjectPropertiesVerifier(Verifier): + """ + 验证器用于物品拥有游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = int(data.answer) + parsed_answer = int(self.extract_answer(test_answer)) + + if parsed_answer is None: + return False + return int(parsed_answer) == ground_truth + + except Exception as e: + print(f"NOTE!!! parse error!!!! (ObjectProperties): {e}") + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\Box{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\Box{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return None + \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/operation_verifier.py b/src/reasoning360/utils/reward_score/synlogic/operation_verifier.py new file mode 100644 index 000000000..7f730ce8a --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/operation_verifier.py @@ -0,0 +1,47 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import math_verify + + +class OperationVerifier(Verifier): + """ + 验证器用于物品计数游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = math_verify.parse(data.answer) + parsed_answer = math_verify.parse(test_answer) + + if parsed_answer is None: + return False + return math_verify.verify(parsed_answer, ground_truth) + except Exception as e: + print(f"NOTE!!! parse error!!!! (OperationVerifier): {e}") + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py b/src/reasoning360/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py new file mode 100644 index 000000000..524f5859b --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py @@ -0,0 +1,169 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +import ast + + +class SkyscraperPuzzleVerifier(Verifier): + """ + 摩天楼游戏验证器,用于验证模型提供的解答是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否符合摩天楼游戏的规则 + + @param data: 包含游戏信息的Data对象 + @param test_answer: 游戏类提取的网格数据 + @return: 回答是否正确的布尔值 + """ + try: + # 获取游戏元数据 + metadata = data.metadata + n = metadata['n'] + top = metadata['top'] + bottom = metadata['bottom'] + left = metadata['left'] + right = metadata['right'] + + self.n = n + test_answer = self.extract_answer(test_solution) + + # print(f"验证: 游戏规模 {n}×{n}") + # print(f"上方提示: {top}") + # print(f"下方提示: {bottom}") + # print(f"左侧提示: {left}") + # print(f"右侧提示: {right}") + + # 使用提取好的网格数据 + grid = test_answer + + # 检查网格是否是字符串,如果是,说明提取失败 + if isinstance(grid, str): + # print("无法提取有效网格") + return False + + print("提取的网格:") + for row in grid: + print(row) + + # 检查网格规模 + if len(grid) != n or any(len(row) != n for row in grid): + # print(f"网格规模不正确,应为 {n}×{n}") + return False + + # 检查数字范围 (1 到 n) + for i in range(n): + for j in range(n): + if not isinstance(grid[i][j], int) or grid[i][j] < 1 or grid[i][j] > n: + # print(f"位置 ({i+1},{j+1}) 的值 {grid[i][j]} 不在有效范围内 (1-{n})") + return False + + # 检查每行唯一性 + for i in range(n): + if len(set(grid[i])) != n: + # print(f"第 {i+1} 行包含重复数字") + return False + + # 检查每列唯一性 + for j in range(n): + column = [grid[i][j] for i in range(n)] + if len(set(column)) != n: + # print(f"第 {j+1} 列包含重复数字") + return False + + # 检查从上方观察 + for j in range(n): + visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n)]) + if visible_count != top[j]: + # print(f"从上方看第 {j+1} 列可见楼数为 {visible_count},应为 {top[j]}") + return False + + # 检查从下方观察 + for j in range(n): + visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n-1, -1, -1)]) + if visible_count != bottom[j]: + # print(f"从下方看第 {j+1} 列可见楼数为 {visible_count},应为 {bottom[j]}") + return False + + # 检查从左侧观察 + for i in range(n): + visible_count = self._count_visible_skyscrapers(grid[i]) + if visible_count != left[i]: + # print(f"从左侧看第 {i+1} 行可见楼数为 {visible_count},应为 {left[i]}") + return False + + # 检查从右侧观察 + for i in range(n): + visible_count = self._count_visible_skyscrapers(grid[i][::-1]) + if visible_count != right[i]: + # print(f"从右侧看第 {i+1} 行可见楼数为 {visible_count},应为 {right[i]}") + return False + + # 所有检查通过 + # print("所有验证规则通过!") + return True + + except Exception as e: + print(f"Verification error (SkyscraperPuzzle): {e}") + return False + + def _count_visible_skyscrapers(self, heights): + """ + 计算从一个方向看过去能看到的摩天楼数量 + + @param heights: 从观察方向依次排列的摩天楼高度列表 + @return: 可见的摩天楼数量 + """ + visible_count = 0 + max_height = 0 + + for height in heights: + if height > max_height: + visible_count += 1 + max_height = height + + return visible_count + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取网格数据 + + @param test_solution: 模型的完整回答 + @return: 提取的解答网格数据 + """ + try: + n = self.n + + # 从 ```python 代码块中提取 + code_block_pattern = r"```python\s*\n([\s\S]*?)\n\s*```" + code_blocks = re.findall(code_block_pattern, test_solution) + + if code_blocks: + # 取第一个代码块(通常只有一个) + code_block = code_blocks[0].strip() + try: + # 直接解析代码块 + grid = ast.literal_eval(code_block) + # 验证是否为有效的n×n网格 + if (isinstance(grid, list) and + len(grid) == n and + all(isinstance(row, list) and len(row) == n for row in grid)): + return grid + except Exception: + # 如果直接解析失败,尝试移除注释后再解析 + code_without_comments = re.sub(r'#.*$', '', code_block, flags=re.MULTILINE) + try: + grid = ast.literal_eval(code_without_comments.strip()) + if (isinstance(grid, list) and + len(grid) == n and + all(isinstance(row, list) and len(row) == n for row in grid)): + return grid + except Exception: + pass + + # 如果提取失败,返回原始答案 + return test_solution + except Exception as e: + print(f"NOTE!!! parse error!!!! (SkyscraperPuzzle): {e}") + return test_solution \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/space_reasoning_tree_verifier.py b/src/reasoning360/utils/reward_score/synlogic/space_reasoning_tree_verifier.py new file mode 100644 index 000000000..abc165d5c --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/space_reasoning_tree_verifier.py @@ -0,0 +1,44 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import math_verify + +class SpaceReasoningTreeVerifier(Verifier): + """ + 验证器用于空间推理树游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + test_answer = self.extract_answer(test_answer) + if test_answer is None: + return False + test_answer = test_answer.replace(",", ",").replace(" ", "") + ground_truth = data.answer.replace(",", ",").replace(" ", "") + test_set = set(test_answer.split(",")) + ground_truth_set = set(ground_truth.split(",")) + return test_set == ground_truth_set + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/space_reasoning_verifier.py b/src/reasoning360/utils/reward_score/synlogic/space_reasoning_verifier.py new file mode 100644 index 000000000..249f2dc08 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/space_reasoning_verifier.py @@ -0,0 +1,41 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import math_verify + + +class SpaceReasoningVerifier(Verifier): + """ + 验证器用于空间推理游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + test_answer = self.extract_answer(test_answer) + if test_answer is None: + return False + return test_answer.lower() == data.answer.lower() + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/star_placement_puzzle_verifier.py b/src/reasoning360/utils/reward_score/synlogic/star_placement_puzzle_verifier.py new file mode 100644 index 000000000..98715e19a --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/star_placement_puzzle_verifier.py @@ -0,0 +1,160 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +import ast + +import re + +class StarPlacementPuzzleVerifier(Verifier): + """ + 星星放置游戏验证器,用于验证模型提供的解答是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否符合星星放置游戏的规则 + + @param data: 包含游戏信息的Data对象 + @param star_coords: 通过extract_answer提取的星星坐标字典 {区域: [(行,列), ...]} + @return: 回答是否正确的布尔值 + """ + try: + star_coords = self.extract_answer(test_solution) + # 获取游戏元数据 + metadata = data.metadata + n = metadata['n'] + k = metadata['k'] + region_grid = metadata['region_grid'] + + # print(f"验证: 游戏规模 {n}×{n}, 每行/列/区域星星数量: {k}") + + # 检查是否有有效的星星坐标 + if not star_coords: + # print("无法从回答中提取有效的星星坐标") + return False + + # 创建一个表示星星位置的网格 + star_grid = [[0 for _ in range(n)] for _ in range(n)] + for region, coords in star_coords.items(): + for coord in coords: + row, col = coord + if row < 0 or row >= n or col < 0 or col >= n: + # print(f"无效坐标: ({row},{col}) - 超出网格范围") + return False + star_grid[row][col] = 1 + + # 打印星星网格以便调试 + # print("星星网格:") + for row in star_grid: + print(''.join(['* ' if cell == 1 else '. ' for cell in row])) + + # 1. 检查每行是否有k颗星星 + for i in range(n): + stars_in_row = sum(star_grid[i]) + if stars_in_row != k: + # print(f"行 {i+1} 有 {stars_in_row} 颗星星,应该有 {k} 颗") + return False + + # 2. 检查每列是否有k颗星星 + for j in range(n): + stars_in_col = sum(star_grid[i][j] for i in range(n)) + if stars_in_col != k: + # print(f"列 {j+1} 有 {stars_in_col} 颗星星,应该有 {k} 颗") + return False + + # 3. 检查每个区域是否有k颗星星 + regions = {} + for i in range(n): + for j in range(n): + region = region_grid[i][j] + if region not in regions: + regions[region] = [] + regions[region].append((i, j)) + + for region, cells in regions.items(): + stars_in_region = sum(star_grid[i][j] for i, j in cells) + if stars_in_region != k: + # print(f"区域 {region} 有 {stars_in_region} 颗星星,应该有 {k} 颗") + return False + + # 4. 检查星星是否互不相邻(水平、垂直、对角线) + for i in range(n): + for j in range(n): + if star_grid[i][j] == 1: + # 检查周围8个方向 + for di in [-1, 0, 1]: + for dj in [-1, 0, 1]: + if di == 0 and dj == 0: + continue # 跳过自身 + ni, nj = i + di, j + dj + if 0 <= ni < n and 0 <= nj < n and star_grid[ni][nj] == 1: + # print(f"星星在 ({i},{j}) 与星星在 ({ni},{nj}) 相邻") + return False + + # 所有检查通过 + # print("所有验证规则通过!") + return True + + except Exception as e: + print(f"Verification error (StarPlacementPuzzle): {e}") + return False + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取星星坐标 + + @param test_solution: 模型的完整回答 + @return: 提取的星星坐标字典 {区域: [(行,列), ...]} + """ + try: + # 从Python代码块中提取 + python_match = re.search(r'```python\s*\n(.*?)\n\s*```', test_solution, re.DOTALL) + if not python_match: + # print("回答中没有找到```python代码块") + return None + + code_content = python_match.group(1) + + # 尝试从Python代码中提取字典 + try: + # 先尝试直接提取字典内容 + dict_match = re.search(r'\{[^{}]*\}', code_content, re.DOTALL) + if dict_match: + dict_str = dict_match.group(0) + try: + # 将字符串转换为字典 + coords_dict = ast.literal_eval(dict_str) + # 如果成功且是字典类型,继续处理 + if isinstance(coords_dict, dict): + # 将坐标减1(因为用户输入的坐标是1-索引) + result = {} + for region, coords in coords_dict.items(): + result[region] = [(row-1, col-1) for row, col in coords] + return result + except (ValueError, SyntaxError) as e: + print(f"NOTE!!! parse error!!!! (StarPlacementPuzzle): {e}") + + # 如果上面的方法失败,尝试解析变量赋值 + assign_match = re.search(r'(\w+)\s*=\s*(\{[^{}]*\})', code_content, re.DOTALL) + if assign_match: + dict_str = assign_match.group(2) + try: + # 将字符串转换为字典 + coords_dict = ast.literal_eval(dict_str) + # 如果成功且是字典类型,继续处理 + if isinstance(coords_dict, dict): + # 将坐标减1(因为用户输入的坐标是1-索引) + result = {} + for region, coords in coords_dict.items(): + result[region] = [(row-1, col-1) for row, col in coords] + return result + except (ValueError, SyntaxError) as e: + print(f"NOTE!!! parse error!!!! (StarPlacementPuzzle): {e}") + except Exception as e: + print(f"NOTE!!! parse error!!!! (StarPlacementPuzzle): {e}") + + return None + + except Exception as e: + print(f"NOTE!!! parse error!!!! (StarPlacementPuzzle): {e}") + return None \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/synlogic.py b/src/reasoning360/utils/reward_score/synlogic/synlogic.py new file mode 100644 index 000000000..e5658878c --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/synlogic.py @@ -0,0 +1,93 @@ +import os +import sys +print(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +# from .game_of_24.scripts.game_of_24_verifier import GameOf24Verifier +# from .cryptarithm.scripts.cryptarithm_verifier import CryptarithmVerifier +# from .survo.scripts.survo_verifier import SurvoVerifier +from .campsite_verifier import CampsiteVerifier +from .skyscraper_puzzle_verifier import SkyscraperPuzzleVerifier +from .web_of_lies_verifier import WebOfLiesVerifier +from .goods_exchange_verifier import GoodsExchangeVerifier +# from .sudoku.scripts.sudoku_verifier import SudokuVerifier +# from corpus.misc.tasks.zebra_puzzle.scripts.zebra_puzzle_verifier import ZebraPuzzleVerifier +# from corpus.misc.tasks.bbeh.scripts.bbeh_verifier import BBEHVerifier +# from corpus.misc.tasks.arc_agi.scripts.arc_agi_verifier import ArcAGIVerifier +from .object_properties_verifier import ObjectPropertiesVerifier +from .object_counting_verifier import ObjectCountingVerifier +from .star_placement_puzzle_verifier import StarPlacementPuzzleVerifier +from .arrow_maze_verifier import ArrowMazeVerifier +# from .kukurasu.scripts.kukurasu_verifier import KukurasuVerifier +from .number_wall_verifier import NumberWallVerifier +from .numbrix_verifier import NumbrixVerifier +from .norinori_verifier import NorinoriVerifier +from .minesweeper_verifier import MinesweeperVerifier +from .operation_verifier import OperationVerifier +from .word_sorting_mistake_verifier import WordSortingMistakeVerifier +from .math_path_verifier import MathPathVerifier +from .boolean_expressions_verifier import BooleanExpressionsVerifier +from .space_reasoning_verifier import SpaceReasoningVerifier +from .space_reasoning_tree_verifier import SpaceReasoningTreeVerifier +from .word_sorting_verifier import WordSortingVerifier +# from corpus.misc.tasks.gpqa.scripts.gpqa_verifier import GPQAVerifier +# from .cipher.scripts.cipher_verifier import CipherVerifier +from .time_sequence_verifier import TimeSequenceVerifier +from .wordscapes_verifier import WordscapesVerifier +# from corpus.misc.tasks.bbh.scripts.boolean_expressions_verifier import BBHBooleanExpressionsVerifier +# from corpus.misc.tasks.bbh.scripts.causal_judgement_verifier import BBHCausalJudgementVerifier # yes no +# from corpus.misc.tasks.bbh.scripts.date_understanding_verifier import BBHDateUnderstandingVerifier # multi-choice +# from corpus.misc.tasks.bbh.scripts.dyck_languages_verifier import BBHDyckLanguagesVerifier +# from corpus.misc.tasks.bbh.scripts.formal_fallacies_verifier import BBHFormalFallaciesVerifier +# from corpus.misc.tasks.bbh.scripts.multistep_arithmetic_two_verifier import BBHMultistepArithmeticVerifier # number +# from corpus.misc.tasks.bbh.scripts.sports_understanding_verifier import BBHSportsUnderstandingVerifier +# from corpus.misc.tasks.bbh.scripts.web_of_lies_verifier import BBHWebOfLiesVerifier +# from corpus.misc.tasks.bbh.scripts.word_sorting_verifier import BBHWordSortingVerifier +from .game_of_buggy_tables_verifier import BuggyTableVerifier +# from .calcudoko.scripts.calcudoko_verifier import CalcudokoVerifier +from .dyck_language_verifier import DyckLanguageVerifier +from .dyck_language_errors_verifier import DyckLanguageErrorsVerifier +from .dyck_language_reasoning_errors_verifier import DyckLanguageReasoningErrorsVerifier +# from .futoshiki.scripts.futoshiki_verifier import FutoshikiVerifier + +# NOTE: Add new tasks in alphabetical order +verifier_classes = { + "arrow_maze": ArrowMazeVerifier, + "boolean_expressions": BooleanExpressionsVerifier, + "buggy_tables": BuggyTableVerifier, + # "calcudoko": CalcudokoVerifier, + "campsite": CampsiteVerifier, + # "cipher": CipherVerifier, + # "cryptarithm": CryptarithmVerifier, + "dyck_language": DyckLanguageVerifier, + "dyck_language_errors": DyckLanguageErrorsVerifier, + "dyck_language_reasoning_errors": DyckLanguageReasoningErrorsVerifier, + # "futoshiki": FutoshikiVerifier, + "goods_exchange": GoodsExchangeVerifier, + # "gpqa_diamond": GPQAVerifier, + # "kukurasu": KukurasuVerifier, + "math_path": MathPathVerifier, + # "arc_agi": ArcAGIVerifier, + # "arc_agi_2": ArcAGIVerifier, + # "mathador": GameOf24Verifier, + "minesweeper": MinesweeperVerifier, + "norinori": NorinoriVerifier, + "number_wall": NumberWallVerifier, + "numbrix": NumbrixVerifier, + "object_counting": ObjectCountingVerifier, + "object_properties": ObjectPropertiesVerifier, + "operation": OperationVerifier, + "skyscraper_puzzle": SkyscraperPuzzleVerifier, + "space_reasoning": SpaceReasoningVerifier, + "space_reasoning_tree": SpaceReasoningTreeVerifier, + "star_placement_puzzle": StarPlacementPuzzleVerifier, + # "sudoku": SudokuVerifier, + # "survo": SurvoVerifier, + "time_sequence": TimeSequenceVerifier, + "web_of_lies": WebOfLiesVerifier, + "word_sorting": WordSortingVerifier, + "word_sorting_mistake": WordSortingMistakeVerifier, + "wordscapes": WordscapesVerifier, + # "zebra_puzzle": ZebraPuzzleVerifier, + # ** bbeh_classes, + # ** bbh_classes, +} \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/time_sequence_verifier.py b/src/reasoning360/utils/reward_score/synlogic/time_sequence_verifier.py new file mode 100644 index 000000000..711a43d86 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/time_sequence_verifier.py @@ -0,0 +1,69 @@ +import json +import numpy as np +from .data import Data +from .verifier import Verifier +import re + +class TimeSequenceVerifier(Verifier): + """ + 验证器用于验证 time sequence 的答案是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的答案,格式为数字列表 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution) + # 解析元数据 + metadata = data.metadata + true_answers = metadata['records']['answers'] + + # 解析模型给出的列表 + try: + test_list = json.loads(test_answer.replace(",", ",")) + except: + print(f"NOTE!!! parse error!!!! (TimeSequence 1): {e}") + return False + + try: + if test_list[0]!=true_answers['answer_maxLen']: + # print(f"最长会议时间不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]") + return False + if test_list[1]!=true_answers['answer_nums']: + # print(f"可选会议数量不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]") + return False + except: + print(f"NOTE!!! parse error!!!! (TimeSequence 2): {e}") + return False + + # 所有检查都通过 + # print("验证结果: 正确") + return True + except Exception as e: + print(f"Verification error (TimeSequence): {e}") + return False + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取答案(矩阵) + + @param test_solution: 模型的完整回答 + @return: 提取答案列表 + """ + if not test_solution: + return "" + + # 尝试提取列表 + matrix_pattern = r'\[.*?\]' + matrix_matches = re.findall(matrix_pattern, test_solution, re.DOTALL) + if matrix_matches: + # 使用最后一个匹配的列表 + # print(matrix_matches) + return matrix_matches[-1].strip() + + # 如果失败,返回空字符串 + return "" \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/verifier.py b/src/reasoning360/utils/reward_score/synlogic/verifier.py new file mode 100644 index 000000000..498e87a82 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/verifier.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from .data import Data + +class Verifier(ABC): + """ + Base class for verifier + """ + def __init__(self): + pass + + @abstractmethod + def verify(self, data: Data, test_answer: str): + """ + Verify whether the test answer is consistent with the gold answer + @param data: Data + @param test_answer: str + @return: bool + """ + raise NotImplementedError("Verifier.verify() is not implemented") + + @abstractmethod + def extract_answer(self, test_solution: str): + """ + Extract the answer from the test solution + @param test_solution: str + @return: str + """ + raise NotImplementedError("Verifier.extract_answer() is not implemented") + +import re + +THOUGHT_DELIMITER_START = "" +THOUGHT_DELIMITER_END = "" + +def _extract_answer(text): + # 定义正则表达式模式,匹配 之间的内容 + pattern = r'(.*?)' + + # 使用 re.search 查找第一个匹配项 + match = re.search(pattern, text, re.DOTALL) + + # 如果找到匹配项,返回匹配的内容 + if match: + return match.group(1).strip() + else: + return None + +def _extract_solution_with_thought(solution_str): + + model_output = solution_str + + if THOUGHT_DELIMITER_END in solution_str: + model_output = solution_str.split(THOUGHT_DELIMITER_END)[1] + + predict_answer = _extract_answer(model_output) + + + if predict_answer is not None: + return predict_answer + else: + return "" + + +class ExactMatchVerifier(Verifier): + """ + Verifier for Exact Match + """ + def verify(self, data: Data, test_solution: str): + try: + test_answer = self.extract_answer(test_solution) + ground_truth = data.answer + correct = test_answer == ground_truth + if correct: + acc_score = 1.0 + else: + acc_score = 0 + + return acc_score + except: + return False + + def extract_answer(self, test_solution: str): + return _extract_solution_with_thought(solution_str=test_solution) \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/web_of_lies_verifier.py b/src/reasoning360/utils/reward_score/synlogic/web_of_lies_verifier.py new file mode 100644 index 000000000..94be44fed --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/web_of_lies_verifier.py @@ -0,0 +1,135 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + +class WebOfLiesVerifier(Verifier): + """ + 验证器用于检查谎言之网游戏的答案是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution) + # 获取预期答案和测试答案 + expected_answer = data.answer.lower() + + # 清理测试答案 + test_answer = test_answer.lower() + + # 提取预期答案中的真假值 + expected_truths = self._parse_answer(expected_answer) + + # 提取测试答案中的真假值 + test_truths = self._parse_answer(test_answer) + + # print(f"验证: 预期答案={expected_truths}, 模型答案={test_truths}") + + # 检查答案列表长度是否匹配 + if len(expected_truths) != len(test_truths): + # print(f"验证失败: 答案长度不匹配,预期 {len(expected_truths)},实际 {len(test_truths)}") + return False + + # 检查每个位置的答案是否匹配 + for i, (expected, actual) in enumerate(zip(expected_truths, test_truths)): + if expected != actual: + # print(f"验证失败: 第 {i+1} 个答案不匹配,预期 {expected},实际 {actual}") + return False + + # print("验证成功: 所有答案匹配") + return True + + except Exception as e: + print(f"Verification error (WebOfLies): {e}") + return False + + def _parse_answer(self, answer_str): + """ + 从答案字符串中解析出真假值列表 + + @param answer_str: 答案字符串 + @return: 真假值列表,True表示说真话,False表示说谎话 + """ + # 尝试匹配英文答案格式 (yes/no) + yes_pattern = r'yes|true|truth' + no_pattern = r'no|false|lie' + + # 尝试匹配中文答案格式 (是/否) + cn_yes_pattern = r'是|真话|真' + cn_no_pattern = r'否|假话|假|谎' + + # 组合模式 + yes_patterns = f'({yes_pattern}|{cn_yes_pattern})' + no_patterns = f'({no_pattern}|{cn_no_pattern})' + + # 根据答案字符串中的关键词确定真假值 + truths = [] + + # 寻找所有可能的yes/no或是/否答案 + all_answers = re.findall(rf'{yes_patterns}|{no_patterns}', answer_str) + + for match in all_answers: + # match是一个元组,需要找到非空的元素 + match_str = next((m for m in match if m), '') + + if re.search(yes_pattern, match_str) or re.search(cn_yes_pattern, match_str): + truths.append(True) + elif re.search(no_pattern, match_str) or re.search(cn_no_pattern, match_str): + truths.append(False) + + return truths + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 + """ + if not test_solution: + return "" + # 中文模式 + cn_patterns = [ + r'答案是[::]\s*\*\*([^*]+)\*\*[.。]*$', # 匹配"答案是:**是,否,是**"格式 + ] + + # 英文模式 + en_patterns = [ + r'[Tt]he answer is[::=]\s*\*\*([^*]+)\*\*[.。]*$', # 匹配"The answer is: **yes, no, yes**"格式 + ] + + # 尝试匹配所有模式 + patterns = cn_patterns + en_patterns + + for pattern in patterns: + matches = re.findall(pattern, test_solution, re.DOTALL) + if matches: + return matches[-1].strip() + + # 如果上面的模式都没匹配到,尝试更宽松的匹配 + # 查找最后一行中的加粗文本 + lines = test_solution.strip().split('\n') + if lines: + last_line = lines[-1].strip() + bold_match = re.search(r'\*\*([^*]+)\*\*', last_line) + if bold_match: + return bold_match.group(1).strip() + + # 尝试匹配"答案是"或"The answer is"后面的文本 + answer_match = re.search(r'(?:答案是|[Tt]he answer is)[::=]?\s*(.*?)(?:[.。]|$)', last_line) + if answer_match: + return answer_match.group(1).strip() + + # 如果没有找到格式化的答案,尝试直接匹配yes/no或是/否序列 + yes_no_pattern = r'(?:\b(?:yes|no|是|否)\b[,,\s]*)+' + matches = re.findall(yes_no_pattern, test_solution.lower()) + if matches: + return matches[-1].strip() + + # 如果没有匹配到任何模式,返回空字符串 + return "" \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/word_sorting_mistake_verifier.py b/src/reasoning360/utils/reward_score/synlogic/word_sorting_mistake_verifier.py new file mode 100644 index 000000000..8216f7605 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/word_sorting_mistake_verifier.py @@ -0,0 +1,45 @@ +import re +from .data import Data +from .verifier import Verifier + +class WordSortingMistakeVerifier(Verifier): + """ + 验证器用于word sorting mistake的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = data.answer if data.answer is not None else "No" + parsed_answer = self.extract_answer(test_answer) + + if parsed_answer is None: + return False + + if parsed_answer.isdigit(): + try: + return int(parsed_answer) == int(ground_truth) + except Exception as e: + return False + else: + return parsed_answer.lower() == ground_truth.lower() + except Exception as e: + print(f"NOTE!!! parse error!!!! (WordSortingMistake): {e}") + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\boxed{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return None + \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/word_sorting_verifier.py b/src/reasoning360/utils/reward_score/synlogic/word_sorting_verifier.py new file mode 100644 index 000000000..567581086 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/word_sorting_verifier.py @@ -0,0 +1,43 @@ +import re +from .data import Data +from .verifier import Verifier + +class WordSortingVerifier(Verifier): + """ + 验证器用于单词排序游戏的答案是否正确 + """ + def str2list(self, answer_str): + # 替换中文逗号为英文逗号,并删除所有空格 + answer_str = answer_str.replace(",", ",").replace(" ", "") + return [w.strip() for w in answer_str.split(",")] + + def verify(self, data: Data, test_answer: str): + try: + ground_truth = self.str2list(data.answer) + parsed_answer = self.str2list(self.extract_answer(test_answer)) + + if parsed_answer is None: + return False + return parsed_answer == ground_truth + + except Exception as e: + print(f"NOTE!!! parse error!!!! (WordSorting): {e}") + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\Box{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\Box{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return None \ No newline at end of file diff --git a/src/reasoning360/utils/reward_score/synlogic/wordscapes_verifier.py b/src/reasoning360/utils/reward_score/synlogic/wordscapes_verifier.py new file mode 100644 index 000000000..2f30a7842 --- /dev/null +++ b/src/reasoning360/utils/reward_score/synlogic/wordscapes_verifier.py @@ -0,0 +1,159 @@ +""" +Wordscapes verifier module for the reasonreason framework. +""" + +import json +import re +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + +debug_mode = False + +class WordscapesVerifier(Verifier): + """ + Verifier for Wordscapes game + """ + def verify(self, data, test_solution: str): + """ + Verify whether the test answer is consistent with the gold answer + + Args: + data: WordscapesData + test_solution: str containing the solution + + Returns: + float: Score between 0 and 1 + """ + try: + extracted_answer = self.extract_answer(test_solution) + if not extracted_answer: + print("NOTE!!! parse error!!!! (Wordscapes): {e}") + return False + + if debug_mode: + for row in extracted_answer: + print(" ".join(cell if cell != " " else "_" for cell in row)) + + # Get grid, across_words, and down_words from data + grid = data.metadata["grid"] + across_words = data.metadata["across_words"] + down_words = data.metadata["down_words"] + + # Validate grid dimensions + if len(extracted_answer) != len(grid): + # print(f"Grid height mismatch: expected {len(grid)}, got {len(extracted_answer)}") + return False + + for i in range(len(grid)): + if len(extracted_answer[i]) != len(grid[i]): + # print(f"Grid width mismatch at row {i}: expected {len(grid[i])}, got {len(extracted_answer[i])}") + return False + + # Check if the answer respects the grid layout (X for letters, 0 for empty) + for i in range(len(grid)): + for j in range(len(grid[i])): + if grid[i][j] == "0" and extracted_answer[i][j].strip(): + # print(f"Expected empty space at position ({i},{j}), got '{extracted_answer[i][j]}'") + return False + if grid[i][j] == "X" and not extracted_answer[i][j].strip(): + # print(f"Expected letter at position ({i},{j}), got empty space") + return False + + # Verify across words + for word in across_words: + found = False + for i in range(len(extracted_answer)): + row_str = ''.join(extracted_answer[i]).replace(' ', '').lower() + if word.lower() in row_str: + found = True + break + if not found and word: + # print(f"Across word '{word}' not found in the grid") + return 0 + + # Verify down words + for word in down_words: + found = False + for j in range(len(extracted_answer[0])): + col = [] + for i in range(len(extracted_answer)): + if j < len(extracted_answer[i]): + col.append(extracted_answer[i][j]) + col_str = ''.join(col).replace(' ', '').lower() + if word.lower() in col_str: + found = True + break + if not found and word: # Only check if word is not empty + # print(f"Down word '{word}' not found in the grid") + return False + + # All checks passed + return True + except Exception as e: + print(f"Verification error (Wordscapes): {e}") + return False + + def extract_answer(self, test_solution: str): + """ + Extract the answer from the test solution + + Args: + test_solution: str + + Returns: + list: 2D grid of the answer or None if extraction fails + """ + try: + # Remove thoughts if present + if THOUGHT_DELIMITER_START in test_solution and THOUGHT_DELIMITER_END in test_solution: + # Extract only the part after the thoughts + thought_end_pos = test_solution.rfind(THOUGHT_DELIMITER_END) + if thought_end_pos >= 0: + test_solution = test_solution[thought_end_pos + len(THOUGHT_DELIMITER_END):] + + # Clean up the response and find the grid pattern + # Look for a pattern like [[...]] or [[[...]]] + grid_pattern = re.search(r'\[\s*\[(?:\s*\[)?(.+?)(?:\]\s*)?\]\s*\]', test_solution, re.DOTALL) + if not grid_pattern: + return None + + grid_text = grid_pattern.group(1) + + # Handle various formats + rows = [] + + # Check if rows are separated by commas + split_rows = re.split(r'\],\s*\[', grid_text) + + for row_text in split_rows: + # Clean the row text and extract characters + row_text = row_text.strip().strip('[],') + + # Extract quoted characters: "X" or 'X' or just X + chars = [] + + # Look for quoted strings or standalone characters + char_matches = re.findall(r'\"([^\"]*)\"|\'([^\']*)\'|([^,\s]+)', row_text) + + for match in char_matches: + # Take the first non-empty group from each match + char = next((x for x in match if x), "") + + # Handle numeric or empty values (0, "", '') + if char == "0" or char == "": + char = " " + + chars.append(char) + + if chars: # Only add non-empty rows + rows.append(chars) + + # Make sure we have a valid grid + if not rows or not all(rows): + return None + + return rows + + except Exception as e: + print(f"NOTE!!! parse error!!!! (Wordscapes): {e}") + return None + \ No newline at end of file diff --git a/verl/utils/reward_score/tablereason.py b/src/reasoning360/utils/reward_score/tablereason.py similarity index 100% rename from verl/utils/reward_score/tablereason.py rename to src/reasoning360/utils/reward_score/tablereason.py diff --git a/verl/utils/reward_score/zebra_puzzle.py b/src/reasoning360/utils/reward_score/zebra_puzzle.py similarity index 100% rename from verl/utils/reward_score/zebra_puzzle.py rename to src/reasoning360/utils/reward_score/zebra_puzzle.py diff --git a/verl/workers/reward_manager/__init__.py b/src/reasoning360/workers/reward_manager/__init__.py similarity index 63% rename from verl/workers/reward_manager/__init__.py rename to src/reasoning360/workers/reward_manager/__init__.py index 173cf1bb8..b2dfc8c24 100644 --- a/verl/workers/reward_manager/__init__.py +++ b/src/reasoning360/workers/reward_manager/__init__.py @@ -12,27 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .registry import get_reward_manager_cls, register # noqa: I001 -from .batch import BatchRewardManager +from verl.workers.reward_manager.registry import get_reward_manager_cls, register # noqa: I001 +from verl.workers.reward_manager.batch import BatchRewardManager +from verl.workers.reward_manager.naive import NaiveRewardManager +from verl.workers.reward_manager.prime import PrimeRewardManager +# NOTE: added by Reasoning360. from .dapo import DAPORewardManager -from .naive import NaiveRewardManager -from .prime import PrimeRewardManager - -# Added by Reasoning360 +from .llm_judge import LLMJudgeRewardManager from .naive_parallel import NaiveParallelRewardManager from .async_mp import AsyncMultiProcessRewardManager -from .llm_judge import LLMJudgeRewardManager # Note(haibin.lin): no need to include all reward managers here in case of complicated dependencies __all__ = [ "BatchRewardManager", - "DAPORewardManager", "NaiveRewardManager", "PrimeRewardManager", - "register", - "get_reward_manager_cls", - # Added by Reasoning360 + # NOTE: added by Reasoning360. + "DAPORewardManager", + "LLMJudgeRewardManager", "NaiveParallelRewardManager", "AsyncMultiProcessRewardManager", - "LLMJudgeRewardManager", + "register", + "get_reward_manager_cls", ] + +# Import experimental reward managers to ensure they are registered +try: + from verl.experimental.reward.reward_loop.limited import RateLimitedRewardLoopManager # noqa: F401 + + __all__.append("RateLimitedRewardLoopManager") +except ImportError: + pass # Optional dependency, may not be available diff --git a/verl/workers/reward_manager/async_mp.py b/src/reasoning360/workers/reward_manager/async_mp.py similarity index 99% rename from verl/workers/reward_manager/async_mp.py rename to src/reasoning360/workers/reward_manager/async_mp.py index 09b913c74..6eece6b42 100644 --- a/verl/workers/reward_manager/async_mp.py +++ b/src/reasoning360/workers/reward_manager/async_mp.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# NOTE: added by Reasoning360 + import asyncio from collections import defaultdict from concurrent.futures import ProcessPoolExecutor @@ -21,7 +23,7 @@ import torch from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from reasoning360.utils.reward_score import _default_compute_score from verl.workers.reward_manager import register diff --git a/verl/workers/reward_manager/dapo.py b/src/reasoning360/workers/reward_manager/dapo.py similarity index 87% rename from verl/workers/reward_manager/dapo.py rename to src/reasoning360/workers/reward_manager/dapo.py index cb8b5cf22..91a5ac28c 100644 --- a/verl/workers/reward_manager/dapo.py +++ b/src/reasoning360/workers/reward_manager/dapo.py @@ -17,12 +17,17 @@ import torch from verl import DataProto -from verl.utils.reward_score import default_compute_score +from reasoning360.utils.reward_score import _default_compute_score from verl.workers.reward_manager import register +from verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY +from verl.workers.reward_manager.abstract import AbstractRewardManager + +if "dapo" in REWARD_MANAGER_REGISTRY: + del REWARD_MANAGER_REGISTRY["dapo"] @register("dapo") -class DAPORewardManager: +class DAPORewardManager(AbstractRewardManager): """The reward manager.""" def __init__( @@ -36,7 +41,7 @@ def __init__( ) -> None: self.tokenizer = tokenizer self.num_examine = num_examine # the number of batches of decoded responses to print to the console - self.compute_score = compute_score or default_compute_score + self.compute_score = compute_score or _default_compute_score self.reward_fn_key = reward_fn_key self.overlong_buffer_cfg = overlong_buffer_cfg self.max_resp_len = max_resp_len @@ -55,7 +60,9 @@ def __call__(self, data: DataProto, return_dict: bool = False): # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn if "rm_scores" in data.batch.keys(): if return_dict: - return {"reward_tensor": data.batch["rm_scores"]} + reward_extra_keys = data.meta_info.get("reward_extra_keys", []) + reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys} + return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info} else: return data.batch["rm_scores"] @@ -89,7 +96,11 @@ def __call__(self, data: DataProto, return_dict: bool = False): data_source = data_item.non_tensor_batch[self.reward_fn_key] - extra_info = data_item.non_tensor_batch.get("extra_info", None) + extra_info = data_item.non_tensor_batch.get("extra_info", {}) + + rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {}) + + extra_info["rollout_reward_scores"] = rollout_reward_scores result = self.compute_score( data_source=data_source, diff --git a/verl/workers/reward_manager/llm_judge.py b/src/reasoning360/workers/reward_manager/llm_judge.py similarity index 98% rename from verl/workers/reward_manager/llm_judge.py rename to src/reasoning360/workers/reward_manager/llm_judge.py index cb44354da..e241c20cd 100644 --- a/verl/workers/reward_manager/llm_judge.py +++ b/src/reasoning360/workers/reward_manager/llm_judge.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# NOTE: added by Reasoning360 + import asyncio from concurrent.futures import ProcessPoolExecutor from functools import partial @@ -19,7 +21,7 @@ import torch from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from reasoning360.utils.reward_score import _default_compute_score async def single_compute_score(evaluation_func, data_source, solution_str, ground_truth, extra_info, executor, timeout=300.): diff --git a/verl/workers/reward_manager/naive_parallel.py b/src/reasoning360/workers/reward_manager/naive_parallel.py similarity index 96% rename from verl/workers/reward_manager/naive_parallel.py rename to src/reasoning360/workers/reward_manager/naive_parallel.py index 4aad2fa71..b3851058e 100644 --- a/verl/workers/reward_manager/naive_parallel.py +++ b/src/reasoning360/workers/reward_manager/naive_parallel.py @@ -1,8 +1,9 @@ from verl import DataProto -from verl.utils.reward_score import _default_compute_score +from reasoning360.utils.reward_score import _default_compute_score import torch from multiprocessing import Pool +# NOTE: added by Reasoning360 class NaiveParallelRewardManager: diff --git a/tests/README.md b/tests/README.md deleted file mode 100644 index 479f06933..000000000 --- a/tests/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Tests layout - -Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: -- `tests/trainer` for testing functionality related to `verl/trainer` -- `tests/models` for testing functionality related to `verl/models` -- ... - -There are a few folders with `special_` prefix, created for special purposes: -- `special_distributed`: unit tests that must run with multiple GPUs -- `special_e2e`: end-to-end tests with training/generation scripts -- `special_npu`: tests for NPUs -- `special_sanity`: a suite of quick sanity tests -- `special_standalone`: a set of test that are designed to run in dedicated environments - -Accelerators for tests -- By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. -- For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. - -# Workflow layout - -All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: -1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` -2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` -3. End-to-end tests: `e2e_*.yml` -4. Unit tests - - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` - - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. - - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when - - new workflow yaml is added to `.github/workflows` - - new tests are added to workflow mentioned in 2. \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/tests/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/experimental/agent_loop/agent_utils.py b/tests/experimental/agent_loop/agent_utils.py deleted file mode 100644 index 3c708c42c..000000000 --- a/tests/experimental/agent_loop/agent_utils.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ray -from omegaconf import DictConfig - -from verl.experimental.agent_loop import AgentLoopManager -from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup -from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role -from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker - - -def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup: - # =========================== 1. Create hybrid ActorRollout workers =========================== - actor_rollout_cls = ( - AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker - ) - role_worker_mapping = { - Role.ActorRollout: ray.remote(actor_rollout_cls), - } - global_pool_id = "global_pool" - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - } - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - resource_pool_manager.create_resource_pool() - resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} - - # create actor and rollout - resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs( - cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout" - ) - resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls - - all_wg = {} - for resource_pool, class_dict in resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - actor_rollout_wg = all_wg["actor_rollout"] - actor_rollout_wg.init_model() - - if config.actor_rollout_ref.rollout.mode == "sync": - return actor_rollout_wg - - # =========================== 2. Create AgentLoopManager =========================== - agent_loop_manager = AgentLoopManager( - config=config, - worker_group=actor_rollout_wg, - ) - - return agent_loop_manager diff --git a/tests/experimental/agent_loop/test_basic_agent_loop.py b/tests/experimental/agent_loop/test_basic_agent_loop.py deleted file mode 100644 index 14deb01f0..000000000 --- a/tests/experimental/agent_loop/test_basic_agent_loop.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import os -from typing import Any - -import numpy as np -import pytest -import ray -from omegaconf import DictConfig -from transformers.utils import get_json_schema - -from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager -from verl.experimental.agent_loop.agent_loop import get_trajectory_info -from verl.protocol import DataProto -from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema -from verl.utils import hf_tokenizer - - -@pytest.fixture -def init_config() -> DictConfig: - from hydra import compose, initialize_config_dir - - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): - config = compose(config_name="ppo_trainer") - model_path = "Qwen/Qwen2.5-1.5B-Instruct" - config.actor_rollout_ref.model.path = model_path - config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") - config.actor_rollout_ref.rollout.mode = "async" - config.actor_rollout_ref.rollout.prompt_length = 4096 - config.actor_rollout_ref.rollout.response_length = 4096 - config.actor_rollout_ref.rollout.n = 4 - config.actor_rollout_ref.rollout.agent.num_workers = 2 - - # test sleep/wake_up with fsdp offload - config.actor_rollout_ref.actor.fsdp_config.param_offload = True - config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True - - return config - - -def test_single_turn(init_config): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - } - ) - - agent_loop_manager = init_agent_loop_manager(init_config) - - raw_prompts = [ - [ - { - "role": "user", - "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.", - } - ], - [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], - ] - batch = DataProto( - non_tensor_batch={ - "raw_prompt": np.array(raw_prompts), - "agent_name": np.array(["single_turn_agent"] * len(raw_prompts)), - }, - ) - n = init_config.actor_rollout_ref.rollout.n - batch = batch.repeat(n) - result = agent_loop_manager.generate_sequences(prompts=batch) - assert len(result) == len(raw_prompts) * n - - # check result - seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1) - assert result.batch["input_ids"].size(1) == seq_len - assert result.batch["attention_mask"].size(1) == seq_len - assert result.batch["position_ids"].size(1) == seq_len - - # check turns - num_turns = result.non_tensor_batch["__num_turns__"] - assert np.all(num_turns == 2) - - print("Test passed!") - ray.shutdown() - - -class WeatherTool(BaseTool): - def get_current_temperature(self, location: str, unit: str = "celsius"): - """Get current temperature at a location. - - Args: - location: The location to get the temperature for, in the format "City, State, Country". - unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) - - Returns: - the temperature, the location, and the unit in a dict - """ - print(f"[DEBUG] get_current_temperature: {location}, {unit}") - return { - "temperature": 26.1, - "location": location, - "unit": unit, - } - - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - schema = get_json_schema(self.get_current_temperature) - return OpenAIFunctionToolSchema(**schema) - - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - try: - result = self.get_current_temperature(**parameters) - return json.dumps(result), 0, {} - except Exception as e: - return str(e), 0, {} - - -class WeatherToolWithData(BaseTool): - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - schema = get_json_schema(self.get_temperature_date) - return OpenAIFunctionToolSchema(**schema) - - def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): - """Get temperature at a location and date. - - Args: - location: The location to get the temperature for, in the format "City, State, Country". - date: The date to get the temperature for, in the format "Year-Month-Day". - unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) - - Returns: - the temperature, the location, the date and the unit in a dict - """ - print(f"[DEBUG] get_temperature_date: {location}, {date}, {unit}") - return { - "temperature": 25.9, - "location": location, - "date": date, - "unit": unit, - } - - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - try: - result = self.get_temperature_date(**parameters) - return json.dumps(result), 0, {} - except Exception as e: - return str(e), 0, {} - - -def test_tool_agent(init_config): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - } - ) - - # =========================== 1. Init rollout manager =========================== - tool_config = { - "tools": [ - { - "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherTool", - "config": {"type": "native"}, - }, - { - "class_name": "tests.experimental.agent_loop.test_basic_agent_loop.WeatherToolWithData", - "config": {"type": "native"}, - }, - ] - } - tool_config_path = "/tmp/tool_config.json" - with open(tool_config_path, "w") as f: - json.dump(tool_config, f) - - n = 2 - init_config.actor_rollout_ref.rollout.n = n - init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path - init_config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls = 2 - agent_loop_manager = init_agent_loop_manager(init_config) - - # =========================== 2. Generate sequences =========================== - raw_prompts = [ - [ - {"role": "user", "content": "How are you?"}, - ], - [ - {"role": "user", "content": "What's the temperature in Los Angeles now?"}, - ], - [ - {"role": "user", "content": "What's the temperature in New York now?"}, - ], - [ - { - "role": "system", - "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" - "Current Date: 2024-09-30", - }, - {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, - ], - ] - batch = DataProto( - non_tensor_batch={ - "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), - "agent_name": np.array(["tool_agent"] * len(raw_prompts)), - }, - ) - batch = batch.repeat(n) - result = agent_loop_manager.generate_sequences(prompts=batch) - assert len(result) == len(raw_prompts) * n - - # Check turns - num_turns = result.non_tensor_batch["__num_turns__"] - print(f"num_turns: {num_turns}") - for i in range(len(num_turns)): - if i // n == 0: - # [user, assistant] - assert num_turns[i] == 2 - else: - # [user, assistant, tool, assistant] - assert num_turns[i] == 4 - - # Check response_mask - tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) - responses = result.batch["responses"] - response_mask = result.batch["response_mask"] - attention_mask = result.batch["attention_mask"] - assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" - response_length = response_mask.size(1) - - for i in range(len(responses)): - # response with tool response - valid_tokens = responses[i][attention_mask[i][-response_length:].bool()] - response_with_obs = tokenizer.decode(valid_tokens) - - # response without tool response - valid_tokens = responses[i][response_mask[i].bool()] - response_without_obs = tokenizer.decode(valid_tokens) - - assert "" not in response_without_obs, ( - f"found in response: {response_without_obs}" - ) - assert "" not in response_without_obs, ( - f"found in response: {response_without_obs}" - ) - print("=========================") - print(response_with_obs) - print("---") - print(response_without_obs) - - print("Test passed!") - ray.shutdown() - - -@pytest.mark.asyncio -async def test_get_trajectory_info(): - """Tests the get_trajectory_info method.""" - # Initialize the class to set up class-level attributes - step = 10 - index = [1, 1, 3, 3] - expected_info = [ - {"step": step, "sample_index": 1, "rollout_n": 0, "validate": False}, - {"step": step, "sample_index": 1, "rollout_n": 1, "validate": False}, - {"step": step, "sample_index": 3, "rollout_n": 0, "validate": False}, - {"step": step, "sample_index": 3, "rollout_n": 1, "validate": False}, - ] - - trajectory_info = await get_trajectory_info(step, index, validate=False) - - assert trajectory_info == expected_info diff --git a/tests/interactions/__init__.py b/tests/interactions/__init__.py deleted file mode 100644 index b6db0fcef..000000000 --- a/tests/interactions/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/interactions/test_gsm8k_interaction.py b/tests/interactions/test_gsm8k_interaction.py deleted file mode 100644 index bc16877c2..000000000 --- a/tests/interactions/test_gsm8k_interaction.py +++ /dev/null @@ -1,421 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import patch - -import pytest - -from verl.interactions.gsm8k_interaction import Gsm8kInteraction - - -class TestGsm8kInteraction: - """Test cases for Gsm8kInteraction class.""" - - def setup_method(self): - """Set up test environment before each test method.""" - self.config = {"name": "gsm8k"} - self.interaction = Gsm8kInteraction(self.config) - - def test_init(self): - """Test Gsm8kInteraction initialization.""" - assert self.interaction._instance_dict == {} - assert self.interaction.config == self.config - assert self.interaction.name == "gsm8k" - - @pytest.mark.asyncio - async def test_start_interaction_with_instance_id(self): - """Test start_interaction with provided instance_id.""" - instance_id = "test_instance" - ground_truth = "42" - - result_id = await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - assert result_id == instance_id - assert instance_id in self.interaction._instance_dict - assert self.interaction._instance_dict[instance_id]["response"] == "" - assert self.interaction._instance_dict[instance_id]["ground_truth"] == ground_truth - assert self.interaction._instance_dict[instance_id]["reward"] == 0.0 - - @pytest.mark.asyncio - async def test_start_interaction_without_instance_id(self): - """Test start_interaction without provided instance_id (auto-generated).""" - ground_truth = "42" - - result_id = await self.interaction.start_interaction(ground_truth=ground_truth) - - assert result_id is not None - assert len(result_id) == 36 # UUID4 length - assert result_id in self.interaction._instance_dict - assert self.interaction._instance_dict[result_id]["ground_truth"] == ground_truth - - @pytest.mark.asyncio - async def test_start_interaction_without_ground_truth(self): - """Test start_interaction without ground_truth parameter.""" - instance_id = "test_instance" - - result_id = await self.interaction.start_interaction(instance_id=instance_id) - - assert result_id == instance_id - assert self.interaction._instance_dict[instance_id]["ground_truth"] is None - - @pytest.mark.asyncio - async def test_generate_response_correct_answer_with_prefix(self): - """Test generate_response with correct answer already having #### prefix.""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - messages = [{"role": "user", "content": "#### 42"}] - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages - ) - - assert should_terminate is True - assert response == "Your response is correct!" - assert reward == 1.0 - assert metadata == {} - assert self.interaction._instance_dict[instance_id]["response"] == "#### 42" - - @pytest.mark.asyncio - async def test_generate_response_correct_answer_without_prefix(self): - """Test generate_response with correct answer missing #### prefix.""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - messages = [{"role": "user", "content": "42"}] - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages - ) - - assert should_terminate is True - assert response == "Your response is correct!" - assert reward == 1.0 - assert self.interaction._instance_dict[instance_id]["response"] == "#### 42" - - @pytest.mark.asyncio - async def test_generate_response_incorrect_answer(self): - """Test generate_response with incorrect answer.""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - messages = [{"role": "user", "content": "24"}] - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages - ) - - assert should_terminate is False - assert response == "Your response is incorrect! You need to reflect on your answer and try again." - assert reward == 0.0 - assert self.interaction._instance_dict[instance_id]["response"] == "#### 24" - - @pytest.mark.asyncio - async def test_generate_response_multiple_messages(self): - """Test generate_response with multiple messages (should use last user message).""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - messages = [ - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "Let me think about this..."}, - {"role": "user", "content": "#### 42"}, - ] - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages - ) - - assert should_terminate is True - assert response == "Your response is correct!" - assert self.interaction._instance_dict[instance_id]["response"] == "#### 42" - - @pytest.mark.asyncio - async def test_generate_response_no_user_message(self): - """Test generate_response with no user messages.""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - messages = [{"role": "assistant", "content": "Hello!"}] - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages - ) - - assert should_terminate is False - assert self.interaction._instance_dict[instance_id]["response"] == "#### " - - @pytest.mark.asyncio - async def test_calculate_score_direct_call(self): - """Test calculate_score method directly.""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - # Set a response - self.interaction._instance_dict[instance_id]["response"] = "#### 42" - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0) as mock_compute: - score = await self.interaction.calculate_score(instance_id) - - assert score == 1.0 - mock_compute.assert_called_once_with("#### 42", "42", method="flexible", format_score=0.0, score=1.0) - - @pytest.mark.asyncio - async def test_calculate_score_with_kwargs(self): - """Test calculate_score method with additional kwargs.""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - # Set a response - self.interaction._instance_dict[instance_id]["response"] = "#### 24" - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0) as mock_compute: - score = await self.interaction.calculate_score(instance_id, extra_param="test") - - assert score == 0.0 - mock_compute.assert_called_once_with("#### 24", "42", method="flexible", format_score=0.0, score=1.0) - - @pytest.mark.asyncio - async def test_finalize_interaction(self): - """Test finalize_interaction method.""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - assert instance_id in self.interaction._instance_dict - - await self.interaction.finalize_interaction(instance_id) - - assert instance_id not in self.interaction._instance_dict - - @pytest.mark.asyncio - async def test_finalize_interaction_with_kwargs(self): - """Test finalize_interaction method with additional kwargs.""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - assert instance_id in self.interaction._instance_dict - - await self.interaction.finalize_interaction(instance_id, extra_param="test") - - assert instance_id not in self.interaction._instance_dict - - @pytest.mark.asyncio - async def test_finalize_nonexistent_interaction(self): - """Test finalize_interaction with non-existent instance_id.""" - instance_id = "nonexistent_instance" - - # This should raise KeyError - with pytest.raises(KeyError): - await self.interaction.finalize_interaction(instance_id) - - @pytest.mark.asyncio - async def test_full_interaction_workflow_correct(self): - """Test complete interaction workflow with correct answer.""" - ground_truth = "42" - - # Start interaction - instance_id = await self.interaction.start_interaction(ground_truth=ground_truth) - - # Generate response with correct answer - messages = [{"role": "user", "content": "42"}] - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages - ) - - assert should_terminate is True - assert reward == 1.0 - - # Finalize interaction - await self.interaction.finalize_interaction(instance_id) - assert instance_id not in self.interaction._instance_dict - - @pytest.mark.asyncio - async def test_full_interaction_workflow_incorrect(self): - """Test complete interaction workflow with incorrect answer.""" - ground_truth = "42" - - # Start interaction - instance_id = await self.interaction.start_interaction(ground_truth=ground_truth) - - # Generate response with incorrect answer - messages = [{"role": "user", "content": "24"}] - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages - ) - - assert should_terminate is False - assert reward == 0.0 - - # Continue with another attempt - messages.append({"role": "assistant", "content": response}) - messages.append({"role": "user", "content": "42"}) - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages - ) - - assert should_terminate is True - assert reward == 1.0 - - # Finalize interaction - await self.interaction.finalize_interaction(instance_id) - assert instance_id not in self.interaction._instance_dict - - @pytest.mark.asyncio - async def test_multiple_concurrent_interactions(self): - """Test multiple concurrent interaction instances.""" - ground_truth_1 = "42" - ground_truth_2 = "24" - - # Start multiple interactions - instance_id_1 = await self.interaction.start_interaction(ground_truth=ground_truth_1) - instance_id_2 = await self.interaction.start_interaction(ground_truth=ground_truth_2) - - assert len(self.interaction._instance_dict) == 2 - assert instance_id_1 in self.interaction._instance_dict - assert instance_id_2 in self.interaction._instance_dict - - # Test responses for both instances - messages_1 = [{"role": "user", "content": "42"}] - messages_2 = [{"role": "user", "content": "24"}] - - with patch("verl.utils.reward_score.gsm8k.compute_score", side_effect=[1.0, 1.0]): - should_terminate_1, _, reward_1, _ = await self.interaction.generate_response(instance_id_1, messages_1) - should_terminate_2, _, reward_2, _ = await self.interaction.generate_response(instance_id_2, messages_2) - - assert should_terminate_1 is True - assert should_terminate_2 is True - assert reward_1 == 1.0 - assert reward_2 == 1.0 - - # Finalize both interactions - await self.interaction.finalize_interaction(instance_id_1) - await self.interaction.finalize_interaction(instance_id_2) - - assert len(self.interaction._instance_dict) == 0 - - @pytest.mark.asyncio - async def test_edge_case_empty_messages(self): - """Test edge case with empty messages list.""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - messages = [] - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages - ) - - assert should_terminate is False - assert reward == 0.0 - assert self.interaction._instance_dict[instance_id]["response"] == "#### " - - @pytest.mark.asyncio - async def test_edge_case_message_without_content(self): - """Test edge case with message without content field.""" - instance_id = "test_instance" - ground_truth = "42" - - # Setup instance - await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth) - - messages = [ - {"role": "user"} # Missing content field - ] - - with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0): - should_terminate, response, reward, metadata = await self.interaction.generate_response( - instance_id, messages - ) - - assert should_terminate is False - assert reward == 0.0 - assert self.interaction._instance_dict[instance_id]["response"] == "#### None" - - def test_inheritance_from_base_interaction(self): - """Test that Gsm8kInteraction properly inherits from BaseInteraction.""" - from verl.interactions.base import BaseInteraction - - assert isinstance(self.interaction, BaseInteraction) - - # Test that all required methods are implemented - assert hasattr(self.interaction, "start_interaction") - assert hasattr(self.interaction, "generate_response") - assert hasattr(self.interaction, "calculate_score") - assert hasattr(self.interaction, "finalize_interaction") - - # Test that methods are callable - assert callable(self.interaction.start_interaction) - assert callable(self.interaction.generate_response) - assert callable(self.interaction.calculate_score) - assert callable(self.interaction.finalize_interaction) - - def test_name_attribute_initialization(self): - """Test name attribute initialization with different configs.""" - # Test with explicit name in config - config_with_name = {"name": "custom_gsm8k"} - interaction_with_name = Gsm8kInteraction(config_with_name) - assert interaction_with_name.name == "custom_gsm8k" - - # Test with default name when not provided in config - config_without_name = {} - interaction_without_name = Gsm8kInteraction(config_without_name) - assert interaction_without_name.name == "interaction_agent" # Default from BaseInteraction - - # Test that name is accessible as attribute - assert hasattr(self.interaction, "name") - assert self.interaction.name == "gsm8k" diff --git a/tests/interactions/test_interaction_registry.py b/tests/interactions/test_interaction_registry.py deleted file mode 100644 index 7fe193b52..000000000 --- a/tests/interactions/test_interaction_registry.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile - -import pytest -from omegaconf import OmegaConf - -from verl.interactions.base import BaseInteraction -from verl.interactions.gsm8k_interaction import Gsm8kInteraction -from verl.interactions.utils.interaction_registry import ( - get_interaction_class, - initialize_interactions_from_config, -) - - -class TestInteractionRegistry: - def test_get_interaction_class(self): - """Test getting interaction class by name.""" - # Test getting base interaction class - base_cls = get_interaction_class("verl.interactions.base.BaseInteraction") - assert base_cls == BaseInteraction - - # Test getting gsm8k interaction class - gsm8k_cls = get_interaction_class("verl.interactions.gsm8k_interaction.Gsm8kInteraction") - assert gsm8k_cls == Gsm8kInteraction - - def test_initialize_single_interaction_from_config(self): - """Test initializing single interaction from config.""" - # Create temporary config file - config_content = { - "interaction": [ - { - "name": "test_gsm8k", - "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", - "config": {}, - } - ] - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - OmegaConf.save(config_content, f.name) - temp_config_path = f.name - - try: - interaction_map = initialize_interactions_from_config(temp_config_path) - - # Check that interaction was created - assert len(interaction_map) == 1 - assert "test_gsm8k" in interaction_map - assert isinstance(interaction_map["test_gsm8k"], Gsm8kInteraction) - assert interaction_map["test_gsm8k"].name == "test_gsm8k" - finally: - os.unlink(temp_config_path) - - def test_initialize_multiple_interactions_from_config(self): - """Test initializing multiple interactions from config.""" - config_content = { - "interaction": [ - { - "name": "gsm8k_solver", - "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", - "config": {}, - }, - { - "name": "base_agent", - "class_name": "verl.interactions.base.BaseInteraction", - "config": {"custom_param": "test_value"}, - }, - ] - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - OmegaConf.save(config_content, f.name) - temp_config_path = f.name - - try: - interaction_map = initialize_interactions_from_config(temp_config_path) - - # Check that both interactions were created - assert len(interaction_map) == 2 - assert "gsm8k_solver" in interaction_map - assert "base_agent" in interaction_map - - # Check types - assert isinstance(interaction_map["gsm8k_solver"], Gsm8kInteraction) - assert isinstance(interaction_map["base_agent"], BaseInteraction) - - # Check names were injected - assert interaction_map["gsm8k_solver"].name == "gsm8k_solver" - assert interaction_map["base_agent"].name == "base_agent" - - # Check custom config was passed - assert interaction_map["base_agent"].config.get("custom_param") == "test_value" - finally: - os.unlink(temp_config_path) - - def test_initialize_interaction_without_explicit_name(self): - """Test that interaction name is derived from class name when not specified.""" - config_content = { - "interaction": [{"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}] - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - OmegaConf.save(config_content, f.name) - temp_config_path = f.name - - try: - interaction_map = initialize_interactions_from_config(temp_config_path) - - # Check that interaction name was derived from class name - assert len(interaction_map) == 1 - assert "gsm8k" in interaction_map # Should be "gsm8k" after removing "interaction" suffix - assert isinstance(interaction_map["gsm8k"], Gsm8kInteraction) - assert interaction_map["gsm8k"].name == "gsm8k" - finally: - os.unlink(temp_config_path) - - def test_initialize_empty_config(self): - """Test initializing from empty config.""" - config_content = {"interaction": []} - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - OmegaConf.save(config_content, f.name) - temp_config_path = f.name - - try: - interaction_map = initialize_interactions_from_config(temp_config_path) - assert len(interaction_map) == 0 - finally: - os.unlink(temp_config_path) - - def test_invalid_class_name(self): - """Test handling of invalid class name.""" - config_content = { - "interaction": [{"name": "invalid", "class_name": "invalid.module.InvalidClass", "config": {}}] - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - OmegaConf.save(config_content, f.name) - temp_config_path = f.name - - try: - with pytest.raises(ModuleNotFoundError): - initialize_interactions_from_config(temp_config_path) - finally: - os.unlink(temp_config_path) - - def test_duplicate_interaction_names(self): - """Test handling of duplicate interaction names.""" - config_content = { - "interaction": [ - {"name": "duplicate", "class_name": "verl.interactions.base.BaseInteraction", "config": {}}, - { - "name": "duplicate", - "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", - "config": {}, - }, - ] - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - OmegaConf.save(config_content, f.name) - temp_config_path = f.name - - try: - with pytest.raises(ValueError, match="Duplicate interaction name 'duplicate' found"): - initialize_interactions_from_config(temp_config_path) - finally: - os.unlink(temp_config_path) - - def test_auto_name_generation_edge_cases(self): - """Test automatic name generation for various class name patterns.""" - config_content = { - "interaction": [ - {"class_name": "verl.interactions.base.BaseInteraction", "config": {}}, - {"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}, - ] - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - OmegaConf.save(config_content, f.name) - temp_config_path = f.name - - try: - interaction_map = initialize_interactions_from_config(temp_config_path) - - # Check that names were generated correctly - assert len(interaction_map) == 2 - assert "base" in interaction_map # BaseInteraction -> base - assert "gsm8k" in interaction_map # Gsm8kInteraction -> gsm8k - finally: - os.unlink(temp_config_path) diff --git a/tests/kill_github_tests.sh b/tests/kill_github_tests.sh deleted file mode 100644 index 5c76d7658..000000000 --- a/tests/kill_github_tests.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -if [ "$#" -ne 1 ]; then - echo "Usage: $0 YOUR_GITHUB_TOKEN" - echo "Please provide exactly one input argument for your github token." - exit 1 -fi - -# Set your GitHub repository details -OWNER="volcengine" -REPO="verl" -TOKEN=$1 - -# API URL for workflow runs -API_URL="https://api.github.com/repos/$OWNER/$REPO/actions/runs?status=queued" - -# Check required commands -command -v jq >/dev/null 2>&1 || { echo "jq is required but not installed. Aborting."; exit 1; } - -# Get queued workflow runs -response=$(curl -s -H "Authorization: token $TOKEN" -H "Accept: application/vnd.github.v3+json" "$API_URL") - -# Run this for debugging -# echo $response - -# Extract run IDs -queued_run_ids=$(echo "$response" | jq -r '.workflow_runs[] | .id') - -if [ -z "$queued_run_ids" ]; then - echo "No queued workflow runs found." - exit 0 -fi - -# Cancel each queued run -for run_id in $queued_run_ids; do - echo "Cancelling run $run_id" - cancel_url="https://api.github.com/repos/$OWNER/$REPO/actions/runs/$run_id/cancel" - curl -s -X POST -H "Authorization: token $TOKEN" -H "Accept: application/vnd.github.v3+json" "$cancel_url" -done - -echo "Cancelled all queued workflow runs." diff --git a/tests/models/test_transformer.py b/tests/models/test_transformer.py deleted file mode 100644 index 111230a8a..000000000 --- a/tests/models/test_transformer.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input -from transformers import ( - AutoModelForCausalLM, - AutoModelForTokenClassification, - GemmaConfig, - LlamaConfig, - MistralConfig, - Qwen2Config, -) - -from verl.utils.model import compute_position_id_with_mask, create_random_mask -from verl.utils.torch_functional import log_probs_from_logits_all_rmpad, masked_mean - -# TODO(sgm): add more models for test -# we only need one scale for each model -test_configs = [ - LlamaConfig(num_hidden_layers=1), - MistralConfig(num_hidden_layers=1), - GemmaConfig(num_hidden_layers=1), - Qwen2Config(num_hidden_layers=1), -] - - -def test_hf_casual_models(): - batch_size = 4 - seqlen = 128 - response_length = 127 - - for config in test_configs: - # config = AutoConfig.from_pretrained(test_case) - with torch.device("cuda"): - model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) - model = model.to(device="cuda") - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") - attention_mask = create_random_mask( - input_ids=input_ids, - max_ratio_of_left_padding=0.1, - max_ratio_of_valid_token=0.8, - min_ratio_of_valid_token=0.5, - ) - position_ids = compute_position_id_with_mask( - attention_mask - ) # TODO(sgm): we can construct the position_ids_rmpad here - - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # input with input_ids_rmpad and postition_ids to enable flash attention varlen - logits_rmpad = model( - input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False - ).logits # (1, total_nnz, vocab_size) - - origin_logits = model( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False - ).logits - origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask) - - logits_rmpad = logits_rmpad.squeeze(0) - log_probs = log_probs_from_logits_all_rmpad( - input_ids_rmpad=input_ids_rmpad, - logits_rmpad=logits_rmpad, - indices=indices, - batch_size=batch_size, - seqlen=seqlen, - response_length=response_length, - ) # (batch, seqlen) - origin_log_probs = log_probs_from_logits_all_rmpad( - input_ids_rmpad=input_ids_rmpad, - logits_rmpad=origin_logits_rmpad, - indices=origin_logits_indices, - batch_size=batch_size, - seqlen=seqlen, - response_length=response_length, - ) # (batch, seqlen) - - torch.testing.assert_close( - masked_mean(log_probs, attention_mask[:, -response_length - 1 : -1]), - masked_mean(origin_log_probs, attention_mask[:, -response_length - 1 : -1]), - atol=1e-2, - rtol=1e-5, - ) - print("Check pass") - - -def test_hf_value_models(): - batch_size = 4 - seqlen = 128 - - for config in test_configs: - # config = AutoConfig.from_pretrained(test_case) - config.num_labels = 1 - config.classifier_dropout = 0 - config.hidden_dropout = 0 - with torch.device("cuda"): - model = AutoModelForTokenClassification.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) - model = model.to(device="cuda") - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") - attention_mask = create_random_mask( - input_ids=input_ids, - max_ratio_of_left_padding=0.1, - max_ratio_of_valid_token=0.8, - min_ratio_of_valid_token=0.5, - ) - position_ids = compute_position_id_with_mask( - attention_mask - ) # TODO(sgm): we can construct the position_ids_rmpad here - - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - origin_logits = model( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False - ).logits - - # input with input_ids_rmpad and postition_ids to enable flash attention varlen - rmpad_logits = model( - input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False - ).logits # (1, total_nnz, 1) - rmpad_logits = rmpad_logits.squeeze(0) - pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen) - - torch.testing.assert_close( - masked_mean(pad_logits, attention_mask[:, :, None]), - masked_mean(origin_logits, attention_mask[:, :, None]), - atol=1e-2, - rtol=1e-5, - ) - print("Value model check pass") - - -if __name__ == "__main__": - test_hf_casual_models() - test_hf_value_models() diff --git a/tests/models/test_transformers_ulysses.py b/tests/models/test_transformers_ulysses.py deleted file mode 100644 index 111b35ec9..000000000 --- a/tests/models/test_transformers_ulysses.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import contextlib -import copy -from dataclasses import dataclass - -import pytest -import torch -import torch.distributed -from flash_attn.bert_padding import index_first_axis, rearrange, unpad_input -from torch.distributed import init_device_mesh -from transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig, Qwen2Config - -from verl.models.transformers.monkey_patch import apply_monkey_patch -from verl.protocol import DataProto -from verl.utils.distributed import initialize_global_process_group -from verl.utils.model import compute_position_id_with_mask, create_random_mask -from verl.utils.ulysses import ( - gather_outputs_and_unpad, - get_ulysses_sequence_parallel_world_size, - set_ulysses_sequence_parallel_group, - ulysses_pad_and_slice_inputs, -) -from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager - -# TODO(sgm): add more models for test -# we only need one scale for each model - - -@dataclass -class SequenceParallelConfig: - config: PretrainedConfig - sp_size: int - is_valid: bool - - -def test_configs(): - return [ - SequenceParallelConfig( - LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True - ), - SequenceParallelConfig( - Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584), - sp_size=4, - is_valid=True, - ), - SequenceParallelConfig( - Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584), - sp_size=8, - is_valid=False, - ), - SequenceParallelConfig( - Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=4, is_valid=True - ), - SequenceParallelConfig( - Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=8, is_valid=True - ), - ] - - -def sync_model_parameters_global(layer): - # synchronize weights - for p in layer.parameters(): - torch.distributed.broadcast(tensor=p.data, src=0) - - -@pytest.mark.parametrize("test_config", test_configs()) -def test_hf_casual_fwd_bwd(test_config): - if not torch.distributed.is_initialized(): - initialize_global_process_group() - - context = contextlib.nullcontext() if test_config.is_valid else pytest.raises(AssertionError) - with context: - world_size = torch.distributed.get_world_size() - _hf_casual_fwd_bwd(test_config.config, test_config.sp_size, world_size // test_config.sp_size) - - # TODO: seems not work, will cause `socketStartConnect: Connect to xxx failed : Software caused connection abort` - # torch.distributed.destroy_process_group() - - -def _hf_casual_fwd(config, sp_size, dp_size): - assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" - - ulysses_device_mesh = init_device_mesh( - device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp") - ) - sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh) - - batch_size = 1 - seqlen = 128 - # response_length = 127 - - # patch before load - with torch.device("cuda"): - model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) - apply_monkey_patch(model, sp_size) - model = model.to(device="cuda") - sync_model_parameters_global(model) - - # different rank will generate different input_ids following fsdp - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") - attention_mask = create_random_mask( - input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8 - ) - position_ids = compute_position_id_with_mask( - attention_mask - ) # TODO(sgm): we can construct the position_ids_rmpad here - - model_inputs = { - "input_ids": input_ids.cuda(), - "attention_mask": attention_mask.cuda(), - "position_ids": position_ids.int().cuda(), - } - - model_inputs = DataProto.from_dict(model_inputs) - - # 1. perform ulysses forward - with sharding_manager: - model_inputs = sharding_manager.preprocess_data(model_inputs) - input_ids = model_inputs.batch["input_ids"] - attention_mask = model_inputs.batch["attention_mask"] - position_ids = model_inputs.batch["position_ids"] - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # slice input tensor for ulysses - # input_ids are padded and sliced - # postition_ids are only padded but not sliced - input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() - ) - - # input with input_ids_rmpad and postition_ids to enable flash attention varlen - logits_split_in_seq = model( - input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False - ).logits # (1, total_nnz/n, vocab_size) - - # all_gather output - logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size) - - # 2. perform normal forward - set_ulysses_sequence_parallel_group(None) - logits_rmpad_local = model( - input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False - ).logits # (1, total_nnz, vocab_size) - - mean_local = logits_rmpad_local.mean() - mean_full = logits_full.mean() - torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5) - - -def _hf_casual_fwd_bwd(config, sp_size, dp_size): - assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" - - ulysses_device_mesh = init_device_mesh( - device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp") - ) - sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh) - - batch_size = 1 - seqlen = 128 - # response_length = 127 - - # patch before load - with torch.device("cuda"): - model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) - apply_monkey_patch(model, sp_size) - model = model.to(device="cuda") - sync_model_parameters_global(model) - - # different rank will generate different input_ids following fsdp - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda") - attention_mask = create_random_mask( - input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8 - ) - position_ids = compute_position_id_with_mask( - attention_mask - ) # TODO(sgm): we can construct the position_ids_rmpad here - - model_inputs = { - "input_ids": input_ids.cuda(), - "attention_mask": attention_mask.cuda(), - "position_ids": position_ids.int().cuda(), - } - - model_inputs = DataProto.from_dict(model_inputs) - - # 1. perform ulysses forward - with sharding_manager: - model_inputs = sharding_manager.preprocess_data(model_inputs) - input_ids = model_inputs.batch["input_ids"] - attention_mask = model_inputs.batch["attention_mask"] - position_ids = model_inputs.batch["position_ids"] - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - # unpad the position_ids to align the rotary - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # slice input tensor for ulysses - # input_ids are padded and sliced - # postition_ids are only padded but not sliced - input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() - ) - - # input with input_ids_rmpad and postition_ids to enable flash attention varlen - logits_split_in_seq = model( - input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False - ).logits # (1, total_nnz/n, vocab_size) - - # all_gather output - logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size) - - # 2. perform normal forward - set_ulysses_sequence_parallel_group(None) - input_ids_full = copy.deepcopy(input_ids_rmpad) - position_ids_full = copy.deepcopy(position_ids_rmpad) - model_no_sp = copy.deepcopy(model) - logits_rmpad_local = model_no_sp( - input_ids_full, position_ids=position_ids_full, use_cache=False - ).logits # (1, total_nnz, vocab_size) - - mean_local = logits_rmpad_local.mean() - mean_full = logits_full.mean() - - mean_full.backward() - mean_local.backward() - - # 3. check the gradients - grad = model.model.layers[0].self_attn.q_proj.weight.grad - grad_full = model_no_sp.model.layers[0].self_attn.q_proj.weight.grad - torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5) - torch.testing.assert_close(grad, grad_full, atol=1e-2, rtol=1e-5) - - -if __name__ == "__main__": - pytest.main([__file__, "-svv"]) diff --git a/tests/single_controller/__init__.py b/tests/single_controller/__init__.py deleted file mode 100644 index 1cd1e8433..000000000 --- a/tests/single_controller/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/single_controller/base/test_decorator.py b/tests/single_controller/base/test_decorator.py deleted file mode 100644 index 5447d65ce..000000000 --- a/tests/single_controller/base/test_decorator.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -import verl.single_controller.base.decorator as decorator_module -from verl.single_controller.base.decorator import ( - DISPATCH_MODE_FN_REGISTRY, - Dispatch, - _check_dispatch_mode, - get_predefined_dispatch_fn, - register_dispatch_mode, - update_dispatch_mode, -) - - -@pytest.fixture -def reset_dispatch_registry(): - # Store original state - original_registry = DISPATCH_MODE_FN_REGISTRY.copy() - yield - # Reset registry after test - decorator_module.DISPATCH_MODE_FN_REGISTRY.clear() - decorator_module.DISPATCH_MODE_FN_REGISTRY.update(original_registry) - - -def test_register_new_dispatch_mode(reset_dispatch_registry): - # Test registration - def dummy_dispatch(worker_group, *args, **kwargs): - return args, kwargs - - def dummy_collect(worker_group, output): - return output - - register_dispatch_mode("TEST_MODE", dummy_dispatch, dummy_collect) - - # Verify enum extension - _check_dispatch_mode(Dispatch.TEST_MODE) - - # Verify registry update - assert get_predefined_dispatch_fn(Dispatch.TEST_MODE) == { - "dispatch_fn": dummy_dispatch, - "collect_fn": dummy_collect, - } - # Clean up - Dispatch.remove("TEST_MODE") - - -def test_update_existing_dispatch_mode(reset_dispatch_registry): - # Store original implementation - original_mode = Dispatch.ONE_TO_ALL - - # New implementations - def new_dispatch(worker_group, *args, **kwargs): - return args, kwargs - - def new_collect(worker_group, output): - return output - - # Test update= - update_dispatch_mode(original_mode, new_dispatch, new_collect) - - # Verify update - assert get_predefined_dispatch_fn(original_mode)["dispatch_fn"] == new_dispatch - assert get_predefined_dispatch_fn(original_mode)["collect_fn"] == new_collect diff --git a/tests/single_controller/check_worker_alive/main.py b/tests/single_controller/check_worker_alive/main.py deleted file mode 100644 index cbdee9a8d..000000000 --- a/tests/single_controller/check_worker_alive/main.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys -import time - -import ray - -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup - - -@ray.remote -class TestActor(Worker): - def __init__(self) -> None: - super().__init__() - - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def foo(self, wait_time): - time.sleep(wait_time) - sys.exit(1) - - -if __name__ == "__main__": - wait_time = int(os.getenv("WAIT_TIME", "10")) - - ray.init() - - # test single-node-no-partition - print("test single-node-no-partition") - resource_pool = RayResourcePool([2], use_gpu=False) - class_with_args = RayClassWithInitArgs(cls=TestActor) - - print("create worker group") - wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="test") - - wg.start_worker_aliveness_check(1) - time.sleep(1) - - print(time.time(), "start foo") - - _ = wg.foo(wait_time) - print("foo started") - - print( - time.time(), - f"wait 6x wait time {wait_time * 6} to let signal returned to process but still not exceed process wait time", - ) - time.sleep(wait_time * 6) - - ray.shutdown() diff --git a/tests/single_controller/detached_worker/README.md b/tests/single_controller/detached_worker/README.md deleted file mode 100644 index b06c4c614..000000000 --- a/tests/single_controller/detached_worker/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Detached Worker -## How to run (Only on a single node) -- Start a local ray cluster: -```bash -ray start --head --port=6379 -``` -- Run the server -```bash -python3 server.py -``` -- On another terminal, Run the client -```bash -python3 client.py -``` diff --git a/tests/single_controller/detached_worker/client.py b/tests/single_controller/detached_worker/client.py deleted file mode 100644 index 52f2c7242..000000000 --- a/tests/single_controller/detached_worker/client.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -In client, we can get the server handler and send RPC request -""" - -import ray -import torch -from server import Trainer -from tensordict import TensorDict - -from verl import DataProto -from verl.single_controller.ray import RayClassWithInitArgs -from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup - - -def compute_position_id_with_mask(mask): - return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None) - - -if __name__ == "__main__": - ray.init(address="auto", namespace="verl") - # get the worker group using names - worker_names = ["trainerTrainer_0:0", "trainerTrainer_0:1"] - cls_with_init_args = RayClassWithInitArgs(cls=Trainer) - worker_group = NVMegatronRayWorkerGroup.from_detached( - worker_names=worker_names, ray_cls_with_init=cls_with_init_args - ) - - batch_size = 16 - sequence_length = 1024 - - # give Trainer some data to train - input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda") - attention_mask = torch.ones_like(input_ids) - position_ids = compute_position_id_with_mask(attention_mask) - - data = DataProto( - batch=TensorDict( - {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}, - batch_size=batch_size, - ), - meta_info={}, - ) - - output = worker_group.train_model(data) - - print(output) diff --git a/tests/single_controller/detached_worker/run.sh b/tests/single_controller/detached_worker/run.sh deleted file mode 100644 index a3c638793..000000000 --- a/tests/single_controller/detached_worker/run.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash -ray start --head --port=6379 -python3 server.py -python3 client.py -ray stop --force \ No newline at end of file diff --git a/tests/single_controller/detached_worker/server.py b/tests/single_controller/detached_worker/server.py deleted file mode 100644 index 57e555a3a..000000000 --- a/tests/single_controller/detached_worker/server.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Server starts a Trainer. Client sends data to the server to train. -""" - -import os - -os.environ["MEGATRON_USE_CUDA_TIMER"] = "0" -os.environ["MEGATRON_START_PROCESS_TIMER"] = "False" -os.environ["NCCL_DEBUG"] = "WARN" - -import ray -import torch -from megatron.core import parallel_state as mpu -from megatron.core import tensor_parallel -from megatron.core.models.gpt.gpt_model import ModelType -from omegaconf import OmegaConf -from tensordict import TensorDict -from torch import nn -from transformers import LlamaConfig - -from verl import DataProto -from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.base.megatron.worker import MegatronWorker -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool -from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup -from verl.utils.megatron.optimizer import get_megatron_optimizer -from verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config - - -@ray.remote -class Trainer(MegatronWorker): - def __init__(self): - super().__init__() - - if not torch.distributed.is_initialized(): - rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group(backend="nccl") - torch.cuda.set_device(rank) - - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - mpu.initialize_model_parallel( - tensor_model_parallel_size=2, - pipeline_model_parallel_size=1, - virtual_pipeline_model_parallel_size=None, - pipeline_model_parallel_split_rank=None, - use_sharp=False, - context_parallel_size=1, - expert_model_parallel_size=1, - nccl_communicator_config_path=None, - ) - tensor_parallel.model_parallel_cuda_manual_seed(10) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - actor_model_config = LlamaConfig( - vocab_size=256, - hidden_size=2048, - intermediate_size=5504, - num_hidden_layers=24, - num_attention_heads=16, - num_key_value_heads=16, - ) - - megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16) - self.megatron_config = megatron_config - - def megatron_actor_model_provider(pre_process, post_process): - # vpp is not supported yet because it will hang for some reason. Need debugging - # this_megatron_config = copy.deepcopy(megatron_config) - # this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank - parallel_model = ParallelLlamaForCausalLMRmPadPP( - config=actor_model_config, - megatron_config=megatron_config, - pre_process=pre_process, - post_process=post_process, - ) - parallel_model.cuda() - return parallel_model - - actor_module = get_model( - model_provider_func=megatron_actor_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True, - ) - actor_module = nn.ModuleList(actor_module) - - optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0}) - - optim_config = init_megatron_optim_config(optim_config) - self.optimizer_config = optim_config - actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config) - - self.model = actor_module[0] - self.optimizer = actor_optimizer - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - def train_model(self, data: DataProto) -> DataProto: - input_ids = data.batch["input_ids"] - attention_mask = data.batch["attention_mask"] - position_ids = data.batch["position_ids"] - - self.optimizer.zero_grad() - self.model.zero_grad_buffer( - zero_buffer=(not self.optimizer_config.use_distributed_optimizer) - ) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm - # update for 1 iteration - output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits - output.mean().backward() - - update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step( - self.megatron_config, self.megatron_config.timers - ) - - return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0])) - - -if __name__ == "__main__": - ray.init(address="auto", namespace="verl") - - resource_pool = RayResourcePool(process_on_nodes=[2], detached=True) - cls_with_init_args = RayClassWithInitArgs(cls=Trainer) - worker_group = NVMegatronRayWorkerGroup( - resource_pool=resource_pool, - ray_cls_with_init=cls_with_init_args, - name_prefix="trainer", - detached=True, - ) - - worker_group.init_model() - - worker_names = worker_group.worker_names - print(worker_names) diff --git a/tests/single_controller/test_auto_padding_on_cpu.py b/tests/single_controller/test_auto_padding_on_cpu.py deleted file mode 100644 index f2c441243..000000000 --- a/tests/single_controller/test_auto_padding_on_cpu.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import ray -import torch - -from verl import DataProto -from verl.protocol import DataProtoConfig -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup - -# or set env var VERL_AUTO_PADDING = "1" / "true" -DataProtoConfig.auto_padding = True - - -@ray.remote -class Actor(Worker): - def __init__(self) -> None: - super().__init__() - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def add(self, data: DataProto): - data.batch["a"] += self.rank - return data - - -def test_auto_padding(): - ray.init(num_cpus=100) - - chunk_size = 4 - actor_cls = RayClassWithInitArgs(cls=Actor) - resource_pool = RayResourcePool(process_on_nodes=[chunk_size], use_gpu=False) - actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls) - - # test locally first - for test_size in range(4, 20): - local_data = DataProto.from_dict({"a": torch.zeros(test_size)}, {"na": np.zeros(test_size, dtype=object)}) - # print(f"before padding, local_data = {local_data}") - padding_size = (chunk_size - (test_size % chunk_size)) if (test_size % chunk_size > 0) else 0 - local_data.padding(padding_size) - # print(f"after padding, local_data = {local_data}") - assert len(local_data) == len(local_data) + len(local_data) % chunk_size, ( - f"expecting padded length to be {len(local_data) + len(local_data) % chunk_size}, but got {len(local_data)}" - ) - chunked = local_data.chunk(chunk_size) - assert len(chunked) == chunk_size, f"during test_size = {test_size}, expecting {chunk_size}, got {chunked}" - for dp in chunked: - assert len(dp) == test_size // chunk_size + bool(test_size % chunk_size), ( - f"test size = {test_size}, expecting dp to be length of " - f"{test_size // chunk_size + bool(test_size % chunk_size)}, but got {len(dp)}: {dp} {chunked}" - ) - - # test with RayWorkerGroup method decorated as dispatch_mode=Dispatch.DP_COMPUTE_PROTO - data = DataProto.from_dict({"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}) - output = actor_wg.add(data) - - print(output.batch["a"]) - assert len(output) == 10 - - data = DataProto.from_dict({"a": torch.zeros(1)}, {"na": np.array([str(i) for i in range(1)], dtype=object)}) - output = actor_wg.add(data) - - print(output.batch["a"]) - assert len(output) == 1 - - data = DataProto.from_dict({"a": torch.zeros(8)}, {"na": np.array([str(i) for i in range(8)], dtype=object)}) - output = actor_wg.add(data) - - print(output.batch["a"]) - assert len(output) == 8 - - # test data proto specific config - DataProtoConfig.auto_padding = False - - data = DataProto.from_dict( - {"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True - ) - output = actor_wg.add(data) - print(output.batch["a"]) - assert len(output) == 10 - - data = DataProto.from_single_dict( - {"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True - ) - output = actor_wg.add(data) - - print(output.batch["a"]) - assert len(output) == 1 - - data = DataProto.from_single_dict({"a": torch.zeros(8), "na": np.array([str(i) for i in range(8)], dtype=object)}) - output = actor_wg.add(data) - - print(output.batch["a"]) - assert len(output) == 8 - - ray.shutdown() - - -if __name__ == "__main__": - test_auto_padding() diff --git a/tests/single_controller/test_colocated_workers.py b/tests/single_controller/test_colocated_workers.py deleted file mode 100644 index cdaa74768..000000000 --- a/tests/single_controller/test_colocated_workers.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ray - -from verl import DataProto -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray.base import ( - RayClassWithInitArgs, - RayResourcePool, - RayWorkerGroup, - create_colocated_worker_cls, -) - - -@ray.remote -class Actor(Worker): - def __init__(self) -> None: - super().__init__() - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def add(self, data: DataProto): - data.batch["a"] += self.rank - return data - - -@ray.remote -class Critic(Worker): - def __init__(self, config) -> None: - super().__init__() - self.config = config - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - async def sub(self, data: DataProto): - data.batch["a"] -= self.config["b"] - return data - - -def test_colocated_workers(): - ray.init() - - import torch - - data = DataProto.from_dict({"a": torch.zeros(10)}) - # create separate workers on the same resource pool - actor_cls = RayClassWithInitArgs(cls=Actor) - critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10}) - resource_pool = RayResourcePool(process_on_nodes=[2]) - - actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls) - critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls) - - expected_actor_output = actor_wg.add(data) - expected_critic_output = critic_wg.sub(data) - - # create colocated workers - cls_dict = {"actor": actor_cls, "critic": critic_cls} - ray_cls_with_init = create_colocated_worker_cls(cls_dict) - wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) - spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys()) - - colocated_actor_wg = spawn_wg["actor"] - colocated_critic_wg = spawn_wg["critic"] - - actor_output = colocated_actor_wg.add(data) - critic_output = colocated_critic_wg.sub(data) - - torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0) - torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0) - - ray.shutdown() diff --git a/tests/single_controller/test_colocated_workers_fused.py b/tests/single_controller/test_colocated_workers_fused.py deleted file mode 100644 index 93b1a728e..000000000 --- a/tests/single_controller/test_colocated_workers_fused.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ray - -from verl import DataProto -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray.base import ( - RayClassWithInitArgs, - RayResourcePool, - RayWorkerGroup, - create_colocated_worker_cls_fused, -) - - -@ray.remote -class Actor(Worker): - def __init__(self) -> None: - super().__init__() - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def add(self, data: DataProto): - data.batch["a"] += self.rank - return data - - -@ray.remote -class Critic(Worker): - def __init__(self, config) -> None: - super().__init__() - self.config = config - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def sub(self, data: DataProto): - data.batch["a"] -= self.config["b"] - return data - - -def test_colocated_workers_fused(): - ray.init() - - import torch - - data = DataProto.from_dict({"a": torch.zeros(10)}) - # create separate workers on the same resource pool - actor_cls = RayClassWithInitArgs(cls=Actor) - critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10}) - resource_pool = RayResourcePool(process_on_nodes=[2]) - - actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls) - critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls) - - expected_actor_output = actor_wg.add(data) - expected_critic_output = critic_wg.sub(data) - - # create colocated workers - cls_dict = {"actor": actor_cls, "critic": critic_cls} - ray_cls_with_init = create_colocated_worker_cls_fused(cls_dict) - wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) - spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys()) - - colocated_actor_wg = spawn_wg["actor"] - colocated_critic_wg = spawn_wg["critic"] - - actor_output = colocated_actor_wg.add(data) - critic_output = colocated_critic_wg.sub(data) - - torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0) - torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0) - - ray.shutdown() diff --git a/tests/single_controller/test_data_transfer.py b/tests/single_controller/test_data_transfer.py deleted file mode 100644 index 13777b0bd..000000000 --- a/tests/single_controller/test_data_transfer.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -In this test, we instantiate a data parallel worker with 8 GPUs -""" - -import ray -import tensordict -import torch -from codetiming import Timer -from torch import distributed as dist - -from verl import DataProto -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup -from verl.utils.ray_utils import parallel_put - - -@ray.remote -class DummyWorker(Worker): - def __init__(self): - super().__init__() - dist.init_process_group() - - @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) - def do_nothing(self, data): - for key in data.batch.keys(): - data.batch[key] += 1 - if tensordict.__version__ >= "0.5.0": - data.batch = data.batch.consolidate() - return data - - -def test_data_transfer(): - ray.init() - # construct resource pool - resource_pool = RayResourcePool([8]) - cls_with_init = RayClassWithInitArgs(cls=DummyWorker) - # construct worker group - wg = RayWorkerGroup(resource_pool, cls_with_init) - - # this is real dataset size - batch_size = 4096 - seqlen = 32768 - - data_dict = {} - - for i in range(2): - data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen)) - - data = DataProto.from_dict(tensors=data_dict) - - print(data) - - # we manually split data here and send to each worker - data_list = data.chunk(wg.world_size) - - for i in range(wg.world_size): - # consolidate is necessary - if tensordict.__version__ >= "0.5.0": - data_list[i].batch = data_list[i].batch.consolidate() - - with Timer(name="ray.pickle", initial_text=True): - for i in range(wg.world_size): - ray.cloudpickle.pickle.dumps(data_list[i]) - - with Timer(name="raw.pickle", initial_text=True): - import pickle - - for i in range(wg.world_size): - pickle.dumps(data_list[i]) - - # we put in advance - with Timer(name="put", initial_text=True): - # takes around 40 seconds - data_list_ref = parallel_put(data_list) - # for i in range(wg.world_size): - # data_list[i] = ray.put(data_list[i]) - - with Timer(name="launch", initial_text=True): - output_ref = wg.do_nothing(data_list_ref) - - with Timer(name="get", initial_text=True): - # takes around 40 seconds - output_lst = ray.get(output_ref) - - for input_data, output_data in zip(data_list, output_lst, strict=True): - for key in input_data.batch.keys(): - assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), ( - input_data.batch[key], - output_data.batch[key], - key, - ) - - ray.shutdown() diff --git a/tests/single_controller/test_decorator_on_cpu.py b/tests/single_controller/test_decorator_on_cpu.py deleted file mode 100644 index 4dfec6331..000000000 --- a/tests/single_controller/test_decorator_on_cpu.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import time - -import pytest -import ray -import torch -from tensordict import TensorDict - -from verl.protocol import DataProto, DataProtoFuture -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.base.worker import Worker -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup - - -# Pytest fixture for Ray setup/teardown -@pytest.fixture -def ray_init_shutdown(): - ray.init(num_cpus=100) - yield - ray.shutdown() - - -# Define a simple worker for testing -@ray.remote -class DecoratorTestWorker(Worker): - def __init__(self, initial_value=0): - super().__init__() - self.value = initial_value - # Simulate some setup if needed - time.sleep(0.1) # Ensure worker init completes - - # Test method for synchronous DP compute (default behavior) - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def dp_compute(self, data: DataProto) -> DataProto: - time.sleep(0.1) # Simulate work - rank_value = torch.tensor(self.rank, device=data.batch["input"].device, dtype=data.batch["input"].dtype) - data.batch["output"] = data.batch["input"] + self.value + rank_value - return data - - # Test async def method with DP compute (default behavior) - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) - async def async_dp_compute(self, data: DataProto) -> DataProto: - # Simulate async work - await asyncio.sleep(0.1) # Simulate async work - rank_value = torch.tensor(self.rank, device=data.batch["input"].device, dtype=data.batch["input"].dtype) - data.batch["output_async"] = data.batch["input"] * 2 + self.value + rank_value - return data - - -# Test function for synchronous DP compute -def test_decorator_dp_compute(ray_init_shutdown): - """ - Tests the default behavior of a synchronous decorated method with DP_COMPUTE_PROTO. - Verifies the result correctness. - """ - num_workers = 2 - resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) # Use CPU for simplicity - cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=10) - worker_group = RayWorkerGroup( - resource_pool, cls_with_args, name_prefix=f"decorator_test_sync_dp_{int(time.time())}" - ) - - # Prepare input data (size 4, for 2 workers) - input_tensor = torch.arange(4, dtype=torch.float32) - data = DataProto(batch=TensorDict({"input": input_tensor}, batch_size=[4])) - - # Call the decorated method - output = worker_group.dp_compute(data) - - # Assert the result correctness - assert isinstance(output, DataProto), "Expected DataProto result" - assert "output" in output.batch.keys() - assert len(output) == len(data), "Output length should match input length" - - # Expected output calculation for DP_COMPUTE_PROTO with 2 workers - # Worker 0 gets data[0:2], Worker 1 gets data[2:4] - # Worker 0 adds initial_value(10) + rank(0) = 10 - # Worker 1 adds initial_value(10) + rank(1) = 11 - expected_output_part1 = torch.tensor([0, 1], dtype=torch.float32) + 10 + 0 - expected_output_part2 = torch.tensor([2, 3], dtype=torch.float32) + 10 + 1 - expected_output = torch.cat([expected_output_part1, expected_output_part2]) - - torch.testing.assert_close(output.batch["output"], expected_output, msg="Sync DP compute output data mismatch") - - -# Test function for async def method with DP compute -def test_decorator_async_function(ray_init_shutdown): - """ - Tests the decorator with an `async def` method using DP_COMPUTE_PROTO. - Verifies that the call returns a future and the result is correct after .get(). - """ - num_workers = 2 - resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) - cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=5) - worker_group = RayWorkerGroup( - resource_pool, cls_with_args, name_prefix=f"decorator_test_async_dp_{int(time.time())}" - ) - - # Prepare input data (size 4, for 2 workers) - input_tensor = torch.arange(4, dtype=torch.float32) - data = DataProto(batch=TensorDict({"input": input_tensor}, batch_size=[4])) - - # Call the async decorated method - this should return a future - future_output: DataProtoFuture = worker_group.async_dp_compute(data) - - # Assert that the call returned a future - assert isinstance(future_output, DataProtoFuture), "Expected DataProtoFuture for async def call" - - # Get the result (this should block) - result_data = future_output.get() - - # Assert the result correctness - assert isinstance(result_data, DataProto) - assert "output_async" in result_data.batch.keys() - assert len(result_data) == len(data), "Output length should match input length" - - # Expected output calculation for DP_COMPUTE_PROTO with 2 workers - # Worker 0 gets data[0:2], Worker 1 gets data[2:4] - # Worker 0 calculates: input * 2 + initial_value(5) + rank(0) - # Worker 1 calculates: input * 2 + initial_value(5) + rank(1) - expected_output_part1 = (torch.tensor([0, 1], dtype=torch.float32) * 2) + 5 + 0 - expected_output_part2 = (torch.tensor([2, 3], dtype=torch.float32) * 2) + 5 + 1 - expected_output = torch.cat([expected_output_part1, expected_output_part2]) - - torch.testing.assert_close( - result_data.batch["output_async"], expected_output, msg="Async DP compute output data mismatch" - ) diff --git a/tests/single_controller/test_driverfunc_to_worker.py b/tests/single_controller/test_driverfunc_to_worker.py deleted file mode 100644 index a38d790d6..000000000 --- a/tests/single_controller/test_driverfunc_to_worker.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import ray -import torch -from tensordict import TensorDict - -from verl import DataProto -from verl.single_controller.base.worker import Worker -from verl.single_controller.ray import RayWorkerGroup -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool - -os.environ["RAY_DEDUP_LOGS"] = "0" -os.environ["NCCL_DEBUG"] = "WARN" - - -@ray.remote -class ModelActor(Worker): - def __init__(self): - pass - - -class HackSelf: - def __init__(self): - pass - - -def get_aux_metrics(self, test_proto): - sequence_ids = test_proto.batch["sequence_ids"] - decode_count = [] - for i in range(sequence_ids.size(0)): - decode_count.append(len(sequence_ids[i].tolist())) - ret_proto = DataProto( - batch=TensorDict( - {"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0) - ) - ) - return ret_proto - - -def test(): - # construct model - ray.init() - - # create 2 workers, each hold a GPU - resource_pool = RayResourcePool([2], use_gpu=True, name_prefix="a") - - class_with_args = RayClassWithInitArgs(cls=ModelActor) - shard_wg = RayWorkerGroup(resource_pool, class_with_args) - - test_bs = 8 - test_proto = DataProto( - TensorDict( - { - "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64), - }, - batch_size=test_bs, - ), - meta_info={"query_length": 1536}, - ) - - # Sharding among different ranks - ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto) - - # compare execute on driver - hs = HackSelf() - ret_proto2 = get_aux_metrics(hs, test_proto) - - torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"]) - - ray.shutdown() diff --git a/tests/single_controller/test_fused_workers_on_cpu.py b/tests/single_controller/test_fused_workers_on_cpu.py deleted file mode 100644 index 527ddc102..000000000 --- a/tests/single_controller/test_fused_workers_on_cpu.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ray - -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray.base import ( - RayClassWithInitArgs, - RayResourcePool, - RayWorkerGroup, - create_colocated_worker_raw_cls, -) - - -@ray.remote -class Actor(Worker): - def __init__(self) -> None: - super().__init__() - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def add(self, x): - x += self.rank - return x - - -@ray.remote -class Critic(Worker): - def __init__(self, val) -> None: - super().__init__() - self.val = val - - @register(dispatch_mode=Dispatch.ALL_TO_ALL) - def sub(self, x): - x -= self.val - return x - - -actor_cls = RayClassWithInitArgs(cls=Actor) -critic_cls = RayClassWithInitArgs(cls=Critic, val=10) -cls_dict = {"actor": actor_cls, "critic": critic_cls} -FusedBaseClass = create_colocated_worker_raw_cls(cls_dict) - - -@ray.remote -class HybridWorker(FusedBaseClass): - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def foo(self, x): - return self.critic.sub(self.actor.add(x)) - - -def test_fused_workers(): - ray.init(num_cpus=100) - - # create separate workers on the same resource pool - process_on_nodes = [2] - resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=False) - - # create colocated workers - hybrid_cls_with_init = RayClassWithInitArgs(cls=HybridWorker) - hybrid_cls_with_init.fused_worker_used = True - - fused_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=hybrid_cls_with_init) - fused_wg.fuse(cls_dict.keys()) - - x = fused_wg.actor.add(0.1) - print(x) - y = fused_wg.critic.sub(x) - print(y) - z = fused_wg.foo(0.1) - print(z) - for i, j in zip(y, z, strict=True): - assert i == j - - ray.shutdown() - - -if __name__ == "__main__": - test_fused_workers() diff --git a/tests/single_controller/test_high_level_scheduling_api.py b/tests/single_controller/test_high_level_scheduling_api.py deleted file mode 100644 index 52cc7c7df..000000000 --- a/tests/single_controller/test_high_level_scheduling_api.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import ray - -from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool - - -@ray.remote -class TestActor(Worker): - # TODO: pass *args and **kwargs is bug prone and not very convincing - def __init__(self, cuda_visible_devices=None) -> None: - super().__init__(cuda_visible_devices) - - def get_node_id(self): - return ray.get_runtime_context().get_node_id() - - -def test(): - ray.init() - - # test single-node-no-partition - print("test single-node-no-partition") - resource_pool = RayResourcePool([8], use_gpu=True) - - class_with_args = RayClassWithInitArgs(cls=TestActor) - - print("create actor worker group") - actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_actor") - print("create critic worker group") - critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="hight_level_api_critic") - print("create rm worker group") - rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_rm") - print("create ref worker group") - ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_ref") - - assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - - del actor_wg - del critic_wg - del rm_wg - del ref_wg - - [ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()] - print("wait 5s to remove placemeng_group") - time.sleep(5) - # test single-node-multi-partition - - print("test single-node-multi-partition") - rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm") - ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref") - total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool) - - assert rm_resource_pool.world_size == 4 - assert ref_resource_pool.world_size == 4 - assert total_resource_pool.world_size == 8 - - actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_actor") - critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_critic") - rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix="high_level_api_rm") - ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix="high_level_api_ref") - - assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] - assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)] - assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)] - - ray.shutdown() diff --git a/tests/single_controller/test_ray_collectives.py b/tests/single_controller/test_ray_collectives.py deleted file mode 100644 index 3722a8f80..000000000 --- a/tests/single_controller/test_ray_collectives.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Test for using ray collective group. -Suppose we Actor and Rollout. Actor contains 4 workers and Rollout contains 2 workers. We established a Worker to -Rollout relationship by using collective groups -Actor: rank 0, 1 - Rollout rank 0 -Rollout rank 2, 3 - Rollout rank 1 -Then, we initiate 4 p2p comms from actor to rollout -""" - -import ray -import ray.util.collective as collective -import torch - -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup - - -@ray.remote -class Actor(Worker): - @register(Dispatch.ONE_TO_ALL) - def init(self): - remote_rank = self.rank // 2 - self.group_name = f"A{self.rank}_R{remote_rank}" - collective.init_collective_group(world_size=2, rank=0, backend="nccl", group_name=self.group_name) - - @register(Dispatch.ONE_TO_ALL, blocking=False) - def send_tensors(self): - tensor = torch.ones(size=(4,), dtype=torch.float32, device="cuda") * self.rank - collective.send(tensor=tensor, dst_rank=1, group_name=self.group_name) - - -@ray.remote -class Rollout(Worker): - @register(Dispatch.ONE_TO_ALL) - def init(self): - self.remote_first_rank = self.rank * 2 - self.remote_second_rank = self.remote_first_rank + 1 - self.first_group_name = f"A{self.remote_first_rank}_R{self.rank}" - self.second_group_name = f"A{self.remote_second_rank}_R{self.rank}" - - collective.init_collective_group(world_size=2, rank=1, backend="nccl", group_name=self.first_group_name) - collective.init_collective_group(world_size=2, rank=1, backend="nccl", group_name=self.second_group_name) - - @register(Dispatch.ONE_TO_ALL, blocking=False) - def receive_tensors(self): - self.tensor1 = torch.randn(size=(4,), dtype=torch.float32, device="cuda") - self.tensor2 = torch.randn(size=(4,), dtype=torch.float32, device="cuda") - - collective.recv(self.tensor1, src_rank=0, group_name=self.first_group_name) - collective.recv(self.tensor2, src_rank=0, group_name=self.second_group_name) - - @register(Dispatch.ONE_TO_ALL) - def get_tensors(self): - return {f"src_{self.remote_first_rank}": self.tensor1, f"src_{self.remote_second_rank}": self.tensor2} - - -def test_ray_collective_group(): - ray.init() - - actor_resource_pool = RayResourcePool([4]) - rollout_resource_pool = RayResourcePool([2]) - - actor_cls = RayClassWithInitArgs(cls=Actor) - rollout_cls = RayClassWithInitArgs(cls=Rollout) - - actor_wg = RayWorkerGroup( - resource_pool=actor_resource_pool, ray_cls_with_init=actor_cls, name_prefix="collective_group_actor" - ) - rollout_wg = RayWorkerGroup( - resource_pool=rollout_resource_pool, ray_cls_with_init=rollout_cls, name_prefix="collective_group_rollout" - ) - - actor_wg.init() - rollout_wg.init() - - out1 = actor_wg.send_tensors() - out2 = rollout_wg.receive_tensors() - - # block to wait - ray.get(out1) - ray.get(out2) - - output = rollout_wg.get_tensors() - - rollout_0_output = output[0] - rollout_1_output = output[1] - - output = rollout_0_output | rollout_1_output - - print(output) - - for i in range(4): - assert torch.sum(output[f"src_{i}"]).item() == 4 * i - - ray.shutdown() - - -if __name__ == "__main__": - test_ray_collective_group() diff --git a/tests/single_controller/test_ray_local_envs_on_cpu.py b/tests/single_controller/test_ray_local_envs_on_cpu.py deleted file mode 100644 index ee6c0cbed..000000000 --- a/tests/single_controller/test_ray_local_envs_on_cpu.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -e2e test verl.single_controller.ray -""" - -import os - -import ray - -from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup - - -@ray.remote -class TestActor(Worker): - def __init__(self) -> None: - super().__init__() - - def getenv(self, key): - val = os.getenv(key, f"{key} not set") - return val - - -def test_basics(): - ray.init(num_cpus=100) - - # create 4 workers, each hold a GPU - resource_pool = RayResourcePool([4], use_gpu=False) - class_with_args = RayClassWithInitArgs(cls=TestActor) - - worker_group = RayWorkerGroup( - resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic" - ) - - output = worker_group.execute_all_sync("getenv", key="RAY_LOCAL_WORLD_SIZE") - assert output == ["4", "4", "4", "4"] - - output = worker_group.execute_all_sync("getenv", key="RAY_LOCAL_RANK") - assert set(output) == set(["0", "1", "2", "3"]) - - ray.shutdown() - - -if __name__ == "__main__": - test_basics() diff --git a/tests/single_controller/test_ray_utils_on_cpu.py b/tests/single_controller/test_ray_utils_on_cpu.py deleted file mode 100644 index e36497d21..000000000 --- a/tests/single_controller/test_ray_utils_on_cpu.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import ray - -from verl.utils.ray_utils import parallel_put - - -# Initialize Ray for testing if not already done globally -@pytest.fixture() -def init_ray(): - ray.init(num_cpus=4) - yield - ray.shutdown() - - -def test_parallel_put_basic(init_ray): - data = [1, "hello", {"a": 2}, [3, 4]] - refs = parallel_put(data) - assert len(refs) == len(data) - retrieved_data = [ray.get(ref) for ref in refs] - assert retrieved_data == data - - -def test_parallel_put_empty(init_ray): - data = [] - with pytest.raises(AssertionError): - _ = parallel_put(data) - - -def test_parallel_put_workers(init_ray): - data = list(range(20)) - # Test with specific number of workers - refs = parallel_put(data, max_workers=4) - assert len(refs) == len(data) - retrieved_data = [ray.get(ref) for ref in refs] - assert retrieved_data == data - # Test with default workers (should cap) - refs_default = parallel_put(data) - assert len(refs_default) == len(data) - retrieved_data_default = [ray.get(ref) for ref in refs_default] - assert retrieved_data_default == data diff --git a/tests/single_controller/test_rvdz.py b/tests/single_controller/test_rvdz.py deleted file mode 100644 index 7dea12f95..000000000 --- a/tests/single_controller/test_rvdz.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ray - - -@ray.remote -class TestWorker: - def __init__(self, rank, world_size, group_name): - self.rank = rank - self.world_size = world_size - self.group_name = group_name - self.communicator = None - - def init(self): - from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray - - self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name) - - def test(self): - if self.communicator is None: - return None - return self.communicator.rank_id() - - -def test_rvdz(): - ray.init() - - group_name = "test_group" - world_size = 2 - - workers = [TestWorker.options(num_gpus=1).remote(rank, world_size, group_name) for rank in range(world_size)] - - ray.get([worker.init.remote() for worker in workers]) - - ranks = ray.get([worker.test.remote() for worker in workers]) - - assert ranks == [0, 1], f"expecting [0, 1], got {ranks}" - - ray.shutdown() diff --git a/tests/single_controller/test_worker_group_basics.py b/tests/single_controller/test_worker_group_basics.py deleted file mode 100644 index 5c4823dfb..000000000 --- a/tests/single_controller/test_worker_group_basics.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -e2e test verl.single_controller.ray -""" - -import ray -import torch - -from verl.single_controller.base.decorator import Dispatch, Execute, collect_all_to_all, register -from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup - - -def two_to_all_dispatch_fn(worker_group, *args, **kwargs): - """ - Assume the input is a list of 2. Duplicate the input interleaved and pass to each worker. - """ - for arg in args: - assert len(arg) == 2 - for i in range(worker_group.world_size - 2): - arg.append(arg[i % 2]) - for k, v in kwargs.items(): - assert len(v) == 2 - for i in range(worker_group.world_size - 2): - v.append(v[i % 2]) - return args, kwargs - - -@ray.remote -class TestActor(Worker): - # TODO: pass *args and **kwargs is bug prone and not very convincing - def __init__(self, x) -> None: - super().__init__() - self._x = x - - def foo(self, y): - return self._x + y - - @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) - def foo_rank_zero(self, x, y): - return self._x + y + x - - @register(Dispatch.ONE_TO_ALL, blocking=False) - def foo_one_to_all(self, x, y): - return self._x + y + x - - @register(Dispatch.ALL_TO_ALL, blocking=False) - def foo_all_to_all(self, x, y): - return self._x + y + x - - @register(dispatch_mode={"dispatch_fn": two_to_all_dispatch_fn, "collect_fn": collect_all_to_all}) - def foo_custom(self, x, y): - return self._x + y + x - - -@ray.remote(num_gpus=0.1) -def remote_call_wg(worker_names): - class_with_args = RayClassWithInitArgs(cls=TestActor, x=2) - worker_group = RayWorkerGroup.from_detached( - worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None - ) - print(worker_group.worker_names) - - output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6]) - assert output_ref == [8, 10, 8, 10] - - output_ref = worker_group.foo_rank_zero(x=1, y=2) - assert output_ref == 5 - - return worker_group.worker_names - - -def add_one(data): - data = data.to("cuda") - data += 1 - data = data.to("cpu") - return data - - -def test_basics(): - ray.init(num_cpus=100) - - # create 4 workers, each hold a GPU - resource_pool = RayResourcePool([4], use_gpu=True) - class_with_args = RayClassWithInitArgs(cls=TestActor, x=2) - - worker_group = RayWorkerGroup( - resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic" - ) - - print(worker_group.worker_names) - - # this will wait for all the results - output = worker_group.execute_all_sync("foo", y=3) - assert output == [5, 5, 5, 5] - - # this is a list of object reference. It won't block. - output_ref = worker_group.execute_all_async("foo", y=4) - print(output_ref) - - assert ray.get(output_ref) == [6, 6, 6, 6] - - output_ref = worker_group.foo_one_to_all(x=1, y=2) - assert ray.get(output_ref) == [5, 5, 5, 5] - - output_ref = worker_group.foo_all_to_all(x=[1, 2, 3, 4], y=[5, 6, 7, 8]) - assert ray.get(output_ref) == [8, 10, 12, 14] - - print(ray.get(remote_call_wg.remote(worker_group.worker_names))) - - output = worker_group.execute_func_rank_zero(add_one, torch.ones(2, 2)) - torch.testing.assert_close(output, torch.ones(2, 2) + 1) - - ray.shutdown() - - -if __name__ == "__main__": - test_basics() diff --git a/tests/single_controller/test_worker_group_torch.py b/tests/single_controller/test_worker_group_torch.py deleted file mode 100644 index a601c43da..000000000 --- a/tests/single_controller/test_worker_group_torch.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -os.environ["RAY_DEDUP_LOGS"] = "0" -os.environ["NCCL_DEBUG"] = "WARN" - -import ray -import torch -import torch.distributed - -from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup - - -@ray.remote -class TestAllGatherActor(Worker): - def __init__(self, size) -> None: - super().__init__() - self.size = size - - def init(self): - torch.distributed.init_process_group() - self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda") - self.tensor += self.rank - - def all_gather(self): - world_size = self._world_size - output = torch.zeros( - size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device - ) - torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False) - return output - - -@ray.remote -class TestAllGatherActorV2(Worker): - def __init__(self, size) -> None: - super().__init__() - self.size = size - - torch.distributed.init_process_group() - self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device="cuda") - self.tensor += self.rank - - def all_gather(self): - world_size = self._world_size - output = torch.zeros( - size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device - ) - torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False) - return output - - -def test_all_gather_torch(): - """ - In this test, we instantiate 4 GPUs in a group and test the all_gather - """ - ray.init() - - # create 4 workers, each hold a GPU - resource_pool = RayResourcePool([4], use_gpu=True) - class_with_args = RayClassWithInitArgs(cls=TestAllGatherActor, size=2) - - worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch") - - worker_group.execute_all_sync("init") - output = worker_group.execute_all_sync("all_gather") - for i in range(1, len(output)): - assert torch.all(output[i] == output[0]) - - output = output[0].cpu() - print(output) - assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64)) - - ray.shutdown() - - -def test_all_gather_torch_v2(): - """ - In this test, we instantiate 4 GPUs in a group and test the all_gather - """ - ray.init() - - # create 4 workers, each hold a GPU - resource_pool = RayResourcePool([4], use_gpu=True) - class_with_args = RayClassWithInitArgs(cls=TestAllGatherActorV2, size=2) - - worker_group = RayWorkerGroup(resource_pool, class_with_args, name_prefix="worker_group_torch") - - output = worker_group.execute_all_sync("all_gather") - for i in range(1, len(output)): - assert torch.all(output[i] == output[0]) - - output = output[0].cpu() - print(output) - assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64)) - - ray.shutdown() diff --git a/tests/special_distributed/README.md b/tests/special_distributed/README.md deleted file mode 100644 index f2f865e8b..000000000 --- a/tests/special_distributed/README.md +++ /dev/null @@ -1 +0,0 @@ -This folder is reserved for unit tests (instead of end-to-end tests) that require multiple GPUs. diff --git a/tests/special_distributed/run_all.sh b/tests/special_distributed/run_all.sh deleted file mode 100644 index c34edf222..000000000 --- a/tests/special_distributed/run_all.sh +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/usr/bin/env bash - -set -e -x -torchrun --nproc-per-node=4 --standalone tests/special_distributed/test_tensor_dict.py \ No newline at end of file diff --git a/tests/special_distributed/test_fsdp_ckpt.py b/tests/special_distributed/test_fsdp_ckpt.py deleted file mode 100644 index 49dceb7c1..000000000 --- a/tests/special_distributed/test_fsdp_ckpt.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import shutil -import tempfile - -import torch -import torch.distributed -from torch.distributed import init_device_mesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision, ShardingStrategy -from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config - -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.distributed import initialize_global_process_group -from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2 - - -def test_fsdp_ckpt(strategy="fsdp"): - assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test" - local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) - - model_name = "Qwen/Qwen2.5-0.5B-Instruct" - config = Qwen2Config(num_hidden_layers=1) - - with torch.device("cuda"): - model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) - model = model.to(device="cuda") - - # Wrap model with FSDP - if strategy == "fsdp": - mixed_precision = MixedPrecision( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 - ) - - model = FSDP( - model, - use_orig_params=False, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - device_mesh=device_mesh, - ) - else: - mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True - ) - fsdp_kwargs = { - "mesh": device_mesh, - "mp_policy": mp_policy, - } - apply_fsdp2(model, fsdp_kwargs, {}) - - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) - - # Create checkpoint manager - tokenizer = AutoTokenizer.from_pretrained(model_name) - checkpoint_manager = FSDPCheckpointManager( - model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer - ) - - # Generate sample input - batch_size = 2 - seq_len = 32 - vocab_size = 32000 - # First input for initial update - input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") - attention_mask1 = torch.ones_like(input_ids1) - - # Second input for verification - input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") - attention_mask2 = torch.ones_like(input_ids2) - - # Step 1: Initial update and save checkpoint - outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1) - loss1 = outputs1.logits.mean() - loss1.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Save checkpoint after first update - temp_dir = tempfile.mkdtemp() - checkpoint_path = os.path.join(temp_dir, "checkpoint") - checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0) - - # Step 2: Second update and forward pass - outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2) - loss2 = outputs2.logits.mean() - loss2.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Record logits after second update - with torch.no_grad(): - logits_before_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits - - # Step 3: Load checkpoint and repeat second update - checkpoint_manager.load_checkpoint(checkpoint_path) - - # Repeat the second update with same input - outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2) - loss3 = outputs3.logits.mean() - loss3.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Record logits after loaded checkpoint and update - with torch.no_grad(): - logits_after_load = model(input_ids=input_ids2, attention_mask=attention_mask2).logits - - # Step 4: Verify outputs match - torch.testing.assert_close(logits_before_load, logits_after_load, atol=0.0, rtol=0.0) - print("Checkpoint save/load test passed!") - - # Cleanup - shutil.rmtree(temp_dir) - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - -if __name__ == "__main__": - strategy = os.environ.get("STRATEGY", "fsdp") - test_fsdp_ckpt(strategy=strategy) diff --git a/tests/special_distributed/test_tensor_dict.py b/tests/special_distributed/test_tensor_dict.py deleted file mode 100644 index 0a7f8039d..000000000 --- a/tests/special_distributed/test_tensor_dict.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -os.environ["NCCL_DEBUG"] = "WARN" - -import numpy as np -import torch -import torch.distributed - -from verl.protocol import DataProto, all_gather_data_proto -from verl.utils.distributed import initialize_global_process_group - - -def test_all_gather_data_proto(): - device_mesh = torch.distributed.device_mesh.init_device_mesh("cuda", mesh_shape=[2, 2], mesh_dim_names=["dp", "tp"]) - - global_rank = torch.distributed.get_rank() - - obs = torch.tensor([[1 * global_rank, 2 * global_rank + 1], [3 * global_rank, 4 * global_rank + 1]]) - - labels = ["a", "b"] if global_rank % 2 == 0 else ["b", "a"] - labels = np.array(labels, dtype=object) - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) - - all_gather_data_proto(data=data, process_group=device_mesh.get_group("dp")) - - if global_rank == 0: - expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda") - expected_labels = ["a", "b", "a", "b"] - elif global_rank == 1: - expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda") - expected_labels = ["b", "a", "b", "a"] - elif global_rank == 2: - expected_obs = torch.tensor([[0, 1], [0, 1], [2, 5], [6, 9]], device="cuda") - expected_labels = ["a", "b", "a", "b"] - elif global_rank == 3: - expected_obs = torch.tensor([[1, 3], [3, 5], [3, 7], [9, 13]], device="cuda") - expected_labels = ["b", "a", "b", "a"] - - torch.testing.assert_close(data.batch["obs"], expected_obs, atol=0, rtol=0) - assert (data.non_tensor_batch["labels"] == expected_labels).all() - assert data.meta_info == {"info": "test_info"} - - -def test_vocab_parallel_entropy(): - from megatron.core import parallel_state as mpu - - from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy - from verl.utils.profiler import log_gpu_memory_usage - from verl.utils.torch_functional import entropy_from_logits - - mpu.initialize_model_parallel( - tensor_model_parallel_size=2, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None - ) - - batch_size = 2 - seqlen = 128 - vocab_size = 155136 - - logits = torch.randn(batch_size * seqlen, vocab_size, device="cuda", requires_grad=True) - target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device="cuda", dtype=torch.int64) - - # broadcast across tp - torch.distributed.broadcast( - logits, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() - ) - torch.distributed.broadcast( - target, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() - ) - - tp_rank = mpu.get_tensor_model_parallel_rank() - vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size() - - # get the local logits of each tp - vocab_parallel_logits = ( - logits.clone().detach()[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp].requires_grad_() - ) - logits.grad = None - vocab_parallel_logits.grad = None - - log_gpu_memory_usage("begin") - output_entropy = vocab_parallel_entropy(vocab_parallel_logits) - log_gpu_memory_usage("after forward") - grad_output = torch.randn_like(output_entropy) - output_entropy.backward(grad_output) - log_gpu_memory_usage("after backward") - - target_entropy = entropy_from_logits(logits) - torch.testing.assert_close(output_entropy, target_entropy) - target_entropy.backward(grad_output) - torch.testing.assert_close( - logits.grad[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits.grad - ) - # make sure logits is not altered - torch.testing.assert_close( - logits[:, tp_rank * vocab_size_per_tp : (tp_rank + 1) * vocab_size_per_tp], vocab_parallel_logits - ) - - if mpu.get_tensor_model_parallel_rank() == 0: - print("test_vocab_parallel_entropy passes") - - mpu.destroy_model_parallel() - - -if __name__ == "__main__": - local_rank, rank, world_size = initialize_global_process_group() - test_all_gather_data_proto() - test_vocab_parallel_entropy() diff --git a/tests/special_e2e/README.md b/tests/special_e2e/README.md deleted file mode 100644 index 3c295e844..000000000 --- a/tests/special_e2e/README.md +++ /dev/null @@ -1 +0,0 @@ -This folder is reserved for end-to-end tests that typically require multiple GPUs. diff --git a/tests/special_e2e/__init__.py b/tests/special_e2e/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/tests/special_e2e/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/special_e2e/check_custom_rwd_fn.py b/tests/special_e2e/check_custom_rwd_fn.py deleted file mode 100644 index 8d77a5372..000000000 --- a/tests/special_e2e/check_custom_rwd_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - - -def check_congratulations_in_file(output_file): - with open(output_file) as f: - output = f.read() - - success_message = "Congratulations!!! You have called my_reward_function successfully!!!" - assert success_message in output, f"Success message of my_reward_function not found in {output_file}" - print("Check passes") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--output_file", required=True, type=str) - - args = parser.parse_args() - - check_congratulations_in_file(args.output_file) diff --git a/tests/special_e2e/check_results.py b/tests/special_e2e/check_results.py deleted file mode 100644 index 9453282fb..000000000 --- a/tests/special_e2e/check_results.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -import numpy as np - - -def extract_reward_from_line(line): - # TODO: this function needs error handling - try: - key_vals = line.split(" - ") - for key_val in key_vals: - key, val = key_val.split(":") - if key == "critic/rewards/mean": - reward = float(val) - return reward - return -np.inf - except Exception: - return -np.inf - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--output_file", required=True, type=str) - parser.add_argument("--target", type=float, default=0.2, help="target reward score") - - args = parser.parse_args() - - with open(args.output_file) as f: - output = f.read().split("\n") - - best_reward = -np.inf - for line in output: - if line.startswith("step"): - reward = extract_reward_from_line(line) - if reward > best_reward: - best_reward = reward - - print(f"Best reward is {best_reward}") - assert best_reward > args.target, f"Best reward must be greater than {args.target}. best_reward: {best_reward}" - print("Check passes") diff --git a/tests/special_e2e/envs/__init__.py b/tests/special_e2e/envs/__init__.py deleted file mode 100644 index eb85e22f3..000000000 --- a/tests/special_e2e/envs/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .digit_completion import DigitCompletion - -__all__ = ["DigitCompletion"] diff --git a/tests/special_e2e/envs/digit_completion/__init__.py b/tests/special_e2e/envs/digit_completion/__init__.py deleted file mode 100644 index 80893ae41..000000000 --- a/tests/special_e2e/envs/digit_completion/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers import AutoTokenizer, LlamaConfig - -from .task import DigitCompletion, generate_ground_truth_response -from .tokenizer import CharTokenizer - -AutoTokenizer.register(LlamaConfig, CharTokenizer, exist_ok=True) - -__all__ = ["DigitCompletion", "generate_ground_truth_response", "CharTokenizer"] diff --git a/tests/special_e2e/envs/digit_completion/task.py b/tests/special_e2e/envs/digit_completion/task.py deleted file mode 100644 index c3643a86b..000000000 --- a/tests/special_e2e/envs/digit_completion/task.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Task and environment definition for digit completion.""" - -import numpy as np - - -class DigitCompletion: - """ - The implementation of a simple digit completion task. - The prompt is a sequence of numbers with fixed difference. The task is to complete the next N numbers. - If the max number is reached, the next number should be modulo with max number. - - For example, - - prompt = [1, 2, 3] - - N = 5 - - max_number = 6 - - the response should be [4, 5, 6, 7%6, 8%6] = [4, 5, 6, 0, 1] - - Note that the tokenizer is char-level to increase the difficulty. - """ - - def __init__(self, max_number: int, max_diff: int, max_num_in_response: int, seed=0): - """ - - Args: - max_number: the maximum number allowed in the arithmetic sequence - max_diff: the maximum diff. The actual common diff will be sampled from [0, max_diff] - max_num_in_response: the maximum number in the response - """ - super().__init__() - self.max_number = max_number - self.max_diff = max_diff - self.max_num_in_response = max_num_in_response - assert self.max_num_in_response < 10 - assert self.max_number > 0 - assert self.max_diff > 0 - self.max_number_length = len(str(max_number)) - # {num1},{num2}:{max_num_in_response},{max_number} - self._prompt_length = self.max_number_length * 2 + 4 + self.max_number_length # no negative is allowed - - self.np_rng = np.random.default_rng(seed=seed) - - def __str__(self): - return ( - f"Prompt length: {self.prompt_length}. Response length: {self.response_length}, " - f"Max number: {self.max_number}. Max diff: {self.max_diff}, " - f"Max number in response: {self.max_num_in_response}" - ) - - def get_state(self): - return {"rng": self.np_rng} - - def set_state(self, state): - assert "rng" in state, "rng must be inside state" - self.np_rng = state["rng"] - - @property - def prompt_length(self): - return self._prompt_length - - @property - def response_length(self): - # number length + comma length + [EOS] - # The actual number times 1.5 to allow 'U' - return (self.max_num_in_response * self.max_number_length + (self.max_num_in_response - 1) + 1) * 2 - - def add(self, a, b): - return (a + b) % self.max_number - - def get_all_prompts(self): - all_prompts = [] - for first_num in range(self.max_number + 1): - for diff in range(0, self.max_diff + 1): - second_num = self.add(first_num, diff) - for num_to_complete in range(self.max_num_in_response + 1): - prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}" - all_prompts.append(prompt) - return all_prompts - - def sample_str_prompts(self): - # step 1: sample initial numbers - first_num = self.np_rng.integers(self.max_number + 1) - diff = self.np_rng.integers(self.max_diff + 1) - second_num = self.add(first_num, diff) - num_to_complete = self.np_rng.integers(self.max_num_in_response + 1) - prompt = str(first_num) + "," + str(second_num) + f":{self.max_number},{num_to_complete}" - return prompt - - def sample_batch_str_prompts(self, batch_size): - str_prompts = [] - for _ in range(batch_size): - str_prompts.append(self.sample_str_prompts()) - return str_prompts - - -def compute_attention_mask(prompts, pad_token_id): - mask = np.ones_like(prompts) - mask[prompts == pad_token_id] = 0 - return mask - - -def compute_position_id_with_mask(mask): - return np.clip(np.cumsum(mask, axis=-1) - 1, a_min=0, a_max=None) - - -def generate_ground_truth_response(prompt: str): - """Generate ground truth response given a prompt.""" - num, info = prompt.split(":") - num1, num2 = num.split(",") - max_number, num_to_gen = info.split(",") - num1 = int(num1) - num2 = int(num2) - max_number = int(max_number) - num_to_gen = int(num_to_gen) - diff = (num2 - num1) % max_number - results = [] - last_num = num2 - for _ in range(num_to_gen): - curr = (last_num + diff) % max_number - results.append(str(curr)) - last_num = curr - response = ",".join(results) - return response - - -def compute_reward(prompt: str, response: str, sequence_reward=1.0): - """We compute dense reward here so that we can directly train RL without SFT""" - response_length = len(response) - ground_truth_response = generate_ground_truth_response(prompt) - per_token_reward = sequence_reward / (len(ground_truth_response) + 1) # including [EOS] - - # pad - reward = np.zeros(response_length, dtype=np.float32) # this assumes that each char is a token - # assign reward until mismatches - ground_truth_idx = 0 - for i in range(response_length): - if ground_truth_idx == len(ground_truth_response): - break - - ground_truth_response_token = ground_truth_response[ground_truth_idx] - response_token = response[i] - if ground_truth_response_token == response_token: - reward[i] = per_token_reward - ground_truth_idx += 1 - else: - # no matches - break - - return reward, {"ground_truth_response": ground_truth_response} - - -if __name__ == "__main__": - task = DigitCompletion(max_number=20, max_diff=3, max_num_in_response=5) - print(task.sample_str_prompts()) - - prompt = "7,8:20,0" - response = "" - print(compute_reward(prompt, response)) - - prompt = "7,8:20,0" - response = "E000" - print(compute_reward(prompt, response)) - - prompt = "9,10:20,2" - response = "11,12,13" - print(compute_reward(prompt, response)) diff --git a/tests/special_e2e/envs/digit_completion/tokenizer.py b/tests/special_e2e/envs/digit_completion/tokenizer.py deleted file mode 100644 index 6ff471938..000000000 --- a/tests/special_e2e/envs/digit_completion/tokenizer.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Copied from https://github.com/dariush-bahrami/character-tokenizer/blob/master/charactertokenizer/core.py - -CharacterTokenzier for Hugging Face Transformers. - -This is heavily inspired from CanineTokenizer in transformers package. -""" - -import json -import os -from pathlib import Path -from typing import Optional, Sequence - -from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer - - -class CharTokenizer(PreTrainedTokenizer): - def __init__(self, characters: Sequence[str], model_max_length: int, chat_template, **kwargs): - """Character tokenizer for Hugging Face transformers. - - Args: - characters (Sequence[str]): List of desired characters. Any character which - is not included in this list will be replaced by a special token called - [UNK] with id=6. Following are list of all of the special tokens with - their corresponding ids: - "[CLS]": 0 - "[SEP]": 1 - "[BOS]": 2 - "[MASK]": 3 - "[PAD]": 4 - "[RESERVED]": 5 - "[UNK]": 6 - an id (starting at 7) will be assigned to each character. - - model_max_length (int): Model maximum sequence length. - """ - eos_token_str = "E" - sep_token_str = "S" - pad_token_str = "P" - unk_token_str = "U" - - self.characters = characters - self.model_max_length = model_max_length - eos_token = AddedToken(eos_token_str, lstrip=False, rstrip=False) - sep_token = AddedToken(sep_token_str, lstrip=False, rstrip=False) - pad_token = AddedToken(pad_token_str, lstrip=False, rstrip=False) - unk_token = AddedToken(unk_token_str, lstrip=False, rstrip=False) - - self._vocab_str_to_int = { - sep_token_str: 0, - eos_token_str: 1, - pad_token_str: 2, - unk_token_str: 3, - **{ch: i + 4 for i, ch in enumerate(characters)}, - } - self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} - - super().__init__( - eos_token=eos_token, - sep_token=sep_token, - pad_token=pad_token, - unk_token=unk_token, - add_prefix_space=False, - model_max_length=model_max_length, - **kwargs, - ) - - self.chat_template = chat_template - - @property - def vocab_size(self) -> int: - return len(self._vocab_str_to_int) - - def get_vocab(self): - return self._vocab_str_to_int - - def _tokenize(self, text: str) -> list[str]: - return list(text) - - def _convert_token_to_id(self, token: str) -> int: - return self._vocab_str_to_int.get(token, self._vocab_str_to_int["U"]) - - def _convert_id_to_token(self, index: int) -> str: - return self._vocab_int_to_str[index] - - def convert_tokens_to_string(self, tokens): - return "".join(tokens) - - def build_inputs_with_special_tokens( - self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None - ) -> list[int]: - sep = [self.sep_token_id] - cls = [self.cls_token_id] - result = cls + token_ids_0 + sep - if token_ids_1 is not None: - result += token_ids_1 + sep - return result - - def get_special_tokens_mask( - self, - token_ids_0: list[int], - token_ids_1: Optional[list[int]] = None, - already_has_special_tokens: bool = False, - ) -> list[int]: - if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, - token_ids_1=token_ids_1, - already_has_special_tokens=True, - ) - - result = [1] + ([0] * len(token_ids_0)) + [1] - if token_ids_1 is not None: - result += ([0] * len(token_ids_1)) + [1] - return result - - def get_config(self) -> dict: - return { - "char_ords": [ord(ch) for ch in self.characters], - "model_max_length": self.model_max_length, - "chat_template": self.chat_template, - } - - @classmethod - def from_config(cls, config: dict): - cfg = {} - cfg["characters"] = [chr(i) for i in config["char_ords"]] - cfg["model_max_length"] = config["model_max_length"] - cfg["chat_template"] = config["chat_template"] - return cls(**cfg) - - def save_pretrained(self, save_directory: str | os.PathLike, **kwargs): - cfg_file = Path(save_directory) / "tokenizer_config.json" - cfg = self.get_config() - with open(cfg_file, "w") as f: - json.dump(cfg, f, indent=4) - - @classmethod - def from_pretrained(cls, save_directory: str | os.PathLike, **kwargs): - cfg_file = Path(save_directory) / "tokenizer_config.json" - with open(cfg_file) as f: - cfg = json.load(f) - return cls.from_config(cfg) diff --git a/tests/special_e2e/generation/run_gen_qwen05.sh b/tests/special_e2e/generation/run_gen_qwen05.sh deleted file mode 100755 index 61c55b157..000000000 --- a/tests/special_e2e/generation/run_gen_qwen05.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env bash -# Tested with 1 & 4 GPUs -set -xeuo pipefail - -MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} - -NGPUS_PER_NODE=${NGPUS_PER_NODE:-4} -OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_05_gen_test.parquet} -GEN_TP=${GEN_TP:-2} # Default tensor parallel size to 2 - -python3 -m verl.trainer.main_generation \ - trainer.nnodes=1 \ - trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ - data.path="${HOME}/data/gsm8k/test.parquet" \ - data.prompt_key=prompt \ - data.n_samples=1 \ - data.output_path="${OUTPUT_PATH}" \ - model.path="${MODEL_ID}" \ - +model.trust_remote_code=True \ - rollout.temperature=1.0 \ - rollout.top_k=50 \ - rollout.top_p=0.7 \ - rollout.prompt_length=2048 \ - rollout.response_length=1024 \ - rollout.tensor_model_parallel_size="${GEN_TP}" \ - rollout.gpu_memory_utilization=0.8 diff --git a/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json b/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json deleted file mode 100644 index c215fa4f7..000000000 --- a/tests/special_e2e/ppo_trainer/expert_parallel/qwen2moe_minimal.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "num_hidden_layers": 2, - "max_window_layers": 2 -} \ No newline at end of file diff --git a/tests/special_e2e/ppo_trainer/run_function_reward.sh b/tests/special_e2e/ppo_trainer/run_function_reward.sh deleted file mode 100644 index 62bf410ef..000000000 --- a/tests/special_e2e/ppo_trainer/run_function_reward.sh +++ /dev/null @@ -1,161 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -NUM_GPUS=${NUM_GPUS:-8} - -MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} -MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} -huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" - -TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} -VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} -MAX_PROMPT_LEN=${MAX_PROMPT_LEN:-512} -MAX_RESPONSE_LEN=${MAX_RESPONSE_LEN:-512} - -ENGINE=${ENGINE:-vllm} -ROLLOUT_MODE=${ROLLOUT_MODE:-sync} - -RETURN_RAW_CHAT="False" -if [ "$ROLLOUT_MODE" = "async" ]; then - RETURN_RAW_CHAT="True" -fi - -GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.8} -ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False} -ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False} -REF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True} -RM_PAD=${RM_PAD:-True} -FUSED_KERNELS=${FUSED_KERNELS:-False} -FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend -ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae} -USE_KL=${USE_KL:-False} -CUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False} -ENABLE_CHUNKED_PREFILL=${ENABLE_CHUNKED_PREFILL:-True} # For vLLM VLM placeholder issue: https://github.com/vllm-project/vllm/issues/15185 -STRATEGY=${STRATEGY:-fsdp} -# LoRA config -LORA_RANK=${LORA_RANK:-0} -LORA_ALPHA=${LORA_ALPHA:-${LORA_RANK}} -LORA_TARGET=${LORA_TARGET:-"all-linear"} -LORA_EXCLUDE=${LORA_EXCLUDE:-"DONT_EXCLUDE"} -USE_SHM=${USE_SHM:-False} -LOAD_FORMAT=${LOAD_FORMAT:-dummy_dtensor} -LAYERED_SUMMON=${LAYERED_SUMMON:-False} -# Validation -VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False} -TEST_FREQ=${TEST_FREQ:--1} -# Save & Resume -RESUME_MODE=${RESUME_MODE:-disable} -SAVE_FREQ=${SAVE_FREQ:--1} -TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} - -# whether to save hf_model -SAVE_HF_MODEL=${SAVE_HF_MODEL:-False} -FSDP_SIZE=${FSDP_SIZE:--1} -SP_SIZE=${SP_SIZE:-1} - -if [ "${SAVE_HF_MODEL}" = "True" ]; then - CHECKPOINT_CONTENTS="['model','hf_model','optimizer','extra']" -else - CHECKPOINT_CONTENTS="['model','optimizer','extra']" -fi - -train_traj_micro_bsz_per_gpu=2 # b -n_resp_per_prompt=4 # g - -train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n -train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n -train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g -train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g - -reward_fn_name=null -reward_fn_file_path=null -output_file="$(pwd)/output.txt" -if [ "${CUSTOM_REWARD_FN}" = "True" ]; then - reward_fn_name="my_reward_function" - reward_fn_file_path="$(pwd)/my_reward_function.py" - rm -rf "${reward_fn_file_path}" - cat < "$reward_fn_file_path" -def ${reward_fn_name}(data_source, solution_str, ground_truth, extra_info=None): - print(f"Congratulations!!! You have called ${reward_fn_name} successfully!!!") - return 0.1 -EOF - - rm -rf "${output_file}" -fi - -exp_name="${VERL_EXP_NAME:-$(basename "${MODEL_ID,,}")-function-reward-minimal}" - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator="${ADV_ESTIMATOR}" \ - data.train_files="${TRAIN_FILES}" \ - data.val_files="${VAL_FILES}" \ - data.train_batch_size="${train_prompt_bsz}" \ - data.max_prompt_length="${MAX_PROMPT_LEN}" \ - data.max_response_length="${MAX_RESPONSE_LEN}" \ - data.return_raw_chat=${RETURN_RAW_CHAT} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.use_shm=${USE_SHM} \ - actor_rollout_ref.model.lora_rank=${LORA_RANK} \ - actor_rollout_ref.model.lora_alpha=${LORA_ALPHA} \ - actor_rollout_ref.model.target_modules=${LORA_TARGET} \ - actor_rollout_ref.model.exclude_modules=${LORA_EXCLUDE} \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ - actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \ - actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.actor.strategy=${STRATEGY} \ - actor_rollout_ref.actor.fsdp_config.param_offload=${ACTOR_FSDP_PARAM_OFFLOAD} \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=${ACTOR_FSDP_OPTIMIZER_OFFLOAD} \ - actor_rollout_ref.actor.fsdp_config.fsdp_size=${FSDP_SIZE} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \ - actor_rollout_ref.actor.checkpoint.save_contents=${CHECKPOINT_CONTENTS} \ - actor_rollout_ref.actor.use_kl_loss="${USE_KL}" \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name="${ENGINE}" \ - actor_rollout_ref.rollout.mode="${ROLLOUT_MODE}" \ - actor_rollout_ref.rollout.load_format=${LOAD_FORMAT} \ - actor_rollout_ref.rollout.layered_summon=${LAYERED_SUMMON} \ - actor_rollout_ref.rollout.gpu_memory_utilization="${GPU_MEMORY_UTILIZATION}" \ - actor_rollout_ref.rollout.enable_chunked_prefill="${ENABLE_CHUNKED_PREFILL}" \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.ref.fsdp_config.param_offload="${REF_FSDP_PARAM_OFFLOAD}" \ - critic.optim.lr=1e-5 \ - critic.model.use_remove_padding="${RM_PAD}" \ - critic.model.path="${MODEL_PATH}" \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - custom_reward_function.path="${reward_fn_file_path}"\ - custom_reward_function.name="${reward_fn_name}"\ - algorithm.use_kl_in_reward="${USE_KL}" \ - algorithm.kl_penalty=kl \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl-test' \ - trainer.experiment_name="${exp_name}" \ - trainer.nnodes=1 \ - trainer.n_gpus_per_node="${NUM_GPUS}" \ - trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ - trainer.test_freq="${TEST_FREQ}" \ - trainer.save_freq="${SAVE_FREQ}" \ - trainer.resume_mode="${RESUME_MODE}" \ - trainer.total_epochs=2 \ - trainer.device=cuda \ - trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ \ - | tee "${output_file}" - -if [ "${CUSTOM_REWARD_FN}" = "True" ]; then - python3 tests/special_e2e/check_custom_rwd_fn.py --output_file="${output_file}" - check_exit_code=$? - rm -rf "${reward_fn_file_path}" - rm -rf "${output_file}" - # Return the exit code of check_custom_rwd_fn.py if it fails - if [ $check_exit_code -ne 0 ]; then - exit $check_exit_code - fi -fi diff --git a/tests/special_e2e/ppo_trainer/run_model_reward.sh b/tests/special_e2e/ppo_trainer/run_model_reward.sh deleted file mode 100644 index e7711f96d..000000000 --- a/tests/special_e2e/ppo_trainer/run_model_reward.sh +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -NUM_GPUS=${NUM_GPUS:-8} - -MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} -MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} -huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" - -TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} -VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} - -RM_PAD=${RM_PAD:-True} -FUSED_KERNELS=${FUSED_KERNELS:-False} -FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend -SP_SIZE=${SP_SIZE:-1} -SEQ_BALANCE=${SEQ_BALANCE:-False} -LIGER=${LIGER:-False} -# Validation -VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False} -TEST_FREQ=${TEST_FREQ:--1} -# Save & Resume -RESUME_MODE=${RESUME_MODE:-disable} -SAVE_FREQ=${SAVE_FREQ:--1} -TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} - -train_traj_micro_bsz_per_gpu=2 # b -n_resp_per_prompt=4 # g - -train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n -train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n -train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g -train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g - -train_max_token_num_per_gpu=32768 -infer_max_token_num_per_gpu=32768 - -exp_name="$(basename "${MODEL_ID,,}")-model-reward-minimal" - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=gae \ - data.train_files="${TRAIN_FILES}" \ - data.val_files="${VAL_FILES}" \ - data.train_batch_size=${train_prompt_bsz} \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.use_liger="${LIGER}" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \ - actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \ - actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.use_dynamic_bsz="${SEQ_BALANCE}" \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size="${SP_SIZE}" \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - critic.optim.lr=1e-5 \ - critic.ulysses_sequence_parallel_size="${SP_SIZE}" \ - critic.model.use_remove_padding="${RM_PAD}" \ - critic.optim.lr_warmup_steps_ratio=0.05 \ - critic.model.path="${MODEL_PATH}" \ - critic.model.enable_gradient_checkpointing=False \ - critic.use_dynamic_bsz="${SEQ_BALANCE}" \ - critic.ppo_max_token_len_per_gpu=${train_max_token_num_per_gpu} \ - critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - critic.model.fsdp_config.param_offload=False \ - critic.model.fsdp_config.optimizer_offload=False \ - reward_model.enable=True \ - reward_model.ulysses_sequence_parallel_size="${SP_SIZE}" \ - reward_model.model.path="${MODEL_PATH}" \ - reward_model.model.use_remove_padding="${RM_PAD}" \ - reward_model.model.fsdp_config.param_offload=True \ - reward_model.use_dynamic_bsz="${SEQ_BALANCE}" \ - reward_model.forward_max_token_len_per_gpu=${infer_max_token_num_per_gpu} \ - reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl-test' \ - trainer.experiment_name="${exp_name}" \ - trainer.nnodes=1 \ - trainer.n_gpus_per_node="${NUM_GPUS}" \ - trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ - trainer.test_freq="${VAL_BEFORE_TRAIN}" \ - trainer.save_freq="${SAVE_FREQ}" \ - trainer.resume_mode="${RESUME_MODE}" \ - trainer.total_epochs=2 \ - trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ diff --git a/tests/special_e2e/ppo_trainer/run_single_gpu.sh b/tests/special_e2e/ppo_trainer/run_single_gpu.sh deleted file mode 100644 index 7e8615a24..000000000 --- a/tests/special_e2e/ppo_trainer/run_single_gpu.sh +++ /dev/null @@ -1,24 +0,0 @@ -PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=256 \ - data.max_prompt_length=512 \ - data.max_response_length=256 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - critic.optim.lr=1e-5 \ - critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - critic.ppo_micro_batch_size_per_gpu=4 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=console \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=1 \ - trainer.nnodes=1 \ - actor_rollout_ref.rollout.name=hf \ - trainer.total_training_steps=2 \ No newline at end of file diff --git a/tests/special_e2e/run_dapo.sh b/tests/special_e2e/run_dapo.sh deleted file mode 100644 index 56ff0ae05..000000000 --- a/tests/special_e2e/run_dapo.sh +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -NUM_GPUS=${NUM_GPUS:-8} - -MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} -MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} -huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" - -adv_estimator=grpo - -kl_coef=0.0 -use_kl_in_reward=False -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=1024 -max_response_length=2048 -enable_overlong_buffer=True -overlong_buffer_len=128 -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -enable_filter_groups=True -filter_groups_metric=seq_reward -max_num_gen_batches=10 - -train_traj_micro_bsz_per_gpu=2 # b -n_resp_per_prompt=4 # g - -train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n -train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n -train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g -train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g - -gen_prompt_bsz=$((train_prompt_bsz * 4)) - -exp_name="$(basename "${MODEL_ID,,}")-dapo-minimal" - -python3 -m recipe.dapo.main_dapo \ - data.train_files="${HOME}/data/gsm8k/train.parquet" \ - data.val_files="${HOME}/data/gsm8k/test.parquet" \ - reward_model.reward_manager=dapo \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - data.train_batch_size=${train_prompt_bsz} \ - data.gen_batch_size=${gen_prompt_bsz} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.use_fused_kernels=True \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - trainer.logger=console \ - trainer.project_name='verl-test' \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=${NUM_GPUS} \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_epochs=2 \ - trainer.resume_mode=disable \ - trainer.val_before_train=False \ - trainer.total_training_steps=1 $@ diff --git a/tests/special_e2e/run_genrm_remote.sh b/tests/special_e2e/run_genrm_remote.sh deleted file mode 100644 index 4819248be..000000000 --- a/tests/special_e2e/run_genrm_remote.sh +++ /dev/null @@ -1,81 +0,0 @@ -#!/usr/bin/env bash - -export no_proxy="localhost,127.0.0.1" - -set -x - -# Launch a vllm server -CUDA_VISIBLE_DEVICES=0 vllm serve verl-team/GenRM-CI-Test-1.5B \ - --served_model_name genrm-demo --host localhost --port 30000 > /dev/null & -SERVER_PID=$! - -# kill server when script exits -cleanup() { - echo "Cleaning up..." - kill $SERVER_PID 2>/dev/null || true - wait $SERVER_PID 2>/dev/null || true - echo "Cleanup done" -} -trap cleanup EXIT - -# wait for server to start -wait_for_server() { - local max_attempts=60 - local attempt=0 - local sleep_time=10 - - while [ $attempt -lt $max_attempts ]; do - if curl -s "http://localhost:30000/health" >/dev/null; then - echo "Server is up and running!" - return 0 - fi - echo "Waiting for server to start... (attempt $((attempt+1))/$max_attempts)" - sleep $sleep_time - ((attempt++)) - done - - echo "Error: Failed to start server after $max_attempts attempts" >&2 - return 1 -} - -if ! wait_for_server; then - exit 1 -fi - -CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=${HOME}/data/gsm8k/train.parquet \ - data.val_files=${HOME}/data/gsm8k/test.parquet \ - data.train_batch_size=256 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.rollout.n=4 \ - algorithm.use_kl_in_reward=False \ - reward_model.reward_manager=batch \ - custom_reward_function.path=recipe/genrm_remote/reward_function.py \ - custom_reward_function.name=compute_score_batch \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl-test' \ - trainer.experiment_name='qwen2.5-0.5b-gen-rm' \ - trainer.n_gpus_per_node=4 \ - trainer.val_before_train=False \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_epochs=10 \ - trainer.resume_mode='disable' \ - trainer.total_training_steps=1 diff --git a/tests/special_e2e/run_geo3k_fsdp_sgl_multiturn_w_tool.sh b/tests/special_e2e/run_geo3k_fsdp_sgl_multiturn_w_tool.sh deleted file mode 100644 index caa9e664c..000000000 --- a/tests/special_e2e/run_geo3k_fsdp_sgl_multiturn_w_tool.sh +++ /dev/null @@ -1,58 +0,0 @@ -# run on 8xH100 -# make sure your current working directory is the root of the project - -set -x - -huggingface-cli download Qwen/Qwen2.5-VL-3B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-VL-3B-Instruct - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" -FSDP_STRATEGY=${FSDP_STRATEGY:-fsdp} - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='geo3k_multiturn_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=64 \ - data.max_prompt_length=2048 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-VL-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.strategy=$FSDP_STRATEGY \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ - actor_rollout_ref.ref.strategy=$FSDP_STRATEGY \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='geo3k_async_rl' \ - trainer.experiment_name=qwen2.5-vl-3b_function_rm-geo3k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0619-verify-n8 \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=-1 \ - data.train_files=$HOME/data/geo3k_verl_sgl_multi_turn_preprocessed/train.parquet \ - data.val_files=$HOME/data/geo3k_verl_sgl_multi_turn_preprocessed/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/geo3k_tool_config.yaml" \ - trainer.val_before_train=False \ - trainer.total_training_steps=1 $@ \ No newline at end of file diff --git a/tests/special_e2e/run_grpo_lora_with_merge.sh b/tests/special_e2e/run_grpo_lora_with_merge.sh deleted file mode 100644 index 192d935ba..000000000 --- a/tests/special_e2e/run_grpo_lora_with_merge.sh +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/env bash -# -# An e2e test script for testing the GRPO LoRA training process -# and processing the generated checkpoint using the merge_model.py script. - -set -xeuo pipefail - -MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} -MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} -if [ ! -d "$MODEL_PATH" ]; then - echo "Downloading model to ${MODEL_PATH}..." - huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" -else - echo "Model directory ${MODEL_PATH} already exists, skip downloading." -fi - - -BATCH_SIZE=16 -EXP_NAME="qwen2.5_0.5b_grpo_lora" -# step 1. train model with grpo-lora for 1 step -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=${BATCH_SIZE} \ - data.max_prompt_length=512 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.shuffle=False \ - actor_rollout_ref.model.path=${MODEL_PATH} \ - actor_rollout_ref.model.use_shm=True \ - actor_rollout_ref.model.lora_rank=64 \ - actor_rollout_ref.model.lora_alpha=32 \ - actor_rollout_ref.actor.optim.lr=3e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=${BATCH_SIZE} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.rollout.load_format=safetensors \ - actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name=${EXP_NAME} \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.total_training_steps=1 \ - trainer.save_freq=1 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 $@ - -# step 2. merge model -python3 -m verl.model_merger merge \ - --backend fsdp \ - --local_dir checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/ \ - --target_dir checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/hf - -# step 3. assert -# make sure adapter_model.safetensors exists and its size is larger than 1MB -file_path="checkpoints/verl_grpo_example_gsm8k/${EXP_NAME}/global_step_1/actor/hf/lora_adapter/adapter_model.safetensors" - -if [ ! -f "$file_path" ]; then - echo "Error: File $file_path does not exist!" - exit 1 -fi - -file_size=$(stat -c %s "$file_path") - -min_size_mb=1 -min_size=$((min_size_mb * 1024 * 1024)) # 1MB = 1048576 bytes - -if [ "$file_size" -lt "$min_size" ]; then - echo "Error: File $file_path is too small! Current size: $((file_size/1024))KB, Required: ${min_size_mb}MB" - exit 1 -fi - -echo "Check passed: File exists and size is $(($file_size/1024/1024))MB" -exit 0 diff --git a/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh b/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh deleted file mode 100644 index 729b42554..000000000 --- a/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_sf_tool.sh +++ /dev/null @@ -1,62 +0,0 @@ -# run on 8xH20 -# make sure your current working directory is the root of the project - -set -x - - -export PYTHONUNBUFFERED=1 -export RAY_DEDUP_LOGS=0 -export RUST_BACKTRACE=1 -export HYDRA_FULL_ERROR=1 - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='gsm8k_multiturn_sf_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=128 \ - data.max_prompt_length=2048 \ - data.max_response_length=16384 \ - data.filter_overlong_prompts=False \ - data.truncation='error' \ - data.return_raw_chat=True \ - data.train_files=$HOME/data/retool_dapo/train.parquet \ - data.val_files=$HOME/data/retool_aime2024/train.parquet \ - actor_rollout_ref.model.path=Qwen/Qwen3-4B \ - actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.use_liger=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - +actor_rollout_ref.model.enable_activation_offloading=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.actor.kl_loss_coef=0.0 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.rollout.n=8 \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/sandbox_fusion_tool_config.yaml" \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger='["console","wandb"]' \ - trainer.project_name='retool_async_rl' \ - trainer.experiment_name='qwen3-4b_function_rm-retool-async-sgl-no-sft-n8-v2505271300' \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=100 \ - trainer.test_freq=20 \ - trainer.total_training_steps=1000 \ - trainer.total_epochs=1 $@ \ No newline at end of file diff --git a/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh b/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh deleted file mode 100644 index 76983ddad..000000000 --- a/tests/special_e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh +++ /dev/null @@ -1,58 +0,0 @@ -# run on 8xH100 -# make sure your current working directory is the root of the project - -set -x - -huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-3B-Instruct - -ulimit -n 65535 - -PROJECT_DIR="$(pwd)" -CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" -FSDP_STRATEGY=${FSDP_STRATEGY:-fsdp} - -python3 -m verl.trainer.main_ppo \ - --config-path="$CONFIG_PATH" \ - --config-name='gsm8k_multiturn_grpo' \ - algorithm.adv_estimator=grpo \ - data.train_batch_size=256 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.strategy=$FSDP_STRATEGY \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.rollout.n=8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - actor_rollout_ref.ref.strategy=$FSDP_STRATEGY \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=-1 \ - data.train_files=$HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed/train.parquet \ - data.val_files=$HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed/test.parquet \ - actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ - trainer.val_before_train=False \ - trainer.total_training_steps=1 $@ diff --git a/tests/special_e2e/run_ppo_trainer_megatron.sh b/tests/special_e2e/run_ppo_trainer_megatron.sh deleted file mode 100644 index 72232d4db..000000000 --- a/tests/special_e2e/run_ppo_trainer_megatron.sh +++ /dev/null @@ -1,237 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping -export VERL_LOGGING_LEVEL=INFO -export VERL_PPO_LOGGING_LEVEL=INFO - -NUM_GPUS=${NUM_GPUS:-8} - -MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} -MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} -huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" - -USE_DUMMY_MODEL=${USE_DUMMY_MODEL:-False} -DUMMY_MODEL_PATH=${DUMMY_MODEL_PATH:-${HOME}/dummy_models/${MODEL_ID}} -if [ "$USE_DUMMY_MODEL" = "True" ]; then - if [ -z "${DUMMY_MODEL_CONFIG_PATH}" ]; then - echo "[ERROR] DUMMY_MODEL_CONFIG_PATH not set" - exit 1 - fi - - python scripts/init_random_model.py \ - --hf_model_path "${MODEL_PATH}" \ - --new_config_path "${DUMMY_MODEL_CONFIG_PATH}" \ - --output_path "${DUMMY_MODEL_PATH}" - - MODEL_PATH="${DUMMY_MODEL_PATH}" -fi - -TRAIN_FILES=${TRAIN_FILES:-${HOME}/data/gsm8k/train.parquet} -VAL_FILES=${VAL_FILES:-${HOME}/data/gsm8k/test.parquet} - -ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae} -# Validation -VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-False} -TEST_FREQ=${TEST_FREQ:--1} -# Save & Resume -RESUME_MODE=${RESUME_MODE:-disable} -SAVE_FREQ=${SAVE_FREQ:--1} -TOTAL_TRAIN_STEPS=${TOTAL_TRAIN_STEPS:-1} - -USE_DYNAMIC_BSZ=${USE_DYNAMIC_BSZ:-True} -ppo_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN:-2400} -forward_max_token_len_per_gpu=${FWD_MAX_TOKEN_LEN:-4800} -train_traj_micro_bsz_per_gpu=${MICRO_BSZ:-2} # b -n_resp_per_prompt=4 # g - -train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n -train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n -train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g -train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g - -MAX_PROMPT_LENGTH=${MAX_PROMPT_LENGTH:-512} -MAX_RESPONSE_LENGTH=${MAX_RESPONSE_LENGTH:-512} - -COMMON_PP=${COMMON_PP:-2} -COMMON_VPP=${COMMON_VPP:-2} -COMMON_CP=${COMMON_CP:-2} -COMMON_TP=${COMMON_TP:-2} -COMMON_EP=${COMMON_EP:-1} -COMMON_ETP=${COMMON_ETP:-null} - -TRAIN_TP=${TRAIN_TP:-$COMMON_TP} -INFER_TP=${INFER_TP:-$COMMON_TP} - -ACTOR_PP=${ACTOR_PP:-$COMMON_PP} -ACTOR_VPP=${ACTOR_VPP:-$COMMON_VPP} -ACTOR_CP=${ACTOR_CP:-$COMMON_CP} -ACTOR_TP=${ACTOR_TP:-$TRAIN_TP} -ACTOR_EP=${ACTOR_EP:-$COMMON_EP} -ACTOR_ETP=${ACTOR_ETP:-$COMMON_ETP} -ROLLOUT_TP=${ROLLOUT_TP:-$INFER_TP} -REF_PP=${REF_PP:-$COMMON_PP} -REF_VPP=${REF_VPP:-$COMMON_VPP} -REF_CP=${REF_CP:-$COMMON_CP} -REF_TP=${REF_TP:-$TRAIN_TP} -REF_EP=${REF_EP:-$COMMON_EP} -REF_ETP=${REF_ETP:-$COMMON_ETP} -CRITIC_PP=${CRITIC_PP:-$COMMON_PP} -CRITIC_VPP=${CRITIC_VPP:-$COMMON_VPP} -CRITIC_CP=${CRITIC_CP:-$COMMON_CP} -CRITIC_TP=${CRITIC_TP:-$TRAIN_TP} -CRITIC_EP=${CRITIC_EP:-$COMMON_EP} -CRITIC_ETP=${CRITIC_ETP:-$COMMON_ETP} -RM_PP=${RM_PP:-$COMMON_PP} -RM_VPP=${RM_VPP:-$COMMON_VPP} -RM_CP=${RM_CP:-$COMMON_CP} -RM_TP=${RM_TP:-$TRAIN_TP} -RM_EP=${RM_EP:-$COMMON_EP} -RM_ETP=${RM_ETP:-$COMMON_ETP} - -ALL_OFFLOAD=${ALL_OFFLOAD:-False} -COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} -COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} -COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} - -ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} -ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} -REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} -CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} -RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} -USE_MBRIDGE=${USE_MBRIDGE:-False} -USE_FUSED_KERNELS=${USE_FUSED_KERNELS:-False} - -LR_WARMUP_STEPS=${LR_WARMUP_STEPS:-null} - -CHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra'] -SKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0} -if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then - CHECKPOINT_CONTENTS=['model','optimizer','extra'] -fi - -USE_DIST_CKPT=${USE_DIST_CKPT:-False} -DIST_CKPT_PATH=${DIST_CKPT_PATH:-${HOME}/dist_ckpt/${MODEL_ID}} -if [ "$USE_DIST_CKPT" = "True" ]; then - if [ "$USE_DUMMY_MODEL" = "True" ]; then - DIST_CKPT_PATH=${HOME}/dist_ckpt_dummy/${MODEL_ID} - fi - python scripts/converter_hf_to_mcore.py \ - --hf_model_path "${MODEL_PATH}" \ - --output_path "${DIST_CKPT_PATH}" -fi - -ENGINE=${ENGINE:-"vllm"} - -exp_name="$(basename "${MODEL_ID,,}")-megatron-gsm8k-minimal" - -if [ "$ENGINE" = "vllm" ]; then - MODE=${MODE:-"sync"} - ROLLOUT_MODE_ARG="actor_rollout_ref.rollout.mode=${MODE}" - if [ "$MODE" = "async" ]; then - ROLLOUT_MODE_ARG="${ROLLOUT_MODE_ARG} data.return_raw_chat=True" - fi -else - ROLLOUT_MODE_ARG="" -fi - -python3 -m verl.trainer.main_ppo --config-path=config \ - --config-name='ppo_megatron_trainer.yaml'\ - algorithm.adv_estimator="${ADV_ESTIMATOR}" \ - data.train_files="${TRAIN_FILES}" \ - data.val_files="${VAL_FILES}" \ - data.train_batch_size=${train_prompt_bsz} \ - data.max_prompt_length=${MAX_PROMPT_LENGTH} \ - data.max_response_length=${MAX_RESPONSE_LENGTH} \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.model.use_fused_kernels=${USE_FUSED_KERNELS} \ - actor_rollout_ref.actor.optim.lr_warmup_steps=$LR_WARMUP_STEPS \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ} \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${ppo_max_token_len_per_gpu} \ - actor_rollout_ref.actor.megatron.use_mbridge=${USE_MBRIDGE} \ - actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$ACTOR_PP \ - actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \ - actor_rollout_ref.actor.megatron.context_parallel_size=$ACTOR_CP \ - actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$ACTOR_TP \ - actor_rollout_ref.actor.megatron.expert_model_parallel_size=$ACTOR_EP \ - actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ACTOR_ETP \ - actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ - actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ - actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ - actor_rollout_ref.actor.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ - actor_rollout_ref.actor.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_CONTENTS \ - actor_rollout_ref.rollout.name="${ENGINE}" ${ROLLOUT_MODE_ARG}\ - actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.ref.megatron.use_mbridge=${USE_MBRIDGE} \ - actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$REF_PP \ - actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=$REF_VPP \ - actor_rollout_ref.ref.megatron.context_parallel_size=$REF_CP \ - actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$REF_TP \ - actor_rollout_ref.ref.megatron.expert_model_parallel_size=$REF_EP \ - actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$REF_ETP \ - actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ - actor_rollout_ref.ref.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ - actor_rollout_ref.ref.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ - critic.optim.lr=2e-5 \ - critic.optim.lr_warmup_steps=$LR_WARMUP_STEPS \ - critic.model.path="${MODEL_PATH}" \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - critic.ppo_max_token_len_per_gpu=${forward_max_token_len_per_gpu} \ - critic.megatron.use_mbridge=${USE_MBRIDGE} \ - critic.megatron.pipeline_model_parallel_size=$CRITIC_PP \ - critic.megatron.virtual_pipeline_model_parallel_size=$CRITIC_VPP \ - critic.megatron.context_parallel_size=$CRITIC_CP \ - critic.megatron.tensor_model_parallel_size=$CRITIC_TP \ - critic.megatron.expert_model_parallel_size=$CRITIC_EP \ - critic.megatron.expert_tensor_parallel_size=$CRITIC_ETP \ - critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \ - critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \ - critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \ - critic.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ - critic.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ - critic.checkpoint.save_contents=$CHECKPOINT_CONTENTS \ - reward_model.enable=True \ - reward_model.model.path="${MODEL_PATH}" \ - reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - reward_model.megatron.use_mbridge=${USE_MBRIDGE} \ - reward_model.megatron.pipeline_model_parallel_size=$RM_PP \ - reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \ - reward_model.megatron.context_parallel_size=$RM_CP \ - reward_model.megatron.tensor_model_parallel_size=$RM_TP \ - reward_model.megatron.expert_model_parallel_size=$RM_EP \ - reward_model.megatron.expert_tensor_parallel_size=$RM_ETP \ - reward_model.megatron.param_offload=${RM_PARAM_OFFLOAD} \ - reward_model.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \ - reward_model.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \ - algorithm.use_kl_in_reward=False \ - algorithm.kl_penalty=kl \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl-test' \ - trainer.experiment_name="${exp_name}" \ - trainer.nnodes=1 \ - trainer.n_gpus_per_node=${NUM_GPUS} \ - trainer.val_before_train="${VAL_BEFORE_TRAIN}" \ - trainer.test_freq="${TEST_FREQ}" \ - trainer.save_freq="${SAVE_FREQ}" \ - trainer.resume_mode="${RESUME_MODE}" \ - trainer.total_epochs=2 \ - trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@ diff --git a/tests/special_e2e/run_prime.sh b/tests/special_e2e/run_prime.sh deleted file mode 100644 index ac7ecb79c..000000000 --- a/tests/special_e2e/run_prime.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -NUM_GPUS=${NUM_GPUS:-8} - -MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B} -MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} -huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" - -TRAIN_FILES=${TRAIN_FILES:-${HOME}/data/gsm8k/train.parquet} -VAL_FILES=${VAL_FILES:-${HOME}/data/gsm8k/test.parquet} - -train_traj_micro_bsz_per_gpu=2 # b -n_resp_per_prompt=4 # g - -train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n -train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n -train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g -train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g - -exp_name="$(basename "${MODEL_ID,,}")-prime-minimal" - -python3 -m recipe.prime.main_prime \ - data.train_files="${TRAIN_FILES}" \ - data.val_files="${VAL_FILES}" \ - data.train_batch_size=${train_prompt_bsz} \ - data.max_prompt_length=512 \ - data.max_response_length=512 \ - data.filter_accuracy=True \ - data.accuracy_lower_bound=0.2 \ - data.accuracy_upper_bound=0.8 \ - data.oversample_factor=4 \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.actor.optim.lr=5e-7 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.use_fused_kernels=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.model.enable_gradient_checkpointing=False \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.adv_estimator=rloo \ - algorithm.use_kl_in_reward=True \ - algorithm.kl_penalty=kl \ - algorithm.kl_ctrl.kl_coef=0.001 \ - reward_model.model.path="${MODEL_PATH}" \ - reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - reward_model.model.update=before \ - reward_model.model.beta_train=0.05 \ - reward_model.model.optim.lr=1e-6 \ - reward_model.model.optim.grad_clip=10.0 \ - reward_model.model.input_tokenizer=null \ - reward_model.mini_batch_size=${train_prompt_bsz} \ - reward_model.reward_manager=prime \ - trainer.val_before_train=False \ - trainer.logger=console \ - trainer.project_name='verl-test' \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=${NUM_GPUS} \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 $@ diff --git a/tests/special_e2e/run_r1_distill_qwen_aime24_eval.sh b/tests/special_e2e/run_r1_distill_qwen_aime24_eval.sh deleted file mode 100644 index 5dec6fe6c..000000000 --- a/tests/special_e2e/run_r1_distill_qwen_aime24_eval.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ - --local-dir $HOME/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B - -python3 -m verl.trainer.main_generation \ - trainer.nnodes=1 \ - trainer.n_gpus_per_node=8 \ - data.path=$HOME/data/r1/test.parquet \ - data.prompt_key=prompt \ - data.batch_size=1024 \ - data.n_samples=1 \ - data.output_path=$HOME/data/r1/test-output-k1.parquet \ - model.path=$HOME/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ - rollout.temperature=0.6 \ - rollout.top_p=0.95 \ - rollout.prompt_length=1024 \ - rollout.response_length=32768 \ - rollout.tensor_model_parallel_size=1 \ - rollout.gpu_memory_utilization=0.95 \ - rollout.max_num_batched_tokens=65536 \ - rollout.enforce_eager=False \ - rollout.free_cache_engine=True - -python3 -m recipe.r1.main_eval \ - data.path=$HOME/data/r1/test-output-k1.parquet \ - data.prompt_key=prompt \ - data.response_key=responses \ - custom_reward_function.path=recipe/r1/reward_score.py \ - custom_reward_function.name=reward_func \ No newline at end of file diff --git a/tests/special_e2e/run_spin.sh b/tests/special_e2e/run_spin.sh deleted file mode 100644 index 1b5a2af0d..000000000 --- a/tests/special_e2e/run_spin.sh +++ /dev/null @@ -1,31 +0,0 @@ -set -e -set -x -NUM_GPUS=${NUM_GPUS:-8} - -exp_name="Qwen2.5-0.5B-Instruct-spin-minimal" - -CUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} python3 -m recipe.spin.main_spin \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=1024 \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size=8 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=64 \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.logger=console \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=4 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=1 \ - +trainer.log_freq=1 \ - trainer.ref_update_freq=1 \ - trainer.total_training_steps=1 \ - trainer.total_epochs=1000 2>&1 | tee verl_demo.log \ No newline at end of file diff --git a/tests/special_e2e/run_sppo.sh b/tests/special_e2e/run_sppo.sh deleted file mode 100644 index a33131972..000000000 --- a/tests/special_e2e/run_sppo.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -# in e2e_sppo.yml, we set NUM_GPUS=8 L20 - -NUM_GPUS=${NUM_GPUS:-8} - -gsm8k_train_path=./data/math/train.parquet -gsm8k_test_path=./data/math/test.parquet -train_files="['$gsm8k_train_path']" -test_files="['$gsm8k_test_path']" - -exp_name="Qwen2.5-0.5B-Instruct-sppo-minimal" - -python3 -m recipe.sppo.main_sppo \ - data.train_files="$train_files" \ - data.val_files="$test_files" \ - data.train_batch_size=1024 \ - data.max_prompt_length=1024 \ - data.max_response_length=512 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.return_raw_chat=True \ - actor_rollout_ref.model.path="./models/Qwen2.5-0.5B-Instruct" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.use_fused_kernels=True \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ - actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.use_kl_loss=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.name=sglang \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.val_before_train=False \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_training_steps=1 \ - trainer.total_epochs=2 $@ diff --git a/tests/special_e2e/run_test.sh b/tests/special_e2e/run_test.sh deleted file mode 100644 index 4b0dc5fa5..000000000 --- a/tests/special_e2e/run_test.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -set -xeuo pipefail - -# Get the configuration name and engine name from arguments -CONFIG_NAME="$1" -ENGINE="${2:-vllm}" - -# Download model if needed -huggingface-cli download Qwen/Qwen2.5-0.5B --local-dir "$HOME/models/Qwen/Qwen2.5-0.5B" - -# Run the training with the specified configuration -python3 -m verl.trainer.main_ppo \ - --config-name "$CONFIG_NAME" "$@" \ No newline at end of file diff --git a/tests/special_e2e/sft/run_sft.sh b/tests/special_e2e/sft/run_sft.sh deleted file mode 100644 index 4cd9a4790..000000000 --- a/tests/special_e2e/sft/run_sft.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.fsdp_sft_trainer"} - -NUM_GPUS=${NUM_GPUS:-8} - -MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} -MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} -huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" - -TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet} -VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet} - -SP_SIZE=${SP_SIZE:-1} -LIGER=${LIGER:-False} -MULTITURN=${MULTITURN:-False} -LORA_RANK=${LORA_RANK:-0} -RM_PAD=${RM_PAD:-True} - -micro_bsz=2 -NUM_GPUS=8 - -project_name="verl-test" -exp_name="$(basename "${MODEL_ID,,}")-sft-minimal" -ckpts_home=${ckpts_home:-$HOME/${project_name}/${exp_name}} - -mkdir -p "${ckpts_home}" - -torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \ - data.train_files="${TRAIN_FILES}" \ - data.val_files="${VAL_FILES}" \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - data.prompt_dict_keys=['question'] \ - data.response_dict_keys=['answer'] \ - data.multiturn.enable="${MULTITURN}" \ - data.multiturn.messages_key=messages \ - optim.lr=1e-4 \ - data.micro_batch_size_per_gpu=${micro_bsz} \ - model.strategy=fsdp \ - model.partial_pretrain="${MODEL_PATH}" \ - model.lora_rank="${LORA_RANK}" \ - model.lora_alpha=16 \ - model.target_modules=all-linear \ - model.use_liger="${LIGER}" \ - ulysses_sequence_parallel_size="${SP_SIZE}" \ - use_remove_padding="${RM_PAD}" \ - trainer.default_local_dir="${ckpts_home}" \ - trainer.project_name="${project_name}" \ - trainer.experiment_name="${exp_name}" \ - trainer.total_training_steps=1 \ - trainer.logger=console $@ - -rm -rf "${ckpts_home:?}/*" \ No newline at end of file diff --git a/tests/special_e2e/sft/test_sp_loss_match.py b/tests/special_e2e/sft/test_sp_loss_match.py deleted file mode 100644 index 4dc0cbdae..000000000 --- a/tests/special_e2e/sft/test_sp_loss_match.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.distributed -from tensordict import TensorDict -from torch.distributed.device_mesh import init_device_mesh - -from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer -from verl.utils.distributed import initialize_global_process_group - - -def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4): - """Test consistency between original forward pass and SP+rmpad forward passes. - - Args: - trainer: The FSDPSFTTrainer instance to test - total_steps: Number of steps to test (default: 4) - """ - if trainer.device_mesh.get_rank() == 0: - print("\nStarting debug comparison between original and SP+rmpad forward passes...") - print(f"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}") - print(f"Remove padding: {trainer.use_remove_padding}\n") - - steps_remaining = total_steps - - for epoch in range(1): # Just one epoch for testing - trainer.train_sampler.set_epoch(epoch=epoch) - for data in trainer.train_dataloader: - data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda() - trainer.fsdp_model.train() - micro_batches = data.split(trainer.config.data.micro_batch_size_per_gpu) - - for idx, micro_batch in enumerate(micro_batches): - if trainer.device_mesh.get_rank() == 0: - print(f"\nProcessing micro batch {idx + 1}/{len(micro_batches)}") - - # Compute losses using both methods - # Disable SP and rmpad - trainer.use_remove_padding = False - old_sp = trainer.config.ulysses_sequence_parallel_size - trainer.config.ulysses_sequence_parallel_size = 1 - loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) - - # Do SP and rmpad - trainer.config.ulysses_sequence_parallel_size = old_sp - trainer.use_remove_padding = True - loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) - - # Collect losses across all ranks - loss_ref_all = loss_ref.clone() - loss_sp_all = loss_sp.clone() - torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG) - torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG) - - # Calculate relative difference of averaged losses - rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8) - - if trainer.device_mesh.get_rank() == 0: - print("\nComparison Results (Averaged across ranks):") - print(f"Reference Loss: {loss_ref_all.item():.6f}") - print(f"SP+rmpad Loss: {loss_sp_all.item():.6f}") - print(f"Relative Difference: {rel_diff.item():.6f}") - - assert rel_diff.item() < 1e-2, "Significant difference detected between averaged losses!" - print("Loss difference is within the acceptable range.") - - steps_remaining -= 1 - if steps_remaining == 0: - break - if steps_remaining == 0: - break - break - - if trainer.device_mesh.get_rank() == 0: - print("\nDebug comparison completed successfully.") - - -def create_trainer(config): - """Create and initialize a trainer instance with the given config. - - Args: - config: Configuration object with training parameters - - Returns: - FSDPSFTTrainer: Initialized trainer instance - """ - local_rank, rank, world_size = initialize_global_process_group() - - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) - - dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh( - device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp") - ) - - # build tokenizer and datasets first - from verl.trainer.fsdp_sft_trainer import create_sft_dataset - from verl.utils import hf_tokenizer - from verl.utils.fs import copy_to_local - - local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) - tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) - train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) - val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) - - return FSDPSFTTrainer( - config=config, - device_mesh=device_mesh, - ulysses_device_mesh=ulysses_device_mesh, - tokenizer=tokenizer, - train_dataset=train_dataset, - val_dataset=val_dataset, - ) - - -def main(config): - """Main function to run trainer tests. - - Args: - config: Configuration object with training parameters - """ - trainer = create_trainer(config) - test_trainer_forward_consistency(trainer) - - -if __name__ == "__main__": - import hydra - from omegaconf import DictConfig - - @hydra.main(config_path="../../../verl/trainer/config", config_name="sft_trainer") - def hydra_entry(cfg: DictConfig) -> None: - main(cfg) - - hydra_entry() diff --git a/tests/special_npu/run_qwen2_5_05b_dapo.sh b/tests/special_npu/run_qwen2_5_05b_dapo.sh deleted file mode 100644 index 3f3756bdf..000000000 --- a/tests/special_npu/run_qwen2_5_05b_dapo.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env bash -set -xeuo pipefail - -NUM_GPUS=${NUM_GPUS:-8} - -MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} -MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} -huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" - -adv_estimator=grpo - -kl_coef=0.0 -use_kl_in_reward=False -use_kl_loss=False -kl_loss_coef=0.0 - -clip_ratio_low=0.2 -clip_ratio_high=0.28 - -max_prompt_length=1024 -max_response_length=2048 -enable_overlong_buffer=True -overlong_buffer_len=128 -overlong_penalty_factor=1.0 - -loss_agg_mode="token-mean" - -enable_filter_groups=True -filter_groups_metric=seq_reward -max_num_gen_batches=10 - -train_traj_micro_bsz_per_gpu=2 # b -n_resp_per_prompt=4 # g - -train_traj_micro_bsz=$((train_traj_micro_bsz_per_gpu * NUM_GPUS)) # b * n -train_traj_mini_bsz=$((train_traj_micro_bsz * 2)) # 2 * b * n -train_prompt_mini_bsz=$((train_traj_mini_bsz * n_resp_per_prompt)) # 2 * b * n / g -train_prompt_bsz=$((train_prompt_mini_bsz * 2)) # 4 * b * n / g - -gen_prompt_bsz=$((train_prompt_bsz * 4)) - -exp_name="$(basename "${MODEL_ID,,}")-dapo-minimal" - -python3 -m recipe.dapo.main_dapo \ - data.train_files="${HOME}/data/gsm8k/train.parquet" \ - data.val_files="${HOME}/data/gsm8k/test.parquet" \ - reward_model.reward_manager=dapo \ - algorithm.adv_estimator=${adv_estimator} \ - algorithm.use_kl_in_reward=${use_kl_in_reward} \ - algorithm.kl_ctrl.kl_coef=${kl_coef} \ - actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ - actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ - actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ - actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ - data.max_prompt_length=${max_prompt_length} \ - data.max_response_length=${max_response_length} \ - reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ - reward_model.overlong_buffer.len=${overlong_buffer_len} \ - reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ - actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ - data.train_batch_size=${train_prompt_bsz} \ - data.gen_batch_size=${gen_prompt_bsz} \ - algorithm.filter_groups.enable=${enable_filter_groups} \ - algorithm.filter_groups.metric=${filter_groups_metric} \ - algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ - actor_rollout_ref.model.path="${MODEL_PATH}" \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.model.use_fused_kernels=True \ - actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ - actor_rollout_ref.ref.fsdp_config.forward_prefetch=True \ - actor_rollout_ref.actor.entropy_checkpointing=True \ - actor_rollout_ref.ref.entropy_checkpointing=True \ - actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \ - actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ - trainer.logger=console \ - trainer.project_name='verl-test' \ - trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=${NUM_GPUS} \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.total_epochs=2 \ - trainer.resume_mode=disable \ - trainer.val_before_train=False \ - trainer.total_training_steps=1 \ - trainer.device=npu $@ diff --git a/tests/special_npu/run_qwen2_5_05b_grpo.sh b/tests/special_npu/run_qwen2_5_05b_grpo.sh deleted file mode 100644 index 466386b15..000000000 --- a/tests/special_npu/run_qwen2_5_05b_grpo.sh +++ /dev/null @@ -1,43 +0,0 @@ -set -x - - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.train_batch_size=128 \ - data.max_prompt_length=512 \ - data.max_response_length=128 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - actor_rollout_ref.actor.optim.lr=5e-7 \ - actor_rollout_ref.model.use_remove_padding=False \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.kl_ctrl.kl_coef=0.001 \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_gsm8k' \ - trainer.experiment_name='qwen2_7b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=5 \ - trainer.total_epochs=1 \ - trainer.total_training_steps=2 \ - trainer.device=npu $@ diff --git a/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh b/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh deleted file mode 100644 index 1bb8fc4cd..000000000 --- a/tests/special_npu/run_qwen2_5_05b_sft_peft_sp2.sh +++ /dev/null @@ -1,29 +0,0 @@ -set -x - -mkdir -p ./save_ckpts - -torchrun --standalone --nnodes=1 --nproc_per_node=8 \ - -m verl.trainer.fsdp_sft_trainer \ - data.train_files=$HOME/data/gsm8k/train.parquet \ - data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=extra_info \ - data.response_key=extra_info \ - optim.lr=1e-4 \ - data.prompt_dict_keys=['question'] \ - +data.response_dict_keys=['answer'] \ - data.micro_batch_size_per_gpu=32 \ - model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ - trainer.default_local_dir=./save_ckpts \ - trainer.project_name=gsm8k-sft \ - trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \ - trainer.logger=console \ - trainer.total_epochs=1 \ - trainer.total_training_steps=1 $@ \ - model.lora_rank=32 \ - model.lora_alpha=16 \ - model.target_modules=all-linear \ - model.strategy=fsdp \ - ulysses_sequence_parallel_size=2 \ - use_remove_padding=true - -rm -rf ./outputs ./save_ckpts diff --git a/tests/special_npu/run_qwen2_5_vl_3b_npu.sh b/tests/special_npu/run_qwen2_5_vl_3b_npu.sh deleted file mode 100644 index dc3799e99..000000000 --- a/tests/special_npu/run_qwen2_5_vl_3b_npu.sh +++ /dev/null @@ -1,52 +0,0 @@ -set -x -ENGINE=${1:-vllm} - -# Some models are optimized by vllm ascend. While in some case, e.g. rlhf training, -# the optimized model may not be suitable. In this case, set this value to 0 to disable the optimized model. -export USE_OPTIMIZED_MODEL=0 - -python3 -m verl.trainer.main_ppo \ - algorithm.adv_estimator=grpo \ - data.train_files=$HOME/data/geo3k/train.parquet \ - data.val_files=$HOME/data/geo3k/test.parquet \ - data.train_batch_size=512 \ - data.max_prompt_length=1024 \ - data.max_response_length=2048 \ - data.filter_overlong_prompts=True \ - data.truncation='error' \ - data.image_key=images \ - actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.ppo_mini_batch_size=16 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ - actor_rollout_ref.actor.use_kl_loss=True \ - actor_rollout_ref.actor.kl_loss_coef=0.01 \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.use_torch_compile=False \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=False \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.rollout.enable_chunked_prefill=False \ - actor_rollout_ref.rollout.enforce_eager=True \ - actor_rollout_ref.rollout.free_cache_engine=True \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=console \ - trainer.project_name='verl_grpo_example_geo3k' \ - trainer.experiment_name='qwen2_5_vl_3b_function_rm' \ - trainer.n_gpus_per_node=8 \ - trainer.nnodes=1 \ - trainer.save_freq=-1 \ - trainer.test_freq=-1 \ - trainer.total_epochs=1 \ - trainer.total_training_steps=1 \ - trainer.device=npu $@ \ No newline at end of file diff --git a/tests/special_sanity/check_api_docs.py b/tests/special_sanity/check_api_docs.py deleted file mode 100644 index fa31ec8c5..000000000 --- a/tests/special_sanity/check_api_docs.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Fail CI if any function or class that is publicly exported via -``__all__`` lacks a docstring. - -Usage ------ - # Check specific modules or packages - python check_docstrings.py mypkg.core mypkg.utils - - # Check an entire source tree (all top-level packages under cwd) - python check_docstrings.py -""" - -from __future__ import annotations - -import argparse -import importlib -import inspect -import pkgutil -import sys -from pathlib import Path -from types import ModuleType -from typing import Iterable - -_ALLOW_LIST = [ - "verl.third_party.vllm.LLM", - "verl.third_party.vllm.parallel_state", - "verl.utils.profiler.WorkerProfiler", - "verl.utils.profiler.WorkerProfilerExtension", - "verl.utils.profiler.log_gpu_memory_usage", - "verl.utils.profiler.log_print", - "verl.utils.profiler.mark_annotate", - "verl.utils.profiler.mark_end_range", - "verl.utils.profiler.mark_start_range", - "verl.models.mcore.qwen2_5_vl.get_vision_model_config", - "verl.models.mcore.qwen2_5_vl.get_vision_projection_config", -] - - -def iter_submodules(root: ModuleType) -> Iterable[ModuleType]: - """Yield *root* and every sub-module inside it.""" - yield root - if getattr(root, "__path__", None): # only packages have __path__ - for mod_info in pkgutil.walk_packages(root.__path__, prefix=f"{root.__name__}."): - try: - yield importlib.import_module(mod_info.name) - except Exception as exc: # noqa: BLE001 - print(f"[warn] Skipping {mod_info.name!r}: {exc}", file=sys.stderr) - - -def names_missing_doc(mod: ModuleType) -> list[str]: - """Return fully-qualified names that need docstrings.""" - missing: list[str] = [] - public = getattr(mod, "__all__", []) - for name in public: - obj = getattr(mod, name, None) - if f"{mod.__name__}.{name}" in _ALLOW_LIST: - continue - if obj is None: - # Exported but not found in the module: flag it anyway. - missing.append(f"{mod.__name__}.{name} (not found)") - continue - - if inspect.isfunction(obj) or inspect.isclass(obj): - doc = inspect.getdoc(obj) - if not doc or not doc.strip(): - missing.append(f"{mod.__name__}.{name}") - return missing - - -def check_module(qualname: str) -> list[str]: - """Import *qualname* and check it (and sub-modules).""" - try: - module = importlib.import_module(qualname) - except ModuleNotFoundError as exc: - print(f"[error] Cannot import '{qualname}': {exc}", file=sys.stderr) - return [qualname] - - missing: list[str] = [] - for submod in iter_submodules(module): - missing.extend(names_missing_doc(submod)) - return missing - - -def autodiscover_packages() -> list[str]: - """Detect top-level packages under CWD when no argument is given.""" - pkgs: list[str] = [] - for p in Path.cwd().iterdir(): - if p.is_dir() and (p / "__init__.py").exists(): - pkgs.append(p.name) - return pkgs - - -def main() -> None: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "modules", - nargs="*", - help="Fully-qualified module or package names (defaults to every top-level package found in CWD).", - ) - args = parser.parse_args() - - targets = args.modules or autodiscover_packages() - if not targets: - raise ValueError("[error] No modules specified and none detected automatically.") - - all_missing: list[str] = [] - for modname in targets: - all_missing.extend(check_module(modname)) - - if all_missing: - print("\nMissing docstrings:") - for name in sorted(all_missing): - print(f" - {name}") - raise ValueError("Missing docstrings detected. Please enhance them with docs accordingly.") - - print("✅ All exported functions/classes have docstrings.") - - -if __name__ == "__main__": - main() diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py deleted file mode 100644 index c8988db55..000000000 --- a/tests/special_sanity/check_device_api_usage.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This CI test is used for checking whether device api usage is irregular, suggest using api in `verl/utils/device.py`. -Search targets include .py files in verl/recipe and verl/verl. -Some files that must contain ".cuda", "cuda" or "nccl" keyword is pre-defined in whitelist below. -""" - -import os -from argparse import ArgumentParser -from pathlib import Path - -# directory or file path must contain keyword ".cuda" or "cuda" -CUDA_KEYWORD_CHECK_WHITELIST = [ - "verl/utils/device.py", - "recipe/prime/prime_ray_trainer.py", # appear in default device_name - "recipe/spin/spin_trainer.py", # appear in default device_name - "recipe/sppo/sppo_ray_trainer.py", # appear in default device_name - "verl/utils/profiler/nvtx_profile.py", # appear in NsightSystemsProfiler - "verl/utils/kernel/linear_cross_entropy.py", # appear in nvidia nvtx - "verl/utils/rendezvous/ray_backend.py", # appear in cupy importance - "verl/single_controller/ray/base.py", # appear in default device_name - "verl/trainer/ppo/ray_trainer.py", # appear in default device_name - "verl/utils/reward_score/sandbox_fusion/utils.py", # appear in sandbox language type - "verl/workers/reward_model/megatron/reward_model.py", # appear in default device_name -] - -# directory or file path must contain keyword "nccl" -NCCL_KEYWORD_CHECK_WHITELIST = [ - "verl/utils/device.py", - "verl/third_party/sglang/parallel_state.py", # appear in default backend -] - -SEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST - -SEARCH_KEYWORDS = [".cuda", '"cuda"', '"nccl"'] - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--directory", "-d", required=True, type=str) - args = parser.parse_args() - directory_in_str = args.directory - - pathlist = Path(directory_in_str).glob("**/*.py") - for path in pathlist: - path_in_str = str(path.absolute()) - - # judge whether current path is in pre-defined search whitelist or not. - path_in_whitelist = False - - for sw in SEARCH_WHITELIST: - # for easy debugging in non-linux system - sw = sw.replace("/", os.sep) - if sw in path_in_str: - print(f"[SKIP] File {path_in_str} is in device api usage check whitelist, checking is skipped.") - path_in_whitelist = True - break - - if path_in_whitelist: - continue - - with open(path_in_str, encoding="utf-8") as f: - file_content = f.read() - - find_invalid_device_management = False - - for sk in SEARCH_KEYWORDS: - if sk in file_content: - find_invalid_device_management = True - break - - print( - f"[CHECK] File {path_in_str} is detected for device api usage check, check result: " - f"{'success' if not find_invalid_device_management else f'failed, because detect {sk}'}." - ) - - assert not find_invalid_device_management, ( - f'file {path_in_str} contains .cuda/"cuda"/"nccl" usage, please use api in ' - f"verl/utils/device.py directly." - ) diff --git a/tests/special_sanity/check_docs_time_info.py b/tests/special_sanity/check_docs_time_info.py deleted file mode 100644 index a54d1d50a..000000000 --- a/tests/special_sanity/check_docs_time_info.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Check that every .md and .rst file under docs/ contains the substring "Last updated", -with an allow-list for exceptions. -""" - -import sys -from pathlib import Path - -# === CONFIGURATION === - -# Relative paths (to docs/) or glob patterns to skip checking -ALLOW_LIST = { - "docs/README.md", # you can list individual files - "docs/legacy/*.rst", # or glob patterns - "docs/index.rst", - "docs/start/install.rst", - "docs/start/quickstart.rst", - "docs/README_vllm0.7.md", -} - -# The folder to scan -DOCS_DIR = Path("docs") - -# === SCRIPT === - - -def is_allowed(path: Path) -> bool: - """ - Return True if `path` matches any entry in ALLOW_LIST. - """ - rel = str(path) - for pattern in ALLOW_LIST: - if Path(rel).match(pattern): - return True - return False - - -def main(): - if not DOCS_DIR.exists(): - print(f"Error: Documentation directory '{DOCS_DIR}' does not exist.", file=sys.stderr) - sys.exit(1) - - missing = [] - - # Gather all .md and .rst files under docs/ - for ext in ("*.md", "*.rst"): - for path in DOCS_DIR.rglob(ext): - if is_allowed(path): - continue - - text = path.read_text(encoding="utf-8", errors="ignore") - if "Last updated" not in text: - missing.append(path) - - # Report - if missing: - print("\nThe following files are missing the 'Last updated' string:\n") - for p in missing: - print(f" - {p}") - print(f"\nTotal missing: {len(missing)}\n", file=sys.stderr) - raise AssertionError( - "Some documentation files lack a 'Last updated' line. Please include info such as " - "'Last updated: mm/dd/yyyy' to indicate the last update time of the document." - ) - else: - print("✅ All checked files contain 'Last updated'.") - - -if __name__ == "__main__": - main() diff --git a/tests/special_sanity/check_docstrings.py b/tests/special_sanity/check_docstrings.py deleted file mode 100644 index 7c5d8ed71..000000000 --- a/tests/special_sanity/check_docstrings.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Python script to check docstrings for functions and classes in specified files. -Checks that every public function and class has proper docstring documentation. -""" - -import ast -import os -import sys - - -class DocstringChecker(ast.NodeVisitor): - """AST visitor to check for missing docstrings in functions and classes.""" - - def __init__(self, filename: str): - self.filename = filename - self.missing_docstrings: list[tuple[str, str, int]] = [] - self.current_class = None - self.function_nesting_level = 0 - - def visit_FunctionDef(self, node: ast.FunctionDef): - """Visit function definitions and check for docstrings.""" - if not node.name.startswith("_") and self.function_nesting_level == 0: - if not self._has_docstring(node): - func_name = f"{self.current_class}.{node.name}" if self.current_class else node.name - self.missing_docstrings.append((func_name, self.filename, node.lineno)) - - self.function_nesting_level += 1 - self.generic_visit(node) - self.function_nesting_level -= 1 - - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): - """Visit async function definitions and check for docstrings.""" - if not node.name.startswith("_") and self.function_nesting_level == 0: - if not self._has_docstring(node): - func_name = f"{self.current_class}.{node.name}" if self.current_class else node.name - self.missing_docstrings.append((func_name, self.filename, node.lineno)) - - self.function_nesting_level += 1 - self.generic_visit(node) - self.function_nesting_level -= 1 - - def visit_ClassDef(self, node: ast.ClassDef): - """Visit class definitions and check for docstrings.""" - if not node.name.startswith("_"): - if not self._has_docstring(node): - self.missing_docstrings.append((node.name, self.filename, node.lineno)) - - old_class = self.current_class - self.current_class = node.name - self.generic_visit(node) - self.current_class = old_class - - def _has_docstring(self, node) -> bool: - """Check if a node has a docstring.""" - return ast.get_docstring(node) is not None - - -def check_file_docstrings(filepath: str) -> list[tuple[str, str, int]]: - """Check docstrings in a single file.""" - try: - with open(filepath, encoding="utf-8") as f: - content = f.read() - - tree = ast.parse(content, filename=filepath) - checker = DocstringChecker(filepath) - checker.visit(tree) - return checker.missing_docstrings - - except Exception as e: - print(f"Error processing {filepath}: {e}") - return [] - - -def main(): - """Main function to check docstrings in specified files.""" - - files_to_check = [ - "verl/trainer/ppo/ray_trainer.py", - "verl/trainer/main_ppo.py", - "verl/trainer/ppo/reward.py", - "verl/utils/reward_score/__init__.py", - "verl/trainer/ppo/core_algos.py", - "verl/experimental/agent_loop/agent_loop.py", - "verl/workers/sharding_manager/fsdp_vllm.py", - "verl/workers/sharding_manager/fsdp_ulysses.py", - ] - - script_dir = os.path.dirname(os.path.abspath(__file__)) - repo_path = os.path.dirname(os.path.dirname(script_dir)) - - if not os.path.exists(repo_path): - print(f"Repository path {repo_path} does not exist!") - sys.exit(1) - - os.chdir(repo_path) - - all_missing_docstrings = [] - - print("Checking docstrings in specified files...") - print("=" * 60) - - for file_path in files_to_check: - if not os.path.exists(file_path): - print(f"Warning: File {file_path} does not exist!") - continue - - print(f"Checking {file_path}...") - missing = check_file_docstrings(file_path) - all_missing_docstrings.extend(missing) - - if missing: - print(f" Found {len(missing)} missing docstrings") - else: - print(" All functions and classes have docstrings ✓") - - print("=" * 60) - - if all_missing_docstrings: - print(f"\nSUMMARY: Found {len(all_missing_docstrings)} functions/classes missing docstrings:") - print("-" * 60) - - by_file = {} - for name, filepath, lineno in all_missing_docstrings: - if filepath not in by_file: - by_file[filepath] = [] - by_file[filepath].append((name, lineno)) - - for filepath in sorted(by_file.keys()): - print(f"\n{filepath}:") - for name, lineno in sorted(by_file[filepath], key=lambda x: x[1]): - print(f" - {name} (line {lineno})") - - print(f"\nTotal missing docstrings: {len(all_missing_docstrings)}") - - raise Exception(f"Found {len(all_missing_docstrings)} functions/classes without proper docstrings!") - - else: - print("\n✅ All functions and classes have proper docstrings!") - - -if __name__ == "__main__": - main() diff --git a/tests/special_sanity/check_license.py b/tests/special_sanity/check_license.py deleted file mode 100644 index a02afeb3d..000000000 --- a/tests/special_sanity/check_license.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from argparse import ArgumentParser -from pathlib import Path - -license_head_bytedance = "Copyright 2024 Bytedance Ltd. and/or its affiliates" -license_head_bytedance_25 = "Copyright 2025 Bytedance Ltd. and/or its affiliates" -# Add custom license headers below -license_head_prime = "Copyright 2024 PRIME team and/or its affiliates" -license_head_individual = "Copyright 2025 Individual Contributor:" -license_head_sglang = "Copyright 2023-2024 SGLang Team" -license_head_modelbest = "Copyright 2025 ModelBest Inc. and/or its affiliates" -license_head_amazon = "Copyright 2025 Amazon.com Inc and/or its affiliates" -license_headers = [ - license_head_bytedance, - license_head_bytedance_25, - license_head_prime, - license_head_individual, - license_head_sglang, - license_head_modelbest, - license_head_amazon, -] - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--directory", "-d", required=True, type=str) - args = parser.parse_args() - directory_in_str = args.directory - - pathlist = Path(directory_in_str).glob("**/*.py") - for path in pathlist: - # because path is object not string - path_in_str = str(path.absolute()) - print(path_in_str) - with open(path_in_str, encoding="utf-8") as f: - file_content = f.read() - - has_license = False - for lh in license_headers: - if lh in file_content: - has_license = True - break - assert has_license, f"file {path_in_str} does not contain license" diff --git a/tests/special_sanity/check_pr_description.py b/tests/special_sanity/check_pr_description.py deleted file mode 100644 index 4ed4563db..000000000 --- a/tests/special_sanity/check_pr_description.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/usr/bin/env python3 -import json -import os - -# Number of lines to check -NUM_LINES = 5 - - -# Custom exception types for clear error handling -class TemplateFileError(Exception): - pass - - -class PRBodyLoadError(Exception): - pass - - -class PRDescriptionError(Exception): - pass - - -# Path to the PR template file -template_file = os.path.join(os.getenv("GITHUB_WORKSPACE", "."), ".github", "PULL_REQUEST_TEMPLATE.md") - - -def load_template(path): - """ - Load only the first NUM_LINES of the PR template file as a list of lines, - without stripping any characters. - """ - lines = [] - try: - with open(path, encoding="utf-8") as f: - for _ in range(NUM_LINES): - line = f.readline() - if not line: - break - lines.append(line.strip()) - return lines - except Exception as e: - raise TemplateFileError(f"Failed to read PR template (first {NUM_LINES} lines) at {path}: {e}") from e - - -def load_pr_body(event_path): - try: - with open(event_path, encoding="utf-8") as f: - payload = json.load(f) - return payload.get("pull_request", {}).get("body", "") or "" - except Exception as e: - raise PRBodyLoadError(f"Failed to read PR body from {event_path}: {e}") from e - - -def check_pr_description(body, template_lines): - """ - Compare the first NUM_LINES lines of the PR body to the template lines. - If they match exactly, the placeholder was not modified. - """ - pr_lines = body.splitlines(keepends=True) - pr_first = [x.strip() for x in pr_lines[:NUM_LINES]] - if pr_first == template_lines: - raise PRDescriptionError( - "It looks like you haven't updated the '### What does this PR do?' section. Please replace " - "the placeholder text with a concise description of what your PR does." - ) - else: - print(pr_first) - print(template_lines) - - -def main(): - event_path = os.getenv("GITHUB_EVENT_PATH") - if not event_path: - raise OSError("GITHUB_EVENT_PATH is not set.") - - template_lines = load_template(template_file) - pr_body = load_pr_body(event_path) - check_pr_description(pr_body, template_lines) - - print("✅ '### What does this PR do?' section has been filled out.") - - -if __name__ == "__main__": - main() diff --git a/tests/special_sanity/check_pr_title.py b/tests/special_sanity/check_pr_title.py deleted file mode 100644 index f4cbd5238..000000000 --- a/tests/special_sanity/check_pr_title.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import re - -# Get PR title from environment -pr_title = os.environ.get("PR_TITLE", "").strip() - -# Define rules -allowed_modules = ["fsdp", "megatron", "sglang", "vllm", "rollout", "trainer"] -allowed_modules += ["tests", "training_utils", "recipe", "hardware", "deployment"] -allowed_modules += ["ray", "worker", "single_controller", "misc", "docker", "ci"] -allowed_modules += ["perf", "model", "algo", "env", "tool", "ckpt", "doc", "data", "cfg"] -allowed_types = ["feat", "fix", "refactor", "chore", "test"] - -# Check for [BREAKING] prefix and extract the rest of the title -breaking_match = re.match(r"^\[BREAKING\]\s*(.+)$", pr_title, re.IGNORECASE) -if breaking_match: - core_pr_title = breaking_match.group(1).strip() - is_breaking = True -else: - core_pr_title = pr_title - is_breaking = False - -# Build dynamic regex pattern for modules (now working on core_pr_title) -re_modules_pattern = re.compile(r"^\[([a-z_,\s]+)\]", re.IGNORECASE) -re_modules = re_modules_pattern.match(core_pr_title) -if not re_modules: - print(f"❌ Invalid PR title: '{pr_title}'") - print("Expected format: [BREAKING][module] type: description") - print(f"Allowed modules: {', '.join(allowed_modules)}") - raise Exception("Invalid PR title") -else: - modules = re.findall(r"[a-z_]+", re_modules.group(1).lower()) - if not all(module in allowed_modules for module in modules): - invalid_modules = [module for module in modules if module not in allowed_modules] - print(f"❌ Invalid modules: {', '.join(invalid_modules)}") - print(f"Allowed modules: {', '.join(allowed_modules)}") - raise Exception("Invalid PR title") - -types_pattern = "|".join(re.escape(t) for t in allowed_types) -re_types_pattern = re.compile(rf"^\[[a-z_,\s]+\]\s+({types_pattern}):\s+.+$", re.IGNORECASE) -match = re_types_pattern.match(core_pr_title) - -if not match: - print(f"❌ Invalid PR title: '{pr_title}'") - print("Expected format: [BREAKING][module] type: description") - print(f"Allowed types: {', '.join(allowed_types)}") - raise Exception("Invalid PR title") - -change_type = match.group(1).lower() - -# Build the success message -breaking_info = " (BREAKING CHANGE)" if is_breaking else "" -print(f"✅ PR title is valid: {pr_title}, modules: {modules}, type: {change_type}{breaking_info}") diff --git a/tests/special_sanity/test_config_docs.py b/tests/special_sanity/test_config_docs.py deleted file mode 100644 index 2f260f10b..000000000 --- a/tests/special_sanity/test_config_docs.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from pathlib import Path - - -def validate_yaml_format(yaml_lines): - errors = [] - i = 0 - - while i < len(yaml_lines): - line = yaml_lines[i] - stripped = line.strip() - - # Skip empty lines - if stripped == "": - i += 1 - continue - - # Match YAML keys like "field:" or "field: value" - key_match = re.match(r"^(\s*)([a-zA-Z0-9_]+):", line) - if key_match: - # Check if there's a comment above - if i == 0 or not yaml_lines[i - 1].strip().startswith("#"): - errors.append(f"Missing comment above line {i + 1}: {line.strip()}") - - # Check for inline comment - if "#" in line and not stripped.startswith("#"): - comment_index = line.index("#") - colon_index = line.index(":") - if comment_index > colon_index: - errors.append(f"Inline comment found on line {i + 1}: {line.strip()}") - - # Check for blank line after this key line (unless next is a deeper indent) - if i + 1 < len(yaml_lines): - next_line = yaml_lines[i + 1] - next_stripped = next_line.strip() - - # If next is not empty and not a deeper nested line, enforce blank line - if next_stripped != "": - errors.append(f"Missing blank line after line {i + 1}: {line.strip()}") - - i += 1 - - return errors - - -def test_trainer_config_doc(): - yamls_to_inspect = [ - "verl/trainer/config/ppo_trainer.yaml", - "verl/trainer/config/actor/actor.yaml", - "verl/trainer/config/actor/dp_actor.yaml", - "verl/trainer/config/ref/ref.yaml", - "verl/trainer/config/ref/dp_ref.yaml", - "verl/trainer/config/rollout/rollout.yaml", - ] - success = True - for yaml_to_inspect in yamls_to_inspect: - yaml_path = Path(yaml_to_inspect) # path to your YAML file - with open(yaml_path) as f: - lines = f.readlines() - - validation_errors = validate_yaml_format(lines) - if validation_errors: - success = False - print("YAML documentation format check failed:") - print(f"Please read the top block of {yaml_to_inspect} to see format rules:\n") - for err in validation_errors: - print(" -", err) - - if not success: - raise Exception("Please fix documentation format.") - else: - print("YAML format check passed ✅") diff --git a/tests/special_sanity/test_import.py b/tests/special_sanity/test_import.py deleted file mode 100644 index 4f8a918fe..000000000 --- a/tests/special_sanity/test_import.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def test_import(): - import verl - - print(verl.__version__) - - -def test_single_controller_import(): - import verl.single_controller - - print(verl.single_controller.__version__) diff --git a/tests/special_sanity/type_coverage_check.py b/tests/special_sanity/type_coverage_check.py deleted file mode 100644 index dc6dc7caf..000000000 --- a/tests/special_sanity/type_coverage_check.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Custom type annotation check tool. -To inspect the type annotation for functions in the entire codebase, please run: -find verl -type f -name "*.py" | xargs -n 1 python3 tests/special_sanity/type_coverage_check.py --all-lines ---debug --target-file -""" - -import argparse -import ast -import linecache -import subprocess -from pathlib import Path - - -def get_changed_files() -> list[Path]: - result = subprocess.run( - ["git", "diff", "--name-only", "--diff-filter=AM", "origin/main...HEAD"], stdout=subprocess.PIPE, text=True - ) - return [Path(f) for f in result.stdout.splitlines() if f.endswith(".py")] - - -def get_changed_lines(file_path: Path) -> set[int]: - result = subprocess.run( - ["git", "diff", "-U0", "origin/main...HEAD", "--", str(file_path)], - stdout=subprocess.PIPE, - text=True, - ) - lines: set[int] = set() - for line in result.stdout.splitlines(): - if line.startswith("@@"): - for part in line.split(): - try: - if part.startswith("+") and "," in part: - start, count = map(int, part[1:].split(",")) - lines.update(range(start, start + count)) - elif part.startswith("+") and "," not in part: - lines.add(int(part[1:])) - except Exception: - # (vermouth1992) There are many edge cases here because + can be in the changed program - pass - return lines - - -CHECK_SUCCESS = 0 -CHECK_WARNING = 1 -CHECK_FAILURE = -1 - - -def should_check_type(arg_name: str) -> bool: - if arg_name in ("self", "cls"): - return False - if arg_name.startswith("*"): - return False - return True - - -def has_type_annotations(node: ast.AST, debug: bool = False) -> int: - if isinstance(node, ast.FunctionDef): - is_private = node.name.startswith("_") - has_ann = ( - all(arg.annotation is not None for arg in node.args.args if should_check_type(arg.arg)) - and node.returns is not None - ) - if has_ann or is_private: - return CHECK_SUCCESS - else: - if debug: - print(node, [(arg.annotation, arg.arg) for arg in node.args.args if should_check_type(arg.arg)]) - return CHECK_FAILURE - return CHECK_SUCCESS - - -def check_file( - file_path: Path, changed_lines: set[int], debug: bool = False -) -> tuple[int, int, list[tuple[Path, int, str]], list[tuple[Path, int, str]]]: - with open(file_path) as f: - source: str = f.read() - tree = ast.parse(source, filename=str(file_path)) - annotated = 0 - total = 0 - warning_lines: list[tuple[Path, int, str]] = [] - failure_lines: list[tuple[Path, int, str]] = [] - - for node in ast.walk(tree): - if hasattr(node, "lineno") and node.lineno in changed_lines: - if isinstance(node, ast.FunctionDef | ast.Assign | ast.AnnAssign): - total += 1 - result = has_type_annotations(node, debug) - if result == CHECK_SUCCESS or result == CHECK_WARNING: - annotated += 1 - if result == CHECK_WARNING: - warning_lines.append( - (file_path, node.lineno, linecache.getline(str(file_path), node.lineno).strip()) - ) - else: - source_line = linecache.getline(str(file_path), node.lineno).strip() - failure_lines.append((file_path, node.lineno, source_line)) - - return annotated, total, warning_lines, failure_lines - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--threshold", type=float, default=0.3, help="Minimum ratio of annotated lines required (0.0 - 1.0)" - ) - parser.add_argument("--target-file", type=str, default=None, help="Path to the Python source file to analyse") - parser.add_argument( - "--all-lines", - action="store_true", - help="Check all lines in the file instead of only changed lines based on git", - ) - parser.add_argument("--debug", action="store_true", help="Add debugging logs") - args = parser.parse_args() - - total_changed = 0 - total_annotated = 0 - all_warnings: list[tuple[Path, int, str]] = [] - all_failures: list[tuple[Path, int, str]] = [] - - target_files = [args.target_file] if args.target_file is not None else get_changed_files() - for fpath in target_files: - if "tests/" in str(fpath): - continue - if args.all_lines: - changed_lines = [i + 1 for i in range(len(open(fpath).readlines()))] - else: - changed_lines = get_changed_lines(fpath) - annotated, total, warning_lines, failure_lines = check_file(fpath, changed_lines, args.debug) - total_annotated += annotated - total_changed += total - all_warnings.extend(warning_lines) - all_failures.extend(failure_lines) - - ratio = (total_annotated / total_changed) if total_changed else 1.0 - - print( - f"🔍 Type coverage on {'all' if args.all_lines else 'changed'} lines: " - f"{total_annotated}/{total_changed} = {ratio:.2%}. Files inspected: {target_files}" - ) - - if all_warnings: - print("\n⚠️ Suggest Improve: Lines missing type annotations for inputs and outputs:\n") - for fname, lineno, line in all_warnings: - print(f"{fname}:{lineno}: {line}") - - if all_failures: - print("⚠️ [ERROR] Lines missing type annotations for inputs and outputs:\n") - for fname, lineno, line in all_failures: - print(f"{fname}:{lineno}: {line}") - - if ratio < args.threshold: - print( - f"Please add type annotations for inputs and outputs to meet threshold {args.threshold}. " - f"Cases exempt from checking:" - ) - print("1. Private methods.") - print("2. Args with name in ('self', 'cls'), or *args / **kwargs") - print("3. Files under tests/") - raise Exception(f"\n❌ Type coverage below threshold ({args.threshold:.0%}).") - else: - if all_warnings or all_failures: - print("") - print("✅ Type annotation coverage acceptable.\n") - - -if __name__ == "__main__": - main() diff --git a/tests/special_sanity/validate_imported_docs.py b/tests/special_sanity/validate_imported_docs.py deleted file mode 100644 index b36a407be..000000000 --- a/tests/special_sanity/validate_imported_docs.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -verify_imported_docs.py - -Assert that every function or class *explicitly imported* (via -`from import `) in a given Python file has a docstring. -""" - -from __future__ import annotations - -import argparse -import ast -import importlib -import inspect -import pathlib -import sys - - -def _parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser(description="Verify that imported functions/classes have docstrings.") - p.add_argument( - "--target-file", - default="verl/trainer/ppo/ray_trainer.py", - help="Path to the Python source file to analyse (e.g. verl/trainer/ppo/ray_trainer.py)", - ) - p.add_argument( - "--allow-list", - default=["omegaconf.open_dict"], - help="a list of third_party dependencies that do not have proper docs :(", - ) - p.add_argument( - "--project-root", - default=".", - help="Directory to prepend to PYTHONPATH so local packages resolve (default: .)", - ) - p.add_argument( - "--quiet", - action="store_true", - help="Suppress success message (still prints errors).", - ) - return p.parse_args() - - -def _import_attr(module_name: str, attr_name: str): - """Import `module_name` then return `getattr(module, attr_name)`.""" - module = importlib.import_module(module_name) - return getattr(module, attr_name) - - -def _check_file(py_file: pathlib.Path, project_root: pathlib.Path, allow_list: list[str]) -> list[str]: - """Return a list of error strings (empty == success).""" - # Ensure local packages resolve - sys.path.insert(0, str(project_root.resolve())) - - tree = ast.parse(py_file.read_text(), filename=str(py_file)) - problems: list[str] = [] - - for node in ast.walk(tree): - if not isinstance(node, ast.ImportFrom): - continue - - # Relative imports (level > 0) get the leading dots stripped - module_name = "." * node.level + (node.module or "") - for alias in node.names: - if alias.name == "*": - problems.append( - f"{py_file}:{node.lineno} - wildcard import `from {module_name} import *` cannot be verified." - ) - continue - - imported_name = alias.name - - try: - obj = _import_attr(module_name, imported_name) - except Exception: # pragma: no cover – wide net for import quirks - pass - # For some reason the module cannot be imported, skip for now - # problems.append( - # f"{py_file}:{node.lineno} - could not resolve " - # f"`{imported_name}` from `{module_name}` ({exc})" - # ) - continue - - if f"{module_name}.{imported_name}" in allow_list: - continue - if inspect.isfunction(obj) or inspect.isclass(obj): - doc = inspect.getdoc(obj) - if not (doc and doc.strip()): - kind = "class" if inspect.isclass(obj) else "function" - problems.append( - f"{py_file}:{node.lineno} - {kind} `{module_name}.{imported_name}` is missing a docstring." - ) - - return problems - - -def main() -> None: - args = _parse_args() - target_path = pathlib.Path(args.target_file).resolve() - project_root = pathlib.Path(args.project_root).resolve() - - if not target_path.is_file(): - raise Exception(f"❌ Target file not found: {target_path}") - - errors = _check_file(target_path, project_root, args.allow_list) - - if errors: - print("Docstring verification failed:\n") - print("\n".join(f" • {e}" for e in errors)) - raise Exception("❌ Docstring verification failed.") - - if not args.quiet: - print(f"✅ All explicitly imported functions/classes in {target_path} have docstrings.") - - -if __name__ == "__main__": - main() diff --git a/tests/special_sanity/validate_structure.py b/tests/special_sanity/validate_structure.py deleted file mode 100644 index a5390b15a..000000000 --- a/tests/special_sanity/validate_structure.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -#!/usr/bin/env python3 -""" -Validate that test file subfolders mirror the top-level package layout. - -Usage examples --------------- - -# Typical run (defaults: impl_root=my_project, tests_root=tests) -python check_tests_structure.py - -# Custom layout and extra allowed folders -python check_tests_structure.py \ - --impl-root verl \ - --tests-root tests \ - --allow-dirs special_e2e special_sanity special_standalone special_distributed -""" - -from __future__ import annotations - -import argparse -import sys -from pathlib import Path - - -def discover_allowed_modules(impl_root: Path, extra: list[str]) -> set[str]: - """Return the set of first-level directories that tests may live under.""" - allowed = {p.name for p in impl_root.iterdir() if p.is_dir()} - allowed.update(extra) - return allowed - - -def find_violations(tests_root: Path, allowed: set[str], allowed_files: list[str]) -> list[str]: - """Return a list of error strings for test files in the wrong place.""" - errors: list[str] = [] - for test_file in tests_root.rglob("test*.py"): - if str(test_file) in allowed_files: - continue - rel_parts = test_file.relative_to(tests_root).parts - if len(rel_parts) < 2: - errors.append(f"{test_file}: must be inside one of {sorted(allowed)} (not at tests root)") - continue - - first_folder = rel_parts[0] - if first_folder not in allowed: - errors.append( - f"{test_file}: subfolder '{first_folder}' under tests/ is not an allowed module. " - f"The valid ones are: {sorted(allowed)}" - ) - return errors - - -def main() -> None: - parser = argparse.ArgumentParser(description="Check that test files follow tests//… layout.") - parser.add_argument( - "--impl-root", - type=Path, - default="verl", - help="Implementation root (default: my_project)", - ) - parser.add_argument( - "--tests-root", - type=Path, - default="tests", - help="Root of test tree (default: tests)", - ) - parser.add_argument( - "--allow-dirs", - nargs="*", - default=["special_e2e", "special_sanity", "special_standalone", "special_distributed"], - help="Extra top-level test folders that are exempt from the rule", - ) - parser.add_argument( - "--allow-files", - nargs="*", - default=["tests/test_protocol_on_cpu.py", "tests/test_base_config_on_cpu.py"], - help="Extra top-level test folders that are exempt from the rule", - ) - args = parser.parse_args() - - if not args.impl_root.is_dir(): - raise Exception(f"Implementation root '{args.impl_root}' does not exist.") - if not args.tests_root.is_dir(): - raise Exception(f"Tests root '{args.tests_root}' does not exist.") - - allowed = discover_allowed_modules(args.impl_root, args.allow_dirs) - violations = find_violations(args.tests_root, allowed, args.allow_files) - - if violations: - print("❌ Test layout violations found:\n", file=sys.stderr) - for err in violations: - print(" -", err, file=sys.stderr) - - print( - f"\nGuideline:\n Place each test file under tests//…\n where is " - f"one of the top-level packages inside '{args.impl_root}', or is explicitly listed via --allow-dirs.\n", - file=sys.stderr, - ) - raise Exception("❌ Test layout violations found.") - - print("✅ Tests folder structure looks good.") - - -if __name__ == "__main__": - main() diff --git a/tests/special_standalone/README.md b/tests/special_standalone/README.md deleted file mode 100644 index 0e3596e1a..000000000 --- a/tests/special_standalone/README.md +++ /dev/null @@ -1 +0,0 @@ -The standalone test folder is reserved for tests that require dedicated environment (e.g. memory stress tests) diff --git a/tests/special_standalone/test_memory_buffers.py b/tests/special_standalone/test_memory_buffers.py deleted file mode 100644 index 778515347..000000000 --- a/tests/special_standalone/test_memory_buffers.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Test memory buffers -- We start with two models with the same weights -- We use Memory buffer to make one of the models and then compare the parameters -""" - -import gc - -import torch -from transformers import LlamaConfig, LlamaModel - - -def test_memory_buffers(): - llama_config = LlamaConfig( - vocab_size=256, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=2, - num_attention_heads=16, - num_key_value_heads=16, - ) - - model = LlamaModel(config=llama_config).cuda() - model_copy = LlamaModel(config=llama_config).cuda() - model_copy.load_state_dict(model.state_dict()) - - norm_factor = 1024**3 - - t_before = torch.cuda.get_device_properties(0).total_memory / norm_factor - r_before = torch.cuda.memory_reserved(0) / norm_factor - a_before = torch.cuda.memory_allocated(0) / norm_factor - - print(f"Before Total memory: {t_before} GB, reserved: {r_before} GB, allocated: {a_before} GB") - - t = torch.cuda.get_device_properties(0).total_memory / norm_factor - r = torch.cuda.memory_reserved(0) / norm_factor - a = torch.cuda.memory_allocated(0) / norm_factor - - gc.collect() - torch.cuda.empty_cache() - - print(f"After Total memory: {t} GB, reserved: {r} GB, allocated: {a} GB") - - change_ratio = (a - a_before) / a_before - assert change_ratio < 0.01, f"make sure the allocated change is less than 1%, Got {change_ratio}" - - for (name1, param1), (name2, param2) in zip(model.named_parameters(), model_copy.named_parameters(), strict=True): - assert name1 == name2 - assert torch.eq(param1.data, param2.data).all(), f"{param1.data}, {param2.data}, {name1}" - - -if __name__ == "__main__": - test_memory_buffers() diff --git a/tests/test_base_config_on_cpu.py b/tests/test_base_config_on_cpu.py deleted file mode 100644 index 9a50235c8..000000000 --- a/tests/test_base_config_on_cpu.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from verl.base_config import BaseConfig - - -@pytest.fixture -def base_config_mock(): - """Fixture to create a mock BaseConfig instance with test attributes.""" - mock_config = BaseConfig() - mock_config.test_attr = "test_value" - return mock_config - - -def test_getitem_success(base_config_mock): - """Test __getitem__ with existing attribute (happy path).""" - assert base_config_mock["test_attr"] == "test_value" - - -def test_getitem_nonexistent_attribute(base_config_mock): - """Test __getitem__ with non-existent attribute (exception path 1).""" - with pytest.raises(AttributeError): - _ = base_config_mock["nonexistent_attr"] - - -def test_getitem_invalid_key_type(base_config_mock): - """Test __getitem__ with invalid key type (exception path 2).""" - with pytest.raises(TypeError): - _ = base_config_mock[123] # type: ignore diff --git a/tests/test_protocol_on_cpu.py b/tests/test_protocol_on_cpu.py deleted file mode 100644 index 2052635c1..000000000 --- a/tests/test_protocol_on_cpu.py +++ /dev/null @@ -1,522 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random - -import numpy as np -import pytest -import torch -from tensordict import TensorDict - -from verl import DataProto -from verl.protocol import union_numpy_dict, union_tensor_dict - - -def test_union_tensor_dict(): - obs = torch.randn(100, 10) - - data1 = TensorDict({"obs": obs, "act": torch.randn(100, 3)}, batch_size=[100]) - data2 = TensorDict({"obs": obs, "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100]) - - data_with_copied_obs = TensorDict( - {"obs": obs.clone(), "next_obs": torch.randn(100, 10), "rew": torch.randn(100)}, batch_size=[100] - ) - - data = union_tensor_dict(data1, data2) - with pytest.raises(AssertionError): - data = union_tensor_dict(data1, data_with_copied_obs) - - data = np.random.random(100) - data2 = [float("nan") for _ in range(99)] - data2.append("nan") - data2 = np.array(data2, dtype=object) - data3 = np.tile(data2, (2, 1)) - a = {"a": data, "b": data2, "c": data3} - b = {"a": data, "b": data2, "c": data3} - b_ = {"a": np.random.random(100)} - union_numpy_dict(a, b) - with pytest.raises(AssertionError): - union_numpy_dict(a, b_) - - -def test_tensor_dict_constructor(): - obs = torch.randn(100, 10) - act = torch.randn(100, 10, 3) - data = DataProto.from_dict(tensors={"obs": obs, "act": act}) - - assert data.batch.batch_size == torch.Size([100]) - - with pytest.raises(AssertionError): - data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=2) - - with pytest.raises(AssertionError): - data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=3) - - -def test_tensor_dict_make_iterator(): - obs = torch.randn(100, 10) - labels = [random.choice(["abc", "cde"]) for _ in range(100)] - dataset = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}) - - data_iter_1 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1) - data_list_1 = [] - for data in data_iter_1: - data_list_1.append(data) - - data_iter_2 = dataset.make_iterator(mini_batch_size=10, epochs=2, seed=1) - data_list_2 = [] - for data in data_iter_2: - data_list_2.append(data) - - for data1, data2 in zip(data_list_1, data_list_2, strict=True): - assert isinstance(data1, DataProto) - assert isinstance(data2, DataProto) - result = torch.all(torch.eq(data1.batch["obs"], data2.batch["obs"])) - if not result.item(): - print(data1.batch["obs"]) - print(data2.batch["obs"]) - raise AssertionError() - non_tensor_result = np.all(np.equal(data1.non_tensor_batch["labels"], data2.non_tensor_batch["labels"])) - if not non_tensor_result.item(): - print(data1.non_tensor_batch["labels"]) - print(data2.non_tensor_batch["labels"]) - - -def test_reorder(): - obs = torch.tensor([1, 2, 3, 4, 5, 6]) - labels = ["a", "b", "c", "d", "e", "f"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"}) - data.reorder(torch.tensor([3, 4, 2, 0, 1, 5])) - - assert torch.all(torch.eq(data.batch["obs"], torch.tensor([4, 5, 3, 1, 2, 6]))) - assert np.all(data.non_tensor_batch["labels"] == np.array(["d", "e", "c", "a", "b", "f"])) - assert data.meta_info == {"name": "abdce"} - - -def test_chunk_concat(): - obs = torch.tensor([1, 2, 3, 4, 5, 6]) - labels = ["a", "b", "c", "d", "e", "f"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"}) - - with pytest.raises(AssertionError): - data.chunk(5) - - data_split = data.chunk(2) - assert len(data_split) == 2 - assert torch.all(torch.eq(data_split[0].batch["obs"], torch.tensor([1, 2, 3]))) - assert np.all(data_split[0].non_tensor_batch["labels"] == np.array(["a", "b", "c"])) - assert data_split[0].meta_info == {"name": "abdce"} - - assert torch.all(torch.eq(data_split[1].batch["obs"], torch.tensor([4, 5, 6]))) - assert np.all(data_split[1].non_tensor_batch["labels"] == np.array(["d", "e", "f"])) - assert data_split[1].meta_info == {"name": "abdce"} - - concat_data = DataProto.concat(data_split) - assert torch.all(torch.eq(concat_data.batch["obs"], data.batch["obs"])) - assert np.all(concat_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"]) - assert concat_data.meta_info == data.meta_info - - -def test_pop(): - obs = torch.randn(100, 10) - act = torch.randn(100, 3) - dataset = DataProto.from_dict({"obs": obs, "act": act}, meta_info={"2": 2, "1": 1}) - poped_dataset = dataset.pop(batch_keys=["obs"], meta_info_keys=["2"]) - - assert poped_dataset.batch.keys() == {"obs"} - assert poped_dataset.meta_info.keys() == {"2"} - - assert dataset.batch.keys() == {"act"} - assert dataset.meta_info.keys() == {"1"} - - -def test_repeat(): - # Create a DataProto object with some batch and non-tensor data - obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) - - # Test interleave=True - repeated_data_interleave = data.repeat(repeat_times=2, interleave=True) - expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [3, 4], [3, 4], [5, 6], [5, 6]]) - expected_labels_interleave = ["a", "a", "b", "b", "c", "c"] - - assert torch.all(torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave)) - assert (repeated_data_interleave.non_tensor_batch["labels"] == expected_labels_interleave).all() - assert repeated_data_interleave.meta_info == {"info": "test_info"} - - # Test interleave=False - repeated_data_no_interleave = data.repeat(repeat_times=2, interleave=False) - expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6]]) - expected_labels_no_interleave = ["a", "b", "c", "a", "b", "c"] - - assert torch.all(torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave)) - assert (repeated_data_no_interleave.non_tensor_batch["labels"] == expected_labels_no_interleave).all() - assert repeated_data_no_interleave.meta_info == {"info": "test_info"} - - -def test_dataproto_pad_unpad(): - obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) - - from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto - - padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=2) - assert pad_size == 1 - - expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2]]) - expected_labels = ["a", "b", "c", "a"] - - assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs)) - assert (padded_data.non_tensor_batch["labels"] == expected_labels).all() - assert padded_data.meta_info == {"info": "test_info"} - - unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size) - assert torch.all(torch.eq(unpadd_data.batch["obs"], obs)) - assert (unpadd_data.non_tensor_batch["labels"] == labels).all() - assert unpadd_data.meta_info == {"info": "test_info"} - - padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=3) - assert pad_size == 0 - - expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - expected_labels = ["a", "b", "c"] - - assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs)) - assert (padded_data.non_tensor_batch["labels"] == expected_labels).all() - assert padded_data.meta_info == {"info": "test_info"} - - unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size) - assert torch.all(torch.eq(unpadd_data.batch["obs"], obs)) - assert (unpadd_data.non_tensor_batch["labels"] == labels).all() - assert unpadd_data.meta_info == {"info": "test_info"} - - padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7) - assert pad_size == 4 - - expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]]) - expected_labels = ["a", "b", "c", "a", "b", "c", "a"] - assert torch.all(torch.eq(padded_data.batch["obs"], expected_obs)) - assert (padded_data.non_tensor_batch["labels"] == expected_labels).all() - assert padded_data.meta_info == {"info": "test_info"} - - unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size) - assert torch.all(torch.eq(unpadd_data.batch["obs"], obs)) - assert (unpadd_data.non_tensor_batch["labels"] == labels).all() - assert unpadd_data.meta_info == {"info": "test_info"} - - -def test_dataproto_fold_unfold(): - from verl.protocol import DataProto, fold_batch_dim, unfold_batch_dim - - obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) - - data1 = data.repeat(repeat_times=2, interleave=True) - - data2 = fold_batch_dim(data1, new_batch_size=3) - - torch.testing.assert_close(data2.batch["obs"], torch.tensor([[[1, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]])) - assert (data2.non_tensor_batch["labels"] == [["a", "a"], ["b", "b"], ["c", "c"]]).all() - - data2.reorder(indices=torch.tensor([1, 2, 0])) - - data3 = unfold_batch_dim(data2, batch_dims=2) - - torch.testing.assert_close(data3.batch["obs"], torch.tensor([[3, 4], [3, 4], [5, 6], [5, 6], [1, 2], [1, 2]])) - assert (data3.non_tensor_batch["labels"] == ["b", "b", "c", "c", "a", "a"]).all() - assert data3.meta_info == {"info": "test_info"} - - -def test_torch_save_data_proto(): - obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) - data.save_to_disk("test_data.pt") - loaded_data = DataProto.load_from_disk("test_data.pt") - - assert torch.all(torch.eq(loaded_data.batch["obs"], data.batch["obs"])) - assert (loaded_data.non_tensor_batch["labels"] == data.non_tensor_batch["labels"]).all() - assert loaded_data.meta_info == data.meta_info - - import os - - os.remove("test_data.pt") - - -def test_len(): - obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = np.array(["a", "b", "c"], dtype=object) - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) - - assert len(data) == 3 - - data = DataProto(batch=None, non_tensor_batch={"labels": labels}, meta_info={"info": "test_info"}) - - assert len(data) == 3 - - data = DataProto(batch=None, non_tensor_batch={}, meta_info={"info": "test_info"}) - - assert len(data) == 0 - - data = DataProto(batch=None, non_tensor_batch=None, meta_info={"info": "test_info"}) - - assert len(data) == 0 - - -def test_dataproto_index(): - data_len = 100 - idx_num = 10 - - obs = torch.randn(data_len, 10) - labels = [random.choice(["abc", "cde"]) for _ in range(data_len)] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}) - labels_np = np.array(labels) - - idx_np_int = np.random.randint(0, data_len, size=(idx_num,)) - result_np_int = data[idx_np_int] - assert result_np_int.batch.keys() == data.batch.keys() - assert result_np_int.non_tensor_batch.keys() == data.non_tensor_batch.keys() - assert result_np_int.batch["obs"].shape[0] == idx_num - assert result_np_int.non_tensor_batch["labels"].shape[0] == idx_num - assert np.array_equal(result_np_int.batch["obs"].cpu().numpy(), obs[idx_np_int].numpy()) - assert np.array_equal(result_np_int.non_tensor_batch["labels"], labels_np[idx_np_int]) - - idx_torch_int = torch.randint(0, data_len, size=(idx_num,)) - result_torch_int = data[idx_torch_int] - assert result_torch_int.batch.keys() == data.batch.keys() - assert result_torch_int.non_tensor_batch.keys() == data.non_tensor_batch.keys() - assert result_torch_int.batch["obs"].shape[0] == idx_num - assert result_torch_int.non_tensor_batch["labels"].shape[0] == idx_num - assert np.array_equal(result_torch_int.batch["obs"].cpu().numpy(), obs[idx_torch_int].cpu().numpy()) - assert np.array_equal(result_torch_int.non_tensor_batch["labels"], labels_np[idx_torch_int.cpu().numpy()]) - - idx_list_int = [np.random.randint(0, data_len) for _ in range(idx_num)] - result_list_int = data[idx_list_int] - assert result_list_int.batch.keys() == data.batch.keys() - assert result_list_int.non_tensor_batch.keys() == data.non_tensor_batch.keys() - assert result_list_int.batch["obs"].shape[0] == idx_num - assert result_list_int.non_tensor_batch["labels"].shape[0] == idx_num - assert np.array_equal(result_list_int.batch["obs"].cpu().numpy(), obs[idx_list_int].cpu().numpy()) - assert np.array_equal(result_list_int.non_tensor_batch["labels"], labels_np[idx_list_int]) - - idx_np_bool = np.random.randint(0, 2, size=(data_len,), dtype=bool) - result_np_bool = data[idx_np_bool] - assert result_np_bool.batch.keys() == data.batch.keys() - assert result_np_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys() - assert result_np_bool.batch["obs"].shape[0] == idx_np_bool.sum() - assert result_np_bool.non_tensor_batch["labels"].shape[0] == idx_np_bool.sum() - assert np.array_equal(result_np_bool.batch["obs"].cpu().numpy(), obs[idx_np_bool].cpu().numpy()) - assert np.array_equal(result_np_bool.non_tensor_batch["labels"], labels_np[idx_np_bool]) - - idx_torch_bool = torch.randint(0, 2, size=(data_len,), dtype=torch.bool) - result_torch_bool = data[idx_torch_bool] - assert result_torch_bool.batch.keys() == data.batch.keys() - assert result_torch_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys() - assert result_torch_bool.batch["obs"].shape[0] == idx_torch_bool.sum().item() - assert result_torch_bool.non_tensor_batch["labels"].shape[0] == idx_torch_bool.sum().item() - assert np.array_equal(result_torch_bool.batch["obs"].cpu().numpy(), obs[idx_torch_bool].cpu().numpy()) - assert np.array_equal(result_torch_bool.non_tensor_batch["labels"], labels_np[idx_torch_bool]) - - idx_list_bool = [np.random.randint(0, 2, dtype=bool) for _ in range(data_len)] - result_list_bool = data[idx_list_bool] - assert result_list_bool.batch.keys() == data.batch.keys() - assert result_list_bool.non_tensor_batch.keys() == data.non_tensor_batch.keys() - assert result_list_bool.batch["obs"].shape[0] == sum(idx_list_bool) - assert result_list_bool.non_tensor_batch["labels"].shape[0] == sum(idx_list_bool) - assert np.array_equal(result_list_bool.batch["obs"].cpu().numpy(), obs[idx_list_bool].cpu().numpy()) - assert np.array_equal(result_list_bool.non_tensor_batch["labels"], labels_np[idx_list_bool]) - - -def test_old_vs_new_from_single_dict(): - class CustomProto(DataProto): - """Uses the new, fixed from_single_dict.""" - - pass - - class OriginProto(DataProto): - """Mimics the *old* from_single_dict (always returns a DataProto).""" - - @classmethod - def from_single_dict(cls, data, meta_info=None, auto_padding=False): - tensors, non_tensors = {}, {} - for k, v in data.items(): - if torch.is_tensor(v): - tensors[k] = v - else: - non_tensors[k] = v - # always calls DataProto.from_dict, ignoring `cls` - return DataProto.from_dict( - tensors=tensors, - non_tensors=non_tensors, - meta_info=meta_info, - auto_padding=auto_padding, - ) - - sample = {"x": torch.tensor([0])} - - orig = OriginProto.from_single_dict(sample) - # old behavior: always DataProto, not a CustomOriginProto - assert type(orig) is DataProto - assert type(orig) is not OriginProto - - cust = CustomProto.from_single_dict(sample) - # new behavior: respects subclass - assert type(cust) is CustomProto - - -def test_dataproto_no_batch(): - labels = ["a", "b", "c"] - data = DataProto.from_dict(non_tensors={"labels": labels}, meta_info={"info": "test_info"}) - selected = data.select(non_tensor_batch_keys=["labels"]) - assert (selected.non_tensor_batch["labels"] == labels).all() - pop_data = data.pop(non_tensor_batch_keys=["labels"]) - assert (pop_data.non_tensor_batch["labels"] == labels).all() - assert data.non_tensor_batch == {} - - -def test_sample_level_repeat(): - # Create a DataProto object with some batch and non-tensor data - obs = torch.tensor([[1, 2], [3, 4], [5, 6]]) - labels = ["a", "b", "c"] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"info": "test_info"}) - - # list - repeated_data_interleave = data.sample_level_repeat(repeat_times=[3, 1, 2]) - expected_obs_interleave = torch.tensor([[1, 2], [1, 2], [1, 2], [3, 4], [5, 6], [5, 6]]) - expected_labels_interleave = ["a", "a", "a", "b", "c", "c"] - - assert torch.all(torch.eq(repeated_data_interleave.batch["obs"], expected_obs_interleave)) - assert (repeated_data_interleave.non_tensor_batch["labels"] == expected_labels_interleave).all() - assert repeated_data_interleave.meta_info == {"info": "test_info"} - - # torch.tensor - repeated_data_no_interleave = data.sample_level_repeat(repeat_times=torch.tensor([1, 2, 3])) - expected_obs_no_interleave = torch.tensor([[1, 2], [3, 4], [3, 4], [5, 6], [5, 6], [5, 6]]) - expected_labels_no_interleave = ["a", "b", "b", "c", "c", "c"] - - assert torch.all(torch.eq(repeated_data_no_interleave.batch["obs"], expected_obs_no_interleave)) - assert (repeated_data_no_interleave.non_tensor_batch["labels"] == expected_labels_no_interleave).all() - assert repeated_data_no_interleave.meta_info == {"info": "test_info"} - - -def test_dataproto_unfold_column_chunks(): - obs1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) - obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]]) - - labels = ["a", "b", "c"] - data = DataProto.from_dict( - tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} - ) - ret = data.unfold_column_chunks(2, split_keys=["obs1"]) - - expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) - expect_obs2 = torch.tensor([[1, 2], [1, 2], [5, 6], [5, 6], [9, 10], [9, 10]]) - expect_labels = ["a", "a", "b", "b", "c", "c"] - assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1)) - assert torch.all(torch.eq(ret.batch["obs2"], expect_obs2)) - assert (ret.non_tensor_batch["labels"] == expect_labels).all() - assert ret.meta_info == {"name": "abc"} - - obs1 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) - obs2 = torch.tensor([[1, 2], [5, 6], [9, 10]]) - - labels = [["a1", "a2"], ["b1", "b2"], ["c1", "c2"]] - data = DataProto.from_dict( - tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} - ) - ret = data.unfold_column_chunks(2, split_keys=["obs1", "labels"]) - - expect_obs1 = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) - expect_obs2 = torch.tensor([[1, 2], [1, 2], [5, 6], [5, 6], [9, 10], [9, 10]]) - expect_labels = [["a1"], ["a2"], ["b1"], ["b2"], ["c1"], ["c2"]] - assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1)) - assert torch.all(torch.eq(ret.batch["obs2"], expect_obs2)) - assert (ret.non_tensor_batch["labels"] == expect_labels).all() - assert ret.meta_info == {"name": "abc"} - - obs1 = torch.tensor( - [[[1, 1], [2, 2], [3, 3], [4, 4]], [[5, 5], [6, 6], [7, 7], [8, 8]], [[9, 9], [10, 10], [11, 11], [12, 12]]] - ) - obs2 = torch.tensor([[[1, 1], [2, 2]], [[5, 5], [6, 6]], [[9, 9], [10, 10]]]) - - labels = ["a", "b", "c"] - data = DataProto.from_dict( - tensors={"obs1": obs1, "obs2": obs2}, non_tensors={"labels": labels}, meta_info={"name": "abc"} - ) - ret = data.unfold_column_chunks(2, split_keys=["obs1"]) - - expect_obs1 = torch.tensor( - [ - [[1, 1], [2, 2]], - [[3, 3], [4, 4]], - [[5, 5], [6, 6]], - [[7, 7], [8, 8]], - [[9, 9], [10, 10]], - [[11, 11], [12, 12]], - ] - ) - expect_obs2 = torch.tensor( - [[[1, 1], [2, 2]], [[1, 1], [2, 2]], [[5, 5], [6, 6]], [[5, 5], [6, 6]], [[9, 9], [10, 10]], [[9, 9], [10, 10]]] - ) - expect_labels = ["a", "a", "b", "b", "c", "c"] - assert torch.all(torch.eq(ret.batch["obs1"], expect_obs1)) - assert torch.all(torch.eq(ret.batch["obs2"], expect_obs2)) - assert (ret.non_tensor_batch["labels"] == expect_labels).all() - assert ret.meta_info == {"name": "abc"} - - -def test_dataproto_chunk_after_index(): - data_len = 4 - obs = torch.randn(data_len, 4) - labels = [f"label_{i}" for i in range(data_len)] - data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abc"}) - - # Test with boolean numpy array - bool_mask = np.array([True, False, True, False]) - selected = data[bool_mask] - assert isinstance(selected.batch.batch_size, torch.Size) - assert all(isinstance(d, int) for d in selected.batch.batch_size) # int or List[int] - - # Test with integer numpy array - int_mask = np.array([0, 2]) - selected = data[int_mask] - assert isinstance(selected.batch.batch_size, torch.Size) - assert all(isinstance(d, int) for d in selected.batch.batch_size) - - # Test with boolean list - list_mask = [True, False, True, False] - selected = data[list_mask] - assert isinstance(selected.batch.batch_size, torch.Size) - assert all(isinstance(d, int) for d in selected.batch.batch_size) - - # Test with list - list_mask = [0, 2] - selected = data[list_mask] - assert isinstance(selected.batch.batch_size, torch.Size) - assert all(isinstance(d, int) for d in selected.batch.batch_size) - - # Test with torch tensor (bool) - torch_bool_mask = torch.tensor([True, False, True, False]) - selected = data[torch_bool_mask] - assert isinstance(selected.batch.batch_size, torch.Size) - assert all(isinstance(d, int) for d in selected.batch.batch_size) - - # Test with torch tensor (int) - torch_int_mask = torch.tensor([0, 2]) - selected = data[torch_int_mask] - assert isinstance(selected.batch.batch_size, torch.Size) - assert all(isinstance(d, int) for d in selected.batch.batch_size) diff --git a/tests/tools/test_base_tool_on_cpu.py b/tests/tools/test_base_tool_on_cpu.py deleted file mode 100644 index 63a2bbb37..000000000 --- a/tests/tools/test_base_tool_on_cpu.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Unit Tests for `initialize_tools_from_config` -import json -import os -from typing import Any - -import pytest -from transformers.utils import get_json_schema - -from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema -from verl.tools.utils.tool_registry import initialize_tools_from_config - - -class WeatherToolForTest(BaseTool): - def get_current_temperature(self, location: str, unit: str = "celsius"): - """Get current temperature at a location. - - Args: - location: The location to get the temperature for, in the format "City, State, Country". - unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) - - Returns: - the temperature, the location, and the unit in a dict - """ - return { - "temperature": 26.1, - "location": location, - "unit": unit, - } - - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - schema = get_json_schema(self.get_current_temperature) - return OpenAIFunctionToolSchema(**schema) - - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - try: - result = self.get_current_temperature(**parameters) - return json.dumps(result), 0, {} - except Exception as e: - return str(e), 0, {} - - -class WeatherToolWithDataForTest(BaseTool): - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - schema = get_json_schema(self.get_temperature_date) - return OpenAIFunctionToolSchema(**schema) - - def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): - """Get temperature at a location and date. - - Args: - location: The location to get the temperature for, in the format "City, State, Country". - date: The date to get the temperature for, in the format "Year-Month-Day". - unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) - - Returns: - the temperature, the location, the date and the unit in a dict - """ - return { - "temperature": 25.9, - "location": location, - "date": date, - "unit": unit, - } - - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - try: - result = self.get_temperature_date(**parameters) - return json.dumps(result), 0, {} - except Exception as e: - return str(e), 0, {} - - -@pytest.fixture -def create_local_tool_config(): - tool_config = { - "tools": [ - { - "class_name": "tests.tools.test_base_tool_on_cpu.WeatherToolForTest", - "config": {"type": "native"}, - }, - { - "class_name": "tests.tools.test_base_tool_on_cpu.WeatherToolWithDataForTest", - "config": {"type": "native"}, - }, - ] - } - tool_config_path = "/tmp/tool_config.json" - with open(tool_config_path, "w") as f: - json.dump(tool_config, f) - yield tool_config_path - if os.path.exists(tool_config_path): - os.remove(tool_config_path) - - -@pytest.fixture -def create_fake_tool_config(): - tool_config = { - "tools": [ - { - "class_name": "tests.workers.rollout.fake_path.test_vllm_chat_scheduler.WeatherTool", - "config": {"type": "native"}, - }, - { - "class_name": "tests.workers.rollout.fake_path.test_vllm_chat_scheduler.WeatherToolWithData", - "config": {"type": "native"}, - }, - ] - } - tool_config_path = "/tmp/tool_config.json" - with open(tool_config_path, "w") as f: - json.dump(tool_config, f) - yield tool_config_path - if os.path.exists(tool_config_path): - os.remove(tool_config_path) - - -def test_initialize_tools_from_fake_config(create_fake_tool_config): - tool_config_path = create_fake_tool_config - - # Use pytest.raises to check if an exception is raised when calling initialize_tools_from_config. - # Since the tool configuration uses fake paths, an exception is expected during the tool initialization process. - with pytest.raises(ModuleNotFoundError): - _ = initialize_tools_from_config(tool_config_path) - - -def test_initialize_tools_from_local_config(create_local_tool_config): - """ - Test the `initialize_tools_from_config` function using a local tool configuration. - This test verifies that the function can correctly initialize tools based on a local configuration file. - - Args: - create_local_tool_config: A pytest fixture that creates a local tool configuration file - and returns its path. After the test is completed, the fixture - will clean up the configuration file. - """ - # Retrieve the path of the local tool configuration file generated by the fixture - tool_config_path = create_local_tool_config - - tools = initialize_tools_from_config(tool_config_path) - - assert len(tools) == 2 - from tests.tools.test_base_tool_on_cpu import WeatherToolForTest, WeatherToolWithDataForTest - - assert isinstance(tools[0], WeatherToolForTest) - assert isinstance(tools[1], WeatherToolWithDataForTest) - assert tools[0].config == {"type": "native"} - assert tools[1].config == {"type": "native"} diff --git a/tests/trainer/__init__.py b/tests/trainer/__init__.py deleted file mode 100644 index 6f79d474d..000000000 --- a/tests/trainer/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests for the trainer module. -""" diff --git a/tests/trainer/config/__init__.py b/tests/trainer/config/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/tests/trainer/config/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/trainer/config/legacy_ppo_megatron_trainer.yaml b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml deleted file mode 100644 index fc146c934..000000000 --- a/tests/trainer/config/legacy_ppo_megatron_trainer.yaml +++ /dev/null @@ -1,472 +0,0 @@ -data: - tokenizer: null - train_files: ~/data/rlhf/gsm8k/train.parquet - val_files: ~/data/rlhf/gsm8k/test.parquet - prompt_key: prompt - reward_fn_key: data_source - max_prompt_length: 512 - max_response_length: 512 - train_batch_size: 1024 - val_batch_size: null # DEPRECATED: Validation datasets are sent to inference engines as a whole batch, which will schedule the memory themselves - return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs - return_raw_chat: False - return_full_prompt: False - shuffle: True - filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up. - filter_overlong_prompts_workers: 1 - truncation: error - trust_remote_code: False # main_ppo will check this config to determine whether to use remote code for tokenizer - custom_cls: - path: null - name: null - sampler: - class_path: null - class_name: null - dataloader_num_workers: 8 - return_multi_modal_inputs: True - -actor_rollout_ref: - hybrid_engine: True - nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron - model: - path: ~/models/deepseek-llm-7b-chat - custom_chat_template: null - external_lib: null - override_config: - model_config: {} - moe_config: - freeze_moe_router: False - enable_gradient_checkpointing: False - gradient_checkpointing_kwargs: - ## Activation Checkpointing - activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective' - # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk - # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity - activations_checkpoint_granularity: null # 'selective' or 'full' - # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention - activations_checkpoint_num_layers: null # not used with 'selective' - trust_remote_code: False - actor: - strategy: megatron # This is for backward-compatibility - ppo_mini_batch_size: 256 - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: null - use_dynamic_bsz: False - ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} - use_torch_compile: True # False to disable torch compile - # pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) - clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified - clip_ratio_low: 0.2 - clip_ratio_high: 0.2 - clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729 - loss_agg_mode: "token-mean" # / "seq-mean-token-sum" / "seq-mean-token-mean" - # NOTE: "token-mean" is the default behavior - entropy_coeff: 0 - use_kl_loss: False # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo - ppo_epochs: 1 - data_loader_seed: null - shuffle: False - policy_loss: # policy loss config - loss_mode: "vanilla" # Loss function mode: vanilla / clip-cov / kl-cov / gpg from https://arxiv.org/abs/2505.22617, - clip_cov_ratio: 0.0002 # Ratio of tokens to be clipped for clip-cov loss - clip_cov_lb: 1.0 # Lower bound for clip-cov loss - clip_cov_ub: 5.0 # Upper bound for clip-cov loss - kl_cov_ratio: 0.0002 # Ratio of tokens to be applied kl penalty for kl-cov loss - ppo_kl_coef: 0.1 # KL divergence penalty coefficient - optim: - optimizer: adam - lr: 1e-6 - clip_grad: 1.0 - total_training_steps: -1 # must be override by program - lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0 - lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - lr_decay_steps: null - lr_decay_style: constant # select from constant/linear/cosine/inverse_square_root - min_lr: 0.0 # minimum learning rate, default to 0.0 - weight_decay: 0.01 - weight_decay_incr_style: constant # select from constant/linear/cosine - lr_wsd_decay_style: exponential # select from constant/exponential/cosine - lr_wsd_decay_steps: null - use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler - megatron: - param_offload: False - grad_offload: False - optimizer_offload: False - tensor_model_parallel_size: 1 - expert_model_parallel_size: 1 - expert_tensor_parallel_size: null - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests - context_parallel_size: 1 - sequence_parallel: True - use_distributed_optimizer: True - use_dist_checkpointing: False - dist_checkpointing_path: null - seed: 42 - override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage - use_mbridge: False - profile: # profile the actor model in `update_policy` - use_profile: False # open it when you want to profile the actor model - profile_ranks: null # list, you can specify the ranks to profile - step_start: -1 # start step in update_policy - step_end: -1 # end step - save_path: null # the path to save the profile result - load_weight: True - checkpoint: - async_save: False # save checkpoint asynchronously - # What to include in saved checkpoints - # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - save_contents: ['model', 'optimizer', 'extra'] - # For more flexibility, you can specify the contents to load from the checkpoint. - load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} - ref: - strategy: ${actor_rollout_ref.actor.strategy} - use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} - megatron: - param_offload: False - tensor_model_parallel_size: 1 - expert_model_parallel_size: 1 - expert_tensor_parallel_size: None - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests - context_parallel_size: 1 - sequence_parallel: True - use_distributed_optimizer: False - use_dist_checkpointing: False - dist_checkpointing_path: null - seed: ${actor_rollout_ref.actor.megatron.seed} - override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} - use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} - profile: - use_profile: False - profile_ranks: null - step_start: -1 - step_end: -1 - save_path: null - load_weight: True - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - rollout: - name: vllm - mode: sync # sync: LLM, async: AsyncLLM - temperature: 1.0 - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1 - prompt_length: ${data.max_prompt_length} # for xperf_gpt - response_length: ${data.max_response_length} - # for vllm rollout - dtype: bfloat16 # should align with FSDP - gpu_memory_utilization: 0.5 - ignore_eos: False - enforce_eager: True - free_cache_engine: True - load_format: dummy_megatron - tensor_model_parallel_size: 1 - max_num_batched_tokens: 8192 - max_model_len: null - max_num_seqs: 1024 - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - disable_log_stats: True - enable_chunked_prefill: False # could get higher throughput - # for hf rollout - do_sample: True - layer_name_map: - qkv_layer_name: qkv - gate_proj_layer_name: gate_up - # number of responses (i.e. num sample times) - n: 1 - engine_kwargs: # inference engine parameters - vllm: - swap_space: null # null means "use the engine default value" (usually 4 GB), setting it to, e.g., 32 means 32 GB - disable_mm_preprocessor_cache: False # whether to disable the preprocessor cache for multimodel models. - sglang: - attention_backend: null # null means use the engine default value, available options: flashinfer, triton, flashmla - val_kwargs: - # sampling parameters for validation - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1.0 - temperature: 0 - n: 1 - do_sample: False # default eager for validation - - # Multi-turn interaction config for tools or chat. - multi_turn: - # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well - enable: False - - # null for no limit (default max_length // 3) - max_assistant_turns: null - - # null for no tool - tool_config_path: null - - # null for no limit (default max_length // 3) - max_user_turns: null - - # max parallel call for tools in single turn - max_parallel_calls: 1 - - # max length of tool response - max_tool_response_length: 256 - - # truncate side of tool response: left, middle, right - tool_response_truncate_side: middle - - # null for no interaction - interaction_config_path: null - - # null for default callback - completion_callback: null - - # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. - # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, - # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. - use_inference_chat_template: False - - # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. - # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. - # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. - # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: - # Qwen/QwQ-32B, Qwen/Qwen3-xxB - # - disable: disable tokenization sanity check - # - strict: enable strict tokenization sanity check (default) - # - ignore_strippable: ignore strippable tokens when checking tokenization sanity - tokenization_sanity_check_mode: strict - - # Format of the multi-turn interaction. Options: hermes, llama3_json, ... - format: hermes - - # [Experimental] agent loop based rollout configs - agent: - - # Number of agent loop workers - num_workers: 8 - - custom_async_server: - path: null - name: null - - # support logging rollout prob for debugging purpose - calculate_log_probs: False - # Nsight system profiler configs - profiler: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.utils.profiler.ProfilerConfig - discrete: False - all_ranks: False - ranks: [] - -critic: - rollout_n: ${actor_rollout_ref.rollout.n} - strategy: ${actor_rollout_ref.actor.strategy} - nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron - optim: - optimizer: adam - lr: 1e-6 - clip_grad: 1.0 - total_training_steps: -1 # must be override by program - lr_warmup_init: 0.0 # initial learning rate for warmup, default to 0.0 - lr_warmup_steps: null # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - lr_decay_steps: null - lr_decay_style: linear # select from constant/linear/cosine/inverse_square_root - min_lr: 0.0 # minimum learning rate, default to 0.0 - weight_decay: 0.01 - weight_decay_incr_style: constant # select from constant/linear/cosine - lr_wsd_decay_style: exponential # select from constant/exponential/cosine - lr_wsd_decay_steps: null - use_checkpoint_opt_param_scheduler: False # use checkpoint optimizer parameter scheduler - model: - path: ~/models/deepseek-llm-7b-chat - tokenizer_path: ${actor_rollout_ref.model.path} - override_config: - model_config: {} - moe_config: - freeze_moe_router: False - external_lib: ${actor_rollout_ref.model.external_lib} - trust_remote_code: False - enable_gradient_checkpointing: False - gradient_checkpointing_kwargs: - ## Activation Checkpointing - activations_checkpoint_method: null - activations_checkpoint_granularity: null - activations_checkpoint_num_layers: null - megatron: - param_offload: False - grad_offload: False - optimizer_offload: False - tensor_model_parallel_size: 1 - expert_model_parallel_size: 1 - expert_tensor_parallel_size: null - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests - context_parallel_size: 1 - sequence_parallel: True - use_distributed_optimizer: True - use_dist_checkpointing: False - dist_checkpointing_path: null - seed: ${actor_rollout_ref.actor.megatron.seed} - override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} - use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} - load_weight: True - ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: null - use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 - forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} - ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} - data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed} - shuffle: ${actor_rollout_ref.actor.shuffle} - cliprange_value: 0.5 - kl_ctrl: - type: fixed - kl_coef: 0.001 - loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} - checkpoint: - async_save: False # save checkpoint asynchronously - # What to include in saved checkpoints - # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - save_contents: ['model', 'optimizer', 'extra'] - load_contents: ${critic.checkpoint.save_contents} - # Nsight system profiler configs - profiler: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.utils.profiler.ProfilerConfig - discrete: False - all_ranks: False - ranks: [] -reward_model: - enable: False - strategy: ${actor_rollout_ref.actor.strategy} - nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron - megatron: - param_offload: False - tensor_model_parallel_size: 1 - expert_model_parallel_size: 1 - expert_tensor_parallel_size: null - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests - context_parallel_size: 1 - sequence_parallel: True - use_distributed_optimizer: False - use_dist_checkpointing: False - dist_checkpointing_path: null - seed: ${actor_rollout_ref.actor.megatron.seed} - override_transformer_config: {} - use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} - model: - input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical - path: ~/models/FsfairX-LLaMA3-RM-v0.1 - trust_remote_code: False - external_lib: ${actor_rollout_ref.model.external_lib} - load_weight: True - micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu - micro_batch_size_per_gpu: null - use_dynamic_bsz: ${critic.use_dynamic_bsz} - forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} - max_length: null - reward_manager: naive - launch_reward_fn_async: False # custom reward function executed async on CPU, during log_prob - sandbox_fusion: - url: null # faas url to run code in cloud sandbox - max_concurrent: 64 # max concurrent requests to sandbox - memory_limit_mb: 1024 # Max memory limit for each sandbox process in MB - # Nsight system profiler configs - profiler: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.utils.profiler.ProfilerConfig - discrete: False - all_ranks: False - ranks: [] - -custom_reward_function: - path: null - name: compute_score - -algorithm: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.trainer.config.AlgoConfig - gamma: 1.0 - lam: 1.0 - adv_estimator: gae - norm_adv_by_std_in_grpo: True - use_kl_in_reward: False - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.trainer.config.KLControlConfig - type: fixed - kl_coef: 0.001 - horizon: 10000 - target_kl: 0.1 - use_pf_ppo: False - pf_ppo: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.trainer.config.PFPPOConfig - reweight_method: pow # ["pow", "max_min", "max_random"] - weight_pow: 2.0 - -trainer: - balance_batch: True - total_epochs: 30 - total_training_steps: null - profile_steps: null # [1,2,5] or [] or null - project_name: verl_examples - experiment_name: gsm8k - logger: ['console', 'wandb'] - log_val_generations: 0 - nnodes: 1 - n_gpus_per_node: 8 - save_freq: -1 - esi_redundant_time: 0 - - # auto: find the last ckpt to resume. If can't find, start from scratch - resume_mode: auto # or disable or resume_path if resume_from_path is set - resume_from_path: null - del_local_ckpt_after_load: False - val_before_train: True - test_freq: -1 - critic_warmup: 0 - default_hdfs_dir: null - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} - max_actor_ckpt_to_keep: null - max_critic_ckpt_to_keep: null - # The timeout for ray worker group to wait for the register center to be ready - ray_wait_register_center_timeout: 300 - device: cuda - # see ppo_trainer.yaml for more details - controller_nsight_options: - trace: "cuda,nvtx,cublas,ucx" - cuda-memory-usage: "true" - cuda-graph-trace: "graph" - worker_nsight_options: - trace: "cuda,nvtx,cublas,ucx" - cuda-memory-usage: "true" - cuda-graph-trace: "graph" - capture-range: "cudaProfilerApi" - capture-range-end: null - kill: none - npu_profile: - options: - save_path: ./profiler_data - level: level1 - with_memory: False - record_shapes: False - with_npu: True - with_cpu: True - with_module: False - with_stack: False - analysis: True - -ray_init: - num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. - timeline_json_file: null diff --git a/tests/trainer/config/legacy_ppo_trainer.yaml b/tests/trainer/config/legacy_ppo_trainer.yaml deleted file mode 100644 index 8ba94e204..000000000 --- a/tests/trainer/config/legacy_ppo_trainer.yaml +++ /dev/null @@ -1,1111 +0,0 @@ -# Format checks enforced on CI: -# 1. Comments must appear above each field. -# 2. There must be a blank line between each field. -# 3. Inline comments (after a field on the same line) are not allowed. -# 4. Indentation level is respected for nested fields. - -# dataset config -data: - - # Tokenizer class or path. If null, it will be inferred from the model. - tokenizer: null - - # Whether to use shared memory for data loading. - use_shm: False - - # Training set parquet. Can be a list or a single file. - # The program will read all files into memory, so it can't be too large (< 100GB). - # The path can be either a local path or an HDFS path. - # For HDFS path, we provide utils to download it to DRAM and convert it to a local path. - train_files: ~/data/rlhf/gsm8k/train.parquet - - # Validation parquet. Can be a list or a single file. - val_files: ~/data/rlhf/gsm8k/test.parquet - - # The field in the dataset where the prompt is located. Default is 'prompt'. - prompt_key: prompt - - # The field used to select the reward function (if using different ones per example). - reward_fn_key: data_source - - # Maximum prompt length. All prompts will be left-padded to this length. - # An error will be reported if the length is too long. - max_prompt_length: 512 - - # Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length. - max_response_length: 512 - - # Batch size sampled for one training iteration of different RL algorithms. - train_batch_size: 1024 - - # Batch size used during validation. Can be null. - val_batch_size: null - - # Whether to return the original input_ids without adding chat template. - # This is used when the reward model's chat template differs from the policy. - # If using a model-based RM with different templates, this should be True. - return_raw_input_ids: False - - # Whether to return the original chat (prompt) without applying chat template. - return_raw_chat: False - - # Whether to return the full prompt with chat template. - return_full_prompt: False - - # Whether to shuffle the data in the dataloader. - shuffle: True - - # num dataloader workers - dataloader_num_workers: 8 - - # Whether to shuffle the validation set. - validation_shuffle: False - - # Whether to filter overlong prompts. - filter_overlong_prompts: False - - # Number of workers for filtering overlong prompts. - # For large-scale datasets, filtering can be time-consuming. - # Use multiprocessing to speed up. Default is 1. - filter_overlong_prompts_workers: 1 - - # Truncate the input_ids or prompt if they exceed max_prompt_length. - # Options: 'error', 'left', or 'right'. Default is 'error'. - truncation: error - - # The field in the multi-modal dataset where the image is located. Default is 'images'. - image_key: images - - # The field in the multi-modal dataset where the video is located. - video_key: videos - - # If the remote tokenizer has a Python file, this flag determines whether to allow using it. - trust_remote_code: False - - # Optional: specify a custom dataset class path and name if overriding default loading behavior. - custom_cls: - - # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used. - path: null - - # The name of the dataset class within the specified file. - name: null - - # Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs. - return_multi_modal_inputs: True - - # Data generation configuration for augmenting the dataset. - datagen: - - # The path to the file containing your customized data generation class. - # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset' - path: null - - # The class name of the data generation class within the specified file. - # E.g. 'MockDataGenerator' - name: null - - # settings related to data sampler - sampler: - - # the path to the module containing a curriculum class which implements the - # AbstractSampler interface - class_path: null - - # the name of the curriculum class like `MySampler` - class_name: null - -# config for actor, rollout and reference model -actor_rollout_ref: - - # Whether it's a hybrid engine, currently only supports hybrid engine - hybrid_engine: true - - # common configs for the model - model: - - # Huggingface model path. This can be either local path or HDFS path. - path: ~/models/deepseek-llm-7b-chat - - # Custom chat template for the model. - custom_chat_template: null - - # Whether to use shared memory (SHM) for accelerating the loading of model weights - use_shm: false - - # Additional Python packages to register huggingface models/tokenizers. - external_lib: null - - # Used to override model's original configurations, mainly dropout - override_config: {} - - # Enable gradient checkpointing for actor - enable_gradient_checkpointing: true - - # Enable activation offloading for actor - enable_activation_offload: false - - # Whether to remove padding tokens in inputs during training - use_remove_padding: false - - # Set to positive value to enable LoRA (e.g., 32) - lora_rank: 0 - - # LoRA scaling factor - lora_alpha: 16 - - # Target modules to apply LoRA. Options: "all-linear" (not recommended for VLMs) or - # [q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj] - target_modules: all-linear - - # Exclude modules from applying Lora. Similar usage to target_modules and Peft. - # Example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as currently vllm does not support ViT Lora. - exclude_modules: null - - # Whether to use Liger for linear layer fusion - use_liger: false - - # Whether to use custom fused kernels (e.g., FlashAttention, fused MLP) - use_fused_kernels: false - - # Options for fused kernels. If use_fused_kernels is true, this will be used. - fused_kernel_options: - - # Implementation backend for fused kernels. Options: "triton" or "torch". - impl_backend: torch - - # Whether to enable loading a remote code model - trust_remote_code: false - - # actor configs - actor: - - # fsdp, fsdp2 or megatron. fsdp backend used here. - strategy: fsdp - - # Split each sample into sub-batches of this size for PPO - ppo_mini_batch_size: 256 - - # [Deprecated] Global micro batch size - ppo_micro_batch_size: null - - # Local per-GPU micro batch size - ppo_micro_batch_size_per_gpu: null - - # Whether to automatically adjust batch size at runtime - use_dynamic_bsz: false - - # Max tokens per GPU in one PPO batch; affects gradient accumulation - # Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length} - ppo_max_token_len_per_gpu: 16384 - - # Gradient clipping for actor updates - grad_clip: 1.0 - - # PPO clip ratio - clip_ratio: 0.2 - - # Lower bound for asymmetric clipping (used in dual-clip PPO) - clip_ratio_low: 0.2 - - # Upper bound for asymmetric clipping (used in dual-clip PPO) - clip_ratio_high: 0.2 - - # policy loss config - policy_loss: - - # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617 - loss_mode: "vanilla" - - # Ratio of tokens to be clipped for clip-cov loss - clip_cov_ratio: 0.0002 - - # Lower bound for clip-cov loss - clip_cov_lb: 1.0 - - # Upper bound for clip-cov loss - clip_cov_ub: 5.0 - - # Ratio of tokens to be applied kl penalty for kl-cov loss - kl_cov_ratio: 0.0002 - - # KL divergence penalty coefficient - ppo_kl_coef: 0.1 - - # Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C - clip_ratio_c: 3.0 - - # Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" - loss_agg_mode: token-mean - - # Entropy regularization coefficient in PPO loss - entropy_coeff: 0 - - # Whether to use KL loss instead of KL reward penalty. True for GRPO - use_kl_loss: false - - # Whether to use torch.compile() - use_torch_compile: true - - # KL loss coefficient when use_kl_loss is enabled. For GRPO - kl_loss_coef: 0.001 - - # Type of KL divergence loss. Options: "kl"(k1), "abs", "mse"(k2), "low_var_kl"(k3), "full" - kl_loss_type: low_var_kl - - # Number of PPO epochs per batch - ppo_epochs: 1 - - # Shuffle training data across PPO epochs - shuffle: false - - # Sequence parallelism size for Ulysses-style model parallelism - ulysses_sequence_parallel_size: 1 - - # calculate entropy with chunking to reduce memory peak - entropy_from_logits_with_chunking: False - - # recompute entropy - entropy_checkpointing: False - - # checkpoint configs - checkpoint: - - # What to include in saved checkpoints - # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - save_contents: ['model', 'optimizer', 'extra'] - - # For more flexibility, you can specify the contents to load from the checkpoint. - load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents} - - # optimizer configs - optim: - - # Learning rate - lr: 1e-6 - - # Warmup steps; negative value delegates to lr_warmup_steps_ratio - lr_warmup_steps: -1 - - # Warmup steps ratio (used if lr_warmup_steps is negative) - lr_warmup_steps_ratio: 0.0 - - # Minimum LR ratio for cosine schedule - min_lr_ratio: 0.0 - - # Number of cosine cycles in LR schedule - num_cycles: 0.5 - - # LR warmup style: "constant" or "cosine" - warmup_style: constant - - # Total training steps (must be overridden at runtime) - total_training_steps: -1 - - # Weight decay - weight_decay: 0.01 - - # configs for FSDP - fsdp_config: - - # policy for wrapping the model - wrap_policy: - - # Minimum number of parameters to trigger wrapping a layer with FSDP - min_num_params: 0 - - # Whether to offload model parameters to CPU (trades speed for memory) - param_offload: false - - # Whether to offload optimizer state to CPU - optimizer_offload: false - - # Only for FSDP2: offload param/grad/optimizer during train - offload_policy: false - - # Only for FSDP2: Reshard after forward pass to reduce memory footprint - reshard_after_forward: true - - # Number of GPUs in each FSDP shard group; -1 means auto - fsdp_size: -1 - - # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather - # before the current forward computation. - forward_prefetch: False - - # Reference model config. - # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. - ref: - - # actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default - strategy: ${actor_rollout_ref.actor.strategy} - - # config for FSDP strategy - fsdp_config: - - # whether to offload parameters in FSDP - param_offload: False - - # whether to perform reshard after model forward to save memory. - # only for fsdp2, [True, False, int between 1 and fsdp_size] - reshard_after_forward: True - - # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather - # before the current forward computation. - forward_prefetch: False - - # the wrap policy for FSDP model - wrap_policy: - - # minimum number of params in a wrapped module - min_num_params: 0 - - # whether to enable torch.compile - use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile} - - # [Will be deprecated, use log_prob_micro_batch_size_per_gpu] - # The batch size for one forward pass in the computation of log_prob. Global batch size. - log_prob_micro_batch_size: null - - # The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. - log_prob_micro_batch_size_per_gpu: null - - # enable dynamic batch size (sequence packing) for log_prob computation - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - - # the max token length per GPU - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - - # sequence parallel size - ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} - - # calculate entropy with chunking to reduce memory peak - entropy_from_logits_with_chunking: False - - # recompute entropy - entropy_checkpointing: False - - # Rollout model config. - rollout: - - # actor_rollout_ref.rollout.name: hf/vllm/sglang. - name: vllm - - # sync: LLM, async: AsyncLLM - mode: sync - - # Sampling temperature for rollout. - temperature: 1.0 - - # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. - top_k: -1 - - # Top-p sampling parameter. Default 1.0. - top_p: 1 - - - # typically the same as data max prompt length - prompt_length: ${data.max_prompt_length} - - # typically the same as data max response length - response_length: ${data.max_response_length} - - # for vllm rollout - # Rollout model parameters type. Align with actor model's FSDP/Megatron type. - dtype: bfloat16 - - # Fraction of GPU memory used by vLLM/SGLang for KV cache. - gpu_memory_utilization: 0.5 - - # Whether to ignore EOS and continue generating after EOS is hit. - ignore_eos: False - - # Whether to disable CUDA graph. Default True to allow cache freeing. - enforce_eager: True - - # Whether to free engine KVCache after generation. Set enforce_eager=True when enabled. - free_cache_engine: True - - # Which loader to use for rollout model weights: dummy_dtensor, hf, megatron, etc. - # safetensors (for huge model, and set use_shm=True); dummy_dtensor: randomly init model weight - load_format: dummy_dtensor - - # for huge model, layered summon can save memory (prevent OOM) but make it slower - layered_summon: False - - # TP size for rollout. Only effective for vLLM. - tensor_model_parallel_size: 2 - - # max number of tokens in a batch - max_num_batched_tokens: 8192 - - # max length for rollout - max_model_len: null - - # max length of sequences - max_num_seqs: 1024 - - # [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size. - log_prob_micro_batch_size: null - - # The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. - log_prob_micro_batch_size_per_gpu: null - - # enable dynamic batch size (sequence packing) for log_prob computation - log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - - # max token length for log_prob computation - log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} - - # disable logging statistics - disable_log_stats: True - - # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. - enable_chunked_prefill: True - - # for hf rollout - # Whether to sample during training rollout. False uses greedy sampling. - do_sample: True - - # number of responses (i.e. num sample times). > 1 for grpo - n: 1 - - # Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache) - multi_stage_wake_up: false - - # Extra inference engine arguments (vllm, sglang). - engine_kwargs: - - # for vllm - vllm: - - # Swap space (in GB) used by inference engine. null uses default (e.g., 4 GB). - swap_space: null - - # Whether to disable the preprocessor cache for multimodel models. - disable_mm_preprocessor_cache: False - - # for sglang - sglang: - - # The attention backend for sglang engine. Options: flashinfer, triton, flashmla, null for default. - attention_backend: null - - # Sampling parameters used during validation. - val_kwargs: - - # sampling parameters for validation - # Top-k sampling parameter. -1 for vLLM rollout, 0 for HF rollout. - top_k: -1 - - # Top-p sampling parameter. Default 1.0. - top_p: 1.0 - - # Sampling temperature for rollout. - temperature: 0 - - # whether to repeat n times for validation - n: 1 - - # Whether to sample during training rollout. False uses greedy sampling. - do_sample: False - - # Multi-turn interaction config for tools or chat. - multi_turn: - - # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well - enable: False - - # null for no limit (default max_length // 3) - max_assistant_turns: null - - # null for no tool - tool_config_path: null - - # null for no limit (default max_length // 3) - max_user_turns: null - - # max parallel call for tools in single turn - max_parallel_calls: 1 - - # max length of tool response - max_tool_response_length: 256 - - # truncate side of tool response: left, middle, right - tool_response_truncate_side: middle - - # null for no interaction - interaction_config_path: null - - # null for default callback - completion_callback: null - - # - When set to True, the model's default chat template is used for multi-turn rollout, which typically matches production behavior. - # - When set to False, the token ids recorded for training are used instead; unlike the default chat template, these always include the model's full output, - # which may contain additional content such as reasoning content. This maintains the consistency between training and rollout, but it will lead to longer prompts. - use_inference_chat_template: False - - # Tokenization is performed turn by turn and the resulting token ids are concatenated to form the full conversation. - # To ensure this matches the result of tokenizing the entire conversation at once, a sanity check is run at the end of each multi-turn rollout to compare the two sets of token ids. - # Some models are known to produce different tokenization results when tokenizing turn by turn vs. all at once. aThis behavior has already been validated for them. - # To reduce excessive warnings, you can turn off the sanity check for these models if you are using their default chat template: - # Qwen/QwQ-32B, Qwen/Qwen3-xxB - # - disable: disable tokenization sanity check - # - strict: enable strict tokenization sanity check (default) - # - ignore_strippable: ignore strippable tokens when checking tokenization sanity - tokenization_sanity_check_mode: strict - - # Format of the multi-turn interaction. Options: hermes, llama3_json, ... - format: hermes - - # support logging rollout prob for debugging purpose - calculate_log_probs: False - - # [Experimental] agent loop based rollout configs - agent: - - # Number of agent loop workers - num_workers: 8 - - # custom async server configs - custom_async_server: - - # Path to the custom async server implementation - path: null - - # Class name of the custom async server class (e.g. AsyncvLLMServer) - name: null - - # profiler configs - profiler: - - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.utils.profiler.ProfilerConfig - - # True for each task has its own database, False for all tasks in one training step share one database. - discrete: False - - # Whether to profile all ranks. - all_ranks: False - - # The ranks that will be profiled. [] or [0,1,...] - ranks: [] - -# configs for the critic -critic: - - # Number of rollouts per update (mirrors actor rollout_n) - rollout_n: ${actor_rollout_ref.rollout.n} - - # fsdp or fsdp2 strategy used for critic model training - strategy: ${actor_rollout_ref.actor.strategy} - - # optimizer configs - optim: - - # Learning rate - lr: 1e-5 - - # Warmup steps ratio; total steps will be injected at runtime - lr_warmup_steps_ratio: 0. - - # Minimum LR ratio for cosine schedule - min_lr_ratio: null - - # LR warmup style: "constant" or "cosine" - warmup_style: constant - - # Total training steps (must be overridden at runtime) - total_training_steps: -1 - - # Weight decay - weight_decay: 0.01 - - # model config for the critic - model: - - # Path to pretrained model weights - path: ~/models/deepseek-llm-7b-chat - - # Whether to use shared memory for loading the model - use_shm: False - - # Tokenizer path (defaults to actor's model path) - tokenizer_path: ${actor_rollout_ref.model.path} - - # Hugging Face config override - override_config: { } - - # External model implementation (optional) - external_lib: ${actor_rollout_ref.model.external_lib} - - # Enable gradient checkpointing to save memory - enable_gradient_checkpointing: True - - # Offload activations to CPU to reduce GPU memory usage - enable_activation_offload: False - - # Use remove padding optimization (saves compute) - use_remove_padding: False - - # Whether to trust remote code from Hugging Face models - trust_remote_code: ${actor_rollout_ref.model.trust_remote_code} - - # FSDP-specific config - fsdp_config: - - # Whether to offload model parameters to CPU - param_offload: False - - # Whether to offload optimizer state to CPU - optimizer_offload: False - - # Only for FSDP2: offload param/grad/optimizer during train - offload_policy: False - - # Only for FSDP2: Reshard after forward pass to reduce memory footprint - reshard_after_forward: True - - # Policy for wrapping layers with FSDP - wrap_policy: - - # Minimum number of parameters to trigger wrapping - min_num_params: 0 - - # Number of GPUs in each FSDP shard group; -1 means auto - fsdp_size: -1 - - # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather - # before the current forward computation. - forward_prefetch: False - - # Set to positive value to enable LoRA (e.g., 32) - lora_rank: 0 - - # LoRA scaling factor - lora_alpha: 16 - - # LoRA target modules: "all-linear" or list of linear projection layers - target_modules: all-linear - - # PPO mini-batch size per update - ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - - # [Deprecated] Global micro batch size - ppo_micro_batch_size: null - - # Local per-GPU micro batch size - ppo_micro_batch_size_per_gpu: null - - # Forward-only batch size (global) - forward_micro_batch_size: ${critic.ppo_micro_batch_size} - - # Forward-only batch size (per GPU) - forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} - - # Whether to automatically adjust batch size at runtime - use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - - # Max tokens per GPU in one PPO batch (doubled for critic) - ppo_max_token_len_per_gpu: 32768 - - # Max token length per GPU in forward pass - forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} - - # Sequence parallelism size for Ulysses-style model parallelism - ulysses_sequence_parallel_size: 1 - - # Number of PPO epochs per batch - ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} - - # Shuffle training data across PPO epochs - shuffle: ${actor_rollout_ref.actor.shuffle} - - # Gradient clipping for critic updates - grad_clip: 1.0 - - # PPO value function clipping range - cliprange_value: 0.5 - - # Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" - loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} - - # checkpoint configs - checkpoint: - - # What to include in saved checkpoints - # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - save_contents: ['model', 'optimizer', 'extra'] - - # What to include when loading checkpoints - load_contents: ${critic.checkpoint.save_contents} - - # profiler configs - # the corresponding dataclass is verl.utils.profiler.ProfilerConfig. - profiler: - - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.utils.profiler.ProfilerConfig - - # True for each task has its own database, False for all tasks in one training step share one database. - discrete: False - - # Whether to profile all ranks. - all_ranks: False - - # The ranks that will be profiled. [] or [0,1,...] - ranks: [] - -# configs for the reward model -reward_model: - - # Whether to enable reward model. If False, we compute the reward only with the user-defined reward functions. - # In GSM8K and Math examples, we disable reward model. - # For RLHF alignment example using full_hh_rlhf, we utilize reward model to assess the responses. - # If False, the following parameters are not effective - enable: False - - # FSDP strategy: "fsdp" or "fsdp2" - strategy: ${actor_rollout_ref.actor.strategy} - - # model config for reward scoring - model: - - # Input tokenizer. If the reward model’s chat template is inconsistent with the policy, - # we need to first decode to plaintext, then apply the rm’s chat_template. - # Then score with RM. If chat_templates are consistent, it can be set to null. - input_tokenizer: ${actor_rollout_ref.model.path} - - # RM’s HDFS path or local path. Note that RM only supports AutoModelForSequenceClassification. - # Other model types need to define their own RewardModelWorker and pass it from the code. - path: ~/models/FsfairX-LLaMA3-RM-v0.1 - - # Whether to use shared memory for loading the model - use_shm: False - - # External model implementation (optional) - external_lib: ${actor_rollout_ref.model.external_lib} - - # Use remove padding optimization (saves compute) - use_remove_padding: False - - # Whether to use fused reward kernels for speedup - use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} - - # Whether to enable loading a remote code model, default to False - trust_remote_code: False - - # FSDP-specific config - fsdp_config: - - # Policy for wrapping layers with FSDP - wrap_policy: - - # Minimum number of parameters to trigger wrapping - min_num_params: 0 - - # Whether to offload model parameters to CPU - param_offload: False - - # Only for FSDP2: Reshard after forward pass to reduce memory footprint - reshard_after_forward: True - - # Number of GPUs in each FSDP shard group; -1 means auto - fsdp_size: -1 - - # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather - # before the current forward computation. - forward_prefetch: False - - # [Deprecated] Global micro batch size - micro_batch_size: null - - # Local per-GPU micro batch size - micro_batch_size_per_gpu: null - - # Maximum sequence length to process for scoring - max_length: null - - # Sequence parallelism size for Ulysses-style model parallelism - ulysses_sequence_parallel_size: 1 - - # Whether to dynamically adjust batch size at runtime - use_dynamic_bsz: ${critic.use_dynamic_bsz} - - # Maximum number of tokens per GPU in one forward pass - forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} - - # Reward Manager. This defines the mechanism of computing rule-based reward and handling different reward sources. - # Default is naive. If all verification functions are multiprocessing-safe, - # the reward manager can be set to prime for parallel verification. - reward_manager: naive - - # Whether to launch custom reward function asynchronously during log_prob - launch_reward_fn_async: False - - # Cloud/local sandbox fusion configuration for custom reward logic - sandbox_fusion: - - # Cloud/local function URL for sandbox execution - url: null - - # Max concurrent requests allowed to sandbox - max_concurrent: 64 - - # Max memory limit for each sandbox process in MB - memory_limit_mb: 1024 - - # profiler configs - profiler: - - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.utils.profiler.ProfilerConfig - - # True for each task has its own database, False for all tasks in one training step share one database. - discrete: False - - # Whether to profile all ranks. - all_ranks: False - - # The ranks that will be profiled. [] or [0,1,...] - ranks: [] - -# custom reward function definition -custom_reward_function: - - # The path to the file containing your customized reward function. - # If not specified, pre-implemented reward functions will be used. - path: null - - # The name of the reward function within the specified file. Default is 'compute_score'. - name: compute_score - -# config for the algorithm -algorithm: - - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.trainer.config.AlgoConfig - - # Discount factor for future rewards - gamma: 1.0 - - # Trade-off between bias and variance in the GAE estimator - lam: 1.0 - - # Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. - adv_estimator: gae - - # Whether to normalize advantages by std (specific to GRPO) - norm_adv_by_std_in_grpo: True - - # Whether to enable in-reward KL penalty - use_kl_in_reward: False - - # How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full" - kl_penalty: kl - - # KL control configuration - kl_ctrl: - - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.trainer.config.KLControlConfig - - # KL control type: "fixed" or "adaptive" - type: fixed - - # Initial coefficient for KL penalty - kl_coef: 0.001 - - # Horizon value for adaptive controller (if enabled) - horizon: 10000 - - # Target KL divergence (used for adaptive controller) - target_kl: 0.1 - - # Whether to enable preference feedback PPO - use_pf_ppo: False - - # Preference feedback PPO settings - pf_ppo: - - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.trainer.config.PFPPOConfig - - # Method for reweighting samples: "pow", "max_min", or "max_random" - reweight_method: pow - - # Power used for weight scaling in "pow" method - weight_pow: 2.0 - -# config for the trainer -trainer: - - # Whether to balance batch sizes across distributed workers - balance_batch: True - - # Number of epochs in training - total_epochs: 30 - - # Total training steps (can be set explicitly or derived from epochs) - total_training_steps: null - - # The steps that will be profiled. null means no profiling. null or [1,2,5,...] - profile_steps: null - - # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None. - ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html - ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html - controller_nsight_options: - - # Select the API(s) to be traced. - trace: "cuda,nvtx,cublas,ucx" - - # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". - cuda-memory-usage: "true" - - # CUDA graphs will be traced as a whole - cuda-graph-trace: "graph" - - # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None. - worker_nsight_options: - - # Select the API(s) to be traced. - trace: "cuda,nvtx,cublas,ucx" - - # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false". - cuda-memory-usage: "true" - - # CUDA graphs will be traced as a whole - cuda-graph-trace: "graph" - - # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config. - capture-range: "cudaProfilerApi" - - # Specify the desired behavior when a capture range ends. - # In verl we need the orch.cuda.profiler.start/stop pair to repeats n times. - # valid values are "repeat-shutdown:n" or null. - # For normal whole step profiling, n = len(profile_steps); - # but for discrete profiling, n = len(profile_steps) * Number(subtasks). - # Or you can just leave it null and the program will use n = len(profile_steps) * 6; - capture-range-end: null - - # Send signal to the target application's process group. We let the program to exit by itself. - kill: none - - # Config for npu profiler. Must set when profile_steps is not None and torch_npu is available. - npu_profile: - - # Options for the npu profiler - options: - - # Storage path of collected data. - save_path: ./profiler_data - - # Collection level, optional values: level_none, level0, level1, level2. - level: level1 - - # Whether to enable memory analysis. - with_memory: False - - # Whether to record tensor shape. - record_shapes: False - - # Whether to record Device-side performance data. - with_npu: True - - # Whether to record Host-side performance data. - with_cpu: True - - # Whether to record Python call stack information. - with_module: False - - # Whether to record operator call stack information. - with_stack: False - - # Whether to automatically parse the data. - analysis: True - - # Project name for experiment tracking (e.g., wandb) - project_name: verl_examples - - # Experiment name for run identification in tracking tools - experiment_name: gsm8k - - # Logging backends to use: "console", "wandb", etc. - logger: [ 'console', 'wandb' ] - - # Number of generations to log during validation - log_val_generations: 0 - - # Directory for logging rollout data; no dump if null - rollout_data_dir: null - - # Directory for logging validation data; no dump if null - validation_data_dir: null - - # Number of nodes used in the training - nnodes: 1 - - # Number of GPUs per node - n_gpus_per_node: 8 - - # Save frequency (by iteration) for model checkpoints - save_freq: -1 - - # ESI refers to the elastic server instance used during training, similar to the training plan. For example, - # if you purchase 10 hours of computing power, the ESI will automatically shut down after 10 hours of training. - # To ensure a checkpoint is saved before ESI shuts down, the system will start saving a checkpoint in advance. - # The advance time is calculated as: Advance Time = Longest historical step duration + Checkpoint save duration + esi_redundant_time. - # Here, esi_redundant_time is a user-defined value that further extends the advance time for added safety. - esi_redundant_time: 0 - - # Resume mode: "auto", "disable", or "resume_path" - # "auto": resume from last checkpoint if available - # "disable": start from scratch - # "resume_path": resume from a user-defined path - resume_mode: auto - - # Path to resume training from (only used when resume_mode is "resume_path") - resume_from_path: null - - # Whether to run validation before training begins - val_before_train: True - - # Whether to run validation only - val_only: False - - # Validation frequency (in training iterations) - test_freq: -1 - - # Number of iterations to warm up the critic before updating policy - critic_warmup: 0 - - # Default path to distributed filesystem for saving checkpoints - default_hdfs_dir: null - - # Whether to delete local checkpoints after loading - del_local_ckpt_after_load: False - - # Default local directory for saving checkpoints - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} - - # Maximum number of actor checkpoints to keep - max_actor_ckpt_to_keep: null - - # Maximum number of critic checkpoints to keep - max_critic_ckpt_to_keep: null - - # Timeout (in seconds) for Ray worker to wait for registration - ray_wait_register_center_timeout: 300 - - # Device to run training on (e.g., "cuda", "cpu") - device: cuda - -# configs related to ray initialization -ray_init: - - # Number of CPUs for Ray. Use a fixed number instead of null when using SLURM. - num_cpus: null - - # Path to save Ray timeline JSON for performance profiling - timeline_json_file: null diff --git a/tests/trainer/config/test_algo_config_on_cpu.py b/tests/trainer/config/test_algo_config_on_cpu.py deleted file mode 100644 index 848a3ffe1..000000000 --- a/tests/trainer/config/test_algo_config_on_cpu.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -import torch -from omegaconf import OmegaConf - -from verl.trainer.config import AlgoConfig, KLControlConfig, PFPPOConfig -from verl.trainer.ppo.core_algos import ( - compute_gae_advantage_return, - compute_grpo_outcome_advantage, - get_adv_estimator_fn, -) -from verl.utils.config import omega_conf_to_dataclass - - -class TestAlgoConfig(unittest.TestCase): - """Test the AlgoConfig dataclass and its integration with core algorithms.""" - - def setUp(self): - """Set up test fixtures.""" - # Create a sample algorithm config as DictConfig (similar to what comes from YAML) - self.config_dict = { - "_target_": "verl.trainer.config.AlgoConfig", - "gamma": 0.99, - "lam": 0.95, - "adv_estimator": "gae", - "norm_adv_by_std_in_grpo": True, - "use_kl_in_reward": True, - "kl_penalty": "kl", - "kl_ctrl": { - "_target_": "verl.trainer.config.KLControlConfig", - "type": "adaptive", - "kl_coef": 0.002, - "horizon": 5000, - "target_kl": 0.05, - }, - "use_pf_ppo": True, - "pf_ppo": {"_target_": "verl.trainer.config.PFPPOConfig", "reweight_method": "max_min", "weight_pow": 3.0}, - } - self.omega_config = OmegaConf.create(self.config_dict) - - def test_dataclass_creation_from_dict(self): - """Test creating AlgoConfig from dictionary.""" - config = omega_conf_to_dataclass(self.config_dict) - - self.assertIsInstance(config, AlgoConfig) - self.assertEqual(config.gamma, 0.99) - self.assertEqual(config.lam, 0.95) - self.assertEqual(config.adv_estimator, "gae") - self.assertTrue(config.norm_adv_by_std_in_grpo) - self.assertTrue(config.use_kl_in_reward) - self.assertEqual(config.kl_penalty, "kl") - self.assertTrue(config.use_pf_ppo) - - def test_dataclass_creation_from_omega_config(self): - """Test creating AlgoConfig from OmegaConf DictConfig.""" - config = omega_conf_to_dataclass(self.omega_config) - - self.assertIsInstance(config, AlgoConfig) - self.assertEqual(config.gamma, 0.99) - self.assertEqual(config.lam, 0.95) - - def test_nested_configs(self): - """Test that nested configurations are properly converted.""" - config = omega_conf_to_dataclass(self.omega_config) - - # Test KL control config - self.assertIsInstance(config.kl_ctrl, KLControlConfig) - self.assertEqual(config.kl_ctrl.type, "adaptive") - self.assertEqual(config.kl_ctrl.kl_coef, 0.002) - self.assertEqual(config.kl_ctrl.horizon, 5000) - self.assertEqual(config.kl_ctrl.target_kl, 0.05) - - # Test PF PPO config - self.assertIsInstance(config.pf_ppo, PFPPOConfig) - self.assertEqual(config.pf_ppo.reweight_method, "max_min") - self.assertEqual(config.pf_ppo.weight_pow, 3.0) - - def test_default_values(self): - """Test that default values are properly set.""" - minimal_config = {"gamma": 0.8} - config = omega_conf_to_dataclass(minimal_config, AlgoConfig) - - self.assertEqual(config.gamma, 0.8) - self.assertEqual(config.lam, 1.0) # default value - self.assertEqual(config.adv_estimator, "gae") # default value - self.assertTrue(config.norm_adv_by_std_in_grpo) # default value - self.assertFalse(config.use_kl_in_reward) # default value - self.assertEqual(config.kl_penalty, "kl") # default value - self.assertFalse(config.use_pf_ppo) # default value - - def test_get_method_backward_compatibility(self): - """Test the get method for backward compatibility.""" - config = omega_conf_to_dataclass(self.omega_config) - - # Test existing attribute - self.assertEqual(config.get("gamma"), 0.99) - self.assertEqual(config.get("gamma", 1.0), 0.99) - - # Test non-existing attribute - self.assertIsNone(config.get("non_existing")) - self.assertEqual(config.get("non_existing", "default"), "default") - - def test_post_init_nested_configs(self): - """Test that __post_init__ properly initializes nested configs when None.""" - # Create config without nested configs - minimal_config = AlgoConfig(gamma=0.9) - - # Check that nested configs are initialized - self.assertIsNotNone(minimal_config.kl_ctrl) - self.assertIsInstance(minimal_config.kl_ctrl, KLControlConfig) - self.assertIsNone(minimal_config.pf_ppo) - - def test_config_init_from_yaml(self): - import os - - from hydra import compose, initialize_config_dir - - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): - cfg = compose(config_name="ppo_trainer") - algo_config = omega_conf_to_dataclass(cfg.algorithm) - from verl.trainer.config import AlgoConfig, PFPPOConfig - - assert isinstance(algo_config, AlgoConfig) - assert isinstance(algo_config.pf_ppo, PFPPOConfig) - - -class TestAlgoCompute(unittest.TestCase): - """Test the AlgoConfig dataclass and its integration with core algorithms.""" - - def setUp(self): - """Set up test fixtures.""" - self.algo_config = AlgoConfig( - gamma=0.99, - lam=0.95, - adv_estimator="gae", - norm_adv_by_std_in_grpo=True, - use_kl_in_reward=True, - kl_penalty="kl", - kl_ctrl=KLControlConfig(type="adaptive", kl_coef=0.002, horizon=5000, target_kl=0.05), - use_pf_ppo=True, - pf_ppo=PFPPOConfig(reweight_method="max_min", weight_pow=3.0), - ) - - def test_advantage_estimator_with_cfg(self): - """Test integration with advantage estimators from core_algos.""" - config = self.algo_config - - # Test GAE advantage estimator - adv_fn = get_adv_estimator_fn(config.adv_estimator) - self.assertIsNotNone(adv_fn) - - # Test with actual GAE computation - batch_size, seq_len = 2, 5 - token_level_rewards = torch.randn(batch_size, seq_len) - values = torch.randn(batch_size, seq_len) - response_mask = torch.ones(batch_size, seq_len) - - advantages, returns = compute_gae_advantage_return( - token_level_rewards=token_level_rewards, - values=values, - response_mask=response_mask, - gamma=config.gamma, - lam=config.lam, - ) - - self.assertEqual(advantages.shape, (batch_size, seq_len)) - self.assertEqual(returns.shape, (batch_size, seq_len)) - - def test_grpo_advantage_estimator_with_cfg(self): - """Test integration with GRPO advantage estimator.""" - grpo_config = AlgoConfig(adv_estimator="grpo", norm_adv_by_std_in_grpo=True) - - # Test GRPO advantage computation - batch_size, seq_len = 4, 3 - token_level_rewards = torch.tensor([[1.0, 0.5, 0.0], [2.0, 1.0, 0.0], [0.5, 0.2, 0.0], [1.5, 0.8, 0.0]]) - response_mask = torch.ones(batch_size, seq_len) - index = np.array([0, 0, 1, 1]) # Two groups - - advantages, returns = compute_grpo_outcome_advantage( - token_level_rewards=token_level_rewards, - response_mask=response_mask, - index=index, - norm_adv_by_std_in_grpo=grpo_config.norm_adv_by_std_in_grpo, - ) - - self.assertEqual(advantages.shape, (batch_size, seq_len)) - self.assertEqual(returns.shape, (batch_size, seq_len)) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/trainer/config/test_legacy_config_on_cpu.py b/tests/trainer/config/test_legacy_config_on_cpu.py deleted file mode 100644 index 39862aa22..000000000 --- a/tests/trainer/config/test_legacy_config_on_cpu.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import unittest - -from hydra import compose, initialize_config_dir -from hydra.core.global_hydra import GlobalHydra -from omegaconf import OmegaConf - - -class TestConfigComparison(unittest.TestCase): - """Test that current configs match their legacy counterparts exactly.""" - - def _compare_configs_recursively(self, current_config, legacy_config, path="", legacy_allow_missing=True): - """Recursively compare two OmegaConf configs and assert they are identical. - - Args: - legacy_allow_missing (bool): sometimes the legacy megatron config contains fewer keys and - we allow that to happen - """ - if isinstance(current_config, dict) and isinstance(legacy_config, dict): - current_keys = set(current_config.keys()) - legacy_keys = set(legacy_config.keys()) - - missing_in_current = legacy_keys - current_keys - missing_in_legacy = current_keys - legacy_keys - - if missing_in_current: - self.fail(f"Keys missing in current config at {path}: {missing_in_current}") - if missing_in_legacy: - # if the legacy - msg = f"Keys missing in legacy config at {path}: {missing_in_legacy}" - if legacy_allow_missing: - print(msg) - else: - self.fail(msg) - - for key in current_keys: - current_path = f"{path}.{key}" if path else key - if key in legacy_config: - self._compare_configs_recursively(current_config[key], legacy_config[key], current_path) - elif isinstance(current_config, list) and isinstance(legacy_config, list): - self.assertEqual( - len(current_config), - len(legacy_config), - f"List lengths differ at {path}: current={len(current_config)}, legacy={len(legacy_config)}", - ) - for i, (current_item, legacy_item) in enumerate(zip(current_config, legacy_config, strict=True)): - self._compare_configs_recursively(current_item, legacy_item, f"{path}[{i}]") - else: - self.assertEqual( - current_config, - legacy_config, - f"Values differ at {path}: current={current_config}, legacy={legacy_config}", - ) - - def test_ppo_trainer_config_matches_legacy(self): - """Test that ppo_trainer.yaml matches legacy_ppo_trainer.yaml exactly.""" - import os - - from hydra import compose, initialize_config_dir - from hydra.core.global_hydra import GlobalHydra - - GlobalHydra.instance().clear() - - try: - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): - current_config = compose(config_name="ppo_trainer") - - legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_trainer.yaml") - current_dict = OmegaConf.to_container(current_config, resolve=True) - legacy_dict = OmegaConf.to_container(legacy_config, resolve=True) - - if "defaults" in current_dict: - del current_dict["defaults"] - - self._compare_configs_recursively(current_dict, legacy_dict) - finally: - GlobalHydra.instance().clear() - - def test_ppo_megatron_trainer_config_matches_legacy(self): - """Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly.""" - - GlobalHydra.instance().clear() - - try: - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): - current_config = compose(config_name="ppo_megatron_trainer") - - legacy_config = OmegaConf.load("tests/trainer/config/legacy_ppo_megatron_trainer.yaml") - current_dict = OmegaConf.to_container(current_config, resolve=True) - legacy_dict = OmegaConf.to_container(legacy_config, resolve=True) - - if "defaults" in current_dict: - del current_dict["defaults"] - - self._compare_configs_recursively(current_dict, legacy_dict, legacy_allow_missing=True) - finally: - GlobalHydra.instance().clear() - - def test_load_component(self): - """Test that ppo_megatron_trainer.yaml matches legacy_ppo_megatron_trainer.yaml exactly.""" - - GlobalHydra.instance().clear() - configs_to_load = [ - ("verl/trainer/config/actor", "dp_actor"), - ("verl/trainer/config/actor", "megatron_actor"), - ("verl/trainer/config/ref", "dp_ref"), - ("verl/trainer/config/ref", "megatron_ref"), - ("verl/trainer/config/rollout", "rollout"), - ] - for config_dir, config_file in configs_to_load: - try: - with initialize_config_dir(config_dir=os.path.abspath(config_dir)): - compose(config_name=config_file) - finally: - GlobalHydra.instance().clear() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/trainer/ppo/__init__.py b/tests/trainer/ppo/__init__.py deleted file mode 100644 index 26d7c04fc..000000000 --- a/tests/trainer/ppo/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests for the PPO trainer module. -""" diff --git a/tests/trainer/ppo/test_core_algos_on_cpu.py b/tests/trainer/ppo/test_core_algos_on_cpu.py deleted file mode 100644 index 087a0d2f1..000000000 --- a/tests/trainer/ppo/test_core_algos_on_cpu.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import random -import unittest - -import pytest -import torch - -import verl.trainer.ppo.core_algos -from verl.trainer.ppo.core_algos import compute_gae_advantage_return, get_adv_estimator_fn, register_adv_est - - -def mock_test_fn(): - pass - - -class TestRegisterAdvEst(unittest.TestCase): - def setUp(self): - """Clear the registry before each test""" - verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear() - verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY = { - "gae": lambda x: x * 2, - "vtrace": lambda x: x + 1, - } - self.ADV_ESTIMATOR_REGISTRY = verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY - - def tearDown(self) -> None: - verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear() - return super().tearDown() - - def test_register_new_function(self): - """Test registering a new function with a string name""" - - @register_adv_est("test_estimator") - def test_fn(): - pass - - self.assertIn("test_estimator", self.ADV_ESTIMATOR_REGISTRY) - self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_estimator"], test_fn) - - def test_register_with_enum(self): - """Test registering with an enum value (assuming AdvantageEstimator exists)""" - from enum import Enum - - class AdvantageEstimator(Enum): - TEST = "test_enum_estimator" - - @register_adv_est(AdvantageEstimator.TEST) - def test_fn(): - pass - - self.assertIn("test_enum_estimator", self.ADV_ESTIMATOR_REGISTRY) - self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_enum_estimator"], test_fn) - - def test_duplicate_registration_same_function(self): - """Test that registering the same function twice doesn't raise an error""" - register_adv_est("duplicate_test")(mock_test_fn) - register_adv_est("duplicate_test")(mock_test_fn) - - self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["duplicate_test"], mock_test_fn) - - def test_duplicate_registration_different_function(self): - """Test that registering different functions with same name raises ValueError""" - - @register_adv_est("conflict_test") - def test_fn1(): - pass - - with self.assertRaises(ValueError): - - @register_adv_est("conflict_test") - def test_fn2(): - pass - - def test_decorator_preserves_function(self): - """Test that the decorator returns the original function""" - - def test_fn(): - return "original" - - decorated = register_adv_est("preserve_test")(test_fn) - self.assertEqual(decorated(), "original") - - def test_multiple_registrations(self): - """Test registering multiple different functions""" - init_adv_count = len(self.ADV_ESTIMATOR_REGISTRY) - - @register_adv_est("estimator1") - def fn1(): - pass - - @register_adv_est("estimator2") - def fn2(): - pass - - self.assertEqual(len(self.ADV_ESTIMATOR_REGISTRY), 2 + init_adv_count) - self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator1"], fn1) - self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator2"], fn2) - - def test_get_adv_estimator_fn_valid_names(self): - """Test that valid names return the correct function from registry.""" - # Test GAE - gae_fn = get_adv_estimator_fn("gae") - assert gae_fn(5) == 10 # 5 * 2 = 10 - - # Test Vtrace - vtrace_fn = get_adv_estimator_fn("vtrace") - assert vtrace_fn(5) == 6 # 5 + 1 = 6 - - def test_get_adv_estimator_fn_invalid_name(self): - """Test that invalid names raise ValueError.""" - with pytest.raises(ValueError) as excinfo: - get_adv_estimator_fn("invalid_name") - assert "Unknown advantage estimator simply: invalid_name" in str(excinfo.value) - - def test_get_adv_estimator_fn_case_sensitive(self): - """Test that name lookup is case-sensitive.""" - with pytest.raises(ValueError): - get_adv_estimator_fn("GAE") # Different case - - -def test_multi_turn_compute_gae_advantage_return(): - """Test multi-turn GAE skip observation tokens.""" - gamma = random.uniform(0.0, 1.0) - lam = random.uniform(0.0, 1.0) - - rewards = torch.tensor([[0.0, 0.0, 0.1, 0.1, 0.1, 0.0, 0.0, 0.1, 1.0, 0.0, 0.0]], dtype=torch.float) - - values1 = torch.tensor( - [ - [ - random.uniform(-100.0, 100.0), - random.random(), - 4.0, - 5.0, - 6.0, - random.uniform(-100.0, 0), - random.random(), - 7.0, - 9.0, - 0.0, - 0.0, - ] - ], - dtype=torch.float, - ) - - values2 = torch.tensor( - [ - [ - random.random(), - random.uniform(-100.0, 100.0), - 4.0, - 5.0, - 6.0, - random.random(), - random.uniform(0.0, 100.0), - 7.0, - 9.0, - 0.0, - 0.0, - ] - ], - dtype=torch.float, - ) - - response_mask = torch.tensor([[0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float) - - adv1, ret1 = compute_gae_advantage_return(rewards, values1, response_mask, gamma, lam) - adv2, ret2 = compute_gae_advantage_return(rewards, values2, response_mask, gamma, lam) - - ret1 *= response_mask - ret2 *= response_mask - assert torch.equal(adv1, adv2), f"{adv1=}, {adv2=}" - assert torch.equal(ret1, ret2), f"{ret1=}, {ret2=}" - print(f" [CORRECT] \n\n{adv1=}, \n\n{ret1=}") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/trainer/ppo/test_metric_utils_on_cpu.py b/tests/trainer/ppo/test_metric_utils_on_cpu.py deleted file mode 100644 index 50fe952c0..000000000 --- a/tests/trainer/ppo/test_metric_utils_on_cpu.py +++ /dev/null @@ -1,324 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests for the metric utilities in verl.trainer.ppo.metric_utils. -""" - -import unittest -from unittest.mock import MagicMock, patch - -import numpy as np -import torch - -from verl.trainer.ppo.metric_utils import ( - bootstrap_metric, - calc_maj_val, - compute_data_metrics, - compute_throughout_metrics, - compute_timing_metrics, - process_validation_metrics, -) -from verl.utils.metric import ( - reduce_metrics, -) - - -class TestReduceMetrics(unittest.TestCase): - """Tests for the reduce_metrics function.""" - - def test_reduce_metrics_basic(self): - """Test that reduce_metrics correctly computes means.""" - metrics = { - "loss": [1.0, 2.0, 3.0], - "accuracy": [0.0, 0.5, 1.0], - } - result = reduce_metrics(metrics) - - self.assertEqual(result["loss"], 2.0) - self.assertEqual(result["accuracy"], 0.5) - - def test_reduce_metrics_empty(self): - """Test that reduce_metrics handles empty lists.""" - metrics = { - "empty": [], - } - result = reduce_metrics(metrics) - - self.assertTrue(np.isnan(result["empty"])) - - def test_reduce_metrics_single_value(self): - """Test that reduce_metrics works with single values.""" - metrics = { - "single": [5.0], - } - result = reduce_metrics(metrics) - - self.assertEqual(result["single"], 5.0) - - -class TestComputeDataMetrics(unittest.TestCase): - """Tests for the compute_data_metrics function.""" - - def setUp(self): - """Set up common test data.""" - # Create a mock DataProto object - self.batch = MagicMock() - self.batch.batch = { - "token_level_scores": torch.tensor([[1.0, 2.0], [3.0, 4.0]]), - "token_level_rewards": torch.tensor([[0.5, 1.0], [1.5, 2.0]]), - "advantages": torch.tensor([[0.1, 0.2], [0.3, 0.4]]), - "returns": torch.tensor([[1.1, 1.2], [1.3, 1.4]]), - "responses": torch.zeros((2, 2)), # 2 samples, 2 tokens each - "attention_mask": torch.tensor( - [ - [1, 1, 1, 1], # 2 prompt tokens, 2 response tokens - [1, 1, 1, 1], - ] - ), - "response_mask": torch.tensor( - [ - [1, 1], # 2 response tokens - [1, 1], - ] - ), - "values": torch.tensor([[0.9, 1.0], [1.1, 1.2]]), - } - - def test_compute_data_metrics_with_critic(self): - """Test compute_data_metrics with critic enabled.""" - metrics = compute_data_metrics(self.batch, use_critic=True) - - # Check that all expected metrics are present - self.assertIn("critic/score/mean", metrics) - self.assertIn("critic/rewards/mean", metrics) - self.assertIn("critic/advantages/mean", metrics) - self.assertIn("critic/returns/mean", metrics) - self.assertIn("critic/values/mean", metrics) - self.assertIn("critic/vf_explained_var", metrics) - self.assertIn("response_length/mean", metrics) - self.assertIn("prompt_length/mean", metrics) - - # Check some specific values - self.assertAlmostEqual(metrics["critic/score/mean"], 5.0) # Sum of token_level_scores - self.assertAlmostEqual(metrics["critic/rewards/mean"], 2.5) # Sum of token_level_rewards - - def test_compute_data_metrics_without_critic(self): - """Test compute_data_metrics with critic disabled.""" - metrics = compute_data_metrics(self.batch, use_critic=False) - - # Check that critic-specific metrics are not present - self.assertNotIn("critic/values/mean", metrics) - self.assertNotIn("critic/vf_explained_var", metrics) - - # Check that other metrics are still present - self.assertIn("critic/score/mean", metrics) - self.assertIn("critic/rewards/mean", metrics) - self.assertIn("response_length/mean", metrics) - - -class TestComputeTimingMetrics(unittest.TestCase): - """Tests for the compute_timing_metrics function.""" - - def setUp(self): - """Set up common test data.""" - # Create a mock DataProto object - self.batch = MagicMock() - self.batch.batch = { - "responses": torch.zeros((2, 3)), # 2 samples, 3 response tokens each - "attention_mask": torch.tensor( - [ - [1, 1, 1, 1, 1, 1], # 3 prompt tokens, 3 response tokens - [1, 1, 1, 1, 1, 1], - ] - ), - } - - # Mock the _compute_response_info function to return known values - self.response_info = { - "prompt_length": torch.tensor([3.0, 3.0]), - "response_length": torch.tensor([3.0, 3.0]), - "response_mask": torch.ones((2, 3)), - } - - @patch("verl.trainer.ppo.metric_utils._compute_response_info") - def test_compute_timing_metrics(self, mock_compute_response_info): - """Test compute_timing_metrics with various timing data.""" - mock_compute_response_info.return_value = self.response_info - - timing_raw = { - "gen": 0.5, # 500ms - "ref": 0.3, # 300ms - "values": 0.2, # 200ms - } - - metrics = compute_timing_metrics(self.batch, timing_raw) - - # Check raw timing metrics - self.assertEqual(metrics["timing_s/gen"], 0.5) - self.assertEqual(metrics["timing_s/ref"], 0.3) - self.assertEqual(metrics["timing_s/values"], 0.2) - - # Check per-token timing metrics - # gen uses only response tokens (6 tokens) - self.assertAlmostEqual(metrics["timing_per_token_ms/gen"], 0.5 * 1000 / 6, places=5) - - # ref and values use all tokens (12 tokens) - self.assertAlmostEqual(metrics["timing_per_token_ms/ref"], 0.3 * 1000 / 12, places=5) - self.assertAlmostEqual(metrics["timing_per_token_ms/values"], 0.2 * 1000 / 12, places=5) - - -class TestComputeThroughputMetrics(unittest.TestCase): - """Tests for the compute_throughout_metrics function.""" - - def setUp(self): - """Set up common test data.""" - # Create a mock DataProto object - self.batch = MagicMock() - self.batch.meta_info = { - "global_token_num": [100, 200, 300], # 600 tokens total - } - - def test_compute_throughout_metrics(self): - """Test compute_throughout_metrics with various timing data.""" - timing_raw = { - "step": 2.0, # 2 seconds per step - } - - # Test with 1 GPU - metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=1) - - self.assertEqual(metrics["perf/total_num_tokens"], 600) - self.assertEqual(metrics["perf/time_per_step"], 2.0) - self.assertEqual(metrics["perf/throughput"], 600 / 2.0) # 300 tokens/sec - - # Test with 2 GPUs - metrics = compute_throughout_metrics(self.batch, timing_raw, n_gpus=2) - - self.assertEqual(metrics["perf/total_num_tokens"], 600) - self.assertEqual(metrics["perf/time_per_step"], 2.0) - self.assertEqual(metrics["perf/throughput"], 600 / (2.0 * 2)) # 150 tokens/sec/GPU - - -class TestBootstrapMetric(unittest.TestCase): - """Tests for the bootstrap_metric function.""" - - def test_bootstrap_metric_basic(self): - """Test bootstrap_metric with simple data and functions.""" - data = [1, 2, 3, 4, 5] - reduce_fns = [np.mean, np.max] - - # Use a fixed seed for reproducibility - result = bootstrap_metric(data, subset_size=3, reduce_fns=reduce_fns, n_bootstrap=100, seed=42) - - # Check that we get two results (one for each reduce_fn) - self.assertEqual(len(result), 2) - - # Each result should be a tuple of (mean, std) - mean_result, max_result = result - self.assertEqual(len(mean_result), 2) - self.assertEqual(len(max_result), 2) - - # The mean of means should be close to the true mean (3.0) - self.assertAlmostEqual(mean_result[0], 3.0, delta=0.3) - - # The mean of maxes should be close to the expected value for samples of size 3 - # For samples of size 3 from [1,2,3,4,5], the expected max is around 4.0-4.5 - self.assertGreater(max_result[0], 3.5) - self.assertLess(max_result[0], 5.0) - - def test_bootstrap_metric_empty(self): - """Test bootstrap_metric with empty data.""" - with self.assertRaises(ValueError): - bootstrap_metric([], subset_size=1, reduce_fns=[np.mean]) - - -class TestCalcMajVal(unittest.TestCase): - """Tests for the calc_maj_val function.""" - - def test_calc_maj_val_basic(self): - """Test calc_maj_val with simple data.""" - data = [ - {"pred": "A", "val": 0.9}, - {"pred": "B", "val": 0.8}, - {"pred": "A", "val": 0.7}, - ] - - result = calc_maj_val(data, vote_key="pred", val_key="val") - - # "A" is the majority vote, so we should get the first "val" for "A" - self.assertEqual(result, 0.9) - - def test_calc_maj_val_tie(self): - """Test calc_maj_val with tied votes.""" - data = [ - {"pred": "A", "val": 0.9}, - {"pred": "B", "val": 0.8}, - {"pred": "B", "val": 0.7}, - {"pred": "A", "val": 0.6}, - ] - - # In case of a tie, the first key in sorted order wins - # This depends on Python's dict implementation, but for this test - # we just verify that one of the valid values is returned - result = calc_maj_val(data, vote_key="pred", val_key="val") - - self.assertTrue(result in [0.9, 0.8]) - - -class TestProcessValidationMetrics(unittest.TestCase): - """Tests for the process_validation_metrics function.""" - - def test_process_validation_metrics_basic(self): - """Test process_validation_metrics with simple data.""" - data_sources = ["source1", "source1", "source2"] - sample_inputs = ["prompt1", "prompt1", "prompt2"] - infos_dict = { - "score": [0.8, 0.9, 0.7], - } - - result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42) - - # Check the structure of the result - self.assertIn("source1", result) - self.assertIn("source2", result) - - # Check that source1 has metrics for score - self.assertIn("score", result["source1"]) - - # Check that mean@2 is present for source1/score - self.assertIn("mean@2", result["source1"]["score"]) - - # Check the value of mean@2 for source1/score - self.assertAlmostEqual(result["source1"]["score"]["mean@2"], 0.85) - - def test_process_validation_metrics_with_pred(self): - """Test process_validation_metrics with prediction data.""" - data_sources = ["source1", "source1", "source1"] - sample_inputs = ["prompt1", "prompt1", "prompt1"] - infos_dict = { - "score": [0.8, 0.9, 0.7], - "pred": ["A", "B", "A"], - } - - result = process_validation_metrics(data_sources, sample_inputs, infos_dict, seed=42) - - # Check that majority voting metrics are present - self.assertIn("maj@2/mean", result["source1"]["score"]) - - # For bootstrap with n=2, the majority vote could be either A or B - # depending on the random sampling, so we don't check the exact value - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/utils/_test_module.py b/tests/utils/_test_module.py deleted file mode 100644 index ec3d5fb65..000000000 --- a/tests/utils/_test_module.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Test module for import_utils.load_extern_type testing -class TestClass: - """A test class to be imported by load_extern_type""" - - def __init__(self, value=None): - self.value = value or "default" - - def get_value(self): - return self.value - - -TEST_CONSTANT = "test_constant_value" - - -def test_function(): - return "test_function_result" diff --git a/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py b/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py deleted file mode 100644 index 203494bd9..000000000 --- a/tests/utils/ckpt/test_esi_save_ckpt_on_cpu.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import time -from datetime import datetime, timedelta -from unittest import TestCase - -from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi - - -class TestShouldSaveCkptEsi(TestCase): - def test_no_expiration_timestamp(self): - """Test case when no expiration timestamp is set""" - os.environ.pop("MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP", None) - os.environ.pop("SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP", None) - self.assertFalse(should_save_ckpt_esi(100)) - - def test_mlp_expiration_valid(self): - """Test valid MLP expiration timestamp requiring save""" - current_time = time.time() - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 90) - self.assertTrue(should_save_ckpt_esi(30)) # max_steps_duration=30 seconds - - def test_mlp_expiration_passed(self): - """Test expired MLP timestamp""" - current_time = time.time() - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time - 10) - self.assertFalse(should_save_ckpt_esi(30)) - - def test_mlp_invalid_timestamp(self): - """Test invalid MLP timestamp format""" - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = "invalid" - self.assertFalse(should_save_ckpt_esi(30)) - - def test_mlp_expiration_not_reached(self): - """Test MLP expiration timestamp with insufficient remaining time""" - current_time = time.time() - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 200) - self.assertFalse(should_save_ckpt_esi(30)) # max_steps_duration=30 - - def test_aws_expiration_not_reached(self): - """Test AWS expiration timestamp with sufficient remaining time""" - now = datetime.now() - expiration = now + timedelta(minutes=100) # Exceeds 90-minute threshold - os.environ["SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(int(expiration.timestamp())) - self.assertFalse(should_save_ckpt_esi(30 * 60)) - - def test_redundant_time(self): - """Test redundant_time parameter effect""" - current_time = time.time() - # Total required: 60+30+30=120 seconds - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 120) - self.assertTrue(should_save_ckpt_esi(30, redundant_time=30)) - - def test_zero_max_steps_duration(self): - """Test zero max_steps_duration""" - current_time = time.time() - os.environ["MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP"] = str(current_time + 60) - self.assertFalse(should_save_ckpt_esi(0)) diff --git a/tests/utils/dataset/test_create_rl_sampler_on_cpu.py b/tests/utils/dataset/test_create_rl_sampler_on_cpu.py deleted file mode 100644 index 35bf5a3ab..000000000 --- a/tests/utils/dataset/test_create_rl_sampler_on_cpu.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2025 Amazon.com Inc and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -test create_rl_sampler -""" - -from collections.abc import Sized - -import pytest -import torch -from omegaconf import DictConfig, OmegaConf -from torch.utils.data import Dataset, RandomSampler - -from verl.experimental.dataset.sampler import AbstractCurriculumSampler -from verl.trainer.main_ppo import create_rl_sampler - - -class RandomCurriculumSampler(AbstractCurriculumSampler): - def __init__( - self, - data_source: Sized, - data_config: DictConfig, - ): - train_dataloader_generator = torch.Generator() - train_dataloader_generator.manual_seed(1) - sampler = RandomSampler(data_source=data_source) - self.sampler = sampler - - def __iter__(self): - return self.sampler.__iter__() - - def __len__(self) -> int: - return len(self.sampler) - - def update(self, batch) -> None: - return - - -class MockIncorrectSampler: - """A fake sampler class that does not adhere to the AbstractCurriculumSampler interface.""" - - def __init__(self, data_source, data_config): - pass - - -class MockChatDataset(Dataset): - def __init__(self): - self.data = [ - {"prompt": "What's your name?", "response": "My name is Assistant."}, - {"prompt": "How are you?", "response": "I'm doing well, thank you."}, - {"prompt": "What is the capital of France?", "response": "Paris."}, - { - "prompt": "Tell me a joke.", - "response": "Why did the chicken cross the road? To get to the other side!", - }, - {"prompt": "What is 2+2?", "response": "4"}, - ] - - def __getitem__(self, index): - return self.data[index] - - def __len__(self): - return len(self.data) - - -def test_create_custom_curriculum_samper(): - data_config = OmegaConf.create( - { - "dataloader_num_workers": 0, - "sampler": { - "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu", - "class_name": "RandomCurriculumSampler", - }, - } - ) - - dataset = MockChatDataset() - - # doesn't raise - create_rl_sampler(data_config, dataset) - - -def test_create_custom_curriculum_samper_wrong_class(): - data_config = OmegaConf.create( - { - "sampler": { - "class_path": "pkg://tests.utils.dataset.test_create_rl_sampler_on_cpu", - "class_name": "MockIncorrectSampler", - } - } - ) - - dataset = MockChatDataset() - - # MockIncorrectSampler is not an instance of AbstractCurriculumSampler, so raises - with pytest.raises(AssertionError): - create_rl_sampler(data_config, dataset) diff --git a/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py b/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py deleted file mode 100644 index 8028d44e5..000000000 --- a/tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Test the MultiTurnSFTDataset implementation -""" - -import os - -import pandas as pd -import torch -from transformers import AutoTokenizer - -from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset - - -def test_multiturn_sft_dataset(): - print("Starting test...") - # Create a temporary parquet file with test data - test_data = { - "messages": [ - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "2+2 equals 4."}, - {"role": "user", "content": "And what is 4+4?"}, - {"role": "assistant", "content": "4+4 equals 8."}, - ], - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Tell me a joke."}, - {"role": "assistant", "content": "Why did the chicken cross the road?"}, - {"role": "user", "content": "Why?"}, - {"role": "assistant", "content": "To get to the other side!"}, - ], - ] - } - - # Create test directory if it doesn't exist - os.makedirs("test_data", exist_ok=True) - test_file = "test_data/test.parquet" - - # Save test data to parquet - df = pd.DataFrame(test_data) - df.to_parquet(test_file) - - # Initialize tokenizer and dataset - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct") - config = {"max_length": 512, "truncation": "error", "multiturn": {"messages_key": "messages"}} - dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config) - - # Test 1: Dataset Length - assert len(dataset) == 2, f"Expected dataset length 2, got {len(dataset)}" - - # Get items for testing - item0 = dataset[0] # Math conversation - item1 = dataset[1] # Joke conversation - - # Test 2: Required Keys and Types - required_keys = ["input_ids", "attention_mask", "position_ids", "loss_mask"] - for key in required_keys: - assert key in item0, f"Missing key {key} in dataset item" - assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key}" - assert item0[key].dtype == torch.long, f"Expected torch.long for {key}, got {item0[key].dtype}" - - # Test 3: Shape Consistency - assert item0["loss_mask"].shape == item0["input_ids"].shape, "Loss mask shape doesn't match input_ids shape" - assert item0["attention_mask"].shape == item0["input_ids"].shape, ( - "Attention mask shape doesn't match input_ids shape" - ) - assert item0["position_ids"].shape == item0["input_ids"].shape, "Position IDs shape doesn't match input_ids shape" - - # Test 4: Loss Mask Pattern - Math Conversation - loss_mask0 = item0["loss_mask"] - input_ids0 = item0["input_ids"] - - # Find assistant response positions - assistant_positions0 = torch.where(loss_mask0 == 1)[0] - assert len(assistant_positions0) > 0, "No assistant positions found in loss mask" - - # Decode and verify assistant responses - assistant_text0 = tokenizer.decode(input_ids0[loss_mask0 == 1]) - print(f"Math conversation assistant text: {assistant_text0}") - assert "2+2 equals 4" in assistant_text0, "First assistant response not found" - assert "4+4 equals 8" in assistant_text0, "Second assistant response not found" - - # Test 5: Loss Mask Pattern - Joke Conversation - loss_mask1 = item1["loss_mask"] - input_ids1 = item1["input_ids"] - - # Find assistant response positions - assistant_positions1 = torch.where(loss_mask1 == 1)[0] - assert len(assistant_positions1) > 0, "No assistant positions found in loss mask" - - # Decode and verify assistant responses - assistant_text1 = tokenizer.decode(input_ids1[loss_mask1 == 1]) - print(f"Joke conversation assistant text: {assistant_text1}") - assert "chicken cross the road" in assistant_text1, "First assistant response not found" - assert "other side" in assistant_text1, "Second assistant response not found" - - # Test 6: Attention Mask Pattern - attention_mask0 = item0["attention_mask"] - sequence_length = torch.sum(attention_mask0) - assert sequence_length > 0, "No tokens marked as attended in attention mask" - assert torch.all(attention_mask0[:sequence_length] == 1), "Incorrect attention mask pattern" - if sequence_length < len(attention_mask0): - assert torch.all(attention_mask0[sequence_length:] == 0), "Padding not properly masked" - - # Test 7: Position IDs Pattern - position_ids0 = item0["position_ids"] - assert torch.equal(position_ids0[:sequence_length], torch.arange(sequence_length)), ( - "Position IDs not sequential for non-padded tokens" - ) - if sequence_length < len(position_ids0): - assert torch.all(position_ids0[sequence_length:] == 0), "Padding position IDs not zero" - - # Test 8: Verify loss mask for assistant responses - # Get the full conversation text - full_text = tokenizer.decode(input_ids0) - print(f"\nFull conversation text:\n{full_text}") - - # Get the assistant responses - assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 1]) - print(f"\nAssistant responses (from loss mask):\n{assistant_text}") - - # Verify that loss mask is set for all assistant responses - for msg in test_data["messages"][0]: # First conversation - if msg["role"] == "assistant": - # The content should appear in the masked text - assert msg["content"] in assistant_text, f"Assistant message '{msg['content']}' not found in masked text" - - # The content should NOT appear in the non-masked text - non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) - assert msg["content"] not in non_assistant_text, ( - f"Assistant message '{msg['content']}' found in non-assistant text" - ) - - # Test 9: Verify non-assistant parts have loss_mask=0 - # Get non-assistant text - non_assistant_text = tokenizer.decode(input_ids0[loss_mask0 == 0]) - print(f"\nNon-assistant text (from loss mask):\n{non_assistant_text}") - - # Verify that system and user messages are in the non-assistant text - for msg in test_data["messages"][0]: # First conversation - if msg["role"] in ["system", "user"]: - assert msg["content"] in non_assistant_text, ( - f"{msg['role'].title()} message '{msg['content']}' not found in non-assistant text" - ) - - # And verify they're NOT in the assistant text - assert msg["content"] not in assistant_text, ( - f"{msg['role'].title()} message '{msg['content']}' found in assistant text" - ) - - # Test 10: Verify padding behavior - padding_config = {"max_length": 1024, "truncation": "error", "multiturn": {"messages_key": "messages"}} - small_dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=padding_config) - padded_item = small_dataset[0] - - # Get actual sequence length (before padding) - actual_length = torch.sum(padded_item["attention_mask"]) - - # Verify padding tokens - assert torch.all(padded_item["input_ids"][actual_length:] == tokenizer.pad_token_id), ( - "Padding tokens not set correctly" - ) - assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding" - assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding" - - print("All tests passed!") - print("Starting test...") diff --git a/tests/utils/dataset/test_rl_dataset_on_cpu.py b/tests/utils/dataset/test_rl_dataset_on_cpu.py deleted file mode 100644 index 2afc3ef49..000000000 --- a/tests/utils/dataset/test_rl_dataset_on_cpu.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import torch -from omegaconf import OmegaConf -from torch.utils.data import DataLoader - - -def get_gsm8k_data(): - # prepare test dataset - local_folder = os.path.expanduser("~/verl-data/gsm8k/") - local_path = os.path.join(local_folder, "train.parquet") - os.makedirs(local_folder, exist_ok=True) - return local_path - - -def test_rl_dataset(): - from verl.utils import hf_tokenizer - from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn - - tokenizer = hf_tokenizer("deepseek-ai/deepseek-coder-1.3b-instruct") - local_path = get_gsm8k_data() - config = OmegaConf.create( - { - "prompt_key": "prompt", - "max_prompt_length": 256, - "filter_overlong_prompts": True, - "filter_overlong_prompts_workers": 2, - } - ) - dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config) - - dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) - - a = next(iter(dataloader)) - - from verl import DataProto - - tensors = {} - non_tensors = {} - - for key, val in a.items(): - if isinstance(val, torch.Tensor): - tensors[key] = val - else: - non_tensors[key] = val - - data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) - assert "input_ids" in data_proto.batch - - data = dataset[0]["input_ids"] - output = tokenizer.batch_decode([data])[0] - print(f"type: type{output}") - print(f"\n\noutput: {output}") - - -def test_image_rl_data(): - from verl.utils import hf_processor, hf_tokenizer - from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn - - tokenizer = hf_tokenizer("Qwen/Qwen2-VL-2B-Instruct") - processor = hf_processor("Qwen/Qwen2-VL-2B-Instruct") - config = OmegaConf.create( - { - "prompt_key": "prompt", - "max_prompt_length": 1024, - "filter_overlong_prompts": True, - "filter_overlong_prompts_workers": 2, - } - ) - dataset = RLHFDataset( - data_files=os.path.expanduser("~/data/geo3k/train.parquet"), - tokenizer=tokenizer, - config=config, - processor=processor, - ) - - dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) - - a = next(iter(dataloader)) - - from verl import DataProto - - tensors = {} - non_tensors = {} - - for key, val in a.items(): - if isinstance(val, torch.Tensor): - tensors[key] = val - else: - non_tensors[key] = val - - data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) - - assert "multi_modal_data" in data_proto.non_tensor_batch, data_proto - assert "multi_modal_inputs" in data_proto.non_tensor_batch, data_proto - - data = dataset[0]["input_ids"] - output = tokenizer.batch_decode([data])[0] - print(f"type: type{output}") - print(f"\n\noutput: {output}") diff --git a/tests/utils/dataset/test_sft_dataset_on_cpu.py b/tests/utils/dataset/test_sft_dataset_on_cpu.py deleted file mode 100644 index 680fce45a..000000000 --- a/tests/utils/dataset/test_sft_dataset_on_cpu.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -from verl.utils import hf_tokenizer -from verl.utils.dataset.sft_dataset import SFTDataset - - -def get_gsm8k_data(): - # prepare test dataset - local_folder = os.path.expanduser("~/verl-data/gsm8k/") - local_path = os.path.join(local_folder, "train.parquet") - return local_path - - -def test_sft_cot_dataset(): - tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct") - local_path = get_gsm8k_data() - from omegaconf import OmegaConf - - dataset = SFTDataset( - parquet_files=local_path, - tokenizer=tokenizer, - config=OmegaConf.create( - { - "prompt_key": "prompt", - "prompt_dict_keys": ["content"], - "response_key": "extra_info", - "response_dict_keys": ["answer"], - "max_length": 512, - } - ), - ) - - data = dataset[0]["input_ids"] - output = tokenizer.batch_decode([data])[0] - assert len(output) > 1 - assert isinstance(output, str) - - -def test_sft_dataset(): - tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct") - local_path = get_gsm8k_data() - from omegaconf import OmegaConf - - dataset = SFTDataset( - parquet_files=local_path, - tokenizer=tokenizer, - config=OmegaConf.create( - { - "prompt_key": "extra_info", - "prompt_dict_keys": ["question"], - "response_key": "extra_info", - "response_dict_keys": ["answer"], - "max_length": 512, - } - ), - ) - - data = dataset[0]["input_ids"] - output = tokenizer.batch_decode([data])[0] - assert len(output) > 1 - assert isinstance(output, str) diff --git a/tests/utils/megatron/test_pipeline_parallel.py b/tests/utils/megatron/test_pipeline_parallel.py deleted file mode 100644 index 24a416987..000000000 --- a/tests/utils/megatron/test_pipeline_parallel.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from verl.model_merger.megatron_model_merger import get_dynamic_pipeline_shards -from verl.utils.megatron.pipeline_parallel import make_batch_generator - - -def test_make_batch_generator_no_vpp(): - batches = [1, 2, 3] - vpp_size = 1 - generator = make_batch_generator(batches, vpp_size) - assert list(generator) == batches - - -def test_make_batch_generator_with_vpp(): - batches = [{"data": 1}, {"data": 2}] - vpp_size = 2 - generators = make_batch_generator(batches, vpp_size) - assert isinstance(generators, list) - assert len(generators) == vpp_size - - # Check each generator yields the original batches - for gen in generators: - assert list(gen) == batches - - -def test_make_batch_generator_empty(): - batches = [] - vpp_size = 1 - generator = make_batch_generator(batches, vpp_size) - assert list(generator) == [] - - vpp_size = 3 - generators = make_batch_generator(batches, vpp_size) - assert len(generators) == vpp_size - for gen in generators: - assert list(gen) == [] - - -@pytest.mark.parametrize( - "layer_num,pp_size,gt", - [ - (61, 8, [6, 8, 8, 8, 8, 8, 8, 7]), - (61, 7, [8, 9, 9, 9, 9, 9, 8]), - (61, 1, [61]), - (61, 0, ValueError), - (10, 16, ValueError), - ], -) -def test_get_dynamic_pipeline_shards(layer_num, pp_size, gt): - if isinstance(gt, list): - shards = get_dynamic_pipeline_shards(layer_num, pp_size) - assert len(shards) == len(gt) == pp_size, f"Expected {pp_size} shards, got {len(shards)}" - assert all([shard == gt[i] for i, shard in enumerate(shards)]), f"Expected shards {gt}, got {shards}" - elif issubclass(gt, Exception): - with pytest.raises(gt): - shards = get_dynamic_pipeline_shards(layer_num, pp_size) diff --git a/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py b/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py deleted file mode 100644 index aaa427183..000000000 --- a/tests/utils/reward_score/reward_score/test_sandbox_fusion_on_cpu.py +++ /dev/null @@ -1,692 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import multiprocessing -import os -import time -from concurrent.futures import ProcessPoolExecutor -from unittest.mock import patch - -import pytest - -# Import the function to be tested -from verl.utils.reward_score.sandbox_fusion.utils import check_correctness - -# Get SANDBOX_URL from environment variable -SANDBOX_URL = os.environ.get("SANDBOX_FUSION_URL") -# Define skip condition and reason -skip_reason = "SANDBOX_FUSION_URL environment variable not set" -skip_condition = not SANDBOX_URL - -# --- Test code (for real API calls) --- -CODE_SUCCESS = """ -import sys -data = sys.stdin.read() -if data == 'input1': - print('output1\\n', end='') -elif data == 'input2': - print('output2\\n', end='') -else: - print('unexpected input', end='') -""" - -CODE_WRONG_OUTPUT = """ -print('wrong_output\\n', end='') -""" - -CODE_COMPILE_ERROR = """ -a=b -""" - -CODE_RUNTIME_ERROR = """ -import sys -print("About to raise error", file=sys.stderr) -raise ValueError("This is a runtime error") -""" - -CODE_TIMEOUT = """ -import time -import sys -print("Sleeping...", file=sys.stderr) -time.sleep(10) # Sleep time should be longer than the timeout set in the test -print("Finished sleeping", file=sys.stderr) -""" - -# --- Test input/output data --- -INPUT_OUTPUT_VALID = {"inputs": ["input1", "input2"], "outputs": ["output1\n", "output2\n"]} - -INPUT_OUTPUT_SINGLE = {"inputs": ["input1"], "outputs": ["output1\n"]} - -INPUT_OUTPUT_MISMATCH = {"inputs": ["input1"], "outputs": ["output1\n", "output2\n"]} - -INPUT_OUTPUT_INVALID_MISSING_KEY = {"inputs": ["input1"]} - -# --- Integration test cases (calling real API) --- - - -@pytest.mark.skipif(skip_condition, reason=skip_reason) -def test_integration_success_correct(): - """Integration test: Code is correct, output is correct""" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_SUCCESS) - assert results == [True, True] - assert metadata_list[0]["status"] == "success" - assert metadata_list[0]["stdout"] == "output1\n" - assert metadata_list[1]["status"] == "success" - assert metadata_list[1]["stdout"] == "output2\n" - - -@pytest.mark.skipif(skip_condition, reason=skip_reason) -def test_integration_success_wrong_output(): - """Integration test: Code runs successfully, but output is wrong""" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_WRONG_OUTPUT) - assert results == [False, False] - assert metadata_list[0]["status"] == "wrong_answer" - assert metadata_list[0]["stdout"] == "wrong_output\n" - assert metadata_list[1]["status"] == "wrong_answer" - - -@pytest.mark.skipif(skip_condition, reason=skip_reason) -def test_integration_compile_error(): - """Integration test: Code causes compile error""" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_VALID, CODE_COMPILE_ERROR, language="cpp") - assert results == [-4, -4] - assert metadata_list[0]["status"] == "compile_error" - assert metadata_list[1]["status"] == "compile_error" - - -@pytest.mark.skipif(skip_condition, reason=skip_reason) -def test_integration_runtime_error(): - """Integration test: Code causes runtime error""" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_RUNTIME_ERROR) - assert results == [-2] - assert metadata_list[0]["status"] == "runtime_error" - # More assertions can be added based on the actual API response, e.g., exit_code, stderr - - -@pytest.mark.skipif(skip_condition, reason=skip_reason) -def test_integration_runtime_timeout(): - """Integration test: Code causes runtime timeout""" - test_timeout = 5 # Set a timeout shorter than the sleep time in CODE_TIMEOUT - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_SINGLE, CODE_TIMEOUT, timeout=test_timeout) - assert results == [-3] - assert metadata_list[0]["status"] == "timeout" - # More assertions can be added based on the actual API response, e.g., run_status - - -@pytest.mark.skipif(skip_condition, reason=skip_reason) -def test_integration_concurrency_high_load(): - """Integration test: High concurrency (100 cases) against real API with mixed results (success, wrong - answer, timeout)""" - concurrency_level = 100 - # Indices for different expected outcomes - wrong_answer_indices = {10, 25, 50} - timeout_indices = {5, 30, 60, 90} # Indices where we expect a timeout - - # Generate 100 input/output pairs and code - high_load_inputs = [] - high_load_outputs = [] - expected_results_map = {} # Store expected result for each index - - for i in range(concurrency_level): - if i in timeout_indices: - # Use a special input to trigger timeout in the code - high_load_inputs.append(f"input_timeout_{i}") - # Output doesn't matter for timeout, but keep it consistent - high_load_outputs.append(f"output_{i}\n") - expected_results_map[i] = -3 # Expect timeout - elif i in wrong_answer_indices: - high_load_inputs.append(f"input_{i}") - # Intentionally set wrong expected output - high_load_outputs.append(f"wrong_output_{i}\n") - expected_results_map[i] = False # Expect wrong answer - else: - high_load_inputs.append(f"input_{i}") - # Correct expected output - high_load_outputs.append(f"output_{i}\n") - expected_results_map[i] = True # Expect success - - high_load_in_outs = {"inputs": high_load_inputs, "outputs": high_load_outputs} - - # Code that handles normal inputs, and sleeps on specific "timeout" inputs - code_mixed_concurrent = """ -import sys -import time -data = sys.stdin.read() -if data.startswith('input_timeout_'): - time.sleep(20) # Sleep longer than the test timeout - print(f"output_{data.split('_')[-1]}\\n", end='') # Still print something in case it finishes early -elif data.startswith('input_'): - print(f"output_{data.split('_')[-1]}\\n", end='') -else: - print("unknown_input\\n", end='') -""" - # Set a reasonable timeout per case (must be less than the sleep time in the code) - test_timeout = 15 # Allow slightly more time due to potential API load, but less than 20s sleep - - start_time = time.time() - results, metadata_list = check_correctness( - SANDBOX_URL, - high_load_in_outs, - code_mixed_concurrent, # Use the new code - timeout=test_timeout, - ) - end_time = time.time() - duration = end_time - start_time - print( - f"\nHigh concurrency test ({concurrency_level} cases with {len(wrong_answer_indices)} wrong answers, " - f"{len(timeout_indices)} timeouts) duration: {duration:.2f} seconds" - ) - - # Verify results against the expected map - assert len(results) == concurrency_level, f"Expected {concurrency_level} results, got {len(results)}" - - correct_count = 0 - wrong_count = 0 - timeout_count = 0 - unexpected_results = [] - for i, r in enumerate(results): - expected = expected_results_map[i] - if r == expected: - if expected is True: - correct_count += 1 - elif expected is False: - wrong_count += 1 - elif expected == -3: - timeout_count += 1 - else: - unexpected_results.append((i, r, f"Expected {expected}")) - - print( - f"Correct results (True): {correct_count}/" - f"{concurrency_level - len(wrong_answer_indices) - len(timeout_indices)}" - ) - print(f"Expected wrong answers (False, correctly identified): {wrong_count}/{len(wrong_answer_indices)}") - print(f"Expected timeouts (-3, correctly identified): {timeout_count}/{len(timeout_indices)}") - - if unexpected_results: - print("Unexpected results found:") - for idx, res, expected_str in unexpected_results[:10]: # Print first 10 unexpected - print(f" Index {idx}: Got {res}, {expected_str}. Metadata: {metadata_list[idx]}") - raise AssertionError(f"Found {len(unexpected_results)} unexpected results.") - - assert correct_count == concurrency_level - len(wrong_answer_indices) - len(timeout_indices), ( - "Incorrect number of successful results" - ) - assert wrong_count == len(wrong_answer_indices), "Incorrect number of identified wrong answers" - assert timeout_count == len(timeout_indices), "Incorrect number of identified timeouts" - - # Verify metadata count and basic status of one of each type - assert len(metadata_list) == concurrency_level - # Find the first correct index - first_correct_index = next( - i for i in range(concurrency_level) if i not in wrong_answer_indices and i not in timeout_indices - ) - assert metadata_list[first_correct_index]["status"] == "success" - assert metadata_list[first_correct_index]["stdout"] == f"output_{first_correct_index}\n" - - # Check the status of the first intentionally wrong case - first_wrong_index = min(wrong_answer_indices) - assert metadata_list[first_wrong_index]["status"] == "wrong_answer" - assert metadata_list[first_wrong_index]["stdout"] == f"output_{first_wrong_index}\n" - assert metadata_list[first_wrong_index]["expected_output"] == f"wrong_output_{first_wrong_index}\n" - - # Check the status of the first intentionally timeout case - first_timeout_index = min(timeout_indices) - assert metadata_list[first_timeout_index]["status"] == "timeout" - # For timeout, stdout might be None or empty depending on when the timeout occurred - # assert metadata_list[first_timeout_index]["stdout"] is None or metadata_list[first_timeout_index]["stdout"] == "" - - -# --- Unit test cases (using mock) --- - - -@patch("verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api") -def test_unit_concurrency_order(mock_call_sandbox_api): - sandbox_url = "mock_url" - generation = "print(input())" - language = "python" - timeout = 5 - in_outs = {"inputs": ["input1", "input2", "input3"], "outputs": ["output1", "output2", "output3"]} - - def side_effect(*args, **kwargs): - stdin = kwargs.get("stdin") - if stdin == "input1": - return ( - {"status": "Success", "run_result": {"status": "Finished", "stdout": "output1", "return_code": 0}}, - None, - ) - elif stdin == "input2": - time.sleep(0.1) - return ( - {"status": "Success", "run_result": {"status": "Finished", "stdout": "output2", "return_code": 0}}, - None, - ) - elif stdin == "input3": - return ( - {"status": "Success", "run_result": {"status": "Finished", "stdout": "output3", "return_code": 0}}, - None, - ) - else: - return (None, "Unknown input in mock") - - mock_call_sandbox_api.side_effect = side_effect - - results, metadata_list = check_correctness(sandbox_url, in_outs, generation, timeout, language) - - assert results == [True, True, True] - assert len(metadata_list) == 3 - assert metadata_list[0]["case_index"] == 0 - assert metadata_list[0]["status"] == "success" - assert metadata_list[1]["case_index"] == 1 - assert metadata_list[1]["status"] == "success" - assert metadata_list[2]["case_index"] == 2 - assert metadata_list[2]["status"] == "success" - assert mock_call_sandbox_api.call_count == 3 - - -@patch("verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api") -def test_unit_api_timeout_error_concurrent(mock_call_sandbox_api): - sandbox_url = "mock_url" - generation = "print(input())" - language = "python" - timeout = 5 - in_outs = {"inputs": ["input1", "input2_timeout", "input3"], "outputs": ["output1", "output2", "output3"]} - - api_error_message = "API Call Failed: Gateway Timeout (504) on attempt 3/3" - - def side_effect(*args, **kwargs): - stdin = kwargs.get("stdin") - if stdin == "input1": - return ( - {"status": "Success", "run_result": {"status": "Finished", "stdout": "output1", "return_code": 0}}, - None, - ) - elif stdin == "input2_timeout": - return (None, api_error_message) - elif stdin == "input3": - return ( - {"status": "Success", "run_result": {"status": "Finished", "stdout": "output3", "return_code": 0}}, - None, - ) - else: - return (None, "Unknown input in mock") - - mock_call_sandbox_api.side_effect = side_effect - - results, metadata_list = check_correctness(sandbox_url, in_outs, generation, timeout, language) - - assert results == [True, -1, True] - assert len(metadata_list) == 3 - assert metadata_list[0]["status"] == "success" - assert metadata_list[1]["status"] == "api_error" - assert metadata_list[1]["api_request_error"] == api_error_message - assert metadata_list[2]["status"] == "success" - assert mock_call_sandbox_api.call_count == 3 - - -# --- Constants for the new concurrency test --- -# Define a low global concurrency limit to test the semaphore's effect -MAX_GLOBAL_CONCURRENCY_LIMIT_TEST = 5 -# Define the number of processes used in the test -NUM_PROCESSES_TEST = 4 -# Define the number of tasks processed by check_correctness in each process (i.e., internal -# ThreadPoolExecutor's concurrency potential) -NUM_TASKS_PER_PROCESS_TEST = 3 -# Simulate API call duration to ensure calls can overlap -SIMULATED_API_CALL_DURATION_TEST = 0.2 # seconds - - -# --- Mock API call function for concurrency tracking --- -# This function will replace the real call_sandbox_api and use shared variables to track concurrency -def _mock_api_call_for_concurrency_tracking( - active_calls_counter, # multiprocessing.Value - max_calls_tracker, # multiprocessing.Value - call_lock, # multiprocessing.Lock - # Standard call_sandbox_api parameters - sandbox_fusion_url, - code, - stdin, - compile_timeout, - run_timeout, - memory_limit_mb, - language, -): - # entry_time = time.time() # For detailed logging - with call_lock: - active_calls_counter.value += 1 - if active_calls_counter.value > max_calls_tracker.value: - max_calls_tracker.value = active_calls_counter.value - # Optional debug log: - # print(f"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call Start. Active: " - # f"{active_calls_counter.value}, Max Observed: {max_calls_tracker.value}, Input: {stdin}") - - time.sleep(SIMULATED_API_CALL_DURATION_TEST) # Simulate actual work duration - - # exit_time = time.time() # For detailed logging - with call_lock: - active_calls_counter.value -= 1 - # Optional debug log: - # print(f"[PID:{os.getpid()}-TID:{threading.get_ident()}] API Call End. Active: " - # f"{active_calls_counter.value}, Input: {stdin}, Duration: {exit_time - entry_time:.2f}s") - - # Return a simulated successful API response - return { - "status": "Success", - "run_result": {"status": "Finished", "stdout": f"mock_output_for_{stdin}", "return_code": 0}, - }, None - - -# --- Worker function for ProcessPoolExecutor --- -# This function runs in each child process of ProcessPoolExecutor -def _process_pool_worker_for_concurrency_test( - sandbox_url, - in_outs, - generation, - memory_limit_mb, - language, - timeout, - mp_semaphore_for_check_correctness, - active_calls_counter, - max_calls_tracker, - call_lock, -): - # Corrected lambda to accept keyword arguments matching call_sandbox_api's usage - curried_mock_api_call = ( - lambda sandbox_fusion_url, code, stdin, compile_timeout, run_timeout, memory_limit_mb, language: ( - _mock_api_call_for_concurrency_tracking( - active_calls_counter, - max_calls_tracker, - call_lock, - sandbox_fusion_url, - code, - stdin, - compile_timeout, - run_timeout, - memory_limit_mb, - language, - ) - ) - ) - - # ---- START DEBUG PRINTS ---- - import os - - import verl.utils.reward_score.sandbox_fusion.utils - - print( - f"[Worker PID:{os.getpid()}] Original call_sandbox_api: " - f"{verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}", - flush=True, - ) - # ---- END DEBUG PRINTS ---- - - with patch( - "verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api", side_effect=curried_mock_api_call - ) as mock_obj: - # ---- START DEBUG PRINTS ---- - print( - f"[Worker PID:{os.getpid()}] Patched call_sandbox_api: " - f"{verl.utils.reward_score.sandbox_fusion.utils.call_sandbox_api}", - flush=True, - ) - print(f"[Worker PID:{os.getpid()}] Mock object: {mock_obj}", flush=True) - # ---- END DEBUG PRINTS ---- - results, metadata_list = check_correctness( - sandbox_fusion_url=sandbox_url, - in_outs=in_outs, - generation=generation, - timeout=timeout, - memory_limit_mb=memory_limit_mb, - language=language, - concurrent_semaphore=mp_semaphore_for_check_correctness, # Pass multiprocessing.Semaphore - ) - # print(f"Process {os.getpid()} finished check_correctness. Processed {len(results)} tasks.") - return len(results) # Return the number of processed tasks for basic validation - - -# --- The actual test case for multiprocess concurrency control --- -def test_multiprocess_global_concurrency_limit_with_semaphore(): - """ - Tests that the global concurrent_semaphore (multiprocessing.Semaphore) - correctly limits the number of concurrent calls to call_sandbox_api - across multiple processes, each potentially running multiple threads - via check_correctness's internal ThreadPoolExecutor. - """ - manager = multiprocessing.Manager() - active_calls_counter = manager.Value("i", 0) # Current active mock API calls - max_calls_tracker = manager.Value("i", 0) # Observed maximum concurrent mock API calls - call_lock = manager.Lock() # Lock to protect counters - - # Create a multiprocessing.Semaphore instance, this is the global semaphore we are testing. - # It will be passed to check_correctness and used by _process_single_case to limit calls to call_sandbox_api. - global_mp_semaphore = manager.Semaphore(MAX_GLOBAL_CONCURRENCY_LIMIT_TEST) - - mock_sandbox_url = "mock_url_for_concurrency_test" - mock_generation = "pass" # Specific code content is not important as API call is mocked - mock_memory_limit_mb = 1024 - mock_language = "python" - mock_timeout = 5 # Timeout setting, not critical for mock calls - - # Input/output data for each process - # NUM_TASKS_PER_PROCESS_TEST tasks will be handled by check_correctness's internal ThreadPoolExecutor - process_in_outs = { - "inputs": [f"task_input_{i}" for i in range(NUM_TASKS_PER_PROCESS_TEST)], - "outputs": [f"task_output_{i}" for i in range(NUM_TASKS_PER_PROCESS_TEST)], - } - - futures = [] - total_tasks_expected_to_run = NUM_PROCESSES_TEST * NUM_TASKS_PER_PROCESS_TEST - - test_start_time = time.time() - - with ProcessPoolExecutor(max_workers=NUM_PROCESSES_TEST) as executor: - for i in range(NUM_PROCESSES_TEST): - future = executor.submit( - _process_pool_worker_for_concurrency_test, # Worker function - mock_sandbox_url, - process_in_outs, - mock_generation, - mock_memory_limit_mb, - mock_language, - mock_timeout, - global_mp_semaphore, # Global semaphore to test - active_calls_counter, # Shared variables for tracking - max_calls_tracker, - call_lock, - ) - futures.append(future) - - # Wait for all processes to complete and collect results - num_tasks_processed_per_worker = [f.result() for f in futures] - test_end_time = time.time() - total_execution_time = test_end_time - test_start_time - - # Print some test statistics for debugging and validation - print("\n--- Global Concurrency Test Stats ---") - print(f"Semaphore Limit (MAX_GLOBAL_CONCURRENCY_LIMIT_TEST): {MAX_GLOBAL_CONCURRENCY_LIMIT_TEST}") - print(f"Number of Processes (NUM_PROCESSES_TEST): {NUM_PROCESSES_TEST}") - print(f"Tasks per Process (NUM_TASKS_PER_PROCESS_TEST): {NUM_TASKS_PER_PROCESS_TEST}") - print(f"Total Tasks Submitted: {total_tasks_expected_to_run}") - print(f"Simulated API Call Duration: {SIMULATED_API_CALL_DURATION_TEST}s") - print(f"Total Test Execution Time: {total_execution_time:.2f}s") - print(f"Max Concurrent Mock API Calls Observed: {max_calls_tracker.value}") - # print(f"Tasks processed per worker: {num_tasks_processed_per_worker}") - - # Verify that all submitted tasks have been processed - assert sum(num_tasks_processed_per_worker) == total_tasks_expected_to_run, ( - "Mismatch in the number of tasks processed." - ) - - # Verify that the mock API was called at least once - assert max_calls_tracker.value > 0, "The mocked API call_sandbox_api was not called." - - # Core assertion: Observed maximum concurrent calls should not exceed the semaphore's limit - assert max_calls_tracker.value <= MAX_GLOBAL_CONCURRENCY_LIMIT_TEST, ( - f"Observed concurrency ({max_calls_tracker.value}) exceeded semaphore limit " - f"({MAX_GLOBAL_CONCURRENCY_LIMIT_TEST})." - ) - - # Optional: Rough check on execution time to verify semaphore is working to limit concurrency - # Theoretical minimum execution time = (Total tasks / Concurrency limit) * Single task duration - # Actual time will be longer due to various overheads - min_expected_duration = ( - total_tasks_expected_to_run * SIMULATED_API_CALL_DURATION_TEST - ) / MAX_GLOBAL_CONCURRENCY_LIMIT_TEST - # print(f"Minimum Expected Execution Time (approx): {min_expected_duration:.2f}s") - # Allow some margin, e.g., 80% of theoretical minimum time - assert total_execution_time >= min_expected_duration * 0.8, ( - f"Total execution time ({total_execution_time:.2f}s) was unexpectedly short, suggesting the " - f"semaphore might not be effectively limiting concurrency as expected " - f"(min expected: {min_expected_duration * 0.8:.2f}s)." - ) - - -# Ensure there is no more code after this point if these were the last functions. -# If there was other code, it would follow here. -def test_unit_invalid_input_format(): - """Unit test: Invalid in_outs format passed""" - results, metadata_list = check_correctness(SANDBOX_URL, None, CODE_SUCCESS) - assert results == [-1] - assert metadata_list[0]["error"] == "Invalid input/output data" - - results, metadata_list = check_correctness(SANDBOX_URL, {}, CODE_SUCCESS) - assert results == [-1] - assert metadata_list[0]["error"] == "Invalid input/output data" - - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_INVALID_MISSING_KEY, CODE_SUCCESS) - assert results == [-1] - assert metadata_list[0]["error"] == "Invalid input/output data" - - -@pytest.mark.skipif(skip_condition, reason=skip_reason) -def test_unit_input_output_mismatch(): - """Unit test: Mismatch between the number of inputs and outputs""" - results, metadata_list = check_correctness(SANDBOX_URL, INPUT_OUTPUT_MISMATCH, CODE_SUCCESS) - assert results == [-1] - assert len(metadata_list) == 1 - assert metadata_list[0]["error"] == "Input/output count mismatch" - - -@pytest.mark.skipif(skip_condition, reason=skip_reason) -def test_integration_concurrency_all_timeout(): - """Integration test: High concurrency (100 cases) against real API, all causing timeout""" - concurrency_level = 100 - code_infinite_loop = """ -def knight_moves(X, Y): - MOD = 10**9 + 7 - dp = [[0] * (Y + 1) for _ in range(X + 1)] - dp[0][0] = 1 - for i in range(1, X + 1): - for j in range(1, Y + 1): - dp[i][j] = (dp[i - 1][j] + dp[i][j - 1]) % MOD - return dp[X][Y] - -def solve(): - X, Y = map(int, input().split()) - print(knight_moves(X, Y)) - -if __name__ == "__main__": - solve() - """ - - # Generate 100 simple input/output pairs (content doesn't matter) - timeout_inputs = ["324 384429" for i in range(concurrency_level)] - timeout_outputs = [f"output_{i}\n" for i in range(concurrency_level)] - timeout_in_outs = {"inputs": timeout_inputs, "outputs": timeout_outputs} - - # Set a timeout for the test cases - test_timeout = 10 # Set a timeout value - - start_time = time.time() - results, metadata_list = check_correctness(SANDBOX_URL, timeout_in_outs, code_infinite_loop, timeout=test_timeout) - end_time = time.time() - duration = end_time - start_time - print(f"\nHigh concurrency all timeout test ({concurrency_level} cases) duration: {duration:.2f} seconds") - - # Verify all results are -3 (timeout) - assert len(results) == concurrency_level, f"Expected {concurrency_level} results, got {len(results)}" - all_timed_out = all(r == -3 for r in results) - if not all_timed_out: - non_timeout_indices = [i for i, r in enumerate(results) if r != -3] - print(f"Indices that did not time out: {non_timeout_indices}") - # Print metadata for the first few non-timeout cases for debugging - for i in non_timeout_indices[:5]: - print(f"Metadata for non-timeout case {i}: {metadata_list[i]}") - assert all_timed_out, f"Not all {concurrency_level} concurrent tests resulted in timeout (-3). Results: {results}" - - # Verify metadata count and status of the first case - assert len(metadata_list) == concurrency_level - assert metadata_list[0]["status"] == "timeout" - - -@pytest.mark.skipif(skip_condition, reason=skip_reason) -def test_fn_name_success_single_case(): - """Tests successful execution for a single test case with fn_name. - from livecodebench/code_generation_lite test 510 - """ - generation_code = """ -class Solution: - def occurrencesOfElement(self, nums: List[int], queries: List[int], x: int) -> List[int]: - positions = defaultdict(list) - for idx, num in enumerate(nums): - positions[num].append(idx) - - x_positions = positions[x] - answer = [] - for k in queries: - if k > len(x_positions): - answer.append(-1) - else: - answer.append(x_positions[k-1]) - return answer -""" - in_outs = { - "fn_name": "occurrencesOfElement", - "inputs": ["[1, 3, 1, 7]\n[1, 3, 2, 4]\n1", "[1, 2, 3]\n[10]\n5"], - "outputs": ["[0, -1, 2, -1]", "[-1]"], - } - - # Use a short timeout for fast tests - results, metadata_list = check_correctness(SANDBOX_URL, in_outs, generation_code, timeout=5) - # from verl.utils.reward_score.prime_code import apps_check_correctness - # results, metadata_list = apps_check_correctness(in_outs=in_outs, generation=generation_code, - # timeout=50000, debug=True) - - assert results == [True, True] - assert "error" not in metadata_list[0] - assert metadata_list[0].get("status") != "compilation error" - assert metadata_list[0].get("status") != "runtime error" - - -@pytest.mark.skipif(skip_condition, reason=skip_reason) -def test_none_and_empty_stdin_passed_correctly(): - """ - Tests that when stdin data is set to an empty string or None, it is still - is passed correctly to Sandbox Fusion as an empty string. - """ - echo_code = """ -import sys -print(f"You said '{sys.stdin.readline().strip()}'") -""" - in_outs = { - "inputs": [None, "", "hello"], - "outputs": ["You said ''", "You said ''", "You said 'hello'"], - } - - # Use a short timeout for fast tests - results, metadata_list = check_correctness(SANDBOX_URL, in_outs, echo_code, timeout=5) - - assert results == [True, True, True] - assert "error" not in metadata_list[0] - assert metadata_list[0].get("status") != "compilation error" - assert metadata_list[0].get("status") != "runtime error" diff --git a/tests/utils/reward_score/test_sandbox_on_cpu.py b/tests/utils/reward_score/test_sandbox_on_cpu.py deleted file mode 100644 index ff4073232..000000000 --- a/tests/utils/reward_score/test_sandbox_on_cpu.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2024 PRIME team and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import json -import os - -import pytest - -from verl.utils.reward_score import default_compute_score, prime_code, sandbox_fusion -from verl.utils.reward_score.prime_code import apps_check_correctness -from verl.workers.reward_manager.prime import parallel_compute_score_async - -prime_math_answers = [ - """\\begin{bmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19 \n \\end{bmatrix}""", - """\\frac{\\sqrt{505}}{7}""", - """x^2 + y^2 + 4x - 6y + 13""", -] -prime_math_gts = [ - """\\begin{pmatrix}\n -7 & 6 & -8 \\\\\n 11 & -9 & 12 \\\\\n 15 & -16 & 19\n \\end{pmatrix}""", # mat test - """\\frac{\\sqrt{505}}{7}""", # frac test - """(x + 2)^2 + (y - 3)^2 """, # symbolic test -] - -prime_code_answers = [ - """import sys -from collections import deque - -def main(): - data = sys.stdin.read().split() - it = iter(data) - - # Read start and target positions - x0, y0, x1, y1 = int(next(it)), int(next(it)), int(next(it)), int(next(it)) - - n = int(next(it)) - allowed = set() - # The total number of allowed cells is at most 10^5. - for _ in range(n): - r = int(next(it)) - a = int(next(it)) - b = int(next(it)) - for c in range(a, b + 1): - allowed.add((r, c)) - - # Directions for the king (8 neighboring cells) - directions = [(-1, -1), (-1, 0), (-1, 1), - (0, -1), (0, 1), - (1, -1), (1, 0), (1, 1)] - - start = (x0, y0) - target = (x1, y1) - - # BFS initialization - queue = deque() - queue.append((x0, y0, 0)) - # Mark the starting cell as visited by removing it from allowed set. - allowed.discard(start) - - while queue: - x, y, moves = queue.popleft() - if (x, y) == target: - print(moves) - return - for dx, dy in directions: - nx, ny = x + dx, y + dy - if (nx, ny) in allowed: - allowed.remove((nx, ny)) - queue.append((nx, ny, moves + 1)) - - print(-1) - -if __name__ == '__main__': - main() -""" -] * 2 -prime_code_gts = [ - """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"2\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", # A correct sample # noqa: E501 - """{\n \"inputs\": [\n \"5 7 6 11\\n3\\n5 3 8\\n6 7 11\\n5 2 5\\n\",\n \"3 4 3 10\\n3\\n3 1 4\\n4 5 9\\n3 10 10\\n\",\n \"1 1 2 10\\n2\\n1 1 3\\n2 6 10\\n\",\n \"9 8 7 8\\n9\\n10 6 6\\n10 6 6\\n7 7 8\\n9 5 6\\n8 9 9\\n9 5 5\\n9 8 8\\n8 5 6\\n9 10 10\\n\",\n \"6 15 7 15\\n9\\n6 15 15\\n7 14 14\\n6 15 15\\n9 14 14\\n7 14 16\\n6 15 15\\n6 15 15\\n7 14 14\\n8 15 15\\n\",\n \"13 16 20 10\\n18\\n13 16 16\\n20 10 10\\n19 10 10\\n12 15 15\\n20 10 10\\n18 11 11\\n19 10 10\\n19 10 10\\n20 10 10\\n19 10 10\\n20 10 10\\n20 10 10\\n19 10 10\\n18 11 11\\n13 16 16\\n12 15 15\\n19 10 10\\n19 10 10\\n\",\n \"89 29 88 30\\n16\\n87 31 31\\n14 95 95\\n98 88 89\\n96 88 88\\n14 97 97\\n13 97 98\\n100 88 88\\n88 32 32\\n99 88 89\\n90 29 29\\n87 31 31\\n15 94 96\\n89 29 29\\n88 32 32\\n97 89 89\\n88 29 30\\n\",\n \"30 14 39 19\\n31\\n35 7 11\\n37 11 12\\n32 13 13\\n37 5 6\\n46 13 13\\n37 14 14\\n31 13 13\\n43 13 19\\n45 15 19\\n46 13 13\\n32 17 17\\n41 14 19\\n30 14 14\\n43 13 17\\n34 16 18\\n44 11 19\\n38 13 13\\n40 12 20\\n37 16 18\\n46 16 18\\n34 10 14\\n36 9 10\\n36 15 19\\n38 15 19\\n42 13 19\\n33 14 15\\n35 15 19\\n33 17 18\\n39 12 20\\n36 5 7\\n45 12 12\\n\",\n \"2 1 1 1\\n2\\n1 1 2\\n2 1 2\\n\",\n \"1 1 1 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\",\n \"1 1 1000000000 2\\n5\\n1000000000 1 10000\\n19920401 1188 5566\\n1000000000 1 10000\\n1 1 10000\\n5 100 200\\n\"\n ],\n \"outputs\": [\n \"4\\n\",\n \"6\\n\",\n \"-1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"-1\\n\",\n \"1\\n\",\n \"9\\n\",\n \"1\\n\",\n \"1\\n\",\n \"-1\\n\"\n ]\n}""", # noqa: E501 -] # A failed sample with first several in-out passed - -prime_code_scores = [1.0, 0.9] - - -def test_parallelism(): - """ - Test if process pool works properly - """ - sequences_str = [] - ground_truth = [] - data_sources = [] - while len(sequences_str) < 32: - sequences_str.extend(prime_code_answers) - ground_truth.extend(prime_code_gts) - data_sources.extend(["codecontests"] * len(prime_code_answers)) - - sequences_str.extend(prime_math_answers) - ground_truth.extend(prime_math_gts) - data_sources.extend(["numina_aops_forum"] * len(prime_math_answers)) - - scores = asyncio.run( - parallel_compute_score_async(default_compute_score, sequences_str, ground_truth, data_sources, num_processes=16) - ) - print(scores) - - -def test_prime_code(): - """ - Test PRIME code sandbox. - """ - data_source = "codecontests" - for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True): - score = default_compute_score(data_source, completion, ground_truth) - assert float(score) == score_ - - -# Use the pytest.mark.skipif decorator to skip the test -@pytest.mark.skipif(not os.environ.get("SANDBOX_FUSION_URL"), reason="SANDBOX_FUSION_URL environment variable not set") -def test_prime_code_sandbox_fusion(): - """ - Test PRIME code on sandbox fusion. Skips if SANDBOX_FUSION_URL is not set. - """ - data_source = "codecontests" - # Get the URL from the environment variable, as skipif ensures it is set at this point - sandbox_fusion_url = os.environ.get("SANDBOX_FUSION_URL") - # Removed the previous 'if not sandbox_url' check block - - for completion, ground_truth, score_ in zip(prime_code_answers, prime_code_gts, prime_code_scores, strict=True): - score = default_compute_score( - data_source, completion, ground_truth, extra_info={"sandbox_fusion_url": sandbox_fusion_url} - ) # <-- Use the URL obtained from the environment variable - assert float(score) == score_ - - -@pytest.mark.skipif(not os.environ.get("SANDBOX_FUSION_URL"), reason="SANDBOX_FUSION_URL environment variable not set") -def test_continuous_score_consistency(): - """ - Verify that continuous score calculation is consistent between prime_code and sandbox_fusion. - Uses a test case where the first 9 out of 11 sub-cases pass (expected score 0.9). - """ - completion = prime_code_answers[1] # Use the second sample - ground_truth = prime_code_gts[1] # Use the second sample (9/11 pass, first 9 pass) - expected_continuous_score = 0.9 - - # 1. Calculate score using prime_code (default) with continuous=True - prime_score, _ = sandbox_fusion.compute_score( - os.environ.get("SANDBOX_FUSION_URL"), None, completion, ground_truth, continuous=True - ) - - # 2. Calculate score using sandbox_fusion with continuous=True - # Ensure the extra_info key triggers the sandbox_fusion path in default_compute_score - fusion_score, _ = prime_code.compute_score(completion, ground_truth, continuous=True) - - # 3. Assert scores are equal (using pytest.approx for float comparison) - assert float(prime_score) == pytest.approx(expected_continuous_score) - assert float(fusion_score) == pytest.approx(expected_continuous_score) - assert float(prime_score) == pytest.approx(float(fusion_score)) - print(f"Continuous Score (Prime Code): {prime_score}") - print(f"Continuous Score (Sandbox Fusion): {fusion_score}") - - -def test_check_correctness(): - completion = prime_code_answers[0] - ground_truth = json.loads(prime_code_gts[0]) - ground_truth_single = {"inputs": ground_truth["inputs"][:1], "outputs": ground_truth["outputs"][:1]} - res, meta = apps_check_correctness(in_outs=ground_truth_single, generation=completion, timeout=5, debug=False) - print(res, meta) - - -def test_prime_math(): - data_source = "numina_aops_forum" - for completion, ground_truth in zip(prime_math_answers, prime_math_gts, strict=True): - score = default_compute_score(data_source, completion, ground_truth) - assert float(score) == 1.0 diff --git a/tests/utils/test_activation_offload.py b/tests/utils/test_activation_offload.py deleted file mode 100644 index 2393d7962..000000000 --- a/tests/utils/test_activation_offload.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import shutil -import tempfile - -import pytest -import torch -import torch.distributed -import torch.multiprocessing as mp -from torch.distributed import init_device_mesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision, ShardingStrategy -from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config - -from verl.utils.activation_offload import enable_activation_offloading -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, get_fsdp_wrap_policy - - -def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy="fsdp"): - torch.cuda.set_device(rank) - torch.distributed.init_process_group( - backend="nccl", - init_method=f"file://{rendezvous_file}", - rank=rank, - world_size=world_size, - ) - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("dp",)) - - model_name = "Qwen/Qwen2.5-0.5B-Instruct" - config = Qwen2Config(num_hidden_layers=4) - - with torch.device("cuda"): - model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" - ) - model = model.to(device="cuda") - - # Wrap model with FSDP - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) - - if strategy == "fsdp": - model = FSDP( - model, - use_orig_params=False, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - device_mesh=device_mesh, - auto_wrap_policy=get_fsdp_wrap_policy(module=model), - ) - else: - mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True - ) - fsdp_kwargs = { - "mesh": device_mesh, - "mp_policy": mp_policy, - } - apply_fsdp2(model, fsdp_kwargs, {}) - - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) - - # Create checkpoint manager - tokenizer = AutoTokenizer.from_pretrained(model_name) - checkpoint_manager = FSDPCheckpointManager( - model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer - ) - - # Generate sample input - batch_size = 2 - seq_len = 32 - vocab_size = 32000 - # First input for initial update - input_ids1 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") - attention_mask1 = torch.ones_like(input_ids1) - - # Second input for verification - input_ids2 = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") - attention_mask2 = torch.ones_like(input_ids2) - - # Step 1: Initial update and save checkpoint - outputs1 = model(input_ids=input_ids1, attention_mask=attention_mask1) - loss1 = outputs1.logits.mean() - loss1.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Save checkpoint after first update - temp_dir = tempfile.mkdtemp() - checkpoint_path = os.path.join(temp_dir, "checkpoint") - checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0) - - # Step 2: Second update and forward pass - outputs2 = model(input_ids=input_ids2, attention_mask=attention_mask2) - loss2 = outputs2.logits.mean() - loss2.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Record logits after second update - with torch.no_grad(): - logits_without_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits - - # Step 3: wrap module with activation offloading and load checkpoint - enable_activation_offloading(model, "fsdp") - checkpoint_manager.load_checkpoint(checkpoint_path) - - # Step 4: Repeat the second update with same input - outputs3 = model(input_ids=input_ids2, attention_mask=attention_mask2) - loss3 = outputs3.logits.mean() - loss3.backward() - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Record logits after loaded checkpoint and update - with torch.no_grad(): - logits_with_offloading = model(input_ids=input_ids2, attention_mask=attention_mask2).logits - - # Step 4: Verify outputs match - torch.testing.assert_close(logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0) - print(f"Activaiton offloading for {strategy} test passed on {world_size} GPUs!") - - # Cleanup - shutil.rmtree(temp_dir) - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - -@pytest.mark.parametrize("world_size", (2, 4)) -@pytest.mark.parametrize("strategy", ("fsdp", "fsdp2")) -def test_activation_offloading(world_size, strategy, tmp_path): - rendezvous_file = str(tmp_path / "rdzv_file") - os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) - - mp.spawn( - fn=_fsdp_activation_offloading_test, - args=(world_size, rendezvous_file, strategy), - nprocs=world_size, - join=True, - ) diff --git a/tests/utils/test_config_on_cpu.py b/tests/utils/test_config_on_cpu.py deleted file mode 100644 index 42dc8e1f2..000000000 --- a/tests/utils/test_config_on_cpu.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from dataclasses import dataclass - -from omegaconf import OmegaConf - -from verl.utils import omega_conf_to_dataclass - - -@dataclass -class TestDataclass: - hidden_size: int - activation: str - - -@dataclass -class TestTrainConfig: - batch_size: int - model: TestDataclass - - -_cfg_str = """train_config: - batch_size: 32 - model: - hidden_size: 768 - activation: relu""" - - -class TestConfigOnCPU(unittest.TestCase): - """Test cases for configuration utilities on CPU. - - Test Plan: - 1. Test basic OmegaConf to dataclass conversion for simple nested structures - 2. Test nested OmegaConf to dataclass conversion for complex hierarchical configurations - 3. Verify all configuration values are correctly converted and accessible - """ - - def setUp(self): - self.config = OmegaConf.create(_cfg_str) - - def test_omega_conf_to_dataclass(self): - sub_cfg = self.config.train_config.model - cfg = omega_conf_to_dataclass(sub_cfg, TestDataclass) - self.assertEqual(cfg.hidden_size, 768) - self.assertEqual(cfg.activation, "relu") - assert isinstance(cfg, TestDataclass) - - def test_nested_omega_conf_to_dataclass(self): - cfg = omega_conf_to_dataclass(self.config.train_config, TestTrainConfig) - self.assertEqual(cfg.batch_size, 32) - self.assertEqual(cfg.model.hidden_size, 768) - self.assertEqual(cfg.model.activation, "relu") - assert isinstance(cfg, TestTrainConfig) - assert isinstance(cfg.model, TestDataclass) - - -class TestPrintCfgCommand(unittest.TestCase): - """Test suite for the print_cfg.py command-line tool.""" - - def test_command_with_override(self): - """Test that the command runs without error when overriding config values.""" - import subprocess - - # Run the command - result = subprocess.run( - ["python3", "scripts/print_cfg.py", "critic.profiler.discrete=True", "+critic.profiler.extra.any_key=val"], - capture_output=True, - text=True, - ) - - # Verify the command exited successfully - self.assertEqual(result.returncode, 0, f"Command failed with stderr: {result.stderr}") - - # Verify the output contains expected config information - self.assertIn("critic", result.stdout) - self.assertIn("profiler", result.stdout) - self.assertIn("discrete=True", result.stdout) - self.assertIn("extra={'any_key': 'val'}", result.stdout) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/utils/test_flops_counter.py b/tests/utils/test_flops_counter.py deleted file mode 100644 index 0b8889b3a..000000000 --- a/tests/utils/test_flops_counter.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import pytest - -from verl.utils.flops_counter import FlopsCounter - -VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3"} - - -class Config: - def __init__(self, config_dict): - for key, value in config_dict.items(): - setattr(self, key, value) - - -CONFIG = { - "llama": { - "config": { # llama2-7B - "model_type": "llama", - "vocab_size": 32000, - "hidden_size": 4096, - "intermediate_size": 11008, - "num_hidden_layers": 32, - "num_attention_heads": 32, - "num_key_value_heads": 32, - }, - "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), - # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + - # 12*sum(seqlen^2)*layer*head*head_dim - # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(512+1024+2048) + - # 12*(512*512+1024*1024+2048*2048)*32*4096 - # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(4096+4096+4096) + - # 12*(4096*4096+4096*4096+4096*4096)*32*4096 - "expected_flops_tuple": (153555818250240 / 1e12, 575955114393600 / 1e12), - }, - "qwen2": { - "config": { # Qwen/Qwen2.5-7B-Instruct - "model_type": "qwen2", - "vocab_size": 152064, - "hidden_size": 3584, - "intermediate_size": 18944, - "num_hidden_layers": 28, - "num_attention_heads": 28, - "num_key_value_heads": 4, - }, - "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), - # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + - # 12*sum(seqlen^2)*layer*head*head_dim - # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(512+1024+2048) + - # 12*(512*512+1024*1024+2048*2048)*28*3584 - # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(4096+4096+4096) + - # 12*(4096*4096+4096*4096+4096*4096)*28*3584 - "expected_flops_tuple": (170388331954176 / 1e12, 622070178250752 / 1e12), - }, - "qwen3": { - "config": { # Qwen/Qwen3-8B - "model_type": "qwen3", - "vocab_size": 151936, - "hidden_size": 4096, - "intermediate_size": 12288, - "num_hidden_layers": 36, - "num_attention_heads": 32, - "num_key_value_heads": 8, - "head_dim": 128, - }, - "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), - # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum + - # 12*sum(seqlen^2)*layer*head*head_dim - # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(512+1024+2048) + - # 12*(512*512+1024*1024+2048*2048)*36*128*32 - # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(4096+4096+4096) + - # 12*(4096*4096+4096*4096+4096*4096)*36*128*32 - "expected_flops_tuple": (185867930959872 / 1e12, 692924253732864 / 1e12), - }, - "qwen3_moe": { - "config": { # Qwen/Qwen3-30B-A3B-Base - "model_type": "qwen3_moe", - "hidden_size": 2048, - "vocab_size": 151936, - "num_hidden_layers": 48, - "num_key_value_heads": 4, - "num_attention_heads": 32, - "head_dim": 128, - "moe_intermediate_size": 768, - "num_experts_per_tok": 8, - "num_experts": 128, - }, - "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), - # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3 + - # hidden*num_experts))*token_sum + 12*sum(seqlen^2)*layer*head*head_dim - # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(512+1024+2048) + - # 12*(512*512+1024*1024+2048*2048)*48*128*32 - # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(4096+4096+4096) + - # 12*(4096*4096+4096*4096+4096*4096)*48*128*32 - "expected_flops_tuple": (85087060230144 / 1e12, 365944098521088 / 1e12), - }, - "deepseek_v3": { - "config": { # deepseek-ai/DeepSeek-Prover-V2-671B - "model_type": "deepseek_v3", - "hidden_size": 7168, - "vocab_size": 129280, - "moe_intermediate_size": 2048, - "num_hidden_layers": 61, - "first_k_dense_replace": 3, - "num_attention_heads": 128, - "n_routed_experts": 256, - "num_experts_per_tok": 8, - "n_shared_experts": 1, - "kv_lora_rank": 512, - "qk_rope_head_dim": 64, - "v_head_dim": 128, - "intermediate_size": 18432, - "qk_nope_head_dim": 128, - "q_lora_rank": 1536, - }, - "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]), - # (1536*7168+128*192*1536+7168*(512+64)+128*(128+128)*512+128*128*7168) = 187105280 - # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(512+1024+2048) + - # 12*(512*512+1024*1024+2048*2048)*61*192*128 - # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(4096+4096+4096) + - # 12*(4096*4096+4096*4096+4096*4096)*61*192*128 - "expected_flops_tuple": (906535995703296 / 1e12, 3674028304760832 / 1e12), - }, -} - - -@pytest.mark.parametrize( - "config_type", - ["llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3"], -) -def test_flops_counter(config_type: str): - test_config = CONFIG[config_type] - config = Config(test_config["config"]) - flops_counter = FlopsCounter(config) - for batch_seqlens, expected_flops in zip( - test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"], strict=True - ): - # set delta time to 1 to get the flops - counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1) - print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}") - assert math.isclose(counted_flops, expected_flops), ( - f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}" - ) diff --git a/tests/utils/test_fs_on_cpu.py b/tests/utils/test_fs_on_cpu.py deleted file mode 100644 index 7ae85e01a..000000000 --- a/tests/utils/test_fs_on_cpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from pathlib import Path - -import verl.utils.fs as fs - - -def test_record_and_check_directory_structure(tmp_path): - # Create test directory structure - test_dir = tmp_path / "test_dir" - test_dir.mkdir() - (test_dir / "file1.txt").write_text("test") - (test_dir / "subdir").mkdir() - (test_dir / "subdir" / "file2.txt").write_text("test") - - # Create structure record - record_file = fs._record_directory_structure(test_dir) - - # Verify record file exists - assert os.path.exists(record_file) - - # Initial check should pass - assert fs._check_directory_structure(test_dir, record_file) is True - - # Modify structure and verify check fails - (test_dir / "new_file.txt").write_text("test") - assert fs._check_directory_structure(test_dir, record_file) is False - - -def test_copy_from_hdfs_with_mocks(tmp_path, monkeypatch): - # Mock HDFS dependencies - monkeypatch.setattr(fs, "is_non_local", lambda path: True) - - # side_effect will simulate the copy by creating parent dirs + empty file - def fake_copy(src: str, dst: str, *args, **kwargs): - dst_path = Path(dst) - dst_path.parent.mkdir(parents=True, exist_ok=True) - dst_path.write_bytes(b"") # touch an empty file - - monkeypatch.setattr(fs, "copy", fake_copy) # Mock actual HDFS copy - - # Test parameters - test_cache = tmp_path / "cache" - hdfs_path = "hdfs://test/path/file.txt" - - # Test initial copy - local_path = fs.copy_to_local(hdfs_path, cache_dir=test_cache) - expected_path = os.path.join(test_cache, fs.md5_encode(hdfs_path), os.path.basename(hdfs_path)) - assert local_path == expected_path - assert os.path.exists(local_path) - - -def test_always_recopy_flag(tmp_path, monkeypatch): - # Mock HDFS dependencies - monkeypatch.setattr(fs, "is_non_local", lambda path: True) - - copy_call_count = 0 - - def fake_copy(src: str, dst: str, *args, **kwargs): - nonlocal copy_call_count - copy_call_count += 1 - dst_path = Path(dst) - dst_path.parent.mkdir(parents=True, exist_ok=True) - dst_path.write_bytes(b"") - - monkeypatch.setattr(fs, "copy", fake_copy) # Mock actual HDFS copy - - test_cache = tmp_path / "cache" - hdfs_path = "hdfs://test/path/file.txt" - - # Initial copy (always_recopy=False) - fs.copy_to_local(hdfs_path, cache_dir=test_cache) - assert copy_call_count == 1 - - # Force recopy (always_recopy=True) - fs.copy_to_local(hdfs_path, cache_dir=test_cache, always_recopy=True) - assert copy_call_count == 2 - - # Subsequent normal call (always_recopy=False) - fs.copy_to_local(hdfs_path, cache_dir=test_cache) - assert copy_call_count == 2 # Should not increment diff --git a/tests/utils/test_import_utils_on_cpu.py b/tests/utils/test_import_utils_on_cpu.py deleted file mode 100644 index 59709b876..000000000 --- a/tests/utils/test_import_utils_on_cpu.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import pytest - -from verl.utils.import_utils import load_extern_type - -# Path to the test module -TEST_MODULE_PATH = os.path.join(os.path.dirname(__file__), "_test_module.py") - - -def test_load_extern_type_class(): - """Test loading a class from an external file""" - TestClass = load_extern_type(TEST_MODULE_PATH, "TestClass") - - # Verify the class was loaded correctly - assert TestClass is not None - assert TestClass.__name__ == "TestClass" - - # Test instantiation and functionality - instance = TestClass() - assert instance.value == "default" - - # Test with a custom value - custom_instance = TestClass("custom") - assert custom_instance.get_value() == "custom" - - -def test_load_extern_type_function(): - """Test loading a function from an external file""" - test_function = load_extern_type(TEST_MODULE_PATH, "test_function") - - # Verify the function was loaded correctly - assert test_function is not None - assert callable(test_function) - - # Test function execution - result = test_function() - assert result == "test_function_result" - - -def test_load_extern_type_constant(): - """Test loading a constant from an external file""" - constant = load_extern_type(TEST_MODULE_PATH, "TEST_CONSTANT") - - # Verify the constant was loaded correctly - assert constant is not None - assert constant == "test_constant_value" - - -def test_load_extern_type_nonexistent_file(): - """Test behavior when file doesn't exist""" - with pytest.raises(FileNotFoundError): - load_extern_type("/nonexistent/path.py", "SomeType") - - -def test_load_extern_type_nonexistent_type(): - """Test behavior when type doesn't exist in the file""" - with pytest.raises(AttributeError): - load_extern_type(TEST_MODULE_PATH, "NonExistentType") - - -def test_load_extern_type_none_path(): - """Test behavior when file path is None""" - result = load_extern_type(None, "SomeType") - assert result is None - - -def test_load_extern_type_invalid_module(): - """Test behavior when module has syntax errors""" - # Create a temporary file with syntax errors - import tempfile - - with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp_file: - temp_file.write("This is not valid Python syntax :") - temp_path = temp_file.name - - try: - with pytest.raises(RuntimeError): - load_extern_type(temp_path, "SomeType") - finally: - # Clean up the temporary file - if os.path.exists(temp_path): - os.remove(temp_path) diff --git a/tests/utils/test_linear_cross_entropy.py b/tests/utils/test_linear_cross_entropy.py deleted file mode 100644 index 0512d1376..000000000 --- a/tests/utils/test_linear_cross_entropy.py +++ /dev/null @@ -1,361 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch - -import verl.utils.torch_functional as verl_F -from verl.utils.experimental.torch_functional import FusedLinearForPPO -from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy -from verl.utils.torch_functional import logprobs_from_logits - -compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) -fused_linear_for_ppo = FusedLinearForPPO() -fused_linear_for_ppo.compile(dynamic=True) - -MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5) - - -def run_torch_entropy( - hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none" -) -> list[torch.Tensor]: - hidden = hidden.squeeze(0).to(torch.float32) - weight = weight.transpose(0, 1).to(torch.float32) - logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size] - logits /= temperature - pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] - entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] - entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] - entropy = entropy_a - entropy_b - logprobs = torch.nn.functional.cross_entropy(logits, labels.squeeze(0), reduction=reduction) # [num_tokens] - logprobs = torch.neg(logprobs) - return logprobs, entropy - - -def run_verl_original_entropy( - hidden: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - temperature: float, -) -> list[torch.Tensor]: - hidden = hidden.squeeze(0).to(torch.float32) - weight = weight.transpose(0, 1).to(torch.float32) - logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size] - logits /= temperature - # compute entropy - entropy = compute_entropy_from_logits(logits) # ((total_nnz / sp) + pad) - # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - logprobs = logprobs_from_logits(logits=logits, labels=labels, inplace_backward=False) - return logprobs, entropy - - -# To be tested -def run_verl_torch_fused_entropy( - hidden: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - temperature: float, -): - hidden = hidden.to(torch.float32) - weight = weight.to(torch.float32) - logprobs, entropy = fused_linear_for_ppo( - hidden, - weight, - labels, - temperature=temperature, - ) - return logprobs.squeeze(0), entropy.squeeze(0) - - -class TestLinearCrossEntropy: - def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None: - self.test_case_idx = test_case_idx - self.temperature = temperature - - def cleanup(self): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - import gc - - gc.collect() - torch.cuda.synchronize() - - def generate_hyper(self): - global MAX_TEST_CASES - - self.dtype = torch.bfloat16 - if self.test_case_idx == 0: - self.batch_size = 1 - self.num_tokens = 1937 - self.hidden_size = 3584 - self.vocab_size = 152064 - elif self.test_case_idx == 1: - self.batch_size = 1 - self.num_tokens = 2169 - self.hidden_size = 896 - self.vocab_size = 151936 - elif self.test_case_idx == 2: - self.batch_size = 1 - self.num_tokens = 1530 - self.hidden_size = 2048 - self.vocab_size = 32256 - elif self.test_case_idx == 3: - self.batch_size = 1 - self.num_tokens = 1388 - self.hidden_size = 4096 - self.vocab_size = 102400 - elif self.test_case_idx == 4: - self.batch_size = 1 - self.num_tokens = 8192 - self.hidden_size = 4096 - self.vocab_size = 102400 - else: - raise ValueError(f"Invalid test case index: {self.test_case_idx}") - assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5." - - def generate_forward_inputs(self): - hidden = ( - torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda") - .uniform_(-0.5, 0.5) - .requires_grad_() - ) - weight = ( - torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda") - .uniform_(-0.5, 0.5) - .requires_grad_() - ) - labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda") - return hidden, weight, labels - - def generate_backward_inputs(self): - g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5) - g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1) - return g_entropy, g_logprobs - - def verify_correctness(self, iterations=5): - self.cleanup() - self.generate_hyper() - - torch_forward_latency = list() - torch_backward_latency = list() - verl_forward_latency = list() - verl_backward_latency = list() - verl_fused_forward_latency = list() - verl_fused_backward_latency = list() - kernel_forward_latency = list() - kernel_backward_latency = list() - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - for i in range(iterations): - print(f"[INFO]: Iteration {i + 1} / {iterations}...", end="\r") - hidden, weight, labels = self.generate_forward_inputs() - - start_event.record() - (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature) - end_event.record() - torch.cuda.synchronize() - torch_forward_latency.append(start_event.elapsed_time(end_event)) - - start_event.record() - (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature) - end_event.record() - torch.cuda.synchronize() - verl_forward_latency.append(start_event.elapsed_time(end_event)) - - start_event.record() - (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy( - hidden, weight, labels, self.temperature - ) - end_event.record() - torch.cuda.synchronize() - verl_fused_forward_latency.append(start_event.elapsed_time(end_event)) - - start_event.record() - (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature) - end_event.record() - torch.cuda.synchronize() - kernel_forward_latency.append(start_event.elapsed_time(end_event)) - - torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4) - torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4) - - torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) - torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) - torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4) - torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4) - - torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) - torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) - torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) - torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) - torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4) - torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4) - - # backward - g_entropy, g_logprobs = self.generate_backward_inputs() - - start_event.record() - (d_torch_hidden, d_torch_weight) = torch.autograd.grad( - (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False - ) - end_event.record() - torch.cuda.synchronize() - torch_backward_latency.append(start_event.elapsed_time(end_event)) - - start_event.record() - (d_verl_hidden, d_verl_weight) = torch.autograd.grad( - (verl_entropy, verl_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False - ) - end_event.record() - torch.cuda.synchronize() - verl_backward_latency.append(start_event.elapsed_time(end_event)) - - start_event.record() - (d_verl_fused_hidden, d_verl_fused_weight) = torch.autograd.grad( - (verl_fused_entropy, verl_fused_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False - ) - end_event.record() - torch.cuda.synchronize() - verl_fused_backward_latency.append(start_event.elapsed_time(end_event)) - - start_event.record() - (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad( - (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False - ) - end_event.record() - torch.cuda.synchronize() - kernel_backward_latency.append(start_event.elapsed_time(end_event)) - - torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) - - torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) - torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) - - torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2) - - # remove first latency - torch_forward_latency = torch_forward_latency[1:] - torch_backward_latency = torch_backward_latency[1:] - verl_forward_latency = verl_forward_latency[1:] - verl_backward_latency = verl_backward_latency[1:] - verl_fused_forward_latency = verl_fused_forward_latency[1:] - verl_fused_backward_latency = verl_fused_backward_latency[1:] - kernel_forward_latency = kernel_forward_latency[1:] - kernel_backward_latency = kernel_backward_latency[1:] - - print("\n[INFO]: Verified forward & backward correctness.") - - print( - f"[INFO]: Forward pass: Torch implementation average time: " - f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms" - ) - print( - f"[INFO]: Backward pass: torch implementation average time: " - f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms" - ) - print( - f"[INFO]: Forward pass: VeRL implementation average time: " - f"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms" - ) - print( - f"[INFO]: Backward pass: VeRL implementation average time: " - f"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms" - ) - print( - f"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: " - f"{sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms" - ) - print( - f"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: " - f"{sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms" - ) - print( - f"[INFO]: Forward pass: Kernel implementation average time: " - f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms" - ) - print( - f"[INFO]: Backward pass: kernel implementation average time: " - f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms" - ) - - def check_storage(self, method_name, run_forward): - self.cleanup() - self.generate_hyper() - - hidden, weight, labels = self.generate_forward_inputs() - - torch.cuda.reset_peak_memory_stats() - (logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature) - torch.cuda.synchronize() - torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 - print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB") - - g_entropy, g_logprobs = self.generate_backward_inputs() - - torch.cuda.reset_peak_memory_stats() - (d_torch_hidden, d_torch_weight) = torch.autograd.grad( - (entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False - ) - torch.cuda.synchronize() - torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 - print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB") - - def check_storage_all(self): - self.check_storage("Torch", run_torch_entropy) - self.check_storage("VeRL", run_verl_original_entropy) - self.check_storage("VeRL Torch Fused", run_verl_torch_fused_entropy) - self.check_storage("Kernel", linear_cross_entropy) - - -if __name__ == "__main__": - # torch.cuda.memory._record_memory_history() - - for test_case_idx in range(MAX_TEST_CASES): - print(f"[INFO] Running test case {test_case_idx}") - test = TestLinearCrossEntropy(test_case_idx) - - test.verify_correctness() - test.check_storage_all() - - # torch.cuda.memory._dump_snapshot("test_linear_cross_entropy.pkl") diff --git a/tests/utils/test_linear_cross_entropy_tp.py b/tests/utils/test_linear_cross_entropy_tp.py deleted file mode 100644 index 9c1f868a9..000000000 --- a/tests/utils/test_linear_cross_entropy_tp.py +++ /dev/null @@ -1,514 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch -import torch.distributed as dist - -try: - from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy -except ImportError: - # FIXME: remove these manually included paths - import sys - - sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))) -finally: - from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy - -import verl.utils.torch_functional as verl_F - -compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) - -MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5) -VERIFY_TORCH_SELF = os.environ.get("VERIFY_TORCH_SELF", False) -LOW_MEMORY = os.environ.get("LOW_MEMORY", False) -LOW_MEMORY_DIV_FACTOR = os.environ.get("LOW_MEMORY_DIV_FACTOR", 16) - - -def run_torch_entropy( - hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none" -) -> list[torch.Tensor]: - # [num_tokens, vocab_size] - if len(hidden.shape) > 2: - hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size] - if len(labels.shape) > 1: - labels = labels.view(-1) - logits = torch.matmul( - hidden.to(torch.float32), - weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32), - ) - logits /= temperature - pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] - entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] - entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] - entropy = entropy_a - entropy_b - logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens] - logprobs = torch.neg(logprobs) - return logprobs, entropy - - -class TorchEntropyTP(torch.autograd.Function): - """ - it is used for testing the correctness of the kernel - it is not efficient and is not recommended to use in practice - """ - - @staticmethod - def forward( - ctx, - hidden: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - temperature: float, - dist_process_group: torch.distributed.ProcessGroup, - ): - # weight has shape [vocab_size, hidden_size], hidden has shape [num_tokens, hidden_size] - ctx.original_hidden_shape = hidden.shape - if len(hidden.shape) > 2: - hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size] - if len(labels.shape) > 1: - labels = labels.view(-1) - - logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T) # [num_tokens, vocab_size] - logits /= temperature - whole_logits = torch.empty( - (logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)), - dtype=logits.dtype, - device=logits.device, - ) - whole_logits_ref = [ - whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]] - for i in range(dist.get_world_size(dist_process_group)) - ] - dist.all_gather(whole_logits_ref, logits, group=dist_process_group) - - pd = torch.nn.functional.softmax(whole_logits, dim=-1) - entropy_a = torch.logsumexp(whole_logits, dim=-1) # [num_tokens] - entropy_b = torch.sum(pd * whole_logits, dim=-1) # [num_tokens] - entropy = entropy_a - entropy_b - - logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none") - logprobs = torch.neg(logprobs) - - ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b) - ctx.dist_process_group = dist_process_group - ctx.temperature = temperature - return logprobs, entropy - - @staticmethod - def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor): - hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors - dist_process_group = ctx.dist_process_group - temperature = ctx.temperature - batch_size, hidden_size = hidden.shape - vocab_size, hidden_size = weight.shape - rank = dist.get_rank(dist_process_group) - - # Compute softmax probabilities - maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True) - exp_logits = torch.exp(whole_logits - maximum) - accumulate = exp_logits.sum(dim=-1, keepdim=True) - pd = exp_logits / accumulate - - # Gradient for entropy - # entropy = entropy_a - entropy_b - # entropy_a = log(sum(exp(logits))) - # entropy_b = sum(pd * logits) - # d_entropy_a/d_logits = pd - # d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1) - # d_entropy/d_logits = d_entropy_a - d_entropy_b - # d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1) - # d_entropy/d_logits = -pd * (logits - b.unsqueeze(1)) - d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1))) - - # Gradient for logprobs - # logprobs = -cross_entropy = -log(pd[labels]) - # d_logprobs/d_logits = (pd - one_hot(labels)) - one_hot = torch.zeros_like(whole_logits) - one_hot.scatter_(1, labels.unsqueeze(1), 1) - g_logprobs = torch.neg(g_logprobs) - d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot) - # NOTE: This will lead to wrong result - # d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot - - # Combine gradients - d_logits = d_logits_entropy + d_logits_logprobs - d_logits /= temperature - - # Get local slice of gradients - local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size] - - # Compute gradients for hidden and weight - d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32)) - d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32)) - d_hidden = d_hidden.view(ctx.original_hidden_shape) - - return d_hidden, d_weight, None, None, None - - -run_torch_entropy_tp = TorchEntropyTP.apply - - -class TestLinearCrossEntropy_TensorParallel: - def __init__(self): - dist.init_process_group(backend="nccl") - self.group = dist.group.WORLD - - self.local_rank = dist.get_rank(self.group) - self.world_size = dist.get_world_size(self.group) - device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(device) - print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}") - - def initialize(self, test_case_idx: int, temperature: float = 1.5): - self.test_case_idx = test_case_idx - self.temperature = temperature - - def shutdown(self): - dist.destroy_process_group() - - def cleanup(self): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - import gc - - gc.collect() - torch.cuda.synchronize() - - def generate_hyper(self): - global LOW_MEMORY, LOW_MEMORY_DIV_FACTOR, MAX_TEST_CASES - - self.dtype = torch.bfloat16 - if self.test_case_idx == 0: - self.batch_size = 1 - self.num_tokens = 1937 - self.hidden_size = 3584 - self.vocab_size = 152064 - elif self.test_case_idx == 1: - self.batch_size = 1 - self.num_tokens = 2169 - self.hidden_size = 896 - self.vocab_size = 151936 - elif self.test_case_idx == 2: - self.batch_size = 1 - self.num_tokens = 1530 - self.hidden_size = 2048 - self.vocab_size = 32256 - elif self.test_case_idx == 3: - self.batch_size = 1 - self.num_tokens = 1388 - self.hidden_size = 4096 - self.vocab_size = 102400 - elif self.test_case_idx == 4: - self.batch_size = 1 - self.num_tokens = 8192 - self.hidden_size = 4096 - self.vocab_size = 102400 - else: - raise ValueError(f"Invalid test case index: {self.test_case_idx}") - if LOW_MEMORY: - self.vocab_size = int(self.vocab_size / LOW_MEMORY_DIV_FACTOR) - assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5." - - def generate_forward_inputs(self): - hidden = ( - torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda") - .uniform_(-0.5, 0.5) - .requires_grad_() - ) - weight = ( - torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda") - .uniform_(-0.5, 0.5) - .requires_grad_() - ) - labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda") - return hidden, weight, labels - - def generate_backward_inputs(self): - g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5) - g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1) - return g_entropy, g_logprobs - - def verify_torch_itself(self, iterations: int = 5): - self.cleanup() - self.generate_hyper() - - for i in range(iterations): - hidden, weight, labels = self.generate_forward_inputs() - - # NOTE: we need to manually synchronize hidden and labels among Process Group - dist.broadcast(hidden, src=0, group=self.group) - dist.broadcast(labels, src=0, group=self.group) - - # forward pass - # Create a tensor to hold the gathered weights from all ranks - # weight has shape [vocab_size, hidden_size] - # We want to gather along the first dimension to get [vocab_size * world_size, hidden_size] - - # Create a single contiguous tensor to hold all gathered weights - whole_weight = torch.empty( - (self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device - ) - - # Create views into the tensor for each rank's portion - whole_weight_views = [ - whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size) - ] - - # Perform all_gather operation using the views - dist.all_gather(whole_weight_views, weight, group=self.group) - - # Set requires_grad for autograd - whole_weight.requires_grad_() - - (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature) - - (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) - - torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4) - torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4) - - # backward pass - g_entropy, g_logprobs = self.generate_backward_inputs() - # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group - dist.broadcast(g_entropy, src=0, group=self.group) - dist.broadcast(g_logprobs, src=0, group=self.group) - - (single_d_hidden, single_d_weight) = torch.autograd.grad( - (single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False - ) - - (tp_d_hidden, tp_d_weight) = torch.autograd.grad( - (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False - ) - # NOTE: all-reduce on hidden is conducted outside the kernel - dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group) - - torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4) - # Extract the corresponding slice from single_d_weight for comparison - # tp_d_weight has shape [vocab_size, hidden_size] - # single_d_weight has shape [vocab_size * world_size, hidden_size] - torch.testing.assert_close( - tp_d_weight, - single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size], - atol=1e-2, - rtol=1e-4, - ) - - # atol=1e-3, rtol=1e-4) - if self.local_rank == 0: - print("[PASS] torch TP correctness is verified") - - def check_torch_storage(self): - self.cleanup() - self.generate_hyper() - - hidden, weight, labels = self.generate_forward_inputs() - - # NOTE: we need to manually synchronize hidden and labels among Process Group - dist.broadcast(hidden, src=0, group=self.group) - dist.broadcast(labels, src=0, group=self.group) - - torch.cuda.reset_peak_memory_stats() - (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) - torch.cuda.synchronize() - forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 - - g_entropy, g_logprobs = self.generate_backward_inputs() - # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group - dist.broadcast(g_entropy, src=0, group=self.group) - dist.broadcast(g_logprobs, src=0, group=self.group) - - torch.cuda.reset_peak_memory_stats() - (d_tp_hidden, d_tp_weight) = torch.autograd.grad( - (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False - ) - torch.cuda.synchronize() - backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 - # NOTE: all-reduce on hidden is conducted outside the kernel - dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group) - - if self.local_rank == 0: - print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB") - print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB") - - def verify_kernel_correctness(self, iterations: int = 5): - self.cleanup() - self.generate_hyper() - - torch_forward_latency = list() - torch_backward_latency = list() - kernel_forward_latency = list() - kernel_backward_latency = list() - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - for i in range(iterations): - hidden, weight, labels = self.generate_forward_inputs() - - # NOTE: we need to manually synchronize hidden and labels among Process Group - dist.broadcast(hidden, src=0, group=self.group) - dist.broadcast(labels, src=0, group=self.group) - - start_event.record() - (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group) - end_event.record() - torch.cuda.synchronize() - torch_forward_latency.append(start_event.elapsed_time(end_event)) - - start_event.record() - (kernel_logprobs, kernel_entropy) = linear_cross_entropy( - hidden, weight, labels, self.temperature, "none", self.group - ) - end_event.record() - torch.cuda.synchronize() - kernel_forward_latency.append(start_event.elapsed_time(end_event)) - - torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2) - torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2) - - # backward pass - g_entropy, g_logprobs = self.generate_backward_inputs() - # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group - dist.broadcast(g_entropy, src=0, group=self.group) - dist.broadcast(g_logprobs, src=0, group=self.group) - - start_event.record() - (torch_d_hidden, torch_d_weight) = torch.autograd.grad( - (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False - ) - end_event.record() - torch.cuda.synchronize() - torch_backward_latency.append(start_event.elapsed_time(end_event)) - # NOTE: all-reduce on hidden is conducted outside the kernel - dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group) - - start_event.record() - (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad( - (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False - ) - end_event.record() - torch.cuda.synchronize() - kernel_backward_latency.append(start_event.elapsed_time(end_event)) - # NOTE: all-reduce on hidden is conducted outside the kernel - dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group) - - torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2) - torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2) - - # remove first latency - torch_forward_latency = torch_forward_latency[1:] - torch_backward_latency = torch_backward_latency[1:] - kernel_forward_latency = kernel_forward_latency[1:] - kernel_backward_latency = kernel_backward_latency[1:] - - if self.local_rank == 0: - print("\n[PASS]: Verified kernel forward & backward correctness.") - - print( - f"[INFO]: Forward pass: Torch implementation average time: " - f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms" - ) - print( - f"[INFO]: Backward pass: torch implementation average time: " - f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms" - ) - print( - f"[INFO]: Forward pass: Kernel implementation average time: " - f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms" - ) - print( - f"[INFO]: Backward pass: kernel implementation average time: " - f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms" - ) - - def check_kernel_storage(self): - self.cleanup() - self.generate_hyper() - - hidden, weight, labels = self.generate_forward_inputs() - - # NOTE: we need to manually synchronize hidden and labels among Process Group - dist.broadcast(hidden, src=0, group=self.group) - dist.broadcast(labels, src=0, group=self.group) - - torch.cuda.reset_peak_memory_stats() - (kernel_logprobs, kernel_entropy) = linear_cross_entropy( - hidden, weight, labels, self.temperature, "none", self.group - ) - torch.cuda.synchronize() - kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 - - g_entropy, g_logprobs = self.generate_backward_inputs() - # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group - dist.broadcast(g_entropy, src=0, group=self.group) - dist.broadcast(g_logprobs, src=0, group=self.group) - - torch.cuda.reset_peak_memory_stats() - (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad( - (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False - ) - torch.cuda.synchronize() - kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 - # NOTE: all-reduce on hidden is conducted outside the kernel - dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group) - - if self.local_rank == 0: - print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") - print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") - - -if __name__ == "__main__": - # TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernels/test_linear_cross_entropy_tp.py - - # Check if running with torchrun (distributed mode) - assert int(os.environ["WORLD_SIZE"]) > 1, ( - "[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to " - "execute this script." - ) - torch.manual_seed(233376 + int(os.environ.get("RANK", 0))) - - # set_backward_method(BackwardEnum._Total_Fuse_MN) - # set_backward_method(BackwardEnum._Split_Dlogits_N) - - test = TestLinearCrossEntropy_TensorParallel() - for test_case_idx in range(MAX_TEST_CASES): - print(f"[INFO] Running test case {test_case_idx}") - test.initialize(test_case_idx) - if VERIFY_TORCH_SELF: - test.verify_torch_itself() - test.check_torch_storage() - test.verify_kernel_correctness() - test.check_kernel_storage() - - test.shutdown() diff --git a/tests/utils/test_model_on_cpu.py b/tests/utils/test_model_on_cpu.py deleted file mode 100644 index 8b1416c8a..000000000 --- a/tests/utils/test_model_on_cpu.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from types import SimpleNamespace # Or use a mock object library - -import pytest - -from verl.utils.model import update_model_config - - -# Parametrize with different override scenarios -@pytest.mark.parametrize( - "override_kwargs", - [ - {"param_a": 5, "new_param": "plain_added"}, - {"param_a": 2, "nested_params": {"sub_param_x": "updated_x", "sub_param_z": True}}, - ], -) -def test_update_model_config(override_kwargs): - """ - Tests that update_model_config correctly updates attributes, - handling both plain and nested overrides via parametrization. - """ - # Create a fresh mock config object for each test case - mock_config = SimpleNamespace( - param_a=1, nested_params=SimpleNamespace(sub_param_x="original_x", sub_param_y=100), other_param="keep_me" - ) - # Apply the updates using the parametrized override_kwargs - update_model_config(mock_config, override_kwargs) - - # Assertions to check if the config was updated correctly - if "nested_params" in override_kwargs: # Case 2: Nested override - override_nested = override_kwargs["nested_params"] - assert mock_config.nested_params.sub_param_x == override_nested["sub_param_x"], "Nested sub_param_x mismatch" - assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged" - assert hasattr(mock_config.nested_params, "sub_param_z"), "Expected nested sub_param_z to be added" - assert mock_config.nested_params.sub_param_z == override_nested["sub_param_z"], "Value of sub_param_z mismatch" - else: # Case 1: Plain override (nested params untouched) - assert mock_config.nested_params.sub_param_x == "original_x", "Nested sub_param_x should be unchanged" - assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged" - assert not hasattr(mock_config.nested_params, "sub_param_z"), "Nested sub_param_z should not exist" diff --git a/tests/utils/test_nvtx_profile.py b/tests/utils/test_nvtx_profile.py deleted file mode 100644 index 817d03000..000000000 --- a/tests/utils/test_nvtx_profile.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from unittest.mock import MagicMock, patch - -from verl.utils import omega_conf_to_dataclass -from verl.utils.profiler import ProfilerConfig -from verl.utils.profiler.nvtx_profile import NsightSystemsProfiler - - -class TestProfilerConfig(unittest.TestCase): - def test_config_init(self): - import os - - from hydra import compose, initialize_config_dir - - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): - cfg = compose(config_name="ppo_trainer") - arr = cfg.actor_rollout_ref - for config in [ - cfg.critic.profiler, - arr.profiler, - cfg.reward_model.profiler, - ]: - profiler_config = omega_conf_to_dataclass(config) - self.assertEqual(profiler_config.discrete, config.discrete) - self.assertEqual(profiler_config.all_ranks, config.all_ranks) - self.assertEqual(profiler_config.ranks, config.ranks) - assert isinstance(profiler_config, ProfilerConfig) - with self.assertRaises(AttributeError): - _ = profiler_config.non_existing_key - assert config.get("non_existing_key") == profiler_config.get("non_existing_key") - assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1) - assert config["discrete"] == profiler_config["discrete"] - from dataclasses import FrozenInstanceError - - with self.assertRaises(FrozenInstanceError): - profiler_config.discrete = False - - def test_frozen_config(self): - """Test that modifying frozen keys in ProfilerConfig raises exceptions.""" - from dataclasses import FrozenInstanceError - - from verl.utils.profiler.config import ProfilerConfig - - # Create a new ProfilerConfig instance - config = ProfilerConfig(discrete=True, all_ranks=False, ranks=[0]) - - # Test direct attribute assignment - with self.assertRaises(FrozenInstanceError): - config.discrete = False - - with self.assertRaises(FrozenInstanceError): - config.all_ranks = True - - with self.assertRaises(FrozenInstanceError): - config.ranks = [1, 2, 3] - - # Test dictionary-style assignment - with self.assertRaises(TypeError): - config["discrete"] = False - - with self.assertRaises(TypeError): - config["all_ranks"] = True - - with self.assertRaises(TypeError): - config["ranks"] = [1, 2, 3] - - config["extra"]["key"] = "value" - - -class TestNsightSystemsProfiler(unittest.TestCase): - """Test suite for NsightSystemsProfiler functionality. - - Test Plan: - 1. Initialization: Verify profiler state after creation - 2. Basic Profiling: Test start/stop functionality - 3. Discrete Mode: Test discrete profiling behavior - 4. Annotation: Test the annotate decorator in both normal and discrete modes - 5. Config Validation: Verify proper config initialization from OmegaConf - """ - - def setUp(self): - self.config = ProfilerConfig(all_ranks=True) - self.rank = 0 - self.profiler = NsightSystemsProfiler(self.rank, self.config) - - def test_initialization(self): - self.assertEqual(self.profiler.this_rank, True) - self.assertEqual(self.profiler.this_step, False) - self.assertEqual(self.profiler.discrete, False) - - def test_start_stop_profiling(self): - with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop: - # Test start - self.profiler.start() - self.assertTrue(self.profiler.this_step) - mock_start.assert_called_once() - - # Test stop - self.profiler.stop() - self.assertFalse(self.profiler.this_step) - mock_stop.assert_called_once() - - def test_discrete_profiling(self): - discrete_config = ProfilerConfig(discrete=True, all_ranks=True) - profiler = NsightSystemsProfiler(self.rank, discrete_config) - - with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop: - profiler.start() - self.assertTrue(profiler.this_step) - mock_start.assert_not_called() # Shouldn't start immediately in discrete mode - - profiler.stop() - self.assertFalse(profiler.this_step) - mock_stop.assert_not_called() # Shouldn't stop immediately in discrete mode - - def test_annotate_decorator(self): - mock_self = MagicMock() - mock_self.profiler = self.profiler - mock_self.profiler.this_step = True - - @NsightSystemsProfiler.annotate(message="test") - def test_func(self, *args, **kwargs): - return "result" - - with ( - patch("torch.cuda.profiler.start") as mock_start, - patch("torch.cuda.profiler.stop") as mock_stop, - patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range, - patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range, - ): - result = test_func(mock_self) - self.assertEqual(result, "result") - mock_start_range.assert_called_once() - mock_end_range.assert_called_once() - mock_start.assert_not_called() # Not discrete mode - mock_stop.assert_not_called() # Not discrete mode - - def test_annotate_discrete_mode(self): - discrete_config = ProfilerConfig(discrete=True, all_ranks=True) - profiler = NsightSystemsProfiler(self.rank, discrete_config) - mock_self = MagicMock() - mock_self.profiler = profiler - mock_self.profiler.this_step = True - - @NsightSystemsProfiler.annotate(message="test") - def test_func(self, *args, **kwargs): - return "result" - - with ( - patch("torch.cuda.profiler.start") as mock_start, - patch("torch.cuda.profiler.stop") as mock_stop, - patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range, - patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range, - ): - result = test_func(mock_self) - self.assertEqual(result, "result") - mock_start_range.assert_called_once() - mock_end_range.assert_called_once() - mock_start.assert_called_once() # Should start in discrete mode - mock_stop.assert_called_once() # Should stop in discrete mode - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/utils/test_rollout_trace_on_cpu.py b/tests/utils/test_rollout_trace_on_cpu.py deleted file mode 100644 index 04dfbeef8..000000000 --- a/tests/utils/test_rollout_trace_on_cpu.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys -from unittest.mock import MagicMock, patch - -import pytest - -from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op - - -@pytest.fixture(autouse=True) -def reset_rollout_trace_config_singleton(): - """Fixture to reset the RolloutTraceConfig singleton before each test.""" - RolloutTraceConfig.reset() - - -@pytest.fixture -def mock_weave_client(): - """Mocks the weave module and its client, yielding the mock client.""" - mock_weave = MagicMock() - mock_client = MagicMock() - mock_call = MagicMock() - mock_client.create_call.return_value = mock_call - mock_weave.init.return_value = mock_client - - # Also mock the call_context if it's used internally by the decorator - mock_weave.trace.context.call_context.return_value = MagicMock() - - with patch.dict(sys.modules, {"weave": mock_weave, "weave.trace.context": mock_weave.trace.context}): - yield mock_client - - -class TracedClass: - @rollout_trace_op - # @weave.op - # @mlflow.trace - async def my_method(self, a, b="default"): - return f"result: {a}, {b}" - - @rollout_trace_op - # @weave.op - # @mlflow.trace - async def middle_method(self, a, b="default"): - await self.my_method("test_a1", b="test_b1") - return f"result: {a}, {b}" - - @rollout_trace_op - # @mlflow.trace - async def my_method_with_exception(self): - raise ValueError("Test Exception") - - async def upper_method(self): - await self.my_method("test_a0", b="test_b0") - await self.middle_method("test_a2", b="test_b2") - return True - - -class UntracedClass: - @rollout_trace_op - async def my_method(self, x): - return x * 2 - - -async def test_rollout_trace_on_untraced_class(): - """Tests that the decorator works correctly when no backend is configured.""" - instance = UntracedClass() - assert await instance.my_method(10) == 20 - - -async def test_rollout_trace_with_tracer(mock_weave_client): - """Tests that the decorator calls the tracer's methods correctly.""" - RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") - instance = TracedClass() - assert RolloutTraceConfig.get_client() is mock_weave_client - - result = await instance.my_method("test_a", b="test_b") - - assert result == "result: test_a, test_b" - mock_weave_client.create_call.assert_called_once() - call_kwargs = mock_weave_client.create_call.call_args.kwargs - assert call_kwargs["op"] == "TracedClass.my_method" - expected_inputs = {"a": "test_a", "b": "test_b"} - assert call_kwargs["inputs"] == expected_inputs - - mock_call = mock_weave_client.create_call.return_value - mock_weave_client.finish_call.assert_called_once_with(mock_call, output=result) - - -async def test_rollout_trace_with_exception(mock_weave_client): - """Tests that `finish` is called with the exception when one is raised.""" - RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") - instance = TracedClass() - - with pytest.raises(ValueError, match="Test Exception"): - await instance.my_method_with_exception() - - mock_weave_client.create_call.assert_called_once() - mock_call = mock_weave_client.create_call.return_value - mock_weave_client.finish_call.assert_called_once() - - # Check that finish_call was called with the exception - args, kwargs = mock_weave_client.finish_call.call_args - assert args[0] == mock_call - assert "exception" in kwargs - assert isinstance(kwargs["exception"], ValueError) - - -async def test_rollout_trace_with_dummy_backend(mock_weave_client): - """Tests that the tracer is not called when the backend is 'dummy'.""" - RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="dummy") - instance = TracedClass() - - await instance.my_method("test_a") - - mock_weave_client.create_call.assert_not_called() - - -@pytest.mark.skipif( - os.environ.get("RUN_WEAVE_INTEGRATION_TESTS", "false").lower() != "true", - reason="Skipping weave integration test. Set RUN_WEAVE_INTEGRATION_TESTS=true to run.", -) -async def test_rollout_trace_with_real_weave_backend(): - """Integration test with a real weave backend.""" - - # This assumes that the weave environment (e.g., project) is configured - RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") - - instance = TracedClass() - - with rollout_trace_attr(step=1, sample_index=2, rollout_n=3): - await instance.upper_method() - - with pytest.raises(ValueError, match="Test Exception"): - await instance.my_method_with_exception() - - print("\nWeave integration test ran successfully. Check your weave project for the trace.") - - -@pytest.mark.skipif( - os.environ.get("RUN_MLFLOW_INTEGRATION_TESTS", "false").lower() != "true", - reason="Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.", -) -async def test_rollout_trace_with_real_mlflow_backend(): - """Integration test with a real mlflow backend.""" - - # This assumes that the mlflow environment (e.g., project) is configured - RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="mlflow") - - instance = TracedClass() - - with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name="agent_run"): - assert await instance.upper_method() - - # with pytest.raises(ValueError, match="Test Exception"): - # await instance.my_method_with_exception() - - print("\nWeave integration test ran successfully. Check your weave project for the trace.") diff --git a/tests/utils/test_seqlen_balancing.py b/tests/utils/test_seqlen_balancing.py deleted file mode 100644 index 9de777f1c..000000000 --- a/tests/utils/test_seqlen_balancing.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -from verl import DataProto -from verl.utils.model import create_random_mask -from verl.utils.seqlen_balancing import ( - ceildiv, - get_reverse_idx, - prepare_dynamic_batch, - rearrange_micro_batches, - restore_dynamic_batch, -) - - -def test_seqlen_balancing(): - input_ids = torch.randint(low=0, high=10, size=(20, 100)) - - attention_mask = create_random_mask( - input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5 - ) - data = {"input_ids": input_ids, "attention_mask": attention_mask} - dataproto = DataProto.from_single_dict(data) - micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300) - batch = torch.cat(micro_batches) - micro_bsz_idx = [] - for idx in micro_bsz_idx_lst: - micro_bsz_idx.extend(idx) - reverse_idx_map = get_reverse_idx(micro_bsz_idx) - reverse_idx_map = torch.tensor(reverse_idx_map) - new_batch = batch[reverse_idx_map] - torch.testing.assert_close(new_batch, dataproto.batch) - - -def test_dynamic_batch(): - input_ids = torch.randint(low=0, high=10, size=(20, 100)) - - attention_mask = create_random_mask( - input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5 - ) - data = {"input_ids": input_ids, "attention_mask": attention_mask} - dataproto = DataProto.from_single_dict(data) - micro_batches, micro_bsz_idx_lst = prepare_dynamic_batch(dataproto, max_token_len=300) - input_ids = torch.cat([micro_batch.batch["input_ids"] for micro_batch in micro_batches], dim=0) - input_ids = restore_dynamic_batch(input_ids, micro_bsz_idx_lst) - torch.testing.assert_close(input_ids, dataproto.batch["input_ids"]) - - -def _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb): - # 1) init process group & CUDA - torch.cuda.set_device(rank) - dist.init_process_group( - backend="nccl", - init_method=init_method, - world_size=world_size, - rank=rank, - ) - - # 2) build a small random batch (each rank different length to force mismatch) - torch.manual_seed(42 + rank) - input_ids = torch.randint(0, 10, (20 + rank * 5, 100), device=f"cuda:{rank}") - attention_mask = create_random_mask( - input_ids=input_ids, - max_ratio_of_left_padding=0.1, - max_ratio_of_valid_token=0.9, - min_ratio_of_valid_token=0.5, - ) - dp = {"input_ids": input_ids, "attention_mask": attention_mask} - proto = DataProto.from_single_dict(dp) - batch = proto.batch - - # 3) call rearrange_micro_batches with one of the two params under test - micros, idx_lst = rearrange_micro_batches( - batch, - max_token_len=max_token_len, - dp_group=dist.group.WORLD, - same_micro_num_in_dp=use_same_dp, - min_num_micro_batch=min_mb, - ) - - # 4) check the enforced counts - seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) - total_seqlen = seq_len_effective.sum().item() - local = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len)) - - if min_mb is not None: - expected = max(local, min_mb) - assert len(micros) == expected - if use_same_dp: - # gather all local_counts - counts = [torch.zeros(1, device=f"cuda:{rank}") for _ in range(world_size)] - counts[rank].fill_(local) - dist.all_gather(counts, counts[rank]) - expected = max(int(c.item()) for c in counts) - assert len(micros) == expected - else: - # if neither, we get the local natural count - assert len(micros) == local - - # 5) reconstruction sanity: concat→reverse_idx→orig - flat = torch.cat(micros, dim=0) - idx = [] - for sub in idx_lst: - idx.extend(sub) - inv = get_reverse_idx(idx) - inv = torch.tensor(inv, device=flat.device) - reconstructed = flat[inv] - torch.testing.assert_close(reconstructed, batch) - - dist.destroy_process_group() - - -def test_dataproto_split_uneven(): - """Test DataProto.split with uneven splits""" - # Create test data with 10 items - input_ids = torch.randint(low=0, high=10, size=(10, 5)) - attention_mask = torch.ones(10, 5) - data = {"input_ids": input_ids, "attention_mask": attention_mask} - dataproto = DataProto.from_single_dict(data) - - # Test split with size 3 (should create chunks of [3, 3, 3, 1]) - splits = dataproto.split(3) - assert len(splits) == 4 - assert len(splits[0]) == 3 - assert len(splits[1]) == 3 - assert len(splits[2]) == 3 - assert len(splits[3]) == 1 - - reconstructed = DataProto.concat(splits) - torch.testing.assert_close(reconstructed.batch["input_ids"], dataproto.batch["input_ids"]) - torch.testing.assert_close(reconstructed.batch["attention_mask"], dataproto.batch["attention_mask"]) - - # Test split with size equal to length (should create one chunk) - splits = dataproto.split(10) - assert len(splits) == 1 - assert len(splits[0]) == 10 - - # Test split with size larger than length (should create one chunk with all data) - splits = dataproto.split(15) - assert len(splits) == 1 - assert len(splits[0]) == 10 - - # Test with non-tensor batch data - import numpy as np - - data_with_non_tensor = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": np.array([f"label_{i}" for i in range(10)], dtype=object), - } - dataproto_with_non_tensor = DataProto.from_single_dict(data_with_non_tensor) - - splits = dataproto_with_non_tensor.split(3) - assert len(splits) == 4 - assert len(splits[0]) == 3 - assert len(splits[1]) == 3 - assert len(splits[2]) == 3 - assert len(splits[3]) == 1 - - # Verify non-tensor data integrity - reconstructed = DataProto.concat(splits) - np.testing.assert_array_equal( - reconstructed.non_tensor_batch["labels"], dataproto_with_non_tensor.non_tensor_batch["labels"] - ) - - -def test_seqlen_balancing_distributed_params(tmp_path): - world_size = 2 - init_file = tmp_path / "dist_init" - init_file.write_text("") # empty file - init_method = f"file://{init_file}" - - # test min_num_micro_batch only - mp.spawn( - _worker, - args=(world_size, init_method, 300, False, 4), - nprocs=world_size, - join=True, - ) - - # test same_micro_num_in_dp only - mp.spawn( - _worker, - args=(world_size, init_method, 300, True, None), - nprocs=world_size, - join=True, - ) diff --git a/tests/utils/test_temp_env_on_cpu.py b/tests/utils/test_temp_env_on_cpu.py deleted file mode 100644 index 851e4cbe4..000000000 --- a/tests/utils/test_temp_env_on_cpu.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import pytest - -from verl.utils.py_functional import temp_env_var - - -@pytest.fixture(autouse=True) -def clean_env(): - """Fixture to clean up environment variables before and after each test.""" - # Store original environment state - original_env = dict(os.environ) - - # Clean up any test variables that might exist - test_vars = ["TEST_VAR", "TEST_VAR_2", "EXISTING_VAR"] - for var in test_vars: - if var in os.environ: - del os.environ[var] - - # Yield control to the test function - yield - - # Restore original environment state after test - os.environ.clear() - os.environ.update(original_env) - - -def test_set_new_env_var(): - """Test setting a new environment variable that didn't exist before.""" - # Ensure variable doesn't exist - assert "TEST_VAR" not in os.environ - - with temp_env_var("TEST_VAR", "test_value"): - # Variable should be set inside context - assert os.environ["TEST_VAR"] == "test_value" - assert "TEST_VAR" in os.environ - - # Variable should be removed after context - assert "TEST_VAR" not in os.environ - - -def test_restore_existing_env_var(): - """Test restoring an environment variable that already existed.""" - # Set up existing variable - os.environ["EXISTING_VAR"] = "original_value" - - with temp_env_var("EXISTING_VAR", "temporary_value"): - # Variable should be temporarily changed - assert os.environ["EXISTING_VAR"] == "temporary_value" - - # Variable should be restored to original value - assert os.environ["EXISTING_VAR"] == "original_value" - - -def test_env_var_restored_on_exception(): - """Test that environment variables are restored even when exceptions occur.""" - # Set up existing variable - os.environ["EXISTING_VAR"] = "original_value" - - with pytest.raises(ValueError): - with temp_env_var("EXISTING_VAR", "temporary_value"): - # Verify variable is set - assert os.environ["EXISTING_VAR"] == "temporary_value" - # Raise exception - raise ValueError("Test exception") - - # Variable should still be restored despite exception - assert os.environ["EXISTING_VAR"] == "original_value" - - -def test_nested_context_managers(): - """Test nested temp_env_var context managers.""" - # Set up original variable - os.environ["TEST_VAR"] = "original" - - with temp_env_var("TEST_VAR", "level1"): - assert os.environ["TEST_VAR"] == "level1" - - with temp_env_var("TEST_VAR", "level2"): - assert os.environ["TEST_VAR"] == "level2" - - # Should restore to level1 - assert os.environ["TEST_VAR"] == "level1" - - # Should restore to original - assert os.environ["TEST_VAR"] == "original" - - -def test_multiple_different_vars(): - """Test setting multiple different environment variables.""" - # Set up one existing variable - os.environ["EXISTING_VAR"] = "existing_value" - - with temp_env_var("EXISTING_VAR", "modified"): - with temp_env_var("TEST_VAR", "new_value"): - assert os.environ["EXISTING_VAR"] == "modified" - assert os.environ["TEST_VAR"] == "new_value" - - # Check restoration - assert os.environ["EXISTING_VAR"] == "existing_value" - assert "TEST_VAR" not in os.environ - - -def test_empty_string_value(): - """Test setting environment variable to empty string.""" - with temp_env_var("TEST_VAR", ""): - assert os.environ["TEST_VAR"] == "" - assert "TEST_VAR" in os.environ - - # Should be removed after context - assert "TEST_VAR" not in os.environ - - -def test_overwrite_with_empty_string(): - """Test overwriting existing variable with empty string.""" - os.environ["EXISTING_VAR"] = "original" - - with temp_env_var("EXISTING_VAR", ""): - assert os.environ["EXISTING_VAR"] == "" - - # Should restore original value - assert os.environ["EXISTING_VAR"] == "original" - - -def test_context_manager_returns_none(): - """Test that context manager yields None.""" - with temp_env_var("TEST_VAR", "value") as result: - assert result is None - assert os.environ["TEST_VAR"] == "value" diff --git a/tests/utils/test_timeout_decorator_cpu.py b/tests/utils/test_timeout_decorator_cpu.py deleted file mode 100644 index 3417469db..000000000 --- a/tests/utils/test_timeout_decorator_cpu.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -import sys -import threading -import time - -import pytest # Import pytest - -from verl.utils.py_functional import timeout_limit as timeout - -# --- Test Task Functions --- -TEST_TIMEOUT_SECONDS = 1.5 # Timeout duration for tests -LONG_TASK_DURATION = TEST_TIMEOUT_SECONDS + 0.5 # Duration slightly longer than timeout - - -@timeout(seconds=TEST_TIMEOUT_SECONDS) # Keep global decorator for mp tests -def quick_task(x): - """A task that completes quickly.""" - time.sleep(0.1) - return "quick_ok" - - -@timeout(seconds=TEST_TIMEOUT_SECONDS) # Keep global decorator for mp tests -def slow_task(x): - """A task that takes longer than the timeout.""" - time.sleep(LONG_TASK_DURATION) - return "slow_finished" # This return value indicates it didn't time out - - -# REMOVE global decorator here -def task_raises_value_error(): # Now truly not globally decorated - """A task that intentionally raises a ValueError.""" - raise ValueError("Specific value error from task") - - -# --- Top-level function for signal test in subprocess --- -# Keep this decorated globally for the specific subprocess test case -@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True) -def top_level_decorated_quick_task_signal(): - """A pickleable top-level function decorated with signal timeout.""" - # Assuming this calls the logic of quick_task directly for the test purpose - time.sleep(0.1) - return "quick_ok_signal_subprocess" # Different return for clarity if needed - - -# --- Top-level function for signal test in subprocess --- -# Keep this decorated globally for the specific subprocess test case -@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True) -def top_level_decorated_slow_task_signal(): - """A pickleable top-level function decorated with signal timeout.""" - time.sleep(LONG_TASK_DURATION) - return "slow_finished" - - -# --- NEW: Top-level helper function to run target in process --- -def run_target_and_put_in_queue(target_func, q): - """ - Top-level helper function to run a target function and put its result or exception into a queue. - This function is pickleable and can be used as the target for multiprocessing.Process. - """ - try: - result = target_func() - q.put(("success", result)) - except Exception as e: - q.put(("error", e)) - - -# Use a module-level fixture to set the start method on macOS -@pytest.fixture(scope="module", autouse=True) # Changed scope to module -def set_macos_start_method(): - if sys.platform == "darwin": - # Force fork method on macOS to avoid pickling issues with globally decorated functions - # when running tests via pytest discovery. - current_method = multiprocessing.get_start_method(allow_none=True) - # Only set if not already set or if set to something else (less likely in test run) - if current_method is None or current_method != "fork": - try: - multiprocessing.set_start_method("fork", force=True) - except RuntimeError: - # Might fail if context is already started, ignore in that case. - pass - - -def test_quick_task(): # Renamed from test_multiprocessing_quick_task - """Tests timeout handles a quick task correctly.""" - # Call the globally decorated function directly - result = quick_task(1) - assert result == "quick_ok" # Use pytest assert - - -def test_slow_task_timeout(): # Renamed from test_multiprocessing_slow_task_timeout - """Tests timeout correctly raises TimeoutError for a slow task.""" - # Call the globally decorated function directly within pytest.raises - with pytest.raises(TimeoutError) as excinfo: # Use pytest.raises - slow_task(1) - # Check the error message from the multiprocessing implementation - assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value) # Use pytest assert - - -def test_internal_exception(): # Renamed from test_multiprocessing_internal_exception - """Tests timeout correctly propagates internal exceptions.""" - # Apply the default timeout decorator dynamically to the undecorated function - decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS)(task_raises_value_error) # Apply decorator dynamically - with pytest.raises(ValueError) as excinfo: # Use pytest.raises - decorated_task() # Call the dynamically decorated function - assert str(excinfo.value) == "Specific value error from task" # Use pytest assert - - -# --- Test the signal implementation (use_signals=True) --- -# Note: As per py_functional.py, use_signals=True currently falls back to -# multiprocessing on POSIX. These tests verify that behavior. - - -def test_signal_quick_task_main_process(): # Removed self - """Tests signal timeout handles a quick task correctly in the main process.""" - - # Apply the signal decorator dynamically - def plain_quick_task_logic(): - time.sleep(0.1) - return "quick_ok_signal" - - decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_quick_task_logic) - assert decorated_task() == "quick_ok_signal" # Use pytest assert - - -def test_signal_slow_task_main_process_timeout(): # Removed self - """Tests signal timeout correctly raises TimeoutError for a slow task in the main process.""" - - # Apply the signal decorator dynamically - def plain_slow_task_logic(): - time.sleep(LONG_TASK_DURATION) - return "slow_finished_signal" - - decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_slow_task_logic) - with pytest.raises(TimeoutError) as excinfo: # Use pytest.raises - decorated_task() - # Check the error message (falls back to multiprocessing message on POSIX) - assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value) # Use pytest assert - - -@pytest.mark.skip(reason="this test won't pass. Just to show why use_signals should not be used") -def test_signal_in_thread_does_not_timeout(): - """ - Tests that signal-based timeout does NOT work reliably in a child thread. - The TimeoutError from the signal handler is not expected to be raised. - """ - result_container = [] # Use a list to store result from thread - exception_container = [] # Use a list to store exception from thread - - @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True) - def slow_task_in_thread(): - try: - print("Thread: Starting slow task...") - time.sleep(LONG_TASK_DURATION) - print("Thread: Slow task finished.") - return "slow_finished_in_thread" - except Exception as e: - # Catch any exception within the thread's target function - print(f"Thread: Caught exception: {e}") - exception_container.append(e) - return None # Indicate failure - - def thread_target(): - try: - # Run the decorated function inside the thread - res = slow_task_in_thread() - if res is not None: - result_container.append(res) - except Exception as e: - # This might catch exceptions happening *outside* the decorated function - # but still within the thread target, though less likely here. - print(f"Thread Target: Caught exception: {e}") - exception_container.append(e) - - thread = threading.Thread(target=thread_target) - print("Main: Starting thread...") - thread.start() - # Wait longer than the timeout + task duration to ensure the thread finishes - # regardless of whether timeout worked or not. - thread.join(timeout=LONG_TASK_DURATION + 1) - - assert len(exception_container) == 1 - assert isinstance(exception_container[0], TimeoutError) - assert not result_container - - -def test_in_thread_timeout(): - result_container = [] # Use a list to store result from thread - exception_container = [] # Use a list to store exception from thread - - @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=False) - def slow_task_in_thread(): - try: - print("Thread: Starting slow task...") - time.sleep(LONG_TASK_DURATION) - print("Thread: Slow task finished.") - return "slow_finished_in_thread" - except Exception as e: - # Catch any exception within the thread's target function - print(f"Thread: Caught exception: {e}") - exception_container.append(e) - return None # Indicate failure - - def thread_target(): - try: - # Run the decorated function inside the thread - res = slow_task_in_thread() - if res is not None: - result_container.append(res) - except Exception as e: - # This might catch exceptions happening *outside* the decorated function - # but still within the thread target, though less likely here. - print(f"Thread Target: Caught exception: {e}") - exception_container.append(e) - - thread = threading.Thread(target=thread_target) - print("Main: Starting thread...") - thread.start() - # Wait longer than the timeout + task duration to ensure the thread finishes - # regardless of whether timeout worked or not. - thread.join(timeout=LONG_TASK_DURATION + 1) - - assert len(exception_container) == 1 - assert isinstance(exception_container[0], TimeoutError) - assert not result_container diff --git a/tests/utils/test_torch_functional.py b/tests/utils/test_torch_functional.py deleted file mode 100644 index 900cb5d54..000000000 --- a/tests/utils/test_torch_functional.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -from verl.utils.torch_functional import distributed_masked_mean, distributed_mean_max_min_std, masked_mean - - -def _worker_mean(rank: int, world_size: int, rendezvous_file: str): - # 1) set GPU and init NCCL - torch.cuda.set_device(rank) - dist.init_process_group( - backend="nccl", - init_method=f"file://{rendezvous_file}", - rank=rank, - world_size=world_size, - ) - - # each rank holds tensor [rank+1] - local = torch.tensor([float(rank + 1)], device=f"cuda:{rank}") - mean, gmax, gmin, gstd = distributed_mean_max_min_std(local, True, True, True) - - values = [float(i + 1) for i in range(world_size)] - exp_mean = sum(values) / len(values) - exp_max = max(values) - exp_min = min(values) - var = sum((x - exp_mean) ** 2 for x in values) / (len(values) - 1) - exp_std = var**0.5 - - # all ranks should see the same result - assert torch.allclose(mean.cpu(), torch.tensor(exp_mean)), f"mean@{rank}" - assert torch.allclose(gmax.cpu(), torch.tensor(exp_max)), f"max@{rank}" - assert torch.allclose(gmin.cpu(), torch.tensor(exp_min)), f"min@{rank}" - assert torch.allclose(gstd.cpu(), torch.tensor(exp_std)), f"std@{rank}" - - dist.destroy_process_group() - - -@pytest.mark.parametrize( - "value,mask,gt", - [ - ([1.0, 2.0, 3.0, 4.0], [1, 0, 0, 1], 2.5), - ([1.0, 2.0, float("nan"), 4.0], [1, 0, 0, 1], 2.5), - ([1.0, 2.0, float("nan"), 4.0], [1, 0, 1, 0], float("nan")), - ], -) -def test_masked_mean(value, mask, gt): - res = masked_mean(torch.tensor(value), torch.tensor(mask)) - gt = torch.tensor(gt) - assert torch.allclose(res, gt) or (torch.isnan(res) and torch.isnan(gt)) - - -@pytest.mark.parametrize("world_size", [2, 4]) -def test_distributed_mean_max_min_std(world_size, tmp_path): - rendezvous_file = str(tmp_path / "rdzv_mean") - os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) - - mp.spawn( - fn=_worker_mean, - args=(world_size, rendezvous_file), - nprocs=world_size, - join=True, - ) - - -def _worker_mask(rank: int, world_size: int, rendezvous_file: str): - torch.cuda.set_device(rank) - dist.init_process_group( - backend="nccl", - init_method=f"file://{rendezvous_file}", - rank=rank, - world_size=world_size, - ) - - # build per‐rank tensor and mask - local_tensor = torch.tensor([rank * 2 + 1.0, rank * 2 + 2.0], device=f"cuda:{rank}") - if rank == 0: - mask = torch.tensor([1, 0], device=f"cuda:{rank}", dtype=torch.float32) - else: - mask = torch.tensor([0, 1], device=f"cuda:{rank}", dtype=torch.float32) - - gmean = distributed_masked_mean(local_tensor, mask) - - valid_values = [1.0] + [2 * i + 2.0 for i in range(1, world_size)] - expected_mean = sum(valid_values) / len(valid_values) - assert torch.allclose(gmean.cpu(), torch.tensor(expected_mean)), f"masked_mean@{rank}" - - dist.destroy_process_group() - - -@pytest.mark.parametrize("world_size", [2, 4]) -def test_distributed_masked_mean(world_size, tmp_path): - rendezvous_file = str(tmp_path / "rdzv_mask") - os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) - - mp.spawn( - fn=_worker_mask, - args=(world_size, rendezvous_file), - nprocs=world_size, - join=True, - ) diff --git a/tests/workers/reward_manager/test_registry_on_cpu.py b/tests/workers/reward_manager/test_registry_on_cpu.py deleted file mode 100644 index 9932ae891..000000000 --- a/tests/workers/reward_manager/test_registry_on_cpu.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -# Assuming REWARD_MANAGER_REGISTRY is defined somewhere in the module -from verl.workers.reward_manager.registry import REWARD_MANAGER_REGISTRY, get_reward_manager_cls, register - - -@pytest.fixture -def setup(): - """Setup test cases with a mock registry.""" - REWARD_MANAGER_REGISTRY.clear() - REWARD_MANAGER_REGISTRY.update({"manager1": "Manager1Class", "manager2": "Manager2Class"}) - return REWARD_MANAGER_REGISTRY - - -def test_get_existing_manager(setup): - """Test getting an existing reward manager class.""" - assert get_reward_manager_cls("manager1") == "Manager1Class" - assert get_reward_manager_cls("manager2") == "Manager2Class" - - -def test_get_nonexistent_manager(setup): - """Test getting a non-existent reward manager raises ValueError.""" - with pytest.raises(ValueError) as excinfo: - get_reward_manager_cls("unknown_manager") - assert "Unknown reward manager: unknown_manager" in str(excinfo.value) - - -def test_case_sensitivity(setup): - """Test that manager names are case-sensitive.""" - with pytest.raises(ValueError): - get_reward_manager_cls("MANAGER1") - with pytest.raises(ValueError): - get_reward_manager_cls("Manager1") - - -def test_empty_registry(setup): - """Test behavior when registry is empty.""" - REWARD_MANAGER_REGISTRY.clear() - with pytest.raises(ValueError) as excinfo: - get_reward_manager_cls("any_manager") - assert "Unknown reward manager: any_manager" in str(excinfo.value) - - -def test_register_new_class(setup): - """Test registering a new class with the decorator.""" - - @register("test_manager") - class TestManager: - pass - - assert "test_manager" in REWARD_MANAGER_REGISTRY - assert REWARD_MANAGER_REGISTRY["test_manager"] == TestManager - - -def test_register_different_classes_same_name(setup): - """Test that registering different classes with same name raises ValueError.""" - - @register("conflict_manager") - class Manager1: - pass - - with pytest.raises(ValueError): - - @register("conflict_manager") - class Manager2: - pass - - assert REWARD_MANAGER_REGISTRY["conflict_manager"] == Manager1 - - -def test_decorator_returns_original_class(setup): - """Test that the decorator returns the original class unchanged.""" - - @register("return_test") - class OriginalClass: - def method(setup): - return 42 - - assert OriginalClass().method() == 42 - assert REWARD_MANAGER_REGISTRY["return_test"] == OriginalClass diff --git a/tests/workers/rollout/async_rollout_utils.py b/tests/workers/rollout/async_rollout_utils.py deleted file mode 100644 index 22f20291e..000000000 --- a/tests/workers/rollout/async_rollout_utils.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import ray -from omegaconf import DictConfig - -from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup -from verl.single_controller.ray.base import create_colocated_worker_cls -from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role -from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker -from verl.workers.rollout.async_server import AsyncLLMServerManager - - -def init_async_rollout_manager(config: DictConfig) -> AsyncLLMServerManager: - # =========================== 1. Create hybrid ActorRollout workers =========================== - role_worker_mapping = { - Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker), - } - global_pool_id = "global_pool" - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - mapping = { - Role.ActorRollout: global_pool_id, - } - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) - resource_pool_manager.create_resource_pool() - resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} - - # create actor and rollout - resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs( - cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout" - ) - resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls - - all_wg = {} - for resource_pool, class_dict in resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) - wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) - spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) - all_wg.update(spawn_wg) - actor_rollout_wg = all_wg["actor_rollout"] - actor_rollout_wg.init_model() - - # =========================== 2. Create AsyncLLMServerManager =========================== - async_rollout_manager = AsyncLLMServerManager( - config=config, - worker_group=actor_rollout_wg, - ) - - return async_rollout_manager diff --git a/tests/workers/rollout/perf/vllm_async_rollout.py b/tests/workers/rollout/perf/vllm_async_rollout.py deleted file mode 100644 index dbcd255df..000000000 --- a/tests/workers/rollout/perf/vllm_async_rollout.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Compare vLLM AsyncLLM backend: ExternalRayDistributedExecutor(remote call) vs RayDistributedExecutor(compiled graph) - -1. Prepare openai/gsm8k dataset -python3 examples/data_preprocess/gsm8k.py - -2. Run perf test -python3 tests/workers/rollout/perf/vllm_async_rollout.py >perf.log 2>&1 - -hardware: Nvidia 8*H20 -packages: -- torch==2.6.0 -- vllm==0.8.5 - -[DEBUG] backend: sync, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 21.27 secs -[DEBUG] backend: zeromq, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 23.40 secs -[DEBUG] backend: ray, n_gpus_per_node: 8, batch_size: 2048, step: 0, step_time: 25.33 secs -""" - -import os -import time - -import ray -from omegaconf import DictConfig -from torch.utils.data import SequentialSampler -from torchdata.stateful_dataloader import StatefulDataLoader - -from tests.experimental.agent_loop.agent_utils import AgentLoopManager, RayWorkerGroup, init_agent_loop_manager -from verl.protocol import DataProto -from verl.utils import hf_tokenizer -from verl.utils.dataset import RLHFDataset -from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn - - -def init_config(n_gpus_per_node) -> DictConfig: - import os - - from hydra import compose, initialize_config_dir - - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): - config = compose(config_name="ppo_trainer") - config.trainer.n_gpus_per_node = n_gpus_per_node - config.data.train_batch_size = 128 - config.data.return_raw_chat = True - config.actor_rollout_ref.model.path = "Qwen/Qwen2.5-7B-Instruct" - config.actor_rollout_ref.rollout.mode = "async" - config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 - config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9 - config.actor_rollout_ref.rollout.multi_turn.format = "hermes" - config.actor_rollout_ref.rollout.prompt_length = 4096 - config.actor_rollout_ref.rollout.response_length = 4096 - config.actor_rollout_ref.rollout.n = 16 - - # test sleep/wake_up with fsdp offload - config.actor_rollout_ref.actor.fsdp_config.param_offload = True - config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True - - return config - - -def initialize(config, backend) -> tuple[AgentLoopManager | RayWorkerGroup, StatefulDataLoader]: - env_vars = { - "NCCL_DEBUG": "WARN", - "VLLM_USE_V1": "1", - "VERL_VLLM_DISTRIBUTED_BACKEND": backend, - } - ray.init(runtime_env={"env_vars": env_vars}) - - # STEP 1: init async llm server - server = init_agent_loop_manager(config) - - # STEP 2: create dataloader - tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path) - dataset = RLHFDataset( - data_files=os.path.expanduser("~/data/gsm8k/train.parquet"), - tokenizer=tokenizer, - config=config.data, - ) - dataloader = StatefulDataLoader( - dataset=dataset, - batch_size=config.data.get("gen_batch_size", config.data.train_batch_size), - num_workers=config.data.get("dataloader_num_workers", 8), - drop_last=True, - collate_fn=default_collate_fn, - sampler=SequentialSampler(dataset), - ) - - return server, dataloader - - -def perf_rollout(mode, backend, n_gpus_per_node, num_steps): - config = init_config(n_gpus_per_node) - config.actor_rollout_ref.rollout.mode = mode - agent_loop_manager, dataloader = initialize(config, backend) - - for step, batch in enumerate(dataloader): - batch: DataProto = DataProto.from_single_dict(batch) - batch = batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids", "raw_prompt"], - ) - t_start = time.time() - gen_batch = agent_loop_manager.generate_sequences(batch) - t_end = time.time() - print( - f"[DEBUG] backend: {backend}, n_gpus_per_node: {n_gpus_per_node}, batch_size: {len(gen_batch)}, " - f"step: {step}, step_time: {t_end - t_start:.2f} secs" - ) - if step + 1 >= num_steps: - break - - ray.shutdown() - - -if __name__ == "__main__": - num_steps = 1 - n_gpus_per_node = 8 - - # test_cases = [("sync", "sync"), ("async", "zeromq"), ("async", "ray")] - test_cases = [("async", "zeromq"), ("async", "ray")] - for mode, backend in test_cases: - perf_rollout(mode=mode, backend=backend, n_gpus_per_node=n_gpus_per_node, num_steps=num_steps) diff --git a/tests/workers/rollout/resource/tool_configs/mcp_server.json b/tests/workers/rollout/resource/tool_configs/mcp_server.json deleted file mode 100644 index 9ed41f10b..000000000 --- a/tests/workers/rollout/resource/tool_configs/mcp_server.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "mcpServers": { - "Tavily Expert": { - "url": "https://tavily.api.tadata.com/mcp/tavily/your_expert", - "auth_token": "your_tavily_token" - } - } -} \ No newline at end of file diff --git a/tests/workers/rollout/resource/tool_configs/mcp_tool_config b/tests/workers/rollout/resource/tool_configs/mcp_tool_config deleted file mode 100644 index a9a45bd0b..000000000 --- a/tests/workers/rollout/resource/tool_configs/mcp_tool_config +++ /dev/null @@ -1,11 +0,0 @@ -tools: - - class_name: verl.tools.mcp_search_tool.MCPSearchTool - config: - rate_limit: 120 - timeout: 120 - type: mcp - mcp: - mcp_servers_config_path: ./resource/tool_configs/mcp_server.json - # optional - tool_selected_list: - - tavily_search_tool \ No newline at end of file diff --git a/tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config b/tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config deleted file mode 100644 index aa3f1eec5..000000000 --- a/tests/workers/rollout/resource/tool_configs/sandbox_fusion_tool_config +++ /dev/null @@ -1,17 +0,0 @@ -tools: - - class_name: "verl.tools.sandbox_fusion_tools.SandboxFusionTool" - config: - sandbox_fusion_url: "https://xxx.apigateway-cn-beijing.volceapi.com/run_code" - type: native - tool_schema: - type: "function" - function: - name: "code_interpreter" - description: "A tool for executing code." - parameters: - type: "object" - properties: - code: - type: "string" - description: "The code to execute." - required: ["code"] \ No newline at end of file diff --git a/tests/workers/rollout/resource/tool_configs/search_tool_config b/tests/workers/rollout/resource/tool_configs/search_tool_config deleted file mode 100644 index 926b6b832..000000000 --- a/tests/workers/rollout/resource/tool_configs/search_tool_config +++ /dev/null @@ -1,23 +0,0 @@ -tools: - - class_name: verl.tools.search_tool.SearchTool - config: - retrieval_service_url: http://127.0.0.1:8000/retrieve - num_workers: 120 - rate_limit: 120 - timeout: 30 - type: native - tool_schema: - type: function - function: - name: search - description: Searches the web for relevant information based on the given query. - parameters: - type: object - properties: - query_list: - type: array - item: - type: string - description: A list of fully-formed semantic queries. The tool will return search results for each query. - required: - - query_list \ No newline at end of file diff --git a/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py b/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py deleted file mode 100644 index 69223890d..000000000 --- a/tests/workers/rollout/rollout_vllm/run_fsdp_vllm.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import time - -import torch -import torch.distributed as dist -from torch.distributed.fsdp import CPUOffload, MixedPrecision -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -from vllm import SamplingParams - -from verl.third_party.vllm import LLM -from verl.utils.distributed import initialize_global_process_group - - -def main(): - assert torch.cuda.is_available(), "CUDA must be present to run FSDP vLLM example" - local_rank, rank, world_size = initialize_global_process_group() - - local_cache_path = "~/.cache/verl/rlhf" - local_cache_path = os.path.expanduser(local_cache_path) - hdfs_path = "Qwen/Qwen2-7B-Instruct" - - from verl.utils.fs import copy_to_local - - local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) - tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True) - actor_model_config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=True) - with torch.device("cuda"): - actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True) - actor_model.to(torch.bfloat16) - - max_prompt_length = 16 - response_length = 32 - preencode_prompts = [ - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - tokenizer.pad_token = tokenizer.eos_token - prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) - input_ids = prompts["input_ids"] - attention_mask = prompts["attention_mask"] - from verl.utils.torch_functional import pad_sequence_to_length - - input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True).cuda() - attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True).cuda() - - from transformers import GenerationConfig - - generation_config = GenerationConfig(do_sample=False) - actor_model.cuda() - output = actor_model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_new_tokens=32, - # max_length=max_length, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - generation_config=generation_config, - # renormalize_logits=True, - output_scores=False, # this is potentially very large - return_dict_in_generate=True, - use_cache=False, - ) # may OOM when use_cache = True - seq = output.sequences - response = seq[:, max_prompt_length:] - - print(f"hf response: {tokenizer.batch_decode(response)}") - - tensor_model_parallel_size = 4 - from torch.distributed.device_mesh import init_device_mesh - - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) - - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) - fsdp_model = FSDP( - actor_model, - use_orig_params=True, - auto_wrap_policy=None, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - cpu_offload=CPUOffload(offload_params=False), - sync_module_states=False, - device_mesh=device_mesh, - ) - - FSDP.set_state_dict_type( - fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() - ) - - state_dict = fsdp_model.state_dict() - - sampling_params = SamplingParams( - temperature=0, top_p=1, n=1, max_tokens=response_length, logprobs=1, ignore_eos=True, detokenize=False - ) - - print(actor_model_config) - llm = LLM( - model=None, - tokenizer=tokenizer, - model_hf_config=actor_model_config, - tensor_parallel_size=tensor_model_parallel_size, - enforce_eager=True, - dtype="bfloat16", - load_format="dummy_dtensor", - gpu_memory_utilization=0.8, - trust_remote_code=True, - ) - - # Warmup iterations - for _ in range(10): - torch.cuda.synchronize() - llm.sync_model_weights(actor_weights=state_dict, load_format="dtensor") - torch.cuda.synchronize() - dist.barrier() - - start_time = time.time() - llm.sync_model_weights(actor_weights=state_dict, load_format="dtensor") - torch.cuda.synchronize() - dist.barrier() - end_time = time.time() - - # Calculate elapsed time - elapsed_time = end_time - start_time - print(f"Time taken: {elapsed_time:.6f} seconds") - - input_ids = input_ids.cuda() - attention_mask = attention_mask.cuda() - idx_list = [] - batch_size = input_ids.shape[0] - - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id - from verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import _pre_process_inputs - - for i in range(batch_size): - idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) - print("start generation") - outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False) - vllm_output = outputs[0].cuda() - if torch.distributed.get_rank() == 0: - print(f"hf response: {tokenizer.batch_decode(response)}") - print(f"vllm response: {tokenizer.batch_decode(vllm_output)}") - - -if __name__ == "__main__": - main() diff --git a/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py b/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py deleted file mode 100644 index 93aca6a2d..000000000 --- a/tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -from typing import Any - -import numpy as np -import pytest -import ray -from omegaconf import DictConfig -from transformers.utils import get_json_schema - -from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager -from verl.protocol import DataProto -from verl.tools.base_tool import BaseTool, OpenAIFunctionToolSchema -from verl.utils import hf_tokenizer - - -@pytest.fixture -def init_config() -> DictConfig: - import os - - from hydra import compose, initialize_config_dir - - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): - config = compose(config_name="ppo_trainer") - model_path = "Qwen/Qwen2.5-1.5B-Instruct" - config.actor_rollout_ref.model.path = model_path - config.actor_rollout_ref.rollout.mode = "async" - config.actor_rollout_ref.rollout.multi_turn.format = "hermes" - config.actor_rollout_ref.rollout.prompt_length = 4096 - config.actor_rollout_ref.rollout.response_length = 4096 - - # test sleep/wake_up with fsdp offload - config.actor_rollout_ref.actor.fsdp_config.param_offload = True - config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True - - return config - - -def test_vllm_async_rollout_without_tool_calls(init_config): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - } - ) - - # =========================== 1. Init rollout manager =========================== - async_rollout_manager = init_async_rollout_manager(init_config) - - # test sleep and wake_up - async_rollout_manager.sleep() - async_rollout_manager.wake_up() - - # =========================== 2. Generate sequences =========================== - raw_prompts = [ - [ - { - "role": "user", - "content": "Let's play a role playing game. Your name is Alice, your favorite color is blue.", - } - ], - [{"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."}], - ] - batch = DataProto( - non_tensor_batch={ - "raw_prompt": np.array(raw_prompts), - }, - ) - result = async_rollout_manager.generate_sequences(prompts=batch) - - # check result - seq_len = result.batch["prompts"].size(1) + result.batch["responses"].size(1) - assert len(result) == 2 - assert result.batch["input_ids"].size(1) == seq_len - assert result.batch["attention_mask"].size(1) == seq_len - assert result.batch["position_ids"].size(1) == seq_len - - # check turns - num_turns = result.non_tensor_batch["__num_turns__"] - assert np.all(num_turns == 2) - - print("Test passed!") - ray.shutdown() - - -class WeatherTool(BaseTool): - def get_current_temperature(self, location: str, unit: str = "celsius"): - """Get current temperature at a location. - - Args: - location: The location to get the temperature for, in the format "City, State, Country". - unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) - - Returns: - the temperature, the location, and the unit in a dict - """ - return { - "temperature": 26.1, - "location": location, - "unit": unit, - } - - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - schema = get_json_schema(self.get_current_temperature) - return OpenAIFunctionToolSchema(**schema) - - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - try: - result = self.get_current_temperature(**parameters) - return json.dumps(result), 0, {} - except Exception as e: - return str(e), 0, {} - - -class WeatherToolWithData(BaseTool): - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - schema = get_json_schema(self.get_temperature_date) - return OpenAIFunctionToolSchema(**schema) - - def get_temperature_date(self, location: str, date: str, unit: str = "celsius"): - """Get temperature at a location and date. - - Args: - location: The location to get the temperature for, in the format "City, State, Country". - date: The date to get the temperature for, in the format "Year-Month-Day". - unit: The unit to return the temperature in. Defaults to "celsius". (choices: ["celsius", "fahrenheit"]) - - Returns: - the temperature, the location, the date and the unit in a dict - """ - return { - "temperature": 25.9, - "location": location, - "date": date, - "unit": unit, - } - - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - try: - result = self.get_temperature_date(**parameters) - return json.dumps(result), 0, {} - except Exception as e: - return str(e), 0, {} - - -def test_vllm_async_rollout_with_tool_calls(init_config): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - } - ) - - # =========================== 1. Init rollout manager =========================== - tool_config = { - "tools": [ - { - "class_name": "tests.workers.rollout.rollout_vllm.test_vllm_chat_scheduler.WeatherTool", - "config": {"type": "native"}, - }, - { - "class_name": "tests.workers.rollout.rollout_vllm.test_vllm_chat_scheduler.WeatherToolWithData", - "config": {"type": "native"}, - }, - ] - } - tool_config_path = "/tmp/tool_config.json" - with open(tool_config_path, "w") as f: - json.dump(tool_config, f) - - init_config.actor_rollout_ref.rollout.multi_turn.tool_config_path = tool_config_path - async_rollout_manager = init_async_rollout_manager(init_config) - - # =========================== 2. Generate sequences =========================== - raw_prompts = [ - [ - {"role": "user", "content": "How are you?"}, - ], - [ - {"role": "user", "content": "What's the temperature in Los Angeles now?"}, - ], - [ - { - "role": "system", - "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n\n" - "Current Date: 2024-09-30", - }, - {"role": "user", "content": "What's the temperature in San Francisco now? How about tomorrow?"}, - ], - ] - batch = DataProto( - non_tensor_batch={ - "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), - }, - ) - result = async_rollout_manager.generate_sequences(prompts=batch) - - # Check turns - num_turns = result.non_tensor_batch["__num_turns__"] - # [user, assistant] - assert num_turns[0] == 2 - # [user, assistant, tool, assistant] - assert num_turns[1] == 4 - # [system, user, assistant, tool, tool, assistant] - assert num_turns[2] == 6 - - # Check response_mask - tokenizer = hf_tokenizer(init_config.actor_rollout_ref.model.path) - responses = result.batch["responses"] - response_mask = result.batch["response_mask"] - assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" - - # Decode responses with response_mask - for i in range(len(responses)): - valid_tokens = responses[i][response_mask[i].bool()] - response_str = tokenizer.decode(valid_tokens) - assert "" not in response_str, f"found in response: {response_str}" - assert "" not in response_str, f"found in response: {response_str}" - print(f"response: {response_str}") - - print("Test passed!") - ray.shutdown() diff --git a/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py b/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py deleted file mode 100644 index 30c9ae2bc..000000000 --- a/tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc - -import torch -import torch.distributed -import torch.distributed as dist -from omegaconf import OmegaConf -from transformers import AutoConfig, AutoTokenizer - -from verl import DataProto -from verl.utils.distributed import initialize_global_process_group -from verl.utils.model import compute_position_id_with_mask -from verl.workers.rollout.vllm_rollout.vllm_rollout_spmd import vLLMRollout - - -def test_vllm_rollout_with_yarn_position_embeddings(): - """ - Test the vLLM rollout with yarn position embeddings. - """ - - local_rank, rank, world_size = initialize_global_process_group() - config = OmegaConf.create( - { - "model_path": "OldKingMeister/Qwen2.5-1.5B-Instruct-YaRN", - "prompt_length": 35000, - "response_length": 512, - "dtype": "bfloat16", - "enforce_eager": True, - "gpu_memory_utilization": 0.4, - "enable_chunked_prefill": False, - "free_cache_engine": False, - "disable_log_stats": True, - "max_model_len": 35000 + 512, - "load_format": "auto", - "val_kwargs": { - "top_k": -1, - "top_p": 1.0, - "temperature": 0, - "n": 1, - "do_sample": False, - }, - "tensor_model_parallel_size": 4, - "trust_remote_code": True, - "calculate_log_probs": False, - "do_sample": False, - "temperature": 0.0, - "max_num_batched_tokens": 35000 + 512, - } - ) - - tokenizer = AutoTokenizer.from_pretrained(config.model_path, trust_remote_code=True, padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - model_hf_config = AutoConfig.from_pretrained(config.model_path) - - # do_sample=False for temperate=0 deterministic - input_dataproto = prepare_input_dataproto(tokenizer, config, validate=True, do_sample=False) - - vllm_rollout = vLLMRollout( - model_path=config.model_path, - config=config, - tokenizer=tokenizer, - model_hf_config=model_hf_config, - ) - # rollout - rollout_response = vllm_rollout.generate_sequences( - prompts=input_dataproto, - ) - if rank == 0: - print("VLLM Rollout Outputs:") - print(tokenizer.batch_decode(rollout_response.batch["responses"][:], skip_special_tokens=False)) - for response in rollout_response.batch["responses"]: - assert "<|im_end|>" in tokenizer.decode(response, skip_special_tokens=False), ( - "Response should contain <|im_end|> token" - ) - print("Checks passed.") - - del vllm_rollout - gc.collect() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - dist.barrier() - torch.distributed.destroy_process_group() - - -def prepare_input_dataproto(tokenizer, config, validate, do_sample=False): - base_phrase = "Roses are red, sky is blue. " * 4096 - preencode_prompts = [ - # 32810 tokens > 32768 tokens - [{"role": "user", "content": base_phrase + "Who won the Champions League in 2019?"}], - [{"role": "user", "content": base_phrase + "The founder of Apple is"}], - [{"role": "user", "content": base_phrase + "What's your name"}], - ] - formatted_prompts = [ - tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) - for conversation in preencode_prompts - ] - prompts = tokenizer(formatted_prompts, return_tensors="pt", padding="max_length", max_length=config.prompt_length) - input_dataproto = DataProto.from_dict( - { - "input_ids": prompts["input_ids"], - "attention_mask": prompts["attention_mask"], - "position_ids": compute_position_id_with_mask(prompts["attention_mask"]), - }, - meta_info={ - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, - "pad_token_id": tokenizer.pad_token_id, - "validate": validate, - "do_sample": do_sample, - "response_length": config.response_length, - "temperature": config.temperature, - }, - ) - return input_dataproto - - -if __name__ == "__main__": - test_vllm_rollout_with_yarn_position_embeddings() diff --git a/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py b/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py deleted file mode 100644 index c2b8f51cb..000000000 --- a/tests/workers/rollout/rollout_vllm/test_vllm_spmd.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import pytest -import torch -from torch.distributed.fsdp import CPUOffload, MixedPrecision -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType -from transformers import AutoModelForCausalLM, AutoTokenizer -from vllm import LLM, SamplingParams - -from verl.utils.distributed import initialize_global_process_group -from verl.utils.torch_functional import pad_sequence_to_length - - -def levenshtein(s1, s2): - m, n = len(s1), len(s2) - # Initialize matrix of zeros - dp = [[0] * (n + 1) for _ in range(m + 1)] - # Initialize first column and first row of the matrix - for i in range(m + 1): - dp[i][0] = i # Deletion from s1 to empty string - for j in range(n + 1): - dp[0][j] = j # Insertion to s1 from empty string - # Compute the Levenshtein distance matrix - for i in range(1, m + 1): - for j in range(1, n + 1): - cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match - dp[i][j] = min( - dp[i - 1][j] + 1, # Deletion - dp[i][j - 1] + 1, # Insertion - dp[i - 1][j - 1] + cost, # Substitution - ) - return dp[m][n] - - -def are_lists_similar(a, b): - if len(a) != len(b): - print("The lists are of different lengths.") - return False - - total_length = 0 - total_diff = 0 - - for s1, s2 in zip(a, b, strict=True): - max_len = max(len(s1), len(s2)) - total_length += max_len - diff = levenshtein(s1, s2) - total_diff += diff - print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n") - - percentage_difference = (total_diff / total_length) * 100 - print(f"Total difference: {percentage_difference:.2f}%") - - return percentage_difference <= 15 - - -@pytest.mark.skip("https://github.com/vllm-project/vllm/issues/16993") -def test_vllm_spmd(): - assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." - local_rank, rank, world_size = initialize_global_process_group() - - # Initialize model and token - local_cache_path = "~/.cache/verl/rlhf" - local_cache_path = os.path.expanduser(local_cache_path) - hdfs_path = "Qwen/Qwen2-7B-Instruct" - from verl.utils.fs import copy_to_local - - local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left", trust_remote_code=True) - - actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True) - actor_model.to(torch.bfloat16) - - # fill rollout config - max_prompt_length = 16 - max_response_length = 32 - preencode_prompts = [ - "Who won the Champions League in 2019?", - "The founder of Apple is", - "What's your name?", - ] - tokenizer.pad_token = tokenizer.eos_token - prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) - input_ids = prompts["input_ids"] - attention_mask = prompts["attention_mask"] - - input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) - attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) - - print("start generation") - input_ids = input_ids.cuda() - attention_mask = attention_mask.cuda() - - temperature = 0 - top_p = 1 - kwargs = dict( - n=1, temperature=temperature, top_p=top_p, max_tokens=max_response_length, logprobs=1, ignore_eos=True - ) - - tensor_parallel_size = 4 - - from torch.distributed.device_mesh import init_device_mesh - - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) - - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) - - fsdp_model = FSDP( - actor_model, - use_orig_params=True, - auto_wrap_policy=None, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - cpu_offload=CPUOffload(offload_params=False), - sync_module_states=False, - device_mesh=device_mesh, - ) - - FSDP.set_state_dict_type( - fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() - ) - - state_dict = fsdp_model.state_dict() - - sampling_params = SamplingParams(**kwargs) - llm = LLM( - model=local_model_path, - enable_sleep_mode=True, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend="external_launcher", - dtype="bfloat16", - enforce_eager=True, - gpu_memory_utilization=0.8, - disable_custom_all_reduce=True, - skip_tokenizer_init=False, - enable_prefix_caching=True, - trust_remote_code=True, - seed=1, - ) - - outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False) - vllm_response_tokens = [] - for output in outputs: - generated_text = output.outputs[0].text - vllm_response_tokens.append(generated_text) - - world_size = torch.distributed.get_world_size() - model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model - model.load_weights( - ((name, param.full_tensor() if world_size != 1 else param) for name, param in state_dict.items()) - ) - - outputs = llm.generate(preencode_prompts, sampling_params=sampling_params, use_tqdm=False) - verl_vllm_response_tokens = [] - for output in outputs: - generated_text = output.outputs[0].text - verl_vllm_response_tokens.append(generated_text) - - if torch.distributed.get_rank() == 0: - print(f"vllm response: {vllm_response_tokens}") - print(f"verl-vllm response: {verl_vllm_response_tokens}") - assert are_lists_similar(vllm_response_tokens, verl_vllm_response_tokens), "Strings differ more than 10%:\n" - print("Check Pass") - torch.distributed.destroy_process_group() - - -if __name__ == "__main__": - test_vllm_spmd() diff --git a/tests/workers/rollout/test_async_sglang_server.py b/tests/workers/rollout/test_async_sglang_server.py deleted file mode 100644 index 0b4e914f1..000000000 --- a/tests/workers/rollout/test_async_sglang_server.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from omegaconf import DictConfig - - -@patch.dict( - "sys.modules", - { - "verl.workers.rollout.sglang_rollout.sglang_rollout": MagicMock(SGLangRollout=MagicMock()), - }, -) -class TestAsyncSglangServer: - @pytest.fixture - def server_config(self): - return DictConfig({"rollout": {"tensor_model_parallel_size": 2}}) - - @pytest.mark.asyncio - @patch("verl.workers.rollout.sglang_rollout.async_sglang_server.ray.util.list_named_actors") - @patch("verl.workers.rollout.async_server.AsyncServerBase._start_fastapi_server", new_callable=AsyncMock) - @pytest.mark.filterwarnings("ignore:Ray state API is no longer experimental:DeprecationWarning") - async def test_init_engine(self, mock_start_fastapi_server, mock_list_actors, server_config): - mock_list_actors.return_value = [ - {"name": "test_prefixWorkerDict_1:0", "namespace": "test"}, - {"name": "test_prefixWorkerDict_1:1", "namespace": "test"}, - {"name": "test_prefixWorkerDict_0:0", "namespace": "test"}, - {"name": "test_prefixWorkerDict_0:1", "namespace": "test"}, - {"name": "test_prefixWorkerDict_1:2", "namespace": "test"}, - {"name": "test_prefixWorkerDict_1:3", "namespace": "test"}, - {"name": "test_prefixWorkerDict_0:2", "namespace": "test"}, - {"name": "test_prefixWorkerDict_0:3", "namespace": "test"}, - ] - from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer - - ActualClassToInstantiate = AsyncSglangServer - if hasattr(AsyncSglangServer, "__ray_metadata__") and hasattr( - AsyncSglangServer.__ray_metadata__, "modified_class" - ): - ActualClassToInstantiate = AsyncSglangServer.__ray_metadata__.modified_class - - def mock_get_actor_side_effect(name, namespace=None): - # Create a new mock actor for each call - actor_mock = MagicMock() - - # Support .name attribute access - actor_mock.name = name # Use 'name' here - - # Support ['name'] item access by mocking __getitem__ - def getitem_mock(key): - if key == "name": - return name # Use 'name' here - # For other keys, return a new MagicMock to mimic default behavior or raise KeyError - # Returning a MagicMock is consistent with the original error's cause for unmocked keys - return MagicMock(name=f"mock.__getitem__('{key}')") - - actor_mock.__getitem__.side_effect = getitem_mock - - return actor_mock - - # Verify instance.workers is correctly populated - with patch( - "verl.workers.rollout.sglang_rollout.async_sglang_server.ray.get_actor", - side_effect=mock_get_actor_side_effect, - ): - # Instance 1 - instance = ActualClassToInstantiate(server_config, 4, 0, "test_prefix") - await instance.init_engine() - - assert len(instance.workers) == 2 - assert instance.master_worker["name"] == "test_prefixWorkerDict_0:0" - assert instance.workers[0].name == "test_prefixWorkerDict_0:0" - assert instance.workers[1].name == "test_prefixWorkerDict_0:1" - - # Instance 2 - instance = ActualClassToInstantiate(server_config, 4, 1, "test_prefix") - await instance.init_engine() - - assert len(instance.workers) == 2 - assert instance.master_worker["name"] == "test_prefixWorkerDict_0:2" - assert instance.workers[0].name == "test_prefixWorkerDict_0:2" - assert instance.workers[1].name == "test_prefixWorkerDict_0:3" - - # Instance 3 - instance = ActualClassToInstantiate(server_config, 4, 3, "test_prefix") - await instance.init_engine() - - assert len(instance.workers) == 2 - assert instance.master_worker["name"] == "test_prefixWorkerDict_1:2" - assert instance.workers[0].name == "test_prefixWorkerDict_1:2" - assert instance.workers[1].name == "test_prefixWorkerDict_1:3" diff --git a/tests/workers/rollout/test_custom_completion_callback.py b/tests/workers/rollout/test_custom_completion_callback.py deleted file mode 100644 index c17d5272c..000000000 --- a/tests/workers/rollout/test_custom_completion_callback.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import concurrent.futures -import os -import re -import socket -import sys -import tempfile -from contextlib import asynccontextmanager -from typing import Any - -import fastapi -import numpy as np -import ray -import uvicorn -from datasets import load_dataset -from omegaconf import DictConfig -from openai.types.chat.chat_completion import ChatCompletion -from starlette.requests import Request -from starlette.responses import JSONResponse - -from tests.workers.rollout.async_rollout_utils import init_async_rollout_manager -from verl.protocol import DataProto -from verl.utils import hf_tokenizer -from verl.utils.reward_score.sandbox_fusion.utils import _process_single_case -from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler, ToolCompletionCallback - - -def _get_free_port(): - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - -@ray.remote(num_cpus=1) -class Sandbox: - """Sandbox to execute python code. - - WARNING: This class is for testing purpose only, do not use it in production. - Please use a sandbox with strong isolation and security restrictions instead. - """ - - def __init__(self): - self.address = ray.util.get_node_ip_address() - self.port = None - self.server_ready = asyncio.Event() - asyncio.create_task(self._start_fastapi_server()) - - async def code_execution(self, request: Request): - request_json = await request.json() - code = request_json["code"] - print(f"execute code:\n{code}") - - _, temp_file = tempfile.mkstemp(suffix=".py", prefix="temp_code", dir=None, text=True) - with open(temp_file, "w") as f: - f.write(code) - - try: - process = await asyncio.create_subprocess_exec( - sys.executable, temp_file, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) - - stdout, stderr = await process.communicate() - - response = { - "status": "Success" if process.returncode == 0 else "Failed", - "run_result": { - "status": "Finished", - "stdout": stdout.decode(), - "stderr": stderr.decode(), - "return_code": process.returncode, - }, - } - return JSONResponse(content=response) - finally: - try: - os.unlink(temp_file) - except: # noqa: E722 - pass - - async def _start_fastapi_server(self): - @asynccontextmanager - async def lifespan(app: fastapi.FastAPI): - print("FastAPI startup") - self.server_ready.set() - yield - - print("FastAPI shutdown, maybe address already in use, exit process immediately.") - os._exit(-1) - - app = fastapi.FastAPI(lifespan=lifespan) - app.router.add_api_route("/run_code", self.code_execution, methods=["POST"]) - - self.port = _get_free_port() - config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") - server = uvicorn.Server(config) - await server.serve() - - async def get_server_address(self) -> str: - """Get FastAPI server address.""" - await self.server_ready.wait() - return f"{self.address}:{self.port}" - - -class CustomCompletionCallback(ToolCompletionCallback): - def __init__(self, config: DictConfig, scheduler: ChatCompletionScheduler): - super().__init__(config, scheduler) - - self.max_assistant_turns = 16 - self.answer_pattern = re.compile(r"(.*?)", re.DOTALL) - self.code_pattern = re.compile(r"\s*```python(.*?)```\s*", re.DOTALL) - - self.sandbox_fusion_url = config.reward_model.sandbox_fusion.url - self.default_timeout = 10 - self.memory_limit_mb = config.reward_model.sandbox_fusion.memory_limit_mb - # TODO: support asyncio executor - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max(32, os.cpu_count() * 5)) - - async def sandbox_code_execution(self, code: str) -> dict[str, Any]: - loop = asyncio.get_running_loop() - result_status, metadata = await loop.run_in_executor( - self.executor, - _process_single_case, - 0, # case_index, - None, # stdin_data, - None, # expected_output, - self.sandbox_fusion_url, # sandbox_fusion_url - code, # generation - self.default_timeout, # timeout - self.memory_limit_mb, # memory limit - "python", # language - ) - - return metadata - - @property - def extra_body(self): - extra = { - "include_stop_str_in_output": True, - "stop": ["", ""], - } - return extra - - async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]): - role, content, finish_reason = ( - completions.choices[0].message.role, - completions.choices[0].message.content, - completions.choices[0].finish_reason, - ) - messages.append({"role": role, "content": content}) - turn = len(messages) - - # STEP 0: check if we reach max turns - if len(messages) >= self.max_assistant_turns: - print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max turns, done!") - return - - # STEP 1: check if we reach max tokens - if finish_reason == "length": - print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Reach max tokens, done!") - return - - # STEP 2: check if we got answer - matches = self.answer_pattern.findall(content) - if matches: - print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Got answer: {matches[0]}, done!") - return - - # STEP 3: check if we got code block - matches = self.code_pattern.findall(content) - if not matches: - print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] No code block found, done!") - return - - # STEP 4: execute code block in sandbox - code = matches[0].strip() - metadata = await self.sandbox_code_execution(code) - if metadata["run_status"] != "Finished": - print( - f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block execution failed: " - f"{metadata}, done!" - ) - return - - stdout, stderr = metadata["stdout"], metadata["stderr"] - messages.append({"role": "tool", "content": f"{stdout}{stderr}"}) - print(f"[id={completions.id},turn={turn},finish_reason={finish_reason}] Code block executed, continue...") - - # STEP 5: resubmit chat completions with code block output - self.scheduler.submit_chat_completions( - messages=messages, - request_id=completions.id, - info=info, - ) - - -user_prompt_template = """ -You are a helpful assistant. Let's solve math problem in following steps: -1. Write a python code first and return the code to user, the code must be in following format: - - -```python -import os - -print(...) -``` - - -The code must explictly print necessary output to stdout. Remember stop generation at immediately and -return the code. -2. User will send the python code to a external sandbox to execute and get output from stdout. -3. User will send the output in format output to you, and you should use the -output to answer the question. -The answer format must be: \\boxed{'The final answer goes here.'} - -*user question:* -{question} -""" - - -if __name__ == "__main__": - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - } - ) - - # Load config - import os - - from hydra import compose, initialize_config_dir - - with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): - config = compose(config_name="ppo_trainer") - model_path = "Qwen/Qwen2.5-1.5B-Instruct" - config.actor_rollout_ref.model.path = model_path - config.actor_rollout_ref.rollout.mode = "async" - config.actor_rollout_ref.rollout.multi_turn.format = "hermes" - config.actor_rollout_ref.rollout.multi_turn.completion_callback = ( - "tests.workers.rollout.test_custom_completion_callback.CustomCompletionCallback" - ) - config.actor_rollout_ref.rollout.prompt_length = 4096 - config.actor_rollout_ref.rollout.response_length = 4096 - config.actor_rollout_ref.rollout.n = 4 - - # Init sandbox and async rollout manager - sandbox = Sandbox.options(num_cpus=1).remote() - sandbox_address = ray.get(sandbox.get_server_address.remote()) - sandbox_fusion_url = f"http://{sandbox_address}/run_code" - config.reward_model.sandbox_fusion.url = sandbox_fusion_url - async_rollout_manager = init_async_rollout_manager(config) - - # Build dataset - dataset = load_dataset("Maxwell-Jia/AIME_2024", split="train") - prompts = DataProto( - non_tensor_batch={ - "raw_prompt": np.array( - [ - [{"role": "user", "content": user_prompt_template.replace("{question}", problem)}] - for problem in dataset["Problem"] - ] - ), - }, - ) - - result = async_rollout_manager.generate_sequences(prompts=prompts) - assert len(result) == len(dataset) * config.actor_rollout_ref.rollout.n - - # Check max turns that sandbox is called - num_turns = result.non_tensor_batch["__num_turns__"] - print(f"num_turns: {num_turns}") - assert np.max(num_turns) > 2, f"max turns: {np.max(num_turns)}" - - # Check response_mask - tokenizer = hf_tokenizer(config.actor_rollout_ref.model.path) - responses = result.batch["responses"] - response_mask = result.batch["response_mask"] - assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" - - # Decode responses with response_mask - for i in range(len(responses)): - valid_tokens = responses[i][response_mask[i].bool()] - response_str = tokenizer.decode(valid_tokens) - assert "" not in response_str, f"found in response: {response_str}" - assert "" not in response_str, f"found in response: {response_str}" - print(f"response: {response_str}") - - print("Test passed!") diff --git a/tests/workers/rollout/test_hf_rollout.py b/tests/workers/rollout/test_hf_rollout.py deleted file mode 100644 index 3eb6f4bb2..000000000 --- a/tests/workers/rollout/test_hf_rollout.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import torch -from omegaconf import OmegaConf -from torch.distributed.fsdp import CPUOffload, MixedPrecision -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.api import ShardedStateDictConfig, ShardingStrategy, StateDictType -from transformers import AutoModelForCausalLM, AutoTokenizer - -from verl import DataProto -from verl.utils.distributed import initialize_global_process_group -from verl.utils.fs import copy_to_local -from verl.utils.model import compute_position_id_with_mask -from verl.workers.rollout.hf_rollout import HFRollout - -BASE_HF_ROLLOUT_CONFIG = { - "temperature": 1.0, - "top_k": -1, - "top_p": 1, - "prompt_length": 64, - "response_length": 64, - "do_sample": True, - "n": 1, - "val_kwargs": { - "top_k": -1, - "top_p": 1.0, - "temperature": 0, - "n": 1, - "do_sample": False, - }, -} - - -def prepare_input_dataproto(tokenizer, config, validate): - preencode_prompts = [ - [{"role": "user", "content": "Who won the Champions League in 2019?"}], - [{"role": "user", "content": "The founder of Apple is"}], - [{"role": "user", "content": "What's your name"}], - ] - formatted_prompts = [ - tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) - for conversation in preencode_prompts - ] - prompts = tokenizer(formatted_prompts, return_tensors="pt", padding="max_length", max_length=config.prompt_length) - input_dataproto = DataProto.from_dict( - { - "input_ids": prompts["input_ids"], - "attention_mask": prompts["attention_mask"], - "position_ids": compute_position_id_with_mask(prompts["attention_mask"]), - }, - meta_info={ - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, - "pad_token_id": tokenizer.pad_token_id, - "validate": validate, - }, - ) - return input_dataproto - - -def prepare_fsdp_model(model, world_size): - from torch.distributed.device_mesh import init_device_mesh - - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) - - mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32) - - fsdp_model = FSDP( - model, - use_orig_params=True, - auto_wrap_policy=None, - device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - cpu_offload=CPUOffload(offload_params=False), - sync_module_states=False, - device_mesh=device_mesh, - ) - - FSDP.set_state_dict_type( - fsdp_model, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig() - ) - return fsdp_model - - -def test_hf_rollout(n: int = 1, do_sample: bool = True, validate: bool = False): - config = OmegaConf.create(BASE_HF_ROLLOUT_CONFIG) - config.update({"n": n, "do_sample": do_sample}) - - assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." - local_rank, rank, world_size = initialize_global_process_group() - - # Initialize model and tokenizer - local_cache_path = "~/.cache/verl/rlhf" - local_cache_path = os.path.expanduser(local_cache_path) - hdfs_path = "Qwen/Qwen2-7B-Instruct" - local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left", trust_remote_code=True) - tokenizer.pad_token = tokenizer.eos_token - - # Initialize FSDP model - actor_model = AutoModelForCausalLM.from_pretrained(local_model_path, trust_remote_code=True) - actor_model.to(torch.bfloat16) - fsdp_model = prepare_fsdp_model(actor_model, world_size) - - # Initialize HFRollout and start generate - hf_rollout = HFRollout(fsdp_model, OmegaConf.create(config)) - input = prepare_input_dataproto(tokenizer, config, validate).to(torch.cuda.current_device()) - outputs = hf_rollout.generate_sequences(input) - - # check generated batch size is expected - generated_batch_size = outputs.batch.batch_size[0] - assert generated_batch_size == input.batch.batch_size[0] * config.n - - for i in range(generated_batch_size): - prompt_tokens = outputs.batch["prompts"][i] - prompt_mask = prompt_tokens != tokenizer.pad_token_id - prompt_tokens = prompt_tokens[prompt_mask] - decoded_prompt = tokenizer.decode(prompt_tokens, skip_special_tokens=False) - - response_tokens = outputs.batch["responses"][i] - response_mask = response_tokens != tokenizer.pad_token_id - response_tokens = response_tokens[response_mask] - decoded_response = tokenizer.decode(response_tokens, skip_special_tokens=False) - - attention_mask = outputs.batch["attention_mask"][i] - position_ids = outputs.batch["position_ids"][i] - prompt_length = outputs.batch["prompts"].size(1) - response_length = outputs.batch["responses"].size(1) - - assert attention_mask.size(0) == prompt_length + response_length - assert position_ids.size(0) == prompt_length + response_length - - # check response attention mask is expected - response_attention = attention_mask[prompt_length:] - eos_positions = (outputs.batch["responses"][i] == tokenizer.pad_token_id).nonzero(as_tuple=True)[0] - if len(eos_positions) > 0: - first_eos_pos = eos_positions[0].item() - assert response_attention[: first_eos_pos + 1].all(), "Response attention mask should be 1 until EOS" - if first_eos_pos + 1 < response_length: - assert not response_attention[first_eos_pos + 1 :].any(), ( - "Response attention mask should be 0 after EOS" - ) - else: - assert response_attention.all(), "Response attention mask should be all 1 if no EOS token" - - # check response position ids is expected - prompt_positions = position_ids[:prompt_length] - response_positions = position_ids[prompt_length:] - valid_response_length = min(len(response_tokens), response_length) - if valid_response_length > 0: - assert response_positions[0] == prompt_positions[-1] + 1 - for j in range(1, valid_response_length): - assert response_positions[j] == response_positions[j - 1] + 1 - - # print generated text for inspection - if torch.distributed.get_rank() == 0: - print(f"prompt: {decoded_prompt}") - print(f"response: {decoded_response}") - print("=" * 30) - - -if __name__ == "__main__": - test_hf_rollout(n=2, do_sample=True, validate=False) - # test_hf_rollout(n=1, do_sample=False, validate=True) - # test_hf_rollout(n=1, do_sample=True, validate=False) diff --git a/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py b/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py deleted file mode 100644 index 387de1618..000000000 --- a/tests/workers/rollout/test_sglang_async_rollout_mcp_tools.py +++ /dev/null @@ -1,461 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from tests/workers/rollout/test_sglang_async_rollout_sf_tools.py - - -import asyncio -from copy import deepcopy -from unittest.mock import AsyncMock, MagicMock, patch - -import numpy as np -import pytest -from tensordict import TensorDict -from transformers import AutoConfig, AutoTokenizer -from utils_sglang import get_rollout_config, prepare_inputs - -from verl.protocol import DataProto -from verl.tools.mcp_search_tool import MCPSearchTool -from verl.tools.utils.mcp_clients.McpClientManager import MCPClientManager -from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message -from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout - -DEFAULT_USER_CONTENT_PREFIX = ( - "Answer the given question. You must conduct reasoning inside and " - "first every time you get new information. After reasoning, if you find you lack " - "some knowledge, you can call a search engine by query " - "and it will return the top searched results between and " - ". You can search as many times as your want. If you find no " - "further external knowledge needed, you can directly provide the answer inside " - " and , without detailed illustrations. For example, " - " Beijing . Question: " -) -user_content = DEFAULT_USER_CONTENT_PREFIX.rstrip("\n") + "How's the weather lately?" - - -def get_search_messages(): - user_prompt = { - "role": "user", - "content": user_content, - } - - expect_turn_0_msg = { - "role": "assistant", - "content": "Let me search the web.", - "tool_calls": [ - { - "id": "10", - "type": "function", - "function": { - "name": "tavily_search_tool", - "arguments": { - "what_is_your_intent": "Search for the weather lately", - "query": "the weather in Beijing today", - "search_depth": "basic", - "time_range": "day", - "include_domains": ["google.com", "baidu.com"], - "max_results": 2, - }, - }, - } - ], - } - - expect_turn_1_msg = { - "role": "assistant", - "content": "Let me search again.", - "tool_calls": [ - { - "type": "function", - "function": { - "name": "tavily_search_tool", - "arguments": { - "what_is_your_intent": "Search for the weather lately", - "query": "the weather in Beijing tomorrow", - "search_depth": "basic", - "time_range": "day", - "include_domains": ["google.com", "baidu.com"], - "max_results": 2, - }, - }, - } - ], - } - - expect_turn_2_msg = { - "role": "assistant", - "content": "Today is sunny and tomorrow will be cloudy in Beijing.", - } - - # Mock search tool responses - tool_return_0_msg = {"role": "tool", "content": [{"type": "text", "text": "Today's weather in Beijing is sunny."}]} - tool_return_1_msg = { - "role": "tool", - "content": [{"type": "text", "text": "Tomorrow's weather in Beijing is cloudy."}], - } - - user_prompts = [user_prompt] - expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg] - tool_return_array = [tool_return_0_msg, tool_return_1_msg] - - return user_prompts, expect_turn_array, tool_return_array - - -class TestRolloutWithMCPSearchTools: - @pytest.fixture - def qwen_tokenizer(self): - local_model_path = "Qwen/Qwen2.5-0.5B" - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - return tokenizer - - # we only need this for tokenizer - @pytest.fixture - def qwen_model_config(self): - local_model_path = "Qwen/Qwen2.5-0.5B" - config = AutoConfig.from_pretrained(local_model_path) - return config - - @pytest.fixture - def search_data(self, qwen_tokenizer): - user_prompt, expect_turn_array, tool_return_array = get_search_messages() - prompts = [[message] for message in user_prompt] - preencode_turn_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) - for turn in expect_turn_array - ] - preencode_tool_return_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) - for turn in tool_return_array - ] - return prompts, preencode_turn_array, preencode_tool_return_array - - @pytest.fixture - def search_rollout_config(self): - max_prompt_length = 4096 - max_response_length = 3000 - dtype = "bfloat16" - tensor_parallel_size = 1 - tool_path = "./resource/tool_configs/mcp_tool_config" - rollout_config = get_rollout_config( - max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path - ) - return rollout_config - - @pytest.fixture - def search_data_proto(self, search_data, qwen_tokenizer): - preencode_prompts, _, _ = search_data - prompts = [ - qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - for message in preencode_prompts - ] - input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000) - prompt_dict = TensorDict( - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=input_ids.shape[0], - ) - messages = np.asarray(preencode_prompts) - - tools_kwargs = np.array( - [ - { - "tavily_search_tool": { - "create_kwargs": {"ground_truth": "Today is sunny and tomorrow will be cloudy in Beijing."}, - }, - } - ], - dtype=object, - ) - index = np.array([0], dtype=object) - prompts = DataProto( - batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index} - ) - return prompts - - @pytest.fixture - def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config): - """Mock the rollout instance with sampling_params initialized.""" - tool_schema = [ - { - "type": "function", - "function": { - "name": "tavily_search_tool", - "description": "A powerful web search tool...", - "parameters": { - "type": "object", - "properties": { - "what_is_your_intent": { - "type": "string", - "description": "Describe your intent for using Tavily", - }, - "query": {"type": "string", "description": "Search query"}, - "search_depth": { - "type": "string", - "description": "The depth of the search ('basic' or 'advanced')", - }, - "topic": { - "type": "string", - "description": "The category of the search ('general' or 'news')", - }, - "days": { - "type": "integer", - "description": "Number of days back to include in search results (only for " - "'news' topic)", - }, - "time_range": { - "type": "string", - "description": "Time range for results ('day', 'week', 'month', 'year', 'd', " - "'w', 'm', 'y')", - }, - "include_domains": { - "type": "array", - "description": "List of domains to specifically include in search results", - }, - "exclude_domains": { - "type": "array", - "description": "List of domains to specifically exclude from search results", - }, - "include_answer": { - "type": "boolean", - "description": "Whether to include an answer summary generated by an LLM", - }, - "include_raw_content": { - "type": "boolean", - "description": "Whether to include the cleaned and parsed HTML content of each result", - }, - "include_images": { - "type": "boolean", - "description": "Whether to include images from search results", - }, - "include_image_descriptions": { - "type": "boolean", - "description": "Whether to include descriptions with images", - }, - "max_results": { - "type": "integer", - "description": "Maximum number of results to return (5-20)", - }, - "async_search": { - "type": "boolean", - "description": "Whether to perform the search asynchronously", - }, - }, - "required": ["what_is_your_intent", "query"], - }, - "strict": False, - }, - } - ] - with ( - patch.object(MCPClientManager, "fetch_tool_schemas", return_value=tool_schema), - patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), - patch.object(SGLangRollout, "_init_sampling_params", return_value=None), - ): - rollout = SGLangRollout( - actor_module="", - config=search_rollout_config, - processing_class=qwen_tokenizer, - model_hf_config=qwen_model_config, - ) - rollout.sampling_params = { - "n": 1, - "max_new_tokens": search_rollout_config.response_length, - "presence_penalty": 0.0, - "frequency_penalty": 0.0, - "repetition_penalty": 1.0, - } - return rollout - - def test_tools_registration(self, mock_rollout): - assert len(mock_rollout._tool_schemas) != 0 - assert "tavily_search_tool" in mock_rollout._tool_map.keys() - from verl.tools.mcp_search_tool import MCPSearchTool - - assert isinstance(mock_rollout._tool_map["tavily_search_tool"], MCPSearchTool) - # depend on the tokenizer - assert mock_rollout._tool_call_parser_type == "qwen25" - - def test_rollout_req_creation(self, mock_rollout, search_data_proto): - req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1) - assert len(req_list) == 1 - assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING - assert len(req_list[0].tool_schemas) == 1 - - def test_over_size_case(self, mock_rollout, search_data_proto, search_data): - mock_rollout.config.multi_turn.max_assistant_turns = 1 - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] - req = MagicMock(wraps=req, spec=AsyncRolloutRequest) - req.finalize = MagicMock() - req_list = [req] - - _, expect_turn_array, _ = search_data - # here we mock a meta info with 'length'. indicate the response is truncate - mock_rollout._handle_engine_call = MagicMock() - future = asyncio.Future() - future.set_result( - { - "text": expect_turn_array[0], - "meta_info": { - "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "length", "length": 3000}, - "prompt_tokens": 132, - "completion_tokens": 100, - "cached_tokens": 0, - "e2e_latency": 2.23543, - }, - } - ) - mock_rollout._handle_engine_call.return_value = future - mock_rollout._tp_rank = 0 - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather( - *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], - ) - ) - assert len(output_req_list) == 1 - output_req = output_req_list[0] - assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED - assert output_req.reward_scores.get("tavily_search_tool") == [] - # we should only have two message, one for prompt, second for response. - assert len(output_req.messages) == 2 - assert output_req.messages[1] == Message( - role="assistant", - content=expect_turn_array[0], - tool_calls=None, - ) - - @patch.object(MCPSearchTool, "execute", new_callable=AsyncMock) - def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_proto, search_data): - _, expect_turn_array, tool_return_array = search_data - # Mock search tool execution to return predefined responses - mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array] - - mock_rollout.config.multi_turn.max_assistant_turns = 10 - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] - req = MagicMock(wraps=req, spec=AsyncRolloutRequest) - req.finalize = MagicMock() - req_list = [req] - - mock_rollout._handle_engine_call = MagicMock() - futures = [asyncio.Future() for i in expect_turn_array] - for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)): - i.set_result( - { - "text": turn, - "meta_info": { - "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, - "prompt_tokens": len(turn), - "completion_tokens": 100, - "cached_tokens": 0, - "e2e_latency": 2.23543, - }, - } - ) - if idx < len(expect_turn_array) - 1: - assert mock_rollout._function_call_parser.has_tool_call(turn) - assert mock_rollout._function_call_parser.parse_non_stream(turn) - - mock_rollout._handle_engine_call.side_effect = futures - mock_rollout._tp_rank = 0 - - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather(*[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list]) - ) - - # Verify conversation completed successfully with proper tool usage - output_req = output_req_list[0] - assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED - assert "tavily_search_tool" in output_req.metrics - assert output_req.metrics["tavily_search_tool"][0]["status"] == "success" - assert mock_execute.await_count == 2 - assert len(output_req.messages) == 6 - # Verify tool response messages contain expected content - search_counter = 0 - for msg in output_req.messages: - if msg.role == "tool": - assert msg.content == tool_return_array[search_counter] - search_counter += 1 - assert search_counter == 2 - - @patch.object(MCPSearchTool, "execute", new_callable=AsyncMock) - def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_proto, search_data): - _, expect_turn_array, tool_return_array = search_data - # Mock tool execution for large batch (100 requests * 2 calls each) - mock_execute.side_effect = [ - (tool_return_array[0], 0.0, {"status": "success"}), - (tool_return_array[1], 0.0, {"status": "success"}), - ] * 100 - - mock_rollout.config.multi_turn.max_assistant_turns = 10 - base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] - - req_nums = 100 - req_list = [] - req_turns_map = {} - req_turns_counter = {} - - for i in range(req_nums): - tmp_req = deepcopy(base_req) - tmp_req.batch_data_id = i - tmp_req.request_id = i - req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest)) - - futures = [asyncio.Future() for _ in expect_turn_array] - for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)): - fut.set_result( - { - "text": turn, - "meta_info": { - "id": "dummy", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, - "prompt_tokens": len(turn), - "completion_tokens": 100, - }, - } - ) - req_turns_map[i] = futures - req_turns_counter[i] = 0 - - async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_kwargs): - fut = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]] - req_turns_counter[_req.batch_data_id] += 1 - return await fut - - with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): - mock_rollout._tp_rank = 0 - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather(*[mock_rollout._async_rollout_a_request(r, True, False) for r in req_list]) - ) - - # Verify all requests completed successfully - assert len(output_req_list) == req_nums - for out_req in output_req_list: - assert out_req.state == AsyncRolloutRequestStateEnum.COMPLETED - assert "tavily_search_tool" in out_req.metrics - for metric in out_req.metrics["tavily_search_tool"]: - assert metric["status"] == "success" - assert len(out_req.messages) == 6 - assert sum(1 for m in out_req.messages if m.role == "tool") == 2 - - assert mock_execute.await_count == 2 * req_nums diff --git a/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py b/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py deleted file mode 100644 index 47fefca8a..000000000 --- a/tests/workers/rollout/test_sglang_async_rollout_multimodal_delta.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright 2025 Amazon.com, Inc. or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import pytest - -from verl.utils.dataset.vision_utils import process_image -from verl.utils.tokenizer import hf_processor -from verl.workers.rollout.schemas import ( - AsyncRolloutRequest, - AsyncRolloutRequestStateEnum, - TokenizationSanityCheckModeEnum, -) - - -def _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False): - assert len(image_list) == len(description_list) - # Get the smallest dimensions across all images - processed_images = [] - for img_url in image_list: - img = process_image(img_url) - processed_images.append(img) - - min_width = min(img.size[0] for img in processed_images) - min_height = min(img.size[1] for img in processed_images) - min_size = (min_width, min_height) - - if resize_image: - processed_images_resized = [] - for img in processed_images: - img = img.resize(min_size) - processed_images_resized.append(img) - processed_images = processed_images_resized - - # Initial message history - system_prompt = ( - "You will be provided with an image. Describe this image and then generate a new image for the next round" - ) - messages = [ - { - "role": "system", - "content": system_prompt, - }, - { - "role": "user", - "content": [ - {"type": "text", "text": "Here is the first image provided: "}, - {"type": "image", "image": [processed_images[0]]}, - ], - }, - ] - - # Initial multi_modal_data with one image - multi_modal_data = {"image": [processed_images[0]], "video": []} - # Minimal required fields for AsyncRolloutRequest - - req = AsyncRolloutRequest( - batch_data_id=0, - request_id="test-req-1", - state=AsyncRolloutRequestStateEnum.PENDING, - messages=messages, - multi_modal_keys=["image", "video"], - multi_modal_data=multi_modal_data.copy(), - tool_schemas=[], - tools_kwargs={}, - interaction_kwargs={}, - input_ids=None, - prompt_ids=None, - response_ids=None, - attention_mask=None, - prompt_attention_mask=None, - response_attention_mask=None, - position_ids=None, - prompt_position_ids=None, - response_position_ids=None, - loss_mask=None, - prompt_loss_mask=None, - response_loss_mask=None, - reward_scores={}, - max_prompt_len=8192, - max_response_len=8192, - max_model_len=16384, - metrics={}, - use_inference_chat_template=True, - tokenization_sanity_check_mode=TokenizationSanityCheckModeEnum.STRICT, - generation_prompt_ids=None, - base_conv_wo_gen_prompt_end_pos=0, - base_conv_with_gen_prompt_end_pos=0, - processing_class=processor, - ) - - prev_generated_len = 0 - # Add First Assistant Message and first tool response message(image) - for idx, img in enumerate(processed_images): - if idx == 0: - continue - _ = req.get_generation_prompt_ids(processor) - req.add_assistant_message(processor, content=description_list[idx - 1]) - before_tool_call_len = req.input_ids.shape[-1] - req.add_tool_response_messages(processor, [{"image": [img], "text": "Here is the new image you requested: "}]) - after_tool_call_len = req.input_ids.shape[-1] - if prev_generated_len == 0: - prev_generated_len = after_tool_call_len - before_tool_call_len - else: - if resize_image: - assert after_tool_call_len - before_tool_call_len == prev_generated_len - assert req.multi_modal_data["image"] == processed_images[: idx + 1] - - _ = req.get_generation_prompt_ids(processor) - req.add_assistant_message(processor, content=description_list[-1]) - - messages = [msg.model_dump() for msg in req.messages] - tools = [tool.model_dump() for tool in req.tool_schemas] if req.tool_schemas else None - full_prompt_info = req._handle_apply_chat_template( - processor, - messages, - multi_modal_data=req.multi_modal_data, - tools=tools, - add_generation_prompt=False, - tokenize=True, - return_dict=True, - ) - full_prompt_ids = full_prompt_info["input_ids"] - assert full_prompt_ids.eq(req.input_ids).all() - - # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict - # because np.array() only keeps the keys for BatchFeature. - full_prompt_multi_modal_inputs = full_prompt_info.copy() - full_prompt_multi_modal_inputs.pop("input_ids", None) - full_prompt_multi_modal_inputs.pop("attention_mask", None) - - for key in full_prompt_multi_modal_inputs: - assert full_prompt_multi_modal_inputs[key].eq(req.multi_modal_inputs[key]).all() - - -@pytest.mark.skipif( - hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct" -) -def test_add_tool_response_messages_image_delta(): - processor = hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") - - # From Qwen2.5-VL-3B-Instruct HF example - img_1_url = {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"} - img_1_description = "A woman sits on the beach at sunset, smiling as she shares a high five with her large dog." - # GitHub Logo - img_2_url = {"image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"} - img_2_description = "A GitHub Logo image" - # Octocat - img_3_url = {"image": "https://octodex.github.com/images/orderedlistocat.png"} - img_3_description = "An Octocat image" - - image_list = [img_1_url, img_2_url, img_3_url] - description_list = [img_1_description, img_2_description, img_3_description] - _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=False) - - -@pytest.mark.skipif( - hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") is None, reason="Processor not available for Qwen/Qwen2.5-VL-B-Instruct" -) -def test_add_tool_response_messages_image_delta_resize_image(): - processor = hf_processor("Qwen/Qwen2.5-VL-3B-Instruct") - - # From Qwen2.5-VL-3B-Instruct HF example - img_1_url = {"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"} - img_1_description = "A woman sits on the beach at sunset, smiling as she shares a high five with her large dog." - # GitHub Logo - img_2_url = {"image": "https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png"} - img_2_description = "A GitHub Logo image" - # Octocat - img_3_url = {"image": "https://octodex.github.com/images/orderedlistocat.png"} - img_3_description = "An Octocat image" - - image_list = [img_1_url, img_2_url, img_3_url] - description_list = [img_1_description, img_2_description, img_3_description] - _test_add_tool_response_messages_image_delta(processor, image_list, description_list, resize_image=True) diff --git a/tests/workers/rollout/test_sglang_async_rollout_search_tools.py b/tests/workers/rollout/test_sglang_async_rollout_search_tools.py deleted file mode 100644 index 2400d5c78..000000000 --- a/tests/workers/rollout/test_sglang_async_rollout_search_tools.py +++ /dev/null @@ -1,418 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from tests/workers/rollout/test_sglang_async_rollout_sf_tools.py - - -import asyncio -from copy import deepcopy -from unittest.mock import AsyncMock, MagicMock, patch - -import numpy as np -import pytest -from tensordict import TensorDict -from transformers import AutoConfig, AutoTokenizer -from utils_sglang import get_rollout_config, prepare_inputs - -from verl.protocol import DataProto -from verl.tools.schemas import ( - OpenAIFunctionParametersSchema, - OpenAIFunctionPropertySchema, - OpenAIFunctionSchema, - OpenAIFunctionToolSchema, -) -from verl.tools.search_tool import SearchTool -from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message -from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout - -DEFAULT_USER_CONTENT_PREFIX = ( - "Answer the given question. You must conduct reasoning inside and " - "first every time you get new information. After reasoning, if you find you lack " - "some knowledge, you can call a search engine by query " - "and it will return the top searched results between and " - ". You can search as many times as your want. If you find no " - "further external knowledge needed, you can directly provide the answer inside " - " and , without detailed illustrations. For example, " - " Beijing . Question: " -) -user_content = DEFAULT_USER_CONTENT_PREFIX.rstrip("\n") + "How's the weather lately?" - - -def get_search_messages(): - user_prompt = { - "role": "user", - "content": user_content, - } - - expect_turn_0_msg = { - "role": "assistant", - "content": "Let me search the web.", - "tool_calls": [{"type": "function", "function": {"name": "search", "arguments": {"query": "today's weather"}}}], - } - - expect_turn_1_msg = { - "role": "assistant", - "content": "Let me search again.", - "tool_calls": [ - {"type": "function", "function": {"name": "search", "arguments": {"query": "tomorrow's weather"}}} - ], - } - - expect_turn_2_msg = { - "role": "assistant", - "content": "Today is sunny and tomorrow will be cloudy in Beijing.", - } - - # Mock search tool responses - tool_return_0_msg = {"role": "tool", "content": "Today's weather in Beijing is sunny."} - tool_return_1_msg = {"role": "tool", "content": "Tomorrow's weather in Beijing is cloudy."} - - user_prompts = [user_prompt] - expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg] - tool_return_array = [tool_return_0_msg, tool_return_1_msg] - - return user_prompts, expect_turn_array, tool_return_array - - -class TestRolloutWithSearchTools: - @pytest.fixture - def qwen_tokenizer(self): - local_model_path = "Qwen/Qwen2.5-0.5B" - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - return tokenizer - - # we only need this for tokenizer - @pytest.fixture - def qwen_model_config(self): - local_model_path = "Qwen/Qwen2.5-0.5B" - config = AutoConfig.from_pretrained(local_model_path) - return config - - @pytest.fixture - def search_data(self, qwen_tokenizer): - user_prompt, expect_turn_array, tool_return_array = get_search_messages() - prompts = [[message] for message in user_prompt] - preencode_turn_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) - for turn in expect_turn_array - ] - preencode_tool_return_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) - for turn in tool_return_array - ] - return prompts, preencode_turn_array, preencode_tool_return_array - - @pytest.fixture - def search_rollout_config(self): - max_prompt_length = 4096 - max_response_length = 3000 - dtype = "bfloat16" - tensor_parallel_size = 1 - tool_path = "./resource/tool_configs/search_tool_config" - rollout_config = get_rollout_config( - max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path - ) - return rollout_config - - @pytest.fixture - def search_data_proto(self, search_data, qwen_tokenizer): - preencode_prompts, _, _ = search_data - prompts = [ - qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - for message in preencode_prompts - ] - input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000) - prompt_dict = TensorDict( - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=input_ids.shape[0], - ) - messages = np.asarray(preencode_prompts) - - tools_kwargs = np.array( - [ - { - "search": { - "create_kwargs": { - "ground_truth": "Today is sunny and tomorrow will be cloudy in Beijing.", - "data_source": "searchR1_nq", - }, - }, - } - ], - dtype=object, - ) - index = np.array([0], dtype=object) - prompts = DataProto( - batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index} - ) - return prompts - - @pytest.fixture - def mock_rollout(self, search_rollout_config, qwen_tokenizer, qwen_model_config): - """Mock the rollout instance with sampling_params initialized.""" - with ( - patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), - patch.object(SGLangRollout, "_init_sampling_params", return_value=None), - ): - rollout = SGLangRollout( - actor_module="", - config=search_rollout_config, - processing_class=qwen_tokenizer, - model_hf_config=qwen_model_config, - ) - rollout.sampling_params = { - "n": 1, - "max_new_tokens": search_rollout_config.response_length, - "presence_penalty": 0.0, - "frequency_penalty": 0.0, - "repetition_penalty": 1.0, - } - return rollout - - @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_tools_registration( - self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config - ): - rollout = SGLangRollout( - actor_module="", - config=search_rollout_config, - processing_class=qwen_tokenizer, - model_hf_config=qwen_model_config, - ) - assert len(rollout._tool_schemas) == 1 - assert "search" in rollout._tool_map.keys() - from verl.tools.search_tool import SearchTool - - assert isinstance(rollout._tool_map["search"], SearchTool) - # depend on the tokenizer - assert rollout._tool_call_parser_type == "qwen25" - - @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) - def test_rollout_req_creation( - self, - mock_env, - mock_engine, - mock_sampling, - search_rollout_config, - qwen_tokenizer, - qwen_model_config, - search_data_proto, - ): - rollout = SGLangRollout( - actor_module="", - config=search_rollout_config, - processing_class=qwen_tokenizer, - model_hf_config=qwen_model_config, - ) - req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1) - assert len(req_list) == 1 - assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING - assert len(req_list[0].tool_schemas) == 1 - print(type(req_list[0].tool_schemas[0])) - assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema( - type="function", - function=OpenAIFunctionSchema( - name="search", - description="Searches the web for relevant information based on the given query.", - parameters=OpenAIFunctionParametersSchema( - type="object", - properties={ - "query_list": OpenAIFunctionPropertySchema( - type="array", - description="A list of fully-formed semantic queries. The tool will return search " - "results for each query.", - items={"type": "string"}, - ) - }, - required=["query_list"], - ), - strict=False, - ), - ) - - def test_over_size_case(self, mock_rollout, search_data_proto, search_data): - mock_rollout.config.multi_turn.max_assistant_turns = 1 - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] - req = MagicMock(wraps=req, spec=AsyncRolloutRequest) - req.finalize = MagicMock() - req_list = [req] - - _, expect_turn_array, _ = search_data - mock_rollout._handle_engine_call = MagicMock() - future = asyncio.Future() - future.set_result( - { - "text": expect_turn_array[0], - "meta_info": { - "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "length", "length": 3000}, - "prompt_tokens": 132, - "completion_tokens": 100, - "cached_tokens": 0, - "e2e_latency": 2.23543, - }, - } - ) - mock_rollout._handle_engine_call.return_value = future - mock_rollout._tp_rank = 0 - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather( - *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], - ) - ) - assert len(output_req_list) == 1 - output_req = output_req_list[0] - assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED - assert output_req.reward_scores.get("search") == [] - assert len(output_req.messages) == 2 - assert output_req.messages[1] == Message( - role="assistant", - content=expect_turn_array[0], - tool_calls=None, - ) - - @patch.object(SearchTool, "execute", new_callable=AsyncMock) - def test_tool_call_basic_case(self, mock_execute, mock_rollout, search_data_proto, search_data): - _, expect_turn_array, tool_return_array = search_data - - # Mock search tool execution to return predefined responses - mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array] - - mock_rollout.config.multi_turn.max_assistant_turns = 10 - mock_rollout._tool_map["search"].retrieval_service_url = "mock://dummy" - - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] - req = MagicMock(wraps=req, spec=AsyncRolloutRequest) - req.finalize = MagicMock() - req_list = [req] - - mock_rollout._handle_engine_call = MagicMock() - futures = [asyncio.Future() for i in expect_turn_array] - for idx, (i, turn) in enumerate(zip(futures, expect_turn_array, strict=True)): - i.set_result( - { - "text": turn, - "meta_info": { - "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, - "prompt_tokens": len(turn), - "completion_tokens": 100, - "cached_tokens": 0, - "e2e_latency": 2.23543, - }, - } - ) - if idx < len(expect_turn_array) - 1: - assert mock_rollout._function_call_parser.has_tool_call(turn) - assert mock_rollout._function_call_parser.parse_non_stream(turn) - - mock_rollout._handle_engine_call.side_effect = futures - mock_rollout._tp_rank = 0 - - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather(*[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list]) - ) - - # Verify conversation completed successfully with proper tool usage - output_req = output_req_list[0] - assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED - assert "search" in output_req.metrics - assert output_req.metrics["search"][0]["status"] == "success" - assert mock_execute.await_count == 2 - assert len(output_req.messages) == 6 # user + 3*assistant + 2*tool_call - # Verify tool response messages contain expected content - search_counter = 0 - for msg in output_req.messages: - if msg.role == "tool": - assert msg.content == tool_return_array[search_counter] - search_counter += 1 - assert search_counter == 2 - - @patch.object(SearchTool, "execute", new_callable=AsyncMock) - def test_tool_call_batch_case(self, mock_execute, mock_rollout, search_data_proto, search_data): - _, expect_turn_array, tool_return_array = search_data - - # Mock tool execution for large batch (100 requests * 2 calls each) - mock_execute.side_effect = [ - (tool_return_array[0], 0.0, {"status": "success"}), - (tool_return_array[1], 0.0, {"status": "success"}), - ] * 100 - - mock_rollout.config.multi_turn.max_assistant_turns = 10 - mock_rollout._tool_map["search"].retrieval_service_url = "mock://dummy" - - base_req = mock_rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] - - req_nums = 100 - req_list = [] - req_turns_map = {} - req_turns_counter = {} - - for i in range(req_nums): - tmp_req = deepcopy(base_req) - tmp_req.batch_data_id = i - tmp_req.request_id = i - req_list.append(MagicMock(wraps=tmp_req, spec=AsyncRolloutRequest)) - - futures = [asyncio.Future() for _ in expect_turn_array] - for idx, (fut, turn) in enumerate(zip(futures, expect_turn_array, strict=True)): - fut.set_result( - { - "text": turn, - "meta_info": { - "id": "dummy", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, - "prompt_tokens": len(turn), - "completion_tokens": 100, - }, - } - ) - req_turns_map[i] = futures - req_turns_counter[i] = 0 - - async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_kwargs): - fut = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]] - req_turns_counter[_req.batch_data_id] += 1 - return await fut - - with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): - mock_rollout._tp_rank = 0 - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather(*[mock_rollout._async_rollout_a_request(r, True, False) for r in req_list]) - ) - - # Verify all requests completed successfully - assert len(output_req_list) == req_nums - for out_req in output_req_list: - assert out_req.state == AsyncRolloutRequestStateEnum.COMPLETED - assert "search" in out_req.metrics - for metric in out_req.metrics["search"]: - assert metric["status"] == "success" - assert len(out_req.messages) == 6 # user + 3 assistant + 2 tool - assert sum(1 for m in out_req.messages if m.role == "tool") == 2 - - assert mock_execute.await_count == 2 * req_nums diff --git a/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py b/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py deleted file mode 100644 index 3f30929c2..000000000 --- a/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py +++ /dev/null @@ -1,659 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# noqa -import asyncio -import time -from copy import deepcopy -from functools import wraps -from unittest.mock import MagicMock, patch - -import numpy as np -import pytest -import ray -from tensordict import TensorDict -from torch.testing._internal.common_distributed import MultiProcessTestCase -from transformers import AutoConfig, AutoTokenizer -from utils_sglang import ( - get_rollout_config, - prepare_inputs, -) - -from verl.protocol import DataProto -from verl.tools.sandbox_fusion_tools import TokenBucketWorker -from verl.tools.schemas import ( - OpenAIFunctionParametersSchema, - OpenAIFunctionPropertySchema, - OpenAIFunctionSchema, - OpenAIFunctionToolSchema, -) -from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message -from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout - -sandbox_url = "" - - -def get_sandbox_fusion_messages(): - user_prompt = { - "role": "user", - "content": """ - Solve the following problem step by step. You now have the ability to selectively - write executable Python code to enhance your reasoning process. \n\n**user question:**\nThere - are 152 students at Dala High School. Assume the following: \n- 100 students take a Math class \n- 94 - students take a Science class \n- 57 students take an English class \n- 73 students take a Math class - and a Science class \n- 24 students take a Math class and an English class \n- 27 students take a Science - class and an English class \n- 22 students take a Math class and a Science class and an English class\n \nHow - many students take neither a Math class nor a Science class nor an Eglish class?\n\nRemember to place the final - answer in the last part using the format: \n\n\boxed{'The final answer goes here.'}\n - """, - } - expect_turn_0_msg = { - "role": "assistant", - "content": """ - Okay, so I need to find out how many students at Dala High School are not taking any of the three classes: Math, - Science, or English. The total number of students is 152. Let me see... I remember this is a problem about sets - and maybe using the principle of inclusion-exclusion. Let me recall how that works.\n\nFirst, the inclusion-exclusion - principle for three sets says that the total number of students taking at least one of the classes is equal to the - sum of the numbers in each individual class, minus the sum of the numbers in each pair of classes, plus the number in - all three classes. Then, subtract that total from the overall number of students to get those not taking any of the - three. \n\nLet me write that down step by step. Let M be the set of students taking Math, S for Science, and E for English. - Then:\n\nTotal in at least one class = |M ∪ S ∪ E| = |M| + |S| + |E| - |M ∩ S| - |M ∩ E| - |S ∩ E| + |M ∩ S ∩ E|\n\nGiven the - numbers:\n\n|M| = 100\n\n|S| = 94\n\n|E| = 57\n\n|M ∩ S| = 73\n\n|M ∩ E| = 24\n\n|S ∩ E| = 27\n\n|M ∩ S ∩ E| = 22\n\nSo plugging - these into the formula:\n\nTotal = 100 + 94 + 57 - 73 - 24 - 27 + 22\n\nLet me compute that step by step using code to ensure - accuracy.\n - """, - "tool_calls": [ - { - "type": "function", - "function": { - "name": "code_interpreter", - "arguments": { - "code": """M = 100\\nS = 94\\nE = 57\\nM_S = 73\\nM_E = 24\\nS_E = 27\\nM_S_E = - 22\\n\\ntotal_in_any = M + S + E - M_S - M_E - S_E + M_S_E\\nstudents_neither = 152 - total_in_any\\nprint(students_neither)""", - }, - }, - } - ], - } - - tool_return_0_msg = { - "role": "tool", - "content": """3""", - } - - expect_turn_1_msg = { - "role": "assistant", - "content": """ - The code returns 3, indicating that 3 students take none of the classes. To verify, I'll break down each - category to ensure the total matches:\n\n- Only Math and Science: 73 - 22 = 51\n- Only Math and English: - 24 - 22 = 2\n- Only Science and English: 27 - 22 = 5\n- Only Math: 100 - 51 - 2 - 22 = 25\n- Only Science: - 94 - 51 - 5 - 22 = 16\n- Only English: 57 - 2 - 5 - 22 = 28\n\nSumming all categories:\n - """, - "tool_calls": [ - { - "type": "function", - "function": { - "name": "code_interpreter", - "arguments": { - "code": """only_M_S = 73 - 22\\nonly_M_E = 24 - 22\\nonly_S_E = 27 - 22\\n\\nonly_M = 100 - only_M_S - - only_M_E - 22\\nonly_S = 94 - only_M_S - only_S_E - 22\\nonly_E = 57 - only_M_E - only_S_E - 22\\n\\ntotal_verify - = only_M + only_S + only_E + only_M_S + only_M_E + only_S_E + 22\\nprint(total_verify)""", - }, - }, - } - ], - } - - tool_return_1_msg = { - "role": "tool", - "content": """149""", - } - expect_turn_2_msg = { - "role": "assistant", - "content": """ - The verification total is 149, so students not taking any classes are 152 - 149 = 3, confirming the initial - result.\n\n\n\\boxed{3}\n - """, - } - - user_prompts = [user_prompt] - expect_turn_array = [expect_turn_0_msg, expect_turn_1_msg, expect_turn_2_msg] - tool_return_array = [tool_return_0_msg, tool_return_1_msg] - - return user_prompts, expect_turn_array, tool_return_array - - -def skip_if_valid_sandbox(url): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - if url == "" or url is None: - pytest.skip("No valid sandbox url provided") - - return wrapper - - return decorator - - -class TestRolloutWithTools: - @pytest.fixture - def qwen_tokenizer(self): - local_model_path = "Qwen/Qwen2.5-0.5B" - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - return tokenizer - - # we only need this for tokenizer - @pytest.fixture - def qwen_model_config(self): - local_model_path = "Qwen/Qwen2.5-0.5B" - config = AutoConfig.from_pretrained(local_model_path) - return config - - @pytest.fixture - def sandbox_fusion_data(self, qwen_tokenizer): - user_prompt, expect_turn_array, tool_return_array = get_sandbox_fusion_messages() - prompts = [[message] for message in user_prompt] - preencode_turn_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=False) - for turn in expect_turn_array - ] - preencode_tool_return_array = [ - qwen_tokenizer.apply_chat_template([turn], tokenize=False, add_generation_prompt=True) - for turn in tool_return_array - ] - return prompts, preencode_turn_array, preencode_tool_return_array - - @pytest.fixture - def sandbox_fusion_rollout_config(self): - max_prompt_length = 1024 - max_response_length = 1024 - dtype = "bfloat16" - tensor_parallel_size = 1 - tool_path = "./resource/tool_configs/sandbox_fusion_tool_config" - rollout_config = get_rollout_config( - max_response_length, max_prompt_length, dtype, tensor_parallel_size, tool_path - ) - return rollout_config - - @pytest.fixture - def sandbox_data_proto(self, sandbox_fusion_data, qwen_tokenizer): - preencode_prompts, _, _ = sandbox_fusion_data - prompts = [ - qwen_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - for message in preencode_prompts - ] - input_ids, attention_mask, position_ids = prepare_inputs(qwen_tokenizer, prompts, 1000) - prompt_dict = TensorDict( - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=input_ids.shape[0], - ) - messages = np.asarray(preencode_prompts) - tools_kwargs = np.array( - [ - { - "code_interpreter": { - "create_kwargs": {"ground_truth": "test-solution-str"}, - }, - } - ], - dtype=object, - ) - index = np.array([0], dtype=object) - prompts = DataProto( - batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index} - ) - return prompts - - @pytest.fixture - def mock_rollout(self, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config): - """Mock the rollout instance""" - with patch.object(SGLangRollout, "_init_distributed_env", return_value=None), patch.object( - SGLangRollout, "_init_inference_engine", return_value=None - ), patch.object(SGLangRollout, "_init_sampling_params", return_value=None): - rollout = SGLangRollout( - actor_module="", - config=sandbox_fusion_rollout_config, - processing_class=qwen_tokenizer, - model_hf_config=qwen_model_config, - ) - # set default sampling_params - rollout.sampling_params = { - "n": 1, - "max_new_tokens": sandbox_fusion_rollout_config.response_length, - "presence_penalty": 0.0, - "frequency_penalty": 0.0, - "repetition_penalty": 1.0, - } - return rollout - - def test_tools_registration(self, mock_rollout): - """Test tool registration functionality""" - assert len(mock_rollout._tool_schemas) == 1 - assert "code_interpreter" in mock_rollout._tool_map.keys() - from verl.tools.sandbox_fusion_tools import SandboxFusionTool - - assert isinstance(mock_rollout._tool_map["code_interpreter"], SandboxFusionTool) - assert mock_rollout._tool_call_parser_type == "qwen25" - - def test_rollout_req_creation(self, mock_rollout, sandbox_data_proto): - """Test request creation functionality""" - req_list = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1) - assert len(req_list) == 1 - assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING - assert len(req_list[0].tool_schemas) == 1 - print(type(req_list[0].tool_schemas[0])) - assert req_list[0].tool_schemas[0] == OpenAIFunctionToolSchema( - type="function", - function=OpenAIFunctionSchema( - name="code_interpreter", - description="A tool for executing code.", - parameters=OpenAIFunctionParametersSchema( - type="object", - properties={ - "code": OpenAIFunctionPropertySchema( - type="string", - description="The code to execute.", - enum=None, - ) - }, - required=["code"], - ), - strict=False, - ), - ) - - def test_over_size_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data): - """Test over-size response truncation case""" - mock_rollout.config.multi_turn.max_assistant_turns = 1 - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] - req = MagicMock(wraps=req, spec=AsyncRolloutRequest) - req.finalize = MagicMock() - req_list = [req] - - _, expect_turn_array, tool_return_array = sandbox_fusion_data - # here we mock a meta info with 'length'. indicate the response is truncate - mock_rollout._handle_engine_call = MagicMock() - future = asyncio.Future() - future.set_result( - { - "text": expect_turn_array[0], - "meta_info": { - "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "length", "length": 1024}, - "prompt_tokens": 132, - "completion_tokens": 100, - "cached_tokens": 0, - "e2e_latency": 9.9304039478302, - }, - } - ) - mock_rollout._handle_engine_call.return_value = future - mock_rollout._tp_rank = 0 - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather( - *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], - ) - ) - assert len(output_req_list) == 1 - output_req = output_req_list[0] - assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED - assert output_req.reward_scores.get("code_interpreter") == [] - # we should only have two message, one for prompt, second for response. - assert len(output_req.messages) == 2 - assert output_req.messages[1] == Message( - role="assistant", - content=expect_turn_array[0], - tool_calls=None, - ) - - @skip_if_valid_sandbox(sandbox_url) - def test_tool_call_basic_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data): - """Test basic tool call case""" - mock_rollout.config.multi_turn.max_assistant_turns = 10 - mock_rollout._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] - req = MagicMock(wraps=req, spec=AsyncRolloutRequest) - req.finalize = MagicMock() - req_list = [req] - _, expect_turn_array, tool_return_array = sandbox_fusion_data - # here we mock a meta info with 'length'. indicate the response is truncate - mock_rollout._handle_engine_call = MagicMock() - futures = [asyncio.Future() for i in expect_turn_array] - for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)): - i.set_result( - { - "text": turn, - "meta_info": { - "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, - "prompt_tokens": len(turn), - "completion_tokens": 100, - "cached_tokens": 0, - "e2e_latency": 9.9304039478302, - }, - } - ) - if idx < len(expect_turn_array) - 1: - assert mock_rollout._function_call_parser.has_tool_call(turn) - assert mock_rollout._function_call_parser.parse_non_stream(turn) - - mock_rollout._handle_engine_call.side_effect = futures - mock_rollout._tp_rank = 0 - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather( - *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], - ) - ) - assert len(output_req_list) == 1 - output_req = output_req_list[0] - assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED - # here we verify whether the code sandbox is executed correctly - assert output_req.metrics == {"code_interpreter": ["3", "149"]} - assert mock_rollout._handle_engine_call.call_count == 3 - assert len(output_req.messages) == 6 # user + 3*assistant + 2*tool_call - code_counter = 0 - for msg in output_req.messages: - if msg.role == "tool": - code_counter += 1 - assert msg.content == tool_return_array[code_counter] - assert code_counter == 2 - - @skip_if_valid_sandbox(sandbox_url) - def test_tool_call_batch_case(self, mock_rollout, sandbox_data_proto, sandbox_fusion_data): - """Test batch tool call case""" - mock_rollout.config.multi_turn.max_assistant_turns = 10 - mock_rollout._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url - req = mock_rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] - req_nums = 100 - req_list = [] - req_turns_counter = {} - # this map should a Map[id:List[Futures]] - req_turns_map = {} - _, expect_turn_array, tool_return_array = sandbox_fusion_data - for i in range(req_nums): - _temp_req = deepcopy(req) - _temp_req.batch_data_id = i - _temp_req.request_id = i - req_list.append(MagicMock(wraps=_temp_req, spec=AsyncRolloutRequest)) - futures = [asyncio.Future() for i in expect_turn_array] - for idx, (i, turn) in enumerate(zip(futures, expect_turn_array)): - i.set_result( - { - "text": turn, - "meta_info": { - "id": "d1188d81cba840359df5b352b344bc8e", - "finish_reason": {"type": "tool_calls" if idx < len(expect_turn_array) - 1 else "stop"}, - "prompt_tokens": len(turn), - "completion_tokens": 100, - "cached_tokens": 0, - "e2e_latency": 9.9304039478302, - }, - } - ) - if idx < len(expect_turn_array) - 1: - assert mock_rollout._function_call_parser.has_tool_call(turn) - assert mock_rollout._function_call_parser.parse_non_stream(turn) - req_turns_map[_temp_req.batch_data_id] = futures - req_turns_counter[_temp_req.batch_data_id] = 0 - - async def hacked_handle_engine_call( - self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs - ): - result = req_turns_map[_req.batch_data_id][req_turns_counter[_req.batch_data_id]] - req_turns_counter[_req.batch_data_id] += 1 - re = await result - return re - - with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): - mock_rollout._tp_rank = 0 - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather( - *[mock_rollout._async_rollout_a_request(req, True, False) for req in req_list], - ) - ) - assert len(output_req_list) == req_nums - # FIGUER out how to count this - # assert rollout._handle_engine_call.call_count == 3 * req_nums - for output_req in output_req_list: - assert output_req.state == AsyncRolloutRequestStateEnum.COMPLETED - # here we verify whether the code sandbox is executed correctly - assert output_req.metrics == {"code_interpreter": ["3", "149"]} - assert len(output_req.messages) == 6 # user + 3*assistant + 2*tool_call - code_counter = 0 - for msg in output_req.messages: - if msg.role == "tool": - code_counter += 1 - assert code_counter == 2 - - def test_sampling_params_functionality(self, mock_rollout): - """Test sampling_params functionality""" - # test basic copy functionality - copied_params = mock_rollout.sampling_params.copy() - assert copied_params == mock_rollout.sampling_params - assert copied_params is not mock_rollout.sampling_params - - # test parameter update - copied_params.update({"temperature": 0.8, "top_p": 0.9}) - assert copied_params["temperature"] == 0.8 - assert copied_params["top_p"] == 0.9 - - # ensure original parameters are not modified - assert "temperature" not in mock_rollout.sampling_params - assert "top_p" not in mock_rollout.sampling_params - - -class RayMultiProcessTestCase(MultiProcessTestCase): - def setUp(self): - super().setUp() - ray.init(ignore_reinit_error=True) - print("init_single cluster") - self._spawn_processes() - - def tearDown(self): - print("tearDown_single cluster") - ray.shutdown() - - -@ray.remote -class TestActor: - def __init__(self, rank, world_size): - self._world_size = world_size - self._rank = rank - self.rank_list = [] - self.time_list = [] - - def record_rank(self, rank): - self.rank_list.append(rank) - - def get_rank(self): - return self._rank - - def ping(self): - return True - - def record_execution_time(self, time): - self.time_list.append(time) - - def get_time(self, timeout): - import time - - now = time.time() - while time.time() - now < timeout: - # for start and end time - if len(self.time_list) == self._world_size * 2: - self.time_list.sort() - return self.time_list[-1] - self.time_list[0] - else: - time.sleep(1) - continue - return False - - def verify_rank(self): - import time - - now = time.time() - while time.time() - now < 10: - if len(self.rank_list) == self._world_size: - print(self.rank_list) - self.rank_list.sort() - for i in range(self._world_size): - if self.rank_list[i] != i: - return False - return True - else: - time.sleep(1) - continue - return False - - -class TestRayGlobalActorCase(RayMultiProcessTestCase): - @property - def world_size(self) -> int: - # for DP = 8 - return 2 - - def test_basic_multi_process_init(self): - ray.init("auto", namespace="test", ignore_reinit_error=True) - handle = TestActor.remote(self.rank, self.world_size) - re = ray.get(handle.get_rank.remote()) - assert re == self.rank, f"rank not match: {re} != {self.rank}" - - # def test_global_actor(self): - # ray.init("auto",namespace="test",ignore_reinit_error=True) - # handle = TestActor.options(get_if_exists=True,name="test-actor").remote(self.rank,self.world_size) - # handle.record_rank.remote(self.rank) - # # since test actor's concurrency is 1, we need to wait for all processes to finish - # time.sleep(5) - # assert ray.get(handle.ping.remote()) == True # make sure actor handle is valid - # if self.rank == 0: - # assert ray.get(handle.verify_rank.remote()) == True - # else: - # # get_actor use weak_ref, so we need to make sure the actor is not garbage collected - # time.sleep(10) - - -class TestSingleNodeRateLimiterCase(RayMultiProcessTestCase): - @property - def world_size(self) -> int: - return 1 - - def test_rate_limiter(self): - ray.init("auto", namespace="test", ignore_reinit_error=True) - from verl.tools.sandbox_fusion_tools import PoolMode, init_execution_pool - - # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=3) - exec_worker = init_execution_pool( - num_workers=10, enable_global_rate_limit=True, rate_limit=3, mode=PoolMode.ThreadMode - ) - center = TestActor.options(get_if_exists=True, name="test-actor").remote(self.rank, self.world_size) - ray.get(exec_worker.ping.remote()) - - def fn(i): - import time - - time.sleep(3) - return i - - start = time.time() - tasks = [exec_worker.execute.remote(fn, i) for i in range(6)] - loop = asyncio.get_event_loop() - results = loop.run_until_complete(asyncio.gather(*tasks)) - end = time.time() - duration = end - start - center.record_execution_time.remote(start) - center.record_execution_time.remote(end) - print(f"Total time: {duration:.2f} seconds for rank: {self.rank}") - - assert results == list(range(6)) - # we have 6 task with rate limit of 3, therefore we need at least 2 round: 3*2=6 seconds - assert duration > 6 - assert duration < 10 - - def test_rotten_execution(self): - ray.init("auto", namespace="test", ignore_reinit_error=True) - from verl.tools.sandbox_fusion_tools import PoolMode, init_execution_pool - - # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=6) - exec_worker = init_execution_pool( - num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode - ) - ray.get(exec_worker.ping.remote()) - - def fn(i): - if i == 10: - raise Exception("test") - else: - return i - - tasks = [exec_worker.execute.remote(fn, i) for i in range(20)] - loop = asyncio.get_event_loop() - results = loop.run_until_complete(asyncio.gather(*tasks)) - expect_result = [None] + list(range(10)) + list(range(11, 20)) - sorted_data = sorted(results, key=lambda x: (x is not None, x)) - assert sorted_data == expect_result, f"results: {results}, expect_result: {expect_result}" - rate_limiter = TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote() - rate = ray.get(rate_limiter.get_current_count.remote()) - assert rate == 0, f"rate: {rate}" - - -class TestMultiNodeRateLimiterCase(RayMultiProcessTestCase): - @property - def world_size(self) -> int: - return 2 - - def test_rate_limiter(self): - ray.init("auto", namespace="test", ignore_reinit_error=True) - from verl.tools.sandbox_fusion_tools import PoolMode, init_execution_pool - - # exec_worker = ExecutionWorker.options(max_concurrency=10).remote(enable_global_rate_limit=True, rate_limit=6) - exec_worker = init_execution_pool( - num_workers=10, enable_global_rate_limit=True, rate_limit=6, mode=PoolMode.ThreadMode - ) - center = TestActor.options(get_if_exists=True, name="test-actor").remote(self.rank, self.world_size) - ray.get(exec_worker.ping.remote()) - - def fn(i): - import time - - time.sleep(2) - return i - - start = time.time() - tasks = [exec_worker.execute.remote(fn, i) for i in range(6)] - loop = asyncio.get_event_loop() - results = loop.run_until_complete(asyncio.gather(*tasks)) - end = time.time() - duration = end - start - center.record_execution_time.remote(start) - center.record_execution_time.remote(end) - print(f"Total time: {duration:.2f} seconds for rank: {self.rank}") - assert results == list(range(6)) - time.sleep(5) - if self.rank == 0: - total_cost = ray.get(center.get_time.remote(10)) - print(f"for total cost: {total_cost}") - # # we have 6 task each node * 2node = 12 task, each task take 2 second. - # with rate limit of 6, - # therefore we need at least 2 round: 12/6*2=4 seconds - assert total_cost > 4, total_cost - else: - time.sleep(10) diff --git a/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py b/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py deleted file mode 100644 index 3ccde1852..000000000 --- a/tests/workers/rollout/test_sglang_async_rollout_w_interaction.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -usage: torchrun --standalone --nnodes=1 \ - --nproc_per_node=2 $(which pytest) \ - -s test_sglang_async_rollout_w_interaction.py -""" - -import numpy as np -import torch -from tensordict import TensorDict -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision, ShardingStrategy -from utils_sglang import ( - are_lists_similar, - clean_torchelastic_env, - generate_hf_output, - get_rollout_config, - initialize_global_process_group, - load_tokenizer_and_model, - prepare_inputs, -) - -from verl import DataProto -from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout -from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager - - -def test_async_sglang_rollout_w_interaction(): - assert torch.cuda.device_count() >= 2 - initialize_global_process_group() - clean_torchelastic_env() - - max_prompt_length = 32 - max_response_length = 16 - dtype = "bfloat16" - tensor_parallel_size = 2 - local_model_path = "Qwen/Qwen2.5-0.5B" - - tokenizer, actor_model = load_tokenizer_and_model(local_model_path) - - preencode_prompts = [ - [{"role": "user", "content": prompt, "tool_calls": None}] - for prompt in [ - "Who won the Champions League in 2019?", - "The founder of Apple is", - "What's the best way to learn python?", - ] - ] - interaction_kwargs = [ - {"name": "gsm8k", "query": "Who won the Champions League in 2019?", "ground_truth": "Real Madrid"}, - {"name": "gsm8k", "query": "The founder of Apple is", "ground_truth": "Steve Jobs"}, - {"name": "gsm8k", "query": "What's the best way to learn python?", "ground_truth": "Learn python from scratch"}, - ] - prompts = [ - tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - for message in preencode_prompts - ] - input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length) - - hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) - - fsdp_device_mesh = init_device_mesh("cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",)) - inference_device_mesh_cpu = init_device_mesh( - "cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=("dp", "infer_tp", "pp") - ) - - fsdp_model = FSDP( - actor_model, - use_orig_params=True, - device_id=fsdp_device_mesh["fsdp"].get_local_rank(), - mixed_precision=MixedPrecision(param_dtype=getattr(torch, dtype)), - sharding_strategy=ShardingStrategy.FULL_SHARD, - device_mesh=fsdp_device_mesh, - ) - - # Create a temporary interaction config file for testing - import tempfile - - from omegaconf import OmegaConf - - interaction_config = { - "interaction": [ - {"name": "gsm8k", "class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}} - ] - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - OmegaConf.save(interaction_config, f.name) - interaction_config_path = f.name - - rollout_config = get_rollout_config( - max_response_length, max_prompt_length, dtype, tensor_parallel_size, None, interaction_config_path - ) - rollout = SGLangRollout( - actor_module=local_model_path, - config=rollout_config, - processing_class=tokenizer, - model_hf_config=actor_model.config, - ) - - rollout_sharding_manager = FSDPSGLangShardingManager( - module=fsdp_model, - inference_engine=rollout._engine, - model_config=actor_model.config, - rollout_config=rollout_config, - full_params=True, - device_mesh=inference_device_mesh_cpu, - ) - - with rollout_sharding_manager: - prompt_dict = TensorDict( - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=input_ids.shape[0], - ) - print(f"preprocessed {input_ids.shape=}") - - messages = np.asarray(preencode_prompts) - prompts = DataProto( - batch=prompt_dict, - non_tensor_batch={"raw_prompt": messages, "interaction_kwargs": np.asarray(interaction_kwargs)}, - ) - - prompts.meta_info.update( - { - "eos_token_id": tokenizer.eos_token_id, - "pad_token_id": tokenizer.pad_token_id, - } - ) - - prompts = rollout_sharding_manager.preprocess_data(prompts) - # log_gpu_memory_usage("Before generating sequences", logger=None) - output = rollout.generate_sequences(prompts=prompts) - print(f"generated {output.batch['responses'].shape=}") - # log_gpu_memory_usage("After generating sequences", logger=None) - output = rollout_sharding_manager.postprocess_data(output) - print(f"postprocessed {output.batch['responses'].shape=}") - sglang_output = output.to("cpu") - - sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch["responses"]) - - print(f"hf response: {hf_response_tokens}") - print(f"sglang response: {sglang_response_tokens}") - assert are_lists_similar(hf_response_tokens, sglang_response_tokens) - print("SGLang w interaction Test Passed!") - - # Clean up temporary config file - import os - - os.unlink(interaction_config_path) - - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - -if __name__ == "__main__": - test_async_sglang_rollout_w_interaction() diff --git a/tests/workers/rollout/test_sglang_async_rollout_w_tools.py b/tests/workers/rollout/test_sglang_async_rollout_w_tools.py deleted file mode 100644 index 20faab851..000000000 --- a/tests/workers/rollout/test_sglang_async_rollout_w_tools.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -usage: torchrun --standalone --nnodes=1 \ - --nproc_per_node=2 $(which pytest) \ - -s test_sglang_async_rollout_w_tools.py -""" - -import numpy as np -import torch -from tensordict import TensorDict -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import MixedPrecision, ShardingStrategy -from utils_sglang import ( - are_lists_similar, - clean_torchelastic_env, - generate_hf_output, - get_rollout_config, - initialize_global_process_group, - load_tokenizer_and_model, - prepare_inputs, -) - -from verl import DataProto -from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout -from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager - - -def test_async_sglang_rollout_w_tool(): - assert torch.cuda.device_count() >= 2 - initialize_global_process_group() - clean_torchelastic_env() - - max_prompt_length = 32 - max_response_length = 16 - dtype = "bfloat16" - tensor_parallel_size = 2 - local_model_path = "Qwen/Qwen2.5-0.5B" - - tokenizer, actor_model = load_tokenizer_and_model(local_model_path) - - preencode_prompts = [ - [{"role": "user", "content": prompt, "tool_calls": None}] - for prompt in [ - "Who won the Champions League in 2019?", - "The founder of Apple is", - "What's the best way to learn python?", - ] - ] - prompts = [ - tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - for message in preencode_prompts - ] - input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length) - - hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) - - fsdp_device_mesh = init_device_mesh("cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",)) - inference_device_mesh_cpu = init_device_mesh( - "cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=("dp", "infer_tp", "pp") - ) - - fsdp_model = FSDP( - actor_model, - use_orig_params=True, - device_id=fsdp_device_mesh["fsdp"].get_local_rank(), - mixed_precision=MixedPrecision(param_dtype=getattr(torch, dtype)), - sharding_strategy=ShardingStrategy.FULL_SHARD, - device_mesh=fsdp_device_mesh, - ) - - rollout_config = get_rollout_config( - max_response_length, - max_prompt_length, - dtype, - tensor_parallel_size, - "./resource/tool_configs/sandbox_fusion_tool_config", - ) - rollout = SGLangRollout( - actor_module=local_model_path, - config=rollout_config, - processing_class=tokenizer, - model_hf_config=actor_model.config, - ) - - rollout_sharding_manager = FSDPSGLangShardingManager( - module=fsdp_model, - inference_engine=rollout._engine, - model_config=actor_model.config, - rollout_config=rollout_config, - full_params=True, - device_mesh=inference_device_mesh_cpu, - ) - - with rollout_sharding_manager: - prompt_dict = TensorDict( - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=input_ids.shape[0], - ) - print(f"preprocessed {input_ids.shape=}") - - messages = np.asarray(preencode_prompts) - prompts = DataProto( - batch=prompt_dict, - non_tensor_batch={ - "raw_prompt": messages, - "tools_kwargs": np.array([{}] * input_ids.shape[0], dtype=object), - }, - ) - - prompts.meta_info.update( - { - "eos_token_id": tokenizer.eos_token_id, - "pad_token_id": tokenizer.pad_token_id, - } - ) - - prompts = rollout_sharding_manager.preprocess_data(prompts) - # log_gpu_memory_usage("Before generating sequences", logger=None) - output = rollout.generate_sequences(prompts=prompts) - print(f"generated {output.batch['responses'].shape=}") - # log_gpu_memory_usage("After generating sequences", logger=None) - output = rollout_sharding_manager.postprocess_data(output) - print(f"postprocessed {output.batch['responses'].shape=}") - sglang_output = output.to("cpu") - - sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch["responses"]) - - print(f"hf response: {hf_response_tokens}") - print(f"sglang response: {sglang_response_tokens}") - assert are_lists_similar(hf_response_tokens, sglang_response_tokens) - print("SGLang w tool Test Passed!") - - torch.distributed.barrier() - torch.distributed.destroy_process_group() - - -if __name__ == "__main__": - test_async_sglang_rollout_w_tool() diff --git a/tests/workers/rollout/test_sglang_multi_interaction.py b/tests/workers/rollout/test_sglang_multi_interaction.py deleted file mode 100644 index 465470fbd..000000000 --- a/tests/workers/rollout/test_sglang_multi_interaction.py +++ /dev/null @@ -1,426 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -Test for multi-interaction support in SGLangRollout. -usage: torchrun --standalone --nnodes=1 \ - --nproc_per_node=2 $(which pytest) \ - -s test_sglang_multi_interaction.py -""" - -import os -import tempfile -from unittest.mock import MagicMock, patch - -import torch -import torch.distributed as dist -from omegaconf import DictConfig, OmegaConf -from transformers import AutoTokenizer - -from verl.interactions.base import BaseInteraction -from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout - - -class MockInteraction(BaseInteraction): - """Mock interaction for testing.""" - - def __init__(self, config): - super().__init__(config) - self.started_instances = set() - - async def start_interaction(self, instance_id=None, **kwargs): - if instance_id is None: - instance_id = "mock_instance" - self.started_instances.add(instance_id) - return instance_id - - async def generate_response(self, instance_id, messages, **kwargs): - return False, f"Mock response from {self.name}", 1.0, {} - - -def create_mock_config_with_multi_interactions(): - """Create a mock configuration with multiple interactions.""" - # Create temporary interaction config file - interaction_config = { - "interaction": [ - { - "name": "mock_agent1", - "class_name": "tests.workers.rollout.test_sglang_multi_interaction.MockInteraction", - "config": {"param1": "value1"}, - }, - { - "name": "mock_agent2", - "class_name": "tests.workers.rollout.test_sglang_multi_interaction.MockInteraction", - "config": {"param2": "value2"}, - }, - ] - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - OmegaConf.save(interaction_config, f.name) - interaction_config_path = f.name - - # Create mock SGLangRollout config - config = DictConfig( - { - "multi_turn": { - "interaction_config_path": interaction_config_path, - "tool_config_path": None, - "enable": True, - "max_assistant_turns": 5, - "max_user_turns": 3, - "use_inference_chat_template": True, - "tokenization_sanity_check_mode": "off", - }, - "prompt_length": 32, - "response_length": 16, - "max_model_len": 512, - "dtype": "bfloat16", - "gpu_memory_utilization": 0.8, - "load_format": "dummy", - "enforce_eager": True, - "free_cache_engine": False, - "calculate_log_probs": False, - "tensor_model_parallel_size": 1, - "n": 1, - "val_kwargs": {"top_k": 1, "top_p": 1.0, "temperature": 0.0}, - } - ) - - return config, interaction_config_path - - -def setup_distributed(): - """Initialize distributed environment if not already initialized.""" - if not dist.is_initialized(): - dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo") - - -class TestSGLangMultiInteraction: - def test_initialize_multiple_interactions(self): - """Test that SGLangRollout can initialize multiple interactions.""" - setup_distributed() - config, temp_config_path = create_mock_config_with_multi_interactions() - - try: - # Mock SGLang engine and initialization methods like the reference test - with ( - patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), - patch.object(SGLangRollout, "_init_sampling_params", return_value=None), - ): - # Create a real tokenizer like the reference test - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - - # Mock model config - mock_model_config = MagicMock() - mock_model_config.max_position_embeddings = 2048 - # since this is a mock, we can set any rope scaling config - # to test the rope_scaling logic at the same time of this test - mock_model_config.rope_scaling = { - "factor": 4.0, - "original_max_position_embeddings": 32768, - "type": "yarn", - } - - # Create SGLangRollout instance - rollout = SGLangRollout( - actor_module="mock_model", - config=config, - processing_class=tokenizer, - model_hf_config=mock_model_config, - port=None, - trust_remote_code=False, - device_mesh=None, - ) - - # Check that interactions were initialized - assert len(rollout.interaction_map) == 2 - assert "mock_agent1" in rollout.interaction_map - assert "mock_agent2" in rollout.interaction_map - - # Use class name comparison instead of isinstance for multi-process compatibility - assert rollout.interaction_map["mock_agent1"].__class__.__name__ == "MockInteraction" - assert rollout.interaction_map["mock_agent2"].__class__.__name__ == "MockInteraction" - - # Also check that they are instances of BaseInteraction (which should work across processes) - assert isinstance(rollout.interaction_map["mock_agent1"], BaseInteraction) - assert isinstance(rollout.interaction_map["mock_agent2"], BaseInteraction) - - # Check that names were set correctly - assert rollout.interaction_map["mock_agent1"].name == "mock_agent1" - assert rollout.interaction_map["mock_agent2"].name == "mock_agent2" - - finally: - os.unlink(temp_config_path) - - def test_interaction_selection_by_name(self): - """Test that interactions are selected by name from interaction_kwargs.""" - setup_distributed() - config, temp_config_path = create_mock_config_with_multi_interactions() - - try: - with ( - patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), - patch.object(SGLangRollout, "_init_sampling_params", return_value=None), - ): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - - mock_model_config = MagicMock() - mock_model_config.max_position_embeddings = 2048 - mock_model_config.rope_scaling = { - "factor": 4.0, - "original_max_position_embeddings": 32768, - "type": "yarn", - } - - rollout = SGLangRollout( - actor_module="mock_model", - config=config, - processing_class=tokenizer, - model_hf_config=mock_model_config, - port=None, - trust_remote_code=False, - device_mesh=None, - ) - - # Test interaction selection logic - from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message - - # Create a mock request with specific interaction name - req = AsyncRolloutRequest( - request_id="test_req", - state=AsyncRolloutRequestStateEnum.INTERACTING, - messages=[Message(role="user", content="test message")], - interaction_kwargs={"name": "mock_agent2", "test_param": "value"}, - input_ids=None, - prompt_ids=None, - response_ids=None, - attention_mask=None, - prompt_attention_mask=None, - response_attention_mask=None, - position_ids=None, - prompt_position_ids=None, - response_position_ids=None, - loss_mask=None, - prompt_loss_mask=None, - response_loss_mask=None, - reward_scores={}, - max_prompt_len=32, - max_response_len=16, - max_model_len=512, - use_inference_chat_template=True, - tokenization_sanity_check_mode="disable", - processing_class=tokenizer, - ) - - # Test that the correct interaction is selected - interaction_name = req.interaction_kwargs.get("name", "gsm8k") - assert interaction_name == "mock_agent2" - assert interaction_name in rollout.interaction_map - - selected_interaction = rollout.interaction_map[interaction_name] - assert selected_interaction.name == "mock_agent2" - - finally: - os.unlink(temp_config_path) - - def test_fallback_to_default_interaction(self): - """Test fallback to default interaction when name is not specified.""" - setup_distributed() - # Create config with gsm8k interaction - interaction_config = { - "interaction": [ - { - "name": "gsm8k", - "class_name": "tests.workers.rollout.test_sglang_multi_interaction.MockInteraction", - "config": {}, - } - ] - } - - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - OmegaConf.save(interaction_config, f.name) - interaction_config_path = f.name - - config = DictConfig( - { - "multi_turn": { - "interaction_config_path": interaction_config_path, - "tool_config_path": None, - "enable": True, - "max_assistant_turns": 5, - "max_user_turns": 3, - "use_inference_chat_template": True, - "tokenization_sanity_check_mode": "disable", - }, - "prompt_length": 32, - "response_length": 16, - "max_model_len": 512, - "dtype": "bfloat16", - "gpu_memory_utilization": 0.8, - "load_format": "dummy", - "enforce_eager": True, - "free_cache_engine": False, - "calculate_log_probs": False, - "tensor_model_parallel_size": 1, - "n": 1, - "val_kwargs": {"top_k": 1, "top_p": 1.0, "temperature": 0.0}, - } - ) - - try: - with ( - patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), - patch.object(SGLangRollout, "_init_sampling_params", return_value=None), - ): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - - mock_model_config = MagicMock() - mock_model_config.max_position_embeddings = 2048 - mock_model_config.rope_scaling = { - "factor": 4.0, - "original_max_position_embeddings": 32768, - "type": "yarn", - } - - rollout = SGLangRollout( - actor_module="mock_model", - config=config, - processing_class=tokenizer, - model_hf_config=mock_model_config, - port=None, - trust_remote_code=False, - device_mesh=None, - ) - - # Test that default interaction name works - interaction_kwargs_without_name = {"test_param": "value"} - default_name = interaction_kwargs_without_name.get("name", "gsm8k") - assert default_name == "gsm8k" - assert default_name in rollout.interaction_map - - finally: - os.unlink(interaction_config_path) - - def test_error_on_missing_interaction(self): - """Test that error is raised when requested interaction is not found.""" - setup_distributed() - config, temp_config_path = create_mock_config_with_multi_interactions() - - try: - with ( - patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), - patch.object(SGLangRollout, "_init_sampling_params", return_value=None), - ): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - - mock_model_config = MagicMock() - mock_model_config.max_position_embeddings = 2048 - mock_model_config.rope_scaling = { - "factor": 4.0, - "original_max_position_embeddings": 32768, - "type": "yarn", - } - - rollout = SGLangRollout( - actor_module="mock_model", - config=config, - processing_class=tokenizer, - model_hf_config=mock_model_config, - port=None, - trust_remote_code=False, - device_mesh=None, - ) - - # Test error when requesting non-existent interaction - non_existent_name = "non_existent_interaction" - assert non_existent_name not in rollout.interaction_map - - # This should raise ValueError in actual usage - available_interactions = list(rollout.interaction_map.keys()) - assert "mock_agent1" in available_interactions - assert "mock_agent2" in available_interactions - assert non_existent_name not in available_interactions - - finally: - os.unlink(temp_config_path) - - def test_backward_compatibility_no_interaction_config(self): - """Test backward compatibility when no interaction config is provided.""" - setup_distributed() - # Create config without interaction config - config = DictConfig( - { - "multi_turn": { - "interaction_config_path": None, - "tool_config_path": None, - "enable": True, - "max_assistant_turns": 5, - "max_user_turns": 3, - "use_inference_chat_template": True, - "tokenization_sanity_check_mode": "disable", - }, - "prompt_length": 32, - "response_length": 16, - "max_model_len": 512, - "dtype": "bfloat16", - "gpu_memory_utilization": 0.8, - "load_format": "dummy", - "enforce_eager": True, - "free_cache_engine": False, - "calculate_log_probs": False, - "tensor_model_parallel_size": 1, - "n": 1, - "val_kwargs": {"top_k": 1, "top_p": 1.0, "temperature": 0.0}, - } - ) - - with ( - patch.object(SGLangRollout, "_init_distributed_env", return_value=None), - patch.object(SGLangRollout, "_init_inference_engine", return_value=None), - patch.object(SGLangRollout, "_init_sampling_params", return_value=None), - ): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B", padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - - mock_model_config = MagicMock() - mock_model_config.max_position_embeddings = 2048 - mock_model_config.rope_scaling = { - "factor": 4.0, - "original_max_position_embeddings": 32768, - "type": "yarn", - } - - rollout = SGLangRollout( - actor_module="mock_model", - config=config, - processing_class=tokenizer, - model_hf_config=mock_model_config, - port=None, - trust_remote_code=False, - device_mesh=None, - ) - - # Check that no interactions were initialized - assert len(rollout.interaction_map) == 0 diff --git a/tests/workers/rollout/test_sglang_rollout_sharding_manager.py b/tests/workers/rollout/test_sglang_rollout_sharding_manager.py deleted file mode 100644 index 0d3c7b5da..000000000 --- a/tests/workers/rollout/test_sglang_rollout_sharding_manager.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import torch - -from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets - -_TENSOR_1MB = torch.zeros(512, 512) -_BYTES_1MB = 1 << 20 - - -@pytest.mark.parametrize( - "named_tensors, bucket_size_mb, gt_groups", - [ - ( - [("a", _TENSOR_1MB), ("b", _TENSOR_1MB)], - 0.5 * _BYTES_1MB, - [["a"], ["b"]], - ), - ( - [("a", _TENSOR_1MB), ("b", _TENSOR_1MB)], - 1 * _BYTES_1MB, - [["a"], ["b"]], - ), - ( - [("a", _TENSOR_1MB), ("b", _TENSOR_1MB)], - 1.5 * _BYTES_1MB, - [["a"], ["b"]], - ), - ( - [("a", _TENSOR_1MB), ("b", _TENSOR_1MB)], - 2 * _BYTES_1MB, - [["a", "b"]], - ), - ], -) -def test_get_named_tensor_buckets(named_tensors, bucket_size_mb, gt_groups: list[list[str]]): - named_tensors_iter = iter(named_tensors) - groups = list(get_named_tensor_buckets(named_tensors_iter, bucket_size_mb)) - assert len(groups) == len(gt_groups) - for group, gt_group in zip(groups, gt_groups, strict=True): - assert len(group) == len(gt_group) - for (name, _), (gt_name) in zip(group, gt_group, strict=True): - assert name == gt_name diff --git a/tests/workers/rollout/test_sglang_spmd.py b/tests/workers/rollout/test_sglang_spmd.py deleted file mode 100644 index 0995e2f64..000000000 --- a/tests/workers/rollout/test_sglang_spmd.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -usage: torchrun --standalone --nnodes=1 \ - --nproc_per_node=2 $(which pytest) \ - -s test_sglang_async_spmd.py -""" - -import asyncio - -import torch -from sglang.srt.entrypoints.engine import Engine -from sglang.srt.utils import broadcast_pyobj -from torch.distributed.device_mesh import init_device_mesh -from utils_sglang import ( - are_lists_similar, - clean_torchelastic_env, - generate_hf_output, - initialize_global_process_group, - load_tokenizer_and_model, - prepare_inputs, -) - - -def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor): - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] - token_ids = prompt_token_ids[non_pad_index:].tolist() - return token_ids - - -def test_sglang_spmd(): - assert torch.cuda.device_count() >= 2 - initialize_global_process_group(spmd=True) - clean_torchelastic_env() - - max_prompt_length = 16 - max_response_length = 16 - - local_model_path = "Qwen/Qwen2.5-0.5B" - tokenizer, actor_model = load_tokenizer_and_model(local_model_path) - - preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"] - input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length) - - hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) - - tensor_parallel_size = 2 - inference_device_mesh_cpu = init_device_mesh( - "cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"] - ) - tp_rank = inference_device_mesh_cpu["tp"].get_local_rank() - - if tp_rank == 0: - llm = Engine( - model_path=local_model_path, - dtype="bfloat16", - mem_fraction_static=0.5, - enable_memory_saver=True, - tp_size=inference_device_mesh_cpu["tp"].size(), - attention_backend="fa3", - ) - - input_ids = input_ids.cuda() - idx_list = [] - - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id - for i in range(input_ids.shape[0]): - idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) - - sampling_params = dict( - n=1, - temperature=0, - top_p=1, - top_k=-1, - max_new_tokens=max_response_length, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - skip_special_tokens=True, - spaces_between_special_tokens=True, - ignore_eos=False, - ) - - loop = asyncio.get_event_loop() - outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params)) - else: - outputs = None - - [outputs] = broadcast_pyobj( - [outputs], - rank=inference_device_mesh_cpu["tp"].get_local_rank(), - src=inference_device_mesh_cpu["tp"].mesh[0].item(), - dist_group=inference_device_mesh_cpu["tp"].get_group(), - force_cpu_device=False, - ) - - sglang_response_tokens = [output["text"] for output in outputs] - - print(f"sglang response: {sglang_response_tokens}") - assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n" - print("SPMD Test Passed!") - - torch.distributed.barrier() - torch.distributed.destroy_process_group() diff --git a/tests/workers/rollout/utils_sglang.py b/tests/workers/rollout/utils_sglang.py deleted file mode 100644 index d16b09feb..000000000 --- a/tests/workers/rollout/utils_sglang.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from datetime import timedelta - -import torch -from omegaconf import OmegaConf -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from verl.utils.model import compute_position_id_with_mask -from verl.utils.torch_functional import pad_sequence_to_length - - -# ====================== utils ====================== -def levenshtein(s1, s2): - m, n = len(s1), len(s2) - dp = [[0] * (n + 1) for _ in range(m + 1)] - for i in range(m + 1): - dp[i][0] = i - for j in range(n + 1): - dp[0][j] = j - for i in range(1, m + 1): - for j in range(1, n + 1): - cost = 0 if s1[i - 1] == s2[j - 1] else 1 - dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost) - return dp[m][n] - - -def are_lists_similar(a, b, threshold=10): - if len(a) != len(b): - print("The lists are of different lengths.") - return False - total_length = 0 - total_diff = 0 - for s1, s2 in zip(a, b, strict=True): - max_len = max(len(s1), len(s2)) - total_length += max_len - total_diff += levenshtein(s1, s2) - percentage_difference = (total_diff / total_length) * 100 - print(f"Total difference: {percentage_difference:.2f}%") - return percentage_difference <= threshold - - -def initialize_global_process_group(timeout_second=36000, spmd=False): - import torch.distributed - - if not torch.distributed.is_initialized(): # Check if already initialized - print("Initializing process group...") - torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second)) - else: - print("Process group already initialized.") - - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - torch.cuda.set_device(local_rank) - - CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", "") - if not CUDA_VISIBLE_DEVICES: - if spmd: - # CUDA_VISIBLE_DEVICES = ','.join(str(i) for i in range(tensor_parallel_size)) - CUDA_VISIBLE_DEVICES = ",".join(str(i) for i in range(world_size)) - else: - CUDA_VISIBLE_DEVICES = str(local_rank) - os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES - print(f"CUDA_VISIBLE_DEVICES is not set, set to {CUDA_VISIBLE_DEVICES}") - - return local_rank, rank, world_size - - -def clean_torchelastic_env(): - for k in ["TORCHELASTIC_USE_AGENT_STORE"]: - if k in os.environ: - del os.environ[k] - - -def load_tokenizer_and_model(local_model_path, dtype="bfloat16"): - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained(local_model_path, torch_dtype=getattr(torch, dtype), device_map="cuda") - return tokenizer, model - - -def prepare_inputs(tokenizer, prompts, max_prompt_length): - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id - tokenized = tokenizer(prompts, return_tensors="pt", padding=True) - input_ids = pad_sequence_to_length(tokenized["input_ids"], max_prompt_length, pad_token_id, left_pad=True) - attention_mask = pad_sequence_to_length( - tokenized["attention_mask"], max_prompt_length, pad_token_id=0, left_pad=True - ) - position_ids = compute_position_id_with_mask(attention_mask) - position_ids = pad_sequence_to_length(position_ids, max_prompt_length, pad_token_id=0, left_pad=True) - return input_ids, attention_mask, position_ids - - -def generate_hf_output(model, input_ids, attention_mask, tokenizer, max_response_length): - generation_config = GenerationConfig(do_sample=False) - output = model.generate( - input_ids=input_ids.cuda(), - attention_mask=attention_mask.cuda(), - max_new_tokens=max_response_length, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - generation_config=generation_config, - output_scores=False, - return_dict_in_generate=True, - use_cache=False, - ) - seq = output.sequences - response = seq[:, input_ids.shape[1] :] - return tokenizer.batch_decode(response) - - -def get_rollout_config( - max_response_length, - max_prompt_length, - dtype, - tensor_parallel_size, - tool_config_path=None, - interaction_config_path=None, -): - sampling_params = dict( - n=1, - temperature=0, - top_p=1, - top_k=-1, - max_new_tokens=max_response_length, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - skip_special_tokens=True, - spaces_between_special_tokens=True, - ignore_eos=False, - ) - - rollout_config = OmegaConf.create( - { - "name": "sglang", - "mode": "sync", - "load_format": "dummy_dtensor", - "enforce_eager": False, - "free_cache_engine": True, - "dtype": dtype, - "gpu_memory_utilization": 0.5, - "ignore_eos": False, - "max_num_batched_tokens": 8192, - "prompt_length": max_prompt_length, - "response_length": max_response_length, - "tensor_model_parallel_size": tensor_parallel_size, - # set to 128MB only for testing - "update_weights_bucket_megabytes": 128, - "multi_turn": { - "max_assistant_turns": 4, - "max_user_turns": 4, - "enable": True, - "tool_config_path": tool_config_path, - "interaction_config_path": interaction_config_path, - "use_inference_chat_template": False, - "tokenization_sanity_check_mode": "strict", - }, - "max_model_len": None, - **sampling_params, - } - ) - - return rollout_config diff --git a/verl b/verl new file mode 160000 index 000000000..f332fc814 --- /dev/null +++ b/verl @@ -0,0 +1 @@ +Subproject commit f332fc814718b9ea7968f6d264211460d4e90fff diff --git a/verl/__init__.py b/verl/__init__.py deleted file mode 100644 index 6dbdd333f..000000000 --- a/verl/__init__.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -import logging -import os -from importlib.metadata import PackageNotFoundError -from importlib.metadata import version as get_version - -from packaging.version import parse as parse_version - -from .protocol import DataProto -from .utils.device import is_npu_available -from .utils.logging_utils import set_basic_config - - -version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) - -with open(os.path.join(version_folder, "version/version")) as f: - __version__ = f.read().strip() - - -set_basic_config(level=logging.WARNING) - - -__all__ = ["DataProto", "__version__"] - -if os.getenv("VERL_USE_MODELSCOPE", "False").lower() == "true": - if importlib.util.find_spec("modelscope") is None: - raise ImportError("You are using the modelscope hub, please install modelscope by `pip install modelscope -U`") - # Patch hub to download models from modelscope to speed up. - from modelscope.utils.hf_util import patch_hub - - patch_hub() - -if is_npu_available: - from .models.transformers import npu_patch as npu_patch - - package_name = "transformers" - required_version_spec = "4.52.4" - try: - installed_version = get_version(package_name) - installed = parse_version(installed_version) - required = parse_version(required_version_spec) - - if installed < required: - raise ValueError( - f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is " - f"{installed}." - ) - except PackageNotFoundError as e: - raise ImportError( - f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}" - ) from e diff --git a/verl/base_config.py b/verl/base_config.py deleted file mode 100644 index 0cd117bb6..000000000 --- a/verl/base_config.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -from dataclasses import ( - dataclass, - field, - fields, # Import the fields function to inspect dataclass fields -) -from typing import Any - - -# BaseConfig class inherits from collections.abc.Mapping, which means it can act like a dictionary -@dataclass -class BaseConfig(collections.abc.Mapping): - """The BaseConfig provides omegaconf DictConfig-like interface for a dataclass config. - - The BaseConfig class implements the Mapping Abstract Base Class. - This allows instances of this class to be used like dictionaries. - """ - - extra: dict[str, Any] = field(default_factory=dict) - - def __setattr__(self, name: str, value): - # if the field already exists (i.e. was set in __init__) - # and is in our frozen list, block assignment - if hasattr(self, "_frozen_fields") and name in self._frozen_fields and name in self.__dict__: - from dataclasses import FrozenInstanceError - - raise FrozenInstanceError(f"Field '{name}' is frozen and cannot be modified") - # otherwise do the normal thing - super().__setattr__(name, value) - - def get(self, key: str, default: Any = None) -> Any: - """Get the value associated with the given key. If the key does not exist, return the default value. - - Args: - key (str): The attribute name to retrieve. - default (Any, optional): The value to return if the attribute does not exist. Defaults to None. - - Returns: - Any: The value of the attribute or the default value. - """ - try: - return getattr(self, key) - except AttributeError: - return default - - def __getitem__(self, key: str): - """Implement the [] operator for the class. Allows accessing attributes like dictionary items. - - Args: - key (str): The attribute name to retrieve. - - Returns: - Any: The value of the attribute. - - Raises: - AttributeError: If the attribute does not exist. - TypeError: If the key type is not string - """ - return getattr(self, key) - - def __iter__(self): - """Implement the iterator protocol. Allows iterating over the attribute names of the instance. - - Yields: - str: The name of each field in the dataclass. - """ - for f in fields(self): - yield f.name - - def __len__(self): - """ - Return the number of fields in the dataclass. - - Returns: - int: The number of fields in the dataclass. - """ - return len(fields(self)) diff --git a/verl/experimental/__init__.py b/verl/experimental/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/experimental/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/experimental/agent_loop/__init__.py b/verl/experimental/agent_loop/__init__.py deleted file mode 100644 index a39171db7..000000000 --- a/verl/experimental/agent_loop/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .agent_loop import AgentLoopBase, AgentLoopManager -from .single_turn_agent_loop import SingleTurnAgentLoop -from .tool_agent_loop import ToolAgentLoop - -_ = [SingleTurnAgentLoop, ToolAgentLoop] - -__all__ = ["AgentLoopBase", "AgentLoopManager"] diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py deleted file mode 100644 index 480f6593d..000000000 --- a/verl/experimental/agent_loop/agent_loop.py +++ /dev/null @@ -1,538 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import heapq -import logging -import os -import random -from abc import ABC, abstractmethod -from typing import Any - -import hydra -import numpy as np -import ray -import torch -from cachetools import LRUCache -from omegaconf import DictConfig, OmegaConf -from pydantic import BaseModel -from tensordict import TensorDict -from transformers import AutoTokenizer - -from verl.protocol import DataProto -from verl.single_controller.ray.base import RayWorkerGroup -from verl.utils import hf_tokenizer -from verl.utils.fs import copy_to_local -from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op -from verl.workers.rollout.async_server import async_server_class - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class AsyncLLMServerManager: - """ - A class to manage multiple OpenAI compatible LLM servers. This class provides - - Load balance: least requests load balancing - - Sticky session: send multi-turn chat completions to same server for automatic prefix caching - """ - - def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000): - """Initialize the AsyncLLMServerManager. - - Args: - config (DictConfig): YAML config. - server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. - max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000. - """ - self.config = config - self.server_handles = server_handles - random.shuffle(self.server_handles) - - # Least requests load balancing - self.weighted_serveres = [[0, (hash(server), server)] for server in server_handles] - heapq.heapify(self.weighted_serveres) - - # LRU cache to map request_id to server - self.request_id_to_server = LRUCache(maxsize=max_cache_size) - - def _choose_server(self, request_id: str) -> ray.actor.ActorHandle: - # TODO: implement server pressure awareness load balancing - if request_id in self.request_id_to_server: - return self.request_id_to_server[request_id] - - server = self.weighted_serveres[0][1][1] - self.weighted_serveres[0][0] += 1 - heapq.heapreplace(self.weighted_serveres, self.weighted_serveres[0]) - self.request_id_to_server[request_id] = server - return server - - @rollout_trace_op - async def generate( - self, - request_id, - *, - prompt_ids: list[int], - sampling_params: dict[str, Any], - ) -> list[int]: - """Generate tokens from prompt ids. - - Args: - request_id (str): request id for sticky session. - prompt_ids (List[int]): List of prompt token ids. - sampling_params (Dict[str, Any]): Sampling parameters for the chat completion. - - Returns: - List[int]: List of generated token ids. - """ - server = self._choose_server(request_id) - output = await server.generate.remote( - request_id=request_id, - prompt_ids=prompt_ids, - sampling_params=sampling_params, - ) - return output - - -class AgentLoopMetrics(BaseModel): - """Agent loop performance metrics.""" - - generate_sequences: float = 0.0 - tool_calls: float = 0.0 - - -class AgentLoopOutput(BaseModel): - """Agent loop output.""" - - prompt_ids: list[int] - response_ids: list[int] - response_mask: list[int] - num_turns: int = 0 - metrics: AgentLoopMetrics - - -# make hydra.utils.instantiate happy -class _DummyConfig: - def __init__(self, config: DictConfig) -> None: - self.config = config - - -class AgentLoopBase(ABC): - """An agent loop takes a input message, chat with OpenAI compatible LLM server and interact with various - environments.""" - - _class_initialized = False - - def __init__( - self, trainer_config: _DummyConfig, server_manager: AsyncLLMServerManager, tokenizer: AutoTokenizer, **kwargs - ): - """Initialize agent loop, each sample will have its own loop instance. - - Args: - trainer_config (_DummyConfig): trainer config. - server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager. - tokenizer (AutoTokenizer): Tokenizer for tokenize messages. - """ - self.init_class(trainer_config.config, tokenizer, **kwargs) - self.config = trainer_config.config - self.server_manager = server_manager - self.tokenizer = tokenizer - self.loop = asyncio.get_running_loop() - - @classmethod - def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, **kwargs): - """This is used to do heavy initialization work that should shared across all instances. It's only called once. - - Args: - config (DictConfig): trainer config. - tokenizer (AutoTokenizer): Tokenizer for tokenize messages. - **kwargs: extra kwargs from config file passed in by `hydra.utils.instantiate`. - """ - if cls._class_initialized: - return - cls._class_initialized = True - - @abstractmethod - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: - """Run agent loop to interact with LLM server and environment. - - Args: - messages (List[Dict[str, Any]]): Input messages. - sampling_params (Dict[str, Any]): LLM sampling params. - - Returns: - AgentLoopOutput: Agent loop output. - """ - raise NotImplementedError - - -"""Agent loop registry: key is agent_name, value is a dict of agent loop config -used by hydra.utils.instantiate to initialize agent loop instance. - -https://hydra.cc/docs/advanced/instantiate_objects/overview/ -""" -_agent_loop_registry: dict[str, dict] = {} - - -def register(agent_name: str): - """Register agent loop class.""" - - def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]: - fqdn = f"{subclass.__module__}.{subclass.__qualname__}" - _agent_loop_registry[agent_name] = {"_target_": fqdn} - return subclass - - return decorator - - -@ray.remote -class AgentLoopWorker: - """Agent loop worker takes a batch of messages and run each message in an agent loop.""" - - def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle]): - """Initialize agent loop manager. - - Args: - config (DictConfig): YAML config. - server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. - """ - self.config = config - self.server_manager = AsyncLLMServerManager(config, server_handles) - - model_path = config.actor_rollout_ref.model.path - self.model_name = "/".join(model_path.split("/")[-2:]) - local_path = copy_to_local(config.actor_rollout_ref.model.path) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) - - agent_loop_config_path = config.actor_rollout_ref.rollout.agent.agent_loop_config_path - if agent_loop_config_path: - agent_loop_configs = OmegaConf.load(agent_loop_config_path) - for agent_loop_config in agent_loop_configs: - _agent_loop_registry[agent_loop_config.name] = agent_loop_config - - trace_config = config.trainer.get("rollout_trace", {}) - trace_config = self.config.actor_rollout_ref.rollout.get("trace", {}) - RolloutTraceConfig.init( - self.config.trainer.project_name, - self.config.trainer.experiment_name, - trace_config.get("backend"), - trace_config.get("token2text", False), - ) - - async def generate_sequences(self, batch: DataProto) -> DataProto: - """Generate sequences from agent loop. - - Args: - batch (DataProto): Input batch. - - Returns: - DataProto: Output batch. - - prompts: [bsz, prompt_length], prompt token ids from dataset. - - responses: [bsz, response_length], output token ids include response tokens - from LLM generation and observation tokens from tool_calls. - - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens. - - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens - and response tokens. - - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens. - - position_ids: [bsz, prompt_length + response_length], incremental position ids. - - For multi-turn conversations: - responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->| - response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0| - """ - config = self.config.actor_rollout_ref.rollout - sampling_params = dict( - temperature=config.temperature, - top_p=config.top_p, - repetition_penalty=1.0, - ) - - # override sampling params for validation - if batch.meta_info.get("validate", False): - sampling_params["top_p"] = config.val_kwargs.top_p - sampling_params["temperature"] = config.val_kwargs.temperature - - # by default, we assume it's a single turn agent - if "agent_name" not in batch.non_tensor_batch: - batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object) - - tasks = [] - agent_names = batch.non_tensor_batch["agent_name"] - raw_prompts = batch.non_tensor_batch["raw_prompt"] - if "index" in batch.non_tensor_batch: - index = batch.non_tensor_batch["index"] - else: - index = np.arange(len(raw_prompts)) - - trajectory_info = await get_trajectory_info( - batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False) - ) - - for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True): - tasks.append( - asyncio.create_task(self._run_agent_loop(agent_name, messages.tolist(), sampling_params, trajectory)) - ) - outputs = await asyncio.gather(*tasks) - - output = self._postprocess(outputs) - return output - - async def _run_agent_loop( - self, - agent_name: str, - messages: list[dict[str, Any]], - sampling_params: dict[str, Any], - trajectory: dict[str, Any], - ) -> AgentLoopOutput: - with rollout_trace_attr( - step=trajectory["step"], - sample_index=trajectory["sample_index"], - rollout_n=trajectory["rollout_n"], - validate=trajectory["validate"], - name="agent_loop", - ): - assert agent_name in _agent_loop_registry, ( - f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}" - ) - - agent_loop_config = _agent_loop_registry[agent_name] - agent_loop = hydra.utils.instantiate( - config=agent_loop_config, - trainer_config=_DummyConfig(config=self.config), - server_manager=self.server_manager, - tokenizer=self.tokenizer, - ) - output = await agent_loop.run(messages, sampling_params) - return output - - def _postprocess(self, inputs: list[AgentLoopOutput]) -> DataProto: - # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py - # prompts: left pad - # responses: right pad - # input_ids: prompt + response - # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] - # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] - - # prompts - self.tokenizer.padding_side = "left" - outputs = self.tokenizer.pad( - [{"input_ids": input.prompt_ids} for input in inputs], - padding="max_length", - max_length=self.config.actor_rollout_ref.rollout.prompt_length, - return_tensors="pt", - return_attention_mask=True, - ) - prompt_ids, prompt_attention_mask = outputs["input_ids"], outputs["attention_mask"] - - # responses - self.tokenizer.padding_side = "right" - outputs = self.tokenizer.pad( - [{"input_ids": input.response_ids} for input in inputs], - padding="max_length", - max_length=self.config.actor_rollout_ref.rollout.response_length, - return_tensors="pt", - return_attention_mask=True, - ) - response_ids, response_attention_mask = outputs["input_ids"], outputs["attention_mask"] - - # response_mask - outputs = self.tokenizer.pad( - [{"input_ids": input.response_mask} for input in inputs], - padding="max_length", - max_length=self.config.actor_rollout_ref.rollout.response_length, - return_tensors="pt", - return_attention_mask=False, - ) - response_mask = outputs["input_ids"] - assert response_ids.shape == response_mask.shape, ( - f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}" - ) - response_mask = response_mask * response_attention_mask - - input_ids = torch.cat([prompt_ids, response_ids], dim=1) - attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1) - position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask - - batch = TensorDict( - { - "prompts": prompt_ids, # [bsz, prompt_length] - "responses": response_ids, # [bsz, response_length] - "response_mask": response_mask, # [bsz, response_length] - "input_ids": input_ids, # [bsz, prompt_length + response_length] - "attention_mask": attention_mask, # [bsz, prompt_length + response_length] - "position_ids": position_ids, # [bsz, prompt_length + response_length] - }, - batch_size=len(input_ids), - ) - - num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32) - metrics = [input.metrics.model_dump() for input in inputs] - return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}, meta_info={"metrics": metrics}) - - -async def get_trajectory_info(step, index, validate): - """Get trajectory info. - - Args: - step (int): global steps in the trainer. - index (list): form datastore extra_info.index column. - validate (bool): whether is a validate step. - - Returns: - list: trajectory. - """ - trajectory_info = [] - rollout_n = 0 - for i in range(len(index)): - if i > 0 and index[i - 1] == index[i]: - rollout_n += 1 - else: - rollout_n = 0 - trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n, "validate": validate}) - return trajectory_info - - -class AgentLoopManager: - """Agent loop manager that manages a group of agent loop workers.""" - - def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): - """Initialize agent loop manager. - - Args: - config (DictConfig): trainer config. - worker_group (RayWorkerGroup): ActorRolloutRef worker group. - """ - self.config = config - self.worker_group = worker_group - - self._initialize_llm_servers() - self._init_agent_loop_workers() - - # Initially we're in sleep mode. - self.sleep() - - def _initialize_llm_servers(self): - self.rollout_tp_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size - self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size - - register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center") - workers_info = ray.get(register_center.get_worker_info.remote()) - assert len(workers_info) == self.worker_group.world_size - - self.async_llm_servers = [None] * self.rollout_dp_size - self.server_addresses = [None] * self.rollout_dp_size - - if self.config.actor_rollout_ref.rollout.agent.custom_async_server: - server_class = async_server_class( - rollout_backend=self.config.actor_rollout_ref.rollout.name, - rollout_backend_module=self.config.actor_rollout_ref.rollout.agent.custom_async_server.path, - rollout_backend_class=self.config.actor_rollout_ref.rollout.agent.custom_async_server.name, - ) - else: - server_class = async_server_class(rollout_backend=self.config.actor_rollout_ref.rollout.name) - - # Start all server instances, restart if address already in use. - unready_dp_ranks = set(range(self.rollout_dp_size)) - while len(unready_dp_ranks) > 0: - servers = { - rollout_dp_rank: server_class.options( - # make sure AsyncvLLMServer colocates with its corresponding workers - scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( - node_id=workers_info[rollout_dp_rank * self.rollout_tp_size], - soft=False, - ), - name=f"async_llm_server_{rollout_dp_rank}", - ).remote(self.config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) - for rollout_dp_rank in unready_dp_ranks - } - - for rollout_dp_rank, server in servers.items(): - try: - address = ray.get(server.get_server_address.remote()) - self.server_addresses[rollout_dp_rank] = address - self.async_llm_servers[rollout_dp_rank] = server - unready_dp_ranks.remove(rollout_dp_rank) - except Exception: - ray.kill(server) - print(f"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting...") - - # All server instances are ready, init AsyncLLM engine. - ray.get([server.init_engine.remote() for server in self.async_llm_servers]) - - def _init_agent_loop_workers(self): - self.agent_loop_workers = [] - for i in range(self.config.actor_rollout_ref.rollout.agent.num_workers): - self.agent_loop_workers.append( - AgentLoopWorker.options( - name=f"agent_loop_worker_{i}", - ).remote(self.config, self.async_llm_servers) - ) - - def generate_sequences(self, prompts: DataProto) -> DataProto: - """Split input batch and dispatch to agent loop workers. - - Args: - prompts (DataProto): Input batch. - - Returns: - DataProto: Output batch. - """ - if self.config.actor_rollout_ref.rollout.free_cache_engine: - self.wake_up() - chunkes = prompts.chunk(len(self.agent_loop_workers)) - outputs = ray.get( - [ - worker.generate_sequences.remote(chunk) - for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True) - ] - ) - output = DataProto.concat(outputs) - if self.config.actor_rollout_ref.rollout.free_cache_engine: - self.sleep() - - # calculate performance metrics - metrics = [output.meta_info["metrics"] for output in outputs] # List[List[Dict[str, str]]] - timing = self._performance_metrics(metrics, output) - - output.meta_info = {"timing": timing} - return output - - def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]: - timing = {} - t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk]) - t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk]) - timing["agent_loop/generate_sequences/min"] = t_generate_sequences.min() - timing["agent_loop/generate_sequences/max"] = t_generate_sequences.max() - timing["agent_loop/generate_sequences/mean"] = t_generate_sequences.mean() - timing["agent_loop/tool_calls/min"] = t_tool_calls.min() - timing["agent_loop/tool_calls/max"] = t_tool_calls.max() - timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean() - - # batch sequence generation is bounded by the slowest sample - slowest = np.argmax(t_generate_sequences + t_tool_calls) - attention_mask = output.batch["attention_mask"][slowest] - prompt_length = output.batch["prompts"].shape[1] - timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest] - timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest] - timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item() - timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item() - - return timing - - def wake_up(self): - """Wake up all rollout server instances.""" - ray.get([server.wake_up.remote() for server in self.async_llm_servers]) - - def sleep(self): - """Sleep all rollout server instances.""" - ray.get([server.sleep.remote() for server in self.async_llm_servers]) diff --git a/verl/experimental/agent_loop/single_turn_agent_loop.py b/verl/experimental/agent_loop/single_turn_agent_loop.py deleted file mode 100644 index 411388e73..000000000 --- a/verl/experimental/agent_loop/single_turn_agent_loop.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import os -from typing import Any -from uuid import uuid4 - -from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register -from verl.utils.profiler import simple_timer - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -@register("single_turn_agent") -class SingleTurnAgentLoop(AgentLoopBase): - """Naive agent loop that only do single turn chat completion.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length - self.response_length = self.config.actor_rollout_ref.rollout.response_length - - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: - metrics = {} - request_id = uuid4().hex - prompt_ids = await self.loop.run_in_executor( - None, lambda: self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) - ) - - with simple_timer("generate_sequences", metrics): - response_ids = await self.server_manager.generate( - request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params - ) - response_mask = [1] * len(response_ids) - - output = AgentLoopOutput( - prompt_ids=prompt_ids, - response_ids=response_ids[: self.response_length], - response_mask=response_mask[: self.response_length], - num_turns=2, - metrics=metrics, - ) - return output diff --git a/verl/experimental/agent_loop/tool_agent_loop.py b/verl/experimental/agent_loop/tool_agent_loop.py deleted file mode 100644 index 3437c0be5..000000000 --- a/verl/experimental/agent_loop/tool_agent_loop.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import json -import logging -import os -from typing import Any -from uuid import uuid4 - -from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register -from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser -from verl.tools.utils.tool_registry import initialize_tools_from_config -from verl.utils.profiler import simple_timer -from verl.utils.rollout_trace import rollout_trace_op - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -@register("tool_agent") -class ToolAgentLoop(AgentLoopBase): - @classmethod - def init_class(cls, config, tokenizer, **kwargs): - if cls._class_initialized: - return - cls._class_initialized = True - print("Performing class-level ToolAgentLoop initialization") - - # Initialize tools from config file - cls.tokenizer = tokenizer - cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns - cls.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns - cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls - cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length - cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side - tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path - tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] - cls.tools = {tool.name: tool for tool in tool_list} - cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] - cls.tool_parser = ToolParser.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format, cls.tokenizer) - print(f"Initialized tools: {cls.tools}") - - cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length - cls.response_length = config.actor_rollout_ref.rollout.response_length - cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True) - - @rollout_trace_op - async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: - metrics = {} - request_id = uuid4().hex - prompt_ids = await self.loop.run_in_executor( - None, - lambda: self.tokenizer.apply_chat_template( - messages, tools=self.tool_schemas, add_generation_prompt=True, tokenize=True - ), - ) - response_mask = [] - - user_turns, assistant_turns = 0, 0 - while True: - with simple_timer("generate_sequences", metrics): - response_ids = await self.server_manager.generate( - request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params - ) - prompt_ids += response_ids - response_mask += [1] * len(response_ids) - assistant_turns += 1 - - # reach max response length - if len(response_mask) >= self.response_length: - break - - # reach max assistant turns - if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns: - break - - # reach max user turns - if self.max_user_turns and user_turns >= self.max_user_turns: - break - - # no tool calls - _, tool_calls = await self.tool_parser.extract_tool_calls(response_ids) - if not tool_calls: - break - - # call tools - tasks = [] - for tool_call in tool_calls[: self.max_parallel_calls]: - tasks.append(self._call_tool(tool_call)) - with simple_timer("tool_calls", metrics): - tool_responses = await asyncio.gather(*tasks) - if any(isinstance(item, Exception) for item in tool_responses): - break - - # append tool_response_ids - tool_response_ids = await self.loop.run_in_executor( - None, - lambda messages=tool_responses: self.tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True - ), - ) - tool_response_ids = tool_response_ids[len(self.system_prompt) :] - - # NOTE: last turn should not be user turn, or the EOS token reward - # can't be propagated to previous token in GAE. - if len(response_mask) + len(tool_response_ids) >= self.response_length: - break - - prompt_ids += tool_response_ids - response_mask += [0] * len(tool_response_ids) - user_turns += 1 - - response_ids = prompt_ids[-len(response_mask) :] - prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)] - - output = AgentLoopOutput( - prompt_ids=prompt_ids, - response_ids=response_ids[: self.response_length], - response_mask=response_mask[: self.response_length], - num_turns=user_turns + assistant_turns + 1, - metrics=metrics, - ) - return output - - async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]: - """Call tool and return tool response.""" - tool, instance_id = None, None - try: - # TODO: append malformed tool_call to the prompt: invalid function name or arguments - tool_name = tool_call.name - tool_args = json.loads(tool_call.arguments) - tool = self.tools[tool_name] - - instance_id = await tool.create() - tool_response, _, _ = await tool.execute(instance_id, tool_args) - except Exception as e: - logger.exception(f"Error when executing tool: {e}") - return e - finally: - if tool and instance_id: - await tool.release(instance_id) - - if len(tool_response) > self.max_tool_response_length: - if self.tool_response_truncate_side == "left": - tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)" - elif self.tool_response_truncate_side == "right": - tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :] - else: - length = self.max_tool_response_length // 2 - tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:] - - return { - "role": "tool", - "content": tool_response, - } diff --git a/verl/experimental/agent_loop/tool_parser.py b/verl/experimental/agent_loop/tool_parser.py deleted file mode 100644 index 5b4de4a8e..000000000 --- a/verl/experimental/agent_loop/tool_parser.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import json -import logging -import os -from abc import ABC, abstractmethod - -import regex as re -from pydantic import BaseModel - -from verl.utils.rollout_trace import rollout_trace_op - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class FunctionCall(BaseModel): - arguments: str - """ - The arguments to call the function with, as generated by the model in JSON - format. Note that the model does not always generate valid JSON, and may - hallucinate parameters not defined by your function schema. Validate the - arguments in your code before calling your function. - """ - - name: str - """The name of the function to call.""" - - -class ToolParser(ABC): - _registry: dict[str, type["ToolParser"]] = {} - - def __init__(self, tokenizer) -> None: - self.tokenizer = tokenizer - - @abstractmethod - async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]: - """Extract tool calls from the responses. - - Args: - responses_ids (List[int]): The ids of the responses. - - Returns: - Tuple[str, List[FunctionCall]]: Content and extracted tool calls. - """ - raise NotImplementedError - - @classmethod - def get_tool_parser(cls, name: str, tokenizer): - if name not in cls._registry: - raise ValueError(f"Unknown tool parser: {name}") - return cls._registry[name](tokenizer) - - @classmethod - def register(cls, name: str): - def decorator(subclass: type[ToolParser]) -> type[ToolParser]: - cls._registry[name] = subclass - return subclass - - return decorator - - -@ToolParser.register("hermes") -class HermesToolParser(ToolParser): - """Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py""" - - def __init__(self, tokenizer) -> None: - super().__init__(tokenizer) - - self.tool_call_start_token: str = "" - self.tool_call_end_token: str = "" - self.tool_call_regex = re.compile(r"(.*?)", re.DOTALL) - - @rollout_trace_op - async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]: - loop = asyncio.get_running_loop() - text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids) - if self.tool_call_start_token not in text or self.tool_call_end_token not in text: - return text, [] - - matches = self.tool_call_regex.findall(text) - function_calls = [] - for match in matches: - try: - function_call = json.loads(match) - name, arguments = function_call["name"], function_call["arguments"] - function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False))) - except Exception as e: - logger.error(f"Failed to decode tool call: {e}") - - # remaing text exclude tool call tokens - content = self.tool_call_regex.sub("", text) - - return content, function_calls diff --git a/verl/experimental/dataset/__init__.py b/verl/experimental/dataset/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/experimental/dataset/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/experimental/dataset/sampler.py b/verl/experimental/dataset/sampler.py deleted file mode 100644 index b7b15b422..000000000 --- a/verl/experimental/dataset/sampler.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2025 Amazon.com Inc and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from abc import abstractmethod -from collections.abc import Sized - -from omegaconf import DictConfig -from torch.utils.data import Sampler - -from verl import DataProto - - -class AbstractSampler(Sampler[int]): - """Abstract interface for custom samplers.""" - - @abstractmethod - def __init__( - self, - data_source: Sized, - data_config: DictConfig, - ): - pass - - -class AbstractCurriculumSampler(AbstractSampler): - """Experimental interface for curriculum learning samplers.""" - - @abstractmethod - def update(self, batch: DataProto) -> None: - pass diff --git a/verl/experimental/dynamic_dataset/__init__.py b/verl/experimental/dynamic_dataset/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/experimental/dynamic_dataset/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/experimental/dynamic_dataset/dynamicgen_dataset.py b/verl/experimental/dynamic_dataset/dynamicgen_dataset.py deleted file mode 100644 index a9532aa03..000000000 --- a/verl/experimental/dynamic_dataset/dynamicgen_dataset.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2025 Amazon.com Inc and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Dataset class that enables dynamic data generation strategies between iterations of training. -This class extends RLHFDataset and uses an AbstractDataGen instance to generate data. - -This is especially useful in settings where proposer model generates new tasks based -on rollout data. -""" - -import logging -from abc import ABC, abstractmethod -from typing import Optional - -import datasets -from omegaconf import DictConfig -from torch.utils.data import Dataset -from transformers import PreTrainedTokenizer, ProcessorMixin - -from verl import DataProto -from verl.utils.dataset import RLHFDataset -from verl.utils.import_utils import load_extern_type - -logger = logging.getLogger(__name__) - - -class AbstractDataGenerator(ABC): - def __init__(self, config: DictConfig): - self.config = config - - @abstractmethod - def generate(self, dataset: Dataset) -> datasets.Dataset: - """ - Generate method must be implemented by subclasses. - Args: - dataset: The dataset to generate from. - Returns: - Processed data or result as implemented by the subclass. - """ - pass - - -class MockDataGenerator(AbstractDataGenerator): - """ - A noop data gen class that only reappends the first datapoint. - This class is useful as a placeholder and testing. - """ - - def __init__(self, config: DictConfig = None): - super().__init__(config) - - def generate(self, dataset: Dataset) -> datasets.Dataset: - print("MockDataGenerator: No operation performed on the dataset.") - return dataset.dataframe.select([0]) - - -class DynamicGenDataset(RLHFDataset): - """ - A dataset class that uses a data generation strategy to process data. - This class extends RLHFDataset and uses an AbstractDataGen instance to generate data. - """ - - def __init__( - self, - data_files: str | list[str], - tokenizer: PreTrainedTokenizer, - config: DictConfig, - processor: Optional[ProcessorMixin] = None, - ): - super().__init__(data_files, tokenizer, config, processor) - self.datagen: AbstractDataGenerator = config.datagen - assert "datagen" in config and config.datagen.get("path", None) is not None, ( - f"datagen path is not set in config: {config}" - ) - # Dynamically load the custom datagen class - datagen_cls = load_extern_type(config.datagen.path, config.datagen.name) - - # Verify that the custom datagen class inherits from AbstractDataGenerator - abs_cls = AbstractDataGenerator - if not issubclass(datagen_cls, abs_cls): - raise TypeError( - f"The custom datagen class '{config.datagen.name}' from '{config.datagen.path}'" - + " must inherit from {abs_cls}" - ) - - self.data_generator = datagen_cls(config.datagen) - self.on_batch_end() - - def append_dataframe(self, new_dataframe: datasets.Dataset): - new_dataframe = self.maybe_filter_out_long_prompts(new_dataframe) - self.dataframe = datasets.concatenate_datasets([self.dataframe, new_dataframe]) - - logger.info(f"new dataset len: {len(self.dataframe)}") - - def on_batch_end(self, batch: DataProto) -> None: - """ - Generate data using the provided data generation strategy. - Note: This method is intended to change the dataset after each training batch. - """ - new_data = self.data_generator.generate(self) - self.append_dataframe(new_data) diff --git a/verl/interactions/__init__.py b/verl/interactions/__init__.py deleted file mode 100644 index b6db0fcef..000000000 --- a/verl/interactions/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/interactions/base.py b/verl/interactions/base.py deleted file mode 100644 index 7c5d200ab..000000000 --- a/verl/interactions/base.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Optional -from uuid import uuid4 - - -class BaseInteraction: - def __init__(self, config: dict[str, Any]): - self.config = config - self.name: str = config.get("name", "interaction_agent") # More general agent default role name - - async def start_interaction(self, instance_id: Optional[str] = None, **kwargs) -> str: - """Create a tool instance. - - Args: - instance_id: The instance id of the tool. - - Returns: - The instance id of the tool. - """ - if instance_id is None: - return str(uuid4()) - else: - return instance_id - - async def generate_response( - self, instance_id: str, messages: list[dict[str, Any]], **kwargs - ) -> tuple[bool, str, float, dict[str, Any]]: # More clear response generation method - """ - Generates a response for the current turn of interaction. - Returns a tuple containing: - - should_terminate_sequence (bool): True if the interaction sequence should end. - - response_content (str): The textual content of the response. - - current_turn_score (float): The score for this specific turn/response. - - additional_data (dict): Any extra information or metadata. - """ - should_terminate_sequence: bool = False # if True, end rollout - response_content: str = "Your current result seems acceptable." - current_turn_score: float = 0.8 - additional_data: dict[str, Any] = {} - return should_terminate_sequence, response_content, current_turn_score, additional_data - - async def calculate_score(self) -> float: # More clear score calculation method - """ - Calculates a score for the interaction, - potentially considering aspects like partial exposure & in-context task switching. - should be invoke at turn-level - """ - # ...implement the logic to calculate turn-level score... - score = 0.0 - return score - - async def finalize_interaction(self) -> None: # More clear interaction end and resource release method - """ - Finalizes the interaction session and releases any associated state or resources. - Simulates: release state - """ - # ...implement the logic to release state... - pass diff --git a/verl/interactions/gsm8k_interaction.py b/verl/interactions/gsm8k_interaction.py deleted file mode 100644 index 365cbb935..000000000 --- a/verl/interactions/gsm8k_interaction.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -from typing import Any, Optional -from uuid import uuid4 - -from verl.utils.reward_score import gsm8k - -from .base import BaseInteraction - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class Gsm8kInteraction(BaseInteraction): - """A demo interaction for calculating the reward of gsm8k. - - - `start_interaction`: start a interaction instance for a trajectory. - - `generate_response`: generate the response of the user. - - `calculate_score`: calculate the score of the interaction. - - `finalize_interaction`: finalize the interaction instance. - """ - - def __init__(self, config: dict): - super().__init__(config) - self._instance_dict = {} - - async def start_interaction( - self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs - ) -> str: - if instance_id is None: - instance_id = str(uuid4()) - self._instance_dict[instance_id] = { - "response": "", - "ground_truth": ground_truth, - "reward": 0.0, - } - return instance_id - - async def generate_response( - self, instance_id: str, messages: list[dict[str, Any]], **kwargs - ) -> tuple[bool, str, float, dict]: - content = "" - for i in range(len(messages) - 1, -1, -1): - item = messages[i] - if item.get("role") == "user": - content = item.get("content") - break - - if content and content.startswith("#### "): - self._instance_dict[instance_id]["response"] = content - else: - self._instance_dict[instance_id]["response"] = "#### " + (content or "") - - reward = await self.calculate_score(instance_id) - if reward == 1.0: - response = "Your response is correct!" - should_terminate_sequence = True - else: - response = "Your response is incorrect! You need to reflect on your answer and try again." - should_terminate_sequence = False - - return should_terminate_sequence, response, reward, {} - - async def calculate_score(self, instance_id: str, **kwargs) -> float: - return gsm8k.compute_score( - self._instance_dict[instance_id]["response"], - self._instance_dict[instance_id]["ground_truth"], - method="flexible", - format_score=0.0, - score=1.0, - ) - - async def finalize_interaction(self, instance_id: str, **kwargs) -> None: - del self._instance_dict[instance_id] diff --git a/verl/interactions/utils/__init__.py b/verl/interactions/utils/__init__.py deleted file mode 100644 index c4b932b1a..000000000 --- a/verl/interactions/utils/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/interactions/utils/interaction_registry.py b/verl/interactions/utils/interaction_registry.py deleted file mode 100644 index df747af11..000000000 --- a/verl/interactions/utils/interaction_registry.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib.util -import logging -import os -import sys - -from omegaconf import OmegaConf - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -def get_interaction_class(cls_name): - """Dynamically import and return the interaction class.""" - module_name, class_name = cls_name.rsplit(".", 1) - if module_name not in sys.modules: - spec = importlib.util.find_spec(module_name) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - else: - module = sys.modules[module_name] - - interaction_cls = getattr(module, class_name) - return interaction_cls - - -def initialize_interactions_from_config(interaction_config_file): - """Initialize interactions from configuration file. - - Args: - interaction_config_file: Path to the interaction configuration file. - - Returns: - dict: A dictionary mapping interaction names to BaseInteraction instances. - """ - interaction_config = OmegaConf.load(interaction_config_file) - interaction_map = {} - - for interaction_item in interaction_config.interaction: - cls_name = interaction_item.class_name - interaction_cls = get_interaction_class(cls_name) - - # Extract config and name - config = OmegaConf.to_container(interaction_item.config, resolve=True) - - # Get the interaction name - either from config or derive from class name - name = interaction_item.get("name", None) - if name is None: - # If no name is specified, use the class name as default - class_simple_name = cls_name.split(".")[-1] - # Remove "Interaction" suffix if present, otherwise use full class name - if class_simple_name.endswith("Interaction"): - name = class_simple_name[:-11].lower() # Remove "Interaction" (11 chars) - else: - name = class_simple_name.lower() - - # Check for duplicate names - if name in interaction_map: - raise ValueError(f"Duplicate interaction name '{name}' found. Each interaction must have a unique name.") - - # Inject the name into the config - config["name"] = name - - # Create the interaction instance - interaction = interaction_cls(config=config) - interaction_map[name] = interaction - - logger.info(f"Initialized interaction '{name}' with class '{cls_name}'") - - return interaction_map diff --git a/verl/model_merger/__init__.py b/verl/model_merger/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/model_merger/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/model_merger/__main__.py b/verl/model_merger/__main__.py deleted file mode 100644 index f3ab5b9c2..000000000 --- a/verl/model_merger/__main__.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module is used to merge huggingface model and test verl checkpoints from FSDP and Megatron backends. - -To merge FSDP checkpoints: -```sh -python -m verl.model_merger merge \ - --backend fsdp \ - --local_dir checkpoints/verl_fsdp_gsm8k_examples/qwen2_5_0b5_fsdp_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - -To merge Megatron checkpoints: -```sh -python -m verl.model_merger merge \ - --backend megatron \ - --tie-word-embedding \ - --local_dir checkpoints/verl_megatron_gsm8k_examples/qwen2_5_0b5_megatron_saveload/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - -or use distribtued merge for large models like dpskv3 671B - -```sh -torchrun --nproc_per_node 1 --nnodes 8 --node_rank ${RANK} -m verl.model_merger merge\ - --backend megatron \ - --local_dir ./checkpoints/global_step_1/actor \ - --target_dir /path/to/merged_hf_model -``` - - -For more details, please refer to documentation: -https://verl.readthedocs.io/en/latest/advance/checkpoint.html#convert-fsdp-and-megatron-checkpoints-to-huggingface-format-model -""" - -from .base_model_merger import generate_config_from_args, parse_args - - -def main(): - args = parse_args() - config = generate_config_from_args(args) - print(f"config: {config}") - - if config.backend == "fsdp": - from .fsdp_model_merger import FSDPModelMerger - - merger = FSDPModelMerger(config) - elif config.backend == "megatron": - from .megatron_model_merger import MegatronModelMerger - - merger = MegatronModelMerger(config) - else: - raise NotImplementedError(f"Unknown backend: {config.backend}") - - merger.merge_and_save() - merger.cleanup() - - -if __name__ == "__main__": - main() diff --git a/verl/model_merger/base_model_merger.py b/verl/model_merger/base_model_merger.py deleted file mode 100644 index 73ddeb0e1..000000000 --- a/verl/model_merger/base_model_merger.py +++ /dev/null @@ -1,325 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from typing import Optional - -import torch -from accelerate import init_empty_weights -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoModelForTokenClassification, - AutoModelForVision2Seq, - GenerationConfig, -) - -from verl.utils import hf_processor, hf_tokenizer - - -def parse_args(): - parser = argparse.ArgumentParser(description="verl model merger") - subparsers = parser.add_subparsers(dest="operation", required=True, help="Specify 'merge' or 'test' operation.") - - base_op_parser = argparse.ArgumentParser(add_help=False) - base_op_parser.add_argument( - "--backend", type=str, required=True, choices=["fsdp", "megatron"], help="The backend of the model" - ) - base_op_parser.add_argument("--local_dir", type=str, default=None, help="Path to the saved model checkpoints.") - base_op_parser.add_argument( - "--tie-word-embedding", - action="store_true", - help="Whether to tie word embedding weights (currently only Megatron supported)", - ) - base_op_parser.add_argument("--trust-remote-code", action="store_true", help="Whether to trust remote code") - base_op_parser.add_argument( - "--is-value-model", - action="store_true", - help="Whether the model is a value model (currently only Megatron supported)", - ) - base_op_parser.add_argument( - "--use_cpu_initialization", - action="store_true", - help="Whether to use CPU initialization for the model. This is useful for large models that cannot " - "fit into GPU memory during initialization.", - ) - - merge_parser = subparsers.add_parser("merge", parents=[base_op_parser], help="Merge model checkpoints and save.") - merge_parser.add_argument( - "--target_dir", default="tmp", type=str, help="Directory to save the merged huggingface model" - ) - merge_parser.add_argument( - "--hf_upload_path", default=None, type=str, help="Hugging Face repository ID to upload the model" - ) - merge_parser.add_argument( - "--private", action="store_true", help="Whether to upload the model to a private Hugging Face repository" - ) - - test_parser = subparsers.add_parser( - "test", parents=[base_op_parser], help="Test merged model against a reference Hugging Face model" - ) - test_parser.add_argument( - "--test_hf_dir", type=str, required=True, help="Path to the reference Hugging Face model directory for testing" - ) - - args = parser.parse_args() - return args - - -@dataclass -class ModelMergerConfig: - operation: str # 'merge' or 'test' - backend: str - target_dir: Optional[str] = "tmp" - hf_upload_path: Optional[str] = None - private: bool = False - test_hf_dir: Optional[str] = None - tie_word_embedding: bool = False - trust_remote_code: bool = False - is_value_model: bool = False - local_dir: Optional[str] = None - hf_model_config_path: Optional[str] = None - hf_upload: bool = field(init=False) - use_cpu_initialization: bool = False - - def __post_init__(self): - self.hf_upload = self.operation == "merge" and bool(self.hf_upload_path) - if self.operation == "test": - self.target_dir = None - self.hf_upload_path = None - self.private = False - - -def generate_config_from_args(args: argparse.Namespace) -> ModelMergerConfig: - common_config_args = { - "operation": args.operation, - "backend": args.backend, - "tie_word_embedding": args.tie_word_embedding, - "trust_remote_code": args.trust_remote_code, - "is_value_model": args.is_value_model, - "local_dir": args.local_dir, - "hf_model_config_path": os.path.join(args.local_dir, "huggingface"), - "use_cpu_initialization": args.use_cpu_initialization, - } - - if args.operation == "merge": - config = ModelMergerConfig( - **common_config_args, - target_dir=args.target_dir, - hf_upload_path=args.hf_upload_path, - private=args.private, - test_hf_dir=None, - ) - os.makedirs(config.target_dir, exist_ok=True) - elif args.operation == "test": - config = ModelMergerConfig( - **common_config_args, - test_hf_dir=args.test_hf_dir, - # the following args are not used by test operation - target_dir=None, - hf_upload_path=None, - private=False, - ) - else: - raise NotImplementedError(f"Unknown operation: {args.operation}") - return config - - -class BaseModelMerger(ABC): - """ - Abstract base class for merging distributed model checkpoints into HuggingFace format. - - This class provides common functionality for converting model checkpoints from different - distributed training backends (FSDP, Megatron) into standard HuggingFace format that - can be easily loaded and used for inference or further training. - - The merger supports two main operations: - - merge: Convert and save checkpoints to HuggingFace format - - test: Validate merged checkpoints against a reference model - - Args: - config (ModelMergerConfig): Configuration object containing paths, backend type, - and operation parameters. - - Attributes: - config (ModelMergerConfig): The configuration object passed during initialization. - hf_model_config_path (str): Path to the HuggingFace model configuration files. - model_config (PretrainedConfig): Loaded HuggingFace model configuration. - """ - - def __init__(self, config: ModelMergerConfig): - self.config = config - self.hf_model_config_path = config.hf_model_config_path - self.model_config = AutoConfig.from_pretrained( - self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code - ) - - def get_transformers_auto_model_class(self): - if "ForTokenClassification" in self.model_config.architectures[0]: - return AutoModelForTokenClassification - elif "ForCausalLM" in self.model_config.architectures[0]: - return AutoModelForCausalLM - elif "ForConditionalGeneration" in self.model_config.architectures[0]: - return AutoModelForVision2Seq - - raise NotImplementedError(f"Unknown architecture {self.model_config.architectures}") - - def patch_model_generation_config(self, model): - """ - The generation_config created from model config may be different to the pretrained model, - this may lead to error when generating: https://github.com/volcengine/verl/issues/1246 - - This function patch the generation_config created from model config to the pretrained model. - """ - if model.can_generate(): - try: - model.generation_config = GenerationConfig.from_pretrained(self.hf_model_config_path) - except OSError: - print( - f"Warning: Generation config file not found in {self.hf_model_config_path}, using a " - f"generation config created from the model config." - ) - return model - - def save_lora_adapter(self, state_dict: dict[str, torch.Tensor]): - """ - Save lora adapter to safetensors. - - Returns: - lora_path: str, the path to the lora adapter. None if no lora adapter found. - - Note: - This function change the 'state_dict' in place. - """ - lora_params_names = [name for name in state_dict.keys() if "lora_" in name] - - if len(lora_params_names) == 0: - return None - - import json - from typing import OrderedDict - - import peft - from safetensors.torch import save_file - - lora_params = OrderedDict() - target_modules = set() - lora_key = None - - for name in lora_params_names: - lora_key = name.replace(".default.weight", ".weight") - target_modules.add(lora_key.split(".")[-3]) - lora_params[lora_key] = state_dict.pop(name) - - lora_rank = min(lora_params[lora_key].shape[0], lora_params[lora_key].shape[1]) - peft_dict = { - "r": lora_rank, - "lora_alpha": 0, # lora_alpha is not set. An error should be raised to inform the user to set it manually. - "target_modules": list(target_modules), - } - peft_config = peft.LoraConfig(**peft_dict).to_dict() - peft_config["task_type"] = peft_config["task_type"].value if peft_config["task_type"] else None - peft_config["peft_type"] = peft_config["peft_type"].value if peft_config["peft_type"] else None - peft_config["target_modules"] = list(peft_config["target_modules"]) - - lora_path = os.path.join(self.config.target_dir, "lora_adapter") - os.makedirs(lora_path, exist_ok=True) - with open(os.path.join(lora_path, "adapter_config.json"), "w", encoding="utf-8") as f: - json.dump(peft_config, f, ensure_ascii=False, indent=4) - save_file(lora_params, os.path.join(lora_path, "adapter_model.safetensors")) - - for name in list(state_dict.keys()): - key = ( - name.replace("base_model.model.", "") - .replace(".base_layer.weight", ".weight") - .replace(".base_layer.bias", ".bias") - ) - state_dict[key] = state_dict.pop(name) - - return lora_path - - def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): - auto_model_class = self.get_transformers_auto_model_class() - with init_empty_weights(): - model = auto_model_class.from_config( - self.model_config, torch_dtype=torch.bfloat16, trust_remote_code=self.config.trust_remote_code - ) - model.to_empty(device="cpu") - model = self.patch_model_generation_config(model) - - lora_path = self.save_lora_adapter(state_dict) - if lora_path: - print(f"Saving lora adapter to {lora_path}") - - print(f"Saving model to {self.config.target_dir}") - model.save_pretrained(self.config.target_dir, state_dict=state_dict) - del state_dict - del model - - processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) - tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) - if processor is not None: - print(f"Saving processor to {self.config.target_dir}") - processor.save_pretrained(self.config.target_dir) - if tokenizer is not None: - print(f"Saving tokenizer to {self.config.target_dir}") - tokenizer.save_pretrained(self.config.target_dir) - - def upload_to_huggingface(self): - import requests - from huggingface_hub import HfApi - from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError - - api = HfApi() - try: - # Attempt to create repository - api.create_repo(repo_id=self.config.hf_upload_path, private=self.config.private, exist_ok=True) - except HfHubHTTPError as e: - # Handle authentication/API errors - if e.response.status_code == 401: - raise PermissionError( - "Hugging Face authentication failed. Verify your token is valid and has write permissions." - ) from e - elif e.response.status_code == 404: - raise RepositoryNotFoundError(f"Repository path not found: {self.config.hf_upload_path}") from e - else: - raise ConnectionError(f"Failed to create repository ({e.response.status_code}): {e}") from e - except requests.exceptions.ConnectionError as e: - raise ConnectionError("Network connection failed. Check your internet connection.") from e - - try: - # Attempt folder upload - api.upload_folder(folder_path=self.config.target_dir, repo_id=self.config.hf_upload_path, repo_type="model") - except HfHubHTTPError as e: - if e.response.status_code == 401: - raise PermissionError("Authentication failed during upload. Token may have expired.") from e - else: - raise RuntimeError(f"Upload failed ({e.response.status_code}): {e}") from e - except requests.exceptions.ConnectionError as e: - raise ConnectionError("Network interruption during upload. Try again with stable connection.") from e - except OSError as e: - raise FileNotFoundError(f"Local folder error: {self.config.target_dir} - {str(e)}") from e - except Exception as e: - raise RuntimeError(f"Unexpected error during upload: {str(e)}") from e - - @abstractmethod - def merge_and_save(self): - raise NotImplementedError("Subclasses should implement this method") - - @abstractmethod - def cleanup(self): - raise NotImplementedError("Subclasses should implement this method to clean up resources if needed") diff --git a/verl/model_merger/fsdp_model_merger.py b/verl/model_merger/fsdp_model_merger.py deleted file mode 100644 index 7853b2b79..000000000 --- a/verl/model_merger/fsdp_model_merger.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path - -import numpy as np -import torch -from torch.distributed._tensor import Placement, Shard - -try: - # for torch 2.5+ - from torch.distributed.tensor import DTensor -except ImportError: - from torch.distributed._tensor import DTensor - -from tqdm import tqdm - -from .base_model_merger import BaseModelMerger - - -class FSDPModelMerger(BaseModelMerger): - """ - Model merger for FSDP (Fully Sharded Data Parallel) checkpoints. - - This class handles the conversion of FSDP distributed checkpoints into HuggingFace format. - FSDP shards model parameters across multiple processes, and this merger reconstructs - the full model by loading and concatenating the sharded parameters from all ranks. - - The merger supports various FSDP configurations including: - - Pure FSDP (single dimension sharding) - - FSDP + DDP (data parallel + fully sharded data parallel) - - DTensor-based sharding with custom device meshes - - Key features: - - Automatic detection of world size from checkpoint filenames - - Support for DTensor and non-DTensor checkpoints - - Parallel loading of checkpoint shards for efficiency - - Validation against reference HuggingFace models - - Example: - To merge FSDP checkpoints: - ```python - config = ModelMergerConfig( - operation="merge", - backend="fsdp", - local_dir="path/to/fsdp/checkpoints", - target_dir="path/to/output" - ) - merger = FSDPModelMerger(config) - merger.merge_and_save() - ``` - """ - - def _get_world_size(self) -> int: - """_summary_ - From FSDP json config file, extract the world size. - - Returns: - int: world size - """ - config_path = Path(self.config.local_dir) / "fsdp_config.json" - if not config_path.exists(): - raise FileNotFoundError(f"Config file {config_path} does not exist.") - - with open(config_path) as f: - config = json.load(f) - - # Extract world size from the config - world_size = config.get("world_size", None) - if world_size is None: - raise ValueError("World size not found in the config file.") - - return world_size - - def _load_rank_zero_state_dict(self, world_size: int) -> dict: - return torch.load( - Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_0.pt", - map_location="cpu", - weights_only=False, - ) - - def _extract_device_mesh_info(self, state_dict: dict, world_size: int) -> tuple[np.ndarray, tuple[str, ...]]: - """ - Retrieves sharding information (device_mesh, mesh_dim_names) from a DTensor in the state_dict. - If no DTensor is found, infers a simple FSDP mesh based on world_size. - """ - pivot_key = sorted(list(state_dict.keys()))[0] - weight = state_dict[pivot_key] - - if isinstance(weight, DTensor): - # get sharding info - device_mesh = weight.device_mesh - mesh = device_mesh.mesh - mesh_dim_names = device_mesh.mesh_dim_names - else: - # for non-DTensor - mesh = np.array([world_size], dtype=np.int64) - mesh_dim_names = ("fsdp",) - - return mesh, mesh_dim_names - - def _calculate_shard_configuration( - self, mesh: np.ndarray, mesh_dim_names: tuple[str, ...] - ) -> tuple[int, tuple[int, ...]]: - """Calculates the total number of shards and the shape of the device mesh.""" - assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}" - - if "tp" in mesh_dim_names: - # TODO: "tp" is not supported yet due to the above assert - total_shards = mesh.shape[-1] * mesh.shape[-2] - mesh_shape = (mesh.shape[-2], mesh.shape[-1]) - else: - total_shards = mesh.shape[-1] - mesh_shape = (mesh.shape[-1],) - - return total_shards, mesh_shape - - def _merge_by_placement(self, tensors: list[torch.Tensor], placement: Placement) -> torch.Tensor: - """Merges a list of tensors based on their DTensor placement""" - if placement.is_replicate(): - return tensors[0] - elif placement.is_partial(): - raise NotImplementedError("Partial placement is not supported yet") - elif placement.is_shard(): - return torch.cat(tensors, dim=placement.dim).contiguous() - - raise NotImplementedError(f"Unsupported placement: {placement}") - - def _load_and_merge_state_dicts( - self, world_size: int, total_shards: int, mesh_shape: tuple[int, ...], mesh_dim_names: tuple[str, ...] - ) -> dict[str, torch.Tensor]: - model_state_dict_lst = [None] * total_shards - - def process_one_shard(rank: int, model_state_dict_lst: list): - model_path = Path(self.config.local_dir) / f"model_world_size_{world_size}_rank_{rank}.pt" - state_dict = torch.load(model_path, map_location="cpu", weights_only=False) - model_state_dict_lst[rank] = state_dict - return state_dict - - with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: - futures = [executor.submit(process_one_shard, rank, model_state_dict_lst) for rank in range(total_shards)] - for future in tqdm(futures, desc=f"Loading {total_shards} FSDP shards", total=total_shards): - future.result() - - # Merge state dicts from all shards - state_dict = {} - param_placements: dict[str, list] = {} - - for key in set(model_state_dict_lst[0].keys()): - state_dict[key] = [] - for model_state_shard in model_state_dict_lst: - # add tensor shard in order of rank to state_dict[key] - tensor = model_state_shard.pop(key) - if isinstance(tensor, DTensor): - state_dict[key].append(tensor._local_tensor.bfloat16()) - - placements = tuple(tensor.placements) - # replicated placement at dp dimension can be discarded - if mesh_dim_names[0] in ("dp", "ddp"): - placements = placements[1:] - - if key not in param_placements: - param_placements[key] = placements - else: - assert param_placements[key] == placements - else: - state_dict[key].append(tensor.bfloat16()) - - del model_state_dict_lst - - # Merge tensors - for key in sorted(state_dict): - if not isinstance(state_dict[key], list): - print(f"No need to merge key {key}") - continue - if key in param_placements: - # merge shards - placements: tuple[Shard] = param_placements[key] - if len(mesh_shape) == 1: - # 1-D list, FSDP without TP - assert len(placements) == 1 - shards = state_dict[key] - state_dict[key] = self._merge_by_placement(shards, placements[0]) - else: - # 2-D list, FSDP + TP - raise NotImplementedError("FSDP + TP is not supported yet") - else: - state_dict[key] = torch.cat(state_dict[key], dim=0) - - return state_dict - - def merge_and_save(self): - world_size = self._get_world_size() - rank_zero_state_dict = self._load_rank_zero_state_dict(world_size) - - mesh, mesh_dim_names = self._extract_device_mesh_info(rank_zero_state_dict, world_size) - print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") - - total_shards, mesh_shape = self._calculate_shard_configuration(mesh, mesh_dim_names) - print(f"Processing model shards with {total_shards} {mesh_shape} in total") - - merged_state_dict = self._load_and_merge_state_dicts(world_size, total_shards, mesh_shape, mesh_dim_names) - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._validate_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): - auto_model_class = self.get_transformers_auto_model_class() - - hf_model = auto_model_class.from_pretrained(self.config.test_hf_dir, torch_dtype=torch.bfloat16) - hf_state_dict = hf_model.state_dict() - del hf_model - - hf_model_keys = set(hf_state_dict.keys()) - collected_keys = set(state_dict.keys()) - - missing_keys = hf_model_keys - collected_keys - assert len(missing_keys) == 0, f"Missing keys in collected state dict: {list(sorted(missing_keys))}" - - extra_keys = collected_keys - hf_model_keys - assert len(extra_keys) == 0, f"Extra keys in collected state dict: {list(sorted(extra_keys))}" - - for key in hf_model_keys: - hf_shape = hf_state_dict[key].shape - collected_shape = state_dict[key].shape - assert hf_shape == collected_shape, ( - f"Shape mismatch for key '{key}': original {hf_shape} vs collected {collected_shape}" - ) - - hf_dtype = hf_state_dict[key].dtype - collected_dtype = state_dict[key].dtype - assert hf_dtype == collected_dtype, ( - f"Dtype mismatch for key '{key}': original {hf_dtype} vs collected {collected_dtype}" - ) - - torch.testing.assert_close(hf_state_dict[key], state_dict[key], atol=1e-6, rtol=1e-6) - - print("FSDP checks passed: The merged state_dict matches the hf model saved by FSDPCheckpointManager.") - - def cleanup(self): - """Cleanup temporary files if needed.""" - # FSDP merger does not create temporary files, so no cleanup is needed. - pass diff --git a/verl/model_merger/megatron_model_merger.py b/verl/model_merger/megatron_model_merger.py deleted file mode 100644 index 5be281681..000000000 --- a/verl/model_merger/megatron_model_merger.py +++ /dev/null @@ -1,537 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import warnings -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Callable, ContextManager - -import numpy as np -import torch -import torch.distributed as dist -from accelerate import init_empty_weights -from megatron.core import mpu -from megatron.core.models.gpt.gpt_model import ModelType -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from safetensors.torch import load_file -from transformers import ( - AutoConfig, - PretrainedConfig, -) - -from verl.models.mcore import hf_to_mcore_config -from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device -from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing -from verl.utils.megatron_utils import get_model -from verl.utils.tokenizer import hf_processor, hf_tokenizer - -from .base_model_merger import BaseModelMerger, ModelMergerConfig - - -@contextmanager -def noop_context() -> Any: - yield - - -def get_dynamic_pipeline_shards(layer_num: int, pp_size: int) -> list[int]: - """Calculate the pipeline sharding configuration for Megatron-LM. - - Args: - layer_num: Total number of layers in the model. - pp_size: Number of pipeline parallel ranks. - - Returns: - layer number of each pp rank. Make the sharding of the pipeline as uniform as possible. - """ - if layer_num < pp_size: - raise ValueError(f"layer_num {layer_num} must be greater than pp_size {pp_size}.") - - if pp_size < 1: - raise ValueError(f"pp_size must be at least 1, got {pp_size}.") - if pp_size == 1: - return [layer_num] - - if pp_size == 2: - return [ - layer_num // 2, - layer_num - layer_num // 2, - ] - - middle_size = pp_size - 2 - shards_strategy = [] - for middle_layer_num in range(layer_num): - first_last_layer_num = layer_num - middle_layer_num * middle_size - first_layer_num = first_last_layer_num // 2 - last_layer_num = first_last_layer_num - first_last_layer_num // 2 - if 0 < first_layer_num <= middle_layer_num and 0 < last_layer_num <= middle_layer_num: - shards_strategy.append( - ( - [first_layer_num] + [middle_layer_num] * middle_size + [last_layer_num], - abs(first_layer_num - middle_layer_num), - ) - ) - - # sort by diff of layer_num, to make it as uniform as possible - res = sorted(shards_strategy, key=lambda x: x[1])[0][0] - assert sum(res) == layer_num, f"sum(res)={sum(res)} != layer_num={layer_num}, pp_size={pp_size}" - return res - - -class MegatronModelMerger(BaseModelMerger): - """ - Model merger for Megatron-LM distributed checkpoints. - - This class handles the conversion of Megatron-LM distributed checkpoints into HuggingFace format. - Megatron-LM uses tensor parallelism, pipeline parallelism, and data parallelism to distribute - large language models across multiple GPUs. This merger reconstructs the full model by - loading distributed checkpoints and applying the necessary transformations. - - Key features: - - Support for tensor parallel, pipeline parallel, and data parallel configurations - - Automatic parameter name mapping from Megatron to HuggingFace conventions - - Handling of QKV and gate-up tensor splitting/merging - - Support for tied word embeddings and value models - - Integration with Megatron's distributed checkpointing system - - The merger handles various model architectures and configurations: - - Standard transformer models (GPT-style) - - Models with tied word embeddings - - Value models for reinforcement learning - - Multi-layer attention (MLA) architectures - - Mixture of Experts (MoE) models - - Args: - config (ModelMergerConfig): Configuration object with Megatron-specific settings - including tie_word_embedding and is_value_model flags. - - Example: - To merge Megatron checkpoints: - ```python - config = ModelMergerConfig( - operation="merge", - backend="megatron", - local_dir="path/to/megatron/checkpoints", - target_dir="path/to/output", - tie_word_embedding=True - ) - merger = MegatronModelMerger(config) - merger.merge_and_save() - ``` - """ - - def __init__(self, config: ModelMergerConfig): - super().__init__(config) - # Currently we use only 1 rank to merge the dist_ckpt, we will move to multi-process save shortly afterwards - if "WORLD_SIZE" not in os.environ: - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - - torch.distributed.init_process_group(get_nccl_backend()) - - self.rank = torch.distributed.get_rank() - self.world_size = torch.distributed.get_world_size() - local_rank = os.environ.get("LOCAL_RANK", 0) - get_torch_device().set_device(f"{get_device_name()}:{local_rank}") - - mpu.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=self.world_size, - virtual_pipeline_model_parallel_size=None, - context_parallel_size=1, - expert_model_parallel_size=1, - ) - model_parallel_cuda_manual_seed(0) - self.hf_config = AutoConfig.from_pretrained( - self.config.hf_model_config_path, trust_remote_code=self.config.trust_remote_code - ) - print(self.hf_config, flush=True) - - self.params_mapping = { - # megatron core gpt model name, huggingface model name - # NOTICE: It's a little bit tricky, when 2 keys have the same prefix, we need to make sure the - # longer key within the containing relationship is processed first. - "embedding.word_embeddings": "model.embed_tokens", - # input layer norm for dpskv3 - "input_layernorm.weight": "input_layernorm.weight", - "input_layernorm.bias": "input_layernorm.bias", - # attn - "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", - "self_attention.linear_qkv.layer_norm_bias": "input_layernorm.bias", - "self_attention.linear_qkv": "self_attn.qkv_proj", - "self_attention.q_layernorm": "self_attn.q_norm", - "self_attention.k_layernorm": "self_attn.k_norm", - "self_attention.linear_proj": "self_attn.o_proj", - # mla - "self_attention.linear_q_proj": "self_attn.q_proj", - "self_attention.linear_q_down_proj": "self_attn.q_a_proj", - "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", - "self_attention.linear_q_up_proj": "self_attn.q_b_proj", - "self_attention.linear_kv_down_proj": "self_attn.kv_a_proj_with_mqa", - "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", - "self_attention.linear_kv_up_proj": "self_attn.kv_b_proj", - # mlp - "pre_mlp_layernorm": "post_attention_layernorm", - "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", - "mlp.linear_fc1.layer_norm_bias": "post_attention_layernorm.bias", - "mlp.linear_fc1": "mlp.gate_up_proj", - "mlp.linear_fc2": "mlp.down_proj", - # moe - "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", - "mlp.router": "mlp.gate", - "mlp.shared_experts.linear_fc1": "mlp.shared_experts.gate_up_proj", - "mlp.shared_experts.linear_fc2": "mlp.shared_experts.down_proj", - "linear_fc1": "gate_up_proj", - "linear_fc2": "down_proj", - # output - "final_layernorm": "norm", - "output_layer": "lm_head", - } - - if "Qwen2MoeForCausalLM" in self.hf_config.architectures: - self.params_mapping["mlp.shared_experts.linear_fc1"] = "mlp.shared_expert.gate_up_proj" - self.params_mapping["mlp.shared_experts.linear_fc2"] = "mlp.shared_expert.down_proj" - self.params_mapping["mlp.shared_experts.gate_weight"] = "mlp.shared_expert_gate.weight" - - def _load_state_dicts(self, model_ckpt_path: str) -> dict[str, Any]: - """_summary_ - Use Megatron dist_checkpointing to load the model state dicts from the checkpoint directory. - - Args: - model_ckpt_path (str): Path to the model checkpoint directory. - - Returns: - State dict containing the model parameters. - """ - - # init hf config - self.pipeline_shards = get_dynamic_pipeline_shards(self.hf_config.num_hidden_layers, self.world_size) - print(f"Pipeline shards: {self.pipeline_shards}, total layers: {sum(self.pipeline_shards)}") - - tf_config = hf_to_mcore_config( - self.hf_config, - torch.bfloat16, - num_layers_in_first_pipeline_stage=self.pipeline_shards[0] if len(self.pipeline_shards) > 1 else None, - num_layers_in_last_pipeline_stage=self.pipeline_shards[-1] if len(self.pipeline_shards) > 2 else None, - ) - tf_config.use_cpu_initialization = self.config.use_cpu_initialization - tie_word_embeddings = getattr(self.hf_config, "tie_word_embeddings", False) - - # init megatron model - def megatron_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model - - parallel_model = init_mcore_model( - tf_config, - self.hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=tie_word_embeddings, - value=False, - ) - return parallel_model - - context: Callable[..., ContextManager] = ( - init_empty_weights if self.config.use_cpu_initialization else noop_context - ) - with context(): - whole_model = get_model( - model_provider_func=megatron_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=False, - transformer_config=tf_config, - ) - - if self.config.use_cpu_initialization: - # convert meta device to empty tensor so it can use `copy_` function - whole_model[0].module = whole_model[0].module.to_empty(device="cpu") - - # load state dicts - sharded_state_dict = {} - for vpp_rank, model in enumerate(whole_model): - key = f"model{vpp_rank}" if len(whole_model) > 1 else "model" - mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) - sharded_state_dict[key] = model.sharded_state_dict() - model_state_dict = load_dist_checkpointing(sharded_state_dict, model_ckpt_path) - model_state_dict_list = [] - for vpp_rank, model in enumerate(whole_model): - key = f"model{vpp_rank}" if len(whole_model) > 1 else "model" - mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) - model_state_dict_list.append(model_state_dict[key]) - - return model_state_dict_list - - def _check_megatron_state_key(self, key: str) -> bool: - """ - Checks if the key is a valid Megatron state key. - - Now the model merger only supports keys that start with "decoder/embedding/output_layer" in TransformerLayer. - Shall not use key starts with "model." - """ - if key.startswith("model."): - raise ValueError( - f"Invalid key {key} in Megatron state_dict. Expected keys to start with " - f"'decoder/embedding/output_layer' in TransformerLayer." - ) - - skip_checking_keys = ["embedding.word_embeddings", "output_layer"] - for skip_key in skip_checking_keys: - if skip_key in key: - print(f"skip checking key {key}") - return - - # Exclude extra state keys - if not key.startswith("decoder"): - raise ValueError( - f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder' in TransformerLayer." - ) - - def _split_tensors( - self, key: str, tensor: torch.Tensor, config: PretrainedConfig, is_value_model: bool = False - ) -> list[torch.Tensor]: - """ - Splits a tensor into multiple tensors based on the name. - This is used to handle qkv and gate_up tensors. - """ - if "linear_fc1.weight" in key: - # if the tensor is gate and proj - gate_lst = [] - up_lst = [] - gate, up = tensor.chunk(2) - gate_lst.append(gate) - up_lst.append(up) - gate = torch.cat(gate_lst, dim=0) - up = torch.cat(up_lst, dim=0) - return [gate, up] - elif "self_attention.linear_qkv." in key and "layer_norm" not in key: - # if the tensor is qkv, for each param on tp, split into q, k, v - # concat q, k, v separately. - q_lst, k_lst, v_lst = [], [], [] - assert config.num_attention_heads % config.num_key_value_heads == 0 - num_q_per_kv = config.num_attention_heads // config.num_key_value_heads - assert tensor.shape[0] % (num_q_per_kv + 2) == 0, ( - f"Tensor shape {tensor.shape} is not divisible by {num_q_per_kv + 2}" - ) - kv_size = tensor.shape[0] // (num_q_per_kv + 2) - split_size = [kv_size * num_q_per_kv, kv_size, kv_size] - - num_query_groups_per_partition = config.num_key_value_heads - for chunk in tensor.chunk(num_query_groups_per_partition): - split_size = [ - kv_size * num_q_per_kv // num_query_groups_per_partition, - kv_size // num_query_groups_per_partition, - kv_size // num_query_groups_per_partition, - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) - - return [torch.cat(q_lst, dim=0), torch.cat(k_lst, dim=0), torch.cat(v_lst, dim=0)] - else: - return [tensor] - - def _merge_state_dicts(self, model_state_dict_list: list[dict[str, Any]]) -> dict[str, torch.Tensor]: - state_dict = {} - layers_cum = 0 - if self.world_size > 1: - pipeline_cumsum = np.cumsum(self.pipeline_shards) - layers_cum = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1] - - print(f"{layers_cum=}") - for model_state_dict in model_state_dict_list: - layers_handled = 0 - keys = model_state_dict.keys() - for key in keys: - if "extra_state" in key: - continue - if self.config.tie_word_embedding and ("output_layer" in key): - print("skip lm_head and reward_head loading because of tie_word_embeddings") - continue - - self._check_megatron_state_key(key) - hf_name = self._replace_name(key, self.params_mapping) - assert hf_name is not None, f"Failed to convert layer name [{key}] from megatron to huggingface." - if "model.layers." in hf_name: - local_layer_no = int(hf_name.split(".")[2]) - layers_handled = max(local_layer_no, layers_handled) - global_layer_no = local_layer_no + layers_cum - new_key_list = hf_name.split(".") - new_key_list[2] = str(global_layer_no) - hf_name = ".".join(new_key_list) - else: - warnings.warn(f"hf_name {hf_name} will not be fixed with layer number", stacklevel=2) - - if "mlp.experts." in hf_name and ".weight" in hf_name: - name_prefix, expert_id = hf_name.split(".weight") - for proj in ["gate_up", "down"]: - if f"{proj}_proj" in hf_name: - hf_name = hf_name.replace( - f"mlp.experts.{proj}_proj.weight{expert_id}", - f"mlp.experts.{expert_id}.{proj}_proj.weight", - ) - - tensor = model_state_dict[key] - split_tensor = self._split_tensors( - key, tensor, self.hf_config, is_value_model=self.config.is_value_model - ) - - if len(split_tensor) == 1: - state_dict[hf_name] = split_tensor[0] - elif len(split_tensor) == 3: - # split qkv - for n, d in zip(["q", "k", "v"], split_tensor, strict=True): - state_dict[hf_name.replace("qkv", n)] = d - elif len(split_tensor) == 2: - # split gate up - state_dict[hf_name.replace("gate_up", "gate")] = split_tensor[0] - state_dict[hf_name.replace("gate_up", "up")] = split_tensor[1] - shape_info = ( - split_tensor.shape if isinstance(split_tensor, torch.Tensor) else [t.shape for t in split_tensor] - ) - print(f"converted {key} to {hf_name} with shape {shape_info}") - - layers_cum += layers_handled + 1 # zero based - - return state_dict - - def save_hf_model_and_tokenizer(self, merged_state_dict): - if self.world_size == 1: - return super().save_hf_model_and_tokenizer(merged_state_dict) - - from safetensors.torch import save_file - - layer_num = self.hf_config.num_hidden_layers - - # FIXME: make configurable - saves_per_layer = 1 if layer_num < 30 else 2 - saves_total = saves_per_layer * layer_num - saves_indexes = {} - - # calculate the layer start index and key chunks - layer_this_rank = self.pipeline_shards[self.rank] - pipeline_cumsum = np.cumsum(self.pipeline_shards) - layer_start = 0 if self.rank == 0 else pipeline_cumsum[self.rank - 1] - keys = list(merged_state_dict.keys()) - keys_chunk = np.array_split(np.array(keys), layer_this_rank * saves_per_layer) - numel = 0 - - assert len(keys_chunk) == layer_this_rank * saves_per_layer, ( - f"Expected {len(keys_chunk)} chunks, but got {layer_this_rank * saves_per_layer} for rank {self.rank}." - ) - - # save to model shards manually - target_dir = Path(self.config.target_dir) - for i, keys in enumerate(keys_chunk): - sd_to_save = {k: merged_state_dict[k] for k in keys} - numel += sum([sd_to_save[i].numel() for i in sd_to_save]) - save_idx = layer_start * saves_per_layer + i - save_path = target_dir / f"model-{save_idx + 1:05d}-of-{saves_total:05d}.safetensors" - - save_file(sd_to_save, save_path) - for k in keys: - saves_indexes[k] = str(save_path.name) - - tensor = torch.tensor([numel]).to(get_device_name()) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - numel = tensor.cpu().item() - - all_save_indexes = [{} for _ in range(self.world_size)] - dist.all_gather_object(all_save_indexes, saves_indexes) - saves_indexes = {k: v for i in all_save_indexes for k, v in i.items()} - if self.rank == 0: - with open(target_dir / "model.safetensors.index.json", "w") as f: - json.dump( - { - "metadata": { - "total_size": numel, - }, - "weight_map": saves_indexes, - }, - f, - indent=4, - ) - print(f"model saved to {target_dir} with {numel=}") - - self.model_config.save_pretrained(self.config.target_dir) - - processor = hf_processor(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) - tokenizer = hf_tokenizer(self.hf_model_config_path, trust_remote_code=self.config.trust_remote_code) - if processor is not None: - print(f"Saving processor to {self.config.target_dir}") - processor.save_pretrained(self.config.target_dir) - if tokenizer is not None: - print(f"Saving tokenizer to {self.config.target_dir}") - tokenizer.save_pretrained(self.config.target_dir) - - def merge_and_save(self): - from verl.utils.megatron_utils import get_dist_checkpoint_path - - model_ckpt_path = get_dist_checkpoint_path(self.config.local_dir) - - model_state_dict = self._load_state_dicts(model_ckpt_path) - merged_state_dict = self._merge_state_dicts(model_state_dict) - del model_state_dict - - if self.config.operation == "test": - if not self.config.test_hf_dir: - raise ValueError("test_hf_dir must be provided for test operation") - self._validate_state_dict(merged_state_dict) - elif self.config.operation == "merge": - self.save_hf_model_and_tokenizer(merged_state_dict) - if self.config.hf_upload: - self.upload_to_huggingface() - else: - raise ValueError(f"Unknown operation: {self.config.operation}") - - def _validate_state_dict(self, state_dict: dict[str, torch.Tensor]): - """ - Compares the merged Megatron state_dict against a reference safetensors model. - Applies necessary name mappings from Megatron to Hugging Face conventions using _replace_name. - """ - ref_state_dict = load_file(Path(self.config.test_hf_dir) / "model.safetensors") - - for name, loaded_weight in state_dict.items(): - # name = self._replace_name(original_name, self.params_mapping) - if not name or name.endswith(".bias") and name not in ref_state_dict: - continue - if "rotary_emb.inv_freq" in name: - continue - if "lm_head.weight" in name: - if self.config.is_value_model or self.config.tie_word_embedding: - continue - if name not in ref_state_dict: - raise RuntimeError(f"key: {name} not exist in state_dict") - param = ref_state_dict[name] - assert loaded_weight.dtype == param.dtype - torch.testing.assert_close(loaded_weight.to("cpu"), param, atol=1e-2, rtol=5e-2) - - def _replace_name(self, megatron_name: str, name_mapping: dict[str, str]) -> str: - for m_name, v_name in name_mapping.items(): - if m_name not in megatron_name: - continue - - megatron_name = megatron_name.replace("decoder", "model") - param_name = megatron_name.replace(m_name, v_name) - - return param_name - - return None # Return None if no mapping found - - def cleanup(self): - torch.distributed.destroy_process_group() diff --git a/verl/models/README.md b/verl/models/README.md deleted file mode 100644 index 677b92f38..000000000 --- a/verl/models/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# Models -Common modelzoo such as huggingface/transformers stuggles when using Pytorch native model parallelism. Following the design principle of vLLM, we keep a simple, parallelizable, highly-optimized with packed inputs in verl. -## Adding a New Huggingface Model -### Step 1: Copy the model file from HF to verl -- Add a new file under verl/models/hf -- Copy ONLY the model file from huggingface/transformers/models to verl/models/hf - -### Step 2: Modify the model file to use packed inputs -- Remove all the code related to inference (kv cache) -- Modify the inputs to include only - - input_ids (total_nnz,) - - cu_seqlens (total_nnz + 1,) - - max_seqlen_in_batch: int -- Note that this requires using flash attention with causal mask. - -### Step 2.5: Add tests -- Add a test to compare this version and the huggingface version -- Following the infrastructure and add tests to tests/models/hf - -### Step 3: Add a function to apply tensor parallelism -- Please follow - - https://pytorch.org/docs/stable/distributed.tensor.parallel.html - - https://pytorch.org/tutorials/intermediate/TP_tutorial.html -- General comments - - Tensor Parallelism in native Pytorch is NOT auto-parallelism. The way it works is to specify how model parameters and input/output reshards using configs. These configs are then registered as hooks to perform input/output resharding before/after model forward. - -### Step 4: Add a function to apply data parallelism -- Please use FSDP2 APIs -- See demo here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L413 - -### Step 5: Add a function to apply pipeline parallelism -- Comes in Pytorch 2.4 -- Currently only in alpha in nightly version -- Check torchtitan for more details - diff --git a/verl/models/__init__.py b/verl/models/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/models/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/models/llama/__init__.py b/verl/models/llama/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/models/llama/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/models/llama/megatron/__init__.py b/verl/models/llama/megatron/__init__.py deleted file mode 100644 index fc851ea43..000000000 --- a/verl/models/llama/megatron/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .modeling_llama_megatron import ( - ParallelLlamaForCausalLM, - # rmpad with megatron - ParallelLlamaForCausalLMRmPad, - # rmpad with megatron and pipeline parallelism - ParallelLlamaForCausalLMRmPadPP, - ParallelLlamaForValueRmPad, - ParallelLlamaForValueRmPadPP, - # original model with megatron - ParallelLlamaModel, -) - -__all__ = [ - "ParallelLlamaForCausalLM", - "ParallelLlamaForCausalLMRmPad", - "ParallelLlamaForCausalLMRmPadPP", - "ParallelLlamaForValueRmPad", - "ParallelLlamaForValueRmPadPP", - "ParallelLlamaModel", -] diff --git a/verl/models/llama/megatron/checkpoint_utils/__init__.py b/verl/models/llama/megatron/checkpoint_utils/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/models/llama/megatron/checkpoint_utils/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader.py deleted file mode 100644 index dafecfdf0..000000000 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_id, get_torch_device - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_llama( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False -): - """Load merged state_dict to sharded Megatron module in training.""" - from megatron.core import DistributedDataParallel as LocalDDP - from megatron.core import mpu - from megatron.core.transformer.module import Float16Module - from torch.nn.parallel import DistributedDataParallel as torchDDP - - from verl.utils.logger import print_rank_0 - from verl.utils.megatron_utils import unwrap_model - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def fetch_params(module): - for param in module.parameters(): - torch.distributed.fetch( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() - ) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( - f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " - f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" - ) - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.model.layers) == num_layers_per_model - - def _fetch_tensor(tensor, name) -> torch.Tensor: - """fetch tensor""" - nonlocal state_dict - if tensor is not None: - tensor.data.copy_(state_dict[name]) - - def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """fetch tensor in tp shards""" - nonlocal state_dict - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - if tensor is not None: - tensor.data.copy_(tensor_chunk[tp_rank]) - else: - print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """fetch tensor in tp shards""" - nonlocal state_dict - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - if tensor is not None: - tensor.data.copy_(tensor_chunk[tp_rank]) - else: - print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """fetch gate_up tensor in tp shards""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if gate_name in state_dict and up_name in state_dict: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - if tensor is not None: - tensor.data.copy_(tensor_chunk[tp_rank]) - else: - print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: - """fetch tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - assert q_name in state_dict and k_name in state_dict and v_name in state_dict - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) - - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - if tensor is not None: - tensor.data.copy_(tensor_chunk[tp_rank]) - - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - embed_tokens_weight = None - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.model.embed_tokens.weight - _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - num_layer_per_pp = config.num_hidden_layers // pp_size - vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - - layer_list = [] - if vpp_size is not None: - for vpp_rank in range(vpp_size): - num_layer_vpp_chunk = num_layer_per_pp // vpp_size - num_layer_this_model = num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( - mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk - ) - layer_list.extend(list(range(offset, offset + num_layer_this_model))) - else: - num_layer_this_model = num_layer_per_pp - offset = pp_rank * num_layer_per_pp - layer_list.extend(list(range(offset, offset + num_layer_this_model))) - - for layer in layer_list: - print_rank_0(f"loading layer #{layer}...") - layer_name = f"model.layers.{layer}" - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[dst_layer_idx] - - _fetch_tensor( - sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - _fetch_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - - _fetch_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - - _fetch_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _fetch_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - ) - - _fetch_tp_shard_tensor( - sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _fetch_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - ) - - print_rank_0("loading lm_head...") - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.lm_head.weight - - if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: - _fetch_tensor(lm_head_weight, "lm_head.weight") - print_rank_0("load lm_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: - _fetch_tensor(lm_head_weight, "reward_head.weight") - print_rank_0("load lm_head from value_head weight") - else: - _fetch_tensor(None, "lm_head.weight") - print_rank_0("fail to match lm_head in value_model") - else: - _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") - - dist.barrier() - get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py b/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py deleted file mode 100644 index 2f65bc6b1..000000000 --- a/verl/models/llama/megatron/checkpoint_utils/llama_loader_depracated.py +++ /dev/null @@ -1,458 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_id, get_torch_device - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - print(f"get megatron data parallel size: {mpu.get_data_parallel_world_size()}") - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_llama( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False -): - """Load merged state_dict to sharded Megatron module in training.""" - from megatron.core import DistributedDataParallel as LocalDDP - from megatron.core import mpu - from megatron.core.transformer.module import Float16Module - from torch.nn.parallel import DistributedDataParallel as torchDDP - - from verl.utils.logger import print_rank_0 - from verl.utils.megatron_utils import unwrap_model - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def broadcast_params(module): - for param in module.parameters(): - torch.distributed.broadcast( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() - ) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( - f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size " - f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" - ) - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.model.layers) == num_layers_per_model - - def _broadcast_tensor(tensor, name) -> torch.Tensor: - """broadcast tensor from rank0 across mp_group""" - nonlocal state_dict - nonlocal mp_group - if torch.distributed.get_rank() == 0: - if name in state_dict: - weight = state_dict[name] - tensor_shape = weight.shape - else: - tensor_shape = None - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not in state_dict, skip load") - return - - if tensor is None: - tensor = torch.empty( - tensor_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - if torch.distributed.get_rank() == 0: - tensor.data.copy_(weight) - dist.broadcast(tensor, src=0, group=mp_group) - - def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - if name in state_dict: - full_weight = state_dict[name] - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " - f"{tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - assert q_name in state_dict and k_name in state_dict and v_name in state_dict - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( - torch.cat([q_part, k_part, v_part], dim=0) - ) - - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( - torch.cat([q_part, k_part, v_part], dim=0) - ) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - embed_tokens_weight = None - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.model.embed_tokens.weight - _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - for layer in range(config.num_hidden_layers): - print_rank_0(f"loading layer #{layer}...") - layer_name = f"model.layers.{layer}" - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[dst_layer_idx] - - _broadcast_tensor( - sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - - _broadcast_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - ) - - print_rank_0("loading lm_head...") - lm_head_weight = None - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.lm_head.weight - - if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "lm_head.weight") - print_rank_0("load lm_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "reward_head.weight") - print_rank_0("load lm_head from value_head weight") - else: - _broadcast_tensor(None, "lm_head.weight") - print_rank_0("fail to match lm_head in value_model") - else: - _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") - dist.barrier() - # Broadcast weights inside data parallel groups - for wrapped_model in wrapped_models: - broadcast_params(wrapped_model) - - get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py b/verl/models/llama/megatron/checkpoint_utils/llama_saver.py deleted file mode 100644 index 595efcde3..000000000 --- a/verl/models/llama/megatron/checkpoint_utils/llama_saver.py +++ /dev/null @@ -1,442 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist -from megatron.core import mpu -from megatron.core.distributed import DistributedDataParallel as LocalDDP -from megatron.core.transformer.module import Float16Module -from torch.nn.parallel import DistributedDataParallel as torchDDP - -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.logger import print_rank_0 -from verl.utils.megatron_utils import unwrap_model - - -def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): - """given TP,DP,PP rank to get the global rank.""" - - tp_size = mpu.get_tensor_model_parallel_world_size() - dp_size = mpu.get_data_parallel_world_size() - pp_size = mpu.get_pipeline_model_parallel_world_size() - assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( - f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" - ) - # We only support TP-DP-PP grouping, for correctness when resharding - return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def merge_megatron_ckpt_llama(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): - """Merge sharded parameters of a Megatron module into a merged checkpoint. - - Args: - wrapped_models (list of megatron.core.distributed.DistributedDataParallel): - The local DDP wrapped megatron modules. - config (str or None): - HF config for model - dtype: model params type - is_value_model: if model is value model - tie_word_embeddings: tie_word_embeddings, not used in llama, only to keep same interface with qwen2 - Returns: - state_dict (dict): - The merged state_dict in rank 0, and an empty dictionary in other ranks. - """ - start_time = time.time() - - def _get_gpt_model(model): - return model - - dp_rank = mpu.get_data_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if dist.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].model.layers) == num_layers_per_model, ( - "len model layers {} not equal to num_layers_per_model {}".format( - len(models[i].model.layers), num_layers_per_model - ) - ) - - state_dict = dict() - - def _get_cpu_tensor(tensor: torch.Tensor): - if tensor is None: - return None - if tensor.device == torch.device("cpu"): - return tensor.detach().clone() - return tensor.detach().cpu() - - def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: - """broadcast tensor across mp_group""" - nonlocal state_dict - nonlocal mp_group - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - if torch.distributed.get_rank() == src_rank: - if tensor is None: - weight = None - tensor_shape = None - else: - weight = tensor - tensor_shape = weight.shape - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not exist, skip collect") - return - - if weight is None: - weight = torch.empty( - tensor_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - dist.broadcast(weight, src=src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - state_dict[name] = _get_cpu_tensor(weight) - - def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=concat_dim) - if mutate_func is not None: - full_tensor = mutate_func(full_tensor) - state_dict[name] = full_tensor - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_list = [] - up_weight_list = [] - for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] - gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] - up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] - gate_weight_list.append(gate_weight_tp) - up_weight_list.append(up_weight_tp) - - state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) - state_dict[up_name] = torch.cat(up_weight_list, dim=0) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - q_weight_list = [] - k_weight_list = [] - v_weight_list = [] - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp : total_size] - q_weight_list.append(q_part) - k_weight_list.append(k_part) - v_weight_list.append(v_part) - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp : total_size] - q_weight_list.append(q_part) - if i * config.num_key_value_heads % tp_size == 0: - k_weight_list.append(k_part) - v_weight_list.append(v_part) - - state_dict[q_name] = torch.cat(q_weight_list, dim=0) - state_dict[k_name] = torch.cat(k_weight_list, dim=0) - state_dict[v_name] = torch.cat(v_weight_list, dim=0) - - # empty cache before collecting weights - get_torch_device().empty_cache() - # Embeddings - # ------------------- - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("collecting embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - _broadcast_tp_shard_tensor( - gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, - "model.embed_tokens.weight", - src_pp_rank=0, - ) - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - for layer in range(config.num_hidden_layers): - print_rank_0(f"collecting layer #{layer}...") - layer_name = f"model.layers.{layer}" - src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[src_layer_idx] - - _broadcast_tensor( - sync_layer.input_layernorm.weight, - f"{layer_name}.input_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight, - f"{layer_name}.self_attn.o_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - _broadcast_tensor( - sync_layer.post_attention_layernorm.weight, - f"{layer_name}.post_attention_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.down_proj.weight, - f"{layer_name}.mlp.down_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - # Final Layernorm - # ------------------- - print_rank_0("collecting final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - src_pp_rank=pp_size - 1, - ) - - print_rank_0("collecting lm_head...") - - if is_value_model: - if pp_rank == pp_size - 1: - print(f"gpt_model_module.lm_head.weight: {gpt_model_module.lm_head.weight.shape}") - _broadcast_tensor( - gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - _broadcast_tensor( - gpt_model_module.reward_head.weight - if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None - else None, - "reward_head.weight", - src_pp_rank=pp_size - 1, - ) - - else: - _broadcast_tp_shard_tensor( - getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - - dist.barrier() - - get_torch_device().empty_cache() - if torch.distributed.get_rank() == 0: - if dtype not in [torch.float16, torch.bfloat16, torch.float32]: - print(f'Unknown/unsupported dtype to save: {dtype}"') - exit(1) - for k, v in state_dict.items(): - if dtype != v.dtype: - state_dict[k] = v.to(dtype) - - print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") - return state_dict diff --git a/verl/models/llama/megatron/layers/__init__.py b/verl/models/llama/megatron/layers/__init__.py deleted file mode 100644 index 352bc5608..000000000 --- a/verl/models/llama/megatron/layers/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .parallel_attention import ParallelLlamaAttention -from .parallel_decoder import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad -from .parallel_linear import ( - LinearForLastLayer, - MergedColumnParallelLinear, - QKVParallelLinear, -) -from .parallel_mlp import ParallelLlamaMLP -from .parallel_rmsnorm import ParallelLlamaRMSNorm - -__all__ = [ - "LinearForLastLayer", - "MergedColumnParallelLinear", - "QKVParallelLinear", - "ParallelLlamaAttention", - "ParallelLlamaDecoderLayer", - "ParallelLlamaDecoderLayerRmPad", - "ParallelLlamaMLP", - "ParallelLlamaRMSNorm", -] diff --git a/verl/models/llama/megatron/layers/parallel_attention.py b/verl/models/llama/megatron/layers/parallel_attention.py deleted file mode 100644 index e8aacbdb7..000000000 --- a/verl/models/llama/megatron/layers/parallel_attention.py +++ /dev/null @@ -1,460 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Optional - -import torch -import torch.nn.functional as F -from einops import rearrange -from flash_attn.layers.rotary import apply_rotary_emb -from megatron.core import ModelParallelConfig, tensor_parallel -from megatron.core import parallel_state as mpu -from torch import nn -from transformers import LlamaConfig -from transformers.utils import is_flash_attn_2_available - -from verl.models.llama.megatron.layers.parallel_linear import QKVParallelLinear -from verl.utils.megatron import tensor_parallel as tp_utils - - -class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -class LlamaLlama3ScalingRotaryEmbedding(LlamaRotaryEmbedding): - def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device=None): - super().__init__(dim, max_position_embeddings, base, device) - - self.factor = config.rope_scaling["factor"] # `8` in the original implementation - self.high_freq_factor = config.rope_scaling["high_freq_factor"] # `1` in the original implementation - self.low_freq_factor = config.rope_scaling["low_freq_factor"] # `4` in the original implementation - self.old_context_len = config.rope_scaling[ - "original_max_position_embeddings" - ] # `8192` in the original implementation - - low_freq_wavelen = self.old_context_len / self.low_freq_factor - high_freq_wavelen = self.old_context_len / self.high_freq_factor - - wavelen = 2 * math.pi / self.inv_freq - # wavelen < high_freq_wavelen: do nothing; wavelen > low_freq_wavelen: divide by factor - inv_freq_llama = torch.where(wavelen > low_freq_wavelen, self.inv_freq / self.factor, self.inv_freq) - # otherwise: interpolate between the two, using a smooth factor - smooth_factor = (self.old_context_len / wavelen - self.low_freq_factor) / ( - self.high_freq_factor - self.low_freq_factor - ) - smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama - is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) - inv_freq = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class ParallelLlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - super().__init__() - self.config = config - self.megatron_config = megatron_config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - - # assign values after tp - tp_size = mpu.get_tensor_model_parallel_world_size() - assert self.num_heads % tp_size == 0, ( - f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" - ) - assert self.num_key_value_heads % tp_size == 0, ( - f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" - f"{self.num_key_value_heads}, tp_size={tp_size}" - ) - - self.num_heads_per_tp = self.num_heads // tp_size - self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size - self.hidden_size_per_tp = self.hidden_size // tp_size - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and " - f"`num_heads`: {self.num_heads})." - ) - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() - - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - assert row_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) - tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) - - # [self.q_size, self.k_size, self.v_size] - self.qkv_proj = QKVParallelLinear( - input_size=self.hidden_size, - num_heads=self.num_heads, - num_key_value_heads=self.num_key_value_heads, - head_dim=self.head_dim, - bias=config.attention_bias, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - self.q_size = self.num_heads_per_tp * self.head_dim - self.k_size = self.num_key_value_heads_per_tp * self.head_dim - self.v_size = self.num_key_value_heads_per_tp * self.head_dim - - self.o_proj = tensor_parallel.RowParallelLinear( - input_size=self.num_heads * self.head_dim, - output_size=self.hidden_size, - bias=config.attention_bias, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs, - ) - - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - rope_type_key = "type" if "type" in self.config.rope_scaling else "rope_type" - scaling_type = self.config.rope_scaling[rope_type_key] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "llama3": - self.rotary_emb = LlamaLlama3ScalingRotaryEmbedding( - self.head_dim, - self.config, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) - - query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, " - f"but is {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, " - f"but is {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) - attn_output = self.o_proj(attn_output)[0] - return attn_output - - -""" -Remove padding Attention -- Using Flash-attn 2 -- Compatible with sequence parallel -""" - - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - -def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): - batch_size = position_ids.shape[0] - - q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) - k = pad_input(k, indices, batch_size, sequence_length) - cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - - q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) - k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) - - return q_embed, k_embed - - -# use flash-attn rotary embeddings with rmpad -# cos/sin shoudl be: (seq_length, rotary_dim / 2) -def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): - q_embed = apply_rotary_emb( - q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - k_embed = apply_rotary_emb( - k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - return q_embed, k_embed - - -class ParallelLlamaAttentionRmPad(ParallelLlamaAttention): - def forward( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen_in_batch: int = None, - ): - total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel - - if self.megatron_config.sequence_parallel: - total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() - - qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split( - [self.q_size, self.k_size, self.v_size], dim=-1 - ) # (total_nnz, 1, hidden_size) - - if self.megatron_config.sequence_parallel: - sequence_parallel_pad = total_nnz - cu_seqlens[-1] - total_nnz = cu_seqlens[-1] # total_nnz before sp padding - query_states = query_states[:total_nnz] - key_states = key_states[:total_nnz] - value_states = value_states[:total_nnz] - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dime x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) - key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) - value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) - - cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) - cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half - query_states, key_states = apply_rotary_pos_emb_rmpad_flash( - query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch - ) - # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, - # position_ids, indices, - - # TODO: llama does not have dropout in the config?? - # It is recommended to use dropout with FA according to the docs - # when training. - dropout_rate = 0.0 # if not self.training else self.attn_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(torch.float16) - key_states = key_states.to(torch.float16) - value_states = value_states.to(torch.float16) - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_in_batch, - max_seqlen_k=max_seqlen_in_batch, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - ) - - attn_output_unpad = attn_output_unpad.to(input_dtype) - attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() - - # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled - # Here we need to repad - if self.megatron_config.sequence_parallel: - attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) - - attn_output_unpad = self.o_proj(attn_output_unpad)[0] - return attn_output_unpad diff --git a/verl/models/llama/megatron/layers/parallel_decoder.py b/verl/models/llama/megatron/layers/parallel_decoder.py deleted file mode 100644 index f46e9457c..000000000 --- a/verl/models/llama/megatron/layers/parallel_decoder.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -from megatron.core import ModelParallelConfig -from torch import nn -from transformers import LlamaConfig - -from verl.utils.megatron_utils import TransformerConfig, convert_config - -from .parallel_attention import ParallelLlamaAttention, ParallelLlamaAttentionRmPad -from .parallel_mlp import ParallelLlamaMLP -from .parallel_rmsnorm import ParallelLlamaRMSNorm - - -class ParallelLlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.self_attn = ParallelLlamaAttention(config=config, megatron_config=megatron_config) - - self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) - self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) - self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Note: sequence parallel is hidden inside ColumnParallelLinear - # reduce scatter is hidden inside RowParallelLinear - - # Self Attention - hidden_states = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - # TODO: add sequence parallel operator reduce_scatter here - - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - # TODO: add sequence parallel operator all_gather here - - hidden_states = self.mlp(hidden_states) - - # TODO: add sequence parallel operator reduce_scatter here - - hidden_states = residual + hidden_states - - outputs = hidden_states - - return outputs - - -class ParallelLlamaDecoderLayerRmPad(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, layer_idx: int): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.self_attn = ParallelLlamaAttentionRmPad(config=config, megatron_config=megatron_config) - - self.mlp = ParallelLlamaMLP(config, megatron_config=megatron_config) - self.input_layernorm = ParallelLlamaRMSNorm(config, megatron_config) - self.post_attention_layernorm = ParallelLlamaRMSNorm(config, megatron_config) - - def forward( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states # (total_nnz // sp, 1, hidden_size) - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) - # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) - hidden_states = self.self_attn( - hidden_states=hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = residual + hidden_states - - # Fully Connected - # shape changes same as attn - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = hidden_states - - return outputs diff --git a/verl/models/llama/megatron/layers/parallel_linear.py b/verl/models/llama/megatron/layers/parallel_linear.py deleted file mode 100644 index 043726c46..000000000 --- a/verl/models/llama/megatron/layers/parallel_linear.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py - -import torch -from megatron.core import tensor_parallel - - -class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): - def __init__( - self, - input_size, - num_heads, - num_key_value_heads, - head_dim, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs, - ): - # Keep input parameters, and already restrict the head numbers - self.input_size = input_size - self.q_output_size = num_heads * head_dim - self.kv_output_size = num_key_value_heads * head_dim - self.head_dim = head_dim - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - input_size = self.input_size - output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim - - super().__init__( - input_size=input_size, - output_size=output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs, - ) - - -class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): - def __init__( - self, - input_size, - gate_ouput_size, - up_output_size, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs, - ): - # Keep input parameters, and already restrict the head numbers - self.input_size = input_size - self.output_size = gate_ouput_size + up_output_size - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - super().__init__( - input_size=self.input_size, - output_size=self.output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs, - ) - - -class LinearForLastLayer(torch.nn.Linear): - def __init__( - self, - input_size, - output_size, - *, - config, - bias=True, - ): - super().__init__(in_features=input_size, out_features=output_size, bias=bias) - self.sequence_parallel = config.sequence_parallel - if self.sequence_parallel: - self.weight.sequence_parallel = True - - def forward( - self, - input_, - weight=None, - runtime_gather_output=None, - ): - logits = super().forward(input_) - logits = logits.float() - if self.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits, None diff --git a/verl/models/llama/megatron/layers/parallel_mlp.py b/verl/models/llama/megatron/layers/parallel_mlp.py deleted file mode 100644 index 583a317eb..000000000 --- a/verl/models/llama/megatron/layers/parallel_mlp.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from megatron.core import ModelParallelConfig, tensor_parallel -from megatron.core import parallel_state as mpu -from torch import nn -from transformers.activations import ACT2FN - -from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear -from verl.utils.megatron import tensor_parallel as tp_utils - - -class ParallelLlamaMLP(nn.Module): - def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() - - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - assert row_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) - tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) - - tp_size = mpu.get_tensor_model_parallel_world_size() - - self.gate_up_proj = MergedColumnParallelLinear( - input_size=self.hidden_size, - gate_ouput_size=self.intermediate_size, - up_output_size=self.intermediate_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - self.gate_size = self.intermediate_size // tp_size - - self.down_proj = tensor_parallel.RowParallelLinear( - input_size=self.intermediate_size, - output_size=self.hidden_size, - bias=False, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs, - ) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - gate_up = self.gate_up_proj(x)[0] - gate, up = gate_up.split(self.gate_size, dim=-1) - return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/verl/models/llama/megatron/layers/parallel_rmsnorm.py b/verl/models/llama/megatron/layers/parallel_rmsnorm.py deleted file mode 100644 index bc2e9ae36..000000000 --- a/verl/models/llama/megatron/layers/parallel_rmsnorm.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numbers - -import torch -from apex.normalization.fused_layer_norm import fused_rms_norm_affine -from megatron.core import ModelParallelConfig -from torch import nn -from transformers import LlamaConfig - -from verl.utils.megatron import sequence_parallel as sp_utils - - -class ParallelLlamaRMSNorm(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - if isinstance(config.hidden_size, numbers.Integral): - normalized_shape = (config.hidden_size,) - self.normalized_shape = torch.Size(normalized_shape) - self.weight = nn.Parameter(torch.ones(self.normalized_shape)) - self.variance_epsilon = config.rms_norm_eps - - if megatron_config.sequence_parallel: - sp_utils.mark_parameter_as_sequence_parallel(self.weight) - - def forward(self, hidden_states): - return fused_rms_norm_affine( - input=hidden_states, - weight=self.weight, - normalized_shape=self.normalized_shape, - eps=self.variance_epsilon, - memory_efficient=True, - ) diff --git a/verl/models/llama/megatron/modeling_llama_megatron.py b/verl/models/llama/megatron/modeling_llama_megatron.py deleted file mode 100644 index ed5022e0c..000000000 --- a/verl/models/llama/megatron/modeling_llama_megatron.py +++ /dev/null @@ -1,688 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch LLaMA model with Megatron-style acceleration.""" - -from typing import Optional - -import torch -import torch.utils.checkpoint -from megatron.core import ModelParallelConfig, mpu, tensor_parallel -from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import CausalLMOutputWithPast - -from verl.utils.megatron import sequence_parallel as sp_utils -from verl.utils.megatron import tensor_parallel as tp_utils -from verl.utils.megatron_utils import TransformerConfig, convert_config - -from .layers import ParallelLlamaDecoderLayer, ParallelLlamaDecoderLayerRmPad, ParallelLlamaRMSNorm - -""" -TODO: -1. Add weight initialization. Here we need to be careful on TP weight init. -2. Add sequence parallel -3. Load checkpoint from meta LLama pretrained checkpoint -""" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class ParallelLlamaModel(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - - self.layers = nn.ModuleList( - [ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] - ) - self.norm = ParallelLlamaRMSNorm(config, megatron_config) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (batch_size, seq_length) - attention_mask: attention_mask. shape (batch_size, seq_length) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - batch_size, seq_length = input_ids.shape - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) - - hidden_states = inputs_embeds - - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - hidden_states = layer_outputs - - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelLlamaForCausalLM(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.model = ParallelLlamaModel(config, megatron_config=megatron_config) - self.vocab_size = config.vocab_size - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - hidden_states = outputs - logits = self.lm_head(hidden_states)[0] - - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - - logits = logits.float() - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - -class ParallelLlamaModelRmPad(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - self.megatron_config = megatron_config - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - - self.layers = nn.ModuleList( - [ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] - ) - self.norm = ParallelLlamaRMSNorm(config, megatron_config) - - def forward( - self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (1, totol_nnz) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) - - # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) - inputs_embeds = inputs_embeds.transpose(0, 1) - if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) - - hidden_states = inputs_embeds - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = layer_outputs - - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelLlamaForCausalLMRmPad(nn.Module): - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.megatron_config = megatron_config - self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config) - self.vocab_size = config.vocab_size - self._init_head(config) - - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - def _forward_head(self, hidden_states): - # all_gather from sequence parallel region is performed inside lm_head - logits = self.lm_head(hidden_states)[0] - logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) - return logits - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - batch_size, sequence_length = input_ids.shape - - # remove padding here - input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( - input_ids.unsqueeze(dim=-1), attention_mask - ) # (total_nnz, 1) - - # pad input_ids to multiple of tp for all tp ranks - # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap - if self.megatron_config.sequence_parallel: - input_ids = sp_utils.pad_to_sequence_parallel(input_ids) - - input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) - - outputs = self.model( - input_ids=input_ids, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = outputs - - logits = self._forward_head(hidden_states) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension - # add removed padding back - logits = pad_input( - logits, indices, batch_size, seqlen=sequence_length - ) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - -class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad): - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) - # lm_head is effectively the same as sequence parallel - sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) - - def _forward_head(self, hidden_states): - logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) - logits = logits.float() - if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - output = super().forward(input_ids, attention_mask, position_ids) - output.logits = torch.squeeze(output.logits, dim=-1) - return output - - -""" -Support pipeline parallelism -""" - - -class ParallelLlamaModelRmPadPP(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - This model definition supports pipeline parallelism. To support pp and vpp, - - This model only contains layer in this pp stage and vpp chunk - - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.pre_process = pre_process - self.post_process = post_process - self.megatron_config = megatron_config - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - if pre_process: - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - else: - self.embed_tokens = None - - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = megatron_config.pipeline_model_parallel_size - self.num_layer_per_pp = config.num_hidden_layers // pp_size - vpp_size = megatron_config.virtual_pipeline_model_parallel_size - vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() - - if vpp_size is not None: - self.layers = nn.ModuleList() - self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size - self.num_layer_this_model = self.num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) - else: - self.num_layer_this_model = self.num_layer_per_pp - offset = pp_rank * self.num_layer_per_pp - - self.layers = nn.ModuleList() - for i in range(self.num_layer_this_model): - layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i) - self.layers.add_module(f"{i}", layer) - - if post_process: - self.norm = ParallelLlamaRMSNorm(config, megatron_config) - else: - self.norm = None - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - self.input_tensor = input_tensor - - def forward( - self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (1, totol_nnz) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - if self.pre_process: - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) - - # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron - # so need to deal with it by handle here: - # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) - inputs_embeds = inputs_embeds.transpose(0, 1) - if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) - - hidden_states = inputs_embeds - else: - # self.hidden_states should be passed by Megatron - hidden_states = self.input_tensor - - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = layer_outputs - - if self.post_process: - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelLlamaForCausalLMRmPadPP(nn.Module): - def __init__( - self, - config: LlamaConfig, - megatron_config: ModelParallelConfig, - pre_process, - post_process, - share_embeddings_and_output_weights=False, - ): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.megatron_config = megatron_config - self.model = ParallelLlamaModelRmPadPP( - config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process - ) - assert share_embeddings_and_output_weights is False, ( - "Llama Model not supports sharing embedding and output weights" - ) - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - self.vocab_size = config.vocab_size - self.pre_process = pre_process - self.post_process = post_process - if post_process: - self._init_head(config) - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - assert len(input_tensor) == 1 - self.model.set_input_tensor(input_tensor[0]) - - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - def _forward_head(self, hidden_states): - # all_gather from sequence parallel region is performed inside lm_head - # logits shape before forward_head hidden_states.shape: [4, 32, 4096] - logits = self.lm_head(hidden_states)[0] - # logits shape after forward_head logits.shape: [8, 32, 8] - logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) - return logits - - def forward( - self, - # original input - *, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - - # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. - # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model - batch_size, sequence_length = input_ids.shape - # remove padding here - input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( - input_ids.unsqueeze(dim=-1), attention_mask - ) # (total_nnz, 1) - - # pad input_ids to multiple of tp for all tp ranks - # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap - if self.megatron_config.sequence_parallel: - input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) - - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) - - outputs = self.model( - input_ids=input_ids_rmpad, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - if self.post_process: - hidden_states = outputs - # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) - logits = self._forward_head(hidden_states) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input( - logits, indices, batch_size, seqlen=sequence_length - ) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - else: - return outputs - - -class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP): - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) - # lm_head is effectively the same as sequence parallel - sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) - - def _forward_head(self, hidden_states): - logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) - logits = logits.float() - if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits - - def forward( - self, - *, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) - if self.post_process: - output.logits = torch.squeeze(output.logits, dim=-1) - return output - else: - return output diff --git a/verl/models/mcore/__init__.py b/verl/models/mcore/__init__.py deleted file mode 100644 index 29d053177..000000000 --- a/verl/models/mcore/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .registry import ( - get_mcore_forward_fn, - get_mcore_forward_fused_fn, - get_mcore_weight_converter, - hf_to_mcore_config, - init_mcore_model, -) - -__all__ = [ - "hf_to_mcore_config", - "init_mcore_model", - "get_mcore_forward_fn", - "get_mcore_weight_converter", - "get_mcore_forward_fused_fn", -] diff --git a/verl/models/mcore/config_converter.py b/verl/models/mcore/config_converter.py deleted file mode 100644 index 597afcdd1..000000000 --- a/verl/models/mcore/config_converter.py +++ /dev/null @@ -1,360 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# convert huggingface config to mcore transformer config - - -import torch -import torch.nn.functional as F -from megatron.core import parallel_state as mpu -from megatron.core.transformer import MLATransformerConfig, TransformerConfig -from transformers import PretrainedConfig - - -def _get_base_transformer_config( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> dict: - """ - Create a base TransformerConfig with common parameters across different model architectures. - TODO: (ycl) use dataclass or converter config? - - Args: - hf_config: HuggingFace model configuration - dtype: Data type for the model - override_transformer_config_kwargs: Additional parameters to override defaults - - Returns: - TransformerConfig with common parameters - """ - - # Common parallel state parameters - overlap_p2p_comm = ( - mpu.get_virtual_pipeline_model_parallel_world_size() is not None - and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 - ) - batch_p2p_comm = False - - # Base configuration with common parameters - base_config = { - # Model architecture parameters - "num_layers": hf_config.num_hidden_layers, - "hidden_size": hf_config.hidden_size, - "num_attention_heads": hf_config.num_attention_heads, - "num_query_groups": hf_config.num_key_value_heads, - "ffn_hidden_size": hf_config.intermediate_size, - "attention_dropout": hf_config.attention_dropout, - "hidden_dropout": getattr(hf_config, "hidden_dropout", 0.0), - "kv_channels": getattr(hf_config, "head_dim", None), - "layernorm_epsilon": hf_config.rms_norm_eps, - "add_bias_linear": True, - # Activation and normalization - "activation_func": F.silu, - "normalization": "RMSNorm", - "gated_linear_unit": True, - # Data types - "pipeline_dtype": dtype, - "params_dtype": dtype, - "bf16": dtype is torch.bfloat16, - # Parallel configuration - "tensor_model_parallel_size": mpu.get_tensor_model_parallel_world_size(), - "pipeline_model_parallel_size": mpu.get_pipeline_model_parallel_world_size(), - "expert_model_parallel_size": mpu.get_expert_model_parallel_world_size(), - "expert_tensor_parallel_size": mpu.get_expert_tensor_parallel_world_size(), - "virtual_pipeline_model_parallel_size": mpu.get_virtual_pipeline_model_parallel_world_size(), - "context_parallel_size": mpu.get_context_parallel_world_size(), - "overlap_p2p_comm": overlap_p2p_comm, - "batch_p2p_comm": batch_p2p_comm, - "sequence_parallel": mpu.get_tensor_model_parallel_world_size() > 1, - # Common settings - "variable_seq_lengths": True, - "masked_softmax_fusion": True, - "moe_token_dispatcher_type": "alltoall", - } - - # Update with any provided overrides - # override_transformer_config_kwargs as kwargs shall never be none - base_config.update(override_transformer_config_kwargs) - - return base_config - - -def _get_mla_transformer_config( - hf_config: PretrainedConfig, mla_rope_config: dict, dtype: torch.dtype, **override_transformer_config_kwargs -) -> dict: - """ - Create a MLATransformerConfig with common parameters across different model architectures. - This is specifically for MLA models like DeepseekV3. - - Args: - hf_config: HuggingFace model configuration - mla_rope_config: MLA specific RoPE configuration - dtype: Data type for the model - override_transformer_config_kwargs: Additional parameters to override defaults - - Returns: - MLATransformerConfig with common parameters - """ - base_config = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, **override_transformer_config_kwargs) - mla_config = { - # MLA specific parameters - "q_lora_rank": hf_config.q_lora_rank, - "kv_lora_rank": hf_config.kv_lora_rank, - "qk_head_dim": hf_config.qk_nope_head_dim, - "qk_pos_emb_head_dim": hf_config.qk_rope_head_dim, - "v_head_dim": hf_config.v_head_dim, - "rotary_base": hf_config.rope_theta, - "rotary_scaling_factor": mla_rope_config["factor"], - "rope_type": mla_rope_config["type"], - "max_position_embeddings": mla_rope_config["original_max_position_embeddings"], - "beta_fast": mla_rope_config["beta_fast"], - "beta_slow": mla_rope_config["beta_slow"], - "mscale": mla_rope_config["mscale"], - "mscale_all_dim": mla_rope_config["mscale_all_dim"], - } - - base_config.update(mla_config) - return base_config - - -def hf_to_mcore_config_dense( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - # for LlamaForCausalLM or Qwen2ForCausalLM - qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) - qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False - - args: dict = _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - use_cpu_initialization=False, - add_bias_linear=False, - add_qkv_bias=qkv_bias, - qk_layernorm=qk_layernorm, - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - print(f"Overridden TF init config: {args}") - return TransformerConfig(**args) - - -def hf_to_mcore_config_qwen2moe( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - args: dict = _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - use_cpu_initialization=False, - add_bias_linear=False, - layernorm_epsilon=hf_config.rms_norm_eps, - # MoE specific - moe_ffn_hidden_size=hf_config.moe_intermediate_size, - moe_router_bias_update_rate=0.001, - moe_router_topk=hf_config.num_experts_per_tok, - num_moe_experts=hf_config.num_experts, - moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size, - moe_aux_loss_coeff=hf_config.router_aux_loss_coef, - # moe_aux_loss_coeff=0.0, - moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL - moe_shared_expert_overlap=True, - moe_grouped_gemm=True, - moe_router_score_function="softmax", - # Other optimizations - persist_layer_norm=True, - bias_activation_fusion=True, - bias_dropout_fusion=True, - # Qwen specific - moe_router_pre_softmax=True, - add_qkv_bias=True, - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - print(f"Overridden TF init config: {args}") - return TransformerConfig(**args) - - -def hf_to_mcore_config_mixtral( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - args: dict = _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - use_cpu_initialization=False, - add_bias_linear=False, - layernorm_epsilon=hf_config.rms_norm_eps, - # MoE specific - num_moe_experts=hf_config.num_local_experts, - moe_aux_loss_coeff=hf_config.router_aux_loss_coef, - moe_router_topk=hf_config.num_experts_per_tok, - moe_router_pre_softmax=True, - moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL - moe_router_score_function="softmax", - moe_shared_expert_intermediate_size=None, # mixtral has no shared expert - moe_shared_expert_overlap=False, # mixtral has no shared expert - moe_ffn_hidden_size=hf_config.intermediate_size, - moe_router_bias_update_rate=0.001, - # moe_permute_fusion=True, # need TE 2.1+ - moe_grouped_gemm=True, - # Other optimizations - persist_layer_norm=True, - apply_rope_fusion=True, - bias_activation_fusion=True, - bias_dropout_fusion=True, - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - print(f"Overridden TF init config: {args}") - return TransformerConfig(**args) - - -def hf_to_mcore_config_qwen3moe( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - args: dict = _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - use_cpu_initialization=False, - add_bias_linear=False, - layernorm_epsilon=hf_config.rms_norm_eps, - # MoE specific - moe_ffn_hidden_size=hf_config.moe_intermediate_size, - moe_router_bias_update_rate=0.001, - moe_router_topk=hf_config.num_experts_per_tok, - num_moe_experts=hf_config.num_experts, - moe_aux_loss_coeff=hf_config.router_aux_loss_coef, - # moe_aux_loss_coeff=0.0, - moe_router_load_balancing_type="none", # turn off aux_loss as it hurts perf in RL - moe_grouped_gemm=True, - moe_router_score_function="softmax", - # Other optimizations - persist_layer_norm=True, - bias_activation_fusion=True, - bias_dropout_fusion=True, - # Qwen specific - moe_router_pre_softmax=False, - qk_layernorm=True, - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - print(f"Overridden TF init config: {args}") - return TransformerConfig(**args) - - -def hf_to_mcore_config_dpskv3( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> MLATransformerConfig: - # DeepseekV3ForCausalLM - from megatron.core.transformer.enums import AttnBackend - - from .patch_v012 import apply_patch - - apply_patch() - - mla_rope_config = { - "beta_fast": 32, - "beta_slow": 1, - "factor": 1, - "mscale": 1.0, - "mscale_all_dim": 1.0, - "original_max_position_embeddings": 4096, - "type": "rope", - } - if "rope_scaling" in hf_config and hf_config.rope_scaling is not None: - mla_rope_config.update(hf_config.rope_scaling) - moe_layer_freq = [1] * hf_config.num_hidden_layers - for i in range(min(hf_config.first_k_dense_replace, hf_config.num_hidden_layers)): - moe_layer_freq[i] = 0 - - # disable MTP and quantization for now - if "num_nextn_predict_layers" in hf_config: - assert hf_config.num_nextn_predict_layers == 0, ( - "MTP is not supported for now, please modify the config.json to set num_nextn_predict_layers to 0" - ) - assert "quantization_config" not in hf_config or not hf_config.quantization_config, ( - "quantization is not supported for now, please modify the config.json to remove quantization_config" - ) - - args: dict = _get_mla_transformer_config( - hf_config=hf_config, - mla_rope_config=mla_rope_config, - dtype=dtype, - # Additional parameters - use_cpu_initialization=False, - add_bias_linear=False, - attention_backend=AttnBackend.fused, - qk_layernorm=True, - # Standard MoE parameters - moe_ffn_hidden_size=hf_config.moe_intermediate_size, - moe_token_dispatcher_type="alltoall", - moe_router_bias_update_rate=0.001, - moe_router_enable_expert_bias=True, - moe_router_topk=hf_config.num_experts_per_tok, - num_moe_experts=hf_config.n_routed_experts, - moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts, - moe_aux_loss_coeff=getattr(hf_config, "aux_loss_alpha", 0.001), - moe_router_load_balancing_type="seq_aux_loss", - moe_shared_expert_overlap=True, - # moe_permute_fusion=True, # need TE 2.1+ - moe_grouped_gemm=True, - moe_router_score_function="sigmoid", - moe_router_pre_softmax=True, - moe_router_topk_scaling_factor=hf_config.routed_scaling_factor, - moe_layer_freq=moe_layer_freq, - # mcore 0.12 moe - moe_router_dtype="fp64", - disable_bf16_reduced_precision_matmul=True, - # Other optimizations - # deallocate_pipeline_outputs=True, - # gradient_accumulation_fusion=True, - persist_layer_norm=True, - bias_activation_fusion=True, - bias_dropout_fusion=True, - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - transformer_config: MLATransformerConfig = MLATransformerConfig(**args) - print(f"Overridden MLA TF init config: {transformer_config}") - # MTP - if "num_nextn_predict_layers" in hf_config: - transformer_config.mtp_num_layers = hf_config.num_nextn_predict_layers - transformer_config.mtp_loss_scaling_factor = 0.1 - - return transformer_config - - -def hf_to_mcore_config_qwen2_5_vl( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - # Qwen2_5_VLForConditionalGeneration - - args = _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - add_bias_linear=False, - # qwen specific - add_qkv_bias=True, - mrope_section=hf_config.rope_scaling["mrope_section"], - ) - # override_transformer_config_kwargs as kwargs shall never be none - args.update(override_transformer_config_kwargs) - print(f"Overridden TF init config: {args}") - return TransformerConfig(**args) - - -def hf_to_mcore_config_llama4( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - # Llama4ForConditionalGeneration - raise NotImplementedError("Llama4ForConditionalGeneration is not supported yet") diff --git a/verl/models/mcore/loader.py b/verl/models/mcore/loader.py deleted file mode 100644 index 659b4baa2..000000000 --- a/verl/models/mcore/loader.py +++ /dev/null @@ -1,492 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_id, get_torch_device - -from .saver import _megatron_calc_global_rank - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False): - """Load merged state_dict to sharded Megatron module in training.""" - from megatron.core import DistributedDataParallel as LocalDDP - from megatron.core import mpu - from megatron.core.transformer.module import Float16Module - from torch.nn.parallel import DistributedDataParallel as torchDDP - - from verl.utils.logger import print_rank_0 - from verl.utils.megatron_utils import unwrap_model - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def broadcast_params(module): - for param in module.parameters(): - torch.distributed.broadcast( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() - ) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - cp_rank = mpu.get_context_parallel_rank() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank) - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == src_rank: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.decoder.layers) == num_layers_per_model - - def _broadcast_tensor(tensor, name) -> torch.Tensor: - """broadcast tensor from rank0 across mp_group""" - nonlocal state_dict - nonlocal mp_group - if torch.distributed.get_rank() == src_rank: - if name in state_dict: - weight = state_dict[name] - tensor_shape = weight.shape - else: - tensor_shape = None - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not in state_dict, skip load") - return - - if tensor is None: - tensor = torch.empty( - tensor_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - if torch.distributed.get_rank() == src_rank: - tensor.data.copy_(weight) - dist.broadcast(tensor, src=src_rank, group=mp_group) - - def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == src_rank: - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == src_rank: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=src_rank, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == src_rank: - if name in state_dict: - full_weight = state_dict[name] - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == src_rank: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=src_rank, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == src_rank: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape " - f"{tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == src_rank: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=src_rank, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == src_rank: - assert q_name in state_dict and k_name in state_dict and v_name in state_dict - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - - if config.num_key_value_heads >= tp_size: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - sizes = [total_size * tp_size] - if not bias: - sizes.append(config.hidden_size) - new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - num_query_groups_per_partition = models[0].config.num_query_groups // tp_size - new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] - q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0) - k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0) - v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0) - total_size_per_head = total_size // num_query_groups_per_partition - for j in range(num_query_groups_per_partition): - new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( - torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) - ) - - else: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - sizes = [total_size * tp_size] - if not bias: - sizes.append(config.hidden_size) - new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv_this_tp = new_weight_qkv[i * total_size : (i + 1) * total_size] - q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0) - k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0) - v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0) - total_size_per_head = total_size // config.num_attention_heads - for j in range(config.num_attention_heads): - new_weight_qkv_this_tp[j * total_size_per_head : (j + 1) * total_size_per_head].copy_( - torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0) - ) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == src_rank: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=src_rank, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - embed_tokens_weight = None - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight - _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - for layer in range(config.num_hidden_layers): - layer_name = f"model.layers.{layer}" - print_rank_0(f"loading layer #{layer}, with layer_name model.layers.{layer}...") - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.decoder.layers[dst_layer_idx] - - _broadcast_tensor( - sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - if f"{layer_name}.self_attn.q_norm.weight" in state_dict: - _broadcast_tensor( - sync_layer.self_attention.q_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_norm.weight", - ) - _broadcast_tensor( - sync_layer.self_attention.k_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.k_norm.weight", - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - if f"{layer_name}.self_attn.q_proj.bias" in state_dict: - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - bias=True, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - _broadcast_tensor( - sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.decoder.final_layernorm, "weight", None), - "model.norm.weight", - ) - - print_rank_0("loading lm_head...") - lm_head_weight = None - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.output_layer.weight - - if is_value_model: - # if torch.distributed.get_rank() == src_rank: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "lm_head.weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "reward_head.weight") - print_rank_0("load lm_head from value_head weight") - else: - _broadcast_tensor(None, "lm_head.weight") - print_rank_0("fail to match lm_head in value_model") - # else: - - # _broadcast_tensor(lm_head_weight, "lm_head.weight") - - else: - _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") - dist.barrier() - # Broadcast weights inside data parallel groups - for wrapped_model in wrapped_models: - broadcast_params(wrapped_model) - pass - get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/mcore/mbridge.py b/verl/models/mcore/mbridge.py deleted file mode 100644 index 35c32d697..000000000 --- a/verl/models/mcore/mbridge.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -try: - from mbridge import AutoBridge - from mbridge.utils.post_creation_callbacks import freeze_moe_router, make_value_model -except ImportError: - print("mbridge package not found. Please install mbridge with `pip install verl[mcore]` or `pip install mbridge`") - raise - -__all__ = ["AutoBridge", "make_value_model", "freeze_moe_router"] diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py deleted file mode 100644 index e70e11f4e..000000000 --- a/verl/models/mcore/model_forward.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from verl.utils.megatron_utils import unwrap_model - -from .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding - - -def gptmodel_forward( - model, - input_ids, - attention_mask, - position_ids, - sequence_parallel, - value_model=False, - pack_seqs=True, - logits_processor=None, - logits_processor_args: dict = None, - **kwargs, -): - """Default forward pass for GPT models with optional sequence packing.""" - pre_process = unwrap_model(model).pre_process - post_process = unwrap_model(model).post_process - if pack_seqs: - batch_size, seq_len = attention_mask.shape[:2] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) - input_ids_rmpad = input_ids_rmpad.contiguous() - output_orig = model( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids, - packed_seq_params=packed_seq_params, - ) - if post_process and logits_processor is not None: - args = { - k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] - for k, v in logits_processor_args.items() - } - output_dict = logits_processor(output_orig, **args) - output = { - k: postprocess_packed_seqs( - v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process - ) - for k, v in output_dict.items() - } - else: - output = postprocess_packed_seqs( - output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process - ) - else: - assert logits_processor is None, "logits_processor is not supported for non-packed sequence" - batch_size, sequence_length = attention_mask.shape - new_input_ids, new_attention_mask, new_position_ids = remove_left_padding( - input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process - ) - output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids) - output = recover_left_padding( - output, new_attention_mask, attention_mask, sequence_length, post_process=post_process - ) - if value_model and post_process: - output = output[..., 0] - return output - - -def gptmodel_forward_qwen2_5_vl( - model, - input_ids, - attention_mask, - position_ids, - sequence_parallel, - value_model=False, - pack_seqs=True, - multi_modal_inputs=None, - logits_processor=None, - logits_processor_args: dict = None, - **kwargs, -): - from megatron.core import parallel_state as mpu - - assert mpu.get_context_parallel_world_size() == 1, "qwen2_5_vl's context parallel is not accurate yet" - pre_process = unwrap_model(model).pre_process - post_process = unwrap_model(model).post_process - pixel_values = ( - multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None - ) - image_grid_thw = ( - multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None - ) - if pack_seqs: - batch_size, seq_len = attention_mask.shape[:2] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) - input_ids_rmpad = input_ids_rmpad.contiguous() - output_orig = model( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids, - packed_seq_params=packed_seq_params, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - ) - - if post_process and logits_processor is not None: - args = { - k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0] - for k, v in logits_processor_args.items() - } - output_dict = logits_processor(output_orig, **args) - output = { - k: postprocess_packed_seqs( - v, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process - ) - for k, v in output_dict.items() - } - else: - output = postprocess_packed_seqs( - output_orig, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process - ) - else: - batch_size, sequence_length = attention_mask.shape - new_input_ids, new_attention_mask, new_position_ids = remove_left_padding( - input_ids, attention_mask, position_ids, sequence_parallel, pre_process=pre_process - ) - output = model( - input_ids=new_input_ids, - position_ids=new_position_ids, - attention_mask=new_attention_mask, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - ) - output = recover_left_padding( - output, new_attention_mask, attention_mask, sequence_length, post_process=post_process - ) - if value_model and post_process: - output = output[..., 0] - return output diff --git a/verl/models/mcore/model_forward_fused.py b/verl/models/mcore/model_forward_fused.py deleted file mode 100644 index fc55ef1b0..000000000 --- a/verl/models/mcore/model_forward_fused.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import OrderedDict -from typing import Optional - -import torch -from megatron.core import parallel_state -from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk -from megatron.core.inference.contexts import BaseInferenceContext -from megatron.core.models.gpt.gpt_model import GPTModel -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region -from torch import Tensor - -from verl.models.mcore.util import preprocess_packed_seqs -from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy -from verl.utils.megatron_utils import unwrap_model -from verl.utils.model import CausalLMOutputForPPO - -from .qwen2_5_vl.model import Qwen2_5VLModel -from .util import postprocess_packed_seqs_for_dict_output - - -def patch_fused_forward(model: torch.nn.Module): - model = unwrap_model(model) - if isinstance(model, GPTModel): - model = model - elif isinstance(model, Qwen2_5VLModel): - if not hasattr(model, "language_model"): - # the qwen2.5vl model might only have vision_model - return - model = model.language_model - else: - raise ValueError("Model is not a GPTModel or Qwen2_5VLModel") - model.forward_backup = model.forward - model.forward = _fused_GPTModel_forward.__get__(model, model.__class__) - return - - -def unpatch_fused_forward(model: torch.nn.Module): - model = unwrap_model(model) - if isinstance(model, GPTModel): - model = model - elif isinstance(model, Qwen2_5VLModel): - model = model.language_model - else: - raise ValueError("Model is not a GPTModel or Qwen2_5VLModel") - model.forward = model.forward_backup - return - - -def fused_forward_gptmodel( - model: GPTModel, - input_ids: Tensor, - position_ids: Tensor, - attention_mask: Tensor, - labels: Tensor, - labels_mask: Tensor, - **kwargs, -): - pre_process: bool = unwrap_model(model).pre_process - post_process: bool = unwrap_model(model).post_process - - batch_size, seq_len = attention_mask.shape[:2] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process) - input_ids_rmpad = input_ids_rmpad.contiguous() - labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True) - labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True) - labels_rmpad = labels_rmpad.contiguous() - labels_mask_rmpad = labels_mask_rmpad.contiguous() - - output_orig: CausalLMOutputForPPO = model( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids, - labels=labels_rmpad, - packed_seq_params=packed_seq_params, - ) - - if post_process: - # output_orig is in type of CausalLMOutputForPPO - output = postprocess_packed_seqs_for_dict_output( - labels_mask_rmpad, - output_orig, - packed_seq_params, - attention_mask, - batch_size, - seq_len, - post_process=post_process, - ) - else: - output = output_orig - return output - - -def fused_forward_qwen2_5_vl( - model: Qwen2_5VLModel, - input_ids: Tensor, - position_ids: Tensor, - attention_mask: Tensor, - labels: Tensor, - labels_mask: Tensor, - multi_modal_inputs=None, - **kwargs, -): - # pre_process = unwrap_model(model).pre_process - post_process = unwrap_model(model).post_process - - pixel_values = ( - multi_modal_inputs["pixel_values"].to(input_ids.device) if "pixel_values" in multi_modal_inputs else None - ) - image_grid_thw = ( - multi_modal_inputs["image_grid_thw"].to(input_ids.device) if "image_grid_thw" in multi_modal_inputs else None - ) - - batch_size, seq_len = attention_mask.shape[:2] - input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) - labels_rmpad, _ = preprocess_packed_seqs(labels, attention_mask, pre_process=True) - labels_mask_rmpad, _ = preprocess_packed_seqs(labels_mask, attention_mask, pre_process=True) - labels_rmpad = labels_rmpad.contiguous() - labels_mask_rmpad = labels_mask_rmpad.contiguous() - input_ids_rmpad = input_ids_rmpad.contiguous() - output_orig: CausalLMOutputForPPO = model( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids, - packed_seq_params=packed_seq_params, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - labels=labels, - ) - if post_process: - # output_orig is in type of CausalLMOutputForPPO - output = postprocess_packed_seqs_for_dict_output( - labels_mask_rmpad, - output_orig, - packed_seq_params, - attention_mask, - batch_size, - seq_len, - post_process=post_process, - ) - else: - output = output_orig - return output - - -def _fused_GPTModel_forward( - self, - input_ids: Tensor, - position_ids: Tensor, - attention_mask: Tensor, - decoder_input: Tensor = None, - labels: Tensor = None, - inference_context: BaseInferenceContext = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - runtime_gather_output: Optional[bool] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - loss_mask: Optional[Tensor] = None, - temperature: float = 1.0, -) -> CausalLMOutputForPPO: - """ - Forward pass for GPT models with fused kernel support. - - Patch https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py - """ - - # If decoder_input is provided (not None), then input_ids and position_ids are ignored. - # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. - - # Decoder embedding. - if decoder_input is not None: - pass - elif self.pre_process: - decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) - else: - # intermediate stage of pipeline - # decoder will get hidden_states from encoder.input_tensor - decoder_input = None - - # Rotary positional embeddings (embedding is None for PP intermediate devices) - rotary_pos_emb = None - rotary_pos_cos = None - rotary_pos_sin = None - if self.position_embedding_type == "rope" and not self.config.multi_latent_attention: - if not self.training and self.config.flash_decode and inference_context: - assert inference_context.is_static_batching(), "GPTModel currently only supports static inference batching." - # Flash decoding uses precomputed cos and sin for RoPE - rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( - inference_context.max_sequence_length, - self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length), - ) - else: - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_context, self.decoder, decoder_input, self.config, packed_seq_params - ) - rotary_pos_emb = self.rotary_pos_emb( - rotary_seq_len, - packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == "thd", - ) - elif self.position_embedding_type == "mrope" and not self.config.multi_latent_attention: - if self.training or not self.config.flash_decode: - rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) - else: - # Flash decoding uses precomputed cos and sin for RoPE - raise NotImplementedError( - "Flash decoding uses precomputed cos and sin for RoPE, not implmented in MultimodalRotaryEmbedding yet." - ) - - if ( - (self.config.enable_cuda_graph or self.config.flash_decode) - and rotary_pos_cos is not None - and inference_context - and inference_context.is_static_batching() - and not self.training - ): - sequence_len_offset = torch.tensor( - [inference_context.sequence_len_offset] * inference_context.current_batch_size, - dtype=torch.int32, - device=rotary_pos_cos.device, # Co-locate this with the rotary tensors - ) - else: - sequence_len_offset = None - - # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the - # reference held by this caller function, enabling early garbage collection for - # skip inference - - # Run decoder. - hidden_states = self.decoder( - hidden_states=decoder_input, - attention_mask=attention_mask, - inference_context=inference_context, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - **(extra_block_kwargs or {}), - ) - - # Process inference output. - if inference_context and not inference_context.is_static_batching(): - hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1) - - # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - - if self.mtp_process: - hidden_states = self.mtp( - input_ids=input_ids, - position_ids=position_ids, - labels=labels, - loss_mask=loss_mask, - hidden_states=hidden_states, - attention_mask=attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - embedding=self.embedding, - output_layer=self.output_layer, - output_weight=output_weight, - runtime_gather_output=runtime_gather_output, - compute_language_model_loss=self.compute_language_model_loss, - **(extra_block_kwargs or {}), - ) - - if not self.post_process: - return hidden_states - - output = CausalLMOutputForPPO( - loss=None, - logits=None, - past_key_values=None, - hidden_states=hidden_states, - attentions=None, - ) - - if self.config.sequence_parallel: - hidden_states = gather_from_sequence_parallel_region(hidden_states) - logprobs, entropy = linear_cross_entropy( - hidden_states, - self.output_layer.weight, - labels, - temperature, - "none", - parallel_state.get_tensor_model_parallel_group(), - ) - - if has_config_logger_enabled(self.config): - payload = OrderedDict( - { - "input_ids": input_ids, - "position_ids": position_ids, - "attention_mask": attention_mask, - "decoder_input": decoder_input, - "logprobs": logprobs, - "entropy": entropy, - } - ) - log_config_to_disk(self.config, payload, prefix="input_and_logits") - - output.entropy = entropy - output.log_probs = logprobs - - return output diff --git a/verl/models/mcore/model_initializer.py b/verl/models/mcore/model_initializer.py deleted file mode 100644 index 4c01b124b..000000000 --- a/verl/models/mcore/model_initializer.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# use mcore transformer config to initialize the model -from abc import ABC, abstractmethod - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec -from megatron.core.models.gpt.gpt_model import GPTModel - -from .config_converter import PretrainedConfig, TransformerConfig - - -class BaseModelInitializer(ABC): - """Base class for model initializers.""" - - def __init__(self, tfconfig: TransformerConfig, hf_config: PretrainedConfig): - self.tfconfig = tfconfig - self.hf_config = hf_config - - @abstractmethod - def get_transformer_layer_spec(self): - """Get the transformer layer specification. - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_layer_specs.py""" - pass - - def get_rope_scaling_args(self) -> dict: - """Get rope scaling args.""" - rope_scaling_args = {} - if "rope_scaling" in self.hf_config: - if self.hf_config.rope_scaling is not None: - # assert self.hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" - rope_scaling_args["seq_len_interpolation_factor"] = self.hf_config.rope_scaling["factor"] - return rope_scaling_args - - def initialize( - self, - pre_process: bool = True, - post_process: bool = True, - share_embeddings_and_output_weights: bool = False, - value: bool = False, - **extra_kwargs, - ) -> GPTModel: - """Initialize a GPT model with the given configuration. - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/models/gpt/gpt_model.py - - Args: - pre_process (bool): include embedding layer. - post_process (bool): including an output layer. - share_embeddings_and_output_weights (bool): input embeddings and output logit weights are shared. - value (bool): add an extra linear layer for classification or regression. - - Returns: - GPTModel: An initialized GPT model instance - """ - transformer_layer_spec = self.get_transformer_layer_spec() - rope_scaling_args = self.get_rope_scaling_args() - mtp_block_spec = extra_kwargs.get("mtp_block_spec", None) - model = GPTModel( - config=self.tfconfig, - transformer_layer_spec=transformer_layer_spec, - vocab_size=self.hf_config.vocab_size, - max_sequence_length=self.hf_config.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - position_embedding_type="rope", - rotary_base=self.hf_config.rope_theta, - **rope_scaling_args, - mtp_block_spec=mtp_block_spec, - ) - - if post_process and value: - from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer - - model.output_layer = LinearForLastLayer( - input_size=self.tfconfig.hidden_size, output_size=1, config=self.tfconfig - ) - - return model - - -class DenseModel(BaseModelInitializer): - """Initializer for dense models like Llama and Qwen2.""" - - def get_transformer_layer_spec(self): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - return get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) - - -class Qwen2MoEModel(BaseModelInitializer): - """Initializer for Qwen2 MoE models.""" - - def get_transformer_layer_spec(self): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) - - # Patch layer spec for shared experts - for i in range(len(transformer_layer_spec.layer_specs)): - transformer_layer_spec.layer_specs[i].submodules.mlp.submodules.shared_experts.params["gate"] = True - - return transformer_layer_spec - - def initialize(self, **kwargs): - # Qwen default freeze_moe_router: true - model = super().initialize(**kwargs) - freeze_moe_router = kwargs.get("freeze_moe_router", True) - if freeze_moe_router: - for layer in model.decoder.layers: - layer.mlp.router.weight.requires_grad = False - return model - - -class MixtralModel(BaseModelInitializer): - """Initializer for Mixtral models.""" - - def get_transformer_layer_spec(self): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) - return transformer_layer_spec - - def initialize(self, **kwargs): - model = super().initialize(**kwargs) - freeze_moe_router = kwargs.get("freeze_moe_router", False) - if freeze_moe_router: - for layer in model.decoder.layers: - layer.mlp.router.weight.requires_grad = False - return model - - -class Qwen3MoEModel(BaseModelInitializer): - """Initializer for Qwen3 MoE models.""" - - def get_transformer_layer_spec(self): - assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) - return transformer_layer_spec - - def initialize(self, **kwargs): - # Qwen default freeze_moe_router: true - model = super().initialize(**kwargs) - freeze_moe_router = kwargs.get("freeze_moe_router", True) - if freeze_moe_router: - for layer in model.decoder.layers: - layer.mlp.router.weight.requires_grad = False - return model - - -class DeepseekV3Model(BaseModelInitializer): - """Initializer for DeepseekV3 models.""" - - def get_transformer_layer_spec(self): - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) - return transformer_layer_spec - - def get_rope_scaling_args(self) -> dict: - """Get rope scaling args.""" - rope_scaling_args = {} - return rope_scaling_args - - def initialize( - self, - **kwargs, - ): - freeze_moe_router = kwargs.get("freeze_moe_router", True) - if freeze_moe_router: - self.tfconfig.moe_router_load_balancing_type = "none" - # MTP - if self.tfconfig.mtp_num_layers is not None: - transformer_layer_spec = self.get_transformer_layer_spec() - mtp_block_spec = get_gpt_mtp_block_spec(self.tfconfig, transformer_layer_spec, use_transformer_engine=True) - kwargs["mtp_block_spec"] = mtp_block_spec - - model = super().initialize(**kwargs) - if freeze_moe_router: - for layer in model.decoder.layers: - if hasattr(layer.mlp, "router"): - layer.mlp.router.weight.requires_grad = False - return model - - -class Qwen25VLModel(BaseModelInitializer): - """Initializer for Qwen2.5 VL models.""" - - def get_transformer_layer_spec(self): - transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True) - return transformer_layer_spec - - def initialize( - self, - pre_process=None, - post_process=None, - share_embeddings_and_output_weights=False, - value=False, - **extra_kwargs, - ): - tfconfig = self.tfconfig - hf_config = self.hf_config - # Qwen2_5_VLForConditionalGeneration - from copy import deepcopy - - transformer_layer_spec = self.get_transformer_layer_spec() - - from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear - from megatron.core.models.gpt.moe_module_specs import MLPSubmodules - from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec - - from .qwen2_5_vl import Qwen2_5VLModel, get_vision_model_config, get_vision_projection_config - - vision_transformer_config = get_vision_model_config(deepcopy(tfconfig)) - vision_transformer_config.pipeline_model_parallel_size = 1 - vision_transformer_config.first_pipeline_num_layers = None - - vision_projection_config = get_vision_projection_config( - deepcopy(tfconfig), - vision_transformer_config.hidden_size, - spatial_merge_size=hf_config.vision_config.spatial_merge_size, - ) - vision_projection_layer_spec = MLPSubmodules( - linear_fc1=TEColumnParallelLinear, - linear_fc2=TERowParallelLinear, - ) - vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() - - qwen25_vl_model = Qwen2_5VLModel( - language_transformer_config=tfconfig, - language_transformer_layer_spec=transformer_layer_spec, - language_vocab_size=hf_config.vocab_size, - language_max_sequence_length=hf_config.max_position_embeddings, - vision_transformer_config=vision_transformer_config, - vision_transformer_layer_spec=vision_transformer_layer_spec, - vision_projection_config=vision_projection_config, - vision_projection_layer_spec=vision_projection_layer_spec, - vision_projection_type="mlp", - language_rotary_base=hf_config.rope_theta, - pre_process=pre_process, - post_process=post_process, - add_decoder=True, - add_encoder=True, - parallel_output=True, - language_share_embeddings_and_output_weights=share_embeddings_and_output_weights, - ) - - if post_process and value: - from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer - - qwen25_vl_model.language_model.output_layer = LinearForLastLayer( - input_size=tfconfig.hidden_size, output_size=1, config=tfconfig - ) - - return qwen25_vl_model diff --git a/verl/models/mcore/patch_v012.py b/verl/models/mcore/patch_v012.py deleted file mode 100644 index d54a3eb34..000000000 --- a/verl/models/mcore/patch_v012.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# there is some bug in mcore 0.12, so we need to patch it -# 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None - - -def apply_patch(): - import torch - from megatron.core import parallel_state, tensor_parallel - from megatron.core.transformer.multi_latent_attention import ( - MLASelfAttention, - apply_rotary_pos_emb, - deprecate_inference_params, - gather_from_sequence_parallel_region, - gather_from_tensor_model_parallel_region, - scatter_to_sequence_parallel_region, - ) - - def patch_get_query_key_value_tensors( - self, - hidden_states, - key_value_states=None, - position_ids=None, - packed_seq_params=None, - inference_context=None, - *, - inference_params=None, - ): - """ - Derives `query`, `key` and `value` tensors from `hidden_states`. - """ - # s = sequence length, b = batch size, h = hidden size, n = num attention heads - # Attention heads [s, b, n*h] - assert hidden_states.ndim == 3, f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" - - inference_context = deprecate_inference_params(inference_context, inference_params) - - # ========================================= - # Prepare RoPE and seqlen related params - # ========================================= - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_context, None, hidden_states, self.config, packed_seq_params - ) - - # rotary_pos_emb:[s, b, 1, 64] - mscale = 1.0 - if self.config.rope_type == "rope": - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == "thd" - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) - else: - rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len) - - # ========================================= - # QKV down projection and layernorm - # ========================================= - if self.config.q_lora_rank is not None: - # if linear_q_down_proj is ColumnParallelLinear: - # q_compressed: [s, b, q_lora_rank / TP] - # elif linear_q_down_proj is Linear: - # q_compressed: [s / TP, b, q_lora_rank] - q_compressed, _ = self.linear_q_down_proj(hidden_states) - - # When output is sharded (ColumnParallelLinear), two things are needed to be - # identical to a normal Linear. - # 1. Manually gather output to restore output dim q_lora_rank; - # 2. Scatter sequence back to s / TP if sequence-parallel since it was - # gathered by ColumnParallelLinear. - if q_compressed.size(-1) != self.config.q_lora_rank: - q_compressed = gather_from_tensor_model_parallel_region(q_compressed) - if self.config.sequence_parallel: - q_compressed = scatter_to_sequence_parallel_region(q_compressed) - - q_compressed = self.q_layernorm(q_compressed) - else: - q_compressed = hidden_states - - # if linear_kv_down_proj is ColumnParallelLinear: - # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP] - # elif linear_kv_down_proj is Linear: - # kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)] - kv_combined, _ = self.linear_kv_down_proj(hidden_states) - if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim: - # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)] - kv_combined = gather_from_tensor_model_parallel_region(kv_combined) - # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim] - kv_compressed, k_pos_emb = torch.split( - kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 - ) - if self.config.sequence_parallel: - # kv_compressed:[s / TP, b, kv_lora_rank] - kv_compressed = scatter_to_sequence_parallel_region(kv_compressed) - else: - # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim] - kv_compressed, k_pos_emb = torch.split( - kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 - ) - if parallel_state.get_tensor_model_parallel_world_size() > 1: - # k_pos_emb: [s, b, qk_pos_emb_head_dim] - k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb) - - kv_compressed = self.kv_layernorm(kv_compressed) - - # ========================================= - # QKV up projection and RoPE apply - # ========================================= - def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb): - if self.config.q_lora_rank is not None: - q, _ = self.linear_q_up_proj(q_compressed) - else: - # hidden_states:[s, b, 2048], q: [s, b, n * 192] - q, _ = self.linear_q_proj(q_compressed) - - q_len, bsz, _ = q.size() - - # q: [s, b, n, 192] - q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim) - - # kv: [s, b, 2048] - kv, _ = self.linear_kv_up_proj(kv_compressed) - - # kv: [s, b, n, 256] - kv = kv.view( - q_len, - bsz, - self.num_attention_heads_per_partition, - self.config.qk_head_dim + self.config.v_head_dim, - ) - - if inference_context is not None: - # add offset to the sequence start for inference - sequence_start = inference_context.sequence_len_offset - sequence_end = sequence_start + q_len - rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end] - else: - # Shorten rotary_pos_emb to the sequence length when inference_params - # is not provided. This makes sure we can run forward directly with - # any sequence length. During training, the sequence length is always - # the full rotary_pos_emb length. - rotary_pos_emb = rotary_pos_emb[0:q_len] - - # [s, b, 64] -> [s, b, 1, 64] - k_pos_emb = torch.unsqueeze(k_pos_emb, 2) - - # q: [s, b, n, 128], q_pos_emb: [s, b, n, 64] - q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1) - - # k_no_pe: [s, b, n, 128], value: [s, b, n, 128] - k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1) - - if packed_seq_params is not None: - cu_seqlens_q = packed_seq_params.cu_seqlens_q - cu_seqlens_kv = packed_seq_params.cu_seqlens_kv - q_pos_emb = q_pos_emb.squeeze(1) - k_pos_emb = k_pos_emb.squeeze(1) - q_no_pe = q_no_pe.squeeze(1) - k_no_pe = k_no_pe.squeeze(1) - value = value.squeeze(1) - else: - cu_seqlens_q = cu_seqlens_kv = None - - # q_pos_emb: [s, b, n, 64], k_pos_emb:[s, b, 1, 64] - q_pos_emb = apply_rotary_pos_emb( - q_pos_emb, - rotary_pos_emb, - config=self.config, - cu_seqlens=cu_seqlens_q, - mscale=mscale, - ) - k_pos_emb = apply_rotary_pos_emb( - k_pos_emb, - rotary_pos_emb, - config=self.config, - cu_seqlens=cu_seqlens_kv, - mscale=mscale, - ) - - # query: [s, b, n, 192] - query = torch.cat([q_no_pe, q_pos_emb], dim=-1) - if packed_seq_params is not None: - k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1) - key = torch.cat([k_no_pe, k_pos_emb], dim=-1) - else: - # key: [s, b, n, 192] - k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1) - key = torch.cat([k_no_pe, k_pos_emb], dim=-1) - - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - return query, key, value - - if self.recompute_up_proj: - self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput() - query, key, value = self.qkv_up_checkpoint.checkpoint( - qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb - ) - else: - query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb) - - return query, key, value - - MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors diff --git a/verl/models/mcore/qwen2_5_vl/__init__.py b/verl/models/mcore/qwen2_5_vl/__init__.py deleted file mode 100644 index 8842d0249..000000000 --- a/verl/models/mcore/qwen2_5_vl/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from .model import Qwen2_5VLModel -from .vision_config import get_vision_model_config, get_vision_projection_config - -__all__ = ["Qwen2_5VLModel", "get_vision_model_config", "get_vision_projection_config"] diff --git a/verl/models/mcore/qwen2_5_vl/attention.py b/verl/models/mcore/qwen2_5_vl/attention.py deleted file mode 100644 index 91a27cc3e..000000000 --- a/verl/models/mcore/qwen2_5_vl/attention.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from megatron.core.transformer.attention import * - -from .rope_utils import apply_rotary_pos_emb_absolute - - -class Qwen2_5VLSelfAttention(SelfAttention): - """ - Overrides the SelfAttention class, the difference is that qwen2_5_vl uses apply_rotary_pos_emb_absolute - instead of apply_rotary_pos_emb - """ - - def forward( - self, - hidden_states: Tensor, - attention_mask: Tensor, - key_value_states: Optional[Tensor] = None, - inference_context: Optional[BaseInferenceContext] = None, - rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, - rotary_pos_cos: Optional[Tensor] = None, - rotary_pos_sin: Optional[Tensor] = None, - attention_bias: Optional[Tensor] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - sequence_len_offset: Optional[int] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Perform a forward pass through the attention module. - - Args: - hidden_states (Tensor): Hidden states. - attention_mask (Tensor): Attention mask. - key_value_states (Optional[Tensor]): Key/value states (for cross attention). - inference_context (Optional[BaseInferenceContext]): Inference context that manages - KV cache. - rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary - embedding tensor(s). - rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. - rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. - attention_bias (Optional[Tensor]): Attention bias. - packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. - sequence_len_offset (Optional[int]): Sequence length offset used for - inference CUDA graphs. - - Return: - (Tuple[Tensor, Tensor]) Attention output and bias. - - """ - - inference_context = deprecate_inference_params(inference_context, inference_params) - - if inference_context and inference_context.is_dynamic_batching(): - assert flash_decode_and_prefill_kernel is not None, ( - "Internal use only: install package `nvidia_chunked_flash_attn`." - ) - - # hidden_states: [sq, b, h] - if self.config.flash_decode and not self.training and inference_context is not None: - rotary_pos_emb = None - else: - assert rotary_pos_cos is None and rotary_pos_sin is None - - # For self attention we just duplicate the rotary_pos_emb if it isn't already - if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = (rotary_pos_emb,) * 2 - - # ===================== - # Query, Key, and Value - # ===================== - # Get the query, key and value tensors based on the type of attention - - # self or cross attn. - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) - - # =================================================== - # Adjust key, value, and rotary_pos_emb for inference - # =================================================== - - # This branch only runs in the decode phase of flash decoding and returns after the linear - # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. - if ( - self.config.flash_decode - and inference_context is not None - and inference_context.is_decode_only() - and not self.training - and rotary_pos_cos is not None - ): - assert self.layer_number in inference_context.key_value_memory_dict - assert inference_context.sequence_len_offset is not None - inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] - output = self.flash_decode( - sequence_len_offset=sequence_len_offset, - query_layer=query, - key_layer=key, - value_layer=value, - inference_key_memory=inference_key_memory, - inference_value_memory=inference_value_memory, - rotary_cos=rotary_pos_cos, - rotary_sin=rotary_pos_sin, - ) - out = output.transpose(0, 1).contiguous() - context_layer = out.view(out.size(0), out.size(1), -1) - output, bias = self.linear_proj(context_layer) - return output, bias - - query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( - inference_context, - query, - key, - value, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, - ) - - if packed_seq_params is not None: - query = query.squeeze(1) - key = key.squeeze(1) - value = value.squeeze(1) - - # ================================================ - # relative positional embedding (rotary embedding) - # ================================================ - if rotary_pos_emb is not None and not self.config.flash_decode: - q_pos_emb, k_pos_emb = rotary_pos_emb - - if packed_seq_params is not None: - if packed_seq_params.cu_seqlens_q_padded is not None: - cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded - else: - cu_seqlens_q = packed_seq_params.cu_seqlens_q - if packed_seq_params.cu_seqlens_kv_padded is not None: - cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded - else: - cu_seqlens_kv = packed_seq_params.cu_seqlens_kv - else: - cu_seqlens_q = cu_seqlens_kv = None - - if q_pos_emb is not None: - # TODO VIJAY: simplify - if inference_context is None or inference_context.is_static_batching(): - query = apply_rotary_pos_emb_absolute(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q) - else: - query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q) - if k_pos_emb is not None: - key = apply_rotary_pos_emb_absolute(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) - - # TODO, can apply positional embedding to value_layer so it has - # absolute positional embedding. - # otherwise, only relative positional embedding takes effect - # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) - - # ================================== - # core attention computation - # ================================== - - if self.checkpoint_core_attention and self.training: - core_attn_out = self._checkpointed_attention_forward( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) - else: - if inference_context is None or inference_context.is_static_batching(): - # Static batching attention kernel. - core_attn_out = self.core_attention( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) - - else: - # Dynamic batching attention kernel. - q, k, v = (query, key, value) - cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() - cu_kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths() - - core_attn_out = self.flash_decode_and_prefill( - q, k, v, max_seqlen_q, max_seqlen_k, cu_query_lengths, cu_kv_lengths - ) - core_attn_out = core_attn_out.squeeze(0).unsqueeze(1) - core_attn_out = rearrange(core_attn_out, "s b h d -> s b (h d)") - - if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": - # reshape to same output shape as unpacked case - # (t, np, hn) -> (t, b=1, h=np*hn) - # t is the pack size = sum (sq_i) - # note that batch is a dummy dimension in the packed case - core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) - - # ================= - # Output. [sq, b, h] - # ================= - - output, bias = self.linear_proj(core_attn_out) - - return output, bias diff --git a/verl/models/mcore/qwen2_5_vl/model.py b/verl/models/mcore/qwen2_5_vl/model.py deleted file mode 100644 index 74e4406c3..000000000 --- a/verl/models/mcore/qwen2_5_vl/model.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -import torch -from megatron.core import InferenceParams, tensor_parallel -from megatron.core.models.gpt.gpt_model import GPTModel - -# from .transformer_config import Qwen2VLTransformerConfig -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig - -from .attention import Qwen2_5VLSelfAttention -from .vision_model import Qwen2_5VisionModel - - -# Note: This is under development and may be missing features. -class Qwen2_5VLModel(MegatronModule): - """Qwen2.5VL multi-modal model. - - Args: - language_transformer_config (TransformerConfig): Transformer config for the language model. - language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the - language model. - language_vocab_size (int): Language model vocabulary size. - language_max_sequence_length (int): Language model maximum sequence length. This is used for - positional embedding. - vision_transformer_config (TransformerConfig): Transformer config for the vision model. - vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the - vision model. - vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to - language model inputs. - vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision - projection. - vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP. - parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This - is typically True for training and False for inference. - language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings - in the language model. Defaults to 1.0. - pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). - Defaults to True. - post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline - parallelism). Defaults to True. - add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. - When we use pipelining, the encoder - will live on only a subset of the pipeline stages (specifically, only the first stage). - add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. - When we use pipelining, the decoder - will live on only a subset of the pipeline stages (specifically, every stage after the first one). - img_h (int): The height of each image that the ViT will see. - img_w (int): The width of each image that the ViT will see. - patch_dim (int): The size of each patch side. - img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be - inserted. Defaults to 0. - """ - - def __init__( - self, - language_transformer_config: TransformerConfig, - language_transformer_layer_spec: ModuleSpec, - language_vocab_size: int, - language_max_sequence_length: int, - vision_transformer_config: TransformerConfig, - vision_transformer_layer_spec: ModuleSpec, - vision_projection_config: TransformerConfig, - vision_projection_layer_spec: ModuleSpec, - vision_projection_type: str = "mlp", - parallel_output: bool = True, - language_rotary_percent: float = 1.0, - pre_process: bool = True, - post_process: bool = True, - add_encoder: bool = True, - add_decoder: bool = True, - language_rotary_base: int = 10000, - fp16_lm_cross_entropy: bool = False, - language_share_embeddings_and_output_weights: bool = False, - image_token_id: int = 151655, - video_token_id: int = 151656, - ) -> None: - super().__init__(config=language_transformer_config) - - # patch self_attention to use qwen2_5_vl attention - vision_transformer_layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention - for layer_spec in language_transformer_layer_spec.layer_specs: - layer_spec.submodules.self_attention.module = Qwen2_5VLSelfAttention - - logging.getLogger(__name__).warning("Qwen2VL model is under development and may be missing features.") - - self.pre_process = pre_process - self.post_process = post_process - self.add_encoder = add_encoder - self.add_decoder = add_decoder - - self.encoder_hidden_state = None - self.vision_model = None - self.vision_projection = None - self.language_model = None - self.image_token_id = image_token_id - self.video_token_id = video_token_id - - self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size - - # This attribute is needed to check if an all-reduce is required - # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. - self.share_embeddings_and_output_weights = False - if self.pre_process: - self.vision_model = Qwen2_5VisionModel( - vision_transformer_config, - vision_transformer_layer_spec, - vision_projection_config, - vision_projection_layer_spec, - projection_type=vision_projection_type, - pre_process=True, - post_process=True, - ) - - self.language_model = GPTModel( - config=language_transformer_config, - transformer_layer_spec=language_transformer_layer_spec, - vocab_size=language_vocab_size, - max_sequence_length=language_max_sequence_length, - parallel_output=parallel_output, - position_embedding_type="mrope", - rotary_percent=language_rotary_percent, - pre_process=self.pre_process, - post_process=self.post_process, - rotary_base=language_rotary_base, - fp16_lm_cross_entropy=fp16_lm_cross_entropy, - share_embeddings_and_output_weights=language_share_embeddings_and_output_weights, - scatter_embedding_sequence_parallel=False, - ) - - self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights - - def shared_embedding_or_output_weight(self): - """This is a convenience method to surface the language model's word embeddings, which is - necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" - if self.add_decoder: - return self.language_model.shared_embedding_or_output_weight() - return None - - def set_input_tensor(self, input_tensor) -> None: - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen2VL" - - if self.pre_process: - self.encoder_hidden_state = input_tensor[0] - else: - self.language_model.set_input_tensor(input_tensor[0]) - - def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool): - """Freeze model modules. - - Make specific modules non-trainable by setting requires_grad to False for the module's parameters. - - Args: - freeze_language_model (bool): Freeze the language model module. - freeze_vision_model (bool): Freeze the vision model module. - freeze_vision_projection (bool): Freeze the vision projection module. - """ - modules = [] - if freeze_language_model and self.language_model is not None: - modules.append(self.language_model) - if freeze_vision_model and self.vision_model is not None: - modules.append(self.vision_model) - if freeze_vision_projection and self.vision_projection is not None: - modules.append(self.vision_projection) - - for module in modules: - for param in module.parameters(): - param.requires_grad = False - - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: torch.Tensor = None, - labels: torch.Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - pixel_values: torch.Tensor = None, - pixel_values_videos: torch.Tensor = None, - image_grid_thw: torch.Tensor = None, - video_grid_thw: torch.Tensor = None, - ) -> torch.Tensor: - """Forward function of the Qwen2VL model. - - Args: - image_data (torch.Tensor): input image of shape [total_thw_size, n_features]. - input_ids (torch.Tensor): input text ids [batch, text_seq_len]. - position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. - attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, - combined_seq_len]. - labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. - inference_params (InferenceParams): Inference-time parameters including KV cache. - - video_start_index: - 0 -- all video - len(video_seq) -- all image - others -- mixture - *_input_mask: should not be None in the first PP stage - Returns: - output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape - [b, s, vocab_size]. - """ - video_start_index = 0 - vision_grid_thw = None - vision_data = None - if image_grid_thw is not None: - image_mask = input_ids == self.image_token_id - vision_grid_thw = image_grid_thw - vision_data = pixel_values - video_start_index = image_mask.sum().item() - if video_grid_thw is not None: - video_mask = input_ids == self.video_token_id - vision_grid_thw = torch.cat([vision_grid_thw, video_grid_thw], dim=0) - vision_data = torch.cat([vision_data, pixel_values_videos], dim=0) - video_start_index = image_mask.sum().item() + video_mask.sum().item() - use_inference_kv_cache = ( - inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict - ) - use_inference_kv_cache = ( - inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict - ) - if use_inference_kv_cache: - raise NotImplementedError() - - if self.pre_process: - vision_embeds = None - if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: - vision_embeds = self.vision_model( - vision_data=vision_data, # If None, vision model should use intermediate outputs (EPP > 1) - grid_thw=vision_grid_thw, # should provided in each EPP stage - ) - - # If running inference, the language model KV cache will be updated for image token positions. - # Here we store the image tokens sequence length, which can be used as an offset to the KV cache later. - if inference_params is not None: - raise NotImplementedError() - # inference_params.key_value_memory_dict["image_tokens_count"] = ( - # vision_embeddings.shape[0] - # ) - - # If running inference, we can skip image token computation if they were computed already earlier - # for this sample. - if use_inference_kv_cache: - language_embeddings: torch.Tensor = self.language_model.embedding( - input_ids=input_ids, - position_ids=None, # NOTE: disable - ) # [text_seq_len, b, h_language] - # NOTE: why not cat here? is it the combined embeddings useless? - combined_embeddings = language_embeddings - elif vision_embeds is not None: - if video_start_index == 0: - image_embeds = None - video_embeds = vision_embeds - elif video_start_index == vision_embeds.shape[0]: - image_embeds = vision_embeds - video_embeds = None - elif 0 < video_start_index < vision_embeds.shape[0]: - image_embeds = vision_embeds[:video_start_index] - video_embeds = vision_embeds[video_start_index:] - else: - raise ValueError( - f"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got " - f"{video_start_index}" - ) - - combined_embeddings = self.language_model.embedding( - input_ids=input_ids, - position_ids=None, # NOTE: disable - ) # [text_seq_len, b, h_language] - - if image_embeds is not None or video_embeds is not None: - combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() - if image_embeds is not None: - image_mask = (input_ids == self.image_token_id).contiguous() - if image_mask.sum() > 0: - combined_embeddings = combined_embeddings.clone() - combined_embeddings[image_mask] = image_embeds.to( - dtype=combined_embeddings.dtype, device=combined_embeddings.device - ) - if video_embeds is not None: - video_mask = (input_ids == self.video_token_id).contiguous() - if video_mask.sum() > 0: - combined_embeddings = combined_embeddings.clone() - combined_embeddings[video_mask] = video_embeds.to( - dtype=combined_embeddings.dtype, device=combined_embeddings.device - ) - combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() - - else: - combined_embeddings = self.language_model.embedding( - input_ids=input_ids, - position_ids=None, # NOTE: disable - ) # [text_seq_len, b, h_language] - if self.config.sequence_parallel: - combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) - combined_embeddings = combined_embeddings.contiguous() - else: - combined_embeddings = None - from .rope_utils import get_rope_index - - position_ids, _ = get_rope_index( - input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, attention_mask=attention_mask - ) - - output = self.language_model( - input_ids=None, - position_ids=position_ids, # None in encoder - attention_mask=attention_mask, # None in encoder - decoder_input=combined_embeddings, # only not None in the first decoder PP stage - labels=labels, # only not None in the last decoder PP stage - # inference_params=inference_params, # currently always None - packed_seq_params=packed_seq_params, # currently always None - **(extra_block_kwargs or {}), - ) - - return output diff --git a/verl/models/mcore/qwen2_5_vl/rope_utils.py b/verl/models/mcore/qwen2_5_vl/rope_utils.py deleted file mode 100644 index fadc74daa..000000000 --- a/verl/models/mcore/qwen2_5_vl/rope_utils.py +++ /dev/null @@ -1,266 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from __future__ import annotations - -import logging -from typing import Optional - -import torch -from megatron.core.models.common.embeddings.rope_utils import * -from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd -from torch import Tensor - -logger = logging.getLogger(__name__) - - -# Slightly modified from Qwen2VLForConditionalGeneration.get_rope_index -def get_rope_index( - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, -): - """ - Calculate the 3D rope index based on image and video's temporal, height and width in LLM. - - Explanation: - - Each embedding sequence contains vision embedding and text embedding or just contains text embedding. - - For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. - - Examples: - - input_ids: [T T T T T], here T is for text. - temporal position_ids: [0, 1, 2, 3, 4] - height position_ids: [0, 1, 2, 3, 4] - width position_ids: [0, 1, 2, 3, 4] - - For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part - and 1D rotary position embedding for text part. - - Examples: - - Temporal (Time): 3 patches, representing different segments of the video in time. - Height: 2 patches, dividing each frame vertically. - Width: 2 patches, dividing each frame horizontally. - We also have some important parameters: - fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each - second. - tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal - tokens" are conceptually packed into a one-second interval of the video. - In this case, we have 25 tokens per second. So each second of the video will be - represented with 25 separate time points. It essentially defines the temporal - granularity. - temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. - interval: The step size for the temporal position IDs, calculated as tokens_per_second * - temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be - have a difference of 50 in the temporal position IDs. - input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. - vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] - vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] - vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] - text temporal position_ids: [101, 102, 103, 104, 105] - text height position_ids: [101, 102, 103, 104, 105] - text width position_ids: [101, 102, 103, 104, 105] - Here we calculate the text start position_ids as the max vision position_ids plus 1. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. - video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): - The temporal, height and width of feature shape of each video in LLM. - second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): - The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - Returns: - position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) - mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) - """ - spatial_merge_size = 2 - tokens_per_second = 2 - image_token_id = 151655 - video_token_id = 151656 - vision_start_token_id = 151652 - mrope_position_deltas = [] - if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): - total_input_ids = input_ids - if attention_mask is None: - attention_mask = torch.ones_like(total_input_ids) - position_ids = torch.ones( - 3, - input_ids.shape[0], - input_ids.shape[1], - dtype=input_ids.dtype, - device=input_ids.device, - ) - image_index, video_index = 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) - for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] - image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - second_per_grid_t = 0 - image_index += 1 - remain_images -= 1 - ed = ed_image - - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - if second_per_grid_ts is not None: - second_per_grid_t = second_per_grid_ts[video_index] - else: - second_per_grid_t = 1.0 - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - range_tensor = torch.arange(llm_grid_t).view(-1, 1) - expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) - - time_tensor = expanded_range * second_per_grid_t * tokens_per_second - - time_tensor_long = time_tensor.long() - t_index = time_tensor_long.flatten() - - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) - mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) - return position_ids, mrope_position_deltas - else: - if attention_mask is not None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] - else: - position_ids = ( - torch.arange(input_ids.shape[1], device=input_ids.device) - .view(1, 1, -1) - .expand(3, input_ids.shape[0], -1) - ) - mrope_position_deltas = torch.zeros( - [input_ids.shape[0], 1], - device=input_ids.device, - dtype=input_ids.dtype, - ) - - return position_ids, mrope_position_deltas - - -def apply_rotary_pos_emb_thd_absolute( - t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False -) -> Tensor: - """A baseline implementation of applying RoPE for `thd` format. - - Args: - t (Tensor): Input tensor T is of shape [t, h, d] - cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, - with shape [b + 1] and dtype torch.int32. - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] - - Returns: - Tensor: Shape [t, h, d]. The input tensor after applying RoPE. - """ - return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1) - - -def apply_rotary_pos_emb_absolute( - t: Tensor, - freqs: Tensor, - config: TransformerConfig, - cu_seqlens: Optional[Tensor] = None, -): - """ - Reroute to the appropriate apply_rotary_pos_emb function depending on - bshd (conventional) / thd (packed seq) format - - In Qwen2-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim] - """ - - if config.apply_rope_fusion: - if cu_seqlens is None: - # NOTE: TE backends do not support mRoPE in bshd format when bs > 1 - if freqs.shape[1] > 1: - return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) - else: - return fused_apply_rotary_pos_emb(t, freqs) - else: - # NOTE: as expected, thd format can use bshd - return fused_apply_rotary_pos_emb(t[:, None], freqs).squeeze(1) - else: - if cu_seqlens is None: - return _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) - else: - return apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved) diff --git a/verl/models/mcore/qwen2_5_vl/vision_config.py b/verl/models/mcore/qwen2_5_vl/vision_config.py deleted file mode 100644 index 0631c90f6..000000000 --- a/verl/models/mcore/qwen2_5_vl/vision_config.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from megatron.core import parallel_state -from megatron.core.transformer import TransformerConfig - - -def get_vision_model_config(config: TransformerConfig) -> TransformerConfig: - # Given a Transformer Config from decoder, build vision encoder config - # diff: out_hidden_size & intermediate_size - - # mlp: hidden_size -> intermediate_size -> embed_dim, silu - # NOTE: here we provide a workaround to solve the wrong layer amount when VPP of decoder is on - if config.num_layers in [28, 36]: - config.ffn_hidden_size = 3420 - else: - config.ffn_hidden_size = 3456 - - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: - config.num_layers = 32 * parallel_state.get_virtual_pipeline_model_parallel_world_size() # depth - else: - config.num_layers = 32 # depth - config.num_attention_heads = 16 # num_heads - config.add_bias_linear = True # all nn.Linear has bias (MLP, attn) - config.add_qkv_bias = True # qkv_proj in attn has bias - config.hidden_size = 1280 # hidden_size - config.hidden_dropout = 0.0 - config.attention_dropout = 0.0 - - # config.gated_linear_unit = False # no gated - # config.activation_func = quick_gelu # hidden_act - config.kv_channels = config.hidden_size // config.num_attention_heads - config.num_query_groups = config.num_attention_heads # no GQA - config.layernorm_zero_centered_gamma = False # False - config.apply_query_key_layer_scaling = False # factor=math.sqrt(head_dim) - config.bias_activation_fusion = False # no swiglu, set false - config.bias_dropout_fusion = False # no dropout, set false - config.attention_softmax_in_fp32 = True # use True - # config.normalization = 'LayerNorm' # use RMSNorm - config.seq_length = 1 - - config.tp_comm_overlap = False - config.sequence_parallel = False - config.temporal_patch_size = 2 - config.patch_size = 14 - config.in_channels = 3 - config.spatial_merge_size = 2 - - config.fullatt_block_indexes = [7, 15, 23, 31] - config._qwen2_5_vl_window_size = 112 - return config - - -def get_vision_projection_config( - config: TransformerConfig, embed_dim: int, spatial_merge_size: int -) -> TransformerConfig: - # merger: - # context_dim = hidden_size * merge_size**2 - # out_hidden_size = hidden_size - # context_dim -> context_dim -> out_hidden_size - # MLP: - # input_size -> ffn_hidden_size -> hidden_size - # spec: LN -> Linear(bias=True) -> GELU -> Linear(bias=True) - config.gated_linear_unit = False - config.bias_activation_fusion = False - config.add_bias_linear = True - config.ffn_hidden_size = embed_dim * (spatial_merge_size**2) - config.activation_func = torch.nn.functional.gelu - config.tp_comm_overlap = False - config.sequence_parallel = False - return config diff --git a/verl/models/mcore/qwen2_5_vl/vision_model.py b/verl/models/mcore/qwen2_5_vl/vision_model.py deleted file mode 100644 index 06b4fd328..000000000 --- a/verl/models/mcore/qwen2_5_vl/vision_model.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -from megatron.core import InferenceParams -from megatron.core.models.common.vision_module.vision_module import VisionModule -from megatron.core.models.vision.multimodal_projector import MultimodalProjector -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer.enums import ModelType -from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_config import TransformerConfig -from torch import nn -from torch.nn import functional as F - -from .vision_transformer_block import Qwen2_5VisionTransformerBlock as TransformerBlock - - -# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class PatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 14, - temporal_patch_size: int = 2, - in_channels: int = 3, - embed_dim: int = 1152, - ) -> None: - super().__init__() - self.patch_size = patch_size - self.temporal_patch_size = temporal_patch_size - self.in_channels = in_channels - self.embed_dim = embed_dim - - kernel_size = [temporal_patch_size, patch_size, patch_size] - self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype - hidden_states = hidden_states.view( - -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size - ) - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) - return hidden_states - - -# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py -class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: - super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.outer(seq, self.inv_freq) - return freqs.float() - - -class Qwen2_5VisionModel(VisionModule): - """Qwen2.5 ViT vision model. - - Args: - transformer_config (TransformerConfig): Transformer config. - transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers. - ln_pre_impl (ModuleSpec or type): Specifies the layer norm type to use for ln_pre. - add_class_token (bool, optional): Include a class token. Defaults to True. - class_token_len (int): Class token length. Defaults to 1 but 8 may be faster. - patch_dim (int): Image patch size. - img_h (int): Input image height. - img_w (int): Input image width. - """ - - def __init__( - self, - transformer_config: TransformerConfig, - transformer_layer_spec: ModuleSpec, - projection_config: TransformerConfig, - projection_layer_spec: ModuleSpec, - projection_type: str = "mlp", - pre_process: bool = True, - post_process: bool = False, - ) -> None: - super().__init__(config=transformer_config) - - self.spatial_merge_size = transformer_config.spatial_merge_size - - embed_dim = transformer_config.hidden_size - num_heads = transformer_config.num_attention_heads - temporal_patch_size = transformer_config.temporal_patch_size - patch_size = transformer_config.patch_size - in_channels = transformer_config.in_channels - - self.patch_size = transformer_config.patch_size - self.fullatt_block_indexes = transformer_config.fullatt_block_indexes - self.window_size = transformer_config._qwen2_5_vl_window_size - self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size - - self.max_sequence_length = transformer_config.seq_length - self.patch_embed = PatchEmbed( - patch_size=patch_size, - temporal_patch_size=temporal_patch_size, - in_channels=in_channels, - embed_dim=embed_dim, - ) - - head_dim = embed_dim // num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) - - self.model_type = ModelType.encoder_or_decoder - self.pre_process = pre_process - self.post_process = post_process - - # Transformer layers. - # TODO: Follow-up changes will make pre and post_process configurable. They are needed for supporting - # pipeline parallelism. - # NOTE: a final layer norm and/or linear layer present in some implementations are omitted here. - self.decoder = TransformerBlock( - config=transformer_config, - spec=transformer_layer_spec, - pre_process=self.pre_process, - post_process=self.post_process, - post_layer_norm=True, - ) - - self.merge_hidden_size = projection_config.ffn_hidden_size - self.square_merge_size = self.merge_hidden_size // embed_dim - - if self.post_process: - self.projection = MultimodalProjector( - projection_config, projection_layer_spec, projection_type, projection_config.ffn_hidden_size - ) - else: - self.projection = None - - self.input_tensor = None - - def set_input_tensor(self, input_tensor: torch.Tensor) -> None: - """Sets input tensor to the model. - - Args: - input_tensor (Tensor): Sets the input tensor for the model. - """ - if self.pre_process: # always True - self.input_tensor = input_tensor - else: - raise NotImplementedError() - - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - hpos_ids = hpos_ids.permute(0, 2, 1, 3) - hpos_ids = hpos_ids.flatten() - - wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ) - wpos_ids = wpos_ids.permute(0, 2, 1, 3) - wpos_ids = wpos_ids.flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) - pos_ids = torch.cat(pos_ids, dim=0).to(grid_thw.device) - max_grid_size = grid_thw[:, 1:].max() - rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).to(grid_thw.device) - rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) - return rotary_pos_emb - - def get_window_index(self, grid_thw): - window_index: list = [] - cu_window_seqlens: list = [0] - window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size - - for grid_t, grid_h, grid_w in grid_thw: - llm_grid_h, llm_grid_w = ( - grid_h // self.spatial_merge_size, - grid_w // self.spatial_merge_size, - ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) - pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size - pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size - num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size - num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) - index_padded = index_padded.reshape( - grid_t, - num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size, - ) - index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, - num_windows_h * num_windows_w, - vit_merger_window_size, - vit_merger_window_size, - ) - seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) - index_padded = index_padded.reshape(-1) - index_new = index_padded[index_padded != -100] - window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) - window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() - window_index = torch.cat(window_index, dim=0) - - return window_index, cu_window_seqlens - - def forward( - self, - vision_data: Optional[torch.Tensor], - grid_thw: torch.Tensor, - inference_params: Optional[InferenceParams] = None, - extra_block_kwargs: dict = None, - ) -> torch.Tensor: - """Forward function of the Qwen2 Vision Model. This function passes the input tensors - through the embedding layer and then the transformer. - - Args: - x (torch.Tensor): input image/video data of shape [n_tokens, n_dims] - grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame - packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend - - Returns: - x (torch.Tensor): output after final transformer block of shape [b, s, h]. - """ - assert grid_thw is not None - assert self.input_tensor is None - assert inference_params is None - - # Rotary positional embeddings (embedding is None for PP intermediate devices) - vision_data = self.patch_embed(vision_data) - window_index, cu_window_seqlens = self.get_window_index(grid_thw) - cu_window_seqlens = torch.tensor( - cu_window_seqlens, - device=vision_data.device, - dtype=torch.int32, - ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - - seq_len, _ = vision_data.size() - vision_data = vision_data.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - vision_data = vision_data[window_index, :, :] - vision_data = vision_data.reshape(seq_len, 1, -1) - - rotary_pos_emb = self.rot_pos_emb(grid_thw) - rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) - rotary_pos_emb = rotary_pos_emb[window_index, :, :] - rotary_pos_emb = rotary_pos_emb.reshape(seq_len, 1, 1, -1).repeat(1, 1, 1, 2) - - hidden_states = self.decoder( - hidden_states=vision_data, - attention_mask=None, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - packed_seq_params=self.build_packed_seq_params(None, cu_window_seqlens), - packed_seq_params_full=self.build_packed_seq_params(grid_thw), - fullatt_block_indexes=self.fullatt_block_indexes, - **(extra_block_kwargs or {}), - ) - - hidden_states = self.projection(hidden_states.view(-1, self.merge_hidden_size)) - reverse_indices = torch.argsort(window_index) - return hidden_states[reverse_indices, :] - - def build_packed_seq_params( - self, - grid_thw: Optional[torch.Tensor], - cu_seqlens: Optional[torch.Tensor] = None, - ) -> PackedSeqParams: - # NOTE: each frame is a sequence (rather than each grid) - if grid_thw is not None: - seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) - cu_seqlens = seqlens.cumsum(dim=0) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).int() - else: - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - - max_seqlen_q = seqlens.max() - return PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - qkv_format="thd", - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_q, - ) diff --git a/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py b/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py deleted file mode 100644 index 8f765a0ff..000000000 --- a/verl/models/mcore/qwen2_5_vl/vision_transformer_block.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2024 Alibaba PAI Team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from megatron.core.transformer.transformer_block import * - - -class Qwen2_5VisionTransformerBlock(TransformerBlock): - def _checkpointed_forward( - self, - hidden_states: Tensor, - attention_mask: Tensor, - context: Tensor, - context_mask: Tensor, - rotary_pos_emb: Tensor, - attention_bias: Tensor, - packed_seq_params: PackedSeqParams, - packed_seq_params_full: PackedSeqParams, - fullatt_block_indexes, - ): - """Forward method with activation checkpointing.""" - - def custom(start: int, end: int): - def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb): - for index in range(start, end): - if index in fullatt_block_indexes: - packed_seq_params_now = packed_seq_params_full - else: - packed_seq_params_now = packed_seq_params - layer = self._get_layer(index) - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - attention_bias=attention_bias, - inference_context=None, - packed_seq_params=packed_seq_params_now, - ) - return hidden_states, context - - return custom_forward - - def checkpoint_handler(forward_func): - """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" - if self.config.fp8: - return te_checkpoint( - forward_func, - self.config.distribute_saved_activations, - tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - ) - else: - return tensor_parallel.checkpoint( - forward_func, - self.config.distribute_saved_activations, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - ) - - if self.config.recompute_method == "uniform": - # Uniformly divide the total number of Transformer layers and checkpoint - # the input activation of each divided chunk. - # A method to further reduce memory usage reducing checkpoints. - layer_idx = 0 - while layer_idx < self.num_layers_per_pipeline_rank: - hidden_states, context = checkpoint_handler( - custom(layer_idx, layer_idx + self.config.recompute_num_layers) - ) - - layer_idx += self.config.recompute_num_layers - - elif self.config.recompute_method == "block": - # Checkpoint the input activation of only a set number of individual - # Transformer layers and skip the rest. - # A method fully use the device memory removing redundant re-computation. - recompute_skip_num_layers = 0 - for layer_idx in range(self.num_layers_per_pipeline_rank): - # Skip recomputation when input grad computation is not needed. - # Need to have at least one input tensor with gradient computation - # for re-enterant autograd engine. - if self.config.fp8 and not hidden_states.requires_grad: - recompute_skip_num_layers += 1 - if ( - layer_idx >= recompute_skip_num_layers - and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers - ): - hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) - else: - hidden_states, context = custom(layer_idx, layer_idx + 1)( - hidden_states, attention_mask, context, context_mask, rotary_pos_emb - ) - else: - raise ValueError("Invalid activation recompute method.") - - return hidden_states - - def forward( - self, - hidden_states: Union[Tensor, WrappedTensor], - attention_mask: Optional[Tensor], - context: Optional[Tensor] = None, - context_mask: Optional[Tensor] = None, - rotary_pos_emb: Optional[Tensor] = None, - rotary_pos_cos: Optional[Tensor] = None, - rotary_pos_sin: Optional[Tensor] = None, - attention_bias: Optional[Tensor] = None, - inference_context: Optional[BaseInferenceContext] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - sequence_len_offset: Optional[Tensor] = None, - packed_seq_params_full: PackedSeqParams = None, - fullatt_block_indexes=None, - *, - inference_params: Optional[BaseInferenceContext] = None, - ): - """ - Perform the forward pass through the transformer block. - - This method handles the core computation of the transformer, including - self-attention, optional cross-attention, and feed-forward operations. - - Args: - hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] - where s is the sequence length, b is the batch size, and h is the hidden size. - Can be passed as a WrappedTensor during inference to avoid an obsolete - reference in the calling function. - attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking - self-attention. - context (Tensor, optional): Context tensor for cross-attention. - context_mask (Tensor, optional): Mask for cross-attention context - rotary_pos_emb (Tensor, optional): Rotary positional embeddings. - attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable - to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. - Used as an alternative to apply attention mask for TE cuDNN attention. - inference_context (BaseInferenceContext, optional): Parameters for inference-time - optimizations. - packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence - processing. - - Returns: - Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape - [s, b, h], and optionally the updated context tensor if cross-attention is used. - """ - - inference_context = deprecate_inference_params(inference_context, inference_params) - - # Delete the obsolete reference to the initial input tensor if necessary - if isinstance(hidden_states, WrappedTensor): - hidden_states = hidden_states.unwrap() - - if not self.pre_process: - # See set_input_tensor() - hidden_states = self.input_tensor - - # Update the inference parameters with the current batch size in case it is variable - if inference_context and not self.training: - inference_context.current_batch_size = hidden_states.size(1) - - # Viewless tensor. - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - if self.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - - # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), - # otherwise do nothing extra at the outer level - # if we are using other fp8 recipes, then the context manager enter&exit are free - # we can wrap fp8_context within the for loop over layers, so that we can fine-grained - # control which layer will be fp8 or bf16 - use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed - use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed - outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() - - with rng_context, outer_fp8_context: - # Forward pass. - if self.config.recompute_granularity == "full" and self.training: - hidden_states = self._checkpointed_forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - packed_seq_params_full=packed_seq_params_full, - fullatt_block_indexes=fullatt_block_indexes, - ) - else: - for l_no, layer in enumerate(self.layers): - inner_fp8_context = ( - get_fp8_context(self.config, layer.layer_number - 1) if use_inner_fp8_context else nullcontext() - ) - if l_no in fullatt_block_indexes: - packed_seq_params_now = packed_seq_params_full - else: - packed_seq_params_now = packed_seq_params - with self.offload_context, inner_fp8_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - attention_bias=attention_bias, - inference_context=inference_context, - packed_seq_params=packed_seq_params_now, - sequence_len_offset=sequence_len_offset, - ) - - if ( - torch.is_grad_enabled() - and self.config.cpu_offloading - and self.group_prefetch_offload_commit_async is not None - ): - hidden_states = self.group_prefetch_offload_commit_async(hidden_states) - - # Final layer norm. - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - # TENorm produces a "viewed" tensor. This will result in schedule.py's - # deallocate_output_tensor() throwing an error, so a viewless tensor is - # created to prevent this. - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - return hidden_states diff --git a/verl/models/mcore/readme.md b/verl/models/mcore/readme.md deleted file mode 100644 index 606dcf189..000000000 --- a/verl/models/mcore/readme.md +++ /dev/null @@ -1,99 +0,0 @@ -# verl Megatron-Core Models -The earlier versions of verl use `Megatron-LM` 0.4 and workaround huggingface model classes. To better use the latest features and speedup of modern Megatron, we are migrating to `Megatron-Core`(mcore), and use the recommended `GPTModel` class for all language models. With mcore `GPTModel`, we can use the latest features like `context parallel`, `expert parallel`, `dist_checkpointing`, etc. and we can update mcore with little effort in the future for new features. - -The migration has been successful with the help of the mcore team and the community. What we have done is: -1. update `Megatron` version to `0.11.0` -2. migrate `LlamaForCausalLM` and `Qwen2ForCausalLM` to mcore `GPTModel` -3. support sequence packing/thd format. -4. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`. -5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion script from huggingface to mcore `dist_checkpointing` format. - -We are working on the following features: -- support `Qwen2MoeForCausalLM` -- support `MixtralForCausalLM` -- support `DeepseekV3ForCausalLM` -- support `expert parallel` - -Features we invite the community to contribute: -- better scripts for offline weights conversion from huggingface to mcore `dist_checkpointing` format. - - conversion of large models with multiple GPUs - - conversion of large models with single GPU -- refactor the `megatron_checkpoint_manager.py` by `dist_checkpointing` format. -- support llama4 -- support qwen2.5-vl - -To track the progress of verl mcore integration, please refer to the [mcore integration issue](https://github.com/volcengine/verl/issues/1033). - -## How things work now -To engage the community in contributing, here are the key steps in our mcore integration process and features under development. - -The huggingface `transformers` is the de facto standard of model zoo while mcore is good at computation efficiency. The main challenge is conversion between the two. -main steps: -1. modelling the huggingface model with mcore `GPTModel` - - a. convert the huggingface config to mcore `TransformerConfig` - - b. init the mcore `GPTModel` with the converted config - - c. load the huggingface model weights to the `GPTModel` -2. online weight conversion from mcore to huggingface (due to the rollout engine `vLLM` is using huggingface format) - - a. bridge the gap between mcore and huggingface weights format and name mapping - - b. online resharding the mcore weights to rollout engine - - this part is very complicated with multiple parallel strategies composition between mcore and rollout engine -3. support the mcore features in verl - - a. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel` - - b. support recompute and other mcore speed up features - -4. checkpointing - - a. support recovering the verl training. - - b. support exporting the mcore checkpoint to huggingface format, for downstream inference. - -### Modelling the huggingface model with mcore `GPTModel` -The first step is to convert huggingface config to mcore `TransformerConfig` and init the mcore `GPTModel` with the converted config. See code in `verl/models/mcore/config_converter.py` and `verl/verl/models/mcore/models/model_initializer.py`. The corresponding model forward code is in `verl/verl/models/mcore/models/model_forward.py`. - -There are two ways of loading the huggingface model weights to the `GPTModel` -1. Runtime loading - - every rank loads the entire huggingface model weights and then shard and convert to mcore weights. - - speed is slow and memory consumption is high. - - this way is deprecated and will not support new models. -2. Offline loading - - use offline script to convert the huggingface model weights to mcore weights and save with mcore `dist_checkpointing` format. - - online loading and sharding is automatically done by mcore `dist_checkpointing` format. The speed is fast and memory consumption is low. - - the offline script is in `verl/scripts/converter_hf_to_mcore.py`. - -### online weight conversion from mcore to huggingface -See function `convert_megatron_model_to_transformers_model` in `verl/utils/megatron_utils.py` for the details. - -It should be refatored for extensibility and better performance. - -### support the mcore features in verl -Most of the features of `GPTModel` is out-of-the-box supported in verl through changing the `TransformerConfig`, except those about parallel strategies, such as `expert parallel`. -Features about parallel strategies should be supported with changes about the online weights conversion(especially the resharding part) and verl work dispatching. - -### checkpointing -The existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger`. - -The existing checkpoint format simply saves every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format. - - -## How to support new models -1. make sure the model is supported by vLLM -2. modelling the huggingface model with mcore `GPTModel` (The [Pai-Megatron-Path](https://github.com/alibaba/Pai-Megatron-Patch/tree/main) is a good reference) - - a. convert the huggingface config to mcore `TransformerConfig` - - b. init the mcore `GPTModel` with the converted config - - c. load the huggingface model weights to the `GPTModel` - - d. for VLM the interface might be different, it is ok to add a new model class with GPTModel as its module. -3. offline weights conversion from huggingface to mcore `dist_checkpointing` format -4. support online weights conversion from mcore to huggingface - - it is recommended to initialize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct. - - -## How to scale up to larger models like deepseek-v3 or other 100B+ models -The greatest challenge for scaling up to larger models is the memory consumption. - -The necessary features under development for scaling up are -1. Training engine part - - expert parallel -2. Rollout engine part - - pipeline parallel - - expert parallel - - more efficient and general weight resharding and loading -3. Offline weights conversion - - support weights larger than single GPU memory diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py deleted file mode 100644 index 23f01e8b7..000000000 --- a/verl/models/mcore/registry.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Registry module for model architecture components. -""" - -from enum import Enum -from typing import Callable - -import torch -import torch.nn as nn - -from .config_converter import ( - PretrainedConfig, - TransformerConfig, - hf_to_mcore_config_dense, - hf_to_mcore_config_dpskv3, - hf_to_mcore_config_llama4, - hf_to_mcore_config_mixtral, - hf_to_mcore_config_qwen2_5_vl, - hf_to_mcore_config_qwen2moe, - hf_to_mcore_config_qwen3moe, -) -from .model_forward import ( - gptmodel_forward, - gptmodel_forward_qwen2_5_vl, -) -from .model_forward_fused import ( - fused_forward_gptmodel, - fused_forward_qwen2_5_vl, -) -from .model_initializer import ( - BaseModelInitializer, - DeepseekV3Model, - DenseModel, - MixtralModel, - Qwen2MoEModel, - Qwen3MoEModel, - Qwen25VLModel, -) -from .weight_converter import ( - McoreToHFWeightConverterDense, - McoreToHFWeightConverterDpskv3, - McoreToHFWeightConverterMixtral, - McoreToHFWeightConverterQwen2_5_VL, - McoreToHFWeightConverterQwen2Moe, - McoreToHFWeightConverterQwen3Moe, -) - - -class SupportedModel(Enum): - LLAMA = "LlamaForCausalLM" # tested - QWEN2 = "Qwen2ForCausalLM" # tested - QWEN2_MOE = "Qwen2MoeForCausalLM" # pending - DEEPSEEK_V3 = "DeepseekV3ForCausalLM" # not tested - MIXTRAL = "MixtralForCausalLM" # tested - QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" # not supported - LLAMA4 = "Llama4ForConditionalGeneration" # not tested - QWEN3 = "Qwen3ForCausalLM" # tested - QWEN3_MOE = "Qwen3MoeForCausalLM" # not tested - - -# Registry for model configuration converters -MODEL_CONFIG_CONVERTER_REGISTRY: dict[SupportedModel, Callable[[PretrainedConfig, torch.dtype], TransformerConfig]] = { - SupportedModel.LLAMA: hf_to_mcore_config_dense, - SupportedModel.QWEN2: hf_to_mcore_config_dense, - SupportedModel.QWEN2_MOE: hf_to_mcore_config_qwen2moe, - SupportedModel.DEEPSEEK_V3: hf_to_mcore_config_dpskv3, - SupportedModel.MIXTRAL: hf_to_mcore_config_mixtral, - SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl, - SupportedModel.LLAMA4: hf_to_mcore_config_llama4, - SupportedModel.QWEN3: hf_to_mcore_config_dense, - SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe, - SupportedModel.QWEN2_5_VL: hf_to_mcore_config_qwen2_5_vl, -} - -# Registry for model initializers -MODEL_INITIALIZER_REGISTRY: dict[SupportedModel, type[BaseModelInitializer]] = { - SupportedModel.LLAMA: DenseModel, - SupportedModel.QWEN2: DenseModel, - SupportedModel.QWEN2_MOE: Qwen2MoEModel, - SupportedModel.MIXTRAL: MixtralModel, - SupportedModel.DEEPSEEK_V3: DeepseekV3Model, - SupportedModel.QWEN2_5_VL: Qwen25VLModel, - SupportedModel.LLAMA4: DenseModel, - SupportedModel.QWEN3: DenseModel, - SupportedModel.QWEN3_MOE: Qwen3MoEModel, - SupportedModel.QWEN2_5_VL: Qwen25VLModel, -} - -# Registry for model forward functions -MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = { - SupportedModel.LLAMA: gptmodel_forward, - SupportedModel.QWEN2: gptmodel_forward, - SupportedModel.QWEN2_MOE: gptmodel_forward, - SupportedModel.MIXTRAL: gptmodel_forward, - SupportedModel.DEEPSEEK_V3: gptmodel_forward, - SupportedModel.QWEN2_5_VL: gptmodel_forward, - SupportedModel.LLAMA4: gptmodel_forward, - SupportedModel.QWEN3: gptmodel_forward, - SupportedModel.QWEN3_MOE: gptmodel_forward, - SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl, - SupportedModel.DEEPSEEK_V3: gptmodel_forward, -} - -# Registry for model forward functions -MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = { - SupportedModel.LLAMA: fused_forward_gptmodel, - SupportedModel.QWEN2: fused_forward_gptmodel, - SupportedModel.QWEN2_MOE: fused_forward_gptmodel, - SupportedModel.MIXTRAL: fused_forward_gptmodel, - SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel, - SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl, - SupportedModel.LLAMA4: fused_forward_gptmodel, - SupportedModel.QWEN3: fused_forward_gptmodel, - SupportedModel.QWEN3_MOE: fused_forward_gptmodel, - SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl, - SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel, -} - -# Registry for model weight converters -MODEL_WEIGHT_CONVERTER_REGISTRY: dict[SupportedModel, type] = { - SupportedModel.LLAMA: McoreToHFWeightConverterDense, - SupportedModel.QWEN2: McoreToHFWeightConverterDense, - SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe, - SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral, - SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3, - SupportedModel.QWEN3: McoreToHFWeightConverterDense, - SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe, - SupportedModel.QWEN2_5_VL: McoreToHFWeightConverterQwen2_5_VL, -} - - -def get_supported_model(model_type: str) -> SupportedModel: - try: - return SupportedModel(model_type) - except ValueError as err: - supported_models = [e.value for e in SupportedModel] - raise NotImplementedError( - f"Model Type: {model_type} not supported. Supported models: {supported_models}" - ) from err - - -def hf_to_mcore_config( - hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs -) -> TransformerConfig: - """Convert huggingface PretrainedConfig to mcore TransformerConfig. - - Args: - hf_config: The huggingface PretrainedConfig. - dtype: The dtype of the model. - **override_transformer_config_kwargs: The kwargs to override the transformer config. - - Returns: - The mcore TransformerConfig. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs) - - -def init_mcore_model( - tfconfig: TransformerConfig, - hf_config: PretrainedConfig, - pre_process: bool = True, - post_process: bool = None, - *, - share_embeddings_and_output_weights: bool = False, - value: bool = False, - **extra_kwargs, # may be used for vlm and moe -) -> nn.Module: - """ - Initialize a Mcore model. - - Args: - tfconfig: The transformer config. - hf_config: The HuggingFace config. - pre_process: Optional pre-processing function. - post_process: Optional post-processing function. - share_embeddings_and_output_weights: Whether to share embeddings and output weights. - value: Whether to use value. - **extra_kwargs: Additional keyword arguments. - - Returns: - The initialized model. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - initializer_cls = MODEL_INITIALIZER_REGISTRY[model] - initializer = initializer_cls(tfconfig, hf_config) - return initializer.initialize( - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - value=value, - **extra_kwargs, - ) - - -def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable: - """ - Get the forward function for given model architecture. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - return MODEL_FORWARD_REGISTRY[model] - - -def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable: - """ - Get the forward function for given model architecture. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - return MODEL_FORWARD_FUSED_REGISTRY[model] - - -def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable: - """ - Get the weight converter for given model architecture. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - tfconfig = hf_to_mcore_config(hf_config, dtype) - return MODEL_WEIGHT_CONVERTER_REGISTRY[model](hf_config, tfconfig) diff --git a/verl/models/mcore/saver.py b/verl/models/mcore/saver.py deleted file mode 100644 index 2a954b241..000000000 --- a/verl/models/mcore/saver.py +++ /dev/null @@ -1,497 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist -from megatron.core import mpu -from megatron.core.distributed import DistributedDataParallel as LocalDDP -from megatron.core.transformer.module import Float16Module -from torch.nn.parallel import DistributedDataParallel as torchDDP - -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.logger import print_rank_0 -from verl.utils.megatron_utils import unwrap_model - - -def _megatron_calc_global_rank( - tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0, cp_rank: int = 0, ep_rank: int = 0 -): - """Calculate global rank with support for CP/EP parallelism""" - - # Get parallel sizes for each dimension - tp_size = mpu.get_tensor_model_parallel_world_size() - dp_size = mpu.get_data_parallel_world_size() - pp_size = mpu.get_pipeline_model_parallel_world_size() - cp_size = mpu.get_context_parallel_world_size() - # ep_size = mpu.get_expert_model_parallel_world_size() - - # Verify total GPU count matches (must be consistent with parallel_state.py) - total_size = tp_size * dp_size * pp_size * cp_size - assert total_size == torch.distributed.get_world_size(), ( - f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}" - ) - - # Core calculation logic (corresponds to RankGenerator order parameter) - # Assumes default order is "tp-cp-ep-dp-pp" - return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): - """Merge sharded parameters of a Megatron module into a merged checkpoint. - - Args: - wrapped_models (list of megatron.core.distributed.DistributedDataParallel): - The local DDP wrapped megatron modules. - config (str or None): - HF config for model - dtype: model params type - is_value_model: if model is value model - tie_word_embeddings: tie_word_embeddings - Returns: - state_dict (dict): - The merged state_dict in rank 0, and an empty dictionary in other ranks. - """ - start_time = time.time() - - def _get_gpt_model(model): - return model - - dp_rank = mpu.get_data_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() - cp_rank = mpu.get_context_parallel_rank() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if dist.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].decoder.layers) == num_layers_per_model, ( - "len model layers {} not equal to num_layers_per_model {}".format( - len(models[i].decoder.layers), num_layers_per_model - ) - ) - - state_dict = dict() - - def _get_cpu_tensor(tensor: torch.Tensor): - if tensor is None: - return None - if tensor.device == torch.device("cpu"): - return tensor.detach().clone() - return tensor.detach().cpu() - - def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: - """broadcast tensor across mp_group""" - nonlocal state_dict - nonlocal mp_group - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - - if torch.distributed.get_rank() == src_rank: - if tensor is None: - weight = None - tensor_shape = None - else: - weight = tensor - tensor_shape = weight.shape - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not exist, skip collect") - return - - if weight is None: - weight = torch.empty( - tensor_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - dist.broadcast(weight, src=src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - state_dict[name] = _get_cpu_tensor(weight) - - def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - # tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=concat_dim) - if mutate_func is not None: - full_tensor = mutate_func(full_tensor) - state_dict[name] = full_tensor - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - # tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_list = [] - up_weight_list = [] - for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] - gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] - up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] - gate_weight_list.append(gate_weight_tp) - up_weight_list.append(up_weight_tp) - - state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) - state_dict[up_name] = torch.cat(up_weight_list, dim=0) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - # tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - q_weight_list = [] - k_weight_list = [] - v_weight_list = [] - hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - - if config.num_key_value_heads >= tp_size: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_size_chunk = q_size_tp // num_query_groups_per_partition - kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): - q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] - q_weight_list.append(q_part) - k_weight_list.append(k_part) - v_weight_list.append(v_part) - else: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_size_chunk = q_size_tp // num_query_groups_per_partition - kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): - q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] - q_weight_list.append(q_part) - if i * config.num_key_value_heads % tp_size == 0: - k_weight_list.append(k_part) - v_weight_list.append(v_part) - - state_dict[q_name] = torch.cat(q_weight_list, dim=0) - state_dict[k_name] = torch.cat(k_weight_list, dim=0) - state_dict[v_name] = torch.cat(v_weight_list, dim=0) - - # empty cache before collecting weights - get_torch_device().empty_cache() - # Embeddings - # ------------------- - if dp_rank == 0 and cp_rank == 0: # models are identical across cp ranks - # Embeddings - # ------------------- - print_rank_0("collecting embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - _broadcast_tp_shard_tensor( - gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None, - "model.embed_tokens.weight", - src_pp_rank=0, - ) - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - for layer in range(config.num_hidden_layers): - print_rank_0(f"collecting layer #{layer}...") - layer_name = f"model.layers.{layer}" - src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) - sync_layer = gpt_model_module.decoder.layers[src_layer_idx] - - _broadcast_tensor( - sync_layer.self_attention.linear_qkv.layer_norm_weight, - f"{layer_name}.input_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - if gpt_model_module.config.qk_layernorm: - _broadcast_tensor( - sync_layer.self_attention.q_layernorm.weight, - f"{layer_name}.self_attn.q_norm.weight", - src_pp_rank=src_pp_rank, - ) - _broadcast_tensor( - sync_layer.self_attention.k_layernorm.weight, - f"{layer_name}.self_attn.k_norm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.weight, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - src_pp_rank=src_pp_rank, - ) - - if gpt_model_module.config.add_qkv_bias: - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attention.linear_qkv.bias, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attention.linear_proj.weight, - f"{layer_name}.self_attn.o_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - _broadcast_tensor( - sync_layer.mlp.linear_fc1.layer_norm_weight, - f"{layer_name}.post_attention_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.linear_fc1.weight, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.linear_fc2.weight, - f"{layer_name}.mlp.down_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - # Final Layernorm - # ------------------- - print_rank_0("collecting final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.decoder.final_layernorm, "weight", None), - "model.norm.weight", - src_pp_rank=pp_size - 1, - ) - - if tie_word_embeddings: - print_rank_0("tie word embedding skip load lm_head...") - else: - print_rank_0("collecting lm_head...") - - if is_value_model: - lm_head_weight = None - if pp_rank == pp_size - 1: - lm_head_weight = getattr(gpt_model_module.output_layer, "weight", None) - _broadcast_tensor(lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1) - - else: - _broadcast_tp_shard_tensor( - getattr(gpt_model_module.output_layer, "weight", None) if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - - dist.barrier() - get_torch_device().empty_cache() - if torch.distributed.get_rank() == 0: - for k, v in state_dict.items(): - if dtype != v.dtype: - state_dict[k] = v.to(dtype) - - print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") - return state_dict - - -def merge_megatron_ckpt_gptmodel_qwen_moe( - wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False -): - raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented") - - -def merge_megatron_ckpt_gptmodel_qwen2_5_vl( - wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False -): - raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen2_5_vl is not implemented") - - -def merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): - raise NotImplementedError("merge_megatron_ckpt_gptmodel_dpskv3 is not implemented") - - -def merge_megatron_ckpt_gptmodel_mixtral( - wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False -): - raise NotImplementedError("merge_megatron_ckpt_gptmodel_mixtral is not implemented") diff --git a/verl/models/mcore/util.py b/verl/models/mcore/util.py deleted file mode 100644 index c1ef7a211..000000000 --- a/verl/models/mcore/util.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from megatron.core import parallel_state as mpu -from megatron.core.packed_seq_params import PackedSeqParams - -from verl.utils.model import CausalLMOutputForPPO - - -def preprocess_packed_seqs( - input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True -) -> tuple[torch.Tensor, PackedSeqParams]: - """ - Preprocess packed sequences - CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 - gets second and second last chunks, and so on), this is for load balancing with causal masking. - See https://github.com/NVIDIA/TransformerEngine/issues/1368 - """ - batch_size = input_ids.shape[0] - - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - tp_size = mpu.get_tensor_model_parallel_world_size() - cp_size = mpu.get_context_parallel_world_size() - cp_rank = mpu.get_context_parallel_rank() - align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size - - pad_size = (align_size - seqlens_in_batch % align_size) % align_size - seqlens_in_batch_padded = seqlens_in_batch + pad_size - cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) - cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) - cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) - cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) - max_seqlen_in_batch = seqlens_in_batch_padded.max().item() - - shape = list(input_ids.shape[1:]) - shape[0] = seqlens_in_batch_padded.sum().item() // cp_size - if pre_process: - input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) - for i in range(batch_size): - if cp_size <= 1: - seqlen = seqlens_in_batch[i] - input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]] - continue - seqlen = seqlens_in_batch_padded[i] // cp_size - half_seqlen = seqlen // 2 - start_idx = cu_seqlens_padded[i] // cp_size - # split to 2 chunks - d = input_ids[i, attention_mask[i]] - input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ - half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) - ] - - remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1) - remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank - remain_end = min(remain_end, d.shape[0]) - remain_len = remain_end - remain_start - if remain_len > 0: - input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ - remain_start:remain_end - ] - - packed_seq_params = PackedSeqParams( - qkv_format="thd", - cu_seqlens_q=cu_seqlens_padded, - max_seqlen_q=max_seqlen_in_batch, - cu_seqlens_kv=cu_seqlens_padded, - max_seqlen_kv=max_seqlen_in_batch, - cu_seqlens_q_padded=cu_seqlens_padded, - cu_seqlens_kv_padded=cu_seqlens_padded, - ) - if pre_process: - return input_ids_rmpad.unsqueeze(0), packed_seq_params - else: - return input_ids, packed_seq_params - - -def postprocess_packed_seqs( - output: torch.Tensor, - packed_seq_params: PackedSeqParams, - attention_mask: torch.Tensor, - batch_size: int, - seq_len: int, - post_process: bool = True, -) -> torch.Tensor: - """ - Postprocess packed sequences - """ - if not post_process: - return output - shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim - output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) - - cp_size = mpu.get_context_parallel_world_size() - # all gather output across context parallel group - if cp_size > 1: - # output shape: [1, packed_len, hidden_dim] - # need to gather across cp group and concatenate in sequence dimension - output_list = [torch.empty_like(output) for _ in range(cp_size)] - torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) - output_list[mpu.get_context_parallel_rank()] = output - else: - output_list = [output] - for i in range(batch_size): - if cp_size <= 1: - s = attention_mask[i].sum().item() - output_new[i, attention_mask[i]] = output[0][ - packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s - ] - continue - s_len_padded_chunk = ( - packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i] - ) // cp_size - half_seqlen = s_len_padded_chunk // 2 - s_len = attention_mask[i].sum().item() - s_len_padded = s_len_padded_chunk * cp_size - tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) - for j in range(cp_size): - o = output_list[j][0] - # split to 2 chunks - packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size - o0, o1 = ( - o[packed_start_idx : packed_start_idx + half_seqlen], - o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], - ) - tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 - tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 - output_new[i, attention_mask[i]] = tmp[:s_len] - - return output_new - - -def remove_left_padding( - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: torch.Tensor, - sequence_parallel: bool = False, - pre_process: bool = True, -): - """ - Remove left padding from input_ids, attention_mask and position_ids - return new_input_ids, new_attention_mask, new_position_ids - """ - assert attention_mask.ndim == 2 - assert position_ids.ndim == 2 - cp_size = mpu.get_context_parallel_world_size() - assert cp_size == 1, "Context parallel size without seq_pack is not supported" - batch_size = input_ids.shape[0] - shape = list(input_ids.shape) # batch_size, seq_len,... - seq_lens = attention_mask.sum(dim=1) - seq_len = seq_lens.max().item() - if sequence_parallel: - sp_world_size = mpu.get_tensor_model_parallel_world_size() - pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size - seq_len = seq_len + pad_size - shape[1] = seq_len - if pre_process: - new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape) - new_attention_mask = torch.zeros( - dtype=attention_mask.dtype, device=attention_mask.device, size=(batch_size, seq_len) - ) - new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len)) - for i in range(batch_size): - if pre_process: - new_input_ids[i, : seq_lens[i]] = input_ids[i, attention_mask[i]] - new_attention_mask[i, : seq_lens[i]] = attention_mask[i, attention_mask[i]] - new_position_ids[i, : seq_lens[i]] = position_ids[i, attention_mask[i]] - if pre_process: - return new_input_ids, new_attention_mask, new_position_ids - else: - return input_ids, new_attention_mask, new_position_ids - - -def recover_left_padding( - result, - attention_mask: torch.Tensor, - original_attention_mask: torch.Tensor, - origin_seqlen: int, - post_process: bool = True, -): - """ - Recover left padding from result - return result - """ - if not post_process: - return result - shape = list(result.shape) - batch_size = shape[0] - shape[1] = origin_seqlen - new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape) - for i in range(batch_size): - new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]] - return new_result - - -def postprocess_packed_seqs_for_dict_output( - labels_mask: torch.Tensor, - output: CausalLMOutputForPPO, - packed_seq_params: PackedSeqParams, - attention_mask: torch.Tensor, - batch_size: int, - seq_len: int, - post_process: bool = True, -) -> dict[str, torch.Tensor]: - """_summary_ - For fused kernels, the output is a dictionary with keys like 'log_probs', 'entropy', etc. - This function post-processes each tensor in the output dictionary. - Args: - output (CausalLMOutputForPPO): _description_ - packed_seq_params (PackedSeqParams): _description_ - attention_mask (torch.Tensor): _description_ - batch_size (int): _description_ - seq_len (int): _description_ - post_process (bool, optional): _description_. Defaults to True. - Returns: - CausalLMOutputForPPO: _description_ - """ - ret = {} - output.entropy = output.entropy.view(1, -1) - output.log_probs = output.log_probs.view(1, -1) - output.log_probs = output.log_probs.masked_fill(~labels_mask, 0.0) - ret["entropy"] = postprocess_packed_seqs( - output.entropy, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process - ) - ret["log_probs"] = postprocess_packed_seqs( - output.log_probs, packed_seq_params, attention_mask, batch_size, seq_len, post_process=post_process - ) - return ret diff --git a/verl/models/mcore/weight_converter.py b/verl/models/mcore/weight_converter.py deleted file mode 100644 index 791513f32..000000000 --- a/verl/models/mcore/weight_converter.py +++ /dev/null @@ -1,479 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# online convert mcore weight to pure huggingface weight, no any fusion -# including format conversion and name mapping -# not including resharding -import torch -from megatron.core.transformer import TransformerConfig -from transformers import PretrainedConfig - - -class McoreToHFWeightConverterBase: - def __init__(self, hf_config: PretrainedConfig, mcore_config: TransformerConfig): - self.hf_config = hf_config - self.mcore_config = mcore_config - - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> torch.Tensor: - raise NotImplementedError - - -class McoreToHFWeightConverterDense(McoreToHFWeightConverterBase): - def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # 'decoder.layers.0.self_attention.linear_proj.weight' - # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight' - # 'decoder.layers.0.self_attention.linear_qkv.weight' - # 'decoder.layers.0.self_attention.linear_qkv.bias' - layer_number = name.split(".")[2] - convert_names = [] - if "self_attention.linear_qkv.bias" in name or "self_attention.linear_qkv.weight" in name: - param_type = name.split(".")[-1] - assert param_type == "bias" or param_type == "weight" - convert_names.append(f"model.layers.{layer_number}.self_attn.q_proj.{param_type}") - convert_names.append(f"model.layers.{layer_number}.self_attn.k_proj.{param_type}") - convert_names.append(f"model.layers.{layer_number}.self_attn.v_proj.{param_type}") - assert len(params) == 3 - elif "self_attention.linear_proj.weight" in name: - convert_names.append(f"model.layers.{layer_number}.self_attn.o_proj.weight") - assert len(params) == 1 - elif "self_attention.linear_qkv.layer_norm_weight" in name: - convert_names.append(f"model.layers.{layer_number}.input_layernorm.weight") - assert len(params) == 1 - elif "self_attention.q_layernorm.weight" in name: - convert_names.append(f"model.layers.{layer_number}.self_attn.q_norm.weight") - assert len(params) == 1 - elif "self_attention.k_layernorm.weight" in name: - convert_names.append(f"model.layers.{layer_number}.self_attn.k_norm.weight") - assert len(params) == 1 - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params - - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' - # 'decoder.layers.0.mlp.linear_fc1.weight' - # 'decoder.layers.0.mlp.linear_fc2.weight' - layer_number = name.split(".")[2] - convert_names = [] - if "mlp.linear_fc1.weight" in name: - # split gate_proj and up_proj - convert_names.append(f"model.layers.{layer_number}.mlp.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.up_proj.weight") - assert len(params) == 2 - elif "mlp.linear_fc1.layer_norm_weight" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") - assert len(params) == 1 - elif "mlp.linear_fc2.weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.down_proj.weight") - assert len(params) == 1 - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params - - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - direct_name_mapping = { - "embedding.word_embeddings.weight": "model.embed_tokens.weight", - "decoder.final_layernorm.weight": "model.norm.weight", - "output_layer.weight": "lm_head.weight", - } - if name in direct_name_mapping: - return [direct_name_mapping[name]], [params_one_group[0]] - - if "self_attention" in name: - return self._convert_attention_param(name, params_one_group) - elif "mlp" in name: - return self._convert_mlp_param(name, params_one_group) - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - - -class McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense): - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # 'decoder.layers.0.pre_mlp_layernorm.weight', - # 'decoder.layers.0.mlp.router.weight', - # 'decoder.layers.0.mlp.shared_experts.gate_weight', - # 'decoder.layers.0.mlp.shared_experts.linear_fc1.weight', - # 'decoder.layers.0.mlp.shared_experts.linear_fc2.weight' - # moe1 - # 'decoder.layers.0.mlp.experts.linear_fc1.weight0', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight1', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight2', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight3', - # moe2 - # 'decoder.layers.0.mlp.experts.linear_fc2.weight0', - # 'decoder.layers.0.mlp.experts.linear_fc2.weight1', - layer_number = name.split(".")[2] - convert_names = [] - if "pre_mlp_layernorm" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") - assert len(params) == 1 - elif "mlp.router.weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") - assert len(params) == 1 - elif "shared_experts.gate_weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert_gate.weight") - assert len(params) == 1 - elif "shared_experts.linear_fc1.weight" in name: # split gate_proj and up_proj - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.up_proj.weight") - assert len(params) == 2 - elif "shared_experts.linear_fc2.weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.shared_expert.down_proj.weight") - assert len(params) == 1 - elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") - assert len(params) == 2 - elif "mlp.experts.linear_fc2" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") - assert len(params) == 1 - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params - - -class McoreToHFWeightConverterQwen2_5_VL(McoreToHFWeightConverterDense): - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - direct_name_mapping = { - "language_model.embedding.word_embeddings.weight": "model.embed_tokens.weight", - "language_model.decoder.final_layernorm.weight": "model.norm.weight", - "language_model.output_layer.weight": "lm_head.weight", - "vision_model.patch_embed.proj.weight": "visual.patch_embed.proj.weight", - "vision_model.decoder.final_layernorm.weight": "visual.merger.ln_q.weight", - "vision_model.projection.encoder.linear_fc1.weight": "visual.merger.mlp.0.weight", - "vision_model.projection.encoder.linear_fc1.bias": "visual.merger.mlp.0.bias", - "vision_model.projection.encoder.linear_fc2.weight": "visual.merger.mlp.2.weight", - "vision_model.projection.encoder.linear_fc2.bias": "visual.merger.mlp.2.bias", - } - if name in direct_name_mapping: - return [direct_name_mapping[name]], [params_one_group[0]] - - if "self_attention" in name: - return self._convert_attention_param(name, params_one_group) - elif "mlp" in name: - return self._convert_mlp_param(name, params_one_group) - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - - def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - model_type, _, _, layer_number = name.split(".")[:4] - - convert_names = [] - if model_type == "language_model": - name_map_after_layer = { - "self_attention.linear_qkv.bias": [ - "self_attn.q_proj.bias", - "self_attn.k_proj.bias", - "self_attn.v_proj.bias", - ], - "self_attention.linear_qkv.weight": [ - "self_attn.q_proj.weight", - "self_attn.k_proj.weight", - "self_attn.v_proj.weight", - ], - "self_attention.linear_proj.weight": "self_attn.o_proj.weight", - "self_attention.linear_qkv.layer_norm_weight": "input_layernorm.weight", - } - name_after_layer = ".".join(name.split(".")[-3:]) - mapped_name = name_map_after_layer.get(name_after_layer) - if isinstance(mapped_name, list): - assert len(params) == len(mapped_name) - for one in mapped_name: - convert_names.append(f"model.layers.{layer_number}.{one}") - else: - assert len(params) == 1 - convert_names.append(f"model.layers.{layer_number}.{mapped_name}") - elif model_type == "vision_model": - name_map_after_layer = { - "self_attention.linear_proj.weight": "attn.proj.weight", - "self_attention.linear_proj.bias": "attn.proj.bias", - "self_attention.linear_qkv.layer_norm_weight": "norm1.weight", - } - name_after_layer = ".".join(name.split(".")[-3:]) - mapped_name = name_map_after_layer.get(name_after_layer, None) - if mapped_name is None: - assert "linear_qkv" in name_after_layer - assert len(params) == 3 - new_param = torch.cat(params, dim=0) - params = [new_param] - if "bias" in name_after_layer: - convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.bias") - else: - convert_names.append(f"visual.blocks.{layer_number}.attn.qkv.weight") - else: - assert len(params) == 1 - convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") - else: - raise NotImplementedError(f"Unsupported model type: {model_type}") - return convert_names, params - - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - model_type, _, _, layer_number = name.split(".")[:4] - - convert_names = [] - if model_type == "language_model": - name_map_after_layer = { - "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], - "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], - "mlp.linear_fc2.weight": "mlp.down_proj.weight", - "mlp.linear_fc2.bias": "mlp.down_proj.bias", - "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", - } - name_after_layer = ".".join(name.split(".")[-3:]) - mapped_name = name_map_after_layer.get(name_after_layer) - if isinstance(mapped_name, list): - assert len(params) == len(mapped_name) - for one in mapped_name: - convert_names.append(f"model.layers.{layer_number}.{one}") - else: - assert len(params) == 1 - convert_names.append(f"model.layers.{layer_number}.{mapped_name}") - - elif model_type == "vision_model": - name_map_after_layer = { - "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], - "mlp.linear_fc1.bias": ["mlp.gate_proj.bias", "mlp.up_proj.bias"], - "mlp.linear_fc2.weight": "mlp.down_proj.weight", - "mlp.linear_fc2.bias": "mlp.down_proj.bias", - "mlp.linear_fc1.layer_norm_weight": "norm2.weight", - } - name_after_layer = ".".join(name.split(".")[-3:]) - mapped_name = name_map_after_layer.get(name_after_layer) - if isinstance(mapped_name, list): - assert len(params) == len(mapped_name) - for one in mapped_name: - convert_names.append(f"visual.blocks.{layer_number}.{one}") - else: - assert len(params) == 1 - convert_names.append(f"visual.blocks.{layer_number}.{mapped_name}") - else: - raise NotImplementedError(f"Unsupported model type: {model_type}") - return convert_names, params - - -class McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase): - def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # mcore - # 'decoder.layers.0.input_layernorm.weight' - # 'decoder.layers.0.self_attention.linear_proj.weight' - # 'decoder.layers.0.self_attention.linear_q_proj.weight' - # 'decoder.layers.0.self_attention.linear_kv_down_proj.weight' - # 'decoder.layers.0.self_attention.linear_kv_up_proj.layer_norm_weight' - # 'decoder.layers.0.self_attention.linear_kv_up_proj.weight' - # 'decoder.layers.0.self_attention.linear_q_down_proj.weight' - # 'decoder.layers.0.self_attention.linear_q_up_proj.weight' - # 'decoder.layers.0.self_attention.linear_q_up_proj.layer_norm_weight' - # hf - # 'model.layers.0.input_layernorm.weight' - # 'model.layers.0.self_attn.o_proj.weight' - # 'model.layers.0.self_attn.q_proj.weight' - # 'model.layers.0.self_attn.kv_a_proj_with_mqa.weight' - # 'model.layers.0.self_attn.kv_a_layernorm.weight' - # 'model.layers.0.self_attn.kv_b_proj.weight' - # 'model.layers.0.self_attn.q_a_proj.weight' - # 'model.layers.0.self_attn.q_b_proj.weight' - # 'model.layers.0.self_attn.q_a_layernorm.weight' - name_map_after_layer = { - "input_layernorm.weight": "input_layernorm.weight", - "self_attention.linear_proj.weight": "self_attn.o_proj.weight", - "self_attention.linear_q_proj.weight": "self_attn.q_proj.weight", - "self_attention.linear_kv_down_proj.weight": "self_attn.kv_a_proj_with_mqa.weight", - "self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight", - "self_attention.linear_kv_up_proj.weight": "self_attn.kv_b_proj.weight", - "self_attention.linear_q_down_proj.weight": "self_attn.q_a_proj.weight", - "self_attention.linear_q_up_proj.weight": "self_attn.q_b_proj.weight", - "self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight", - } - assert len(params) == 1 - convert_names = [] - layer_number = name.split(".")[2] - name_after_layer = name.split(f".{layer_number}.")[1] - convert_names.append(f"model.layers.{layer_number}.{name_map_after_layer[name_after_layer]}") - return convert_names, params - - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # mcore dense - # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight' - # 'decoder.layers.0.mlp.linear_fc2.weight' - # 'decoder.layers.0.mlp.linear_fc1.weight' - # --- - # 'decoder.layers.1.mlp.shared_experts.linear_fc1.weight' - # --- - # 'decoder.layers.1.mlp.shared_experts.linear_fc2.weight' - # hf dense - # 'model.layers.0.post_attention_layernorm.weight' - # 'model.layers.0.mlp.down_proj.weight' - # 'model.layers.0.mlp.gate_proj.weight' - # 'model.layers.0.mlp.up_proj.weight' - # 'model.layers.1.mlp.shared_experts.gate_proj.weight' - # 'model.layers.1.mlp.shared_experts.up_proj.weight' - # 'model.layers.1.mlp.shared_experts.down_proj.weight' - - # mcore moe - # 'decoder.layers.1.pre_mlp_layernorm.weight' - # 'decoder.layers.1.mlp.router.weight' - # 'decoder.layers.1.mlp.router.expert_bias' - # 'decoder.layers.1.mlp.experts.linear_fc1.weight0' - # --- - # 'decoder.layers.1.mlp.experts.linear_fc2.weight0' - # hf moe - # 'model.layers.1.post_attention_layernorm.weight' - # 'model.layers.1.mlp.gate.weight' - # 'model.layers.1.mlp.gate.e_score_correction_bias' - # 'model.layers.1.mlp.experts.0.gate_proj.weight' - # 'model.layers.1.mlp.experts.0.up_proj.weight' - # 'model.layers.1.mlp.experts.0.down_proj.weight' - - name_map_after_layer = { - "mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight", - "mlp.linear_fc2.weight": "mlp.down_proj.weight", - "mlp.shared_experts.linear_fc2.weight": "mlp.shared_experts.down_proj.weight", - "mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"], - "mlp.shared_experts.linear_fc1.weight": [ - "mlp.shared_experts.gate_proj.weight", - "mlp.shared_experts.up_proj.weight", - ], - "pre_mlp_layernorm.weight": "post_attention_layernorm.weight", - "mlp.router.weight": "mlp.gate.weight", - "mlp.router.expert_bias": "mlp.gate.e_score_correction_bias", - } - convert_names = [] - layer_number = name.split(".")[2] - name_after_layer = name.split(f".{layer_number}.")[1] - if name_after_layer in name_map_after_layer: - mapped_name = name_map_after_layer[name_after_layer] - if isinstance(mapped_name, list): - assert len(params) == len(mapped_name) - for one in mapped_name: - convert_names.append(f"model.layers.{layer_number}.{one}") - else: - assert len(params) == 1 - convert_names.append(f"model.layers.{layer_number}.{mapped_name}") - else: - if "mlp.experts.linear_fc1.weight" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") - assert len(params) == 2 - elif "mlp.experts.linear_fc2.weight" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") - assert len(params) == 1 - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - - return convert_names, params - - def _convert_mtp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - assert self.mcore_config.mtp_num_layers == 1, "only support one mtp layer for now" - assert self.mcore_config.num_layers == 61, "only support 61 layers for now" - direct_name_mapping = { - "mtp.layers.0.enorm.weight": "model.layers.61.enorm.weight", - "mtp.layers.0.hnorm.weight": "model.layers.61.hnorm.weight", - "mtp.layers.0.eh_proj.weight": "model.layers.61.eh_proj.weight", - "mtp.layers.0.final_layernorm.weight": "model.layers.61.shared_head.norm.weight", - } - if name in direct_name_mapping: - return [direct_name_mapping[name]], [params[0]] - assert "mtp.layers.0.transformer_layer" in name, "only support transformer layer for now" - # use proxy name to convert - proxy_name = name.replace("mtp.layers.0.transformer_layer", "decoder.layers.61") - if "self_attention" in proxy_name or "input_layernorm.weight" in proxy_name: - convert_names, params = self._convert_attention_param(proxy_name, params) - elif "mlp" in proxy_name: - convert_names, params = self._convert_mlp_param(proxy_name, params) - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params - - def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - direct_name_mapping = { - "embedding.word_embeddings.weight": "model.embed_tokens.weight", - "decoder.final_layernorm.weight": "model.norm.weight", - "output_layer.weight": "lm_head.weight", - } - if name in direct_name_mapping: - return [direct_name_mapping[name]], [params_one_group[0]] - if "mtp" in name: - return self._convert_mtp_param(name, params_one_group) - elif "self_attention" in name or "input_layernorm.weight" in name: - return self._convert_attention_param(name, params_one_group) - elif "mlp" in name: - return self._convert_mlp_param(name, params_one_group) - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - - -class McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense): - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # decoder.layers.0.mlp.router.weight - # decoder.layers.0.mlp.experts.linear_fc1.weight0 - weight7 - # decoder.layers.0.mlp.experts.linear_fc2.weight0 - weight7 - - layer_number = name.split(".")[2] - convert_names = [] - if "pre_mlp_layernorm" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") - elif "mlp.router.weight" in name: - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.gate.weight") - elif "mlp.experts.linear_fc1.weight" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w1.weight") - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w3.weight") - elif "mlp.experts.linear_fc2.weight" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.block_sparse_moe.experts.{expert_id}.w2.weight") - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params - - -class McoreToHFWeightConverterQwen3Moe(McoreToHFWeightConverterDense): - def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]: - # qwen3 moe no share expert - - # 'decoder.layers.0.pre_mlp_layernorm.weight', - # 'decoder.layers.0.mlp.router.weight', - # moe1 - # 'decoder.layers.0.mlp.experts.linear_fc1.weight0', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight1', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight2', - # 'decoder.layers.0.mlp.experts.linear_fc1.weight3', - # moe2 - # 'decoder.layers.0.mlp.experts.linear_fc2.weight0', - # 'decoder.layers.0.mlp.experts.linear_fc2.weight1', - layer_number = name.split(".")[2] - convert_names = [] - if "pre_mlp_layernorm" in name: - convert_names.append(f"model.layers.{layer_number}.post_attention_layernorm.weight") - assert len(params) == 1 - elif "mlp.router.weight" in name: - convert_names.append(f"model.layers.{layer_number}.mlp.gate.weight") - assert len(params) == 1 - elif "mlp.experts.linear_fc1" in name: # split gate_proj and up_proj - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight") - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight") - assert len(params) == 2 - elif "mlp.experts.linear_fc2" in name: - expert_id = name.split("weight")[-1] - convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight") - assert len(params) == 1 - else: - raise NotImplementedError(f"Unsupported parameter name: {name}") - return convert_names, params diff --git a/verl/models/qwen2/__init__.py b/verl/models/qwen2/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/models/qwen2/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/models/qwen2/megatron/__init__.py b/verl/models/qwen2/megatron/__init__.py deleted file mode 100644 index 57e33ee9e..000000000 --- a/verl/models/qwen2/megatron/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .modeling_qwen2_megatron import ( - ParallelQwen2ForCausalLM, - # rmpad with megatron - ParallelQwen2ForCausalLMRmPad, - # rmpad with megatron and pipeline parallelism - ParallelQwen2ForCausalLMRmPadPP, - ParallelQwen2ForValueRmPad, - ParallelQwen2ForValueRmPadPP, - # original model with megatron - ParallelQwen2Model, -) - -__all__ = [ - "ParallelQwen2ForCausalLM", - "ParallelQwen2ForCausalLMRmPad", - "ParallelQwen2ForCausalLMRmPadPP", - "ParallelQwen2ForValueRmPad", - "ParallelQwen2ForValueRmPadPP", - "ParallelQwen2Model", -] diff --git a/verl/models/qwen2/megatron/checkpoint_utils/__init__.py b/verl/models/qwen2/megatron/checkpoint_utils/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/models/qwen2/megatron/checkpoint_utils/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py deleted file mode 100644 index 3168635c7..000000000 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_id, get_torch_device - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_qwen2( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False -): - """Load merged state_dict to sharded Megatron module in training.""" - from megatron.core import DistributedDataParallel as LocalDDP - from megatron.core import mpu - from megatron.core.transformer.module import Float16Module - from torch.nn.parallel import DistributedDataParallel as torchDDP - - from verl.utils.logger import print_rank_0 - from verl.utils.megatron_utils import unwrap_model - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def fetch_params(module): - for param in module.parameters(): - torch.distributed.fetch( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() - ) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( - f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " - f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" - ) - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.model.layers) == num_layers_per_model - - def _fetch_tensor(tensor, name) -> torch.Tensor: - """fetch tensor""" - nonlocal state_dict - if tensor is not None: - tensor = tensor.data.copy_(state_dict[name], non_blocking=True) - - def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """fetch tensor in tp shards""" - nonlocal state_dict - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - if tensor is not None: - tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) - else: - print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """fetch tensor in tp shards""" - nonlocal state_dict - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - if tensor is not None: - tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) - else: - print(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """fetch gate_up tensor in tp shards""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - if gate_name in state_dict and up_name in state_dict: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - if tensor is not None: - tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) - else: - print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading") - - def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: - """fetch tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - assert q_name in state_dict and k_name in state_dict and v_name in state_dict - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - if not bias: - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) - - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - if not bias: - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0)) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - if tensor is not None: - tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True) - - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.model.embed_tokens.weight - _fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - num_layer_per_pp = config.num_hidden_layers // pp_size - vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - - layer_list = [] - if vpp_size is not None: - for vpp_rank in range(vpp_size): - num_layer_vpp_chunk = num_layer_per_pp // vpp_size - num_layer_this_model = num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + ( - mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk - ) - layer_list.extend(list(range(offset, offset + num_layer_this_model))) - else: - num_layer_this_model = num_layer_per_pp - offset = pp_rank * num_layer_per_pp - layer_list.extend(list(range(offset, offset + num_layer_this_model))) - - for layer in layer_list: - print(f"{torch.distributed.get_rank()} loading layer #{layer}...") - layer_name = f"model.layers.{layer}" - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - print( - f"{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, " - f"layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}" - ) - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[dst_layer_idx] - - _fetch_tensor( - sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - _fetch_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - - _fetch_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - bias=True, - ) - - _fetch_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - - _fetch_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _fetch_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - ) - - _fetch_tp_shard_tensor( - sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _fetch_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - ) - - if tie_word_embeddings: - print_rank_0("tie_word_embeddings skip load lm_head") - else: - print_rank_0("loading lm_head...") - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.lm_head.weight - - if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: - _fetch_tensor(lm_head_weight, "lm_head.weight") - print_rank_0("load lm_head from value_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: - _fetch_tensor(lm_head_weight, "reward_head.weight") - print_rank_0("load lm_head from value_head weight") - else: - _fetch_tensor(None, "lm_head.weight") - print_rank_0("fail to match lm_head in value_model") - - else: - _fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight") - - dist.barrier() - get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py deleted file mode 100644 index 770e36533..000000000 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_loader_depracated.py +++ /dev/null @@ -1,475 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_id, get_torch_device - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def load_state_dict_to_megatron_qwen2( - state_dict, wrapped_models, config, params_dtype, is_value_model=False, tie_word_embeddings=False -): - """Load merged state_dict to sharded Megatron module in training.""" - from megatron.core import DistributedDataParallel as LocalDDP - from megatron.core import mpu - from megatron.core.transformer.module import Float16Module - from torch.nn.parallel import DistributedDataParallel as torchDDP - - from verl.utils.logger import print_rank_0 - from verl.utils.megatron_utils import unwrap_model - - start_time = time.time() - - def _get_gpt_model(model): - return model - - def broadcast_params(module): - for param in module.parameters(): - torch.distributed.broadcast( - param.data, src=mpu.get_data_parallel_src_rank(), group=mpu.get_data_parallel_group() - ) - - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if torch.distributed.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, ( - f"num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: " - f"{virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}" - ) - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - gpt_model_module = _get_gpt_model(models[i]) - assert len(gpt_model_module.model.layers) == num_layers_per_model - - def _broadcast_tensor(tensor, name) -> torch.Tensor: - """broadcast tensor from rank0 across mp_group""" - nonlocal state_dict - nonlocal mp_group - if torch.distributed.get_rank() == 0: - if name in state_dict: - weight = state_dict[name] - tensor_shape = weight.shape - else: - tensor_shape = None - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not in state_dict, skip load") - return - - if tensor is None: - tensor = torch.empty( - tensor_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - if torch.distributed.get_rank() == 0: - tensor.data.copy_(weight) - dist.broadcast(tensor, src=0, group=mp_group) - - def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - if name in state_dict: - full_weight = state_dict[name] - - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - if name in state_dict: - full_weight = state_dict[name] - if mutate_func is not None: - full_weight = mutate_func(full_weight) - tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - gate_weight = state_dict[gate_name] - up_weight = state_dict[up_name] - new_gate_up_weight = torch.empty( - config.intermediate_size * 2, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - for i in range(tp_size): - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_tp = gate_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - up_weight_tp = up_weight[i * intermediate_size_tp : (i + 1) * intermediate_size_tp] - new_gate_up_weight[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)].copy_( - torch.cat([gate_weight_tp, up_weight_tp], dim=0) - ) - - tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape " - f"{tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - - if torch.distributed.get_rank() == 0: - assert q_name in state_dict and k_name in state_dict and v_name in state_dict - full_weight_q = state_dict[q_name] - full_weight_k = state_dict[k_name] - full_weight_v = state_dict[v_name] - - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - if not bias: - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - k_part = full_weight_k[i * kv_size_tp : (i + 1) * kv_size_tp] - v_part = full_weight_v[i * kv_size_tp : (i + 1) * kv_size_tp] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( - torch.cat([q_part, k_part, v_part], dim=0) - ) - - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - if not bias: - new_weight_qkv = torch.empty( - total_size * tp_size, config.hidden_size, dtype=params_dtype, device=get_device_id() - ) - else: - new_weight_qkv = torch.empty(total_size * tp_size, dtype=params_dtype, device=get_device_id()) - for i in range(tp_size): - q_part = full_weight_q[i * q_size_tp : (i + 1) * q_size_tp] - start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head - end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head - k_part = full_weight_k[start_idx:end_idx] - v_part = full_weight_v[start_idx:end_idx] - new_weight_qkv[i * total_size : (i + 1) * total_size].copy_( - torch.cat([q_part, k_part, v_part], dim=0) - ) - - tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0) - chunk_shape = tensor_chunk[0].shape - else: - chunk_shape = None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=0, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading") - return - - if tensor is None: - sync_tensor = torch.empty( - chunk_shape, - dtype=params_dtype, - device=get_device_id(), - requires_grad=False, - ) - else: - assert tensor.shape == chunk_shape, ( - f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}" - ) - sync_tensor = torch.empty_like(tensor, device=get_device_id(), requires_grad=False) - - for i in range(tp_size): - if torch.distributed.get_rank() == 0: - sync_tensor.data.copy_(tensor_chunk[i]) - dist.broadcast(sync_tensor, src=0, group=mp_group) - if (i == tp_rank) and (tensor is not None): - tensor.data.copy_(sync_tensor) - - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("loading embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - embed_tokens_weight = None - if pp_rank == 0: - embed_tokens_weight = gpt_model_module.model.embed_tokens.weight - _broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight") - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - - for layer in range(config.num_hidden_layers): - print_rank_0(f"loading layer #{layer}...") - layer_name = f"model.layers.{layer}" - dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[dst_layer_idx] - - _broadcast_tensor( - sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.input_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - bias=True, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.self_attn.o_proj.weight", - chunk_dim=1, - ) - - _broadcast_tensor( - sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.post_attention_layernorm.weight", - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None, - f"{layer_name}.mlp.down_proj.weight", - chunk_dim=1, - ) - # Final Layernorm - # ------------------- - print_rank_0("loading final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - ) - - if tie_word_embeddings: - print_rank_0("tie_word_embeddings skip load lm_head") - else: - print_rank_0("loading lm_head...") - lm_head_weight = None - if pp_rank + 1 == pp_size: - lm_head_weight = gpt_model_module.lm_head.weight - - if is_value_model: - if "lm_head.weight" in state_dict and state_dict["lm_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "lm_head.weight") - print_rank_0("load lm_head from value_head weight") - elif "reward_head.weight" in state_dict and state_dict["reward_head.weight"].shape[0] == 1: - _broadcast_tensor(lm_head_weight, "reward_head.weight") - print_rank_0("load lm_head from value_head weight") - else: - _broadcast_tensor(None, "lm_head.weight") - print_rank_0("fail to match lm_head in value_model") - - else: - _broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight") - - dist.barrier() - # Broadcast weights inside data parallel groups - for wrapped_model in wrapped_models: - broadcast_params(wrapped_model) - - get_torch_device().empty_cache() - print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s") diff --git a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py b/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py deleted file mode 100644 index 737f73b4c..000000000 --- a/verl/models/qwen2/megatron/checkpoint_utils/qwen2_saver.py +++ /dev/null @@ -1,448 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import torch -import torch.distributed as dist -from megatron.core import mpu -from megatron.core.distributed import DistributedDataParallel as LocalDDP -from megatron.core.transformer.module import Float16Module -from torch.nn.parallel import DistributedDataParallel as torchDDP - -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.logger import print_rank_0 -from verl.utils.megatron_utils import unwrap_model - - -def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0): - """given TP,DP,PP rank to get the global rank.""" - - tp_size = mpu.get_tensor_model_parallel_world_size() - dp_size = mpu.get_data_parallel_world_size() - pp_size = mpu.get_pipeline_model_parallel_world_size() - assert tp_size * dp_size * pp_size == torch.distributed.get_world_size(), ( - f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}" - ) - # We only support TP-DP-PP grouping, for correctness when resharding - return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank - - -def _megatron_calc_layer_map(config): - """Calculate the mapping of global layer_idx to local layer_idx - Returns: - layer_map (Dict: int -> tuple(int, int, int)): - mapping from the global layer index to - a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model) - """ - from megatron.core import mpu - - pp_size = mpu.get_pipeline_model_parallel_world_size() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - - layer_map = dict() - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - for pp_rank_idx in range(pp_size): - for virtual_pp_rank_idx in range(virtual_pp_size): - layer_offset = ( - virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) + pp_rank_idx * num_layers_per_model - ) - for layer_idx in range(num_layers_per_model): - layer_map[layer_offset + layer_idx] = ( - pp_rank_idx, - virtual_pp_rank_idx, - layer_idx, - ) - return layer_map - - -def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False): - """Merge sharded parameters of a Megatron module into a merged checkpoint. - - Args: - wrapped_models (list of megatron.core.distributed.DistributedDataParallel): - The local DDP wrapped megatron modules. - config (str or None): - HF config for model - dtype: model params type - is_value_model: if model is value model - tie_word_embeddings: tie_word_embeddings - Returns: - state_dict (dict): - The merged state_dict in rank 0, and an empty dictionary in other ranks. - """ - start_time = time.time() - - def _get_gpt_model(model): - return model - - dp_rank = mpu.get_data_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - pp_rank = mpu.get_pipeline_model_parallel_rank() - virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1 - mp_group = mpu.get_model_parallel_group() - - if dist.get_rank() == 0: - assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0" - assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0" - assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0" - - if not isinstance(wrapped_models, list | tuple): - wrapped_models = list(wrapped_models) - - assert len(wrapped_models) == virtual_pp_size - num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size - assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers - - models = [None] * len(wrapped_models) - - for i, wrapped_model in enumerate(wrapped_models): - models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module)) - assert len(models[i].model.layers) == num_layers_per_model, ( - "len model layers {} not equal to num_layers_per_model {}".format( - len(models[i].model.layers), num_layers_per_model - ) - ) - - state_dict = dict() - - def _get_cpu_tensor(tensor: torch.Tensor): - if tensor is None: - return None - if tensor.device == torch.device("cpu"): - return tensor.detach().clone() - return tensor.detach().cpu() - - def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor: - """broadcast tensor across mp_group""" - nonlocal state_dict - nonlocal mp_group - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - if torch.distributed.get_rank() == src_rank: - if tensor is None: - weight = None - tensor_shape = None - else: - weight = tensor - tensor_shape = weight.shape - else: - weight = None - tensor_shape = None - - obj_list = [tensor_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - tensor_shape = obj_list[0] - - if tensor_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tensor:[{name}] not exist, skip collect") - return - - if weight is None: - weight = torch.empty( - tensor_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - dist.broadcast(weight, src=src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - state_dict[name] = _get_cpu_tensor(weight) - - def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=concat_dim) - if mutate_func is not None: - full_tensor = mutate_func(full_tensor) - state_dict[name] = full_tensor - - def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor: - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_list = [] - up_weight_list = [] - for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] - gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] - up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] - gate_weight_list.append(gate_weight_tp) - up_weight_list.append(up_weight_tp) - - state_dict[gate_name] = torch.cat(gate_weight_list, dim=0) - state_dict[up_name] = torch.cat(up_weight_list, dim=0) - - def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank): - """broadcast tensor in tp shards across mp_group""" - nonlocal state_dict - nonlocal mp_group - tp_size = mpu.get_tensor_model_parallel_world_size() - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank) - - chunk_shape = tensor.shape if torch.distributed.get_rank() == src_rank else None - - obj_list = [chunk_shape] - dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group) - chunk_shape = obj_list[0] - if chunk_shape is None: - # all or none ranks in the mp_group should reach here - print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting") - return - - buffer_tensor = torch.empty( - chunk_shape, - dtype=dtype, - device=get_device_id(), - requires_grad=False, - ) - - chunk_tensors = [None] * tp_size - - for i in range(tp_size): - cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank) - sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor - dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group) - - if torch.distributed.get_rank() == 0: - chunk_tensors[i] = _get_cpu_tensor(sync_tensor) - - if torch.distributed.get_rank() == 0: - full_tensor = torch.concat(chunk_tensors, dim=0) - q_weight_list = [] - k_weight_list = [] - v_weight_list = [] - hidden_size_per_head = config.hidden_size // config.num_attention_heads - - if config.num_key_value_heads >= tp_size: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp : total_size] - q_weight_list.append(q_part) - k_weight_list.append(k_part) - v_weight_list.append(v_part) - else: - q_size_tp = config.hidden_size // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_part = qkv_part[:q_size_tp] - k_part = qkv_part[q_size_tp : q_size_tp + kv_size_tp] - v_part = qkv_part[q_size_tp + kv_size_tp : total_size] - q_weight_list.append(q_part) - if i * config.num_key_value_heads % tp_size == 0: - k_weight_list.append(k_part) - v_weight_list.append(v_part) - - state_dict[q_name] = torch.cat(q_weight_list, dim=0) - state_dict[k_name] = torch.cat(k_weight_list, dim=0) - state_dict[v_name] = torch.cat(v_weight_list, dim=0) - - # empty cache before collecting weights - get_torch_device().empty_cache() - # Embeddings - # ------------------- - if dp_rank == 0: - # Embeddings - # ------------------- - print_rank_0("collecting embeddings...") - gpt_model_module = _get_gpt_model(models[0]) - _broadcast_tp_shard_tensor( - gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None, - "model.embed_tokens.weight", - src_pp_rank=0, - ) - - # Transformer layers - # ------------------- - layer_map = _megatron_calc_layer_map(config) - for layer in range(config.num_hidden_layers): - print_rank_0(f"collecting layer #{layer}...") - layer_name = f"model.layers.{layer}" - src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer] - - gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank]) - sync_layer = gpt_model_module.model.layers[src_layer_idx] - - _broadcast_tensor( - sync_layer.input_layernorm.weight, - f"{layer_name}.input_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.weight, - f"{layer_name}.self_attn.q_proj.weight", - f"{layer_name}.self_attn.k_proj.weight", - f"{layer_name}.self_attn.v_proj.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_qkv( - sync_layer.self_attn.qkv_proj.bias, - f"{layer_name}.self_attn.q_proj.bias", - f"{layer_name}.self_attn.k_proj.bias", - f"{layer_name}.self_attn.v_proj.bias", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.self_attn.o_proj.weight, - f"{layer_name}.self_attn.o_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - _broadcast_tensor( - sync_layer.post_attention_layernorm.weight, - f"{layer_name}.post_attention_layernorm.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor_gate_up( - sync_layer.mlp.gate_up_proj.weight, - f"{layer_name}.mlp.gate_proj.weight", - f"{layer_name}.mlp.up_proj.weight", - src_pp_rank=src_pp_rank, - ) - - _broadcast_tp_shard_tensor( - sync_layer.mlp.down_proj.weight, - f"{layer_name}.mlp.down_proj.weight", - concat_dim=1, - src_pp_rank=src_pp_rank, - ) - - # Final Layernorm - # ------------------- - print_rank_0("collecting final layernorm...") - gpt_model_module = _get_gpt_model(models[-1]) - _broadcast_tensor( - getattr(gpt_model_module.model.norm, "weight", None), - "model.norm.weight", - src_pp_rank=pp_size - 1, - ) - - if tie_word_embeddings: - print_rank_0("tie word embedding skip load lm_head...") - else: - print_rank_0("collecting lm_head...") - - if is_value_model: - _broadcast_tensor( - gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - _broadcast_tensor( - gpt_model_module.reward_head.weight - if pp_rank == pp_size - 1 and getattr(gpt_model_module, "reward_weight", None) is not None - else None, - "reward_head.weight", - src_pp_rank=pp_size - 1, - ) - - else: - _broadcast_tp_shard_tensor( - getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None, - "lm_head.weight", - src_pp_rank=pp_size - 1, - ) - - dist.barrier() - - get_torch_device().empty_cache() - if torch.distributed.get_rank() == 0: - for k, v in state_dict.items(): - if dtype != v.dtype: - state_dict[k] = v.to(dtype) - - print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s") - return state_dict diff --git a/verl/models/qwen2/megatron/layers/__init__.py b/verl/models/qwen2/megatron/layers/__init__.py deleted file mode 100644 index 263ea596f..000000000 --- a/verl/models/qwen2/megatron/layers/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .parallel_attention import ParallelQwen2Attention -from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad -from .parallel_mlp import ParallelQwen2MLP -from .parallel_rmsnorm import ParallelQwen2RMSNorm - -__all__ = [ - "ParallelQwen2Attention", - "ParallelQwen2DecoderLayer", - "ParallelQwen2DecoderLayerRmPad", - "ParallelQwen2MLP", - "ParallelQwen2RMSNorm", -] diff --git a/verl/models/qwen2/megatron/layers/parallel_attention.py b/verl/models/qwen2/megatron/layers/parallel_attention.py deleted file mode 100644 index 702c429c2..000000000 --- a/verl/models/qwen2/megatron/layers/parallel_attention.py +++ /dev/null @@ -1,399 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Optional - -import torch.nn.functional as F -from einops import rearrange -from transformers.utils import is_flash_attn_2_available - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -import torch -from flash_attn.layers.rotary import apply_rotary_emb -from megatron.core import ModelParallelConfig, tensor_parallel -from megatron.core import parallel_state as mpu -from torch import nn -from transformers import Qwen2Config - -from verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear -from verl.utils.megatron import tensor_parallel as tp_utils - - -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding): - """Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): - """Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class ParallelQwen2Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - super().__init__() - self.config = config - self.megatron_config = megatron_config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - - # assign values after tp - tp_size = mpu.get_tensor_model_parallel_world_size() - assert self.num_heads % tp_size == 0, ( - f"num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}" - ) - assert self.num_key_value_heads % tp_size == 0, ( - f"num_key_value_heads must be divisible by tp_size. Got num_key_value_heads=" - f"{self.num_key_value_heads}, tp_size={tp_size}" - ) - - self.num_heads_per_tp = self.num_heads // tp_size - self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size - self.hidden_size_per_tp = self.hidden_size // tp_size - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and " - f"`num_heads`: {self.num_heads})." - ) - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() - - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - assert row_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) - tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) - - # [self.q_size, self.k_size, self.v_size] - self.qkv_proj = QKVParallelLinear( - input_size=self.hidden_size, - num_heads=self.num_heads, - num_key_value_heads=self.num_key_value_heads, - head_dim=self.head_dim, - # bias=config.attention_bias, - bias=True, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - self.q_size = self.num_heads_per_tp * self.head_dim - self.k_size = self.num_key_value_heads_per_tp * self.head_dim - self.v_size = self.num_key_value_heads_per_tp * self.head_dim - - self.o_proj = tensor_parallel.RowParallelLinear( - input_size=self.num_heads * self.head_dim, - output_size=self.hidden_size, - # bias=config.attention_bias, - bias=False, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs, - ) - - self._init_rope() - - def _init_rope(self): - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) - - query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, " - f"but is {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, " - f"but is {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp) - attn_output = self.o_proj(attn_output)[0] - return attn_output - - -""" -Remove padding Attention -- Using Flash-attn 2 -- Compatible with sequence parallel -""" - - -def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length): - batch_size = position_ids.shape[0] - - q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim) - k = pad_input(k, indices, batch_size, sequence_length) - cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - - q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices) - k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices) - - return q_embed, k_embed - - -# use flash-attn rotary embeddings with rmpad -# cos/sin shoudl be: (seq_length, rotary_dim / 2) -def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen): - q_embed = apply_rotary_emb( - q, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - k_embed = apply_rotary_emb( - k, cos, sin, interleaved=False, inplace=False, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - return q_embed, k_embed - - -class ParallelQwen2AttentionRmPad(ParallelQwen2Attention): - def forward( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: torch.Tensor = None, - max_seqlen_in_batch: int = None, - ): - total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel - - if self.megatron_config.sequence_parallel: - total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size() - - qkv = self.qkv_proj(hidden_states)[0] - query_states, key_states, value_states = qkv.split( - [self.q_size, self.k_size, self.v_size], dim=-1 - ) # (total_nnz, 1, hidden_size) - - if self.megatron_config.sequence_parallel: - sequence_parallel_pad = total_nnz - cu_seqlens[-1] - total_nnz = cu_seqlens[-1] # total_nnz before sp padding - query_states = query_states[:total_nnz] - key_states = key_states[:total_nnz] - value_states = value_states[:total_nnz] - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dime x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim) - key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) - value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim) - - cos, sin = self.rotary_emb(value_states, seq_len=sequence_length) - cos, sin = cos[:, : cos.shape[1] // 2], sin[:, : sin.shape[1] // 2] # flash attn only needs half - query_states, key_states = apply_rotary_pos_emb_rmpad_flash( - query_states, key_states, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen_in_batch - ) - # query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, - # position_ids, indices, - - # It is recommended to use dropout with FA according to the docs - # when training. - dropout_rate = 0.0 # if not self.training else self.attn_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Qwen2RMSNorm handles it correctly) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(torch.float16) - key_states = key_states.to(torch.float16) - value_states = value_states.to(torch.float16) - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_in_batch, - max_seqlen_k=max_seqlen_in_batch, - dropout_p=dropout_rate, - softmax_scale=None, - causal=True, - ) - - attn_output_unpad = attn_output_unpad.to(input_dtype) - attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous() - - # sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled - # Here we need to repad - if self.megatron_config.sequence_parallel: - attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad)) - - attn_output_unpad = self.o_proj(attn_output_unpad)[0] - return attn_output_unpad diff --git a/verl/models/qwen2/megatron/layers/parallel_decoder.py b/verl/models/qwen2/megatron/layers/parallel_decoder.py deleted file mode 100644 index 3c8a2a6ee..000000000 --- a/verl/models/qwen2/megatron/layers/parallel_decoder.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -from megatron.core import ModelParallelConfig -from torch import nn -from transformers import Qwen2Config - -from verl.utils.megatron_utils import TransformerConfig, convert_config - -from .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad -from .parallel_mlp import ParallelQwen2MLP -from .parallel_rmsnorm import ParallelQwen2RMSNorm - - -class ParallelQwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config) - - self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) - self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) - self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Note: sequence parallel is hidden inside ColumnParallelLinear - # reduce scatter is hidden inside RowParallelLinear - - # Self Attention - hidden_states = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - # TODO: add sequence parallel operator reduce_scatter here - - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - # TODO: add sequence parallel operator all_gather here - - hidden_states = self.mlp(hidden_states) - - # TODO: add sequence parallel operator reduce_scatter here - - hidden_states = residual + hidden_states - - outputs = hidden_states - - return outputs - - -class ParallelQwen2DecoderLayerRmPad(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.hidden_size = config.hidden_size - self.layer_idx = layer_idx - self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config) - - self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config) - self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config) - self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config) - - def forward( - self, - hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states # (total_nnz // sp, 1, hidden_size) - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - # (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size) - # -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size) - hidden_states = self.self_attn( - hidden_states=hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = residual + hidden_states - - # Fully Connected - # shape changes same as attn - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = hidden_states - - return outputs diff --git a/verl/models/qwen2/megatron/layers/parallel_linear.py b/verl/models/qwen2/megatron/layers/parallel_linear.py deleted file mode 100644 index e6d4a09f4..000000000 --- a/verl/models/qwen2/megatron/layers/parallel_linear.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The vLLM team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py - - -from megatron.core import tensor_parallel - - -class QKVParallelLinear(tensor_parallel.ColumnParallelLinear): - def __init__( - self, - input_size, - num_heads, - num_key_value_heads, - head_dim, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs, - ): - # Keep input parameters, and already restrict the head numbers - self.input_size = input_size - self.q_output_size = num_heads * head_dim - self.kv_output_size = num_key_value_heads * head_dim - self.head_dim = head_dim - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - input_size = self.input_size - output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim - - super().__init__( - input_size=input_size, - output_size=output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs, - ) - - -class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear): - def __init__( - self, - input_size, - gate_ouput_size, - up_output_size, - *, - bias=True, - gather_output=True, - skip_bias_add=False, - **kwargs, - ): - # Keep input parameters, and already restrict the head numbers - self.input_size = input_size - self.output_size = gate_ouput_size + up_output_size - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - super().__init__( - input_size=self.input_size, - output_size=self.output_size, - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - **kwargs, - ) diff --git a/verl/models/qwen2/megatron/layers/parallel_mlp.py b/verl/models/qwen2/megatron/layers/parallel_mlp.py deleted file mode 100644 index 672908a21..000000000 --- a/verl/models/qwen2/megatron/layers/parallel_mlp.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from megatron.core import ModelParallelConfig, tensor_parallel -from megatron.core import parallel_state as mpu -from torch import nn -from transformers.activations import ACT2FN - -from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear -from verl.utils.megatron import tensor_parallel as tp_utils - - -class ParallelQwen2MLP(nn.Module): - def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None: - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - # The weight is only [hidden_size, intermediate_size // model_parallel_world_size] - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear() - - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - assert row_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(row_kwargs, megatron_config) - tp_utils.update_kwargs_with_config(column_kwargs, megatron_config) - - tp_size = mpu.get_tensor_model_parallel_world_size() - - self.gate_up_proj = MergedColumnParallelLinear( - input_size=self.hidden_size, - gate_ouput_size=self.intermediate_size, - up_output_size=self.intermediate_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - self.gate_size = self.intermediate_size // tp_size - - self.down_proj = tensor_parallel.RowParallelLinear( - input_size=self.intermediate_size, - output_size=self.hidden_size, - bias=False, - input_is_parallel=True, - skip_bias_add=False, - **row_kwargs, - ) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - gate_up = self.gate_up_proj(x)[0] - gate, up = gate_up.split(self.gate_size, dim=-1) - return self.down_proj(self.act_fn(gate) * up)[0] diff --git a/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py b/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py deleted file mode 100644 index 2f4c90dd4..000000000 --- a/verl/models/qwen2/megatron/layers/parallel_rmsnorm.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numbers - -import torch -from apex.normalization.fused_layer_norm import fused_rms_norm_affine -from megatron.core import ModelParallelConfig -from torch import nn -from transformers import Qwen2Config - -from verl.utils.megatron import sequence_parallel as sp_utils - - -class ParallelQwen2RMSNorm(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - if isinstance(config.hidden_size, numbers.Integral): - normalized_shape = (config.hidden_size,) - self.normalized_shape = torch.Size(normalized_shape) - self.weight = nn.Parameter(torch.ones(self.normalized_shape)) - self.variance_epsilon = config.rms_norm_eps - - if megatron_config.sequence_parallel: - sp_utils.mark_parameter_as_sequence_parallel(self.weight) - - def forward(self, hidden_states): - return fused_rms_norm_affine( - input=hidden_states, - weight=self.weight, - normalized_shape=self.normalized_shape, - eps=self.variance_epsilon, - memory_efficient=True, - ) diff --git a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py deleted file mode 100644 index 92e81be8d..000000000 --- a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py +++ /dev/null @@ -1,737 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Qwen2 model.""" - -from typing import Optional - -import torch -import torch.utils.checkpoint -from megatron.core import ModelParallelConfig, mpu, parallel_state, tensor_parallel -from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.qwen2.configuration_qwen2 import Qwen2Config -from transformers.models.qwen2.modeling_qwen2 import CausalLMOutputWithPast - -from verl.utils.device import get_device_name -from verl.utils.megatron import sequence_parallel as sp_utils -from verl.utils.megatron import tensor_parallel as tp_utils -from verl.utils.megatron_utils import TransformerConfig, convert_config - -from .layers import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad, ParallelQwen2RMSNorm - -""" -TODO: -1. Add weight initialization. Here we need to be careful on TP weight init. -2. Add sequence parallel -3. Load checkpoint from Qwen2 pretrained checkpoint -""" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class ParallelQwen2Model(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - - self.layers = nn.ModuleList( - [ParallelQwen2DecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)] - ) - self.norm = ParallelQwen2RMSNorm(config, megatron_config) - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (batch_size, seq_length) - attention_mask: attention_mask. shape (batch_size, seq_length) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - batch_size, seq_length = input_ids.shape - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds) - - hidden_states = inputs_embeds - - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - hidden_states = layer_outputs - - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelQwen2ForCausalLM(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.model = ParallelQwen2Model(config, megatron_config=megatron_config) - self.vocab_size = config.vocab_size - - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - ) - - hidden_states = outputs - logits = self.lm_head(hidden_states)[0] - - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - - logits = logits.float() - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - -from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - -class ParallelQwen2ModelRmPad(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - self.megatron_config = megatron_config - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - - self.layers = nn.ModuleList( - [ParallelQwen2DecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)] - ) - self.norm = ParallelQwen2RMSNorm(config, megatron_config) - - def forward( - self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (1, totol_nnz) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) - - # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) - inputs_embeds = inputs_embeds.transpose(0, 1) - if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) - - hidden_states = inputs_embeds - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = layer_outputs - - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelQwen2ForCausalLMRmPad(nn.Module): - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.megatron_config = megatron_config - self.model = ParallelQwen2ModelRmPad(config, megatron_config=megatron_config) - self.vocab_size = config.vocab_size - self._init_head(config) - - def _init_head(self, config: Qwen2Config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - **column_kwargs, - ) - - def _forward_head(self, hidden_states): - # all_gather from sequence parallel region is performed inside lm_head - logits = self.lm_head(hidden_states)[0] - logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size) - return logits - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - batch_size, sequence_length = input_ids.shape - - # remove padding here - input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( - input_ids.unsqueeze(dim=-1), attention_mask - ) # (total_nnz, 1) - - # pad input_ids to multiple of tp for all tp ranks - # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap - if self.megatron_config.sequence_parallel: - input_ids = sp_utils.pad_to_sequence_parallel(input_ids) - - input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad) - - outputs = self.model( - input_ids=input_ids, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = outputs - - logits = self._forward_head(hidden_states) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension - # add removed padding back - logits = pad_input( - logits, indices, batch_size, seqlen=sequence_length - ) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - - -class ParallelQwen2ForValueRmPad(ParallelQwen2ForCausalLMRmPad): - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) - # lm_head is effectively the same as sequence parallel - sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) - - def _forward_head(self, hidden_states): - logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) - logits = logits.float() - if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - output = super().forward(input_ids, attention_mask, position_ids) - output.logits = torch.squeeze(output.logits, dim=-1) - return output - - -""" -Support pipeline parallelism -""" - - -class ParallelQwen2ModelRmPadPP(nn.Module): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - This model definition supports pipeline parallelism. To support pp and vpp, - - This model only contains layer in this pp stage and vpp chunk - - When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp. - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, pre_process, post_process): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.pre_process = pre_process - self.post_process = post_process - self.megatron_config = megatron_config - embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding() - if megatron_config is not None: - assert embedding_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config) - if pre_process: - self.embed_tokens = tensor_parallel.VocabParallelEmbedding( - num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, **embedding_kwargs - ) - else: - self.embed_tokens = None - - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = megatron_config.pipeline_model_parallel_size - self.num_layer_per_pp = config.num_hidden_layers // pp_size - vpp_size = megatron_config.virtual_pipeline_model_parallel_size - vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank() - - if vpp_size is not None: - self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size - self.num_layer_this_model = self.num_layer_vpp_chunk - offset = vpp_rank * (config.num_hidden_layers // vpp_size) + (pp_rank * self.num_layer_vpp_chunk) - else: - self.num_layer_this_model = self.num_layer_per_pp - offset = pp_rank * self.num_layer_per_pp - - self.layers = nn.ModuleList() - for i in range(self.num_layer_this_model): - layer = ParallelQwen2DecoderLayerRmPad(config, megatron_config, layer_idx=i + offset) - self.layers.add_module(f"{i}", layer) - - if post_process: - self.norm = ParallelQwen2RMSNorm(config, megatron_config) - else: - self.norm = None - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - self.input_tensor = input_tensor - - def forward( - self, - input_ids: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - sequence_length: int = None, - indices: torch.Tensor = None, - cu_seqlens: int = None, - max_seqlen_in_batch: int = None, - ) -> tuple | BaseModelOutputWithPast: - """ - - Args: - input_ids: input ids. shape (1, totol_nnz) - position_ids: position ids. shape (batch_size, seq_length) - - Returns: - - """ - if self.pre_process: - inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size) - - # vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron - # so need to deal with it by handle here: - # (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size) - inputs_embeds = inputs_embeds.transpose(0, 1) - if self.megatron_config.sequence_parallel: - inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds) - - hidden_states = inputs_embeds - else: - # self.hidden_states should be passed by Megatron - hidden_states = self.input_tensor - - for idx, decoder_layer in enumerate(self.layers): - layer_outputs = decoder_layer( - hidden_states, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - hidden_states = layer_outputs - - if self.post_process: - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class ParallelQwen2ForCausalLMRmPadPP(nn.Module): - def __init__( - self, - config: Qwen2Config, - megatron_config: ModelParallelConfig, - pre_process, - post_process, - share_embeddings_and_output_weights, - ): - super().__init__() - self.config: TransformerConfig = convert_config(config, megatron_config) - self.megatron_config = megatron_config - self.model = ParallelQwen2ModelRmPadPP( - config, megatron_config=megatron_config, pre_process=pre_process, post_process=post_process - ) - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - self.vocab_size = config.vocab_size - self.pre_process = pre_process - self.post_process = post_process - if post_process: - self._init_head(config) - if pre_process or post_process: - self.setup_embeddings_and_output_layer() - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - assert len(input_tensor) == 1 - self.model.set_input_tensor(input_tensor[0]) - - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = tensor_parallel.ColumnParallelLinear( - input_size=config.hidden_size, - output_size=config.vocab_size, - bias=False, - gather_output=False, - skip_bias_add=False, - skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, - **column_kwargs, - ) - - def setup_embeddings_and_output_layer(self) -> None: - """Sets up embedding layer in first stage and output layer in last stage. - - This function initalizes word embeddings in the final stage when we are - using pipeline parallelism and sharing word embeddings, and sets up param - attributes on the embedding and output layers. - """ - # Set `is_embedding_or_output_parameter` attribute. - if self.pre_process: - self.model.embed_tokens.weight.is_embedding_or_output_parameter = True - if self.post_process and self.lm_head.weight is not None: - self.lm_head.weight.is_embedding_or_output_parameter = True - - if not self.share_embeddings_and_output_weights: - return - - if parallel_state.get_pipeline_model_parallel_world_size() == 1: - # Zero out wgrad if sharing embeddings between two layers on same - # pipeline stage to make sure grad accumulation into main_grad is - # correct and does not include garbage values (e.g., from torch.empty). - self.shared_embedding_or_output_weight().zero_out_wgrad = True - return - - if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: - self.shared_embedding_or_output_weight().shared_embedding = True - - if self.post_process and not self.pre_process: - assert not parallel_state.is_pipeline_first_stage() - # set word_embeddings weights to 0 here, then copy first - # stage's weights using all_reduce below. - self.lm_head.weight.data.fill_(0) - self.lm_head.weight.shared = True - self.lm_head.weight.shared_embedding = True - - if torch.distributed.is_initialized() and parallel_state.is_rank_in_embedding_group(): - weight = self.shared_embedding_or_output_weight() - weight.data = weight.data.to(get_device_name()) - torch.distributed.all_reduce(weight.data, group=parallel_state.get_embedding_group()) - - def shared_embedding_or_output_weight(self) -> torch.Tensor: - if self.pre_process: - return self.model.embed_tokens.weight - elif self.post_process: - return self.lm_head.weight - return None - - def _forward_head(self, hidden_states): - # all_gather from sequence parallel region is performed inside lm_head - # print(f'logits shape before forward_head: {hidden_states.shape}, vocab_size = ' - # f'{self.config.vocab_size}') # [4, 32, 4096] - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - logits = self.lm_head(hidden_states, weight=output_weight)[0] - # print(f'logits shape after forward_head: {logits.shape}') # [8, 32, 8] - logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp) - return logits - - def forward( - self, - # original input - *, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - ```""" - - # Note that input_ids, attention_mask and position_ids should be passed to every pp layer. - # In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model - batch_size, sequence_length = input_ids.shape - # remove padding here - input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input( - input_ids.unsqueeze(dim=-1), attention_mask - ) # (total_nnz, 1) - - # pad input_ids to multiple of tp for all tp ranks - # TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap - if self.megatron_config.sequence_parallel: - input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad) - - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad) - - outputs = self.model( - input_ids=input_ids_rmpad, - position_ids=position_ids, - sequence_length=sequence_length, - indices=indices, - cu_seqlens=cu_seqlens, - max_seqlen_in_batch=max_seqlen_in_batch, - ) - - if self.post_process: - hidden_states = outputs - logits = self._forward_head(hidden_states) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input( - logits, indices, batch_size, seqlen=sequence_length - ) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( - loss=None, - logits=logits, - past_key_values=None, - hidden_states=None, - attentions=None, - ) - else: - return outputs - - -class ParallelQwen2ForValueRmPadPP(ParallelQwen2ForCausalLMRmPadPP): - def _init_head(self, config): - column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear() - if self.megatron_config is not None: - assert column_kwargs.get("config", False), "must have ModelParallelConfig" - tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config) - self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False) - # lm_head is effectively the same as sequence parallel - sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight) - - def _forward_head(self, hidden_states): - logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1) - logits = logits.float() - if self.megatron_config.sequence_parallel: - logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False) - return logits - - def forward( - self, - *, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - ) -> tuple | CausalLMOutputWithPast: - output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) - if self.post_process: - output.logits = torch.squeeze(output.logits, dim=-1) - return output - else: - return output diff --git a/verl/models/registry.py b/verl/models/registry.py deleted file mode 100644 index 829b9e20c..000000000 --- a/verl/models/registry.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import importlib -from typing import Optional - -import torch.nn as nn - -# Supported models in Megatron-LM -# Architecture -> (module, class). -_MODELS = { - "LlamaForCausalLM": ( - "llama", - ("ParallelLlamaForCausalLMRmPadPP", "ParallelLlamaForValueRmPadPP", "ParallelLlamaForCausalLMRmPad"), - ), - "Qwen2ForCausalLM": ( - "qwen2", - ("ParallelQwen2ForCausalLMRmPadPP", "ParallelQwen2ForValueRmPadPP", "ParallelQwen2ForCausalLMRmPad"), - ), - "MistralForCausalLM": ( - "mistral", - ("ParallelMistralForCausalLMRmPadPP", "ParallelMistralForValueRmPadPP", "ParallelMistralForCausalLMRmPad"), - ), -} - - -# return model class -class ModelRegistry: - @staticmethod - def load_model_cls(model_arch: str, value=False) -> Optional[type[nn.Module]]: - if model_arch not in _MODELS: - return None - - megatron = "megatron" - - module_name, model_cls_name = _MODELS[model_arch] - if not value: # actor/ref - model_cls_name = model_cls_name[0] - elif value: # critic/rm - model_cls_name = model_cls_name[1] - - module = importlib.import_module(f"verl.models.{module_name}.{megatron}.modeling_{module_name}_megatron") - return getattr(module, model_cls_name, None) - - @staticmethod - def get_supported_archs() -> list[str]: - return list(_MODELS.keys()) diff --git a/verl/models/transformers/__init__.py b/verl/models/transformers/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/models/transformers/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/models/transformers/dense_common.py b/verl/models/transformers/dense_common.py deleted file mode 100644 index 56fe293f5..000000000 --- a/verl/models/transformers/dense_common.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Optional, Union - -import torch -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast - - -@dataclass -class CausalLMOutputForPPO(CausalLMOutputWithPast): - log_probs: Optional[torch.FloatTensor] = None - entropy: Optional[torch.FloatTensor] = None - - -def forward_base_model( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, -) -> CausalLMOutputWithPast: - r""" - Copy paste LLaMa's forward - https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py - - This function should be generic enough for all pure text models. - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - return outputs - - -def forward_with_torch_backend( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: int | torch.Tensor = 0, - temperature: float = 1.0, - **loss_kwargs, -) -> tuple | CausalLMOutputForPPO: - from verl.utils.experimental.torch_functional import FusedLinearForPPO - - outputs = forward_base_model( - self, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - - if not return_dict: - raise NotImplementedError("forward_with_torch_backend has to return_dict") - - # Loss calculations - if labels is not None: - rolled_labels = torch.roll(labels, shifts=-1, dims=-1) - elif input_ids is not None: - rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) - else: - raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") - - fused_linear_for_ppo = FusedLinearForPPO() - log_probs, entropy = fused_linear_for_ppo.forward( - hidden_states=hidden_states, - vocab_weights=self.lm_head.weight, - input_ids=rolled_labels, - temperature=temperature, - ) - - return CausalLMOutputForPPO( - log_probs=log_probs, - entropy=entropy, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def forward_with_triton_backend( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: int | torch.Tensor = 0, - temperature: float = 1.0, - **loss_kwargs, -) -> tuple | CausalLMOutputForPPO: - from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy - - outputs = forward_base_model( - self, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - - if not return_dict: - raise NotImplementedError("forward_with_triton_backend has to return_dict") - - # Loss calculations - if labels is not None: - rolled_labels = torch.roll(labels, shifts=-1, dims=-1) - elif input_ids is not None: - rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) - else: - raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") - - log_probs, entropy = linear_cross_entropy( - hidden_states, - self.lm_head.weight, - rolled_labels, - temperature, - "none", - ) - - return CausalLMOutputForPPO( - log_probs=log_probs, - entropy=entropy, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/verl/models/transformers/kimi_vl.py b/verl/models/transformers/kimi_vl.py deleted file mode 100644 index edd79364b..000000000 --- a/verl/models/transformers/kimi_vl.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -import torch.nn.functional as F -from transformers.cache_utils import Cache -from transformers.modeling_flash_attention_utils import _flash_attention_forward - -from verl.utils.ulysses import ( - gather_heads_scatter_seq, - gather_seq_scatter_heads, - get_ulysses_sequence_parallel_world_size, - validate_ulysses_config, -) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def _ulysses_flash_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) - - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - # patch - ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - if ulysses_sp_size > 1: - validate_ulysses_config(self.num_heads, ulysses_sp_size) - - num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads - k_pe = repeat_kv(k_pe, ulysses_sp_size) # to keep heads=1 after a2a - k_nope = repeat_kv(k_nope, num_key_value_groups) - value_states = repeat_kv(value_states, num_key_value_groups) - q = gather_seq_scatter_heads(q, seq_dim=2, head_dim=1) - k_pe = gather_seq_scatter_heads(k_pe, seq_dim=2, head_dim=1) - k_nope = gather_seq_scatter_heads(k_nope, seq_dim=2, head_dim=1) - value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) - # (batch_size, num_head / sp_size, seq_length, head_size) - full_q_len = q.size(2) # full_q_len = seq_length - - else: - full_q_len = q_len - - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - cos, sin = self.rotary_emb(value_states, seq_len=full_q_len) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - - key_states = k_pe.new_empty(bsz, self.num_heads // ulysses_sp_size, full_q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - - if self.q_head_dim != self.v_head_dim: - value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - full_q_len, - dropout=dropout_rate, - sliding_window=None, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - position_ids=position_ids, # important: pass position ids - softmax_scale=self.softmax_scale, - ) - - if ulysses_sp_size > 1: - attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) - - if self.q_head_dim != self.v_head_dim: - attn_output = attn_output[:, :, :, : self.v_head_dim] - - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim).contiguous() - attn_output = self.o_proj(attn_output) - - return attn_output, None, None diff --git a/verl/models/transformers/llama.py b/verl/models/transformers/llama.py deleted file mode 100644 index 687ceab71..000000000 --- a/verl/models/transformers/llama.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -from typing import Callable, Optional - -import torch - -if sys.version_info >= (3, 11): - pass -else: - pass - -from transformers.cache_utils import Cache -from transformers.modeling_flash_attention_utils import _flash_attention_forward -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -from transformers.utils import logging - -from verl.utils.ulysses import ( - gather_heads_scatter_seq, - gather_seq_scatter_heads, - get_ulysses_sequence_parallel_world_size, - validate_ulysses_config, -) - -logger = logging.get_logger(__name__) - - -def llama_flash_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - """ - Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. - - NOTE: This function is used for transformers versions in the range [4.45.0, 4.47.1]. - """ - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - # trade off: repeat first and then all to all - # key_states = repeat_kv(key_states, self.num_key_value_groups) - # value_states = repeat_kv(value_states, self.num_key_value_groups) - - ########## AlltoAll for Ulysses ########## - ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - - if ulysses_sp_size > 1: - validate_ulysses_config(self.num_heads, ulysses_sp_size) - - # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) - query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) - key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) - value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) - - full_q_len = query_states.size(2) # full seq length - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to " - f"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the " - f"input in {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - full_q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - **kwargs, - ) - - attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() - ########## AlltoAll for Ulysses ########## - if ulysses_sp_size > 1: - attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def llama_attn_forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - """ - Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. - - NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. - """ - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - from transformers.models.llama.modeling_llama import eager_attention_forward - - bsz, q_len, _ = hidden_states.shape - - query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - ########## AlltoAll for Ulysses ########## - ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - - if ulysses_sp_size > 1: - validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) - - query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) - key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) - value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) - - full_q_len = query_states.size(2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() - ########## AlltoAll for Ulysses ########## - if ulysses_sp_size > 1: - attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py deleted file mode 100644 index d6be65a77..000000000 --- a/verl/models/transformers/monkey_patch.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Apply monkey-patch function to models -""" - -import importlib.metadata -import sys -from functools import lru_cache -from typing import Optional - -import torch -from packaging import version -from transformers.modeling_flash_attention_utils import _flash_attention_forward -from transformers.modeling_utils import PreTrainedModel - -from verl.utils.import_utils import is_trl_available -from verl.utils.ulysses import ( - gather_heads_scatter_seq, - gather_seq_scatter_heads, - get_ulysses_sequence_parallel_group, - get_ulysses_sequence_parallel_world_size, - slice_input_tensor, -) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=2, repeats=n_rep). The hidden states go from (batch, - seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim) - """ - batch, slen, num_key_value_heads, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim) - return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim) - - -def _ulysses_flash_attention_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - *args, - position_ids: Optional[torch.Tensor] = None, - **kwargs, -): - """Insert all-to-all before and after flash attention. - DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509 - - Args: - query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim) - key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim) - value_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim) - position_ids (torch.Tensor, optional): (batch_size, seqlen/sp_size) - - Returns: - torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim) - """ - ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - - ########## AlltoAll for Ulysses ########## - if ulysses_sp_size > 1: - assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism" - - # NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k, - # we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA. - # For example: - # - nheads_k=4, sp=8, repeats=2 - # - nheads_k=8, sp=8, repeats=1 - # - nheads_k=16, sp=8, repeats=1 - repeats = max(ulysses_sp_size // key_states.size(2), 1) - key_states = repeat_kv(key_states, repeats) - value_states = repeat_kv(value_states, repeats) - - # (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim) - query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2) - key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2) - value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2) - - # TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate - # this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly. - # https://github.com/huggingface/transformers/pull/33932 - - # (bsz, seq_len/n) -> (bsz, seq_len) - position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)] - torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group()) - position_ids = torch.concat(position_ids_list, dim=-1) - - # (bsz, seq_len, n_head/n, head_dim) - attn_output = _flash_attention_forward( - query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs - ) - - ########## AlltoAll for Ulysses ########## - if ulysses_sp_size > 1: - # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim) - attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) - - return attn_output - - -def patch_vlm_for_ulysses_input_slicing(model_class: type): - """ - Applies a monkey patch to the forward method of a given model class - to enable Ulysses sequence parallelism input slicing. - """ - - def _create_ulysses_wrapped_decoder_forward(original_forward): - def ulysses_wrapped_decoder_forward(self, *args, **kwargs): - inputs_embeds = kwargs.get("inputs_embeds") - call_kwargs = kwargs.copy() - - current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - - slice_now = ( - inputs_embeds is not None - and current_ulysses_sp_size > 1 - and getattr(self, "_needs_initial_slice", True) - ) - if slice_now: - call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False) - self._needs_initial_slice = False - try: - return original_forward(self, *args, **call_kwargs) - finally: - if slice_now: - self._needs_initial_slice = True - - return ulysses_wrapped_decoder_forward - - original_forward = model_class.forward - wrapped_forward = _create_ulysses_wrapped_decoder_forward(original_forward) - model_class.forward = wrapped_forward - print(f"Monkey patch {model_class.__name__}.forward for Ulysses SP input slicing.") - - -def patch_forward_with_backends( - model: PreTrainedModel, - use_fused_kernels: bool = False, - fused_kernels_backend: str = None, -): - """ - Choose the forward function based on the model and backend. - Args: - model (PreTrainedModel): The model to apply the monkey patch. - use_fused_kernels (bool): Whether to use fused kernels. - fused_kernels_backend (str): The backend to use for fused kernels. - """ - if not use_fused_kernels or fused_kernels_backend not in ["triton", "torch"]: - print( - f"Skipping monkey patch for {model.__class__.__name__} as use_fused_kernels is " - f"{use_fused_kernels} or fused_kernels_backend is {fused_kernels_backend}" - ) - return - - forward_with_torch_backend_function = model.__class__.forward - forward_with_triton_backend_function = model.__class__.forward - if model.config.model_type == "qwen2_5_vl": - from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend - - forward_with_torch_backend_function = forward_with_torch_backend - forward_with_triton_backend_function = forward_with_triton_backend - elif model.config.model_type == "qwen2_vl": - from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend - - forward_with_torch_backend_function = forward_with_torch_backend - forward_with_triton_backend_function = forward_with_triton_backend - else: - from verl.models.transformers.dense_common import forward_with_torch_backend, forward_with_triton_backend - - forward_with_torch_backend_function = forward_with_torch_backend - forward_with_triton_backend_function = forward_with_triton_backend - - if fused_kernels_backend == "triton": - model.__class__.forward = forward_with_triton_backend_function - print(f"Using Triton backend for fused kernels in {model.__class__.__name__}") - elif fused_kernels_backend == "torch": - model.__class__.forward = forward_with_torch_backend_function - print(f"Using Torch backend for fused kernels in {model.__class__.__name__}") - else: - raise ValueError(f"Unsupported fused_kernels_backend: {fused_kernels_backend}. Choose 'triton' or 'torch'.") - - -def apply_monkey_patch( - model: PreTrainedModel, - ulysses_sp_size: int = 1, - use_remove_padding: bool = True, - use_fused_kernels: bool = False, - fused_kernels_backend: str = None, -): - """ - Apply monkey patch to the models for ulysses sequence parallel and fused kernel. - - In the end of this function forward function of the model is patched for fused kernel. - If the model is not supported with fused kernel, please return after patch. - """ - - """Replace _flash_attention_forward to _ulysses_flash_attention_forward""" - module = sys.modules[model.__module__] - - try: - num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads - except AttributeError: - num_attention_heads, num_key_value_heads = ( - model.config.text_config.num_attention_heads, - model.config.text_config.num_key_value_heads, - ) - - assert num_attention_heads % ulysses_sp_size == 0, ( - f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}" - ) - assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, ( - f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size " - f"{ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0," - f"kv heads are repeated to ensure correctness." - ) - - if is_trl_available(): - from trl import AutoModelForCausalLMWithValueHead # type: ignore - - def state_dict(self, *args, **kwargs): - return torch.nn.Module.state_dict(self, *args, **kwargs) - - AutoModelForCausalLMWithValueHead.state_dict = state_dict - print("Monkey patch state_dict in AutoModelForCausalLMWithValueHead. ") - - # TODO: VLM models only, unify monkey patch to LLM models. - if model.config.model_type == "qwen2_5_vl": - if is_transformers_version_in_range(min_version="4.53.0"): - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention - - # TODO: Support transformers 4.53 - raise ValueError("Transformers 4.53 is not supported") - else: - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention, - ) - - if use_remove_padding or ulysses_sp_size > 1: - from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward - - Qwen2_5_VLAttention.forward = ulysses_flash_attn_forward - print("Monkey patch FlashAttention2.forward in Qwen2.5VL") - - if ulysses_sp_size > 1: - if is_transformers_version_in_range(min_version="4.52.0"): - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel - - patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel) - else: - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel - - patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel) - - elif model.config.model_type == "qwen2_vl": - if is_transformers_version_in_range(min_version="4.53.0"): - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention - - # TODO: Support transformers 4.53 - raise ValueError("Transformers 4.53 is not supported") - else: - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention - - if use_remove_padding or ulysses_sp_size > 1: - from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward - - Qwen2VLAttention.forward = ulysses_flash_attn_forward - print("Monkey patch FlashAttention2.forward in Qwen2VL") - - if ulysses_sp_size > 1: - if is_transformers_version_in_range(min_version="4.52.0"): - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel - - patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel) - else: - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel - - patch_vlm_for_ulysses_input_slicing(Qwen2VLModel) - - elif model.config.model_type == "kimi_vl": - if use_remove_padding or ulysses_sp_size > 1: - # TODO: Changes need to be made when transformers are adapted. - from verl.models.transformers.kimi_vl import _ulysses_flash_attn_forward - - module.DeepseekV3FlashAttention2.forward = _ulysses_flash_attn_forward - print("Monkey patch FlashAttention2.forward in KimiVL") - - if ulysses_sp_size > 1: - patch_vlm_for_ulysses_input_slicing(module.DeepseekV3ForCausalLM) - - if use_fused_kernels: - print("Not support fused kernels for KimiVL") - - return - - # transformers<=4.47.1 - if use_remove_padding or ulysses_sp_size > 1: - if hasattr(module, "_flash_attention_forward"): - module._flash_attention_forward = _ulysses_flash_attention_forward - print(f"Monkey patch _flash_attention_forward in {model.__module__}") - else: - # transformers>=4.48.0 - from transformers.integrations import flash_attention - - flash_attention._flash_attention_forward = _ulysses_flash_attention_forward - print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}") - - patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend) - - -@lru_cache -def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool: - try: - # Get the installed version of the transformers library - transformers_version_str = importlib.metadata.version("transformers") - except importlib.metadata.PackageNotFoundError as e: - raise ModuleNotFoundError("The `transformers` package is not installed.") from e - - transformers_version = version.parse(transformers_version_str) - - lower_bound_check = True - if min_version is not None: - lower_bound_check = version.parse(min_version) <= transformers_version - - upper_bound_check = True - if max_version is not None: - upper_bound_check = transformers_version <= version.parse(max_version) - - return lower_bound_check and upper_bound_check diff --git a/verl/models/transformers/npu_patch.py b/verl/models/transformers/npu_patch.py deleted file mode 100644 index e6bb37368..000000000 --- a/verl/models/transformers/npu_patch.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Copyright 2025 The Qwen Team and The HuggingFace Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch_npu -from torch_npu import npu_rotary_mul as apply_rotary_emb -from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm - - -# This patch takes effect when using apply_rotary_pos_emb_flashatt on qwen2_5_vl and will be removed in -# subsequent versions -# https://github.com/huggingface/transformers/pull/38491 -def apply_rotary_pos_emb_flashatt_npu( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - cos = cos.chunk(2, dim=-1)[0].contiguous() - sin = sin.chunk(2, dim=-1)[0].contiguous() - cos = cos.repeat(1, 2) - sin = sin.repeat(1, 2) - q_embed = apply_rotary_emb( - q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float() - ).type_as(q) - k_embed = apply_rotary_emb( - k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float() - ).type_as(k) - return q_embed, k_embed - - -# This api can improve performance on ASCEND NPU -def rms_norm_forward(self, x): - return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0] - - -Qwen2RMSNorm.forward = rms_norm_forward -modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu diff --git a/verl/models/transformers/qwen2.py b/verl/models/transformers/qwen2.py deleted file mode 100644 index e55fb26d5..000000000 --- a/verl/models/transformers/qwen2.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Optional - -import torch -from transformers.cache_utils import Cache -from transformers.modeling_flash_attention_utils import _flash_attention_forward -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv -from transformers.utils import logging - -from verl.utils.ulysses import ( - gather_heads_scatter_seq, - gather_seq_scatter_heads, - get_ulysses_sequence_parallel_world_size, - validate_ulysses_config, -) - -logger = logging.get_logger(__name__) - - -def qwen2_flash_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 -): - """ - Adapted from transformers 4.47.1 to support Ulysses sequence parallelism. - - NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1. - """ - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - ########## AlltoAll for Ulysses ########## - ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - - if ulysses_sp_size > 1: - validate_ulysses_config(self.num_heads, ulysses_sp_size) - - # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) - query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) - key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) - value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) - - full_q_len = query_states.size(2) # full seq length - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to " - f"the fact you have upcasted embedding or layer norm layers in float32. We will cast back the " - f"input in {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - full_q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - # use full_q_len to reshape - attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() - ########## AlltoAll for Ulysses ########## - if ulysses_sp_size > 1: - attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def qwen2_attn_forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - """ - Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0. - - NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0. - """ - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - - bsz, q_len, _ = hidden_states.shape - hidden_shape = (bsz, q_len, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - ########## AlltoAll for Ulysses ########## - ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - - if ulysses_sp_size > 1: - validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size) - - # (bsz, n_head, seq_len/n, head_dim) -> (bsz, n_head/n, seq_len, head_dim) - query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) - key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) - value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) - - full_q_len = query_states.size(2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - sliding_window = None - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - - from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. " - "Falling back to eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=sliding_window, # main diff with Llama - **kwargs, - ) - - attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous() - ########## AlltoAll for Ulysses ########## - if ulysses_sp_size > 1: - # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim) - attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights diff --git a/verl/models/transformers/qwen2_5_vl.py b/verl/models/transformers/qwen2_5_vl.py deleted file mode 100644 index 51d9753fb..000000000 --- a/verl/models/transformers/qwen2_5_vl.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Optional - -import torch -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLCausalLMOutputWithPast, - Qwen2_5_VLForConditionalGeneration, -) - - -@dataclass -class Qwen2_5_VLCausalLMOutputForPPO(Qwen2_5_VLCausalLMOutputWithPast): - log_probs: Optional[torch.FloatTensor] = None - entropy: Optional[torch.FloatTensor] = None - - -def forward_base_model( - self: Qwen2_5_VLForConditionalGeneration, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, -) -> tuple | Qwen2_5_VLCausalLMOutputWithPast: - r""" - Copy paste Qwen2_5_VL's forward - https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_5_vl.py - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, " - f"features {n_image_features}" - ) - - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) - - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, " - f"features {n_video_features}" - ) - - mask = input_ids == self.config.video_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - video_mask = mask_expanded.to(inputs_embeds.device) - - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: - position_ids, rope_deltas = self.get_rope_index( - input_ids, - image_grid_thw, - video_grid_thw, - second_per_grid_ts, - attention_mask, - ) - self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 - position_ids = torch.arange(seq_length, device=inputs_embeds.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - - outputs = self.model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - return outputs - - -def forward_with_torch_backend( - self: Qwen2_5_VLForConditionalGeneration, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - temperature: float = 1.0, - **loss_kwargs, -) -> tuple | Qwen2_5_VLCausalLMOutputForPPO: - from verl.utils.experimental.torch_functional import FusedLinearForPPO - - outputs = forward_base_model( - self, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - pixel_values=pixel_values, - pixel_values_videos=pixel_values_videos, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - rope_deltas=rope_deltas, - cache_position=cache_position, - second_per_grid_ts=second_per_grid_ts, - ) - - hidden_states = outputs[0] - - if not return_dict: - raise NotImplementedError("forward_with_torch_backend has to return_dict") - - # Loss calculations - if labels is not None: - rolled_labels = torch.roll(labels, shifts=-1, dims=-1) - elif input_ids is not None: - rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) - else: - raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") - - fused_linear_for_ppo = FusedLinearForPPO() - log_probs, entropy = fused_linear_for_ppo.forward( - hidden_states=hidden_states, - vocab_weights=self.lm_head.weight, - input_ids=rolled_labels, - temperature=temperature, - ) - - return Qwen2_5_VLCausalLMOutputForPPO( - log_probs=log_probs, - entropy=entropy, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=rope_deltas, - ) - - -def forward_with_triton_backend( - self: Qwen2_5_VLForConditionalGeneration, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - temperature: float = 1.0, - **loss_kwargs, -) -> tuple | Qwen2_5_VLCausalLMOutputForPPO: - from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy - - outputs = forward_base_model( - self, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - pixel_values=pixel_values, - pixel_values_videos=pixel_values_videos, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - rope_deltas=rope_deltas, - cache_position=cache_position, - second_per_grid_ts=second_per_grid_ts, - ) - - hidden_states = outputs[0] - - if not return_dict: - raise NotImplementedError("forward_with_triton_backend has to return_dict") - - # Loss calculations - if labels is not None: - rolled_labels = torch.roll(labels, shifts=-1, dims=-1) - elif input_ids is not None: - rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) - else: - raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") - - log_probs, entropy = linear_cross_entropy( - hidden_states, - self.lm_head.weight, - rolled_labels, - temperature, - "none", - ) - - return Qwen2_5_VLCausalLMOutputForPPO( - log_probs=log_probs, - entropy=entropy, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=rope_deltas, - ) diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py deleted file mode 100644 index 358b00b6b..000000000 --- a/verl/models/transformers/qwen2_vl.py +++ /dev/null @@ -1,559 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import os -from dataclasses import dataclass -from typing import Optional - -import torch -from transformers.modeling_flash_attention_utils import _flash_attention_forward -from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLCausalLMOutputWithPast, - Qwen2VLForConditionalGeneration, -) -from transformers.utils import is_flash_attn_greater_or_equal - -from verl.utils.ulysses import ( - gather_heads_scatter_seq, - gather_seq_scatter_heads, - get_ulysses_sequence_parallel_world_size, - validate_ulysses_config, -) - -try: - from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func - - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) -except ImportError: - flash_attn_varlen_func = None - - -def get_rope_index( - processor, - input_ids: torch.Tensor, - image_grid_thw: Optional[torch.Tensor] = None, - video_grid_thw: Optional[torch.Tensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence. - The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. - https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546 - """ - spatial_merge_size = processor.image_processor.merge_size - tokens_per_second = 2 - image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") - video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") - vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") - if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - - position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) - image_index, video_index = 0, 0 - input_ids = input_ids[attention_mask == 1] - image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - second_per_grid_t = 0 - image_index += 1 - remain_images -= 1 - ed = ed_image - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - second_per_grid_t = second_per_grid_ts[video_index] if second_per_grid_ts is not None else 1.0 - - video_index += 1 - remain_videos -= 1 - ed = ed_video - - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, - ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) - t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) - else: - if attention_mask is not None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) - else: - position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) - - return position_ids - - -def prepare_fa2_from_position_ids( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor -): - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) - position_ids = position_ids.flatten() - indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) - cu_seqlens = torch.cat( - ( - indices_q[position_ids == 0], - torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), - ) - ) - max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope - return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length)) - - -def flash_attention_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, - is_causal: bool = True, - position_ids: Optional[torch.Tensor] = None, - sliding_window: Optional[int] = None, - use_top_left_mask: bool = False, - deterministic: Optional[bool] = None, - **kwargs, -): - """ - Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length) - """ - causal = is_causal if not use_top_left_mask else is_causal and query_length != 1 - - # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). - use_sliding_windows = ( - _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window - ) - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - - if is_flash_attn_greater_or_equal("2.4.1"): - if deterministic is None: - deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" - flash_kwargs["deterministic"] = deterministic - - if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all(): - batch_size = query_states.size(0) - query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( - query_states, key_states, value_states, position_ids[0] - ) # remove channel dimension - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=kwargs.pop("dropout", 0.0), - softmax_scale=kwargs.pop("softmax_scale", None), - causal=causal, - **flash_kwargs, - ) - attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) - else: - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_length, - is_causal=is_causal, - sliding_window=sliding_window, - use_top_left_mask=use_top_left_mask, - deterministic=deterministic, - **kwargs, - ) # do not pass position_ids to old flash_attention_forward - - return attn_output - - -def ulysses_flash_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, -) -> tuple[torch.Tensor, None, None]: - from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv - - bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size - query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - - if ulysses_sp_size > 1: - validate_ulysses_config(self.num_heads, ulysses_sp_size) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) - key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) - value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) - # (batch_size, num_head / sp_size, seq_length, head_size) - full_q_len = query_states.size(2) # full_q_len = seq_length - else: - full_q_len = q_len - - # Because the input can be padded, the absolute sequence length depends on the max position id. - if position_embeddings is None: - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] - ) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - attn_output = flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - full_q_len, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - position_ids=position_ids, # important: pass position ids - ) # (batch_size, seq_length, num_head / sp_size, head_size) - if ulysses_sp_size > 1: - attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, None, None - - -@dataclass -class Qwen2VLCausalLMOutputForPPO(Qwen2VLCausalLMOutputWithPast): - log_probs: Optional[torch.FloatTensor] = None - entropy: Optional[torch.FloatTensor] = None - - -def forward_base_model( - self: Qwen2VLForConditionalGeneration, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, -) -> tuple | Qwen2VLCausalLMOutputWithPast: - r""" - Copy paste Qwen2VL's forward - https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/qwen2_vl.py - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, " - f"features {n_image_features}" - ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, " - f"features {n_video_features}" - ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: - position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) - self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 - position_ids = torch.arange(seq_length, device=inputs_embeds.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) - position_ids = position_ids.add(delta) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) - - outputs = self.model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - return outputs - - -def forward_with_torch_backend( - self: Qwen2VLForConditionalGeneration, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - temperature: float = 1.0, - **loss_kwargs, -) -> tuple | Qwen2VLCausalLMOutputForPPO: - from verl.utils.experimental.torch_functional import FusedLinearForPPO - - outputs = forward_base_model( - self, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - pixel_values=pixel_values, - pixel_values_videos=pixel_values_videos, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - rope_deltas=rope_deltas, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - - if not return_dict: - raise NotImplementedError("forward_with_torch_backend has to return_dict") - - # Loss calculations - if labels is not None: - rolled_labels = torch.roll(labels, shifts=-1, dims=-1) - elif input_ids is not None: - rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) - else: - raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") - - fused_linear_for_ppo = FusedLinearForPPO() - log_probs, entropy = fused_linear_for_ppo.forward( - hidden_states=hidden_states, - vocab_weights=self.lm_head.weight, - input_ids=rolled_labels, - temperature=temperature, - ) - - return Qwen2VLCausalLMOutputForPPO( - log_probs=log_probs, - entropy=entropy, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=rope_deltas, - ) - - -def forward_with_triton_backend( - self: Qwen2VLForConditionalGeneration, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - temperature: float = 1.0, - **loss_kwargs, -) -> tuple | Qwen2VLCausalLMOutputForPPO: - from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy - - outputs = forward_base_model( - self, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - pixel_values=pixel_values, - pixel_values_videos=pixel_values_videos, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - rope_deltas=rope_deltas, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - - if not return_dict: - raise NotImplementedError("forward_with_triton_backend has to return_dict") - - # Loss calculations - if labels is not None: - rolled_labels = torch.roll(labels, shifts=-1, dims=-1) - elif input_ids is not None: - rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) - else: - raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") - - log_probs, entropy = linear_cross_entropy( - hidden_states, - self.lm_head.weight, - rolled_labels, - temperature, - "none", - ) - - return Qwen2VLCausalLMOutputForPPO( - log_probs=log_probs, - entropy=entropy, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=rope_deltas, - ) diff --git a/verl/models/weight_loader_registry.py b/verl/models/weight_loader_registry.py deleted file mode 100644 index 8aa3bc71f..000000000 --- a/verl/models/weight_loader_registry.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -def get_weight_loader(arch: str): - from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel - - _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY = { - "LlamaForCausalLM": load_state_dict_to_megatron_gptmodel, - "Qwen2ForCausalLM": load_state_dict_to_megatron_gptmodel, - } - - if arch in _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY: - return _MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY[arch] - raise ValueError( - f"Model architectures {arch} loader are not supported for now. Supported architectures: " - f"{_MODEL_WEIGHT_MEGATRON_LOADER_REGISTRY.keys()}" - ) - - -def get_weight_saver(arch: str): - from verl.models.mcore.saver import ( - merge_megatron_ckpt_gptmodel, - merge_megatron_ckpt_gptmodel_dpskv3, - merge_megatron_ckpt_gptmodel_mixtral, - merge_megatron_ckpt_gptmodel_qwen2_5_vl, - merge_megatron_ckpt_gptmodel_qwen_moe, - ) - - _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = { - "LlamaForCausalLM": merge_megatron_ckpt_gptmodel, - "Qwen2ForCausalLM": merge_megatron_ckpt_gptmodel, - "MixtralForCausalLM": merge_megatron_ckpt_gptmodel_mixtral, - "Qwen2MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, - "Qwen2_5_VLForConditionalGeneration": merge_megatron_ckpt_gptmodel_qwen2_5_vl, - "DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3, - "Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel, - "Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe, - } - if arch in _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY: - return _MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY[arch] - raise ValueError( - f"Model architectures {arch} saver are not supported for now. Supported architectures: " - f"{_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY.keys()}" - ) diff --git a/verl/protocol.py b/verl/protocol.py deleted file mode 100644 index 39979f848..000000000 --- a/verl/protocol.py +++ /dev/null @@ -1,964 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Implement base data transfer protocol between any two functions, modules. -We can subclass Protocol to define more detailed batch info with specific keys -""" - -import contextlib -import copy -import logging -import os -import pickle -from dataclasses import dataclass, field -from typing import Callable, Optional - -import numpy as np -import pandas as pd -import ray -import tensordict -import torch -import torch.distributed -from packaging import version -from tensordict import TensorDict -from torch.utils.data import DataLoader - -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.py_functional import union_two_dict -from verl.utils.torch_functional import allgather_dict_tensors - -__all__ = ["DataProto", "union_tensor_dict"] - -with contextlib.suppress(Exception): - tensordict.set_lazy_legacy(False).set() - - -class _DataProtoConfigMeta(type): - _config = {} - - auto_padding_key = "_verl_auto_padding" - - @property - def auto_padding(cls): - enabled_by_env = os.getenv("VERL_AUTO_PADDING", "FALSE").upper() in ["TRUE", "1"] - return enabled_by_env or cls._config.get(cls.auto_padding_key, False) - - @auto_padding.setter - def auto_padding(cls, enabled: bool): - assert isinstance(enabled, bool), f"enabled must be a boolean, got {enabled} as {type(enabled)}" - cls._config[cls.auto_padding_key] = enabled - - -class DataProtoConfig(metaclass=_DataProtoConfigMeta): - pass - - -_padding_size_key = "_padding_size_key_x123d" - - -def pad_dataproto_to_divisor(data: "DataProto", size_divisor: int): - """Pad a DataProto to size divisible by size_divisor - - Args: - size_divisor (int): size divisor - - Returns: - data: (DataProto): the padded DataProto - pad_size (int) - """ - assert isinstance(data, DataProto), "data must be a DataProto" - if len(data) % size_divisor != 0: - pad_size = size_divisor - len(data) % size_divisor - padding_protos = [] - remaining_pad = pad_size - while remaining_pad > 0: - take_size = min(remaining_pad, len(data)) - padding_protos.append(data[:take_size]) - remaining_pad -= take_size - data_padded = DataProto.concat([data] + padding_protos) - else: - if len(data) == 0: - logging.warning("padding a DataProto with no item, no changed made") - pad_size = 0 - data_padded = data - return data_padded, pad_size - - -def unpad_dataproto(data: "DataProto", pad_size): - """Unpad the data proto with pad_size. i.e. `data[:-pad_size]`""" - if pad_size != 0: - data = data[:-pad_size] - return data - - -def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: - """Union two tensordicts.""" - assert tensor_dict1.batch_size == tensor_dict2.batch_size, ( - f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" - ) - for key in tensor_dict2.keys(): - if key not in tensor_dict1.keys(): - tensor_dict1[key] = tensor_dict2[key] - else: - assert tensor_dict1[key].equal(tensor_dict2[key]), ( - f"{key} in tensor_dict1 and tensor_dict2 are not the same object" - ) - - return tensor_dict1 - - -def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str, np.ndarray]) -> dict[str, np.ndarray]: - for key, val in tensor_dict2.items(): - if key in tensor_dict1: - assert isinstance(tensor_dict2[key], np.ndarray) - assert isinstance(tensor_dict1[key], np.ndarray) - # to properly deal with nan and object type - assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), ( - f"{key} in tensor_dict1 and tensor_dict2 are not the same object" - ) - tensor_dict1[key] = val - - return tensor_dict1 - - -def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): - if len(list_of_dict) == 0: - return {} - keys = list_of_dict[0].keys() - output = {key: [] for key in keys} - for data in list_of_dict: - for key, item in data.items(): - assert key in output - output[key].append(item) - return output - - -def fold_batch_dim(data: "DataProto", new_batch_size): - """ - Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx] - """ - batch_size = data.batch.batch_size[0] - - assert batch_size % new_batch_size == 0 - - tensor: TensorDict = data.batch - non_tensor = data.non_tensor_batch - - tensor = tensor.view(new_batch_size, -1) - tensor.auto_batch_size_(batch_dims=1) - - for key, val in non_tensor.items(): - non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:])) - - return type(data)(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info) - - -def unfold_batch_dim(data: "DataProto", batch_dims=2): - """ - Unfold the first n dims as new batch dim - """ - tensor: TensorDict = data.batch - non_tensor = data.non_tensor_batch - tensor.auto_batch_size_(batch_dims=batch_dims) - tensor = tensor.view(-1) - - batch_size = tensor.batch_size[0] - - non_tensor_new = {} - - for key, val in non_tensor.items(): - non_tensor_new[key] = np.reshape(val, newshape=(batch_size, *val.shape[batch_dims:])) - - return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info) - - -def collate_fn(x: list["DataProtoItem"]): - batch = [] - non_tensor_batch = [] - for data in x: - batch.append(data.batch) - non_tensor_batch.append(data.non_tensor_batch) - batch = torch.stack(batch).contiguous() - non_tensor_batch = list_of_dict_to_dict_of_list(non_tensor_batch) - for key, val in non_tensor_batch.items(): - non_tensor_batch[key] = np.array(val, dtype=object) - return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) - - -@dataclass -class DataProtoItem: - # TODO(zhangchi.usc1992) add consistency check - batch: TensorDict = None - non_tensor_batch: dict = field(default_factory=dict) - meta_info: dict = field(default_factory=dict) - - -@dataclass -class DataProto: - """ - A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. - It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. - TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the - same batch size should be put inside batch. - """ - - batch: TensorDict = None - non_tensor_batch: dict = field(default_factory=dict) - meta_info: dict = field(default_factory=dict) - - def __post_init__(self): - # perform necessary checking - self.check_consistency() - - def __len__(self): - if self.batch is not None: - return self.batch.batch_size[0] - elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0: - random_key = list(self.non_tensor_batch.keys())[0] - return self.non_tensor_batch[random_key].shape[0] - else: - return 0 - - def __getitem__(self, item): - """ - Enhanced indexing for DataProto objects. - - Args: - item: Can be one of: - - int: A single index - - slice: A slice object (start:stop:step) - - list: A list of indices - - numpy.ndarray: An array of indices - - torch.Tensor: A tensor of indices - - Returns: - DataProto: For all indexing types except single integers - DataProtoItem: Only for single integer indices - """ - # Case 1: Slice object - use the slice method - if isinstance(item, slice): - return self.slice(item.start, item.stop, item.step) - - # Case 2: List, numpy array, or torch tensor - use sel_idxs - elif isinstance(item, list | np.ndarray | torch.Tensor): - return self.select_idxs(item) - - # Case 3: Single integer - return DataProtoItem for backward compatibility - elif isinstance(item, int | np.integer): - tensor_data = self.batch[item] if self.batch is not None else None - non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} - return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) - - # # Case 4: Unsupported type - else: - raise TypeError(f"Indexing with {type(item)} is not supported") - - def __getstate__(self): - import io - - buffer = io.BytesIO() - if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: - self.batch = self.batch.contiguous() - self.batch = self.batch.consolidate() - torch.save(self.batch, buffer) - buffer_bytes = buffer.getvalue() - return buffer_bytes, self.non_tensor_batch, self.meta_info - - def __setstate__(self, data): - import io - - batch_deserialized_bytes, non_tensor_batch, meta_info = data - batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) - batch = torch.load( - batch_deserialized, - weights_only=False, - map_location="cpu" if not get_torch_device().is_available() else None, - ) - self.batch = batch - self.non_tensor_batch = non_tensor_batch - self.meta_info = meta_info - - def save_to_disk(self, filepath): - with open(filepath, "wb") as f: - pickle.dump(self, f) - - @staticmethod - def load_from_disk(filepath) -> "DataProto": - with open(filepath, "rb") as f: - data = pickle.load(f) - return data - - def print_size(self, prefix=""): - size_of_tensordict = 0 - if self.batch is not None: - for _, tensor in self.batch.items(): - size_of_tensordict += tensor.element_size() * tensor.numel() - size_of_numpy_array = 0 - for _, numpy_array in self.non_tensor_batch.items(): - size_of_numpy_array += numpy_array.nbytes - - size_of_numpy_array /= 1024**3 - size_of_tensordict /= 1024**3 - - message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB" - - if prefix: - message = f"{prefix}, " + message - print(message) - - def check_consistency(self): - """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch - We expose this function as a public one so that user can call themselves directly - """ - if self.batch is not None: - assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1" - - if self.non_tensor_batch is not None: - for key, val in self.non_tensor_batch.items(): - assert isinstance(val, np.ndarray) - - if self.batch is not None and self.non_tensor_batch is not None and len(self.non_tensor_batch) != 0: - # TODO: we can actually lift this restriction if needed - assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty." - - batch_size = self.batch.batch_size[0] - for key, val in self.non_tensor_batch.items(): - assert isinstance(val, np.ndarray), ( - f"data in the non_tensor_batch must be a numpy.array with dtype=object, but for " - f"{key=}, got {type(val)=}" - ) - assert val.shape[0] == batch_size, ( - f"key {key} length {len(val)} is not equal to batch size {batch_size}" - ) - - @classmethod - def from_single_dict(cls, data: dict[str, torch.Tensor | np.ndarray], meta_info=None, auto_padding=False): - """Create a DataProto from a dict of tensors and non_tensors""" - tensors = {} - non_tensors = {} - - for key, val in data.items(): - if isinstance(val, torch.Tensor): - tensors[key] = val - elif isinstance(val, np.ndarray): - non_tensors[key] = val - else: - raise ValueError(f"Unsupported type in data {type(val)}") - - return cls.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info, auto_padding=auto_padding) - - @classmethod - def from_dict( - cls, - tensors: Optional[dict[str, torch.Tensor]] = None, - non_tensors=None, - meta_info=None, - num_batch_dims=1, - auto_padding=False, - ): - """Create a DataProto from a dict of tensors. This assumes that - 1. All the tensor in tensors have the same dim0 - 2. Only dim0 is the batch dim - """ - - assert num_batch_dims > 0, "num_batch_dims must be greater than zero" - if non_tensors is not None: - assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None." - - if tensors is None: - tensors = {} - if meta_info is None: - meta_info = {} - if non_tensors is None: - non_tensors = {} - - assert isinstance(non_tensors, dict) - - # get and check batch size - batch_size = None - pivot_key = None - for key, tensor in tensors.items(): - if batch_size is None: - batch_size = tensor.shape[:num_batch_dims] - pivot_key = key - else: - current_batch = tensor.shape[:num_batch_dims] - assert batch_size == current_batch, ( - f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. " - f"Got {pivot_key} has {batch_size}, {key} has {current_batch}" - ) - - for key, val in non_tensors.items(): - if not isinstance(val, np.ndarray): - non_tensors[key] = np.array(val, dtype=object) - - tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None - if auto_padding: - meta_info[DataProtoConfig.auto_padding_key] = True - return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) - - def to(self, device) -> "DataProto": - """move the batch to device - - Args: - device (torch.device, str): torch device - - Returns: - DataProto: the current DataProto - - """ - if self.batch is not None: - self.batch = self.batch.to(device) - return self - - def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> "DataProto": - """Select a subset of the DataProto via batch_keys and meta_info_keys - - Args: - batch_keys (list, optional): a list of strings indicating the keys in batch to select - meta_info_keys (list, optional): a list of keys indicating the meta info to select - - Returns: - DataProto: the DataProto with the selected batch_keys and meta_info_keys - """ - # TODO (zhangchi.usc1992) whether to copy - if batch_keys is not None: - batch_keys = tuple(batch_keys) - sub_batch = self.batch.select(*batch_keys) - else: - sub_batch = self.batch - - if non_tensor_batch_keys is not None: - non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys} - else: - non_tensor_batch = self.non_tensor_batch - - if deepcopy: - non_tensor_batch = copy.deepcopy(non_tensor_batch) - - if meta_info_keys is not None: - sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys} - else: - sub_meta_info = self.meta_info - - if deepcopy: - sub_meta_info = copy.deepcopy(sub_meta_info) - - return type(self)(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) - - def select_idxs(self, idxs): - """ - Select specific indices from the DataProto. - - Args: - idxs (torch.Tensor or numpy.ndarray or list): Indices to select - - Returns: - DataProto: A new DataProto containing only the selected indices - """ - if isinstance(idxs, list): - idxs = torch.tensor(idxs) - if idxs.dtype != torch.bool: - idxs = idxs.type(torch.int32) - - if isinstance(idxs, np.ndarray): - idxs_np = idxs - idxs_torch = torch.from_numpy(idxs) - else: # torch.Tensor - idxs_torch = idxs - idxs_np = idxs.detach().cpu().numpy() - - batch_size = int(idxs_np.sum()) if idxs_np.dtype == bool else idxs_np.shape[0] - - if self.batch is not None: - # Use TensorDict's built-in indexing capabilities - selected_batch = TensorDict( - source={key: tensor[idxs_torch] for key, tensor in self.batch.items()}, - batch_size=(batch_size,), - device=self.batch.device, - ) - else: - selected_batch = None - - selected_non_tensor = {} - for key, val in self.non_tensor_batch.items(): - selected_non_tensor[key] = val[idxs_np] - - return type(self)(batch=selected_batch, non_tensor_batch=selected_non_tensor, meta_info=self.meta_info) - - def slice(self, start=None, end=None, step=None): - """ - Slice the DataProto and return a new DataProto object. - This is an improved version of direct slicing which returns a DataProtoItem. - - Args: - start (int, optional): Start index. Defaults to None (start from beginning). - end (int, optional): End index (exclusive). Defaults to None (go to end). - step (int, optional): Step size. Defaults to None (step=1). - - Returns: - DataProto: A new DataProto containing the sliced data - - Examples: - # Using the slice method directly - sliced_data = data_proto.slice(10, 20) - - # Using enhanced indexing (returns DataProto) - sliced_data = data_proto[10:20] - sliced_data = data_proto[::2] # Every other element - - # Using list indexing (returns DataProto) - indices = [1, 5, 10] - selected_data = data_proto[indices] - - # Single index still returns DataProtoItem - single_item = data_proto[5] - """ - # Create a slice object - slice_obj = slice(start, end, step) - - # Handle the batch data - if self.batch is not None: - # Use TensorDict's built-in slicing capabilities - sliced_batch = self.batch[slice_obj] - else: - sliced_batch = None - - # Handle the non-tensor batch data - sliced_non_tensor = {} - for key, val in self.non_tensor_batch.items(): - sliced_non_tensor[key] = val[slice_obj] - - # Return a new DataProto object - return type(self)(batch=sliced_batch, non_tensor_batch=sliced_non_tensor, meta_info=self.meta_info) - - def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> "DataProto": - """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` - - Args: - batch_keys (list, optional): a list of strings indicating the keys in batch to pop - meta_info_keys (list, optional): a list of keys indicating the meta info to pop - - Returns: - DataProto: the DataProto with the poped batch_keys and meta_info_keys - """ - if batch_keys is None: - batch_keys = [] - if meta_info_keys is None: - meta_info_keys = [] - if non_tensor_batch_keys is None: - non_tensor_batch_keys = [] - - tensors = {} - # tensor batch - for key in batch_keys: - assert key in self.batch.keys() - tensors[key] = self.batch.pop(key) - non_tensors = {} - # non tensor batch - for key in non_tensor_batch_keys: - assert key in self.non_tensor_batch.keys() - non_tensors[key] = self.non_tensor_batch.pop(key) - meta_info = {} - for key in meta_info_keys: - assert key in self.meta_info.keys() - meta_info[key] = self.meta_info.pop(key) - return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) - - def rename(self, old_keys=None, new_keys=None) -> "DataProto": - """ - Note that this function only rename the key in the batch - """ - - def validate_input(keys): - if keys is not None: - if isinstance(keys, str): - keys = [keys] - elif isinstance(keys, list): - pass - else: - raise TypeError(f"keys must be a list or a string, but got {type(keys)}") - return keys - - old_keys = validate_input(old_keys) - new_keys = validate_input(new_keys) - - if len(new_keys) != len(old_keys): - raise ValueError( - f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}" - ) - - self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) - - return self - - def union(self, other: "DataProto") -> "DataProto": - """Union with another DataProto. Union batch and meta_info separately. - Throw an error if - - - there are conflict keys in batch and they are not equal - - the batch size of two data batch is not the same - - there are conflict keys in meta_info and they are not the same. - - Args: - other (DataProto): another DataProto to union - - Returns: - DataProto: the DataProto after union - """ - self.batch = union_tensor_dict(self.batch, other.batch) - self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) - self.meta_info = union_two_dict(self.meta_info, other.meta_info) - return self - - def make_iterator(self, mini_batch_size, epochs, seed=None, dataloader_kwargs=None): - r"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch - dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details. - - - Args: - mini_batch_size (int): mini-batch size when iterating the dataset. We require that - ``batch.batch_size[0] % mini_batch_size == 0``. - epochs (int): number of epochs when iterating the dataset. - dataloader_kwargs (Any): internally, it returns a DataLoader over the batch. The - dataloader_kwargs is the kwargs passed to the DataLoader. - - Returns: - Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration - steps is ``self.batch.batch_size * epochs // mini_batch_size`` - """ - assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" - # we can directly create a dataloader from TensorDict - if dataloader_kwargs is None: - dataloader_kwargs = {} - - if seed is not None: - generator = torch.Generator() - generator.manual_seed(seed) - else: - generator = None - - assert isinstance(dataloader_kwargs, dict) - train_dataloader = DataLoader( - dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs - ) - - def get_data(): - for _ in range(epochs): - for d in train_dataloader: - d.meta_info = self.meta_info - yield d - - return iter(get_data()) - - def is_padding_enabled(self): - """ - Check if padding is enabled for the DataProto. - Returns: - bool: True if padding is enabled, False otherwise. - """ - dataproto_specific_padding = self.meta_info.get(DataProtoConfig.auto_padding_key, False) - return dataproto_specific_padding or DataProtoConfig.auto_padding - - def padding(self, padding_size, padding_candidate=""): - """Pad the DataProto by concating with padding_candidate.repeat(padding_size) - - Args: - padding_size (int): the number of repeated padding_candidate - padding_candidate: the item to be repeated and appended to the DataProto, only supporting ["first", "last"] - """ - if padding_size == 0: - return - padding_candidate = self.select_idxs([0 if padding_candidate == "first" else len(self) - 1]) - padding_part = padding_candidate.repeat(padding_size) - padded_dp = DataProto.concat([self, padding_part]) - self.batch = padded_dp.batch - self.non_tensor_batch = padded_dp.non_tensor_batch - - def chunk(self, chunks: int) -> list["DataProto"]: - """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. - - Args: - chunks (int): the number of chunks to split on dim=0 - - Returns: - List[DataProto]: a list of DataProto after splitting - """ - if not self.is_padding_enabled(): - assert len(self) % chunks == 0, ( - f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." - ) - - bsz_in_batch = None - if self.batch is not None: - batch_lst = self.batch.chunk(chunks=chunks, dim=0) - bsz_in_batch = np.array([batch.batch_size[0] for batch in batch_lst]) - chunk_indices = np.cumsum(bsz_in_batch)[:-1] - else: - batch_lst = [None for _ in range(chunks)] - - non_tensor_batch_lst = [{} for _ in range(chunks)] - for key, val in self.non_tensor_batch.items(): - assert isinstance(val, np.ndarray) - if bsz_in_batch is not None: - non_tensor_lst = np.array_split(val, chunk_indices.tolist()) - else: - non_tensor_lst = np.array_split(val, chunks) - assert len(non_tensor_lst) == chunks - for i in range(chunks): - non_tensor_batch_lst[i][key] = non_tensor_lst[i] - - output = [] - for i in range(chunks): - output.append( - type(self)(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info) - ) - - return output - - def split(self, split_size: int) -> list["DataProto"]: - """Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. - - Args: - split_size (int): the size of each split - - Returns: - List[DataProto]: a list of DataProto after splitting - """ - return [self[i : i + split_size] for i in range(0, len(self), split_size)] - - @staticmethod - def concat(data: list["DataProto"]) -> "DataProto": - """Concat a list of DataProto. The batch is concatenated among dim=0. - The meta_info is assumed to be identical and will use the first one. - - Args: - data (List[DataProto]): list of DataProto - - Returns: - DataProto: concatenated DataProto - """ - batch_lst = [] - for batch in data: - batch_lst.append(batch.batch) - new_batch = torch.cat(batch_lst, dim=0) if batch_lst[0] is not None else None - - non_tensor_batch = list_of_dict_to_dict_of_list(list_of_dict=[d.non_tensor_batch for d in data]) - for key, val in non_tensor_batch.items(): - non_tensor_batch[key] = np.concatenate(val, axis=0) - - cls = type(data[0]) if len(data) > 0 else DataProto - return cls(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) - - def reorder(self, indices): - """ - Note that this operation is in-place - """ - indices_np = indices.detach().numpy() - self.batch = self.batch[indices] - self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} - - def repeat(self, repeat_times=2, interleave=True): - """ - Repeat the batch data a specified number of times. - - Args: - repeat_times (int): Number of times to repeat the data. - interleave (bool): Whether to interleave the repeated data. - - Returns: - DataProto: A new DataProto with repeated data. - """ - if self.batch is not None: - if interleave: - # Interleave the data - repeated_tensors = { - key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() - } - else: - # Stack the data - repeated_tensors = { - key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) - for key, tensor in self.batch.items() - } - - repeated_batch = TensorDict( - source=repeated_tensors, - batch_size=(self.batch.batch_size[0] * repeat_times,), - ) - else: - repeated_batch = None - - repeated_non_tensor_batch = {} - for key, val in self.non_tensor_batch.items(): - if interleave: - repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) - else: - repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) - - return type(self)( - batch=repeated_batch, - non_tensor_batch=repeated_non_tensor_batch, - meta_info=self.meta_info, - ) - - def unfold_column_chunks(self, n_split: int, split_keys: Optional[list[str]] = None): - """Split along the second dim into `n_split`, unfold it to the first dim (batch dim) - Useful in passing grouped tensors that doesn't want to be shuffled in dataset. - keys not in split_keys are repeated to match the shape - Note that if the `split_keys` is not provided, it will repeat all the keys in the second dim. - """ - if self.batch is not None: - unfolded_batch = {} - for key in self.batch.keys(): - if key in split_keys if split_keys is not None else False: - shape = list(self.batch[key].shape) - shape[0] = self.batch[key].shape[0] * n_split - shape[1] = self.batch[key].shape[1] // n_split - unfolded_batch[key] = self.batch[key].reshape(*shape) - else: - unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0) - # locate the `unfolded_batch` as a TensorDict on the same device as the original batch - unfolded_batch = TensorDict( - source=unfolded_batch, batch_size=(self.batch.batch_size[0] * n_split,), device=self.batch.device - ) - else: - unfolded_batch = None - - repeated_non_tensor_batch = {} - for key, val in self.non_tensor_batch.items(): - if key in split_keys: - shape = list(val.shape) - shape[0] = val.shape[0] * n_split - shape[1] = val.shape[1] // n_split - repeated_non_tensor_batch[key] = val.reshape(*shape) - else: - repeated_non_tensor_batch[key] = np.repeat(val, n_split, axis=0) - - return type(self)( - batch=unfolded_batch, - non_tensor_batch=repeated_non_tensor_batch, - meta_info=self.meta_info, - ) - - def sample_level_repeat(self, repeat_times): - """ - Repeat each row of the batch data a specified number of times. - - Args: - repeat_times (torch.tensor, list, tuple, ndarray): Number of times to repeat the data. - - Returns: - DataProto: A new DataProto with repeated data. - """ - if isinstance(repeat_times, tuple): - repeat_times = list(repeat_times) - elif isinstance(repeat_times, torch.Tensor): - assert len(repeat_times.shape) == 1 - repeat_times = repeat_times.tolist() - elif isinstance(repeat_times, np.ndarray): - assert len(repeat_times.shape) == 1 - repeat_times = repeat_times.tolist() - else: - assert isinstance(repeat_times, list), ( - f"repeat_times type must be in [list, torch.Tensor, np.ndarray, tuple], got {type(repeat_times)}" - ) - repeat_times = torch.tensor(repeat_times) - - if self.batch is not None: - # Interleave the data - repeated_tensors = { - key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() - } - - repeated_batch = TensorDict( - source=repeated_tensors, - batch_size=(repeat_times.sum().item(),), - device=self.batch.device, - ) - else: - repeated_batch = None - - repeated_non_tensor_batch = {} - for key, val in self.non_tensor_batch.items(): - repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) - - return type(self)( - batch=repeated_batch, - non_tensor_batch=repeated_non_tensor_batch, - meta_info=self.meta_info, - ) - - -@dataclass -class DataProtoFuture: - """ - DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait - for data so that asynchronous execution becomes possible. - DataProtoFuture contains a list of futures from another WorkerGroup of size world_size. - - collect_fn is a Callable that reduces the list of futures to a DataProto - - dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size - and then select - - Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination - - DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any - operation on the DataProtoFuture in driver. - """ - - collect_fn: Callable - futures: list[ray.ObjectRef] - dispatch_fn: Callable = None - - @staticmethod - def concat(data: list[ray.ObjectRef]) -> "DataProtoFuture": - output = DataProtoFuture(collect_fn=DataProto.concat, futures=data) - return output - - def chunk(self, chunks: int) -> list["DataProtoFuture"]: - from functools import partial - - arg_future_lst = [] - for i in range(chunks): - # note that we can't directly pass i and chunks - def dispatch_fn(x, i, chunks): - return x.chunk(chunks=chunks)[i] - - arg_future = DataProtoFuture( - collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures - ) - arg_future_lst.append(arg_future) - return arg_future_lst - - def get(self): - output = ray.get(self.futures) # dp_size. - for o in output: - assert isinstance(o, DataProto) - output = self.collect_fn(output) # select dp, concat - if self.dispatch_fn is not None: - output = self.dispatch_fn(output) # split in batch dim, select using dp - return output - - -def all_gather_data_proto(data: DataProto, process_group): - # Note that this is an inplace operator just like torch.distributed.all_gather - group_size = torch.distributed.get_world_size(group=process_group) - assert isinstance(data, DataProto) - prev_device = data.batch.device - data.batch = data.batch.to(get_device_id()) - data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0) - data.batch = data.batch.to(prev_device) - # all gather non_tensor_batch - all_non_tensor_batch = [None for _ in range(group_size)] - torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=process_group) - data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch} diff --git a/verl/single_controller/__init__.py b/verl/single_controller/__init__.py deleted file mode 100644 index ad6c42a80..000000000 --- a/verl/single_controller/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -from . import base -from .base import * - -version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) - -# Note(haibin.lin): single_controller.__version__ is deprecated -with open(os.path.join(os.path.join(version_folder, os.pardir), "version/version")) as f: - __version__ = f.read().strip() - - -__all__ = base.__all__ diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py deleted file mode 100644 index c22ca34ec..000000000 --- a/verl/single_controller/base/decorator.py +++ /dev/null @@ -1,705 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from functools import wraps -from types import FunctionType - -import torch # Added by Reasoning360 - -from verl.protocol import DataProtoFuture, _padding_size_key -from verl.utils.py_functional import DynamicEnum - -# here we add a magic number of avoid user-defined function already have this attribute -MAGIC_ATTR = "attrs_3141562937" - - -class Dispatch(DynamicEnum): - """Enum class defining different dispatch modes for distributed computation. - - Each mode represents a specific strategy for distributing data across - different ranks in a distributed system. The modes are used to control - how data is partitioned and processed across different worker groups. - """ - - _registry = {} - _next_value = 0 - - -def init_predefined_dispatch_mode(): - Dispatch.register("RANK_ZERO") - Dispatch.register("ONE_TO_ALL") - Dispatch.register("ALL_TO_ALL") - Dispatch.register("MEGATRON_COMPUTE") - Dispatch.register("MEGATRON_PP_AS_DP") - Dispatch.register("MEGATRON_PP_ONLY") - Dispatch.register("MEGATRON_COMPUTE_PROTO") - Dispatch.register("MEGATRON_PP_AS_DP_PROTO") - Dispatch.register("DP_COMPUTE") - Dispatch.register("DP_COMPUTE_PROTO") - Dispatch.register("DP_COMPUTE_PROTO_WITH_FUNC") - Dispatch.register("DP_COMPUTE_METRIC") - Dispatch.register("MEGATRON_PP_DUMMY_PROTO") - # This is a special dispatch mode for vllm ExternalRayDistributedExecutor - Dispatch.register("DIRECT_ROLLOUT_METHOD") - - -class Execute(DynamicEnum): - """Enum class defining different execution modes for distributed computation. - - These modes control how a function should be executed across different ranks - in a distributed system. - """ - - _registry = {} - _next_value = 0 - - -def init_predefined_execute_mode(): - Execute.register("ALL") - Execute.register("RANK_ZERO") - - -# Initialize the two Dynamic Enum Classes -init_predefined_dispatch_mode() -init_predefined_execute_mode() - - -def _split_args_kwargs_data_proto(chunks, *args, **kwargs): - from verl.protocol import DataProto, DataProtoFuture - - splitted_args = [] - for arg in args: - assert isinstance(arg, DataProto | DataProtoFuture) - splitted_args.append(arg.chunk(chunks=chunks)) - - splitted_kwargs = {} - for key, val in kwargs.items(): - assert isinstance(val, DataProto | DataProtoFuture) - splitted_kwargs[key] = val.chunk(chunks=chunks) - - return splitted_args, splitted_kwargs - - -def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs): - from verl.protocol import DataProto, DataProtoFuture - - splitted_args = [] - splitted_kwargs = {} - - data_proto_len = None - padding_size = None - for arg in args: - assert isinstance(arg, (DataProto, DataProtoFuture)) - if isinstance(arg, DataProto) and arg.is_padding_enabled(): - # for padding, we only support DataProto with same length - if data_proto_len is None: - data_proto_len = len(arg) - padding_size = (chunks - (data_proto_len % chunks)) if (data_proto_len % chunks > 0) else 0 - splitted_kwargs[_padding_size_key] = padding_size - else: - assert data_proto_len == len(arg), f"expecting all arg share same length of {data_proto_len}, but got {len(arg)}" - data_proto_len = len(arg) - arg.padding(padding_size=padding_size) - - splitted_args.append(arg.chunk(chunks=chunks)) - - for key, val in kwargs.items(): - assert isinstance(val, (DataProto, DataProtoFuture)) - if isinstance(val, DataProto) and val.is_padding_enabled(): - # for padding, we only support DataProto with same length - if data_proto_len is None: - data_proto_len = len(val) - padding_size = chunks - (data_proto_len % chunks) - splitted_kwargs[_padding_size_key] = padding_size - else: - assert data_proto_len == len(val), f"expecting all arg share same length of {data_proto_len}, but got {len(val)}" - data_proto_len = len(val) - splitted_kwargs[key] = val.chunk(chunks=chunks) - - return splitted_args, splitted_kwargs - - -def dispatch_one_to_all(worker_group, *args, **kwargs): - args = tuple([arg] * worker_group.world_size for arg in args) - kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()} - return args, kwargs - - -def dummy_direct_rollout_call(worker_group, *args, **kwargs): - raise NotImplementedError("Direct rollout call is forbidden.") - - -def dispatch_all_to_all(worker_group, *args, **kwargs): - return args, kwargs - - -def collect_all_to_all(worker_group, output): - return output - - -def dispatch_megatron_compute(worker_group, *args, **kwargs): - """ - User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp - """ - from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - - assert isinstance(worker_group, MegatronWorkerGroup), ( - f"worker_group must be MegatronWorkerGroup, Got {type(worker_group)}" - ) - - # ray put all the args in advance to avoid duplicate serialization cost - import ray - - args = [[ray.put(dp_arg) for dp_arg in arg] for arg in args] - kwargs = {k: [ray.put(dp_v) for dp_v in v] for k, v in kwargs.items()} - - all_args = [] - for arg in args: - assert isinstance(arg, tuple | list) and len(arg) == worker_group.dp_size - transformed_args = [] - for i in range(worker_group.world_size): - local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank - transformed_args.append(arg[local_dp_rank]) - all_args.append(transformed_args) - all_args = tuple(all_args) - - all_kwargs = {} - for k, v in kwargs.items(): - assert isinstance(v, tuple | list) and len(v) == worker_group.dp_size - transformed_v = [] - for i in range(worker_group.world_size): - local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank - transformed_v.append(v[local_dp_rank]) - all_kwargs[k] = transformed_v - return all_args, all_kwargs - - -def collect_megatron_compute(worker_group, output): - """ - Only collect the data from the tp=0 and pp=last and every dp ranks - """ - from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - - assert isinstance(worker_group, MegatronWorkerGroup) - output_in_dp = [] - pp_size = worker_group.get_megatron_global_info().pp_size - for global_rank in range(worker_group.world_size): - local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) - if local_rank_info.tp_rank == 0 and local_rank_info.pp_rank == pp_size - 1 and local_rank_info.cp_rank == 0: - output_in_dp.append(output[global_rank]) - return output_in_dp - - -def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs): - """ - All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank - """ - from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - - assert isinstance(worker_group, MegatronWorkerGroup) - - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs) - return dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs) - - -def _concat_data_proto_or_future(output: list): - import ray - - from verl.protocol import DataProto, DataProtoFuture - - # make sure all the elements in output has the same type - for o in output: - assert type(o) is type(output[0]) - - o = output[0] - - if isinstance(o, DataProto): - return DataProto.concat(output) - elif isinstance(o, ray.ObjectRef): - return DataProtoFuture.concat(output) - else: - raise NotImplementedError - - -def collect_megatron_compute_data_proto(worker_group, output): - """ - Each output must be a DataProto. We concat the dim=0 of output - """ - import ray - - from verl.protocol import DataProto - - output = collect_megatron_compute(worker_group, output) - for o in output: - assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}" - - return _concat_data_proto_or_future(output) - - -def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): - """ - treat pp as dp. - """ - from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - - assert isinstance(worker_group, MegatronWorkerGroup) - - pp_size = worker_group.pp_size - dp_size = worker_group.dp_size - cp_size = worker_group.cp_size - pp_dp_cp_size = pp_size * dp_size * cp_size - - all_args = [] - for arg in args: - assert isinstance(arg, list | tuple) and len(arg) == pp_dp_cp_size - transformed_args = [] - for i in range(worker_group.world_size): - local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank - local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank - local_cp_rank = worker_group.get_megatron_rank_info(rank=i).cp_rank - # compute the rank in arg. Note that the order is dp then cp then pp - # Also note that the outputs within a pp group will be firstly allgathered, then only the - # output of pp0 will be collected. - # For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order: - # dispatch: pp_allgther: collect: - # dp 0 1 2 3 dp 0 1 2 3 - # pp +---------+ pp +-------------+ - # 0 | A C E G | 0 | AB CD EF GH | ABCDEFGH - # 1 | B D F H | 1 | AB CD EF GH | - # +---------+ +-------------+ - dp_cp_rank = local_cp_rank * dp_size + local_dp_rank - arg_rank = dp_cp_rank * pp_size + local_pp_rank - - transformed_args.append(arg[arg_rank]) - all_args.append(transformed_args) - all_args = tuple(all_args) - - all_kwargs = {} - for k, v in kwargs.items(): - assert isinstance(v, list | tuple) and len(v) == pp_dp_cp_size, f"expect len(v)=={pp_dp_cp_size}, got {len(v)}" - transformed_v = [] - for i in range(worker_group.world_size): - local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank - local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank - local_cp_rank = worker_group.get_megatron_rank_info(rank=i).cp_rank - # compute the rank in arg. Note that the order is dp then cp then pp - dp_cp_rank = local_cp_rank * dp_size + local_dp_rank - arg_rank = dp_cp_rank * pp_size + local_pp_rank - transformed_v.append(v[arg_rank]) - all_kwargs[k] = transformed_v - return all_args, all_kwargs - - -def collect_megatron_pp_as_dp(worker_group, output): - """ - treat pp as dp. Only collect data on tp=0 - """ - from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - - assert isinstance(worker_group, MegatronWorkerGroup) - output_in_dp = [] - for global_rank in range(worker_group.world_size): - local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) - if local_rank_info.tp_rank == 0: - output_in_dp.append(output[global_rank]) - return output_in_dp - - -def collect_megatron_pp_only(worker_group, output): - """ - Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp - """ - from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - - assert isinstance(worker_group, MegatronWorkerGroup) - output_in_pp = [] - for global_rank in range(worker_group.world_size): - local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank) - if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0: - output_in_pp.append(output[global_rank]) - return output_in_pp - - -def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): - from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - - assert isinstance(worker_group, MegatronWorkerGroup) - - pp_dp_cp_size = worker_group.dp_size * worker_group.pp_size * worker_group.cp_size - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(pp_dp_cp_size, *args, **kwargs) - ret = dispatch_megatron_pp_as_dp(worker_group, *splitted_args, **splitted_kwargs) - return ret - - -def collect_megatron_pp_as_dp_data_proto(worker_group, output): - from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - - assert isinstance(worker_group, MegatronWorkerGroup) - - output = collect_megatron_pp_as_dp(worker_group, output) - return _concat_data_proto_or_future(output) - - -def dispatch_dp_compute(worker_group, *args, **kwargs): - from verl.single_controller.base.worker_group import WorkerGroup - - assert isinstance(worker_group, WorkerGroup) - for arg in args: - assert isinstance(arg, tuple | list) and len(arg) == worker_group.world_size - for k, v in kwargs.items(): - assert isinstance(v, tuple | list) and len(v) == worker_group.world_size - return args, kwargs - - -def collect_dp_compute(worker_group, output): - from verl.single_controller.base.worker_group import WorkerGroup - - assert isinstance(worker_group, WorkerGroup) - assert len(output) == worker_group.world_size - return output - - -def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): - from verl.single_controller.base.worker_group import WorkerGroup - - assert isinstance(worker_group, WorkerGroup) - # Note: enable auto padding for dp compute DatapProto - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto_with_auto_padding( - worker_group.world_size, - *args, - **kwargs, - ) - return splitted_args, splitted_kwargs - - -def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): - from verl.single_controller.base.worker_group import WorkerGroup - - assert isinstance(worker_group, WorkerGroup) - assert isinstance(args[0], FunctionType) # NOTE: The first one args is a function! - - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args[1:], **kwargs) - splitted_args_with_func = [[args[0]] * worker_group.world_size] + splitted_args - return splitted_args_with_func, splitted_kwargs - - -def collect_dp_compute_data_proto(worker_group, output): - import ray - - from verl.protocol import DataProto - - for o in output: - assert isinstance(o, DataProto | ray.ObjectRef), f"expecting {o} to be DataProto, but got {type(o)}" - - output = collect_dp_compute(worker_group, output) - return _concat_data_proto_or_future(output) - - -#### Added by Reasoning360 -MAGIC_PREFIX = "__verl_dummy_tensor_" -def _materialize_dummy_data_proto(arg): - from verl.protocol import DataProto - from tensordict import TensorDict - import numpy as np - - if not isinstance(arg, DataProto): - return arg - - # This is not a dummy data proto - if not arg.meta_info.get(f"{MAGIC_PREFIX}is_dummy", False): - return arg - arg.meta_info.pop(f"{MAGIC_PREFIX}is_dummy") - - new_batch = {} - new_non_tensor_batch = {} - batch_size = None - for k, v in arg.batch.items(): - assert f"{MAGIC_PREFIX}batch_{k}_shape" in arg.meta_info - shape = arg.meta_info[f"{MAGIC_PREFIX}batch_{k}_shape"] - new_batch[k] = torch.zeros(shape, dtype=v.dtype, device=v.device) - arg.meta_info.pop(f"{MAGIC_PREFIX}batch_{k}_shape") - batch_size = batch_size or shape[0] - assert batch_size == shape[0], f"{batch_size=}, {shape=}" - for k, v in arg.non_tensor_batch.items(): - assert f"{MAGIC_PREFIX}non_tensor_batch_{k}_shape" in arg.meta_info - shape = arg.meta_info[f"{MAGIC_PREFIX}non_tensor_batch_{k}_shape"] - new_non_tensor_batch[k] = np.zeros(shape, dtype=v.dtype) - arg.meta_info.pop(f"{MAGIC_PREFIX}non_tensor_batch_{k}_shape") - assert batch_size == shape[0], f"{batch_size=}, {shape=}" - return DataProto( - batch=TensorDict(new_batch, batch_size=batch_size), - non_tensor_batch=new_non_tensor_batch, - meta_info=arg.meta_info, - ) - - -def _make_dummy_data_proto(arg): - from verl.protocol import DataProto - import numpy as np - from tensordict import TensorDict - - if not isinstance(arg, DataProto): - return arg - - new_batch = TensorDict({}, batch_size=[1]) - new_non_tensor_batch = {} - meta_info = arg.meta_info.copy() - - empty_shape = [1] - for k, v in arg.batch.items(): - shape = v.shape - # empty_shape = [0] + list(shape[1:]) - new_batch[k] = torch.zeros(empty_shape, dtype=v.dtype, device=v.device) - meta_info[f"{MAGIC_PREFIX}batch_{k}_shape"] = shape - - for k, v in arg.non_tensor_batch.items(): - shape = v.shape - # empty_shape = [0] + list(shape[1:]) - new_non_tensor_batch[k] = np.zeros(empty_shape, dtype=v.dtype) - meta_info[f"{MAGIC_PREFIX}non_tensor_batch_{k}_shape"] = shape - meta_info[f"{MAGIC_PREFIX}is_dummy"] = True - return DataProto(batch=new_batch, non_tensor_batch=new_non_tensor_batch, meta_info=meta_info) - - -def dispatch_megatron_pp_dummy_data_proto(worker_group, *args, **kwargs): - """ - NOTE: added by Reasoning360. It reads from a special keyword argument `verl_pp_send_rank: Sequence[int]` - It handles other arguments the same as `dispatch_megatron_compute_data_proto`, but the DataProto args are different that: - For Data Parallel Group (DP), the dispatch pattern is the same as `dispatch_megatron_compute_data_proto`. - For Pipeline Parallel Group (PP), only workers with a PP rank within `verl_pp_send_rank` will be dispatched. Other workers - wil receive an empty DataProto, with meta_info pairs of `batch_{key}_shape: value.shape for key, value in arg.batch`. - NOTE: this function cannot handle DataProtoFuture now. - TODO: broadcast within TP ranks after receiving, then TP ranks > 0 will also receive dummy data. - """ - from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - from verl.protocol import DataProto, DataProtoFuture - - assert isinstance(worker_group, MegatronWorkerGroup) - - # Extract the special keyword argument for PP send ranks - verl_pp_send_rank = kwargs.pop("verl_pp_send_rank", None) - if verl_pp_send_rank is None: - verl_pp_send_rank = (0, worker_group.pp_size - 1) - - # First, split the DataProto arguments by dp_size like in megatron_compute_data_proto - splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs) - - # Now apply the megatron compute dispatch pattern - all_args, all_kwargs = dispatch_megatron_compute(worker_group, *splitted_args, **splitted_kwargs) - - # For each worker, check if it should receive data or empty DataProto - for rank in range(worker_group.world_size): - local_rank_info = worker_group.get_megatron_rank_info(rank=rank) - pp_rank = local_rank_info.pp_rank - tp_rank = local_rank_info.tp_rank - - # If this worker's PP rank is not in the send list, replace with empty DataProto - if pp_rank not in verl_pp_send_rank or tp_rank != 0: - # Create empty DataProto with shape information from original args - for arg_idx, arg in enumerate(all_args): - if isinstance(arg[rank], (DataProto, DataProtoFuture)): - # Get the original DataProto to extract shape information - original_arg = arg[rank] - if original_arg is not None and isinstance(original_arg, DataProto): - all_args[arg_idx][rank] = _make_dummy_data_proto(original_arg) - - # Handle kwargs similarly - for key, val_list in all_kwargs.items(): - if isinstance(val_list[rank], (DataProto, DataProtoFuture)): - original_val = val_list[rank] - if original_val is not None and isinstance(original_val, DataProto): - all_kwargs[key][rank] = _make_dummy_data_proto(original_val) - - return all_args, all_kwargs - - -# Global registry for dispatch mode. -DISPATCH_MODE_FN_REGISTRY = { - Dispatch.ONE_TO_ALL: { - "dispatch_fn": dispatch_one_to_all, - "collect_fn": collect_all_to_all, - }, - Dispatch.ALL_TO_ALL: { - "dispatch_fn": dispatch_all_to_all, - "collect_fn": collect_all_to_all, - }, - Dispatch.MEGATRON_COMPUTE: { - "dispatch_fn": dispatch_megatron_compute, - "collect_fn": collect_megatron_compute, - }, - Dispatch.MEGATRON_PP_AS_DP: { - "dispatch_fn": dispatch_megatron_pp_as_dp, - "collect_fn": collect_megatron_pp_as_dp, - }, - Dispatch.MEGATRON_PP_ONLY: {"dispatch_fn": dispatch_one_to_all, "collect_fn": collect_megatron_pp_only}, - Dispatch.MEGATRON_COMPUTE_PROTO: { - "dispatch_fn": dispatch_megatron_compute_data_proto, - "collect_fn": collect_megatron_compute_data_proto, - }, - Dispatch.MEGATRON_PP_AS_DP_PROTO: { - "dispatch_fn": dispatch_megatron_pp_as_dp_data_proto, - "collect_fn": collect_megatron_pp_as_dp_data_proto, - }, - Dispatch.DP_COMPUTE: {"dispatch_fn": dispatch_dp_compute, "collect_fn": collect_dp_compute}, - Dispatch.DP_COMPUTE_PROTO: { - "dispatch_fn": dispatch_dp_compute_data_proto, - "collect_fn": collect_dp_compute_data_proto, - }, - Dispatch.DP_COMPUTE_PROTO_WITH_FUNC: { - "dispatch_fn": dispatch_dp_compute_data_proto_with_func, - "collect_fn": collect_dp_compute_data_proto, - }, - Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute}, - Dispatch.DIRECT_ROLLOUT_METHOD: { - "dispatch_fn": dummy_direct_rollout_call, - "collect_fn": dummy_direct_rollout_call, - }, - # Added by Reasoning360 - Dispatch.MEGATRON_PP_DUMMY_PROTO: { - "dispatch_fn": dispatch_megatron_pp_dummy_data_proto, - "collect_fn": collect_megatron_compute_data_proto, - }, -} - - -def get_predefined_dispatch_fn(dispatch_mode): - return DISPATCH_MODE_FN_REGISTRY[dispatch_mode] - - -def register_dispatch_mode(dispatch_mode_name, dispatch_fn, collect_fn): - """ - Register a new dispatch mode. - """ - dispatch_mode = Dispatch.register(dispatch_mode_name) - _check_dispatch_mode(dispatch_mode) - assert dispatch_mode not in DISPATCH_MODE_FN_REGISTRY, f"dispatch_mode_name {dispatch_mode_name} already exists" - DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {"dispatch_fn": dispatch_fn, "collect_fn": collect_fn} - - -def update_dispatch_mode(dispatch_mode, dispatch_fn, collect_fn): - """ - Update the dispatch mode. - """ - _check_dispatch_mode(dispatch_mode) - assert dispatch_mode in DISPATCH_MODE_FN_REGISTRY, f"dispatch_mode {dispatch_mode} not found" - DISPATCH_MODE_FN_REGISTRY[dispatch_mode] = {"dispatch_fn": dispatch_fn, "collect_fn": collect_fn} - - -def get_predefined_execute_fn(execute_mode): - """ - Note that here we only asks execute_all and execute_rank_zero to be implemented - Leave the choice of how these two functions handle argument 'blocking' to users - """ - predefined_execute_mode_fn = { - Execute.ALL: {"execute_fn_name": "execute_all"}, - Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"}, - } - return predefined_execute_mode_fn[execute_mode] - - -def _check_dispatch_mode(dispatch_mode): - assert isinstance(dispatch_mode, Dispatch | dict), ( - f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}" - ) - if isinstance(dispatch_mode, dict): - necessary_keys = ["dispatch_fn", "collect_fn"] - for key in necessary_keys: - assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary" - - -def _check_execute_mode(execute_mode): - assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}" - - -def _materialize_futures(*args, **kwargs): - new_args = [] - for arg in args: - if isinstance(arg, DataProtoFuture): - arg = arg.get() - # add more type to materialize - new_args.append(arg) - for k, v in kwargs.items(): - if isinstance(v, DataProtoFuture): - kwargs[k] = v.get() - - new_args = tuple(new_args) - return new_args, kwargs - - -def _materialize_dummy(*args, **kwargs): - from verl.protocol import DataProto - - new_args = [] - for arg in args: - if isinstance(arg, DataProto): - arg = _materialize_dummy_data_proto(arg) - new_args.append(arg) - for k in kwargs: - if isinstance(kwargs[k], DataProto): - kwargs[k] = _materialize_dummy_data_proto(kwargs[k]) - return tuple(new_args), kwargs - - -def register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.ALL, blocking=True, materialize_futures=True): - """Register a function with distributed execution configuration. - - This decorator registers a function with specific dispatch and execution modes - for distributed computation. It handles both synchronous and asynchronous - functions, and optionally materializes futures before execution. - - Args: - dispatch_mode: - Dispatch mode for computation distribution. Default: Dispatch.ALL_TO_ALL. - execute_mode: - Execute mode for computation distribution. Default: Execute.ALL. - blocking: - Whether the execution should be blocking. Defaults to True. - materialize_futures: - Whether to materialize the data before dispatching. Defaults to True. - - Returns: - A decorator that wraps the original function with distributed execution - configuration. - """ - _check_dispatch_mode(dispatch_mode=dispatch_mode) - _check_execute_mode(execute_mode=execute_mode) - - # Added by Reasoning360 - materialize_dummy = dispatch_mode == Dispatch.MEGATRON_PP_DUMMY_PROTO - - def decorator(func): - @wraps(func) - def inner(*args, **kwargs): - if materialize_futures: - args, kwargs = _materialize_futures(*args, **kwargs) - if materialize_dummy: - args, kwargs = _materialize_dummy(*args, **kwargs) - return func(*args, **kwargs) - - @wraps(func) - async def async_inner(*args, **kwargs): - if materialize_futures: - args, kwargs = _materialize_futures(*args, **kwargs) - # Added by Reasoning360 - if materialize_dummy: - args, kwargs = _materialize_dummy(*args, **kwargs) - return await func(*args, **kwargs) - - wrapper = async_inner if inspect.iscoroutinefunction(func) else inner - attrs = {"dispatch_mode": dispatch_mode, "execute_mode": execute_mode, "blocking": blocking} - setattr(wrapper, MAGIC_ATTR, attrs) - return wrapper - - return decorator diff --git a/verl/single_controller/base/megatron/__init__.py b/verl/single_controller/base/megatron/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/single_controller/base/megatron/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/single_controller/base/megatron/worker.py b/verl/single_controller/base/megatron/worker.py deleted file mode 100644 index baf6eb839..000000000 --- a/verl/single_controller/base/megatron/worker.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from verl.single_controller.base.worker import DistGlobalInfo, DistRankInfo, Worker - - -class MegatronWorker(Worker): - def __init__(self, cuda_visible_devices=None) -> None: - super().__init__(cuda_visible_devices) - - def get_megatron_global_info(self): - from megatron.core import parallel_state as mpu - - tp_size = mpu.get_tensor_model_parallel_world_size() - dp_size = mpu.get_data_parallel_world_size() - pp_size = mpu.get_pipeline_model_parallel_world_size() - cp_size = mpu.get_context_parallel_world_size() - info = DistGlobalInfo(tp_size=tp_size, dp_size=dp_size, pp_size=pp_size, cp_size=cp_size) - return info - - def get_megatron_rank_info(self): - from megatron.core import parallel_state as mpu - - tp_rank = mpu.get_tensor_model_parallel_rank() - dp_rank = mpu.get_data_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() - cp_rank = mpu.get_context_parallel_rank() - info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank) - return info - - def _init_hf_config_and_tf_config( - self, - model_path, - tokenizer_or_path, - dtype, - override_model_config, - override_transformer_config, - trust_remote_code=False, - use_mbridge=False, - ): - from transformers import AutoConfig - - from verl.models.mcore import hf_to_mcore_config - from verl.utils import hf_processor, hf_tokenizer - from verl.utils.fs import copy_to_local - from verl.utils.model import update_model_config - - # Step 1: initialize the tokenizer - self.local_path = copy_to_local(model_path) - if tokenizer_or_path is None: - self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code) - self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code) - elif isinstance(tokenizer_or_path, str): - self.tokenizer = hf_tokenizer(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) - self.processor = hf_processor(copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code) - else: - self.tokenizer = tokenizer_or_path - self.processor = tokenizer_or_path - - if self.config.model.get("custom_chat_template", None) is not None: - if self.processor is not None: - self.processor.chat_template = self.config.model.custom_chat_template - else: - self.tokenizer.chat_template = self.config.model.custom_chat_template - - # Step 2: get the hf - hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code) - - # Step 3: override the hf config - override_config_kwargs = { - "bos_token_id": self.tokenizer.bos_token_id, - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - } - override_config_kwargs.update(override_model_config.get("model_config", {})) - self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) - update_model_config(hf_config, override_config_kwargs=override_config_kwargs) - self.architectures = getattr(hf_config, "architectures", None) - if self.rank == 0: - print(f"Model config after override: {hf_config}") - tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config) - - def add_optimization_config_to_tf_config(tf_config): - # add optimization config to tf_config, e.g. checkpointing - if self.config.model.get("enable_gradient_checkpointing", False): - gradient_checkpointing_cfg = dict(self.config.model.get("gradient_checkpointing_kwargs", dict())) - tf_config.recompute_method = gradient_checkpointing_cfg.get("activations_checkpoint_method", "full") - tf_config.recompute_granularity = gradient_checkpointing_cfg.get( - "activations_checkpoint_granularity", "full" - ) - tf_config.recompute_num_layers = gradient_checkpointing_cfg.get("activations_checkpoint_num_layers", -1) - if megatron_config := self.config.get("megatron", {}): - if extra := megatron_config.get("extra", {}): - for k, v in extra.items(): - setattr(tf_config, k, v) - - add_optimization_config_to_tf_config(tf_config) - if use_mbridge: - from verl.models.mcore.mbridge import AutoBridge - - bridge = AutoBridge.from_config(hf_config) - bridge.set_extra_args(**override_transformer_config) - tf_config = bridge.config - self.bridge = bridge - else: - self.bridge = None - - print(f"TF config: {tf_config}") - self.hf_config = hf_config - self.tf_config = tf_config diff --git a/verl/single_controller/base/megatron/worker_group.py b/verl/single_controller/base/megatron/worker_group.py deleted file mode 100644 index b9beb844c..000000000 --- a/verl/single_controller/base/megatron/worker_group.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from verl.single_controller.base import ResourcePool, WorkerGroup - -from .worker import DistGlobalInfo, DistRankInfo - - -class MegatronWorkerGroup(WorkerGroup): - def __init__(self, resource_pool: ResourcePool, **kwargs): - super().__init__(resource_pool=resource_pool, **kwargs) - self._megatron_rank_info = None - self._megatron_global_info: DistGlobalInfo = None - - def init_megatron(self, default_megatron_kwargs: dict = None): - raise NotImplementedError("MegatronWorkerGroup.init_megatron should be overwritten") - - def get_megatron_rank_info(self, rank: int) -> DistRankInfo: - assert 0 <= rank < self.world_size, f"rank must be from [0, world_size), Got {rank}" - return self._megatron_rank_info[rank] - - @property - def tp_size(self): - assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" - return self._megatron_global_info.tp_size - - @property - def dp_size(self): - assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" - return self._megatron_global_info.dp_size - - @property - def pp_size(self): - assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" - return self._megatron_global_info.pp_size - - @property - def cp_size(self): - assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized" - return self._megatron_global_info.cp_size - - def get_megatron_global_info(self): - return self._megatron_global_info diff --git a/verl/single_controller/base/register_center/__init__.py b/verl/single_controller/base/register_center/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/single_controller/base/register_center/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/single_controller/base/register_center/ray.py b/verl/single_controller/base/register_center/ray.py deleted file mode 100644 index ac071cde5..000000000 --- a/verl/single_controller/base/register_center/ray.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import ray - - -@ray.remote -class WorkerGroupRegisterCenter: - def __init__(self, rank_zero_info): - self.rank_zero_info = rank_zero_info - # rank -> node_id - self.workers_info: dict[int, str] = {} - - def get_rank_zero_info(self): - return self.rank_zero_info - - def set_worker_info(self, rank, node_id) -> None: - self.workers_info[rank] = node_id - - def get_worker_info(self) -> dict[int, str]: - return self.workers_info - - -def create_worker_group_register_center(name, info): - return WorkerGroupRegisterCenter.options(name=name).remote(info) diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py deleted file mode 100644 index 2606a3ef3..000000000 --- a/verl/single_controller/base/worker.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -the class for Worker -""" - -import os -import socket -from dataclasses import dataclass - -import ray - -from verl.utils.device import get_torch_device, get_visible_devices_keyword - -from .decorator import Dispatch, Execute, register - - -@dataclass -class DistRankInfo: - tp_rank: int - dp_rank: int - pp_rank: int - cp_rank: int - - -@dataclass -class DistGlobalInfo: - tp_size: int - dp_size: int - pp_size: int - cp_size: int - - -class WorkerHelper: - @staticmethod - def _get_node_ip(): - if os.getenv("WG_BACKEND", None) == "ray": - return ray.util.get_node_ip_address() - else: - raise NotImplementedError("WG_BACKEND now just support ray mode.") - - @staticmethod - def _get_free_port(): - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - def get_availale_master_addr_port(self): - return self._get_node_ip().strip("[]"), str(self._get_free_port()) - - -# we assume that in each WorkerGroup, there is a Master Worker -class Worker(WorkerHelper): - """A distributed worker that handles initialization and configuration for distributed training. - - This class manages worker initialization, configuration, and provides methods for executing - distributed operations. It handles communication settings, device configuration, and worker - metadata management. - """ - - fused_worker_attr_name = "fused_worker_dict" - - def __new__(cls, *args, **kwargs): - """Create a new Worker instance with proper initialization based on environment settings.""" - instance = super().__new__(cls) - - # note that here we use int to distinguish - disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0)) - if disable_worker_init: - return instance - - rank = os.environ.get("RANK", None) - worker_group_prefix = os.environ.get("WG_PREFIX", None) - - # when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init - if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__: - instance._configure_before_init(f"{worker_group_prefix}_register_center", int(rank)) - - return instance - - def _configure_before_init(self, register_center_name: str, rank: int): - """Configure worker settings before initialization. - - Args: - register_center_name (str): - Name of the register center Ray actor for worker coordination - rank (int): - Rank of the worker in the distributed setup - """ - assert isinstance(rank, int), f"rank must be int, instead of {type(rank)}" - - if rank == 0: - master_addr, master_port = self.get_availale_master_addr_port() - rank_zero_info = { - "MASTER_ADDR": master_addr, - "MASTER_PORT": master_port, - } - - if os.getenv("WG_BACKEND", None) == "ray": - from verl.single_controller.base.register_center.ray import create_worker_group_register_center - - self.register_center = create_worker_group_register_center( - name=register_center_name, info=rank_zero_info - ) - - os.environ.update(rank_zero_info) - else: - self.register_center = ray.get_actor(register_center_name) - - # set worker info for node affinity scheduling - ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id())) - - @classmethod - def env_keys(cls): - """The keys of the environment variables that are used to configure the Worker.""" - return [ - "WORLD_SIZE", - "RANK", - "LOCAL_WORLD_SIZE", - "LOCAL_RANK", - "MASTER_ADDR", - "MASTER_PORT", - get_visible_devices_keyword().upper(), - ] - - def __init__(self, cuda_visible_devices=None) -> None: - """Initialize the worker with environment settings and device configuration. - - Args: - cuda_visible_devices (str, optional): - CUDA visible devices configuration. Defaults to None. - """ - # construct a meta from environment variable. Note that the import must be inside the class because - # it is executed remotely - import os - - self._setup_env_cuda_visible_devices() - - world_size = int(os.environ["WORLD_SIZE"]) - rank = int(os.environ["RANK"]) - self._rank = rank - self._world_size = world_size - - master_addr = os.environ["MASTER_ADDR"] - master_port = os.environ["MASTER_PORT"] - - local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - - store = { - "_world_size": world_size, - "_rank": rank, - "_local_world_size": local_world_size, - "_local_rank": local_rank, - "_master_addr": master_addr, - "_master_port": master_port, - } - if cuda_visible_devices is not None: - store[f"_{get_visible_devices_keyword()}".lower()] = cuda_visible_devices - - self._configure_with_store(store=store) - - self.fused_worker_dict = {} - - def get_fused_worker_by_name(self, worker_name: str): - """Get a fused worker by its name. - - Args: - worker_name (str): - Name of the worker to retrieve - """ - return self.fused_worker_dict.get(worker_name, None) - - def _setup_env_cuda_visible_devices(self): - from verl.utils.ray_utils import ray_noset_visible_devices - - is_ray_noset_visible_devices = ray_noset_visible_devices() - - # Prevent use of clashing `{CUDA/HIP/ROCR}_VISIBLE_DEVICES`` - rocr_val = os.environ.get("ROCR_VISIBLE_DEVICES", None) - hip_val = os.environ.get("HIP_VISIBLE_DEVICES", None) - cuda_val = os.environ.get("CUDA_VISIBLE_DEVICES", None) - if hip_val: - # Switch the use of HIP_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES for consistency. - # Make sure that the HIP_VISIBLE_DEVICES is set to the same value as CUDA_VISIBLE_DEVICES - # at this point. - val = os.environ.pop("HIP_VISIBLE_DEVICES") - hip_val = None - if cuda_val: - assert val == cuda_val, ( - f"Please use the same HIP_VISIBLE_DEVICES or CUDA_VISIBLE_DEVICES, inconsistant values " - f"found: {val} and {cuda_val}." - ) - else: - cuda_val = val - os.environ["CUDA_VISIBLE_DEVICES"] = val - # os.environ["HIP_VISIBLE_DEVICES"] = val - - if rocr_val: - # You must take care if both HIP/CUDA and ROCR env vars are set as they have - # different meanings. Both env vars accept either a list of ints or a - # list of UUIDs. The ROCR env var is processed first which then reduces - # the number of GPUs that HIP can select from. - # https://github.com/pytorch/pytorch/pull/144026 - # To avoid the complexity of this, we simply gives out error if both are set - # (Also to keep consistency with ray's practice with 2.45.0). - # Otherwise, we will set ROCR_VISIBLE_DEVICES to CUDA_VISIBLE_DEVICES - # and remove ROCR_VISIBLE_DEVICES. - if cuda_val: - raise ValueError("Please don't set ROCR_VISIBLE_DEVICES when HIP/CUDA_VISIBLE_DEVICES is set.") - - cuda_val = os.environ.pop("ROCR_VISIBLE_DEVICES") - os.environ["CUDA_VISIBLE_DEVICES"] = cuda_val - rocr_val = None - - if is_ray_noset_visible_devices: - # NOTE: Ray will automatically set the *_VISIBLE_DEVICES - # environment variable for each actor, unless - # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set, - # so we need to set local rank when the flag is set. - local_rank = os.environ.get("RAY_LOCAL_RANK") - os.environ["LOCAL_RANK"] = local_rank - get_torch_device().set_device(int(local_rank)) - - def _configure_with_store(self, store: dict): - """ - This function should only be called inside by WorkerGroup - """ - store_env_dict = {f"_{key.lower()}": store.get(f"_{key.lower()}", None) for key in type(self).env_keys()} - self.__dict__.update(store_env_dict) # this is hacky - # print(f"__dict__: {self.__dict__}") - for key in type(self).env_keys(): - val = self.__dict__.get(f"_{key.lower()}", None) - if val is not None: - # print(f"set {key} to {val}") - os.environ[key] = str(val) - os.environ["REDIS_STORE_SERVER_HOST"] = ( - str(self._master_addr).replace("[", "").replace("]", "") if self._master_addr else "" - ) - - def get_master_addr_port(self): - """Get the master address and port for distributed communication.""" - return self._master_addr, self._master_port - - def get_cuda_visible_devices(self): - """Get the CUDA visible devices configuration.""" - import os - - visible_devices = os.environ.get(get_visible_devices_keyword().upper(), "not set") - return visible_devices - - @property - def world_size(self): - """Get the total number of workers in the distributed setup.""" - return self._world_size - - @property - def rank(self): - """Get the rank of this worker in the distributed setup.""" - return self._rank - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO_WITH_FUNC) - def execute_with_func_generator(self, func, *args, **kwargs): - """Execute a function with function generator dispatch mode. - - Args: - func: - Function to execute - *args: - Positional arguments for the function - **kwargs: - Keyword arguments for the function - """ - ret_proto = func(self, *args, **kwargs) - return ret_proto - - @register(dispatch_mode=Dispatch.ALL_TO_ALL, execute_mode=Execute.RANK_ZERO) - def execute_func_rank_zero(self, func, *args, **kwargs): - """Execute a function in rank zero execution mode. - - Args: - func: - Function to execute - *args: - Positional arguments for the function - **kwargs: - Keyword arguments for the function - """ - result = func(*args, **kwargs) - return result diff --git a/verl/single_controller/base/worker_group.py b/verl/single_controller/base/worker_group.py deleted file mode 100644 index cb86ab4f5..000000000 --- a/verl/single_controller/base/worker_group.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -the class of WorkerGroup -""" - -import logging -import signal -import threading -import time -from typing import Any, Callable - -from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn - - -class ResourcePool: - """ - Manages a pool of resources across multiple nodes, tracking process counts and GPU allocations. - The class provides methods to calculate world size, local world sizes, and local ranks - across all nodes in the pool. - """ - - def __init__(self, process_on_nodes=None, max_colocate_count: int = 10, n_gpus_per_node=8) -> None: - """Initialize the ResourcePool with node processes and GPU configuration. - - Args: - process_on_nodes (List[int], optional): List of process counts per node. Defaults to empty list. - max_colocate_count (int, optional): Maximum number of processes that can be colocated. Defaults to 10. - n_gpus_per_node (int, optional): Number of GPUs available per node. Defaults to 8. - """ - if process_on_nodes is None: - process_on_nodes = [] - self._store = process_on_nodes - self.max_colocate_count = max_colocate_count - self.n_gpus_per_node = n_gpus_per_node # this is left for future huawei GPU that contains 16 GPUs per node - - def add_node(self, process_count): - self._store.append(process_count) - - @property - def world_size(self): - """Total number of processes across all nodes in the pool.""" - return sum(self._store) - - def __call__(self) -> Any: - return self._store - - @property - def store(self): - return self._store - - def local_world_size_list(self) -> list[int]: - """Returns a flat list where each process has its local world size.""" - nested_local_world_size_list = [ - [local_world_size for _ in range(local_world_size)] for local_world_size in self._store - ] - return [item for row in nested_local_world_size_list for item in row] - - def local_rank_list(self) -> list[int]: - """Returns a flat list of local ranks for all processes across all nodes.""" - nested_local_rank_list = [[i for i in range(local_world_size)] for local_world_size in self._store] - return [item for row in nested_local_rank_list for item in row] - - -class ClassWithInitArgs: - """ - Wrapper class that stores constructor arguments for deferred instantiation. - This class is particularly useful for remote class instantiation where - the actual construction needs to happen at a different time or location. - """ - - def __init__(self, cls, *args, **kwargs) -> None: - """Initialize the ClassWithInitArgs instance. - - Args: - cls: The class to be instantiated later - *args: Positional arguments for the class constructor - **kwargs: Keyword arguments for the class constructor - """ - self.cls = cls - self.args = args - self.kwargs = kwargs - - self.fused_worker_used = False - - def __call__(self) -> Any: - """Instantiate the stored class with the stored arguments.""" - return self.cls(*self.args, **self.kwargs) - - -def check_workers_alive(workers: list, is_alive: Callable, gap_time: float = 1) -> None: - """Continuously monitors worker processes and raises SIGABRT if any worker dies. - - Args: - workers (List): - List of worker objects to monitor - is_alive (Callable): - Function to check if a worker is alive - gap_time (float): - Time interval between checks - """ - import time - - while True: - for worker in workers: - if not is_alive(worker): - logging.warning(f"worker {worker} is not alive sending signal to main thread") - signal.raise_signal(signal.SIGABRT) - time.sleep(gap_time) - - -class WorkerGroup: - """ - Base class for managing a group of workers in a distributed system. - The class provides methods for worker management, aliveness checking, and method binding. - """ - - fused_worker_execute_fn_name = "_fuw_execute" - - def __init__(self, resource_pool: ResourcePool, **kwargs) -> None: - self._is_init_with_detached_workers = resource_pool is None - - self.fused_worker_used = False - - if resource_pool is not None: - # handle the case when WorkGroup is attached to an existing one - self._procecss_dispatch_config = resource_pool() - else: - self._procecss_dispatch_config = None - - self._workers = [] - self._worker_names = [] - - self._master_addr = None - self._master_port = None - - self._checker_thread: threading.Thread = None - - def _is_worker_alive(self, worker): - """Check if a worker is alive. Must be implemented by derived classes.""" - raise NotImplementedError("WorkerGroup._is_worker_alive called, should be implemented in derived class.") - - def _block_until_all_workers_alive(self) -> None: - """Blocks until all workers in the group are alive.""" - while True: - all_state = [self._is_worker_alive(worker) for worker in self._workers] - if False in all_state: - time.sleep(1) - else: - break - - def start_worker_aliveness_check(self, every_n_seconds=1) -> None: - """Starts a background thread to monitor worker aliveness. - - Args: - every_n_seconds (int): Interval between aliveness checks - """ - # before starting checking worker aliveness, make sure all workers are already alive - self._block_until_all_workers_alive() - - self._checker_thread = threading.Thread( - target=check_workers_alive, args=(self._workers, self._is_worker_alive, every_n_seconds) - ) - self._checker_thread.start() - - @property - def world_size(self): - """Number of workers in the group.""" - return len(self._workers) - - def _bind_worker_method(self, user_defined_cls, func_generator): - """Binds worker methods to the WorkerGroup based on registered attributes. - - Args: - user_defined_cls (type): The class containing methods to bind - func_generator (Callable): Function that generates the bound method - - Returns: - List[str]: List of method names that were successfully bound - """ - method_names = [] - for method_name in dir(user_defined_cls): - try: - method = getattr(user_defined_cls, method_name) - assert callable(method), f"{method_name} in {user_defined_cls} is not callable" - except Exception: - # if it is a property, it will fail because Class doesn't have instance property - continue - - if hasattr(method, MAGIC_ATTR): - # this method is decorated by register - attribute = getattr(method, MAGIC_ATTR) - assert isinstance(attribute, dict), f"attribute must be a dictionary. Got {type(attribute)}" - assert "dispatch_mode" in attribute, "attribute must contain dispatch_mode in its key" - - dispatch_mode = attribute["dispatch_mode"] - execute_mode = attribute["execute_mode"] - blocking = attribute["blocking"] - - # get dispatch fn - if isinstance(dispatch_mode, Dispatch): - # get default dispatch fn - fn = get_predefined_dispatch_fn(dispatch_mode=dispatch_mode) - dispatch_fn = fn["dispatch_fn"] - collect_fn = fn["collect_fn"] - else: - assert isinstance(dispatch_mode, dict) - assert "dispatch_fn" in dispatch_mode - assert "collect_fn" in dispatch_mode - dispatch_fn = dispatch_mode["dispatch_fn"] - collect_fn = dispatch_mode["collect_fn"] - - # get execute_fn_name - execute_mode = get_predefined_execute_fn(execute_mode=execute_mode) - wg_execute_fn_name = execute_mode["execute_fn_name"] - - # get execute_fn from string - try: - execute_fn = getattr(self, wg_execute_fn_name) - assert callable(execute_fn), "execute_fn must be callable" - except Exception: - print(f"execute_fn {wg_execute_fn_name} is invalid") - raise - - # bind a new method to the RayWorkerGroup - func = func_generator( - self, - method_name, - dispatch_fn=dispatch_fn, - collect_fn=collect_fn, - execute_fn=execute_fn, - blocking=blocking, - ) - - try: - setattr(self, method_name, func) - method_names.append(method_name) - except Exception as e: - raise ValueError(f"Fail to set method_name {method_name}") from e - - return method_names diff --git a/verl/single_controller/ray/__init__.py b/verl/single_controller/ray/__init__.py deleted file mode 100644 index d2a5d6d3c..000000000 --- a/verl/single_controller/ray/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .base import ( - RayClassWithInitArgs, - RayResourcePool, - RayWorkerGroup, - create_colocated_worker_cls, - create_colocated_worker_cls_fused, -) - -__all__ = [ - "RayClassWithInitArgs", - "RayResourcePool", - "RayWorkerGroup", - "create_colocated_worker_cls", - "create_colocated_worker_cls_fused", -] diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py deleted file mode 100644 index 6c9495d61..000000000 --- a/verl/single_controller/ray/base.py +++ /dev/null @@ -1,894 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import logging -import time -from copy import deepcopy -from typing import Any, Optional - -import ray -from ray.experimental.state.api import get_actor -from ray.util import list_named_actors -from ray.util.placement_group import PlacementGroup, placement_group -from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy - -from verl.protocol import DataProto, _padding_size_key -from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup -from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch -from verl.utils.py_functional import temp_env_var - -__all__ = ["Worker"] - - -def get_random_string(length: int) -> str: - import random - import string - - letters_digits = string.ascii_letters + string.digits - return "".join(random.choice(letters_digits) for _ in range(length)) - - -def func_generator(self, method_name, dispatch_fn, collect_fn, execute_fn, blocking): - class Functor: - def __call__(this, *args, **kwargs): - args, kwargs = dispatch_fn(self, *args, **kwargs) - padding_count = kwargs.pop(_padding_size_key, 0) - output = execute_fn(method_name, *args, **kwargs) - if blocking: - output = ray.get(output) - output = collect_fn(self, output) - if padding_count > 0: - if isinstance(output, DataProto): - indices = [i for i in range(len(output))][:-padding_count] - output = output.select_idxs(indices) - elif isinstance(output, list): - output = output[:-padding_count] - return output - - # use class type to pass the method_name to get a better observability - return type(method_name, (Functor,), {})() - - -def sort_placement_group_by_node_ip(pgs: list[PlacementGroup]) -> list[PlacementGroup]: - """ - Sort the placement groups by node ip, all bundles in a single placement group should be on the same node. - - FSDPCheckpointManager saves sharded model states and optimizer states in local storage, which requires RANK - to be consistent across nodes when resume from checkpoint. - - With this function, if there's only one resource pool and there's no node change, RANK should be consistent - across nodes in multiple ray jobs, even if the whole ray cluster is restarted. - """ - node_ip = {node["NodeID"]: node["NodeManagerAddress"] for node in ray.nodes()} - pg_ip = {} - for pg in pgs: - specs = ray._private.state.state.placement_group_table(pg.id) - # all bunles should be on the same node - node_id = specs["bundles_to_node_id"][0] - pg_ip[pg.id] = node_ip[node_id] - return sorted(pgs, key=lambda pg: pg_ip[pg.id]) - - -class RayResourcePool(ResourcePool): - def __init__( - self, - process_on_nodes: Optional[list[int]] = None, - use_gpu: bool = True, - name_prefix: str = None, - max_colocate_count: int = 10, - detached=False, - accelerator_type: Optional[str] = None, - ) -> None: - super().__init__(process_on_nodes, max_colocate_count) - self.use_gpu = use_gpu - # print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}") - self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix - self.pgs = None - self.detached = detached - self.accelerator_type = accelerator_type - - def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"): - if self.pgs is not None: - return self.pgs - - pg_name_prefix = ( - name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" - ) - # print(f"pg_name_prefix = {pg_name_prefix}") - if device_name == "npu": - device_name = "NPU" - elif device_name == "cuda": - device_name = "GPU" - - bundle = {"CPU": self.max_colocate_count} - if self.use_gpu: - bundle[device_name] = 1 - if self.accelerator_type is not None: - bundle[self.accelerator_type] = 1e-4 - pg_scheme = [[bundle.copy() for _ in range(process_count)] for process_count in self._store] - - lifetime = "detached" if self.detached else None - - pgs = [ - placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) - for idx, bundles in enumerate(pg_scheme) - ] - - ray.get([pg.ready() for pg in pgs]) - - self.pgs = pgs - return pgs - - -def extract_pg_from_exist( - resource_pools: dict[str, RayResourcePool], src_role_names: list[str], resource_pool: RayResourcePool -) -> list: - src_pgs = [ - pg - for role_name, resource_pool in resource_pools.items() - for pg in resource_pool.get_placement_groups() - if role_name in src_role_names - ] - - sorted_src_pgs = sorted(src_pgs, key=lambda pg: pg.bundle_count, reverse=True) - sorted_process_on_nodes = sorted([(val, idx) for idx, val in enumerate(resource_pool.store)], reverse=True) - - unsorted_pgs: list[tuple[int, PlacementGroup]] = [] - searching_idx = 0 - for request_process, original_idx in sorted_process_on_nodes: - assert searching_idx < len(sorted_src_pgs), f"no enough nodes for request: searching {searching_idx} th node" - assert request_process <= sorted_src_pgs[searching_idx].bundle_count, ( - f"requesting {request_process} processes, bundle count cannot satisfy" - ) - unsorted_pgs.append((original_idx, sorted_src_pgs[searching_idx])) - searching_idx += 1 - - return [pg for _, pg in sorted(unsorted_pgs)] - - -def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool: - assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not" - assert rp1.max_colocate_count == rp2.max_colocate_count, "Both RayResourcePool must has the same max_colocate_count" - assert rp1.n_gpus_per_node == rp2.n_gpus_per_node, "Both RayResourcePool must has the same n_gpus_per_node" - assert rp1.detached == rp2.detached, "Detached ResourcePool cannot be merged with non-detached ResourcePool" - - new_store = rp1.store + rp2.store - - merged = type(rp1)(new_store, rp1.use_gpu, f"{rp1.name_prefix}_{rp2.name_prefix}") - merged.pgs = rp1.get_placement_groups() + rp2.get_placement_groups() - - return merged - - -class RayClassWithInitArgs(ClassWithInitArgs): - """A wrapper class for Ray actors with initialization arguments. - - This class extends ClassWithInitArgs to provide additional functionality for - configuring and creating Ray actors with specific resource requirements and - scheduling strategies. - """ - - def __init__(self, cls, *args, **kwargs) -> None: - # self._options = kwargs.pop('options', dict()) - super().__init__(cls, *args, **kwargs) - self._options = {} - self._additional_resource = {} - - def set_additional_resource(self, additional_resource): - """Set additional resource requirements for the actor. - - Args: - additional_resource: Dictionary specifying additional resource requirements - """ - self._additional_resource = additional_resource - - def update_options(self, options: dict): - """Update the Ray actor creation options. - - Args: - options: Dictionary of options to update - """ - self._options.update(options) - - def __call__( - self, - placement_group, - placement_group_bundle_idx, - use_gpu: bool = True, - num_gpus=1, - sharing_with=None, - device_name="cuda", - ) -> Any: - """Create and return a Ray actor with the configured options. - - Args: - placement_group: Ray placement group for scheduling - placement_group_bundle_idx: Index of the bundle in the placement group - use_gpu: Whether to use GPU resources - num_gpus: Number of GPUs to allocate - sharing_with: Actor to share resources with - device_name: Device for training - - Returns: - A Ray actor handle with the configured options - """ - if sharing_with is not None: - target_node_id = ray.get(sharing_with.get_node_id.remote()) - visible_devices = ray.get(sharing_with.get_cuda_visible_devices.remote()) - options = {"scheduling_strategy": NodeAffinitySchedulingStrategy(node_id=target_node_id, soft=False)} - return self.cls.options(**options).remote(*self.args, cuda_visible_devices=visible_devices, **self.kwargs) - - options = { - "scheduling_strategy": PlacementGroupSchedulingStrategy( - placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_idx - ) - } - options.update(self._options) - - if use_gpu and device_name == "cuda": - options["num_gpus"] = num_gpus - if use_gpu and device_name == "npu": - options["resources"] = {"NPU": num_gpus} - - if len(self._additional_resource) > 1: - for k, v in self._additional_resource.items(): - options[k] = v - - # print("cls:", self.cls) - # print("args: ", self.args) - # print("kwargs: ", self.kwargs) - return self.cls.options(**options).remote(*self.args, **self.kwargs) - - -class RayWorkerGroup(WorkerGroup): - """A group of Ray workers that can be managed collectively. - - This class extends WorkerGroup to provide Ray-specific functionality for - creating and managing groups of Ray actors with specific resource requirements - and scheduling strategies. - """ - - def __init__( - self, - resource_pool: RayResourcePool = None, - ray_cls_with_init: RayClassWithInitArgs = None, - bin_pack: bool = True, - name_prefix: str = None, - detached=False, - worker_names=None, - worker_handles: list[ray.actor.ActorHandle] = None, - ray_wait_register_center_timeout: int = 300, - device_name="cuda", - **kwargs, - ) -> None: - """Initialize a RayWorkerGroup. - - Args: - resource_pool: Resource pool for worker allocation - ray_cls_with_init: Class with initialization arguments for workers - bin_pack: Whether to use strict bin packing for resource allocation - name_prefix: Prefix for worker names - detached: Whether workers should be detached - worker_names: Names of existing workers to attach to - ray_wait_register_center_timeout: Timeout for waiting on register center - **kwargs: Additional keyword arguments - """ - super().__init__(resource_pool=resource_pool, **kwargs) - self.ray_cls_with_init = ray_cls_with_init - self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix - self._ray_wait_register_center_timeout = ray_wait_register_center_timeout - # Whether the WorkerGroup is a Colocate WorkerGroup created by FusedWorker. - self.fused_worker_used = ray_cls_with_init.fused_worker_used - # if a WorkerGroup is spawned from Colocate WorkerGroup, this indicates which sub-class is binded to - # this WorkerGroup. - self.sub_cls_name = "" - self.device_name = device_name - self.profile_steps = kwargs.get("profile_steps", None) - self.worker_nsight_options = kwargs.get("worker_nsight_options", None) - if self.worker_nsight_options is not None and self.worker_nsight_options["capture-range-end"] is None: - self.worker_nsight_options["capture-range-end"] = f"repeat-shutdown:{6 * len(self.profile_steps)}" - - if worker_names is not None and (not self.fused_worker_used): - assert self._is_init_with_detached_workers - self._worker_names = worker_names - - if self._is_init_with_detached_workers: - self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles) - else: - self._init_with_resource_pool( - resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, bin_pack=bin_pack, detached=detached - ) - - if ray_cls_with_init is not None: - self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) - - self.wg_dict = None - self.method_names = [] - - def _is_worker_alive(self, worker: ray.actor.ActorHandle): - """Check if a worker actor is still alive. - - Args: - worker: Ray actor handle to check - - Returns: - bool: True if the worker is alive, False otherwise - """ - worker_state_dict = get_actor(worker._actor_id.hex()) - return worker_state_dict.get("state", "undefined") == "ALIVE" if worker_state_dict is not None else False - - def _init_with_detached_workers(self, worker_names, worker_handles): - # ray.get_actor holds a weak reference to the actor, which causes actors garbage collected unexpectedly - # if we only hold spawn RayWorkerGroup. By passing actor handle explicitly, spawn RayWorkerGroup have - # strong reference to these actors. - # https://github.com/ray-project/ray/pull/45699 - workers = worker_handles if worker_handles else [ray.get_actor(name=name) for name in worker_names] - self._workers = workers - self._world_size = len(worker_names) - - def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached): - """Initialize the worker group by creating new workers from a resource pool. - - Args: - resource_pool: Resource pool for worker allocation - ray_cls_with_init: Class with initialization arguments for workers - bin_pack: Whether to use strict bin packing for resource allocation - detached: Whether workers should be detached - """ - use_gpu = resource_pool.use_gpu - - strategy = "PACK" - if bin_pack: - strategy = "STRICT_PACK" - pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name) - world_size = resource_pool.world_size - self._world_size = world_size - # cia.add_kwarg("_world_size", world_size) - num_gpus = 1 / resource_pool.max_colocate_count - - rank = -1 - local_world_size = resource_pool.store[0] - for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)): - assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the " - for local_rank in range(local_world_size): - rank += 1 - - # we pass in environment variable at option so that Worker can use environment variable to set - env_vars = { - "WORLD_SIZE": str(world_size), - "RANK": str(rank), - "WG_PREFIX": self.name_prefix, - "WG_BACKEND": "ray", - "RAY_LOCAL_WORLD_SIZE": str(local_world_size), - "RAY_LOCAL_RANK": str(local_rank), - } - if rank != 0: - env_vars["MASTER_ADDR"] = self._master_addr - env_vars["MASTER_PORT"] = self._master_port - - import re - - cia_name = type(ray_cls_with_init.cls).__name__ - match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" - cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" - name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5 - - if self.profile_steps and self.device_name == "cuda": - ray_cls_with_init.update_options( - { - "runtime_env": { - "env_vars": env_vars, - "nsight": self.worker_nsight_options, - }, - "name": name, - } - ) - else: - ray_cls_with_init.update_options({"runtime_env": {"env_vars": env_vars}, "name": name}) - - if detached: - ray_cls_with_init.update_options({"lifetime": "detached"}) - - # create a worker - worker = ray_cls_with_init( - placement_group=pg, - placement_group_bundle_idx=local_rank, - use_gpu=use_gpu, - num_gpus=num_gpus, - device_name=self.device_name, - ) - self._workers.append(worker) - self._worker_names.append(name) - - if rank == 0: - register_center_actor = None - actor_name = f"{self.name_prefix}_register_center" - start_time = time.time() - - while time.time() - start_time < self._ray_wait_register_center_timeout: - if actor_name in list_named_actors(): - register_center_actor = ray.get_actor(actor_name) - break - - elapsed = int(time.time() - start_time) - if elapsed % 30 == 0: - logging.warning( - "Waiting for register center actor %s to be ready. Elapsed time: %s seconds out of " - "%s seconds.", - actor_name, - elapsed, - self._ray_wait_register_center_timeout, - ) - time.sleep(1) - - if register_center_actor is None: - raise TimeoutError( - f"Failed to get register_center_actor {actor_name} " - f"in {list_named_actors(all_namespaces=True)} " - f"for {self._ray_wait_register_center_timeout} seconds. " - "Ensure that any lingering Ray resources from previous " - "runs are cleaned up (e.g., by restarting the Ray cluster), " - "or adjust the waiting time by modifying the config " - "`trainer.ray_wait_register_center_timeout`." - ) - - rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote()) - self._master_addr, self._master_port = rank_zero_info["MASTER_ADDR"], rank_zero_info["MASTER_PORT"] - # print(f"rank_zero_info: {rank_zero_info}") - # print(f"master_addr: {self._master_addr}, master_port: {self._master_port}") - - @property - def worker_names(self): - return self._worker_names - - @classmethod - def from_detached( - cls, - name_prefix=None, - worker_names=None, - worker_handles=None, - ray_cls_with_init=None, - **kwargs, - ): - """Create a worker group from existing detached workers. - - Args: - name_prefix: Prefix for worker names - worker_names: Names of existing workers to attach to - ray_cls_with_init: Class with initialization arguments for workers - - Returns: - A new RayWorkerGroup instance - """ - worker_group = cls( - resource_pool=None, - ray_cls_with_init=ray_cls_with_init, - name_prefix=name_prefix, - worker_names=worker_names, - worker_handles=worker_handles, - **kwargs, - ) - return worker_group - - def spawn(self, prefix_set): - """Spawn to a dictionary of worker groups, each with a subset of method with prefix. - - Args: - prefix_set: Set of prefixes to create worker groups for - - Returns: - Dictionary of worker groups keyed by prefix - """ - if self.fused_worker_used: - return self.spawn_fused(prefix_set) - - def _rebind_actor_methods(worker_group, actor_name): - prefix: str = actor_name + "_" - for method_name in dir(worker_group): - if method_name.startswith(prefix): - original_method_name = method_name.removeprefix(prefix) - method = getattr(worker_group, method_name) - setattr(worker_group, original_method_name, method) - - new_worker_group_dict = {} - for prefix in prefix_set: - new_worker_group = self.from_detached( - name_prefix=self.name_prefix, - worker_names=self._worker_names, - worker_handles=self._workers, - ray_cls_with_init=self.ray_cls_with_init, - profile_steps=self.profile_steps, - worker_nsight_options=self.worker_nsight_options, - ) - - _rebind_actor_methods(new_worker_group, prefix) - new_worker_group_dict[prefix] = new_worker_group - return new_worker_group_dict - - def spawn_fused(self, prefix_set): - """Create a dictionary of worker groups for fused workers. - - Args: - prefix_set: Set of prefixes to create worker groups for - - Returns: - Dictionary of worker groups keyed by prefix - """ - wg_dict = dict() - for key in prefix_set: - new_wg = deepcopy(self) - new_wg._bind_worker_method(self.ray_cls_with_init.cls.raw_cls_dict[key], func_generator) - new_wg.sub_cls_name = key - wg_dict[key] = new_wg - return wg_dict - - def fuse(self, prefix_set): - """Fuse multiple worker groups into the current worker group. - - Args: - prefix_set: Set of prefixes to fuse into the worker group - """ - if self.wg_dict is None: - self.wg_dict = self.spawn(prefix_set) - for role_name, role_wg in self.wg_dict.items(): - setattr(self, role_name, role_wg) - self.method_names = self._bind_worker_method(self.ray_cls_with_init.cls, func_generator) - - def _execute_remote_single_worker(self, worker, method_name: str, *args, **kwargs): - """Execute a method on a single worker remotely. - - Args: - worker: The worker actor handle - method_name: Name of the method to execute - *args: Positional arguments for the method - **kwargs: Keyword arguments for the method - - Returns: - Remote object reference to the method execution - """ - if self.fused_worker_used and method_name not in self.method_names: - remote_call = getattr(worker, self.fused_worker_execute_fn_name) - return remote_call.remote(f"{self.sub_cls_name}_fwmn_{method_name}", *args, **kwargs) - # fused worker not used - remote_call = getattr(worker, method_name) - return remote_call.remote(*args, **kwargs) - - def execute_rank_zero_sync(self, method_name: str, *args, **kwargs): - """Execute a method on rank zero worker synchronously. - - Args: - method_name: Name of the method to execute - *args: Positional arguments for the method - **kwargs: Keyword arguments for the method - - Returns: - Result of the method execution - """ - return ray.get(self.execute_rank_zero_async(method_name, *args, **kwargs)) - - def execute_rank_zero_async(self, method_name: str, *args, **kwargs): - """Execute a method on rank zero worker asynchronously. - - Args: - method_name: Name of the method to execute - *args: Positional arguments for the method - **kwargs: Keyword arguments for the method - - Returns: - Remote object reference to the method execution - """ - return self._execute_remote_single_worker(self._workers[0], method_name, *args, **kwargs) - - def execute_rank_zero(self, method_name: str, *args, **kwargs): - """Alias for execute_rank_zero_async. - - Args: - method_name: Name of the method to execute - *args: Positional arguments for the method - **kwargs: Keyword arguments for the method - - Returns: - Remote object reference to the method execution - """ - return self.execute_rank_zero_async(method_name, *args, **kwargs) - - def execute_all(self, method_name: str, *args, **kwargs): - """Alias for execute_all_async. - - Args: - method_name: Name of the method to execute - *args: Positional arguments for the method - **kwargs: Keyword arguments for the method - - Returns: - List of remote object references to the method executions - """ - return self.execute_all_async(method_name, *args, **kwargs) - - def execute_all_sync(self, method_name: str, *args, **kwargs): - """Execute a method on all workers synchronously. - - Args: - method_name: Name of the method to execute - *args: Positional arguments for the method - **kwargs: Keyword arguments for the method - - Returns: - List of results from all workers - """ - return ray.get(self.execute_all_async(method_name, *args, **kwargs)) - - def execute_all_async(self, method_name: str, *args, **kwargs): - """Execute a method on all workers asynchronously. - - Args: - method_name: Name of the method to execute - *args: Positional arguments for the method - **kwargs: Keyword arguments for the method - - Returns: - List of remote object references to the method executions - """ - # Here, we assume that if all arguments in args and kwargs are lists, - # and their lengths match len(self._workers), we'll distribute each - # element in these lists to the corresponding worker - # print(f"execute_all_async: method {method_name}({args}, {kwargs})") - length = len(self._workers) - if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): - if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()): - # print(f"splitting args and kwargs into {length} shards") - result = [] - for i in range(length): - sliced_args = tuple(arg[i] for arg in args) - sliced_kwargs = {k: v[i] for k, v in kwargs.items()} - result.append( - self._execute_remote_single_worker(self._workers[i], method_name, *sliced_args, **sliced_kwargs) - ) - return result - - return [self._execute_remote_single_worker(worker, method_name, *args, **kwargs) for worker in self._workers] - - @property - def master_address(self): - return self._master_addr - - @property - def master_port(self): - return self._master_port - - @property - def workers(self): - return self._workers - - @property - def world_size(self): - return self._world_size - - -""" -Utilities that enables creating workers inside the same ray.Actor, -with code written in separate ray.Actors. -""" - - -# deprecated, switching to FusedWorker -def _bind_workers_method_to_parent(cls, key, user_defined_cls): - """ - Binds the methods of each worker to the WorkerDict. - Note that we only bind public methods that are decorated by register - """ - - for method_name in dir(user_defined_cls): - try: - method = getattr(user_defined_cls, method_name) - assert callable(method), f"{method_name} in {user_defined_cls} is not callable" - except Exception: - # if it is a property, it will fail because Class doesn't have instance property - continue - - if hasattr(method, MAGIC_ATTR): - - def generate_function(name, key=key): - def func(self, *args, **kwargs): - # dispatch to the actual worker - return getattr(self.worker_dict[key], name)(*args, **kwargs) - - async def async_func(self, *args, **kwargs): - # dispatch to the actual worker - return await getattr(self.worker_dict[key], name)(*args, **kwargs) - - wrapper = async_func if inspect.iscoroutinefunction(method) else func # noqa: B023 - - return wrapper - - func = generate_function(method_name) - # pass MAGIC_ATTR for outer worker group - attrs = getattr(method, MAGIC_ATTR) - setattr(func, MAGIC_ATTR, attrs) - try: - # bind direct rollout method to class without prefix - if attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD and "rollout" in key: - assert not hasattr(cls, method_name), ( - f"conflict direct rollout method {method_name} with role {key}" - ) - setattr(cls, method_name, func) - print(f"bind role {key} method {method_name} to class {cls}") - else: - method_name_with_prefix = key + "_" + method_name - setattr(cls, method_name_with_prefix, func) - except Exception as e: - raise ValueError(f"Fail to set method_name {method_name}") from e - - -def _unwrap_ray_remote(cls): - if hasattr(cls, "__ray_actor_class__"): - cls = cls.__ray_actor_class__ - return cls - - -def _determine_fsdp_megatron_base_class(mros: list): - """ - - megatron: base class should be MegatronWorker - - fsdp: base class should be Worker - """ - for cls in mros[0]: - if cls.__name__ == "MegatronWorker": - return cls - if cls.__name__ == "Worker": - return cls - raise ValueError(f"Cannot determine base class for {mros}") - - -# deprecated, switching to FusedWorker -def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): - """ - This function should return a class instance that delegates the calls to every - cls in cls_dict - """ - cls_dict = {} - init_args_dict = {} - worker_cls = _determine_fsdp_megatron_base_class( - [cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()] - ) - assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker" - print(f"colocated worker base class {worker_cls}") - - for key, cls in class_dict.items(): - cls_dict[key] = cls.cls - init_args_dict[key] = {"args": cls.args, "kwargs": cls.kwargs} - - assert cls_dict.keys() == init_args_dict.keys() - - # TODO: create a class with customizable name - class WorkerDict(worker_cls): - def __init__(self): - super().__init__() - self.worker_dict = {} - for key, user_defined_cls in cls_dict.items(): - user_defined_cls = _unwrap_ray_remote(user_defined_cls) - # directly instantiate the class without remote - # in worker class, e.g. - # when DISABLE_WORKER_INIT == 1 it will return immediately - with temp_env_var("DISABLE_WORKER_INIT", "1"): - self.worker_dict[key] = user_defined_cls( - *init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {}) - ) - - # now monkey-patch the methods from inner class to WorkerDict - for key, user_defined_cls in cls_dict.items(): - user_defined_cls = _unwrap_ray_remote(user_defined_cls) - _bind_workers_method_to_parent(WorkerDict, key, user_defined_cls) - - remote_cls = ray.remote(WorkerDict) - remote_cls = RayClassWithInitArgs(cls=remote_cls) - return remote_cls - - -FusedWorkerCLSName = "FusedWorker" - - -def create_colocated_worker_raw_cls(class_dict: dict[str, RayClassWithInitArgs]): - """ - This function returns a FusedWorker class. - - `FusedWorker.{class_name}` -> FusedClass - Use `class_name` as a param to directly access the underlying class. - - `FusedWorker._fuw_execute("{class_name}_fwmn_{method_name}", *args, **kwargs)` - First param must be "{class_name}_fwmn_{method_name}" in order to access `method_name` - of underlying class `{class_name}`. - - `FusedWorker.fused_worker_dict` -> {"class_name": FusedClass} - Stores all underlying classes. - - `FusedClass.fused_worker_dict` -> {"class_name": FusedClass} - The same as `FusedWorker.fused_worker_dict`, enables underlying class to access other - underlying classes. - """ - raw_cls_dict = {cls_name: _unwrap_ray_remote(cia.cls) for cls_name, cia in class_dict.items()} - init_args_dict = {cls_name: cia.args for cls_name, cia in class_dict.items()} - init_kwargs_dict = {cls_name: cia.kwargs for cls_name, cia in class_dict.items()} - cls_names = list(class_dict.keys()) - - # FusedWorker_Actor_Critic - class_name_renamed = "_".join([FusedWorkerCLSName] + cls_names) - - class FusedWorker(Worker): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.cls_names = cls_names - self.raw_cls_dict = raw_cls_dict - self.init_args_dict = init_args_dict - self.init_kwargs_dict = init_kwargs_dict - - for cls_name, udc, ud_args, ud_kwargs in zip( - self.cls_names, - self.raw_cls_dict.values(), - self.init_args_dict.values(), - self.init_kwargs_dict.values(), - strict=True, - ): - with temp_env_var("DISABLE_WORKER_INIT", "1"): - udc._get_ray_actor_cls_name = lambda x, name_renamed=class_name_renamed: name_renamed - udc._get_ray_method_prefix = lambda x, name_prefixed=cls_name: f"{name_prefixed}_" - # cls_name = "actor", "critic", udc = ActorWorker, CriticWorker - self.fused_worker_dict[cls_name] = udc(*ud_args, **ud_kwargs) - setattr(self, cls_name, self.fused_worker_dict[cls_name]) - - # injecting fused_worker to each sub worker so they can be aware of existence of each other - for _, worker in self.fused_worker_dict.items(): - setattr(worker, Worker.fused_worker_attr_name, self.fused_worker_dict) - - def _fuw_execute(self, method_name: str, *args, **kwargs): - # for fused_worker, method_name is in a form of "{cls_name}_fwmn_{method_name}" - # where fwmn stands "fused worker method name" - names = method_name.split("_fwmn_") - cls_name = names[0] - method_name = names[1] - - assert cls_name in self.fused_worker_dict, ( - f"calling {cls_name}'s {method_name}, but {cls_name} not in fused_worker_dict" - ) - udc_method = getattr(self.fused_worker_dict[cls_name], method_name) - return udc_method(*args, **kwargs) - - renamed_fused_worker_cls = type(class_name_renamed, (FusedWorker,), {}) - renamed_fused_worker_cls.is_fused_worker = True - renamed_fused_worker_cls.raw_cls_dict = raw_cls_dict - - return renamed_fused_worker_cls - - -def create_colocated_worker_cls_fused(class_dict: dict[str, RayClassWithInitArgs]): - """ - This function returns a RayClassWithInitArgs instance of FusedWorker, which is an replacement - of `create_colocated_worker_cls`. WorkerGroup constructed using this class will be a colocated - WorkerGroup, which will be referenced as `ColocateWorkerGroup` below. - - `ColocateWorkerGroup.spawn(prefix_set)` - returns a dict of WorkerGroup {"class_name": WorkerGroup}, WorkerGroup in this dict will - have methods of underlying class `class_name` attached. - - `ColocateWorkerGroup.fuse(prefix_set)` - After executing this function, `ColocateWorkerGroup.{class_name}` will return WorkerGroup - with methods of underlying class `class_name` attached. - """ - raw_colocated_worker_cls = create_colocated_worker_raw_cls(class_dict) - - remote_cls = ray.remote(raw_colocated_worker_cls) - cia = RayClassWithInitArgs(cls=remote_cls) - cia.fused_worker_used = True - - return cia diff --git a/verl/single_controller/ray/megatron.py b/verl/single_controller/ray/megatron.py deleted file mode 100644 index b46fe44a1..000000000 --- a/verl/single_controller/ray/megatron.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import ray - -from verl.single_controller.base.megatron.worker import DistGlobalInfo, DistRankInfo -from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup - -from .base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup - - -# NOTE(sgm): for open-source megatron-core -class NVMegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): - """ - MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup - so that the dispatcher can use it to dispatch data. - """ - - def __init__(self, resource_pool: RayResourcePool, ray_cls_with_init: RayClassWithInitArgs, **kwargs): - """ - Initialize the NVMegatronRayWorkerGroup. - - Args: - resource_pool (RayResourcePool): The resource pool containing worker resources - ray_cls_with_init (RayClassWithInitArgs): The Ray class with initialization arguments - **kwargs: Additional keyword arguments to pass to the parent class - """ - super().__init__(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, **kwargs) - self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") - self._megatron_global_info: DistGlobalInfo = ray.get( - self.execute_rank_zero_async(method_name="get_megatron_global_info") - ) - - -class MegatronRayWorkerGroup(RayWorkerGroup, MegatronWorkerGroup): - """ - MegatronWorkerGroup will query each worker of its megatron rank info and store it inside the WorkerGroup - so that the dispatcher can use it to dispatch data. - """ - - def __init__( - self, - resource_pool: RayResourcePool, - ray_cls_with_init: RayClassWithInitArgs, - default_megatron_kwargs: dict = None, - **kwargs, - ): - super().__init__( - resource_pool=resource_pool, - ray_cls_with_init=ray_cls_with_init, - default_megatron_kwargs=default_megatron_kwargs, - **kwargs, - ) - self.init_megatron(default_megatron_kwargs=default_megatron_kwargs) - self._megatron_rank_info: DistRankInfo = self.execute_all_sync(method_name="get_megatron_rank_info") - self._megatron_global_info: DistGlobalInfo = ray.get( - self.execute_rank_zero_async(method_name="get_megatron_global_info") - ) - - def init_megatron(self, default_megatron_kwargs: Optional[dict] = None): - # after super, we will call init of each worker - if not self._is_init_with_detached_workers: - # only init_megatron if the WorkerGroup is created from scratch - self.execute_all_sync(method_name="init_megatron", default_megatron_kwargs=default_megatron_kwargs) diff --git a/verl/third_party/__init__.py b/verl/third_party/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/third_party/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/third_party/sglang/__init__.py b/verl/third_party/sglang/__init__.py deleted file mode 100644 index 15593caaf..000000000 --- a/verl/third_party/sglang/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/third_party/sglang/parallel_state.py b/verl/third_party/sglang/parallel_state.py deleted file mode 100644 index cdec743d1..000000000 --- a/verl/third_party/sglang/parallel_state.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023 The SGlang team. -# Adapted from -# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -"""Model and data parallel groups.""" - -import os -from typing import Optional - -import sglang.srt.distributed.parallel_state as ps -import torch -import torch.distributed -from sglang.srt.distributed.parallel_state import ( - get_pp_group, - get_world_group, - init_distributed_environment, - init_model_parallel_group, -) - -""" -This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. -- We assume the Megatron tp+dp+pp world is already established before calling this function. - -""" - -# Device mesh for using DTensor -_DEVICE_MESH = None - -# Tensor model parallel group that the current rank belongs to. -_TP = None -# Pipeline model parallel group that the current rank belongs to. -_PP = None - - -# This method is for initializing the ParallelGroup when using HybridEngine -# NOTE(linjunrong): this function is for megatron -def initialize_parallel_state( - distributed_init_method: str = "env://", - backend: str = "nccl", - tensor_model_parallel_size: int = 1, - num_tp_per_train_tp: int = 1, - pipeline_model_parallel_size: int = 1, -): - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. - rank = int(os.getenv("RANK", "-1")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - - # Use the world_size set by TORCHRUN - world_size = int(os.getenv("WORLD_SIZE", "-1")) - assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" - init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) - if torch.distributed.get_world_size() > 1: - # NOTE: build a separate inference group with infer tp & micro dp - initialize_model_parallel_for_sglang( - tensor_model_parallel_size=tensor_model_parallel_size, - num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, - ) - else: - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) - - -# NOTE(linjunrong): After init SGLang rollout using class EngineFragment, user should always remember to call -# this function to sync the _TP, _PP define at the beginning of this file. Otherwise, only the conterparts -# inside sglang.srt.distributed are init as ProcessGroup, the symbols defined in this file remain as None. -# It could be weird to maintain two _TP and _PP, I follow the same way to maintain an extra ones for -# verl itself as how it was done in verl.third_party.vllm.parallel_state. Note that the process is a little -# bit different -def ensure_model_parallel_initialized( - tensor_model_parallel_size: int, - pipeline_model_parallel_size: int = 1, - backend: Optional[str] = None, -) -> None: - """Helper to initialize model parallel groups if they are not initialized, - or ensure tensor-parallel and pipeline-parallel sizes are equal to expected - values if the model parallel groups are initialized. - """ - # get the backend of _DEVICE_WORLD_GROUP - backend = backend or torch.distributed.get_backend(get_world_group().device_group) - if not model_parallel_is_initialized(): - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) - return - - assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( - f"tensor parallel group already initialized, but of unexpected size: " - f"{get_tensor_model_parallel_world_size()=} vs. {tensor_model_parallel_size=}" - ) - pp_world_size = get_pp_group().world_size - assert pp_world_size == pipeline_model_parallel_size, ( - f"pipeline parallel group already initialized, but of unexpected size: {pp_world_size=} vs. " - f"{pipeline_model_parallel_size=}" - ) - - -# TODO(sgm): deviate from the v0.5.4, not pp now -# NOTE(linjunrong): the SGLang version using _TP instead of ps._TP -def model_parallel_is_initialized(): - """Check if tensor and pipeline parallel groups are initialized.""" - return _TP is not None - # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) - - -def initialize_model_parallel_for_sglang( - tensor_model_parallel_size: int, - num_tensor_model_parallel_groups_per_train_tp: int = 1, - pipeline_model_parallel_size: int = 1, -) -> None: - pass - - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - - assert isinstance(tensor_model_parallel_size, int) - - # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group - # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group - - # Build the tensor model-parallel groups. - assert ps._TP is None, "tensor model parallel group is already initialized" - - global _TP - - world_size: int = torch.distributed.get_world_size() - - backend = torch.distributed.get_backend() - - num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size - - if num_tensor_model_parallel_groups_per_train_tp == 1: - # if tensor_model_parallel_size == train_tensor_parallel_size: - # using the same tp group as Megatron/vllm - assert _TP is None, "tensor model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - group_ranks.append(ranks) - _TP = init_model_parallel_group( - group_ranks=group_ranks, - local_rank=get_world_group().local_rank, - backend=backend, - use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True, - ) - ps._TP = _TP - # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine - else: - # initialize a micro_dp group and a tp group - # assume training tp=4, infer tp=2, then, weight is partitioned as - # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference - - # Build the inference tp groups - # train_tp = train_tensor_parallel_size - train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size - # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size - assert _TP is None, "tensor model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): - start = train_tp * i - end = train_tp * (i + 1) - for j in range(num_tensor_model_parallel_groups_per_train_tp): - ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) - for i in range(len(ranks)): - ranks[i] += j - group_ranks.append(ranks) - _TP = init_model_parallel_group( - group_ranks=group_ranks, - local_rank=get_world_group().local_rank, - backend=backend, - use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True, - ) - ps._TP = _TP - - # Build the pipeline model-parallel groups. - # global _PIPELINE_MODEL_PARALLEL_GROUP - # global _PIPELINE_GLOBAL_RANKS - # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") - - # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group() - # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks() - - # TODO: init using device mesh (not support hybrid engine now) - # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - global _PP - assert _PP is None, "pipeline model parallel group is already initialized" - group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) - # pipeline parallel does not need custom allreduce - _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) - ps._PP = _PP # for verl - - -def initialize_model_parallel( - tensor_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, - backend: Optional[str] = None, -) -> None: - """ - NOTE: This method is a hack from the open-sourced version without - asertion of world_size = tp * pp - - Initialize model parallel groups. - - Arguments: - tensor_model_parallel_size: number of GPUs used for tensor model - parallelism. - pipeline_model_parallel_size: number of GPUs used for pipeline model - parallelism. - - Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we - use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize - the model pipeline. The present function will - create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: - 4 tensor model-parallel groups: - [g0, g1], [g2, g3], [g4, g5], [g6, g7] - 2 pipeline model-parallel groups: - [g0, g2, g4, g6], [g1, g3, g5, g7] - Note that for efficiency, the caller should make sure adjacent ranks - are on the same DGX box. For example if we are using 2 DGX-1 boxes - with a total of 16 GPUs, rank 0 to 7 belong to the first box and - ranks 8 to 15 belong to the second box. - """ - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) - - # NOTE(sgm) we don't assert world_size == tp * pp - # DP is not managed by vllm but by the VeRL WorkerGroup - # if (world_size != - # tensor_model_parallel_size * pipeline_model_parallel_size): - # raise RuntimeError( - # f"world_size ({world_size}) is not equal to " - # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " - # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") - - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - - global _TP - assert _TP is None, "tensor model parallel group is already initialized" - group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) - group_ranks.append(ranks) - - # message queue broadcaster is only used in tensor model parallel group - if ps._TP is not None: - _TP = ps._TP - else: - _TP = init_model_parallel_group( - group_ranks, - get_world_group().local_rank, - backend, - use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer - use_message_queue_broadcaster=True, - ) - ps._TP = _TP - - # TODO: init using device mesh (not support hybrid engine now) - # Build the pipeline model-parallel groups. - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - global _PP - assert _PP is None, "pipeline model parallel group is already initialized" - group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group_ranks.append(ranks) - # pipeline parallel does not need custom allreduce - if ps._TP is not None: - _PP = ps._TP - else: - _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) - ps._PP = _PP - - -""" -Device mesh utilities -""" - - -def get_device_mesh(): - assert _DEVICE_MESH is not None, "device mesh is not initialized" - return _DEVICE_MESH - - -""" -Tensor model parallel utilities -""" - - -# NOTE(linjunrong): In the vllm version parallel_state.py. verl created its own _TP and _PP as verl want to use -# the process group for some extra purpose. Under the hood, there is no difference between them and the original -# one in vllm.distributed.parallel_state. However, the implementation need to hack the init process of inference -# engine, as we do not maintain another SGLang here, I just use the original _TP and _PP directly. -def get_tensor_model_parallel_group(): - """Get the tensor model parallel group the caller rank belongs to.""" - - assert _TP is not None, "tensor model parallel group is not initialized" - return _TP.device_group - - -def get_tensor_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" - return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) - - -def get_tensor_model_parallel_rank(): - """Return my rank for the tensor model parallel group.""" - return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) - - -def get_tensor_model_parallel_src_rank(): - """Calculate the global rank corresponding to the first local rank - in the tensor model parallel group.""" - global_rank = torch.distributed.get_rank() - local_world_size = get_tensor_model_parallel_world_size() - return (global_rank // local_world_size) * local_world_size diff --git a/verl/third_party/vllm/__init__.py b/verl/third_party/vllm/__init__.py deleted file mode 100644 index 76fe51b3c..000000000 --- a/verl/third_party/vllm/__init__.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from importlib.metadata import PackageNotFoundError, version - -from packaging import version as vs - -from verl.utils.import_utils import is_sglang_available - - -def get_version(pkg): - try: - return version(pkg) - except PackageNotFoundError: - return None - - -package_name = "vllm" -package_version = get_version(package_name) -vllm_version = None - -if package_version is None: - if not is_sglang_available(): - raise ValueError( - f"vllm version {package_version} not supported and SGLang also not Found. Currently supported " - f"vllm versions are 0.7.0+" - ) -elif vs.parse(package_version) >= vs.parse("0.7.0"): - vllm_version = package_version - from vllm import LLM - from vllm.distributed import parallel_state -else: - if vs.parse(package_version) in [vs.parse("0.5.4"), vs.parse("0.6.3")]: - raise ValueError( - f"vLLM version {package_version} support has been removed. vLLM 0.5.4 and 0.6.3 are no longer " - f"supported. Please use vLLM 0.7.0 or later." - ) - if not is_sglang_available(): - raise ValueError( - f"vllm version {package_version} not supported and SGLang also not Found. Currently supported " - f"vllm versions are 0.7.0+" - ) - -__all__ = ["LLM", "parallel_state"] diff --git a/verl/tools/__init__.py b/verl/tools/__init__.py deleted file mode 100644 index c4b932b1a..000000000 --- a/verl/tools/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/tools/base_tool.py b/verl/tools/base_tool.py deleted file mode 100644 index 9a1189d20..000000000 --- a/verl/tools/base_tool.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -from typing import Any, Optional -from uuid import uuid4 - -from verl.utils.rollout_trace import rollout_trace_op - -from .schemas import OpenAIFunctionToolSchema - - -class BaseTool: - """Base class for tools. - - A tool should support the following methods: - - - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. - - `create`: create a tool instance for a trajectory. - - `execute`: execute the tool. - - `calc_reward`: calculate the reward respect to tool state. - - `release`: release the tool instance. - """ - - def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): - self.config = config - self.tool_schema = tool_schema or self.get_openai_tool_schema() - assert self.tool_schema is not None, "Tool schema is not set!" - self.name = self.tool_schema.function.name - print(json.dumps(self.tool_schema.model_dump(exclude_unset=True, exclude_none=True), indent=2)) - - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - return self.tool_schema - - async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: - """Create a tool instance. - - Args: - instance_id: The instance id of the tool. - - Returns: - The instance id of the tool. - """ - if instance_id is None: - return str(uuid4()) - else: - return instance_id - - @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - """Execute the tool. - - Args: - instance_id: The instance id of the tool. - parameters: The json string of the parameters of the tool. - - Returns: tool_response, tool_reward_score, tool_metrics - tool_response: The response str of the tool. - tool_reward_score: The step reward score of the tool. - tool_metrics: The metrics of the tool. - """ - return "Updated the tool state.", 0.0, {} - - async def calc_reward(self, instance_id: str, **kwargs) -> float: - """Calculate the reward of the tool. - - Args: - instance_id: The instance id of the tool. - - Returns: - The reward of the tool. - """ - return 0.0 - - async def release(self, instance_id: str, **kwargs) -> None: - """Release the tool instance. - - Args: - instance_id: The instance id of the tool. - """ - pass diff --git a/verl/tools/geo3k_tool.py b/verl/tools/geo3k_tool.py deleted file mode 100644 index 6ffd6fb2c..000000000 --- a/verl/tools/geo3k_tool.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Copyright Amazon.com, Inc. or its affiliates. -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -from typing import Any, Optional -from uuid import uuid4 - -from verl.utils.reward_score import geo3k -from verl.utils.rollout_trace import rollout_trace_op - -from .base_tool import BaseTool -from .schemas import OpenAIFunctionToolSchema - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class Geo3kTool(BaseTool): - """A demo tool for calculating the reward of geo3k. - - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. - - `create`: create a tool instance for a trajectory. - - `execute`: execute the tool. - - `calc_reward`: calculate the reward respect to tool state. - - `release`: release the tool instance. - """ - - def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): - """ - _tool_schema = OpenAIFunctionToolSchema.model_validate({ - "type": "function", - "function": { - "name": "calc_geo3k_reward", - "description": "A tool for calculating the reward of geo3k", - "parameters": { - "type": "object", - "properties": { - "answer": { - "type": "string", - "description": "The answer to the question, enclosed in \\boxed{}", - }, - }, - "required": ["answer"], - }, - } - }) - """ - super().__init__(config, tool_schema) - self._instance_dict = {} - - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - return self.tool_schema - - async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: - if instance_id is None: - instance_id = str(uuid4()) - self._instance_dict[instance_id] = { - "response": "", - "ground_truth": ground_truth, - "reward": 0.0, - } - return instance_id, None - - @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - answer = parameters.get("answer", "") - if not isinstance(answer, str): - answer = str(answer) - self._instance_dict[instance_id]["response"] = answer - reward = await self.calc_reward(instance_id) - # penalty for non improved answer submission - tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 - # update the reward - self._instance_dict[instance_id]["reward"] = reward - return f"Current parsed {answer=} {reward=}", tool_reward, {} - - async def calc_reward(self, instance_id: str, **kwargs) -> float: - return geo3k.compute_score( - self._instance_dict[instance_id]["response"], - self._instance_dict[instance_id]["ground_truth"], - use_boxed=False, - format_score=0.0, - ) - - async def release(self, instance_id: str, **kwargs) -> None: - del self._instance_dict[instance_id] diff --git a/verl/tools/gsm8k_tool.py b/verl/tools/gsm8k_tool.py deleted file mode 100644 index f6d89134d..000000000 --- a/verl/tools/gsm8k_tool.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -from typing import Any, Optional -from uuid import uuid4 - -from verl.utils.reward_score import gsm8k -from verl.utils.rollout_trace import rollout_trace_op - -from .base_tool import BaseTool -from .schemas import OpenAIFunctionToolSchema - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class Gsm8kTool(BaseTool): - """A demo tool for calculating the reward of gsm8k. - - - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. - - `create`: create a tool instance for a trajectory. - - `execute`: execute the tool. - - `calc_reward`: calculate the reward respect to tool state. - - `release`: release the tool instance. - """ - - def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): - """ - _tool_schema = OpenAIFunctionToolSchema.model_validate({ - "type": "function", - "function": { - "name": "calc_gsm8k_reward", - "description": "A tool for calculating the reward of gsm8k", - "parameters": { - "type": "object", - "properties": { - "answer": { - "type": "string", - "description": "The answer to the question", - }, - }, - "required": ["answer"], - }, - } - }) - """ - super().__init__(config, tool_schema) - self._instance_dict = {} - - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - return self.tool_schema - - async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: - if instance_id is None: - instance_id = str(uuid4()) - self._instance_dict[instance_id] = { - "response": "", - "ground_truth": ground_truth, - "reward": 0.0, - } - return instance_id - - @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - answer = parameters.get("answer", "") - if not isinstance(answer, str): - answer = str(answer) - - if answer.startswith("#### "): - self._instance_dict[instance_id]["response"] = answer - else: - self._instance_dict[instance_id]["response"] = "#### " + answer - - reward = await self.calc_reward(instance_id) - # penalty for non improved answer submission - tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 - # update the reward - self._instance_dict[instance_id]["reward"] = reward - - return f"Current parsed {answer=} {reward=}", tool_reward, {} - - async def calc_reward(self, instance_id: str, **kwargs) -> float: - return gsm8k.compute_score( - self._instance_dict[instance_id]["response"], - self._instance_dict[instance_id]["ground_truth"], - method="flexible", - format_score=0.0, - score=1.0, - ) - - async def release(self, instance_id: str, **kwargs) -> None: - del self._instance_dict[instance_id] diff --git a/verl/tools/mcp_base_tool.py b/verl/tools/mcp_base_tool.py deleted file mode 100644 index dacd18ebe..000000000 --- a/verl/tools/mcp_base_tool.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -import os -from typing import Any, Optional -from uuid import uuid4 - -from fastmcp.exceptions import ClientError - -from verl.tools.utils.mcp_clients.McpClientManager import ClientManager -from verl.utils.rollout_trace import rollout_trace_op - -from .base_tool import BaseTool -from .schemas import OpenAIFunctionToolSchema - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class MCPBaseTool(BaseTool): - def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): - super().__init__(config, tool_schema) - self._instance_dict = {} - self.timeout = config.get("timeout", 30) - - # TODO(hechanghao): create a global client manager to manage the rate limit, client and pool - logger.info(f"Initialized MCPBaseTool with config: {config}") - - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - """Return the OpenAI tool schema.""" - return self.tool_schema - - async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: - """Create a tool instance. - - Args: - instance_id: The instance id of the tool. - - Returns: - The instance id of the tool. - """ - if instance_id is None: - instance_id = str(uuid4()) - self._instance_dict[instance_id] = { - "response": "", - "reward": [], - } - return instance_id - - async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]: - err_msg = "" - try: - call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout) - except ClientError as e: - err_msg = f"\n Tool call failed: {e}" - except ConnectionError as e: - err_msg = f"\n Connection failed: {e}" - except Exception as e: - err_msg = f"\n An unexpected error occurred: {e}" - - logger.debug(f"Tool result for instance {instance_id} with tool {self.name}: {call_tool_result.content}") - result, metadata = self._parse_tool_result(call_tool_result.content) - metadata["api_request_error"] += err_msg - return result, metadata - - @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - if self.name == "" or self.name is None or parameters is None: - error_msg = "Error: 'parameters' is missing or empty." - logger.error(f"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}") - return json.dumps({"result": error_msg}), 0.0, {} - - try: - result_text, metadata = await self._call_tool(instance_id, parameters) - - # Store results in instance dictionary - self._instance_dict[instance_id]["reward"].append(result_text.strip()) - - # Convert metadata to metrics - metrics = { - "query_count": metadata.get("query_count", 0), - "status": metadata.get("status", "unknown"), - "total_results": metadata.get("total_results", 0), - "api_request_error": metadata.get("api_request_error"), - } - - return result_text, 0.0, metrics - - except Exception as e: - error_result = json.dumps({"result": f"Tool execution failed: {e}"}) - logger.error(f"[MCPBaseTool] Execution failed: {e}") - return error_result, 0.0, {"error": str(e)} - - async def calc_reward(self, instance_id: str, **kwargs) -> str: - return self._instance_dict[instance_id]["reward"] - - async def release(self, instance_id: str, **kwargs) -> None: - if instance_id in self._instance_dict: - del self._instance_dict[instance_id] - - def _parse_tool_result(self, content: list) -> tuple[str, dict]: - tools_content = [part.text for part in filter(lambda x: x.type == "text", content)] - return " ".join(tools_content), {} diff --git a/verl/tools/mcp_search_tool.py b/verl/tools/mcp_search_tool.py deleted file mode 100644 index ac823719b..000000000 --- a/verl/tools/mcp_search_tool.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -import os -import re - -from verl.tools.mcp_base_tool import MCPBaseTool - -from .schemas import OpenAIFunctionToolSchema - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class MCPSearchTool(MCPBaseTool): - def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): - super().__init__(config, tool_schema) - - def _parse_tool_result(self, content: list) -> tuple[str, dict]: - res = "" - res_cnt = 0 - query_list = [] - metadata = { - "api_request_error": "", - "status": "unknown", - "total_results": 0, - } - try: - for part in content: - if part.type != "text": - continue - text = part.text.replace("'", '"') - query_match = re.search(r'query"\s*:\s*"([^"]+)"', text) - query = query_match.group(1) if query_match else "" - query_list.append(query) - - title_matches = re.findall(r'"title"\s*:', text) - title_count = len(title_matches) - - results_match = re.search(r'"results"\s*:\s*(\[.*?\])', text, re.DOTALL) - results_content = results_match.group(1) if results_match else "" - - res += results_content - res_cnt += title_count - except json.JSONDecodeError: - err_msg = "json parse error." - logger.error(err_msg) - metadata["api_request_error"] = err_msg - metadata["status"] = "error" - - # update metadata - metadata["status"] = "success" - metadata["queries"] = query_list - metadata["query_count"] = len(query_list) - metadata["total_results"] = res_cnt - return res, metadata diff --git a/verl/tools/sandbox_fusion_tools.py b/verl/tools/sandbox_fusion_tools.py deleted file mode 100644 index c3a2748d9..000000000 --- a/verl/tools/sandbox_fusion_tools.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import threading -from contextlib import ExitStack -from enum import Enum -from typing import Any, Callable, Optional, TypeVar -from uuid import uuid4 - -import ray - -from verl.tools.base_tool import BaseTool -from verl.utils.reward_score.sandbox_fusion.utils import _process_single_case -from verl.utils.rollout_trace import rollout_trace_op - -from .schemas import OpenAIFunctionToolSchema - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -T = TypeVar("T") - - -class PoolMode(Enum): - ThreadMode = 1 - ProcessMode = 2 - - -@ray.remote(concurrency_groups={"acquire": 1, "release": 10}) -class TokenBucketWorker: - def __init__(self, rate_limit: int): - self.rate_limit = rate_limit - # this only used for observalability - self.current_count = 0 - self._semaphore = threading.Semaphore(rate_limit) - - @ray.method(concurrency_group="acquire") - def acquire(self): - self._semaphore.acquire() - self.current_count += 1 - - @ray.method(concurrency_group="release") - def release(self): - self._semaphore.release() - self.current_count -= 1 - - def get_current_count(self): - return self.current_count - - -class ExecutionWorker: - def __init__(self, enable_global_rate_limit=True, rate_limit=10): - self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None - - def _init_rate_limit(self, rate_limit): - # TODO validation for rate_limit - # A Singleton Rate Limitor - return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) - - def ping(self): - return True - - def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: - with ExitStack() as stack: - stack.callback(self.rate_limit_worker.release.remote) - ray.get(self.rate_limit_worker.acquire.remote()) - try: - return fn(*fn_args, **fn_kwargs) - except Exception as e: - # TODO we should make this available to the tool caller - logger.warning(f"Error when executing code: {e}") - - -def init_execution_pool( - num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode -): - if mode == PoolMode.ThreadMode: - return ( - ray.remote(ExecutionWorker) - .options(max_concurrency=num_workers) - .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) - ) - else: - raise NotImplementedError("Process mode is not implemented yet") - # return ray.util.multiprocessing.Pool(processes=num_workers) - - -class SandboxFusionTool(BaseTool): - """A tool for executing the code using sanbox fusion image. - - - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. - - `create`: create a tool instance for a trajectory. - - `execute`: execute the tool. - - `calc_reward`: calculate the reward respect to tool state. - - `release`: release the tool instance. - """ - - def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): - """ - _tool_schema = OpenAIFunctionToolSchema.model_validate({ - "type": "function", - "function": { - "name": "code_interpreter", - "description": "A tool for execute code", - "parameters": { - "type": "object", - "properties": { - "code": { - "type": "string", - "description": "code needs to be execute and grad", - }, - }, - "required": ["code"], - }, - } - }) - """ - super().__init__(config, tool_schema) - self._instance_dict = {} - # TODO: better documentation for the config - self.num_workers = config.get("num_workers", 10) - self.rate_limit = config.get("rate_limit", 10) - self.default_timeout = config.get("default_timeout", 30) - self.default_language = config.get("default_language", "python") - self.enable_global_rate_limit = config.get("enable_global_rate_limit", True) - self.execution_pool = init_execution_pool( - num_workers=self.num_workers, - enable_global_rate_limit=self.enable_global_rate_limit, - rate_limit=self.rate_limit, - mode=PoolMode.ThreadMode, - ) - self.sandbox_fusion_url = config.get("sandbox_fusion_url", "") - self.memory_limit_mb = config.get("memory_limit_mb", 1024) - if self.sandbox_fusion_url == "": - raise ValueError("sandbox_fusion_url is not set") - log_msg = f"Init SandboxFusionTool with config: {config}" - logger.info(log_msg) - - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - return self.tool_schema - - async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: - if instance_id is None: - instance_id = str(uuid4()) - self._instance_dict[instance_id] = { - "response": "", - "ground_truth": ground_truth, - "reward": [], - } - return instance_id - - @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - code = parameters.get("code", "") - timeout = parameters.get("timeout", self.default_timeout) - language = parameters.get("language", self.default_language) - if not isinstance(code, str): - code = str(code) - - result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language) - # sandbox has no score or metrics, use Nones - return result, None, None - - def execute_code(self, instance_id, code, timeout=30, language="python"): - result_status, metadata = _process_single_case( - 0, None, None, self.sandbox_fusion_url, code, timeout, self.memory_limit_mb, language - ) - # we should always expect this since we don't have correct answer - if metadata["run_status"] == "Finished": - actual_output = metadata["stdout"] + metadata["stderr"] - logger.debug(f"actual_output from sandbox fusion: {actual_output},{instance_id}") - return actual_output - else: - return "no stdout here" - - async def calc_reward(self, instance_id: str, **kwargs) -> str: - return self._instance_dict[instance_id]["reward"] - - async def release(self, instance_id: str, **kwargs) -> None: - del self._instance_dict[instance_id] diff --git a/verl/tools/schemas.py b/verl/tools/schemas.py deleted file mode 100644 index c0c65a30e..000000000 --- a/verl/tools/schemas.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -from typing import Any, Literal - -from pydantic import BaseModel - - -class OpenAIFunctionPropertySchema(BaseModel): - """The schema of a parameter in OpenAI format.""" - - type: str - description: str | None = None - enum: list[str] | None = None - - -class OpenAIFunctionParametersSchema(BaseModel): - """The schema of parameters in OpenAI format.""" - - type: str - properties: dict[str, OpenAIFunctionPropertySchema] - required: list[str] - - -class OpenAIFunctionSchema(BaseModel): - """The schema of a function in OpenAI format.""" - - name: str - description: str - parameters: OpenAIFunctionParametersSchema - strict: bool = False - - -class OpenAIFunctionToolSchema(BaseModel): - """The schema of a tool in OpenAI format.""" - - type: str - function: OpenAIFunctionSchema - - -class OpenAIFunctionParsedSchema(BaseModel): - """The parsed schema of a tool in OpenAI format.""" - - name: str - arguments: str # JSON string - - -class OpenAIFunctionCallSchema(BaseModel): - """The parsed schema of a tool in OpenAI format.""" - - name: str - arguments: dict[str, Any] - - @staticmethod - def from_openai_function_parsed_schema( - parsed_schema: OpenAIFunctionParsedSchema, - ) -> tuple["OpenAIFunctionCallSchema", bool]: - has_decode_error = False - try: - arguments = json.loads(parsed_schema.arguments) - except json.JSONDecodeError: - arguments = {} - has_decode_error = True - # If the arguments is not a dict, it means the arguments is not a valid JSON string - if not isinstance(arguments, dict): - arguments = {} - has_decode_error = True - - return OpenAIFunctionCallSchema(name=parsed_schema.name, arguments=arguments), has_decode_error - - -class OpenAIFunctionToolCall(BaseModel): - """The tool call in OpenAI format.""" - - id: str - type: Literal["function"] = "function" - function: OpenAIFunctionCallSchema diff --git a/verl/tools/search_tool.py b/verl/tools/search_tool.py deleted file mode 100644 index 3cc6cda53..000000000 --- a/verl/tools/search_tool.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -import os -import threading -from contextlib import ExitStack -from enum import Enum -from typing import Any, Callable, Optional, TypeVar -from uuid import uuid4 - -import ray -import ray.actor - -from verl.tools.utils.search_r1_like_utils import perform_single_search_batch -from verl.utils.rollout_trace import rollout_trace_op - -from .base_tool import BaseTool -from .schemas import OpenAIFunctionToolSchema - -logger = logging.getLogger(__name__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -T = TypeVar("T") - - -# Adapted from verl/tools/sandbox_fusion_tools.py -class PoolMode(Enum): - """Execution pool mode enumeration.""" - - ThreadMode = 1 - ProcessMode = 2 - - -@ray.remote(concurrency_groups={"acquire": 1, "release": 10}) -class TokenBucketWorker: - """Ray actor for rate limiting using token bucket algorithm.""" - - def __init__(self, rate_limit: int): - self.rate_limit = rate_limit - self.current_count = 0 # For observability - self._semaphore = threading.Semaphore(rate_limit) - - @ray.method(concurrency_group="acquire") - def acquire(self): - """Acquire a token from the bucket.""" - self._semaphore.acquire() - self.current_count += 1 - - @ray.method(concurrency_group="release") - def release(self): - """Release a token back to the bucket.""" - self._semaphore.release() - self.current_count -= 1 - - def get_current_count(self): - """Get current number of acquired tokens.""" - return self.current_count - - -class SearchExecutionWorker: - """Worker for executing search operations with optional rate limiting.""" - - def __init__(self, enable_global_rate_limit=True, rate_limit=10): - self.rate_limit_worker = self._init_rate_limit(rate_limit) if enable_global_rate_limit else None - - def _init_rate_limit(self, rate_limit): - """Initialize singleton rate limiter.""" - return TokenBucketWorker.options(name="rate-limiter", get_if_exists=True).remote(rate_limit) - - def ping(self): - """Health check method.""" - return True - - def execute(self, fn: Callable[..., T], *fn_args, **fn_kwargs) -> T: - """Execute function with optional rate limiting.""" - if self.rate_limit_worker: - with ExitStack() as stack: - stack.callback(self.rate_limit_worker.release.remote) - ray.get(self.rate_limit_worker.acquire.remote()) - try: - return fn(*fn_args, **fn_kwargs) - except Exception as e: - # TODO we should make this available to the tool caller - logger.warning(f"Error when executing search: {e}") - else: - return fn(*fn_args, **fn_kwargs) - - -def init_search_execution_pool( - num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode -): - """Initialize search execution pool.""" - if mode == PoolMode.ThreadMode: - return ( - ray.remote(SearchExecutionWorker) - .options(max_concurrency=num_workers) - .remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit) - ) - else: - raise NotImplementedError("Process mode is not implemented yet") - - -class SearchTool(BaseTool): - """Search tool for retrieving information using external retrieval services. - - This tool provides search functionality with rate limiting and concurrent execution - support through Ray. It integrates with external retrieval services to perform - semantic search operations. - - Methods: - get_openai_tool_schema: Return the tool schema in OpenAI format - create: Create a tool instance for a trajectory - execute: Execute the search tool - calc_reward: Calculate the reward with respect to tool state - release: Release the tool instance - """ - - def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): - """Initialize SearchTool with configuration and schema. - - Args: - config: Configuration dictionary containing tool settings - tool_schema: OpenAI function tool schema definition - - Example tool_schema: - { - "type": "function", - "function": { - "name": "search", - "description": "Searches for relevant information based on queries.", - "parameters": { - "type": "object", - "properties": { - "query_list": { - "type": "array", - "items": {"type": "string"}, - "description": "List of search queries" - } - }, - "required": ["query_list"] - } - } - } - """ - super().__init__(config, tool_schema) - self._instance_dict = {} - - # Worker and rate limiting configuration - self.num_workers = config.get("num_workers", 120) - self.rate_limit = config.get("rate_limit", 120) - self.timeout = config.get("timeout", 30) - - self.enable_global_rate_limit = config.get("enable_global_rate_limit", True) - self.execution_pool = init_search_execution_pool( - num_workers=self.num_workers, - enable_global_rate_limit=self.enable_global_rate_limit, - rate_limit=self.rate_limit, - mode=PoolMode.ThreadMode, - ) - - # Retrieval service configuration - self.retrieval_service_url = config.get("retrieval_service_url") - assert self.retrieval_service_url, "Configuration must include 'retrieval_service_url'" - self.topk = config.get("topk", 3) - if self.retrieval_service_url == "": - raise ValueError("retrieval_service_url is not set") - - logger.info(f"Initialized SearchTool with config: {config}") - - def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: - """Return the OpenAI tool schema.""" - return self.tool_schema - - async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: - """Create a tool instance. - - Args: - instance_id: The instance id of the tool. - - Returns: - The instance id of the tool. - """ - if instance_id is None: - instance_id = str(uuid4()) - self._instance_dict[instance_id] = { - "response": "", - "reward": [], - } - return instance_id - - def execute_search(self, instance_id: str, query_list: list, retrieval_service_url: str, topk: int, timeout: int): - """Execute search operation using retrieval service. - - Args: - instance_id: Tool instance ID - query_list: List of search queries - retrieval_service_url: URL of the retrieval service - topk: Number of top results to return - timeout: Request timeout in seconds - - Returns: - Tuple of (result_text, metadata) - """ - result_text, metadata = perform_single_search_batch( - retrieval_service_url=retrieval_service_url, - query_list=query_list, - topk=topk, - concurrent_semaphore=None, # Ray handles concurrency control - timeout=timeout, - ) - logger.debug(f"Search result for instance {instance_id}: {result_text}") - return result_text, metadata - - @rollout_trace_op - async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]: - """Execute the search tool. - - Args: - instance_id: The instance ID of the tool - parameters: Tool parameters containing query_list and optional timeout - - Returns: tool_response, tool_reward_score, tool_metrics - tool_response: The response str of the tool. - tool_reward_score: The step reward score of the tool. - tool_metrics: The metrics of the tool. - """ - timeout = self.timeout - query_list_from_params = parameters.get("query_list") - - if not query_list_from_params or not isinstance(query_list_from_params, list): - error_msg = "Error: 'query_list' is missing, empty, or not a list in parameters." - logger.error(f"[SearchTool] {error_msg} Received parameters: {parameters}") - return json.dumps({"result": error_msg}), 0.0, {} - - # Execute search using Ray execution pool - try: - result_text, metadata = await self.execution_pool.execute.remote( - self.execute_search, instance_id, query_list_from_params, self.retrieval_service_url, self.topk, timeout - ) - - # Store results in instance dictionary - self._instance_dict[instance_id]["reward"].append(result_text.strip()) - - # Convert metadata to metrics - metrics = { - "query_count": metadata.get("query_count", 0), - "status": metadata.get("status", "unknown"), - "total_results": metadata.get("total_results", 0), - "api_request_error": metadata.get("api_request_error"), - } - - return result_text, 0.0, metrics - - except Exception as e: - error_result = json.dumps({"result": f"Search execution failed: {e}"}) - logger.error(f"[SearchTool] Execution failed: {e}") - return error_result, 0.0, {"error": str(e)} - - async def calc_reward(self, instance_id: str, **kwargs) -> str: - return self._instance_dict[instance_id]["reward"] - - async def release(self, instance_id: str, **kwargs) -> None: - if instance_id in self._instance_dict: - del self._instance_dict[instance_id] diff --git a/verl/tools/utils/__init__.py b/verl/tools/utils/__init__.py deleted file mode 100644 index c4b932b1a..000000000 --- a/verl/tools/utils/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/tools/utils/mcp_clients/McpClientManager.py b/verl/tools/utils/mcp_clients/McpClientManager.py deleted file mode 100644 index ee5fe3119..000000000 --- a/verl/tools/utils/mcp_clients/McpClientManager.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import json -import logging -from typing import Any - -from fastmcp import Client -from fastmcp.client.transports import SSETransport - -from verl.tools.utils.mcp_clients.utils import TokenBucket, mcp2openai - -logger = logging.getLogger(__name__) - - -class MCPClientManager: - rootServerName = "mcpServers" - initialized = False - clients = [] - tool_client_mapping = {} - rate_limiter = None - - async def initialize(self, config_path, rate_limit: float = 10.0): - if self.initialized: - return - """Initialize the MCP Client Manager and start all clients""" - result = self._load_config(config_path) - servers = result[self.rootServerName] - exclude_sse_servers = {self.rootServerName: {}} - for server_name in servers.keys(): - server = servers[server_name] - if "auth_token" in server: - transport = SSETransport(url=server["url"], headers={"Authorization": f"Bearer {server['auth_token']}"}) - client = Client(transport) - self.clients.append(client) - else: - exclude_sse_servers[self.rootServerName][server_name] = server - - if exclude_sse_servers[self.rootServerName]: - self.clients.append(Client(exclude_sse_servers)) - - # Initialize rate limiter - self.rate_limiter = TokenBucket(rate_limit) - self.initialized = True - - async def call_tool(self, tool_name, parameters, timeout): - # Apply rate limiting - while not self.rate_limiter.acquire(): - await asyncio.sleep(0.1) - - client = self.get_client_with_tool_name(tool_name) - async with client: - return await client.call_tool_mcp(tool_name, parameters) - - async def fetch_tool_schemas(self, tool_selected_list: list[str]) -> list[dict]: - tool_schemas = [] - for client in self.clients: - async with client: - tools = await client.list_tools_mcp() - for tool in tools.tools: - if not tool_selected_list: - self.tool_client_mapping[tool.name] = client - tool_schemas.append(mcp2openai(tool)) - elif tool.name in tool_selected_list: - self.tool_client_mapping[tool.name] = client - tool_schemas.append(mcp2openai(tool)) - - return tool_schemas - - def get_client_with_tool_name(self, tool_name: str): - return self.tool_client_mapping[tool_name] - - def _load_config(self, file: str) -> dict[str, Any]: - try: - with open(file) as f: - return json.load(f) - except FileNotFoundError: - logger.warning(f'the "{file}" file was not found') - except Exception: - logger.error(f'there was an error reading the "{file}" file') - - return {} - - -ClientManager = MCPClientManager() diff --git a/verl/tools/utils/mcp_clients/utils.py b/verl/tools/utils/mcp_clients/utils.py deleted file mode 100644 index 22a5f6353..000000000 --- a/verl/tools/utils/mcp_clients/utils.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import threading -import time - -from mcp import Tool - -logger = logging.getLogger(__file__) - - -class TokenBucket: - def __init__(self, rate_limit: float): - self.rate_limit = rate_limit # tokens per second - self.tokens = rate_limit - self.last_update = time.time() - self.lock = threading.Lock() - - def acquire(self) -> bool: - with self.lock: - now = time.time() - # Add new tokens based on time elapsed - new_tokens = (now - self.last_update) * self.rate_limit - self.tokens = min(self.rate_limit, self.tokens + new_tokens) - self.last_update = now - - if self.tokens >= 1: - self.tokens -= 1 - return True - return False - - -def mcp2openai(mcp_tool: Tool) -> dict: - """Convert a MCP Tool to an OpenAI ChatCompletionTool.""" - openai_format = { - "type": "function", - "function": { - "name": mcp_tool.name, - "description": mcp_tool.description, - "parameters": mcp_tool.inputSchema, - "strict": False, - }, - } - if not openai_format["function"]["parameters"].get("required", None): - openai_format["function"]["parameters"]["required"] = [] - return openai_format diff --git a/verl/tools/utils/search_r1_like_utils.py b/verl/tools/utils/search_r1_like_utils.py deleted file mode 100644 index 23669e44c..000000000 --- a/verl/tools/utils/search_r1_like_utils.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -import threading -import time -import traceback -import uuid -from typing import Any, Optional - -import requests - -DEFAULT_TIMEOUT = 30 # Default search request timeout -MAX_RETRIES = 10 -INITIAL_RETRY_DELAY = 1 -API_TIMEOUT = 10 - -logger = logging.getLogger(__name__) - - -def call_search_api( - retrieval_service_url: str, - query_list: list[str], - topk: int = 3, - return_scores: bool = True, - timeout: int = DEFAULT_TIMEOUT, -) -> tuple[Optional[dict[str, Any]], Optional[str]]: - """ - Calls the remote search API to perform retrieval with retry logic for various errors, - using increasing delay between retries. Logs internal calls with a unique ID. - - Args: - retrieval_service_url: The URL of the retrieval service API. - query_list: List of search queries. - topk: Number of top results to return. - return_scores: Whether to return scores. - timeout: Request timeout in seconds. - - Returns: - A tuple (response_json, error_message). - If successful, response_json is the API's returned JSON object, error_message is None. - If failed after retries, response_json is None, error_message contains the error information. - """ - request_id = str(uuid.uuid4()) - log_prefix = f"[Search Request ID: {request_id}] " - - payload = {"queries": query_list, "topk": topk, "return_scores": return_scores} - - headers = {"Content-Type": "application/json", "Accept": "application/json"} - - last_error = None - - for attempt in range(MAX_RETRIES): - try: - logger.info( - f"{log_prefix}Attempt {attempt + 1}/{MAX_RETRIES}: Calling search API at {retrieval_service_url}" - ) - response = requests.post( - retrieval_service_url, - headers=headers, - json=payload, - timeout=timeout, - ) - - # Check for Gateway Timeout (504) and other server errors for retrying - if response.status_code in [500, 502, 503, 504]: - last_error = ( - f"{log_prefix}API Request Error: Server Error ({response.status_code}) on attempt " - f"{attempt + 1}/{MAX_RETRIES}" - ) - logger.warning(last_error) - if attempt < MAX_RETRIES - 1: - delay = INITIAL_RETRY_DELAY * (attempt + 1) - logger.info(f"{log_prefix}Retrying after {delay} seconds...") - time.sleep(delay) - continue - - # Check for other HTTP errors (e.g., 4xx) - response.raise_for_status() - - # If successful (status code 2xx) - logger.info(f"{log_prefix}Search API call successful on attempt {attempt + 1}") - return response.json(), None - - except requests.exceptions.ConnectionError as e: - last_error = f"{log_prefix}Connection Error: {e}" - logger.warning(last_error) - if attempt < MAX_RETRIES - 1: - delay = INITIAL_RETRY_DELAY * (attempt + 1) - logger.info(f"{log_prefix}Retrying after {delay} seconds...") - time.sleep(delay) - continue - except requests.exceptions.Timeout as e: - last_error = f"{log_prefix}Timeout Error: {e}" - logger.warning(last_error) - if attempt < MAX_RETRIES - 1: - delay = INITIAL_RETRY_DELAY * (attempt + 1) - logger.info(f"{log_prefix}Retrying after {delay} seconds...") - time.sleep(delay) - continue - except requests.exceptions.RequestException as e: - last_error = f"{log_prefix}API Request Error: {e}" - break # Exit retry loop on other request errors - except json.JSONDecodeError as e: - raw_response_text = response.text if "response" in locals() else "N/A" - last_error = f"{log_prefix}API Response JSON Decode Error: {e}, Response: {raw_response_text[:200]}" - break # Exit retry loop on JSON decode errors - except Exception as e: - last_error = f"{log_prefix}Unexpected Error: {e}" - break # Exit retry loop on other unexpected errors - - # If loop finishes without returning success, return the last recorded error - logger.error(f"{log_prefix}Search API call failed. Last error: {last_error}") - return None, last_error.replace(log_prefix, "API Call Failed: ") if last_error else "API Call Failed after retries" - - -def _passages2string(retrieval_result): - """Convert retrieval results to formatted string.""" - format_reference = "" - for idx, doc_item in enumerate(retrieval_result): - content = doc_item["document"]["contents"] - title = content.split("\n")[0] - text = "\n".join(content.split("\n")[1:]) - format_reference += f"Doc {idx + 1} (Title: {title})\n{text}\n\n" - return format_reference.strip() - - -def perform_single_search_batch( - retrieval_service_url: str, - query_list: list[str], - topk: int = 3, - concurrent_semaphore: Optional[threading.Semaphore] = None, - timeout: int = DEFAULT_TIMEOUT, -) -> tuple[str, dict[str, Any]]: - """ - Performs a single batch search for multiple queries (original search tool behavior). - - Args: - retrieval_service_url: The URL of the retrieval service API. - query_list: List of search queries. - topk: Number of top results to return. - concurrent_semaphore: Optional semaphore for concurrency control. - timeout: Request timeout in seconds. - - Returns: - A tuple (result_text, metadata). - result_text: The search result JSON string. - metadata: Metadata dictionary for the batch search. - """ - logger.info(f"Starting batch search for {len(query_list)} queries.") - - api_response = None - error_msg = None - - try: - if concurrent_semaphore: - with concurrent_semaphore: - api_response, error_msg = call_search_api( - retrieval_service_url=retrieval_service_url, - query_list=query_list, - topk=topk, - return_scores=True, - timeout=timeout, - ) - else: - api_response, error_msg = call_search_api( - retrieval_service_url=retrieval_service_url, - query_list=query_list, - topk=topk, - return_scores=True, - timeout=timeout, - ) - except Exception as e: - error_msg = f"API Request Exception during batch search: {e}" - logger.error(f"Batch search: {error_msg}") - traceback.print_exc() - - metadata = { - "query_count": len(query_list), - "queries": query_list, - "api_request_error": error_msg, - "api_response": None, - "status": "unknown", - "total_results": 0, - "formatted_result": None, - } - - result_text = json.dumps({"result": "Search request failed or timed out after retries."}) - - if error_msg: - metadata["status"] = "api_error" - result_text = json.dumps({"result": f"Search error: {error_msg}"}) - logger.error(f"Batch search: API error occurred: {error_msg}") - elif api_response: - logger.debug(f"Batch search: API Response: {api_response}") - metadata["api_response"] = api_response - - try: - raw_results = api_response.get("result", []) - if raw_results: - pretty_results = [] - total_results = 0 - - for retrieval in raw_results: - formatted = _passages2string(retrieval) - pretty_results.append(formatted) - total_results += len(retrieval) if isinstance(retrieval, list) else 1 - - final_result = "\n---\n".join(pretty_results) - result_text = json.dumps({"result": final_result}) - metadata["status"] = "success" - metadata["total_results"] = total_results - metadata["formatted_result"] = final_result - logger.info(f"Batch search: Successful, got {total_results} total results") - else: - result_text = json.dumps({"result": "No search results found."}) - metadata["status"] = "no_results" - metadata["total_results"] = 0 - logger.info("Batch search: No results found") - except Exception as e: - error_msg = f"Error processing search results: {e}" - result_text = json.dumps({"result": error_msg}) - metadata["status"] = "processing_error" - logger.error(f"Batch search: {error_msg}") - else: - metadata["status"] = "unknown_api_state" - result_text = json.dumps({"result": "Unknown API state (no response and no error message)."}) - logger.error("Batch search: Unknown API state.") - - return result_text, metadata diff --git a/verl/tools/utils/tool_registry.py b/verl/tools/utils/tool_registry.py deleted file mode 100644 index 5c14d1016..000000000 --- a/verl/tools/utils/tool_registry.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import importlib -import logging -import os -import sys -from enum import Enum - -from omegaconf import OmegaConf - -from verl.tools.schemas import OpenAIFunctionToolSchema - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class ToolType(Enum): - NATIVE = "native" - MCP = "mcp" - - -async def initialize_mcp_tool(tool_cls, tool_config) -> list: - from verl.tools.utils.mcp_clients.McpClientManager import ClientManager - - tool_list = [] - mcp_servers_config_path = tool_config.mcp.mcp_servers_config_path - tool_selected_list = tool_config.mcp.tool_selected_list if "tool_selected_list" in tool_config.mcp else None - await ClientManager.initialize(mcp_servers_config_path, tool_config.config.rate_limit) - # Wait for MCP client to be ready - max_retries = 10 - retry_interval = 2 # seconds - for i in range(max_retries): - tool_schemas = await ClientManager.fetch_tool_schemas(tool_selected_list) - if tool_schemas: - break - if i < max_retries - 1: - logger.debug(f"Waiting for MCP client to be ready, attempt {i + 1}/{max_retries}") - await asyncio.sleep(retry_interval) - else: - raise RuntimeError("Failed to initialize MCP tools after maximum retries") - # mcp registry - assert len(tool_schemas), "mcp tool is empty" - for tool_schema_dict in tool_schemas: - logger.debug(f"tool_schema_dict: {tool_schema_dict}") - tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict) - tool = tool_cls( - config=OmegaConf.to_container(tool_config.config, resolve=True), - tool_schema=tool_schema, - ) - tool_list.append(tool) - return tool_list - - -def get_tool_class(cls_name): - module_name, class_name = cls_name.rsplit(".", 1) - if module_name not in sys.modules: - spec = importlib.util.find_spec(module_name) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - else: - module = sys.modules[module_name] - - tool_cls = getattr(module, class_name) - return tool_cls - - -def initialize_tools_from_config(tools_config_file): - tools_config = OmegaConf.load(tools_config_file) - tool_list = [] - for tool_config in tools_config.tools: - cls_name = tool_config.class_name - tool_type = ToolType(tool_config.config.type) - tool_cls = get_tool_class(cls_name) - - match tool_type: - case ToolType.NATIVE: - if tool_config.get("tool_schema", None) is None: - tool_schema = None - else: - tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) - tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict) - tool = tool_cls( - config=OmegaConf.to_container(tool_config.config, resolve=True), - tool_schema=tool_schema, - ) - tool_list.append(tool) - case ToolType.MCP: - loop = asyncio.get_event_loop() - mcp_tools = loop.run_until_complete(initialize_mcp_tool(tool_cls, tool_config)) - tool_list.extend(mcp_tools) - case _: - raise NotImplementedError - return tool_list diff --git a/verl/trainer/__init__.py b/verl/trainer/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/trainer/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/trainer/config/__init__.py b/verl/trainer/config/__init__.py deleted file mode 100644 index f4cc9b8e2..000000000 --- a/verl/trainer/config/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .algorithm import AlgoConfig, FilterGroupsConfig, KLControlConfig, PFPPOConfig - -__all__ = [ - "AlgoConfig", - "FilterGroupsConfig", - "KLControlConfig", - "PFPPOConfig", -] diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml deleted file mode 100644 index 1d715e919..000000000 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ /dev/null @@ -1,372 +0,0 @@ -# This reference configration yaml is automatically generated via 'scripts/generate_trainer_config.sh' -# in which it invokes 'python3 scripts/print_cfg.py --cfg job' to flatten the 'verl/trainer/config/ppo_trainer.yaml' config fields into a single file. -# Do not modify this file directly. -# The file is usually only for reference and never used. - -actor_rollout_ref: - actor: - strategy: fsdp - ppo_mini_batch_size: 256 - ppo_micro_batch_size: null - ppo_micro_batch_size_per_gpu: null - use_dynamic_bsz: false - ppo_max_token_len_per_gpu: 16384 - clip_ratio: 0.2 - clip_ratio_low: 0.2 - clip_ratio_high: 0.2 - policy_loss: - loss_mode: vanilla - clip_cov_ratio: 0.0002 - clip_cov_lb: 1.0 - clip_cov_ub: 5.0 - kl_cov_ratio: 0.0002 - ppo_kl_coef: 0.1 - clip_ratio_c: 3.0 - loss_agg_mode: token-mean - entropy_coeff: 0 - use_kl_loss: false - use_torch_compile: true - kl_loss_coef: 0.001 - kl_loss_type: low_var_kl - ppo_epochs: 1 - shuffle: false - checkpoint: - save_contents: - - model - - optimizer - - extra - load_contents: ${.save_contents} - optim: - lr: 1.0e-06 - lr_warmup_steps_ratio: 0.0 - total_training_steps: -1 - weight_decay: 0.01 - lr_warmup_steps: -1 - min_lr_ratio: 0.0 - num_cycles: 0.5 - warmup_style: constant - grad_clip: 1.0 - ulysses_sequence_parallel_size: 1 - entropy_from_logits_with_chunking: false - entropy_checkpointing: false - fsdp_config: - wrap_policy: - min_num_params: 0 - param_offload: false - optimizer_offload: false - offload_policy: false - reshard_after_forward: true - fsdp_size: -1 - forward_prefetch: false - ref: - strategy: ${actor_rollout_ref.actor.strategy} - use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} - log_prob_micro_batch_size: null - log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} - log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} - fsdp_config: - param_offload: false - reshard_after_forward: true - forward_prefetch: false - wrap_policy: - min_num_params: 0 - ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} - entropy_from_logits_with_chunking: false - entropy_checkpointing: false - rollout: - name: vllm - mode: sync - temperature: 1.0 - top_k: -1 - top_p: 1 - prompt_length: ${oc.select:data.max_prompt_length,512} - response_length: ${oc.select:data.max_response_length,512} - dtype: bfloat16 - gpu_memory_utilization: 0.5 - ignore_eos: false - enforce_eager: true - free_cache_engine: true - tensor_model_parallel_size: 2 - max_num_batched_tokens: 8192 - max_model_len: null - max_num_seqs: 1024 - log_prob_micro_batch_size: null - log_prob_micro_batch_size_per_gpu: null - log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} - log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} - disable_log_stats: true - do_sample: true - 'n': 1 - multi_stage_wake_up: false - engine_kwargs: - vllm: - swap_space: null - disable_mm_preprocessor_cache: false - sglang: - attention_backend: null - val_kwargs: - top_k: -1 - top_p: 1.0 - temperature: 0 - 'n': 1 - do_sample: false - multi_turn: - enable: false - max_assistant_turns: null - tool_config_path: null - max_user_turns: null - max_parallel_calls: 1 - max_tool_response_length: 256 - tool_response_truncate_side: middle - interaction_config_path: null - completion_callback: null - use_inference_chat_template: false - tokenization_sanity_check_mode: strict - format: hermes - calculate_log_probs: false - agent: - num_workers: 8 - agent_loop_config_path: null - custom_async_server: - path: null - name: null - update_weights_bucket_megabytes: 512 - trace: - backend: null - token2text: false - enable_chunked_prefill: true - load_format: dummy_dtensor - layered_summon: false - hybrid_engine: true - model: - path: ~/models/deepseek-llm-7b-chat - custom_chat_template: null - use_shm: false - external_lib: null - override_config: {} - enable_gradient_checkpointing: true - enable_activation_offload: false - use_remove_padding: false - lora_rank: 0 - lora_alpha: 16 - target_modules: all-linear - exclude_modules: null - use_liger: false - use_fused_kernels: false - fused_kernel_options: - impl_backend: torch - trust_remote_code: false - profiler: - _target_: verl.utils.profiler.ProfilerConfig - discrete: false - all_ranks: false - ranks: [] -trainer: - npu_profile: - options: - save_path: ./profiler_data - level: level1 - with_memory: false - record_shapes: false - with_npu: true - with_cpu: true - with_module: false - with_stack: false - analysis: true - balance_batch: true - total_epochs: 30 - total_training_steps: null - profile_steps: null - controller_nsight_options: - trace: cuda,nvtx,cublas,ucx - cuda-memory-usage: 'true' - cuda-graph-trace: graph - worker_nsight_options: - trace: cuda,nvtx,cublas,ucx - cuda-memory-usage: 'true' - cuda-graph-trace: graph - capture-range: cudaProfilerApi - capture-range-end: null - kill: none - project_name: verl_examples - experiment_name: gsm8k - logger: - - console - - wandb - log_val_generations: 0 - rollout_data_dir: null - validation_data_dir: null - nnodes: 1 - n_gpus_per_node: 8 - save_freq: -1 - esi_redundant_time: 0 - resume_mode: auto - resume_from_path: null - val_before_train: true - val_only: false - test_freq: -1 - critic_warmup: 0 - default_hdfs_dir: null - del_local_ckpt_after_load: false - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} - max_actor_ckpt_to_keep: null - max_critic_ckpt_to_keep: null - ray_wait_register_center_timeout: 300 - device: cuda -data: - tokenizer: null - use_shm: false - train_files: ~/data/rlhf/gsm8k/train.parquet - val_files: ~/data/rlhf/gsm8k/test.parquet - prompt_key: prompt - reward_fn_key: data_source - max_prompt_length: 512 - max_response_length: 512 - train_batch_size: 1024 - val_batch_size: null - return_raw_input_ids: false - return_raw_chat: false - return_full_prompt: false - shuffle: true - dataloader_num_workers: 8 - validation_shuffle: false - filter_overlong_prompts: false - filter_overlong_prompts_workers: 1 - truncation: error - image_key: images - video_key: videos - trust_remote_code: false - custom_cls: - path: null - name: null - return_multi_modal_inputs: true - sampler: - class_path: null - class_name: null - datagen: - path: null - name: null -critic: - rollout_n: ${actor_rollout_ref.rollout.n} - strategy: fsdp - optim: - lr_warmup_steps_ratio: 0.0 - total_training_steps: -1 - weight_decay: 0.01 - lr: 1.0e-05 - min_lr_ratio: null - warmup_style: constant - model: - path: ~/models/deepseek-llm-7b-chat - tokenizer_path: ${actor_rollout_ref.model.path} - override_config: {} - external_lib: ${actor_rollout_ref.model.external_lib} - enable_gradient_checkpointing: true - trust_remote_code: ${actor_rollout_ref.model.trust_remote_code} - use_shm: false - enable_activation_offload: false - use_remove_padding: false - fsdp_config: - param_offload: false - optimizer_offload: false - offload_policy: false - reshard_after_forward: true - wrap_policy: - min_num_params: 0 - fsdp_size: -1 - forward_prefetch: false - lora_rank: 0 - lora_alpha: 16 - target_modules: all-linear - ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - ppo_micro_batch_size: null - ppo_micro_batch_size_per_gpu: null - use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - ppo_max_token_len_per_gpu: 32768 - forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} - ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} - shuffle: ${actor_rollout_ref.actor.shuffle} - cliprange_value: 0.5 - loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} - checkpoint: - save_contents: - - model - - optimizer - - extra - load_contents: ${.save_contents} - profiler: - _target_: verl.utils.profiler.ProfilerConfig - discrete: false - all_ranks: false - ranks: [] - forward_micro_batch_size: ${critic.ppo_micro_batch_size} - forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} - ulysses_sequence_parallel_size: 1 - grad_clip: 1.0 -reward_model: - enable: false - strategy: fsdp - model: - input_tokenizer: ${actor_rollout_ref.model.path} - path: ~/models/FsfairX-LLaMA3-RM-v0.1 - external_lib: ${actor_rollout_ref.model.external_lib} - trust_remote_code: false - use_shm: false - use_remove_padding: false - use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} - fsdp_config: - wrap_policy: - min_num_params: 0 - param_offload: false - reshard_after_forward: true - fsdp_size: -1 - forward_prefetch: false - micro_batch_size: null - micro_batch_size_per_gpu: null - max_length: null - use_dynamic_bsz: ${critic.use_dynamic_bsz} - forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} - reward_manager: naive - launch_reward_fn_async: false - sandbox_fusion: - url: null - max_concurrent: 64 - memory_limit_mb: 1024 - profiler: - _target_: verl.utils.profiler.ProfilerConfig - discrete: false - all_ranks: false - ranks: [] - ulysses_sequence_parallel_size: 1 -custom_reward_function: - path: null - name: compute_score -algorithm: - _target_: verl.trainer.config.AlgoConfig - gamma: 1.0 - lam: 1.0 - adv_estimator: gae - norm_adv_by_std_in_grpo: true - use_kl_in_reward: false - kl_penalty: kl - kl_ctrl: - _target_: verl.trainer.config.KLControlConfig - type: fixed - kl_coef: 0.001 - horizon: 10000 - target_kl: 0.1 - filter_groups: - enable: false - metric: null - max_num_gen_batches: 0 - horizon: 10000 - target_kl: 0.1 - use_pf_ppo: false - pf_ppo: - _target_: verl.trainer.config.PFPPOConfig - reweight_method: pow - weight_pow: 2.0 -ray_init: - num_cpus: null - timeline_json_file: null diff --git a/verl/trainer/config/actor/actor.yaml b/verl/trainer/config/actor/actor.yaml deleted file mode 100644 index d5402d870..000000000 --- a/verl/trainer/config/actor/actor.yaml +++ /dev/null @@ -1,111 +0,0 @@ -# Format checks enforced on CI: -# 1. Comments must appear above each field. -# 2. There must be a blank line between each field. -# 3. Inline comments (after a field on the same line) are not allowed. -# 4. Indentation level is respected for nested fields. - -# the abstract actor configs -# fsdp, fsdp2 or megatron. must be set. -strategy: ??? - -# Split each sample into sub-batches of this size for PPO -ppo_mini_batch_size: 256 - -# [Deprecated] Global micro batch size -ppo_micro_batch_size: null - -# Local per-GPU micro batch size -ppo_micro_batch_size_per_gpu: null - -# Whether to automatically adjust batch size at runtime -# oc.select: the default val for ref.log_prob_use_dynamic_bsz -use_dynamic_bsz: false - -# Max tokens per GPU in one PPO batch; affects gradient accumulation -# Typically it should be: n * ${data.max_prompt_length} + ${data.max_response_length} -# oc.select: the default val for ref.log_prob_max_token_len_per_gpu -ppo_max_token_len_per_gpu: 16384 - -# PPO clip ratio -clip_ratio: 0.2 - -# Lower bound for asymmetric clipping (used in dual-clip PPO) -clip_ratio_low: 0.2 - -# Upper bound for asymmetric clipping (used in dual-clip PPO) -clip_ratio_high: 0.2 - -# policy loss config -policy_loss: - - # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617 - loss_mode: "vanilla" - - # Ratio of tokens to be clipped for clip-cov loss - clip_cov_ratio: 0.0002 - - # Lower bound for clip-cov loss - clip_cov_lb: 1.0 - - # Upper bound for clip-cov loss - clip_cov_ub: 5.0 - - # Ratio of tokens to be applied kl penalty for kl-cov loss - kl_cov_ratio: 0.0002 - - # KL divergence penalty coefficient - ppo_kl_coef: 0.1 - -# Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C -clip_ratio_c: 3.0 - -# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" -loss_agg_mode: token-mean - -# Entropy regularization coefficient in PPO loss -entropy_coeff: 0 - -# Whether to use KL loss instead of KL reward penalty. True for GRPO -use_kl_loss: false - -# Whether to use torch.compile() -# oc.select: the default val for ref.use_torch_compile -use_torch_compile: true - -# KL loss coefficient when use_kl_loss is enabled. For GRPO -kl_loss_coef: 0.001 - -# Type of KL divergence loss. Options: "kl"(k1), "abs", "mse"(k2), "low_var_kl"(k3), "full" -kl_loss_type: low_var_kl - -# Number of PPO epochs per batch -ppo_epochs: 1 - -# Shuffle training data across PPO epochs -shuffle: false - -# checkpoint configs -checkpoint: - - # What to include in saved checkpoints - # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - save_contents: ['model', 'optimizer', 'extra'] - - # For more flexibility, you can specify the contents to load from the checkpoint. - # .xxx refers to the local variable xxx from the same level of hierarchy similar to python pkg - load_contents: ${.save_contents} - -# optimizer configs -optim: - - # Learning rate - lr: 1e-6 - - # Warmup steps ratio (used if lr_warmup_steps is negative) - lr_warmup_steps_ratio: 0.0 - - # Total training steps (must be overridden at runtime) - total_training_steps: -1 - - # Weight decay - weight_decay: 0.01 diff --git a/verl/trainer/config/actor/dp_actor.yaml b/verl/trainer/config/actor/dp_actor.yaml deleted file mode 100644 index f298c3cfa..000000000 --- a/verl/trainer/config/actor/dp_actor.yaml +++ /dev/null @@ -1,73 +0,0 @@ -# Format checks enforced on CI: -# 1. Comments must appear above each field. -# 2. There must be a blank line between each field. -# 3. Inline comments (after a field on the same line) are not allowed. -# 4. Indentation level is respected for nested fields. - -# defaults specify the default config from each component -defaults: - - # dp actor config, inheriting from trainer/config/actor/actor.yaml - - actor - - # load the reference default config, then apply the fields in the current yaml - - _self_ - -# TODO(haibin.lin): switch to fsdp2 -strategy: fsdp - -# Gradient clipping for actor updates, specific to the strategy. -grad_clip: 1.0 - -# Sequence parallelism size for Ulysses-style model parallelism -# oc.select: the default val for ref.ulysses_sequence_parallel_size -ulysses_sequence_parallel_size: 1 - -# calculate entropy with chunking to reduce memory peak -entropy_from_logits_with_chunking: False - -# recompute entropy -entropy_checkpointing: False - -# optimizer configs -optim: - - # Warmup steps; negative value delegates to lr_warmup_steps_ratio - lr_warmup_steps: -1 - - # Minimum LR ratio for cosine schedule - min_lr_ratio: 0.0 - - # Number of cosine cycles in LR schedule - num_cycles: 0.5 - - # LR warmup style: "constant" or "cosine" - warmup_style: constant - -# configs for FSDP -fsdp_config: - - # policy for wrapping the model - wrap_policy: - - # Minimum number of parameters to trigger wrapping a layer with FSDP - min_num_params: 0 - - # Whether to offload model parameters to CPU (trades speed for memory) - param_offload: false - - # Whether to offload optimizer state to CPU - optimizer_offload: false - - # Only for FSDP2: offload param/grad/optimizer during train - offload_policy: false - - # Only for FSDP2: Reshard after forward pass to reduce memory footprint - reshard_after_forward: true - - # Number of GPUs in each FSDP shard group; -1 means auto - fsdp_size: -1 - - # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather - # before the current forward computation. - forward_prefetch: False diff --git a/verl/trainer/config/actor/megatron_actor.yaml b/verl/trainer/config/actor/megatron_actor.yaml deleted file mode 100644 index 6492eab2f..000000000 --- a/verl/trainer/config/actor/megatron_actor.yaml +++ /dev/null @@ -1,104 +0,0 @@ -# megatron actor config, inheriting from trainer/config/actor/actor.yaml -defaults: - - actor - # load the reference default config, then apply the fields in the current yaml - - _self_ - -strategy: megatron - -data_loader_seed: null - -load_weight: True - -checkpoint: - - async_save: False - -optim: - - optimizer: adam - - clip_grad: 1.0 - - # initial learning rate for warmup, default to 0.0 - lr_warmup_init: 0.0 - - # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps: null - - lr_decay_steps: null - - # select from constant/linear/cosine/inverse_square_root - lr_decay_style: constant - - # minimum learning rate, default to 0.0 - min_lr: 0.0 - - # select from constant/linear/cosine - weight_decay_incr_style: constant - - # select from constant/exponential/cosine - lr_wsd_decay_style: exponential - - lr_wsd_decay_steps: null - - # use checkpoint optimizer parameter scheduler - use_checkpoint_opt_param_scheduler: False - -megatron: - - param_offload: False - - grad_offload: False - - optimizer_offload: False - - tensor_model_parallel_size: 1 - - expert_model_parallel_size: 1 - - expert_tensor_parallel_size: null - - pipeline_model_parallel_size: 1 - - virtual_pipeline_model_parallel_size: null - - context_parallel_size: 1 - - sequence_parallel: True - - use_distributed_optimizer: True - - use_dist_checkpointing: False - - dist_checkpointing_path: null - - # oc.select: default val for ref.megatron.seed - seed: 42 - - # Allow to override Distributed Data Parallel (DDP) config - override_ddp_config: {} - - # additional transformer config like: num_layers_in_first(/last)_pipeline_stage - # oc.select: default val for ref.megatron.override_transformer_config - override_transformer_config: {} - - # oc.select: default val for ref.megatron.use_mbridge - use_mbridge: False - -# profile the actor model in `update_policy` -profile: - # turn it on when you want to profile the actor model - use_profile: False - - # list, you can specify the ranks to profile - profile_ranks: null - - # start step in update_policy - step_start: -1 - - # end step - step_end: -1 - - # the path to save the profile result - save_path: null diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py deleted file mode 100644 index 5bc6cf943..000000000 --- a/verl/trainer/config/algorithm.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field -from typing import Optional - -from verl.base_config import BaseConfig - - -@dataclass -class KLControlConfig(BaseConfig): - """Configuration for KL control. - - The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. - - Args: - type (str): Type of KL control. Can be "fixed" or "adaptive". - kl_coef (float): Initial coefficient for KL penalty. - horizon (int): Horizon value for adaptive controller. - target_kl (float): Target KL divergence for adaptive controller. - """ - - _frozen_fields = ["type", "kl_coef", "horizon", "target_kl"] - type: str = "fixed" - kl_coef: float = 0.001 - horizon: int = 10000 - target_kl: float = 0.1 - - -@dataclass -class PFPPOConfig(BaseConfig): - """Configuration for preference feedback PPO. - - The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. - - Args: - reweight_method (str): Method for reweighting samples. Can be "pow", "max_min", or "max_random". - weight_pow (float): Power used for weight scaling in "pow" method. - """ - - _frozen_fields = ["reweight_method", "weight_pow"] - reweight_method: str = "pow" - weight_pow: float = 2.0 - - -@dataclass -class FilterGroupsConfig(BaseConfig): - """Configuration for filter groups (used in DAPO and Entropy). - - The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. - - Args: - enable (bool): Whether to enable filter groups. - metric (Optional[str]): Metric to use for filtering: "acc", "score", "seq_reward", "seq_final_reward", etc. - max_num_gen_batches (int): Non-positive values mean no upper limit. - """ - - _frozen_fields = ["enable", "metric", "max_num_gen_batches"] - - enable: bool = False - metric: Optional[str] = None - max_num_gen_batches: int = 0 - - -@dataclass -class AlgoConfig(BaseConfig): - """Configuration for the algorithm. - - The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. - - Args: - gamma (float): Discount factor for future rewards. - lam (float): Trade-off between bias and variance in the GAE estimator. - adv_estimator (str): Advantage estimator type: "gae", "grpo", "reinforce_plus_plus", etc. - norm_adv_by_std_in_grpo (bool): Whether to normalize advantages by std (specific to GRPO). - use_kl_in_reward (bool): Whether to enable in-reward KL penalty. - kl_penalty (str): How to estimate KL divergence: "kl", "abs", "mse", "low_var_kl", or "full". - kl_ctrl (KLControlConfig): KL control configuration. - use_pf_ppo (bool): Whether to enable preference feedback PPO. - pf_ppo (Optional[PFPPOConfig]): Preference feedback PPO settings. - filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy - """ - - _frozen_fields = [ - "gamma", - "lam", - "adv_estimator", - "norm_adv_by_std_in_grpo", - "use_kl_in_reward", - "kl_penalty", - "use_pf_ppo", - ] - - gamma: float = 1.0 - lam: float = 1.0 - adv_estimator: str = "gae" - norm_adv_by_std_in_grpo: bool = True - use_kl_in_reward: bool = False - kl_penalty: str = "kl" - kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig) - use_pf_ppo: bool = False - pf_ppo: Optional[PFPPOConfig] = None - filter_groups: Optional[FilterGroupsConfig] = None diff --git a/verl/trainer/config/critic/critic.yaml b/verl/trainer/config/critic/critic.yaml deleted file mode 100644 index a02fca231..000000000 --- a/verl/trainer/config/critic/critic.yaml +++ /dev/null @@ -1,94 +0,0 @@ -# Number of rollouts per update (mirrors actor rollout_n) -rollout_n: ${actor_rollout_ref.rollout.n} - -# fsdp or fsdp2 strategy used for critic model training -strategy: ??? - -# optimizer configs -optim: - - # Warmup steps ratio; total steps will be injected at runtime - lr_warmup_steps_ratio: 0.0 - - # Total training steps (must be overridden at runtime) - total_training_steps: -1 - - # Weight decay - weight_decay: 0.01 - -# model config for the critic -model: - - # Path to pretrained model weights - path: ~/models/deepseek-llm-7b-chat - - # Tokenizer path (defaults to actor's model path) - tokenizer_path: ${actor_rollout_ref.model.path} - - # Hugging Face config override - override_config: {} - - # External model implementation (optional) - external_lib: ${actor_rollout_ref.model.external_lib} - - # Enable gradient checkpointing to save memory - enable_gradient_checkpointing: True - - # Whether to trust remote code from Hugging Face models - trust_remote_code: ${actor_rollout_ref.model.trust_remote_code} - -# PPO mini-batch size per update -ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - -# [Deprecated] Global micro batch size -ppo_micro_batch_size: null - -# Local per-GPU micro batch size -ppo_micro_batch_size_per_gpu: null - -# Whether to automatically adjust batch size at runtime -use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - -# Max tokens per GPU in one PPO batch (doubled for critic) -ppo_max_token_len_per_gpu: 32768 - -# Max token length per GPU in forward pass -forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} - -# Number of PPO epochs per batch -ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} - -# Shuffle training data across PPO epochs -shuffle: ${actor_rollout_ref.actor.shuffle} - -# PPO value function clipping range -cliprange_value: 0.5 - -# Loss aggregation mode: "token-mean", "seq-mean-token-sum", or "seq-mean-token-mean" -loss_agg_mode: ${actor_rollout_ref.actor.loss_agg_mode} - -# checkpoint configs -checkpoint: - - # What to include in saved checkpoints - # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space - save_contents: ['model', 'optimizer', 'extra'] - - # What to include when loading checkpoints - load_contents: ${.save_contents} - -# profiler configs -# the corresponding dataclass is verl.utils.profiler.ProfilerConfig. -profiler: - - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.utils.profiler.ProfilerConfig - - # True for each task has its own database, False for all tasks in one training step share one database. - discrete: False - - # Whether to profile all ranks. - all_ranks: False - - # The ranks that will be profiled. [] or [0,1,...] - ranks: [] \ No newline at end of file diff --git a/verl/trainer/config/critic/megatron_critic.yaml b/verl/trainer/config/critic/megatron_critic.yaml deleted file mode 100644 index 7db811edc..000000000 --- a/verl/trainer/config/critic/megatron_critic.yaml +++ /dev/null @@ -1,138 +0,0 @@ -# defaults specify the default config from each component -defaults: - - # dp actor config, inheriting from trainer/config/critic/critic.yaml - - critic - - # load the reference default config, then apply the fields in the current yaml - - _self_ - -strategy: megatron - -# seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron -nccl_timeout: 600 - -# optimizer configs -optim: - - # select optimizer, default is Adam - optimizer: adam - - # Learning rate - lr: 1e-6 - - # Clip gradients norm - clip_grad: 1.0 - - # initial learning rate for warmup, default to 0.0 - lr_warmup_init: 0.0 - - # Prioritized. None, 0 or Negative values mean delegating to lr_warmup_steps_ratio. - lr_warmup_steps: null - - lr_decay_steps: null - - # select from constant/linear/cosine/inverse_square_root - lr_decay_style: linear - - # minimum learning rate, default to 0.0 - min_lr: 0.0 - - # select from constant/linear/cosine - weight_decay_incr_style: constant - - # select from constant/exponential/cosine - lr_wsd_decay_style: exponential - - # number of steps for weight std decay - lr_wsd_decay_steps: null - - # use checkpoint optimizer parameter scheduler - use_checkpoint_opt_param_scheduler: False - -# model config for the critic -model: - - # override default empty mapping - override_config: - model_config: {} - moe_config: - freeze_moe_router: False - - # Enable gradient checkpointing to save memory - enable_gradient_checkpointing: False - - # Activation Checkpointing settings - gradient_checkpointing_kwargs: - activations_checkpoint_method: null - activations_checkpoint_granularity: null - activations_checkpoint_num_layers: null - -# megatron-specific parallelism settings -megatron: - - # Whether to offload model parameters to CPU - param_offload: False - - # Whether to offload gradients to CPU - grad_offload: False - - # Whether to offload optimizer state to CPU - optimizer_offload: False - - # size of tensor model parallel group - tensor_model_parallel_size: 1 - - # size of expert model parallel group - expert_model_parallel_size: 1 - - # size of expert tensor parallel group - expert_tensor_parallel_size: null - - # size of pipeline model parallel group - pipeline_model_parallel_size: 1 - - # size of virtual pipeline model parallel group - virtual_pipeline_model_parallel_size: null - - # size of context parallel group - context_parallel_size: 1 - - # Whether to use sequence parallelism - sequence_parallel: True - - # Whether to use distributed optimizer - use_distributed_optimizer: True - - # Whether to use distributed checkpointing - use_dist_checkpointing: False - - # Path for distributed checkpointing - dist_checkpointing_path: null - - # Random seed for Megatron - seed: ${actor_rollout_ref.actor.megatron.seed} - - # Allow to override Distributed Data Parallel (DDP) config - override_ddp_config: ${actor_rollout_ref.actor.megatron.override_ddp_config} - - # Transformer config overrides for Megatron - override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} - - # Whether to use mBridge communications - use_mbridge: ${actor_rollout_ref.actor.megatron.use_mbridge} - -# Whether to load initial weights -load_weight: True - -# seed for data loader -data_loader_seed: ${actor_rollout_ref.actor.data_loader_seed} - -# KL control settings -kl_ctrl: - type: fixed - kl_coef: 0.001 - -# Asynchronous checkpoint saving -checkpoint: - async_save: False \ No newline at end of file diff --git a/verl/trainer/config/data/legacy_data.yaml b/verl/trainer/config/data/legacy_data.yaml deleted file mode 100644 index 9a5ce8f0d..000000000 --- a/verl/trainer/config/data/legacy_data.yaml +++ /dev/null @@ -1,109 +0,0 @@ -# Tokenizer class or path. If null, it will be inferred from the model. -tokenizer: null - -# Whether to use shared memory for data loading. -use_shm: False - -# Training set parquet. Can be a list or a single file. -# The program will read all files into memory, so it can't be too large (< 100GB). -# The path can be either a local path or an HDFS path. -# For HDFS path, we provide utils to download it to DRAM and convert it to a local path. -train_files: ~/data/rlhf/gsm8k/train.parquet - -# Validation parquet. Can be a list or a single file. -val_files: ~/data/rlhf/gsm8k/test.parquet - -# The field in the dataset where the prompt is located. Default is 'prompt'. -prompt_key: prompt - -# The field used to select the reward function (if using different ones per example). -reward_fn_key: data_source - -# Maximum prompt length. All prompts will be left-padded to this length. -# An error will be reported if the length is too long. -# oc.select: default val for rollout.prompt_length -max_prompt_length: 512 - -# Maximum response length. Rollout in RL algorithms (e.g. PPO) generates up to this length. -# oc.select: default val for rollout.response_length -max_response_length: 512 - -# Batch size sampled for one training iteration of different RL algorithms. -train_batch_size: 1024 - -# Batch size used during validation. Can be null. -val_batch_size: null - -# Whether to return the original input_ids without adding chat template. -# This is used when the reward model's chat template differs from the policy. -# If using a model-based RM with different templates, this should be True. -return_raw_input_ids: False - -# Whether to return the original chat (prompt) without applying chat template. -return_raw_chat: False - -# Whether to return the full prompt with chat template. -return_full_prompt: False - -# Whether to shuffle the data in the dataloader. -shuffle: True - -# num dataloader workers -dataloader_num_workers: 8 - -# Whether to shuffle the validation set. -validation_shuffle: False - -# Whether to filter overlong prompts. -filter_overlong_prompts: False - -# Number of workers for filtering overlong prompts. -# For large-scale datasets, filtering can be time-consuming. -# Use multiprocessing to speed up. Default is 1. -filter_overlong_prompts_workers: 1 - -# Truncate the input_ids or prompt if they exceed max_prompt_length. -# Options: 'error', 'left', 'right', 'middle'. Default is 'error'. -truncation: error - -# The field in the multi-modal dataset where the image is located. Default is 'images'. -image_key: images - -# The field in the multi-modal dataset where the video is located. -video_key: videos - -# If the remote tokenizer has a Python file, this flag determines whether to allow using it. -trust_remote_code: False - -# Optional: specify a custom dataset class path and name if overriding default loading behavior. -custom_cls: - - # The path to the file containing your customized dataset class. If not specified, pre-implemented dataset will be used. - path: null - - # The name of the dataset class within the specified file. - name: null - -# Whether to return multi-modal inputs in the dataset. Set to False if rollout generates new multi-modal inputs. -return_multi_modal_inputs: True - -# settings related to data sampler -sampler: - - # the path to the module containing a curriculum class which implements the - # AbstractSampler interface - class_path: null - - # the name of the curriculum class like `MySampler` - class_name: null - -# Data generation configuration for augmenting the dataset. -datagen: - - # The path to the file containing your customized data generation class. - # E.g. 'pkg://verl.experimental.dynamic_dataset.dynamicgen_dataset' - path: null - - # The class name of the data generation class within the specified file. - # E.g. 'MockDataGenerator' - name: null \ No newline at end of file diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml deleted file mode 100644 index edafae297..000000000 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ /dev/null @@ -1,149 +0,0 @@ -# specify the default per-component configs -defaults: - - # @.: - # actor_rollout_ref.actor: trainer/config/actor/megatron_actor.yaml - - actor@actor_rollout_ref.actor: megatron_actor - # trainer.npu_profile: trainer/config/npu_profile/npu_profile.yaml - - npu_profile@trainer.npu_profile: npu_profile - # data: trainer/config/data/legacy_data.yaml - - data@data: legacy_data - # load the reference default config, then apply the fields in the current yaml - # Reference model config. - # Reference model will be enabled when actor.use_kl_loss or/and algorithm.use_kl_in_reward is/are True. - - ref@actor_rollout_ref.ref: megatron_ref - # Rollout model config. - - rollout@actor_rollout_ref.rollout: rollout - # Critic model config. - - critic@critic: megatron_critic - # Reward model config. - - reward_model@reward_model: megatron_reward_model - - _self_ - -actor_rollout_ref: - hybrid_engine: True - - nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron - - model: - - path: ~/models/deepseek-llm-7b-chat - - custom_chat_template: null - - external_lib: null - - override_config: - - model_config: {} - - moe_config: - - freeze_moe_router: False - - enable_gradient_checkpointing: False - - gradient_checkpointing_kwargs: - - ## Activation Checkpointing - activations_checkpoint_method: null # 'uniform', 'block'; not used with 'selective' - - # 'uniform' divides the total number of transformer layers and checkpoints the input activation of each chunk - # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity - activations_checkpoint_granularity: null # 'selective' or 'full' - - # 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention - activations_checkpoint_num_layers: null # not used with 'selective' - - use_fused_kernels: False # Whether to use custom fused kernels (PostProcessing, for memory efficiency) - - trust_remote_code: False - - rollout: - # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. - enable_chunked_prefill: False - - load_format: dummy_megatron - - tensor_model_parallel_size: 1 - - layer_name_map: - qkv_layer_name: qkv - gate_proj_layer_name: gate_up - - profiler: - _target_: verl.utils.profiler.ProfilerConfig - discrete: False - all_ranks: False - ranks: [] - -custom_reward_function: - path: null - name: compute_score - -algorithm: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.trainer.config.AlgoConfig - gamma: 1.0 - lam: 1.0 - adv_estimator: gae - norm_adv_by_std_in_grpo: True - use_kl_in_reward: False - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.trainer.config.KLControlConfig - type: fixed - kl_coef: 0.001 - horizon: 10000 - target_kl: 0.1 - use_pf_ppo: False - pf_ppo: - # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs in the entrypoint - _target_: verl.trainer.config.PFPPOConfig - reweight_method: pow # ["pow", "max_min", "max_random"] - weight_pow: 2.0 - -trainer: - balance_batch: True - total_epochs: 30 - total_training_steps: null - profile_steps: null # [1,2,5] or [] or null - project_name: verl_examples - experiment_name: gsm8k - logger: ['console', 'wandb'] - log_val_generations: 0 - nnodes: 1 - n_gpus_per_node: 8 - save_freq: -1 - esi_redundant_time: 0 - - # auto: find the last ckpt to resume. If can't find, start from scratch - resume_mode: auto # or disable or resume_path if resume_from_path is set - resume_from_path: null - del_local_ckpt_after_load: False - val_before_train: True - test_freq: -1 - critic_warmup: 0 - default_hdfs_dir: null - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} - max_actor_ckpt_to_keep: null - max_critic_ckpt_to_keep: null - # The timeout for ray worker group to wait for the register center to be ready - ray_wait_register_center_timeout: 300 - device: cuda - # see ppo_trainer.yaml for more details - controller_nsight_options: - trace: "cuda,nvtx,cublas,ucx" - cuda-memory-usage: "true" - cuda-graph-trace: "graph" - worker_nsight_options: - trace: "cuda,nvtx,cublas,ucx" - cuda-memory-usage: "true" - cuda-graph-trace: "graph" - capture-range: "cudaProfilerApi" - capture-range-end: null - kill: none -ray_init: - num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then. - timeline_json_file: null diff --git a/verl/trainer/config/ppo_megatron_trainer_edited.yaml b/verl/trainer/config/ppo_megatron_trainer_edited.yaml deleted file mode 100644 index a3573408f..000000000 --- a/verl/trainer/config/ppo_megatron_trainer_edited.yaml +++ /dev/null @@ -1,164 +0,0 @@ -data: - tokenizer: null - shuffle: False - train_files: ~/data/rlhf/gsm8k/train.parquet - val_files: ~/data/rlhf/gsm8k/test.parquet - prompt_key: prompt - max_prompt_length: 512 - max_response_length: 512 - train_batch_size: 1024 - val_batch_size: 1312 - return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs - return_raw_chat: False - -actor_rollout_ref: - hybrid_engine: True - model: - path: ~/models/deepseek-llm-7b-chat - external_lib: null - override_config: {} - enable_gradient_checkpointing: False - actor: - strategy: megatron # This is for backward-compatibility - ppo_mini_batch_size: 32 - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: 4 - use_dynamic_bsz: True - clip_ratio: 0.2 - entropy_coeff: 0.001 - ppo_epochs: 1 - shuffle: True - optim: - lr: 1e-6 - clip_grad: 1.0 - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine - total_training_steps: -1 # must be override by program - megatron: - tensor_model_parallel_size: 8 - pipeline_model_parallel_size: 1 - num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. - sequence_parallel: True - seed: 1 - load_weight: True - ref: - megatron: - tensor_model_parallel_size: 8 - pipeline_model_parallel_size: 1 - num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. - sequence_parallel: True - seed: 1 - load_weight: True - param_offload: False - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: 4 - micro_batch_size: 4 - micro_batch_size_per_gpu: 4 - rollout: - name: vllm - temperature: 1.0 - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1 - prompt_length: ${data.max_prompt_length} # for xperf_gpt - response_length: ${data.max_response_length} - # for vllm rollout - dtype: bfloat16 # should align with FSDP - gpu_memory_utilization: 0.5 - ignore_eos: False - enforce_eager: True - free_cache_engine: True - load_format: dummy_megatron - tensor_model_parallel_size: 8 - max_num_batched_tokens: 16384 - max_num_seqs: 1024 - log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu - log_prob_micro_batch_size_per_gpu: 4 - disable_log_stats: True - enable_chunked_prefill: False # could get higher throughput - # for hf rollout - do_sample: True - layer_name_map: - qkv_layer_name: qkv - gate_proj_layer_name: gate_up - # number of responses (i.e. num sample times) - n: 1 - -critic: - strategy: megatron - optim: - lr: 1e-5 - clip_grad: 1.0 - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine - total_training_steps: -1 # must be override by program - model: - path: ~/models/deepseek-llm-7b-chat - tokenizer_path: ${actor_rollout_ref.model.path} - override_config: {} - external_lib: ${actor_rollout_ref.model.external_lib} - enable_gradient_checkpointing: False - megatron: - tensor_model_parallel_size: 8 - pipeline_model_parallel_size: 1 - num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. - sequence_parallel: True - seed: 1 - load_weight: True - ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu - ppo_micro_batch_size_per_gpu: 4 - use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} - ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} - shuffle: ${actor_rollout_ref.actor.shuffle} - cliprange_value: 0.5 - kl_ctrl: - type: fixed - kl_coef: 0.001 - -reward_model: - enable: False - strategy: megatron - megatron: - tensor_model_parallel_size: 8 - pipeline_model_parallel_size: 1 - num_layers_per_virtual_pipeline_stage: null # vpp will hang. need debug. - sequence_parallel: True - seed: 1 - model: - input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical - path: ~/models/FsfairX-LLaMA3-RM-v0.1 - external_lib: ${actor_rollout_ref.model.external_lib} - load_weight: True - param_offload: False - micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu - micro_batch_size_per_gpu: null - use_dynamic_bsz: ${critic.use_dynamic_bsz} - max_length: null - -algorithm: - gamma: 1.0 - lam: 1.0 - adv_estimator: gae - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - -trainer: - total_epochs: 30 - total_training_steps: null - project_name: verl_examples - experiment_name: gsm8k - logger: ["console", "wandb"] - nnodes: 1 - n_gpus_per_node: 8 - save_freq: -1 - test_freq: 2 - critic_warmup: 0 - resume_mode: disable - val_generations_to_log_to_wandb: 0 - remove_previous_ckpt_in_save: False - default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/verl/trainer/config/ref/megatron_ref.yaml b/verl/trainer/config/ref/megatron_ref.yaml deleted file mode 100644 index 6a75d68e3..000000000 --- a/verl/trainer/config/ref/megatron_ref.yaml +++ /dev/null @@ -1,51 +0,0 @@ -# megatron ref config, inheriting from trainer/config/ref/ref.yaml -defaults: - - ref - # load the reference default config, then apply the fields in the current yaml - - _self_ - -strategy: megatron - -megatron: - - param_offload: False - - tensor_model_parallel_size: 1 - - expert_model_parallel_size: 1 - - expert_tensor_parallel_size: None - - pipeline_model_parallel_size: 1 - - virtual_pipeline_model_parallel_size: null # change VPP interface for parallelism tests - - context_parallel_size: 1 - - sequence_parallel: True - - use_distributed_optimizer: False - - use_dist_checkpointing: False - - dist_checkpointing_path: null - - seed: ${oc.select:actor_rollout_ref.actor.megatron.seed,42} - - override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} - - use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} - -profile: - - use_profile: False - - profile_ranks: null - - step_start: -1 - - step_end: -1 - - save_path: null - -load_weight: True \ No newline at end of file diff --git a/verl/trainer/config/ref/ref.yaml b/verl/trainer/config/ref/ref.yaml deleted file mode 100644 index 7d9157b3e..000000000 --- a/verl/trainer/config/ref/ref.yaml +++ /dev/null @@ -1,21 +0,0 @@ -# actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default -strategy: ${actor_rollout_ref.actor.strategy} - -# whether to enable torch.compile -# same as actor_rollout_ref.actor.use_torch_compile if it exists, otherwise 1 -use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} - -# [Will be deprecated, use log_prob_micro_batch_size_per_gpu] -# The batch size for one forward pass in the computation of log_prob. Global batch size. -log_prob_micro_batch_size: null - -# The batch size for one forward pass in the computation of log_prob. Local batch size per GPU. -log_prob_micro_batch_size_per_gpu: null - -# enable dynamic batch size (sequence packing) for log_prob computation -# same as actor_rollout_ref.actor.use_dynamic_bsz if it exists, otherwise false -log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} - -# the max token length per GPU -# same as actor_rollout_ref.actor.ppo_max_token_len_per_gpu if it exists, otherwise 16384 -log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} diff --git a/verl/trainer/constants_ppo.py b/verl/trainer/constants_ppo.py deleted file mode 100644 index 84350bbd9..000000000 --- a/verl/trainer/constants_ppo.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -PPO_RAY_RUNTIME_ENV = { - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "WARN", - "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true", - }, -} diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py deleted file mode 100644 index 866998003..000000000 --- a/verl/trainer/fsdp_sft_trainer.py +++ /dev/null @@ -1,665 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -A lightweight one-file FSDP SFT Trainer -TODO(zhangchi.usc1992) -- Add calculation of mfu -- Add validation -""" - -import os - -os.environ["NCCL_DEBUG"] = "WARN" -os.environ["TOKENIZERS_PARALLELISM"] = "true" - -import logging -import re -from contextlib import nullcontext - -import hydra -import torch -import torch.distributed -from peft import LoraConfig, TaskType, get_peft_model -from tensordict import TensorDict -from torch import nn, optim -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.utils.data import DataLoader, Dataset, DistributedSampler -from tqdm import tqdm -from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel - -import verl.utils.hdfs_io as hdfs_io -from verl.utils.dataset import SFTDataset -from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset -from verl.utils.device import get_device_id, get_device_name, is_cuda_available, is_npu_available -from verl.utils.distributed import destroy_global_process_group, initialize_global_process_group -from verl.utils.fs import copy_to_local -from verl.utils.fsdp_utils import ( - CPUOffloadPolicy, - MixedPrecisionPolicy, - apply_fsdp2, - fsdp2_clip_grad_norm_, - fsdp2_load_full_state_dict, - get_fsdp_wrap_policy, - get_init_weight_context_manager, - init_fn, -) -from verl.utils.profiler import log_gpu_memory_usage -from verl.utils.py_functional import convert_to_regular_types -from verl.utils.torch_dtypes import PrecisionType -from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup -from verl.utils.tracking import Tracking -from verl.utils.ulysses import ( - gather_outputs_and_unpad, - get_ulysses_sequence_parallel_world_size, - ulysses_pad_and_slice_inputs, -) -from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager - -if is_cuda_available: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input -elif is_npu_available: - from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) - - -def extract_step(path): - match = re.search(r"global_step_(\d+)", path) - if match: - return int(match.group(1)) - return None - - -class FSDPSFTTrainer: - def __init__( - self, - config, - device_mesh: DeviceMesh, - ulysses_device_mesh: DeviceMesh, - tokenizer, - train_dataset: Dataset, - val_dataset: Dataset, - ): - self.config = config - self.device_mesh = device_mesh - self.ulysses_device_mesh = ulysses_device_mesh - self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - self.tokenizer = tokenizer - if self.config.data.chat_template is not None: - raise ValueError("Apply Chat template from config is not supported yet.") - - # normalize dp size - self._normalize_config_bsz() - - # Set sequence parallel size - self.config.ulysses_sequence_parallel_size = getattr(self.config, "ulysses_sequence_parallel_size", 1) - self.use_remove_padding = getattr(self.config, "use_remove_padding", False) - if self.device_mesh.get_rank() == 0: - print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}") - print(f"Using remove padding: {self.use_remove_padding}") - - self._build_dataloader(train_dataset, val_dataset) - # build model - self._build_model_optimizer() - - # TODO: add checkpoint manager - if self.device_mesh.get_rank() == 0: - print(self.config) - self.device_name = get_device_name() - - def _normalize_config_bsz(self): - dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) - if self.device_mesh.get_rank() == 0: - print(f"Normalize batch size by dp {dp_size}") - - assert self.config.data.train_batch_size % dp_size == 0, ( - f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" - ) - - self.config.data.train_batch_size //= dp_size - - assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0 - - def _build_dataloader(self, train_dataset, val_dataset): - # build dataset - config = self.config - self.train_dataset, self.val_dataset = train_dataset, val_dataset - - # build dataloader - # Use data parallel rank and size instead of global rank and world size - - # If doing SP, we need to use the local rank and size - if self.config.ulysses_sequence_parallel_size > 1: - rank = self.ulysses_device_mesh.get_local_rank("dp") - world_size = self.ulysses_device_mesh.size(0) - if self.ulysses_device_mesh.get_rank() == 0: - print(f"Using SP rank {rank} and size {world_size} for data distribution") - print("Each SP rank gets different data, but the same data WITHIN the same rank") - else: - rank = self.device_mesh.get_rank() - world_size = self.device_mesh.size() - if self.device_mesh.get_rank() == 0: - print(f"Using FSDP rank {rank} and size {world_size} for data distribution") - - self.train_sampler = DistributedSampler( - self.train_dataset, shuffle=True, num_replicas=world_size, rank=rank, drop_last=True - ) - self.train_dataloader = DataLoader( - dataset=self.train_dataset, - batch_size=config.data.train_batch_size, - sampler=self.train_sampler, - num_workers=8, - pin_memory=True, - drop_last=True, - ) - - self.val_sampler = DistributedSampler( - self.val_dataset, shuffle=False, num_replicas=world_size, rank=rank, drop_last=True - ) - self.val_dataloader = DataLoader( - dataset=self.val_dataset, - batch_size=config.data.micro_batch_size_per_gpu, - sampler=self.val_sampler, - num_workers=8, - pin_memory=True, - drop_last=True, - ) - - def _build_model_optimizer(self): - # TODO (zhangchi.usc1992): - # 1. support pretrain from random weights - # 2. support init directly from sharded weights - local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) - - if self.config.model.get("external_lib", None) is not None: - # This is used to import external_lib into the huggingface systems - import importlib - - importlib.import_module(self.config.model.external_lib) - - log_gpu_memory_usage("Before model allocation", logger=logger) - - trust_remote_code = self.config.model.trust_remote_code - torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") - torch_dtype = PrecisionType.to_dtype(torch_dtype) - # load config first - config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) - self.model_config = config - if hasattr(self.model_config, "max_position_embeddings"): - self.model_config.max_position_embeddings = max( - self.model_config.max_position_embeddings, self.config.data.max_length - ) - if self.config.ulysses_sequence_parallel_size > 1: - assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" - - # This may be very large - init_context = get_init_weight_context_manager( - use_meta_tensor=not config.tie_word_embeddings, mesh=self.device_mesh - ) - - with init_context(): - self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - local_model_path, - config=config, - torch_dtype=torch_dtype, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, - ) - - if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1: - from verl.models.transformers.monkey_patch import apply_monkey_patch - - apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size) - - # Apply Liger kernel if use_liger is enabled - if self.config.model.get("use_liger", False): - from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance - - _apply_liger_kernel_to_instance(model=self.model) - - if self.config.model.get("lora_rank", 0) > 0: - self.model.enable_input_require_grads() - # Convert config to regular Python types before creating PEFT model - lora_config = { - "task_type": TaskType.CAUSAL_LM, - "r": self.config.model.lora_rank, - "lora_alpha": self.config.model.lora_alpha, - "target_modules": convert_to_regular_types(self.config.model.target_modules), - "bias": "none", - } - self.model = get_peft_model(self.model, LoraConfig(**lora_config)) - - if self.config.model.enable_gradient_checkpointing: - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - - log_gpu_memory_usage("After model allocation", logger=logger) - - mixed_precision = MixedPrecision( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 - ) - - auto_wrap_policy = get_fsdp_wrap_policy( - self.model, - config=self.config.model.fsdp_config.wrap_policy, - is_lora=self.config.model.get("lora_rank", 0) > 0, - ) - if self.device_mesh.get_rank() == 0: - print(auto_wrap_policy) - - if not self.config.model.fsdp_config.cpu_offload: - cpu_offload = None - else: - cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params) - - fsdp_strategy = self.config.model.strategy - if fsdp_strategy == "fsdp": - self.fsdp_model = FSDP( - self.model, - cpu_offload=cpu_offload, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=get_device_id(), - sharding_strategy=ShardingStrategy.FULL_SHARD, - mixed_precision=mixed_precision, - sync_module_states=True, - device_mesh=self.device_mesh, - forward_prefetch=False, - ) - elif fsdp_strategy == "fsdp2": - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" - mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True - ) - - fsdp_kwargs = { - "mesh": self.device_mesh, - "mp_policy": mp_policy, - "offload_policy": cpu_offload, - "reshard_after_forward": True, - } - full_state = self.model.state_dict() - apply_fsdp2(self.model, fsdp_kwargs, self.config.model.fsdp_config) - fsdp2_load_full_state_dict(self.model, full_state, self.device_mesh, cpu_offload) - self.fsdp_model = self.model - else: - raise NotImplementedError(f"not implement {fsdp_strategy}") - - log_gpu_memory_usage("After FSDP wrapping", logger=logger) - - self.optimizer = optim.AdamW( - self.fsdp_model.parameters(), - lr=self.config.optim.lr, - betas=self.config.optim.betas, - weight_decay=self.config.optim.weight_decay, - ) - - log_gpu_memory_usage("After initialize optimizer", logger=logger) - - self.steps_per_epoch = len(self.train_dataloader) - self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs - - if self.device_mesh.get_rank() == 0: - print( - f"Number of steps/epoch {self.steps_per_epoch}, number of epochs " - f"{self.config.trainer.total_epochs}, total number of steps {self.total_steps}" - ) - - num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio) - - if not hasattr(self.config.optim, "lr_scheduler") or self.config.optim.lr_scheduler == "cosine": - self.lr_scheduler = get_cosine_schedule_with_warmup( - optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps - ) - elif self.config.optim.lr_scheduler == "wsd": - self.lr_scheduler = get_wsd_schedule_with_warmup( - optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=self.total_steps - ) - else: - raise ValueError(f"Unknown lr scheduler: {self.config.optim.lr_scheduler}") - - def _compute_loss_and_backward(self, batch, do_backward=True): - """Compute loss with optional sequence parallelism and remove padding features""" - use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 - - # Move inputs to GPU and prepare loss mask - input_ids = batch["input_ids"].to(self.device_name) - attention_mask = batch["attention_mask"].to(self.device_name) - position_ids = batch["position_ids"].to(self.device_name) - loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).to(self.device_name) - loss_fct = nn.CrossEntropyLoss(reduction="none") - - # Context manager for sequence parallel if needed - context = self.sharding_manager if use_sp else nullcontext() - with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): - if not use_sp: - # Standard forward pass without sequence parallel - labels = input_ids[:, 1:].contiguous() - output = self.fsdp_model( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False - ) - logits = output.logits - - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels.contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, self.model.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - loss = loss * loss_mask.to(loss.device) - else: - # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks - # i.e., each GPU has <1 sequence, and each SP group has 1 sequence - # 1. All SP ranks will receive the *SAME* batch - # 2. Different SP groups will receive *DIFFERENT* batches - # This is implemented by the DistributedSampler - - batch_size, seqlen = input_ids.shape - # Remove padding - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # Unpad position_ids to align rotary - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # Pad and slice inputs for sequence parallelism - input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size() - ) - # For computing loss - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size() - ) - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) - - # Forward pass - output = self.fsdp_model( - input_ids=input_ids_rmpad_sliced, - attention_mask=None, # Not needed with flash attention varlen - position_ids=position_ids_rmpad_padded, - use_cache=False, - ) - - # Compute loss locally then aggregate - logits_rmpad = output.logits.squeeze(0) - input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device) - loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) - # Gather and unpad for sequence parallelism - loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) - - # This is the loss collected from all ulysses ranks - full_loss = pad_input( - hidden_states=loss.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen - ) - full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss - full_loss = full_loss.reshape(-1) - loss_mask = loss_mask.to(full_loss.device) - loss = full_loss * loss_mask - - valid_token_this_rank = torch.sum(loss_mask) - - if self.config.data.balance_dp_token: - torch.distributed.all_reduce(valid_token_this_rank) - dp_size = self.ulysses_device_mesh.size("dp") if use_sp else torch.distributed.get_world_size() - else: - dp_size = 1 - - loss = torch.sum(loss) / (valid_token_this_rank + 1e-8) * dp_size - - if do_backward: - loss.backward() - return loss - - def training_step(self, batch: TensorDict): - self.fsdp_model.train() - - log_gpu_memory_usage("Before optimizer zero_grad", logger=logger) - - self.optimizer.zero_grad() - - log_gpu_memory_usage("After optimizer zero_grad", logger=logger) - - micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu) - n_micro_batches = len(micro_batches) - step_loss = 0 - for micro_batch in micro_batches: - loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches - step_loss += loss.item() - - if self.config.model.strategy == "fsdp": - grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) - elif self.config.model.strategy == "fsdp2": - grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad) - else: - raise NotImplementedError(f"not implement {self.config.model.strategy}") - - log_gpu_memory_usage("Before optimizer step", logger=logger) - - # if grad_norm is not finite, skip the update - if not torch.isfinite(grad_norm): - print(f"WARN: grad_norm is not finite: {grad_norm}") - self.optimizer.zero_grad() - else: - self.optimizer.step() - - log_gpu_memory_usage("After optimizer step", logger=logger) - - self.lr_scheduler.step() - - # reduce loss across dp ranks - lr = self.lr_scheduler.get_last_lr()[0] - - log_gpu_memory_usage("After offload weights", logger=logger) - - step_loss = torch.tensor(step_loss).to(self.device_name) - if is_cuda_available: - torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) - elif is_npu_available: - torch.distributed.all_reduce(step_loss) - step_loss /= self.device_mesh.size(0) - return {"train/loss": step_loss.detach().item(), "train/lr(1e-3)": lr * 1e3} - - def validation_step(self, batch: TensorDict): - self.fsdp_model.eval() - with torch.no_grad(): - loss = self._compute_loss_and_backward(batch, do_backward=False) - if is_cuda_available: - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) - elif is_npu_available: - torch.distributed.all_reduce(loss) - loss /= self.device_mesh.size(0) - return loss - - def save_checkpoint(self, step): - # save checkpoint - path = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}") - - fsdp_strategy = self.config.model.strategy - if fsdp_strategy == "fsdp": - # FSDP1 checkpoint saving - from torch.distributed.fsdp import FullStateDictConfig, StateDictType - - cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg): - state_dict = self.fsdp_model.state_dict() - - # save huggingface model - if self.device_mesh.get_rank() == 0: - os.makedirs(path, exist_ok=True) - self.model.save_pretrained(path, state_dict=state_dict) - self.tokenizer.save_pretrained(path) - elif fsdp_strategy == "fsdp2": - # FSDP2 checkpoint saving - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict - - # Get full state dict with FSDP2 - options = StateDictOptions(full_state_dict=True, cpu_offload=True) - state_dict = get_model_state_dict(self.fsdp_model, options=options) - - # save huggingface model - if self.device_mesh.get_rank() == 0: - os.makedirs(path, exist_ok=True) - self.model.save_pretrained(path, state_dict=state_dict) - self.model_config.save_pretrained(path) - self.tokenizer.save_pretrained(path) - else: - raise NotImplementedError(f"not implement {fsdp_strategy}") - - # Copy to HDFS if configured - if self.device_mesh.get_rank() == 0 and self.config.trainer.default_hdfs_dir: - hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) - hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True) - - torch.distributed.barrier() - - def fit(self): - rank = self.device_mesh.get_rank() - - # TODO: add a unified tracking - if rank == 0: - tracking = Tracking( - project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - ) - - global_step = 0 - last_valid_metric = None - # compute the total training steps. - # the total training steps in SFT is mainly for early exit - total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs - - if self.config.trainer.total_training_steps is not None: - total_training_steps = self.config.trainer.total_training_steps - - self.total_training_steps = total_training_steps - print(f"Total training steps: {self.total_training_steps}") - - # TODO (zhangchi.usc1992) add back checkpoint manager. - # Currently, it blocks when uploading to hdfs. So very slow. - - for epoch in range(self.config.trainer.total_epochs): - self.train_sampler.set_epoch(epoch=epoch) - for data in tqdm( - self.train_dataloader, - total=self.steps_per_epoch, - desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", - disable=rank != 0, - ): - global_step += 1 - data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name) - metric = self.training_step(data) - if rank == 0: - tracking.log(data=metric, step=global_step) - - is_last_step = global_step >= self.total_training_steps - is_valid_step = global_step % self.config.trainer.test_freq == 0 - is_save_step = global_step % self.config.trainer.save_freq == 0 - - # early exit or validation step - if is_last_step or (self.config.trainer.test_freq > 0 and is_valid_step): - # Perform validation - val_losses = [] - for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to( - self.device_name - ) - val_loss = self.validation_step(val_data) - val_losses.append(val_loss) - if rank == 0: - val_loss = torch.mean(torch.stack(val_losses)) - metric = {"val/loss": val_loss.detach().item()} - tracking.log(data=metric, step=global_step) - last_valid_metric = metric - torch.distributed.barrier() - - if is_last_step or (self.config.trainer.save_freq > 0 and is_save_step): - self.save_checkpoint(step=global_step) - - if is_last_step: - if rank == 0: - print(f"Final validation metrics: {last_valid_metric}") - return - - -def run_sft(config): - device_name = get_device_name() - local_rank, rank, world_size = initialize_global_process_group() - - device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) - dp_size = world_size // config.ulysses_sequence_parallel_size - ulysses_device_mesh = init_device_mesh( - device_type=device_name, - mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), - mesh_dim_names=("dp", "sp"), - ) - # build tokenizer and datasets first - from verl.utils import hf_tokenizer - - local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True) - tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code) - train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer) - val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer) - - trainer = FSDPSFTTrainer( - config=config, - device_mesh=device_mesh, - ulysses_device_mesh=ulysses_device_mesh, - tokenizer=tokenizer, - train_dataset=train_dataset, - val_dataset=val_dataset, - ) - - trainer.fit() - - destroy_global_process_group() - - -@hydra.main(config_path="config", config_name="sft_trainer", version_base=None) -def main(config): - run_sft(config) - - -def create_sft_dataset(data_paths, data_config, tokenizer): - """Create a dataset.""" - # build dataset - # First check if a custom dataset class is specified - if data_config.custom_cls.get("path", None): - from verl.utils.import_utils import load_extern_type - - dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name) - # Then check if multi-turn dataset should be used - elif data_config.get("multiturn", {}).get("enable", False): - dataset_cls = MultiTurnSFTDataset - # Default to single-turn dataset - else: - dataset_cls = SFTDataset - - # Create datasets based on the selected class - dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config) - return dataset - - -if __name__ == "__main__": - main() diff --git a/verl/trainer/main_eval.py b/verl/trainer/main_eval.py deleted file mode 100644 index 5b8246a8e..000000000 --- a/verl/trainer/main_eval.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Offline evaluate the performance of a generated file using reward model and ground truth verifier. -The input is a parquet file that contains N generated sequences and (optional) the ground truth. -""" - -from collections import defaultdict - -import hydra -import numpy as np -import pandas as pd -import ray -from tqdm import tqdm - -from verl.trainer.ppo.reward import get_custom_reward_fn -from verl.utils.fs import copy_to_local - - -@ray.remote -def process_item(reward_fn, data_source, response_lst, reward_data): - ground_truth = reward_data["ground_truth"] - score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst] - return data_source, np.mean(score_lst) - - -@hydra.main(config_path="config", config_name="evaluation", version_base=None) -def main(config): - local_path = copy_to_local(config.data.path, use_shm=config.data.get("use_shm", False)) - dataset = pd.read_parquet(local_path) - responses = dataset[config.data.response_key] - data_sources = dataset[config.data.data_source_key] - reward_model_data = dataset[config.data.reward_model_key] - - total = len(dataset) - - # Initialize Ray - if not ray.is_initialized(): - ray.init(num_cpus=config.ray_init.num_cpus) - - # evaluate test_score based on data source - data_source_reward = defaultdict(list) - compute_score = get_custom_reward_fn(config) - - # Create remote tasks - remote_tasks = [ - process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total) - ] - - # Process results as they come in - with tqdm(total=total) as pbar: - while len(remote_tasks) > 0: - # Use ray.wait to get completed tasks - done_ids, remote_tasks = ray.wait(remote_tasks) - for result_id in done_ids: - data_source, score = ray.get(result_id) - data_source_reward[data_source].append(score) - pbar.update(1) - - metric_dict = {} - for data_source, rewards in data_source_reward.items(): - metric_dict[f"test_score/{data_source}"] = np.mean(rewards) - - print(metric_dict) - - -if __name__ == "__main__": - main() diff --git a/verl/trainer/ppo/__init__.py b/verl/trainer/ppo/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/trainer/ppo/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py deleted file mode 100644 index 143d733c7..000000000 --- a/verl/trainer/ppo/core_algos.py +++ /dev/null @@ -1,1148 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Core functions to implement PPO algorithms. -The function implemented in this file should be used by trainer with different distributed strategies to -implement PPO-like algorithms. -""" - -__all__ = ["register_adv_est", "get_adv_estimator_fn", "AdvantageEstimator"] - -from collections import defaultdict -from enum import Enum -from typing import Optional - -import numpy as np -import torch - -import verl.utils.torch_functional as verl_F -from verl.trainer.config import AlgoConfig - -POLICY_LOSS_REGISTRY = {} - - -def register_policy_loss(name): - """Register a policy loss function with the given name. - - Args: - name (str): The name to register the policy loss function under. - - Returns: - function: Decorator function that registers the policy loss function. - """ - - def decorator(func): - POLICY_LOSS_REGISTRY[name] = func - return func - - return decorator - - -def get_policy_loss_fn(name): - """Get the policy loss with a given name. - - Args: - name: `(str)` - The name of the policy loss. - - Returns: - `(callable)`: The policy loss function. - """ - loss_name = name - if loss_name not in POLICY_LOSS_REGISTRY: - raise ValueError( - f"Unsupported loss mode: {loss_name}. Supported modes are: {list(POLICY_LOSS_REGISTRY.keys())}" - ) - return POLICY_LOSS_REGISTRY[loss_name] - - -ADV_ESTIMATOR_REGISTRY = {} - - -def register_adv_est(name_or_enum): - """Decorator to register a advantage estimator function with a given name. - - Args: - name_or_enum: `(str)` or `(AdvantageEstimator)` - The name or enum of the advantage estimator. - - """ - - def decorator(fn): - name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum - if name in ADV_ESTIMATOR_REGISTRY and ADV_ESTIMATOR_REGISTRY[name] != fn: - raise ValueError( - f"Adv estimator {name} has already been registered: {ADV_ESTIMATOR_REGISTRY[name]} vs {fn}" - ) - ADV_ESTIMATOR_REGISTRY[name] = fn - return fn - - return decorator - - -def get_adv_estimator_fn(name_or_enum): - """Get the advantage estimator function with a given name. - - Args: - name_or_enum: `(str)` or `(AdvantageEstimator)` - The name or enum of the advantage estimator. - - Returns: - `(callable)`: The advantage estimator function. - """ - name = name_or_enum.value if isinstance(name_or_enum, Enum) else name_or_enum - if name not in ADV_ESTIMATOR_REGISTRY: - raise ValueError(f"Unknown advantage estimator simply: {name}") - return ADV_ESTIMATOR_REGISTRY[name] - - -class AdvantageEstimator(str, Enum): - """Using an enumeration class to avoid spelling errors in adv_estimator. - - Note(haibin.lin): this enum class is immutable after creation. Extending this - enum for new estimators may not be necessary since users can always just call - `verl.trainer.ppo.core_algos.register` with string name for a custom advantage - estimator instead. - """ - - GAE = "gae" - GRPO = "grpo" - REINFORCE_PLUS_PLUS = "reinforce_plus_plus" - REINFORCE_PLUS_PLUS_BASELINE = "reinforce_plus_plus_baseline" - REMAX = "remax" - RLOO = "rloo" - OPO = "opo" - GRPO_PASSK = "grpo_passk" - GPG = "gpg" - - -class AdaptiveKLController: - """ - Adaptive KL controller described in the paper: - https://arxiv.org/pdf/1909.08593.pdf - """ - - def __init__(self, init_kl_coef, target_kl, horizon): - self.value = init_kl_coef - self.target = target_kl - self.horizon = horizon - - def update(self, current_kl, n_steps): - """Update the KL coefficient based on current KL divergence. - - Args: - current_kl (float): Current KL divergence value. - n_steps (int): Number of steps taken. - """ - target = self.target - proportional_error = np.clip(current_kl / target - 1, -0.2, 0.2) - mult = 1 + proportional_error * n_steps / self.horizon - self.value *= mult - - -class FixedKLController: - """Fixed KL controller.""" - - def __init__(self, kl_coef): - self.value = kl_coef - - def update(self, current_kl, n_steps): - """Update method for fixed KL controller (no-op). - - Args: - current_kl (float): Current KL divergence value (unused). - n_steps (int): Number of steps taken (unused). - """ - pass - - -def get_kl_controller(kl_ctrl): - """Factory function to create appropriate KL controller based on configuration. - - Args: - kl_ctrl: Configuration object containing KL controller settings. - - Returns: - KL controller instance (FixedKLController or AdaptiveKLController). - - Raises: - NotImplementedError: If controller type is not supported. - AssertionError: If adaptive controller horizon is not positive. - """ - if kl_ctrl.type == "fixed": - return FixedKLController(kl_coef=kl_ctrl.kl_coef) - elif kl_ctrl.type == "adaptive": - assert kl_ctrl.horizon > 0, f"horizon must be larger than 0. Got {kl_ctrl.horizon}" - return AdaptiveKLController(init_kl_coef=kl_ctrl.kl_coef, target_kl=kl_ctrl.target_kl, horizon=kl_ctrl.horizon) - else: - raise NotImplementedError - - -@register_adv_est(AdvantageEstimator.GAE) # or simply: @register_adv_est("gae") -def compute_gae_advantage_return( - token_level_rewards: torch.Tensor, - values: torch.Tensor, - response_mask: torch.Tensor, - gamma: torch.Tensor, - lam: torch.Tensor, -): - """Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py - - Args: - token_level_rewards: `(torch.Tensor)` - shape is (bs, response_length) - values: `(torch.Tensor)` - shape is (bs, response_length) - response_mask: `(torch.Tensor)` - shape is (bs, response_length). [EOS] mask. The token after [EOS] have mask zero. - gamma is `(float)` - discounted factor used in RL - lam: `(float)` - lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438) - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - - """ - with torch.no_grad(): - nextvalues = 0 - lastgaelam = 0 - advantages_reversed = [] - gen_len = token_level_rewards.shape[-1] - - for t in reversed(range(gen_len)): - delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t] - lastgaelam_ = delta + gamma * lam * lastgaelam - - # skip values and TD-error on observation tokens - nextvalues = values[:, t] * response_mask[:, t] + (1 - response_mask[:, t]) * nextvalues - lastgaelam = lastgaelam_ * response_mask[:, t] + (1 - response_mask[:, t]) * lastgaelam - - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], dim=1) - - returns = advantages + values - advantages = verl_F.masked_whiten(advantages, response_mask) - return advantages, returns - - -# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. -@register_adv_est(AdvantageEstimator.GRPO) # or simply: @register_adv_est("grpo") -def compute_grpo_outcome_advantage( - token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, - index: np.ndarray, - epsilon: float = 1e-6, - norm_adv_by_std_in_grpo: bool = True, - config: Optional[AlgoConfig] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute advantage for GRPO, operating only on Outcome reward - (with only one scalar reward for each response). - - Args: - token_level_rewards: `(torch.Tensor)` - shape is (bs, response_length) - response_mask: `(torch.Tensor)` - shape is (bs, response_length) - index: `(np.ndarray)` - index array for grouping - epsilon: `(float)` - small value to avoid division by zero - norm_adv_by_std_in_grpo: `(bool)` - whether to scale the GRPO advantage - config: `(Optional[AlgoConfig])` - algorithm configuration object - - Note: - If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. - If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783). - - Returns: - advantages: `(torch.Tensor)` - shape is (bs, response_length) - Returns: `(torch.Tensor)` - shape is (bs, response_length) - """ - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2mean = {} - id2std = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - id2std[idx] = torch.tensor(1.0) - elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - id2std[idx] = torch.std(torch.tensor([id2score[idx]])) - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - if norm_adv_by_std_in_grpo: - scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) - else: - scores[i] = scores[i] - id2mean[index[i]] - scores = scores.unsqueeze(-1) * response_mask - - return scores, scores - - -@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk") -def compute_grpo_passk_outcome_advantage( - token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, - index: np.ndarray, - epsilon: float = 1e-6, - norm_adv_by_std_in_grpo: bool = True, - config: Optional[AlgoConfig] = None, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute advantage for Pass@k using a GRPO-style outcome reward formulation. - Only the best response per group gets a non-zero advantage: r_max - r_second_max. - - Implemented as described in https://arxiv.org/abs/2503.19595. - - Args: - token_level_rewards: (bs, response_length) - response_mask: (bs, response_length) - index: (bs,) → group ID per sample - epsilon: float for numerical stability - config: (AlgoConfig) algorithm settings, which contains "norm_adv_by_std_in_grpo" - - Returns: - advantages: (bs, response_length) - returns: (bs, response_length) - """ - assert config is not None - # if True, normalize advantage by std within group - norm_adv_by_std_in_grpo = config.get("norm_adv_by_std_in_grpo", True) - scores = token_level_rewards.sum(dim=-1) # (bs,) - advantages = torch.zeros_like(scores) - - id2scores = defaultdict(list) - id2indices = defaultdict(list) - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - idx = index[i] - id2scores[idx].append(scores[i]) - id2indices[idx].append(i) - - for idx in id2scores: - rewards = torch.stack(id2scores[idx]) # (k,) - if rewards.numel() < 2: - raise ValueError( - f"Pass@k requires at least 2 samples per group. Got {rewards.numel()} for group {idx}." - ) - topk, topk_idx = torch.topk(rewards, 2) - r_max, r_second_max = topk[0], topk[1] - i_max = id2indices[idx][topk_idx[0].item()] - advantage = r_max - r_second_max - if norm_adv_by_std_in_grpo: - std = torch.std(rewards) - advantage = advantage / (std + epsilon) - advantages[i_max] = advantage - - advantages = advantages.unsqueeze(-1) * response_mask - return advantages, advantages - - -@register_adv_est( - AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE -) # or simply: @register_adv_est("reinforce_plus_plus_baseline") -def compute_reinforce_plus_plus_baseline_outcome_advantage( - token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, - index: torch.Tensor, - epsilon: float = 1e-6, - config: Optional[AlgoConfig] = None, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute advantage for RF++-baseline (https://arxiv.org/abs/2501.03262), operating only on Outcome reward - (with only one scalar reward for each response). - - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - response_mask: `(torch.Tensor)` - shape: (bs, response_length) - config: (AlgoConfig) algorithm config - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - response_length = token_level_rewards.shape[-1] - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2mean = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - scores[i] = scores[i] - id2mean[index[i]] - - scores = scores.unsqueeze(-1).tile([1, response_length]) * response_mask - scores = verl_F.masked_whiten(scores, response_mask) * response_mask - - return scores, scores - - -@register_adv_est(AdvantageEstimator.RLOO) # or simply: @register_adv_est("rloo") -def compute_rloo_outcome_advantage( - token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, - index: np.ndarray, - epsilon: float = 1e-6, - config: Optional[AlgoConfig] = None, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute advantage for RLOO based on https://arxiv.org/abs/2402.14740 - - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - response_mask: `(torch.Tensor)` - shape: (bs, response_length) - config: (AlgoConfig) algorithm config - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2mean = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - response_num = len(id2score[index[i]]) - if response_num > 1: - scores[i] = scores[i] * response_num / (response_num - 1) - id2mean[index[i]] * response_num / ( - response_num - 1 - ) - scores = scores.unsqueeze(-1) * response_mask - - return scores, scores - - -@register_adv_est(AdvantageEstimator.OPO) # or simply: @register_adv_est("opo") -def compute_opo_outcome_advantage( - token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, - index: np.ndarray, - epsilon: float = 1e-6, - config: Optional[AlgoConfig] = None, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute advantage for OPO based on https://arxiv.org/pdf/2505.23585 - - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - response_mask: `(torch.Tensor)` - shape: (bs, response_length) - config: (AlgoConfig) algorithm config - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - response_length = response_mask.sum(dim=-1) - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2len = defaultdict(list) - id2bsl = {} - - with torch.no_grad(): - bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - id2len[index[i]].append(response_length[i]) - - for idx in id2score: - if len(id2score[idx]) == 1: - id2bsl[idx] = torch.tensor(0.0) - elif len(id2score[idx]) > 1: - score_tensor = torch.tensor(id2score[idx]) - len_tensor = torch.tensor(id2len[idx]) - id2bsl[idx] = (len_tensor * score_tensor).sum() / len_tensor.sum() - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - scores[i] = scores[i] - id2bsl[index[i]] - scores = scores.unsqueeze(-1) * response_mask - - return scores, scores - - -@register_adv_est(AdvantageEstimator.REINFORCE_PLUS_PLUS) # or simply: @register_adv_est("reinforce_plus_plus") -def compute_reinforce_plus_plus_outcome_advantage( - token_level_rewards: torch.Tensor, response_mask: torch.Tensor, config: Optional[AlgoConfig] = None, **kwargs -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute advantage for REINFORCE++. - This implementation is based on the paper: https://arxiv.org/abs/2501.03262 - - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - response_mask: `(torch.Tensor)` - shape: (bs, response_length) - config: (AlgoConfig) algorithm config - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - assert config is not None - gamma = config.gamma - with torch.no_grad(): - returns = torch.zeros_like(token_level_rewards) - running_return = 0 - - for t in reversed(range(token_level_rewards.shape[1])): - running_return = token_level_rewards[:, t] + gamma * running_return - returns[:, t] = running_return - # Reset after EOS - running_return = running_return * response_mask[:, t] - - advantages = verl_F.masked_whiten(returns, response_mask) - advantages = advantages * response_mask - - return advantages, returns - - -@register_adv_est(AdvantageEstimator.REMAX) # or simply: @register_adv_est("remax") -def compute_remax_outcome_advantage( - token_level_rewards: torch.Tensor, - reward_baselines: torch.Tensor, - response_mask: torch.Tensor, - config: Optional[AlgoConfig] = None, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute advantage for ReMax, operating only on Outcome reward - This implementation is based on the paper: https://arxiv.org/abs/2310.10505 - (with only one scalar reward for each response). - - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - reward_baselines: `(torch.Tensor)` - shape: (bs,) - response_mask: `(torch.Tensor)` - shape: (bs, response_length) - config: (AlgoConfig) algorithm config - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - - with torch.no_grad(): - returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) - advantages = returns - reward_baselines.unsqueeze(-1) * response_mask - - return advantages, returns - - -@register_adv_est(AdvantageEstimator.GPG) # or simply: @register_adv_est("gpg") -def compute_gpg_outcome_advantage( - token_level_rewards: torch.Tensor, - response_mask: torch.Tensor, - index: np.ndarray, - epsilon: float = 1e-6, - f_norm: float = 1.0, - alpha: float = 1.0, - config=None, - **kwargs, -): - """ - Compute advantage for GPG, operating only on Outcome reward - (with only one scalar reward for each response). - Args: - token_level_rewards: `(torch.Tensor)` - shape: (bs, response_length) - response_mask: `(torch.Tensor)` - shape: (bs, response_length) - index: `(np.ndarray)` - shape: (bs,) - epsilon: (float) - f_norm: (float) - alpha: (float) - config: (dict) algorithm config - - Returns: - advantages: `(torch.Tensor)` - shape: (bs, response_length) - Returns: `(torch.Tensor)` - shape: (bs, response_length) - """ - scores = token_level_rewards.sum(dim=-1) - - id2score = defaultdict(list) - id2mean = {} - id2std = {} - - with torch.no_grad(): - bsz = scores.shape[0] - m = torch.count_nonzero(scores) - alpha = bsz / m.clamp(min=1) - - for i in range(bsz): - id2score[index[i]].append(scores[i]) - - for idx in id2score: - if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - id2std[idx] = torch.tensor(1.0) - elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - id2std[idx] = torch.std(torch.tensor([id2score[idx]])) - else: - raise ValueError(f"no score in prompt index: {idx}") - for i in range(bsz): - scores[i] = alpha * (scores[i] - id2mean[index[i]]) / (f_norm) - scores = scores.unsqueeze(-1) * response_mask - - return scores, scores - - -def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): - """Compute token-level rewards with KL penalty. - - Args: - token_level_scores (torch.Tensor): Token-level reward scores. - old_log_prob (torch.Tensor): Log probabilities from current policy. - ref_log_prob (torch.Tensor): Log probabilities from reference policy. - kl_ratio (float): KL penalty coefficient. - - Returns: - torch.Tensor: Token-level rewards with KL penalty applied. - """ - kl = old_log_prob - ref_log_prob - return token_level_scores - kl * kl_ratio - - -def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str): - """ - Aggregate the loss matrix into a scalar. - - Args: - loss_mat: `(torch.Tensor)`: - shape: (bs, response_length) - loss_mask: `(torch.Tensor)`: - shape: (bs, response_length) - loss_agg_mode: (str) choices: - method to aggregate the loss matrix into a scalar. - Returns: - loss: `a scalar torch.Tensor` - aggregated loss - """ - if loss_agg_mode == "token-mean": - loss = verl_F.masked_mean(loss_mat, loss_mask) - elif loss_agg_mode == "seq-mean-token-sum": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum - loss = torch.mean(seq_losses) # seq-mean - elif loss_agg_mode == "seq-mean-token-mean": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean - loss = torch.mean(seq_losses) # seq-mean - elif loss_agg_mode == "seq-mean-token-sum-norm": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) - loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor - # (loss_mask.shape[-1]) should ideally be constant - # throughout training to well-replicate the DrGRPO paper. - # TODO: Perhaps add user-defined normalizer argument to - # agg_loss to ensure divisor stays constant throughout. - else: - raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") - - return loss - - -def compute_policy_loss( - old_log_prob, - log_prob, - advantages, - response_mask, - cliprange=None, - cliprange_low=None, - cliprange_high=None, - clip_ratio_c=3.0, - loss_agg_mode: str = "token-mean", -): - """ - Compute the clipped policy objective and related metrics for PPO. - - Adapted from - https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1122 - - Args: - old_log_prob (torch.Tensor): - Log-probabilities of actions under the old policy, shape (batch_size, response_length). - log_prob (torch.Tensor): - Log-probabilities of actions under the current policy, shape (batch_size, response_length). - advantages (torch.Tensor): - Advantage estimates for each action, shape (batch_size, response_length). - response_mask (torch.Tensor): - Mask indicating which tokens to include in the loss, shape (batch_size, response_length). - cliprange (float, optional): - Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. - Defaults to None (must be provided). - cliprange_low (float, optional): - Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. - cliprange_high (float, optional): - Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. - clip_ratio_c (float, optional): - Lower bound of the ratio for dual-clip PPO. See https://arxiv.org/pdf/1912.09729. - Defaults to 3.0. - loss_agg_mode (str, optional): - Aggregation mode for `agg_loss`. Defaults to "token-mean". - """ - assert clip_ratio_c > 1.0, ( - "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," - + f" but get the value: {clip_ratio_c}." - ) - - negative_approx_kl = log_prob - old_log_prob - # Clamp negative_approx_kl for stability - negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) - ratio = torch.exp(negative_approx_kl) - ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) - - pg_losses1 = -advantages * ratio - - if cliprange_low is None: - cliprange_low = cliprange - if cliprange_high is None: - cliprange_high = cliprange - pg_losses2 = -advantages * torch.clamp( - ratio, 1 - cliprange_low, 1 + cliprange_high - ) # - clip(ratio, 1-cliprange, 1+cliprange) * A - clip_pg_losses1 = torch.maximum( - pg_losses1, pg_losses2 - ) # max(-ratio * A, -clip(ratio, 1-cliprange, 1+cliprange) * A) - pg_clipfrac = verl_F.masked_mean(torch.gt(pg_losses2, pg_losses1).float(), response_mask) - - pg_losses3 = -advantages * clip_ratio_c - clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) - pg_clipfrac_lower = verl_F.masked_mean( - torch.gt(clip_pg_losses1, pg_losses3) * (advantages < 0).float(), response_mask - ) - - pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - - return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower - - -@register_policy_loss("gpg") -def compute_policy_loss_gpg(old_log_prob, log_prob, advantages, response_mask, loss_agg_mode="token-mean", config=None): - """Adapted from - https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495 - Args: - log_prob: `(torch.Tensor)` - shape: (bs, response_length) - advantages: `(torch.Tensor)` - shape: (bs, response_length) - response_mask: `(torch.Tensor)` - shape: (bs, response_length) - return: - pg_loss: `a scalar torch.Tensor` - policy gradient loss computed via GPG - """ - pg_losses = -log_prob * advantages - - pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0) - - -@register_policy_loss("clip_cov") -def compute_policy_loss_clip_cov( - old_log_prob: torch.Tensor, - log_prob: torch.Tensor, - advantages: torch.Tensor, - response_mask: torch.Tensor, - loss_agg_mode: str = "token-mean", - config: Optional[AlgoConfig] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Compute the clipped policy objective and related metrics for Clip-Cov. - - Adapted from - https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py - - Args: - old_log_prob (torch.Tensor): - Log-probabilities of actions under the old policy, shape (batch_size, response_length). - log_prob (torch.Tensor): - Log-probabilities of actions under the current policy, shape (batch_size, response_length). - advantages (torch.Tensor): - Advantage estimates for each action, shape (batch_size, response_length). - response_mask (torch.Tensor): - Mask indicating which tokens to include in the loss, shape (batch_size, response_length). - cliprange (float, optional): - Clipping parameter ε for standard PPO. See https://arxiv.org/abs/1707.06347. - Defaults to None (must be provided). - cliprange_low (float, optional): - Lower clip range for dual-clip PPO. Defaults to same as `cliprange`. - cliprange_high (float, optional): - Upper clip range for dual-clip PPO. Defaults to same as `cliprange`. - loss_agg_mode (str, optional): - Aggregation mode for `agg_loss`. Defaults to "token-mean". - clip_cvo_ratio (float, optional): - Ratio for clipping the covariance. Defaults to 0.0002. - clip_cov_lb (float, optional): - Lower bound for clipping covariance. Defaults to 1.0. - clip_cov_ub (float, optional): - Upper bound for clipping covariance. Defaults to 5.0. - """ - clip_cov_ratio = config.policy_loss.clip_cov_ratio if config.policy_loss.clip_cov_ratio is not None else 0.0002 - cliprange = config.clip_ratio - cliprange_low = config.clip_ratio_low if config.clip_ratio_low is not None else cliprange - cliprange_high = config.clip_ratio_high if config.clip_ratio_high is not None else cliprange - clip_cov_ub = config.policy_loss.clip_cov_ub if config.policy_loss.clip_cov_ub is not None else 5.0 - clip_cov_lb = config.policy_loss.clip_cov_lb if config.policy_loss.clip_cov_lb is not None else 1.0 - - assert clip_cov_ratio > 0, "clip_ratio should be larger than 0." - - negative_approx_kl = log_prob - old_log_prob - ratio = torch.exp(negative_approx_kl) - ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) - - pg_losses1 = -advantages * ratio - if cliprange_low is None: - cliprange_low = cliprange - if cliprange_high is None: - cliprange_high = cliprange - - corr = torch.ones_like(advantages) - pg_losses2 = -advantages * torch.clamp(ratio, 1 - cliprange_low, 1 + cliprange_high) - clip_by_origin = (pg_losses2 > pg_losses1) & (response_mask > 0) - - cov_all = (advantages - verl_F.masked_mean(advantages, response_mask)) * ( - log_prob - verl_F.masked_mean(log_prob.detach(), response_mask) - ) - cov_all[response_mask == 0] = -torch.inf - cov_all[clip_by_origin] = -torch.inf - - clip_num = max(int(clip_cov_ratio * response_mask.sum().item()), 1) - top_k_idx = (cov_all < clip_cov_ub) & (cov_all > clip_cov_lb) & (response_mask > 0) - top_k_idx = torch.nonzero(top_k_idx) - - if len(top_k_idx) > 0: - perm = torch.randperm(len(top_k_idx)) - top_k_idx = top_k_idx[perm[: min(clip_num, len(top_k_idx))]] - else: - top_k_idx = torch.empty((0, 2), device=cov_all.device, dtype=torch.long) - - corr[top_k_idx[:, 0], top_k_idx[:, 1]] = 0 - - pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask) - - pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr - pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - - return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0) - - -@register_policy_loss("kl_cov") -def compute_policy_loss_kl_cov( - old_log_prob: torch.Tensor, - log_prob: torch.Tensor, - advantages: torch.Tensor, - response_mask: torch.Tensor, - loss_agg_mode: str = "token-mean", - config: Optional[AlgoConfig] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Compute the clipped policy objective and related metrics for Clip-Cov. - - Adapted from - https://github.com/PRIME-RL/Entropy-Mechanism-of-RL/blob/main/verl/trainer/ppo/core_algos.py - - Args: - old_log_prob (torch.Tensor): - Log-probabilities of actions under the old policy, shape (batch_size, response_length). - log_prob (torch.Tensor): - Log-probabilities of actions under the current policy, shape (batch_size, response_length). - advantages (torch.Tensor): - Advantage estimates for each action, shape (batch_size, response_length). - response_mask (torch.Tensor): - Mask indicating which tokens to include in the loss, shape (batch_size, response_length). - loss_agg_mode (str, optional): - Aggregation mode for `agg_loss`. Defaults to "token-mean". - kl_cov_ratio (float, optional): - Ratio for selecting the top-k covariance values. Defaults to 0.0002. - ppo_kl_coef (float, optional): - Coefficient for the KL penalty term in the loss. Defaults to 1. - """ - kl_cov_ratio = config.policy_loss.kl_cov_ratio if config.policy_loss.kl_cov_ratio is not None else 0.0002 - ppo_kl_coef = config.policy_loss.ppo_kl_coef if config.policy_loss.ppo_kl_coef is not None else 1.0 - - assert kl_cov_ratio > 0, "kl_cov_ratio should be larger than 0." - - negative_approx_kl = log_prob - old_log_prob - abs_kl = negative_approx_kl.abs() - ratio = torch.exp(negative_approx_kl) - ppo_kl_abs = verl_F.masked_mean(negative_approx_kl.abs(), response_mask) - pg_losses1 = -advantages * ratio - pg_losses_kl = -advantages * ratio + ppo_kl_coef * abs_kl - pg_losses = pg_losses1 - - all_valid = response_mask > 0 - all_valid_idx = torch.nonzero(all_valid.reshape(-1), as_tuple=True)[0] - all_valid_adv = advantages[all_valid].detach().reshape(-1).cpu() - all_valid_logp = log_prob[all_valid].detach().reshape(-1).cpu() - - k = min(kl_cov_ratio, len(all_valid_adv)) - - if k != 0: - cov_lst_all = (all_valid_adv - all_valid_adv.mean()) * (all_valid_logp - all_valid_logp.mean()) - k_percent_nums = max(1, int(len(cov_lst_all) * kl_cov_ratio)) - large_cov_idxs = torch.topk(cov_lst_all, k_percent_nums, largest=True).indices - - if len(large_cov_idxs) != 0: - large_cov_idxs = all_valid_idx[large_cov_idxs] - pg_losses[large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]] = pg_losses_kl[ - large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1] - ] - - pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - - return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0) - - -def compute_entropy_loss(logits, response_mask, loss_agg_mode: str = "token-mean"): - """Compute categorical entropy loss (For backward compatibility) - - Args: - logits (torch.Tensor): shape is (bs, response_length, vocab_size) - response_mask (torch.Tensor): shape is (bs, response_length) - - Returns: - entropy: a scalar torch.Tensor - - """ - # compute entropy - token_entropy = verl_F.entropy_from_logits(logits) # (bs, response_len) - entropy_loss = agg_loss(loss_mat=token_entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - return entropy_loss - - -def compute_value_loss( - vpreds: torch.Tensor, - returns: torch.Tensor, - values: torch.Tensor, - response_mask: torch.Tensor, - cliprange_value: float, - loss_agg_mode: str = "token-mean", -): - """ - Compute the clipped value-function loss for PPO. - - Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1151 - - Args: - vpreds (torch.FloatTensor): - Predicted values from the value head, shape (batch_size, response_length). - values (torch.FloatTensor): - Old (baseline) values from the value head, shape (batch_size, response_length). - returns (torch.FloatTensor): - Ground-truth returns, shape (batch_size, response_length). - response_mask (torch.Tensor): - Mask indicating which tokens to include in the value loss calculation. - cliprange_value (float): - Clip range for value prediction updates. - loss_agg_mode (str, optional): - Aggregation mode for `agg_loss`. Defaults to "token-mean". - - Returns: - vf_loss (torch.FloatTensor): - A scalar tensor containing the aggregated value-function loss. - vf_clipfrac (float): - Fraction of elements where the clipped loss was used. - """ - vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value) - vf_losses1 = (vpreds - returns) ** 2 - vf_losses2 = (vpredclipped - returns) ** 2 - clipped_vf_losses = torch.max(vf_losses1, vf_losses2) - vf_loss = 0.5 * agg_loss(loss_mat=clipped_vf_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - vf_clipfrac = verl_F.masked_mean(torch.gt(vf_losses2, vf_losses1).float(), response_mask) - return vf_loss, vf_clipfrac - - -def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_penalty) -> torch.FloatTensor: - """Compute KL divergence given logprob and ref_logprob. - Copied from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L1104 - See more description in http://joschu.net/blog/kl-approx.html - - Args: - logprob: - ref_logprob: - - Returns: - - """ - if kl_penalty in ("kl", "k1"): - return logprob - ref_logprob - - if kl_penalty == "abs": - return (logprob - ref_logprob).abs() - - if kl_penalty in ("mse", "k2"): - return 0.5 * (logprob - ref_logprob).square() - - # J. Schulman. Approximating kl divergence, 2020. - # # URL http://joschu.net/blog/kl-approx.html. - if kl_penalty in ("low_var_kl", "k3"): - kl = ref_logprob - logprob - # For numerical stability - kl = torch.clamp(kl, min=-20, max=20) - ratio = torch.exp(kl) - kld = (ratio - kl - 1).contiguous() - return torch.clamp(kld, min=-10, max=10) - - if kl_penalty == "full": - # so, here logprob and ref_logprob should contain the logits for every token in vocabulary - raise NotImplementedError - - raise NotImplementedError - - -def compute_pf_ppo_reweight_data( - data, - reweight_method: str = "pow", - weight_pow: float = 2.0, -): - """Reweight the data based on the token_level_scores. - - Args: - data: DataProto object, containing batch, non_tensor_batch and meta_info - reweight_method: str, choices: "pow", "max_min", "max_random" - weight_pow: float, the power of the weight - - Returns: - - """ - - @torch.no_grad() - def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: float) -> torch.Tensor: - """Compute importance weights for resampling based on scores. - - Args: - scores (torch.Tensor): Tensor of scores to compute weights from. - reweight_method (str): Method for computing weights ('pow', 'max_min', 'max_random'). - weight_pow (float): Power exponent for 'pow' method. - - Returns: - torch.Tensor: Computed importance weights. - - Raises: - ValueError: If reweight_method is not supported. - """ - if reweight_method == "pow": - weights = torch.pow(torch.abs(scores), weight_pow) - elif reweight_method == "max_min": - max_score = torch.max(scores) - min_score = torch.min(scores) - weights = torch.where((scores == max_score) | (scores == min_score), 1.0, 0.0) - elif reweight_method == "max_random": - max_score = torch.max(scores) - weights = torch.where(scores == max_score, 0.4, 0.1) - else: - raise ValueError(f"Unsupported reweight_method: {reweight_method}") - return weights - - scores = data.batch["token_level_scores"].sum(dim=-1) - weights = compute_weights(scores, reweight_method, weight_pow) - weights = torch.clamp(weights + 1e-8, min=1e-8) - - batch_size = scores.shape[0] - sample_indices = torch.multinomial(weights, batch_size, replacement=True) - - resampled_batch = {key: tensor[sample_indices] for key, tensor in data.batch.items()} - - sample_indices_np = sample_indices.numpy() - resampled_non_tensor_batch = {} - for key, array in data.non_tensor_batch.items(): - if isinstance(array, np.ndarray): - resampled_non_tensor_batch[key] = array[sample_indices_np] - else: - resampled_non_tensor_batch[key] = [array[i] for i in sample_indices_np] - - resampled_meta_info = {} - for key, value in data.meta_info.items(): - if isinstance(value, list) and len(value) == batch_size: - resampled_meta_info[key] = [value[i] for i in sample_indices_np] - else: - resampled_meta_info[key] = value - - from copy import deepcopy - - resampled_data = deepcopy(data) - resampled_data.batch = type(data.batch)(resampled_batch) - resampled_data.batch.batch_size = data.batch.batch_size - resampled_data.non_tensor_batch = resampled_non_tensor_batch - resampled_data.meta_info = resampled_meta_info - - return resampled_data diff --git a/verl/trainer/ppo/reward.py b/verl/trainer/ppo/reward.py deleted file mode 100644 index 143b631bc..000000000 --- a/verl/trainer/ppo/reward.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright 2025 Individual Contributor: Thibaut Barroyer -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -import os -from functools import partial - -import ray - -from verl import DataProto -from verl.utils.reward_score import default_compute_score - - -def _call_with_kwargs(raw_fn, extra_kwargs, *args, **kwargs): - """Calls `raw_fn` by merging `extra_kwargs` into call-time `kwargs`, with `extra_kwargs` taking precedence. - - This function is used to merge additional keyword arguments with the original function's arguments. - """ - merged_kwargs = {**kwargs, **extra_kwargs} - return raw_fn(*args, **merged_kwargs) - - -def get_custom_reward_fn(config): - """Load and return a custom reward function from external file. - - Dynamically imports a reward function from a specified file path and wraps - it with additional keyword arguments from the configuration. - - Args: - config (dict): Configuration dictionary containing custom_reward_function - settings with 'path', 'name', and 'reward_kwargs' fields. - - Returns: - callable or None: Wrapped reward function with merged kwargs, or None - if no custom reward function is configured. - - Raises: - FileNotFoundError: If the specified reward function file doesn't exist. - RuntimeError: If there's an error loading the module from file. - AttributeError: If the specified function name isn't found in the module. - """ - import importlib.util - import sys - - reward_fn_config = config.get("custom_reward_function") or {} - file_path = reward_fn_config.get("path") - if not file_path: - return None - - if not os.path.exists(file_path): - raise FileNotFoundError(f"Reward function file '{file_path}' not found.") - - spec = importlib.util.spec_from_file_location("custom_module", file_path) - module = importlib.util.module_from_spec(spec) - try: - sys.modules["custom_module"] = module - spec.loader.exec_module(module) - except Exception as e: - raise RuntimeError(f"Error loading module from '{file_path}': {e}") from e - - function_name = reward_fn_config.get("name") - if not hasattr(module, function_name): - raise AttributeError(f"Reward function '{function_name}' not found in '{file_path}'.") - - print(f"using customized reward function '{function_name}' from '{file_path}'") - raw_fn = getattr(module, function_name) - - reward_kwargs = dict(reward_fn_config.get("reward_kwargs", {})) - - return partial(_call_with_kwargs, raw_fn, reward_kwargs) - - -def load_reward_manager(config, tokenizer, num_examine, **reward_kwargs): - """ - Load and initialize a reward manager based on the configuration. - - Args: - config: PPO trainer configuration object containing reward_model fields. - tokenizer: Tokenizer object used for processing text. - num_examine: Number of samples to examine. - **reward_kwargs: Additional keyword arguments for the reward manager. - - Returns: - An instance of the specified reward manager class. - """ - from verl.workers.reward_manager import get_reward_manager_cls - - # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`: - # naive: NaiveRewardManager - # prime: PrimeRewardManager - # batch: BatchRewardManager - # dapo: DAPORewardManager - # Note(haibin.lin): For custom reward managers, please make sure they are imported and - # registered via `verl.workers.reward_manager.register` - # By default reward_manager is set to naive (NaiveRewardManager) - reward_manager_name = config.reward_model.get("reward_manager", "naive") - reward_manager_cls = get_reward_manager_cls(reward_manager_name) - - # Try to get a custom reward function based on the configuration - compute_score = get_custom_reward_fn(config) - final_compute_score = compute_score - - if compute_score is None: - sandbox_config = config.reward_model.get("sandbox_fusion") - sandbox_url = sandbox_config.get("url") if sandbox_config else None - memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024) - if sandbox_url: - sandbox_manager = multiprocessing.Manager() - # Create a semaphore to control concurrent access to the sandbox - _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64)) - final_compute_score = partial( - default_compute_score, - sandbox_fusion_url=sandbox_url, - concurrent_semaphore=_concurrent_semaphore, - memory_limit_mb=memory_limit_mb, - ) - else: - final_compute_score = default_compute_score - - # Instantiate and return the reward manager with the specified parameters - return reward_manager_cls( - tokenizer=tokenizer, - num_examine=num_examine, - compute_score=final_compute_score, - reward_fn_key=config.data.reward_fn_key, - **reward_kwargs, - ) - - -def compute_reward(data: DataProto, reward_fn): - """ - Compute reward for a batch of data. - Args: - data: DataProto object containing the input data. - reward_fn: Reward function to compute the reward. - Returns: - Tuple of reward tensor and extra info dictionary. - """ - try: - reward_result = reward_fn(data, return_dict=True) - reward_tensor = reward_result["reward_tensor"] - reward_extra_infos_dict = reward_result.get("reward_extra_info", {}) - except Exception as e: - print(f"Error in reward_fn: {e}") - reward_tensor = reward_fn(data) - reward_extra_infos_dict = {} - - return reward_tensor, reward_extra_infos_dict - - -@ray.remote(num_cpus=1) -def compute_reward_async(data: DataProto, config, tokenizer): - """ - Load the reward manager and compute the reward for a batch of data. - This is meant to be run in a separate Ray worker. - """ - reward_fn = load_reward_manager(config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})) - return compute_reward(data, reward_fn) diff --git a/verl/trainer/runtime_env.yaml b/verl/trainer/runtime_env.yaml deleted file mode 100644 index d29f2128b..000000000 --- a/verl/trainer/runtime_env.yaml +++ /dev/null @@ -1,4 +0,0 @@ -working_dir: ./ -excludes: ["/.git/"] -env_vars: - TORCH_NCCL_AVOID_RECORD_STREAMS: "1" diff --git a/verl/utils/__init__.py b/verl/utils/__init__.py deleted file mode 100644 index 034584945..000000000 --- a/verl/utils/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import config, tokenizer -from .config import omega_conf_to_dataclass -from .tokenizer import hf_processor, hf_tokenizer - -__all__ = tokenizer.__all__ + config.__all__ + ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass"] diff --git a/verl/utils/activation_offload.py b/verl/utils/activation_offload.py deleted file mode 100644 index 73e2e83eb..000000000 --- a/verl/utils/activation_offload.py +++ /dev/null @@ -1,558 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Functionality for CPU offloading of tensors saved for backward pass.""" - -from __future__ import annotations - -import functools -import logging -import os -from typing import Any, Optional - -import torch -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -from verl.utils.device import get_torch_device -from verl.utils.fsdp_utils import FSDPModule as FSDP2 - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -def _get_unique_tensor_key(tensor): - key = (tensor.untyped_storage().data_ptr() + tensor.storage_offset(), tensor.dtype) - return key - - -class FSDPParameterFilter: - def __init__(self): - self.model_parameters_storage = set() - - def __call__(self, tensor): - return tensor.untyped_storage().data_ptr() not in self.model_parameters_storage - - def update_model_parameters(self, model): - new_storage = set() - for p in model.parameters(): - new_storage.add(p.data.untyped_storage().data_ptr()) - self.model_parameters_storage = new_storage - - -class CpuOffloadHookWithOffloadHandler: - """Context-manager that offloads/recovers tensors through an offload hander. - - The hook just offloads/recovers the tensor object to the handler through `tensor_push` - and `tensor_pop` interface. How the offload-handler manages the offloading, recovering - or prefetching timing is transparent to this hook. - """ - - def __init__( - self, - offload_handler: OffloadHandler, - handler_extra_kwargs: Optional[dict[str, Any]] = None, - ) -> None: - if handler_extra_kwargs is None: - handler_extra_kwargs = {} - self.offload_handler: OffloadHandler = offload_handler - self.handler_extra_kwargs: dict[str, Any] = handler_extra_kwargs - self.inside_context = False - - def __enter__(self): - self.inside_context = True - torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor) - - def __exit__(self, *args: Any): - self.inside_context = False - torch._C._autograd._pop_saved_tensors_default_hooks() - - def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) - return retrieve_identifier - - def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) - return tensor - - -class OffloadHandler: - """A base class for CPU offload-handler.""" - - def __init__(self) -> None: - pass - - def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: - """Tensor push.""" - raise NotImplementedError( - "`tensor_push is not implented in OffloadHandler class. Inherit this class and implement your " - "custom tensor_push." - ) - - def tensor_pop(self, tensor_tag: Any, **kwargs): - """Tensor pop.""" - raise NotImplementedError( - "`tensor_pop is not implented in OffloadHandler class. Inherit this class and implement your " - "custom tensor_pop." - ) - - -class GroupCommitFunction(torch.autograd.Function): - """this is a dummy op with output identical to input. - However, it is necessary for marking a timepoint for offload handler to - accomplish all synchronizations. Implementing it as a function is necessary - because we need to actions in both forward and backward. - """ - - @staticmethod - def forward(ctx, tensor, cpu_offload_handler): - # pylint: disable=missing-function-docstring - cpu_offload_handler.on_group_commit_forward() - ctx.cpu_offload_handler = cpu_offload_handler - # return the identical tensor - return tensor - - @staticmethod - def backward(ctx, grad_output): - # pylint: disable=missing-function-docstring - cpu_offload_handler = ctx.cpu_offload_handler - cpu_offload_handler.on_group_commit_backward() - return grad_output, None - - -group_prefetch_offload_commit = GroupCommitFunction.apply - - -class SynchronizedGroupOffloadHandler(OffloadHandler): - """Offload Handler that offloads/reloads in a synchronized way. - The device-to-host and host-to-device copying happen in the same stream - as the computation kernels, thus the copying will block computation. - """ - - def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True)) -> None: - super().__init__() - - self.num_offload_group = num_offload_group - self.tensor_need_offloading_checker = tensor_need_offloading_checker - - self.groupid_reset() - - def groupid_reset(self): - """Groupid reset.""" - # Data structures to label saved tensors and book-keep their cpu copies. - # Currently, on push, create a new cpu tensor and copies; on pop, copies - # the tensor back to gpu and deletes the cpu tensor. - # These will increment whenever `group_commit()` is invoked - self.current_group, self.tensor_count_current_group = (0, 0) - self.torch_tensor_count = 0 - self.tensor_tag_to_state = {} - - def on_group_commit_forward(self): - """On group commit forward.""" - # finishing up with updating current group and tensor count - self.current_group += 1 # increment - self.tensor_count_current_group = 0 # reset - - def on_group_commit_backward(self): - """On group commit backward.""" - self.current_group -= 1 - assert self.current_group >= 0 - - @staticmethod - def offload(src_tensor, pin_memory=True): - """Offload.""" - - cpu_backup = torch.empty( - src_tensor.size(), - dtype=src_tensor.dtype, - layout=src_tensor.layout, - device="cpu", - pin_memory=pin_memory, - ) - cpu_backup.copy_(src_tensor, non_blocking=True) - state = (src_tensor.device, cpu_backup) - return state - - @staticmethod - def reload(state, non_blocking=None): - """Reload.""" - dev, cpu_backup = state - if non_blocking is None: - non_blocking = cpu_backup.is_pinned() - return cpu_backup.to(dev, non_blocking=non_blocking) - - def tensor_push(self, tensor: torch.Tensor, **kwargs): - """Tensor push.""" - # obtain a unique tensor tag - tensor_tag = (self.current_group, self.tensor_count_current_group) - self.tensor_count_current_group += 1 - assert tensor_tag not in self.tensor_tag_to_state - if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): - state = SynchronizedGroupOffloadHandler.offload(tensor) - self.tensor_tag_to_state[tensor_tag] = state - else: - # will be offloaded together after group commit - self.tensor_tag_to_state[tensor_tag] = tensor - - return tensor_tag - - def tensor_pop(self, tensor_tag, **kwargs): - """Tensor pop.""" - assert tensor_tag in self.tensor_tag_to_state - state = self.tensor_tag_to_state.pop(tensor_tag) - if isinstance(state, tuple): - tensor = SynchronizedGroupOffloadHandler.reload(state) - else: - tensor = state - return tensor - - -class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): - """Compared to synchronize, this uses more memory because of the buffer but - achieves better performance due to the overlapping. D2h and h2d copying are - completely hidden behind computation if computation time of a layer is longer - than host-device communication time. Bulk offloading with delay and bulk reloading - with prefetch are implemented.""" - - def __init__( - self, - num_offload_group, # must be <= actual number of groups (number of commits) - num_model_group, - tensor_need_offloading_checker=(lambda t: True), - ) -> None: - super().__init__( - num_offload_group=num_offload_group, - tensor_need_offloading_checker=tensor_need_offloading_checker, - ) - # Number of layers in the model - self.num_layers = num_model_group - # Data Structure to maintain reference to activation tensors - self.tensor_tag_to_buf = {} - # Tracking the number of layers offloaded - self.offloaded_group_count = 0 - # Core data structure that decides the window for offloading - self.layer_window_map = {} - self.group_offload_mapping = {} - - # Logic to make offloading load balance across computation - # for optimal CPU/GPU interconnect usage - constant = 0 - for i in range(self.num_offload_group): - self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 - if i < (self.num_layers % self.num_offload_group): - self.layer_window_map[i] += i + 1 - constant = i + 1 - else: - self.layer_window_map[i] += constant - - # allocate streams and events for synchronization - self.d2h_stream = get_torch_device().Stream() - self.h2d_stream = get_torch_device().Stream() - - def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: - torch_stray_tensor = isinstance( - tensor, - torch._subclasses.fake_tensor.FakeTensor | torch._subclasses.functional_tensor.FunctionalTensor, - ) - need_offload = not torch_stray_tensor - need_offload = need_offload and self.tensor_need_offloading_checker(tensor) - - if need_offload: - # obtain a unique tensor tag - tensor_tag = (self.current_group, self.tensor_count_current_group) - self.tensor_count_current_group += 1 - - assert tensor_tag not in self.tensor_tag_to_state - self.tensor_tag_to_state[tensor_tag] = tensor - - if self.current_group < self.num_offload_group: - self.tensor_tag_to_buf[tensor_tag] = tensor - else: - tensor_tag = tensor - return tensor_tag - - def tensor_pop(self, tensor_tag, **kwargs): - """Tensor pop.""" - if isinstance(tensor_tag, torch.Tensor): - return tensor_tag - assert tensor_tag in self.tensor_tag_to_state - tensor = self.tensor_tag_to_state.pop(tensor_tag) - self.tensor_tag_to_buf.pop(tensor_tag, None) - - # the tensor should have been copied back in on_group_commit_backward() - # which invokes bulk_reload_group. - assert not isinstance(tensor, tuple) - return tensor - - def bulk_offload_group(self, group_to_offload): - """Bulk offload group.""" - offload_mapping = {} - offload_size = 0 - with get_torch_device().stream(self.d2h_stream): - for tensor_tag, state in self.tensor_tag_to_state.items(): - group_id, _ = tensor_tag - if group_id == group_to_offload: - assert not isinstance(state, tuple) - key = _get_unique_tensor_key(state) - if key not in offload_mapping: - offload_mapping[key] = state - # if offload, return the reference to cpu copy - self.tensor_tag_to_state[tensor_tag] = (key, state.shape) - for key, tensor in offload_mapping.items(): - state = SynchronizedGroupOffloadHandler.offload(tensor) - offload_size += tensor.numel() * tensor.element_size() - offload_mapping[key] = state - - self.group_offload_mapping[group_to_offload] = offload_mapping - - def synchronize_on_group_commit_forward(self, current_group): - """Synchronize on group commit forward.""" - - # For the first group, kickstart the offload after we have - # the first compute completion - if current_group == 0: - self.d2h_stream.wait_stream(get_torch_device().current_stream()) - self.bulk_offload_group(current_group) - - # Window map data structure helps us synchronize based on number - # of layers offloaded - if self.layer_window_map[self.offloaded_group_count] == current_group: - # Stream synchronization both ways - self.d2h_stream.wait_stream(get_torch_device().current_stream()) - get_torch_device().current_stream().wait_stream(self.d2h_stream) - - # Time to free the activation memory after usage - for tensor_tag, _ in self.tensor_tag_to_buf.items(): - if tensor_tag[0] == self.offloaded_group_count: - self.tensor_tag_to_buf[tensor_tag] = None - - # Time to offload the next group - if self.offloaded_group_count < (self.num_offload_group - 1): - self.bulk_offload_group(self.offloaded_group_count + 1) - - # Increment the offload group count to keep track - self.offloaded_group_count += 1 - - def on_group_commit_forward(self): - """This function will cause host device synchronization""" - # handle synchronization events - self.synchronize_on_group_commit_forward(self.current_group) - - super().on_group_commit_forward() - - @torch.no_grad - def bulk_reload_group(self, group_to_reload): - """Bulk reload group.""" - assert group_to_reload < self.num_offload_group - - with get_torch_device().stream(self.h2d_stream): - # move back tensors - offload_mapping = self.group_offload_mapping.pop(group_to_reload) - assert offload_mapping is not None - for key, state in offload_mapping.items(): - offload_mapping[key] = SynchronizedGroupOffloadHandler.reload(state) - for tensor_label, state in self.tensor_tag_to_state.items(): - group_id, _ = tensor_label - if group_id == group_to_reload and not isinstance(state, torch.Tensor): - assert isinstance(state, tuple), f"{group_id} {state}" - key, shape = state - recovered_tensor = offload_mapping[key].view(shape) - self.tensor_tag_to_state[tensor_label] = recovered_tensor - - def on_group_commit_backward(self): - # first decrement the current group. - # after last commit in forward, the group will +1; in backward it -1. - # Finally it should be decremented to 0. - self.current_group -= 1 - assert self.current_group >= 0 - - # Layer window data structure helps us to reload at right times - if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: - # Stream synchronization both ways - self.h2d_stream.wait_stream(get_torch_device().current_stream()) - get_torch_device().current_stream().wait_stream(self.h2d_stream) - - # Time to reload the next group - self.bulk_reload_group(self.offloaded_group_count - 1) - - # Decrease the offloading group counter - self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 - - # Last group computation needs to wait till all the reloads complete - if self.current_group == 0: - get_torch_device().current_stream().wait_stream(self.h2d_stream) - self.offloaded_group_count = 0 - - -def get_activation_offload_context( - num_layers: int = 1, model_layers: int = 1, tensor_need_offloading_checker=(lambda t: True) -): - cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( - num_offload_group=num_layers, - num_model_group=model_layers, - tensor_need_offloading_checker=tensor_need_offloading_checker, - ) - - def group_prefetch_offload_commit_async(tensor): - return group_prefetch_offload_commit(tensor, cpu_offload_handler) - - return ( - CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), - group_prefetch_offload_commit_async, - ) - - -class ActivationHandler: - def __init__(self, offload_ctx, sync_func, tensor_filter, enable_ckpt): - self._offload_ctx = offload_ctx - self._sync_func = sync_func - self._enable_ckpt = enable_ckpt - self._tensor_filter = tensor_filter - if enable_ckpt: - self.checkpoint_fn = functools.partial( - torch.utils.checkpoint.checkpoint, - use_reentrant=True, - ) - - def pre_forward(self, module): - if module.training: - self._offload_ctx.__enter__() - self._tensor_filter.update_model_parameters(module) - - def post_forward(self, module): - if module.training: - self._offload_ctx.__exit__(None, None, None) - - def _pack_kwargs(self, *args, **kwargs): - kwarg_keys = [] - flat_args = list(args) - for k, v in kwargs.items(): - kwarg_keys.append(k) - flat_args.append(v) - - return tuple(flat_args), tuple(kwarg_keys) - - def _unpack_kwargs(self, flat_args, kwarg_keys): - assert len(kwarg_keys) <= len(flat_args), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" - if len(kwarg_keys) == 0: - return flat_args, {} - args = flat_args[: -len(kwarg_keys)] - kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :], strict=True)) - return args, kwargs - - def _ckpt_forward(self, forward_method, *args, **kwargs): - flat_args, kwarg_keys = self._pack_kwargs(*args, **kwargs) - - def my_function(*inputs): - # unpack back into args and kwargs - nonlocal forward_method, kwarg_keys - unpacked_args, unpacked_kwargs = self._unpack_kwargs(inputs, kwarg_keys) - # run original module - return forward_method(*unpacked_args, **unpacked_kwargs) - - return self.checkpoint_fn( - my_function, - *flat_args, - ) - - def forward(self, module, forward_method, *args, **kwargs): - if not module.training: - return forward_method(*args, **kwargs) - if not self._enable_ckpt: - ret = forward_method(*args, **kwargs) - else: - ret = self._ckpt_forward(forward_method, *args, **kwargs) - binded_tensor = ret - if isinstance(ret, tuple): - binded_tensor = ret[0] - binded_tensor = self._sync_func(binded_tensor) - final_ret = binded_tensor - if isinstance(ret, tuple): - final_ret = (final_ret,) + ret[1:] - return final_ret - - def wrap_module_forward_method(self, module): - orig_method = module.forward - handler = self - - @functools.wraps(orig_method) - def wrapped_method(model_self, *args, **kwargs): - nonlocal handler - handler.pre_forward(model_self) - out = handler.forward(model_self, orig_method, *args, **kwargs) - handler.post_forward(model_self) - return out - - module.forward = wrapped_method.__get__(module, type(module)) - - -def enable_activation_offloading(model, strategy, enable_ckpt=False): - """ - Enable activation offloading for the model. It groups activations by TransformerLayer and offloads activation - groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th - activation group happen at the same time, and there are at most two activation groups in GPU memory. - - Args: - model: the model to enable activation offloading - strategy: the training strategy of the model, such as "fsdp" - enable_ckpt: whether activation checkpointing(also called gradient checkpointing) has been enabled for the model - - Note: - For best efficiency, activation offloading is usually combined with activation checkpointing. However, this - implementation of activation offloading is conflicted with the implementation of activation checkpointing in - some training strategies. This function resolves this conflict, and therefore requires the "strategy" and - "enable_ckpt" arguments. - - Returns: - - """ - - assert strategy == "fsdp" or strategy == "fsdp2", "activation offloading only supports fsdp strategy" - layers = [] - - def get_layers(module): - for name, child in module.named_children(): - if not isinstance(child, FSDP | FSDP2): - get_layers(child) - else: - wrapped_module = child - if isinstance(child, FSDP): - wrapped_module = child._fsdp_wrapped_module - # In some cases, torch.nn.Embedding is wrapped with FSDP alone. However, the activation - # size of torch.nn.Embedding is small, so it's not necessary to offload it. - if not isinstance(wrapped_module, torch.nn.Embedding): - layers.append(child) - - get_layers(model) - if len(layers) < 3: - logger.warning(f"Find only {len(layers)} fsdp layers, not neccessary to enable async activation offloading") - return - - tensor_filter = FSDPParameterFilter() - context, sync_func = get_activation_offload_context(len(layers) - 1, len(layers), tensor_filter) - if enable_ckpt: - # The implementation of activation checkpointing in transformers library is incompatible with - # activation offloading, - # so it will be disabled, but this implementation supports another version of activation checkpointing, so that - # these two features can be enabled at the same time. - for module in model.modules(): - if hasattr(module, "gradient_checkpointing_disable"): - module.gradient_checkpointing_disable() - - handler = ActivationHandler(context, sync_func, tensor_filter, enable_ckpt) - for layer in layers: - module = layer - if isinstance(layer, FSDP): - module = module._fsdp_wrapped_module - handler.wrap_module_forward_method(module) diff --git a/verl/utils/checkpoint/__init__.py b/verl/utils/checkpoint/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/utils/checkpoint/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/utils/checkpoint/checkpoint_manager.py b/verl/utils/checkpoint/checkpoint_manager.py deleted file mode 100644 index 9659b7b89..000000000 --- a/verl/utils/checkpoint/checkpoint_manager.py +++ /dev/null @@ -1,260 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import random -import shutil -import tempfile -from filelock import FileLock - -import numpy as np -import torch -import torch.distributed -from omegaconf import DictConfig -from transformers import PreTrainedTokenizer, ProcessorMixin - -from verl.utils.device import get_device_name, get_torch_device - - -class BaseCheckpointManager: - """ - A checkpoint manager that saves and loads - - model - - optimizer - - lr_scheduler - - extra_states - in a SPMD way. - - We save - - sharded model states and optimizer states - - full lr_scheduler states - - huggingface tokenizer and config for ckpt merge - """ - - def __init__( - self, - model, - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None, - processing_class: PreTrainedTokenizer | ProcessorMixin = None, - checkpoint_config: DictConfig = None, - ): - self.checkpoint_config = checkpoint_config - checkpoint_load_contents = checkpoint_config.get("load_contents", None) if checkpoint_config else None - checkpoint_save_contents = checkpoint_config.get("save_contents", None) if checkpoint_config else None - if checkpoint_load_contents is None: - checkpoint_load_contents = ["model", "optimizer", "extra"] - if checkpoint_save_contents is None: - checkpoint_save_contents = ["model", "optimizer", "extra"] - self.previous_global_step = None - self.previous_saved_paths = [] - - self.model = model - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.processing_class = processing_class - self.checkpoint_load_contents = checkpoint_load_contents - self.checkpoint_save_contents = checkpoint_save_contents - - self.rank = torch.distributed.get_rank() - self.world_size = torch.distributed.get_world_size() - - @property - def should_save_model(self) -> bool: - """ - Returns True if 'model' is in checkpoint_save_contents, indicating the model state should be saved. - """ - return "model" in self.checkpoint_save_contents - - @property - def should_save_optimizer(self) -> bool: - """ - Returns True if 'optimizer' is in checkpoint_save_contents, indicating the optimizer state should be saved. - """ - return "optimizer" in self.checkpoint_save_contents - - @property - def should_save_extra(self) -> bool: - """ - Returns True if 'extra' is in checkpoint_save_contents, indicating the extra state should be saved. - """ - return "extra" in self.checkpoint_save_contents - - @property - def should_save_hf_model(self) -> bool: - """ - Returns True if 'hf_model' is in checkpoint_save_contents, indicating the model should be converted to hf - model and saved. - """ - return "hf_model" in self.checkpoint_save_contents - - @property - def should_load_model(self) -> bool: - """ - Returns True if 'model' is in checkpoint_load_contents, indicating the model state should be loaded. - """ - return "model" in self.checkpoint_load_contents - - @property - def should_load_optimizer(self) -> bool: - """ - Returns True if 'optimizer' is in checkpoint_load_contents, indicating the optimizer state should be loaded. - """ - return "optimizer" in self.checkpoint_load_contents - - @property - def should_load_extra(self) -> bool: - """ - Returns True if 'extra' is in checkpoint_load_contents, indicating the extra state should be loaded. - """ - return "extra" in self.checkpoint_load_contents - - def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load: bool = False): - raise NotImplementedError - - def save_checkpoint( - self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep: int = None - ): - raise NotImplementedError - - @staticmethod - def checkpath(local_path: str, hdfs_path: str): - assert local_path is not None or hdfs_path is not None, "local_path and hdfs_path cannot be both None" - return local_path is not None, local_path if local_path is not None else hdfs_path - - def remove_previous_save_local_path(self, path): - if isinstance(path, str): - path = [path] - for p in path: - abs_path = os.path.abspath(p) - print(f"Checkpoint manager remove previous save local path: {abs_path}") - if not os.path.exists(abs_path): - continue - shutil.rmtree(abs_path, ignore_errors=True) - - @staticmethod - def local_mkdir(path): - if not os.path.isabs(path): - working_dir = os.getcwd() - path = os.path.join(working_dir, path) - - # Using hash value of path as lock file name to avoid long file name - lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock" - lock_path = os.path.join(tempfile.gettempdir(), lock_filename) - - try: - with FileLock(lock_path, timeout=60): # Add timeout - # make a new dir - os.makedirs(path, exist_ok=True) - except Exception as e: - print(f"Warning: Failed to acquire lock for {path}: {e}") - # Even if the lock is not acquired, try to create the directory - os.makedirs(path, exist_ok=True) - - return path - - @staticmethod - def get_rng_state(): - rng_state = { - "cpu": torch.get_rng_state(), - "numpy": np.random.get_state(), - "random": random.getstate(), - } - - if get_device_name() != "cpu": - rng_state[get_device_name()] = get_torch_device().get_rng_state() - - return rng_state - - @staticmethod - def load_rng_state(rng_state): - torch.set_rng_state(rng_state["cpu"]) - np.random.set_state(rng_state["numpy"]) - random.setstate(rng_state["random"]) - - if get_device_name() != "cpu": - get_torch_device().set_rng_state(rng_state[get_device_name()]) - - -def find_latest_ckpt_path(path, directory_format="global_step_{}"): - """ - Return the most recent checkpoint directory based on a tracker file. - - Args: - path (str): Base directory containing the checkpoint tracker. - directory_format (str): Template for checkpoint subfolders with one - placeholder for the iteration number (default "global_step_{}"). - - Returns: - str or None: Full path to the latest checkpoint directory, or - None if the tracker or checkpoint folder is missing. - """ - if path is None: - return None - - tracker_file = get_checkpoint_tracker_filename(path) - if not os.path.exists(tracker_file): - print(f"Checkpoint tracker file does not exist: {tracker_file}") - return None - - with open(tracker_file, "rb") as f: - iteration = int(f.read().decode()) - ckpt_path = os.path.join(path, directory_format.format(iteration)) - if not os.path.exists(ckpt_path): - print("Checkpoint does not exist: %s", ckpt_path) - return None - - print("Found checkpoint: %s", ckpt_path) - return ckpt_path - - -def get_checkpoint_tracker_filename(root_path: str): - """ - Tracker file rescords the latest chckpoint during training to restart from. - """ - return os.path.join(root_path, "latest_checkpointed_iteration.txt") - - -def should_save_ckpt_esi(max_steps_duration: float, save_ckpt_duration: float = 60, redundant_time: float = 0) -> bool: - """ - Determine if checkpoint should be saved based on capacity esi expiration. - - Args: - max_steps_duration: Max estimated time (seconds) required to complete one training step - save_ckpt_duration: Estimated time (seconds) required to save checkpoint (default: 60) - redundant_time: Additional buffer time (seconds) for unexpected delays (default: 0) - """ - exp_ts_mlp = os.getenv("MLP_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP") # vemlp - exp_ts_aws = os.getenv("SAGEMAKER_CURRENT_CAPACITY_BLOCK_EXPIRATION_TIMESTAMP") # aws - if exp_ts_mlp: - try: - import time - - remaining = float(exp_ts_mlp) - time.time() - except ValueError: - return False - return ( - remaining > 0 - and max_steps_duration > 0 - and remaining <= save_ckpt_duration + max_steps_duration + redundant_time - ) - elif exp_ts_aws: - from datetime import datetime, timedelta - - expiration_time = datetime.fromtimestamp(int(exp_ts_aws)) - time_difference = expiration_time - datetime.now() - threshold_minutes = (save_ckpt_duration + max_steps_duration + redundant_time) / 60 - return time_difference < timedelta(minutes=threshold_minutes) - else: - return False diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py deleted file mode 100644 index e81aebbd0..000000000 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ /dev/null @@ -1,351 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -import os -import warnings -from dataclasses import asdict, dataclass -from typing import Optional - -import torch -import torch.distributed -from accelerate import init_empty_weights -from omegaconf import DictConfig -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType -from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin - -from verl.utils.device import is_cuda_available -from verl.utils.fs import copy_to_local, is_non_local, local_mkdir_safe -from verl.utils.fsdp_utils import fsdp_version, get_fsdp_full_state_dict, get_fsdp_state_ctx -from verl.utils.logger import log_with_rank - -from .checkpoint_manager import BaseCheckpointManager - -# Setup logging -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) - - -@dataclass -class FSDPConfig: - """Configuration for FSDP checkpointing. - - Args: - FSDP_version (int): Version of FSDP being used. - world_size (int): Number of processes in the distributed training setup. - """ - - FSDP_version: int - world_size: int - - -class FSDPCheckpointManager(BaseCheckpointManager): - """ - Manage FSDP checkpointing in SPMD training. - - - Saves/loads per-rank sharded model & optimizer states - - Persists full lr_scheduler and RNG state - - Stores HF tokenizer/processor and model/config for unified restore - - Args: - model (FSDP): Wrapped model instance. - optimizer (Optimizer): Training optimizer. - lr_scheduler (LRScheduler): Learning-rate scheduler. - processing_class (PreTrainedTokenizer or ProcessorMixin, optional): - Pre-/post-processing artifact handler. - checkpoint_contents DictConfig: Configuration for checkpoint contents. - - 'load': Components to load; must contain 'model'. Defaults to ['model', 'optimizer', 'extra']. - - 'save': Components to save; must contain 'model'. Defaults to ['model', 'optimizer', 'extra']. - """ - - def __init__( - self, - model: FSDP, - optimizer: Optional[torch.optim.Optimizer] = None, - lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, - processing_class: PreTrainedTokenizer | ProcessorMixin = None, - checkpoint_config: DictConfig = None, - **kwargs, - ): - if processing_class is None: - assert "tokenizer" in kwargs, "tokenizer or processor must be provided" - warnings.warn( - "`tokenizer` is deprecated. use `processing_class` instead.", DeprecationWarning, stacklevel=2 - ) - processing_class = kwargs.pop("tokenizer") - - super().__init__( - model, - optimizer, - lr_scheduler=lr_scheduler, - processing_class=processing_class, - checkpoint_config=checkpoint_config, - ) - - def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): - """ - Load an FSDP checkpoint for this rank. - - Downloads and loads: - - model and optimizer shards - - extra state dict (scheduler + RNG) - - Args: - local_path: Directory with per-rank checkpoint files. - hdfs_path: Unused (for API compatibility). - del_local_after_load: Remove local files after loading. - """ - if local_path is None: - return - - # check if the checkpoint_load_contents is valid - if self.should_load_model: - assert self.model is not None, "model must be provided when checkpoint_contents.load includes ['model']" - if self.should_load_optimizer: - assert self.optimizer is not None, ( - "optimizer must be provided when checkpoint_contents.load includes ['optimizer']" - ) - - # every rank download its own checkpoint - state_dict_cfg = ( - ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) - if self.should_load_model - else None - ) - optim_cfg = ( - ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) - if self.should_load_optimizer - else None - ) - with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): - if self.should_load_model: - remote_model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") - local_model_path = copy_to_local(remote_model_path) - model_state_dict = torch.load(local_model_path, weights_only=False) - self.model.load_state_dict(model_state_dict) - log_with_rank(f"Loaded model from {remote_model_path}", rank=self.rank, logger=logger) - - if self.should_load_optimizer: - remote_optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") - local_optim_path = copy_to_local(remote_optim_path) - optimizer_state_dict = torch.load(local_optim_path, weights_only=False) - self.optimizer.load_state_dict(optimizer_state_dict) - log_with_rank(f"Loaded optimizer from {remote_optim_path}", rank=self.rank, logger=logger) - - if self.should_load_extra: - remote_extra_state_path = os.path.join( - local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" - ) - local_extra_state_path = copy_to_local(remote_extra_state_path) - extra_state_dict = torch.load(local_extra_state_path, weights_only=False) - # recover random state - if "rng" in extra_state_dict: - # 'rng' may not exist for backward compatibility - self.load_rng_state(extra_state_dict["rng"]) - log_with_rank(f"Loaded rng from {remote_extra_state_path}", rank=self.rank, logger=logger) - - lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] - if lr_scheduler_state_dict is not None and self.lr_scheduler is not None: - self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) - log_with_rank(f"Loaded lr_scheduler from {remote_extra_state_path}", rank=self.rank, logger=logger) - - if self.rank == 0 and del_local_after_load: - try: - os.remove(local_model_path) if is_non_local(local_model_path) else None - os.remove(local_optim_path) if is_non_local(local_optim_path) else None - os.remove(local_extra_state_path) if is_non_local(local_extra_state_path) else None - except Exception as e: - log_with_rank( - f"remove local resume ckpt file after loading failed, exception {e} will be ignored", - rank=self.rank, - logger=logger, - ) - - # wait for everyone to load checkpoints - torch.distributed.barrier() - - def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): - """ - Save an FSDP checkpoint for this rank. - - Writes: - - model & optimizer shard files - - extra state dict (scheduler + RNG) - - HF tokenizer/processor and model/config on rank 0 - - optional full HF model under 'huggingface/' if requested - - Rotates old checkpoints, keeping at most `max_ckpt_to_keep`. - - Args: - local_path: Target directory for checkpoint files. - hdfs_path: Unused (for API compatibility). - global_step: Current training step (used for bookkeeping). - max_ckpt_to_keep: Number of recent checkpoints to retain. - """ - if local_path is None: - return - - # record the previous global step - self.previous_global_step = global_step - - # remove previous local_path, only rank 0 should do this - if ( - self.rank == 0 - and max_ckpt_to_keep - and isinstance(max_ckpt_to_keep, int) - and max_ckpt_to_keep > 0 - and len(self.previous_saved_paths) >= max_ckpt_to_keep - ): - keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 - self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) - self.previous_saved_paths = self.previous_saved_paths[keep_start:] - - # if self.rank == 0: # added by Reasoning360: file system got problem on rank0 when co-current making dirs, so we make dirs on rank0 only - local_path = local_mkdir_safe(local_path) - torch.distributed.barrier() - - # check if the checkpoint_save_contents is valid - if self.should_save_model: - assert self.model is not None, "model must be provided when checkpoint_contents.save includes ['model']" - if self.should_save_optimizer: - assert self.optimizer is not None, ( - "optimizer must be provided when checkpoint_contents.save includes ['optimizer']" - ) - - # every rank will save its own model and optim shard - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - with get_fsdp_state_ctx(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): - model_path = os.path.join(local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt") - optim_path = os.path.join(local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt") - extra_path = os.path.join(local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt") - - if self.should_save_model: - model_state_dict = self.model.state_dict() - torch.save(model_state_dict, model_path) - log_with_rank(f"Saved model to {os.path.abspath(model_path)}", rank=self.rank, logger=logger) - - if self.should_save_optimizer: - optimizer_state_dict = self.optimizer.state_dict() - torch.save(optimizer_state_dict, optim_path) - log_with_rank(f"Saved optim to {os.path.abspath(optim_path)}", rank=self.rank, logger=logger) - - if self.should_save_extra: - lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None - extra_state_dict = { - "lr_scheduler": lr_scheduler_state_dict, - "rng": self.get_rng_state(), - } - torch.save(extra_state_dict, extra_path) - log_with_rank(f"Saved extra_state to {os.path.abspath(extra_path)}", rank=self.rank, logger=logger) - - if self.rank == 0: - # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether - # huggingface model is requested to be saved or not. - - if fsdp_version(self.model) == 1: - unwrap_model = self.model._fsdp_wrapped_module - else: - unwrap_model = self.model - - hf_config_tokenizer_path = os.path.join(local_path, "huggingface") - local_mkdir_safe(hf_config_tokenizer_path) - model_config = unwrap_model.config - generation_config = None - if unwrap_model.can_generate() and hasattr(model_config, "name_or_path") and model_config.name_or_path: - try: - # Some model's name_or_path is empty if not initialized from pretrained, - # in this cases, we don't save generation config. - generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) - generation_config.save_pretrained(hf_config_tokenizer_path) - except Exception: - # if the generation config isn't available, we don't save it - pass - - model_config.save_pretrained(hf_config_tokenizer_path) - self.processing_class.save_pretrained(hf_config_tokenizer_path) - log_with_rank( - f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) - - # Also save runtime FSDP config - fsdp_config_path = os.path.join(local_path, "fsdp_config.json") - fsdp_config = FSDPConfig( - FSDP_version=fsdp_version(self.model), - world_size=self.world_size, - ) - with open(fsdp_config_path, "w") as f: - json.dump(asdict(fsdp_config), f, indent=4) - - # wait for everyone to dump to local - torch.distributed.barrier() - - if self.should_save_hf_model: - # Only rank 0 will save hf model and, - # offload to cpu to save LLMs which may be too large to fit in one GPU - state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True) - - if self.rank == 0: - hf_local_path = os.path.join(local_path, "huggingface") - os.makedirs(hf_local_path, exist_ok=True) - - if "ForTokenClassification" in model_config.architectures[0]: - from transformers import AutoModelForTokenClassification - - auto_model_cls = AutoModelForTokenClassification - elif "ForCausalLM" in model_config.architectures[0]: - from transformers import AutoModelForCausalLM - - auto_model_cls = AutoModelForCausalLM - elif "ForConditionalGeneration" in model_config.architectures[0]: - from transformers import AutoModelForVision2Seq - - auto_model_cls = AutoModelForVision2Seq - else: - raise NotImplementedError(f"Unknown architecture {model_config['architectures']}") - - with init_empty_weights(): - save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16) - save_model.to_empty(device="cpu") - - if save_model.can_generate(): - if generation_config is not None: - save_model.generation_config = generation_config - else: - print( - f"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found " - f"in, using a generation config created from the model config when saving hf_model." - ) - - save_model.save_pretrained(hf_local_path, state_dict=state_dict) - log_with_rank( - f"Saved hf_model to {os.path.abspath(hf_local_path)}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) - del state_dict - del save_model - - # wait for rank0 to dump hf_model to local - torch.distributed.barrier() - - self.previous_saved_paths.append(local_path) diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py deleted file mode 100644 index b9fcc551b..000000000 --- a/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ /dev/null @@ -1,526 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import logging -import os -import random -from collections.abc import Callable -from dataclasses import asdict - -import numpy as np -import torch -import torch.distributed -from megatron.core import mpu, tensor_parallel -from megatron.core.dist_checkpointing.mapping import ShardedObject -from megatron.core.transformer.enums import AttnBackend -from transformers import GenerationConfig - -from verl.models.weight_loader_registry import get_weight_saver -from verl.utils.device import get_device_name, get_torch_device -from verl.utils.fs import is_non_local, local_mkdir_safe -from verl.utils.logger import log_with_rank -from verl.utils.megatron.dist_checkpointing import load_dist_checkpointing, save_dist_checkpointing -from verl.utils.megatron_utils import ( - get_dist_checkpoint_path, - get_hf_model_checkpoint_path, - get_transformer_config_checkpoint_path, -) - -from .checkpoint_manager import BaseCheckpointManager - -# Setup logging -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) - - -class MegatronCheckpointManager(BaseCheckpointManager): - """ - Checkpoint manager for Megatron-LM distributed training. - - This class manages the saving and loading of model checkpoints in a Megatron-LM - distributed training environment. It handles various aspects of checkpointing - including model states, optimizer states, learning rate schedulers, and random - number generator states, ensuring compatibility with HuggingFace formats. - - Key features: - - Distributed checkpoint saving and loading using Megatron's dist_checkpointing - - Support for tensor parallel, pipeline parallel, and data parallel configurations - - Automatic handling of model state dictionaries across multiple pipeline stages - - Integration with HuggingFace model configurations and tokenizers - - Random number generator state management for reproducibility - - Support for both synchronous and asynchronous checkpoint operations - - The manager automatically handles: - - Directory structure creation based on global steps and process ranks - - Model configuration and tokenizer saving in HuggingFace format - - Optimizer and scheduler state persistence - - CUDA RNG state management for deterministic training - - Checkpoint cleanup and retention policies - - Args: - model: The Megatron model instance to checkpoint - optimizer: The optimizer instance (optional) - lr_scheduler: The learning rate scheduler instance (optional) - - Attributes: - model: Reference to the Megatron model being checkpointed - optimizer: Reference to the optimizer (if provided) - lr_scheduler: Reference to the learning rate scheduler (if provided) - rank: Current process rank in the distributed setup - - Example: - ```python - checkpoint_manager = MegatronCheckpointManager( - model=megatron_model, - optimizer=optimizer, - lr_scheduler=scheduler - ) - - checkpoint_manager.save_checkpoint( - local_path="checkpoints/step_1000", - global_step=1000 - ) - - checkpoint_manager.load_checkpoint( - local_path="checkpoints/step_1000" - ) - ``` - """ - - def __init__( - self, - config, - checkpoint_config, - model_config, - transformer_config, - role, - model: torch.nn.ModuleList, - arch: str, - hf_config, - param_dtype: torch.dtype, - share_embeddings_and_output_weights: bool, - processing_class, - optimizer, - optimizer_scheduler, - use_distributed_optimizer: bool, - use_checkpoint_opt_param_scheduler: bool = False, - use_dist_checkpointing: bool = True, - bridge=None, - **kwargs, - ): - super().__init__( - model, - optimizer=optimizer, - lr_scheduler=optimizer_scheduler, - processing_class=processing_class, - checkpoint_config=checkpoint_config, - ) - self.arch = arch - self.config = config - self.transformer_config = transformer_config - self.role = role - self.is_value_model = False - if self.role in ["reward", "critic"]: - self.is_value_model = True - self.model_config = model_config - self.hf_config = hf_config - self.param_dtype = param_dtype - self.share_embeddings_and_output_weights = share_embeddings_and_output_weights - self.model_path = self.config.model.path - self.use_distributed_optimizer = use_distributed_optimizer - self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler - self.bridge = bridge - self.rank = torch.distributed.get_rank() - self.use_dist_checkpointing = use_dist_checkpointing or not self.bridge or self.is_value_model - self.use_hf_checkpoint = not self.use_dist_checkpointing - - self.weight_saver = get_weight_saver(self.arch) - - def get_rng_state(self, use_dist_ckpt: bool = True, data_parallel_random_init: bool = False): - """collect rng state across data parallel ranks""" - rng_state = { - "random_rng_state": random.getstate(), - "np_rng_state": np.random.get_state(), - "torch_rng_state": torch.get_rng_state(), - "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), - } - - if get_device_name() != "cpu": - rng_state[f"{get_device_name()}_rng_state"] = get_torch_device().get_rng_state() - - rng_state_list = None - if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init: - rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())] - torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group()) - else: - rng_state_list = [rng_state] - - if use_dist_ckpt: - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() - rng_state_list = ShardedObject( - "rng_state", - rng_state_list, - (pp_size, tp_size), - (pp_rank, tp_rank), - replica_id=mpu.get_data_parallel_rank(with_context_parallel=True), - ) - - return rng_state_list - - def get_checkpoint_name( - self, - checkpoints_path, - pipeline_parallel=None, - tensor_rank=None, - pipeline_rank=None, - cp_rank=None, - expert_parallel=None, - expert_rank=None, - return_base_dir=True, - basename="model.pt", - ): - """Determine the directory name for this rank's checkpoint.""" - # Use both the tensor and pipeline MP rank. - if pipeline_parallel is None: - pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1 - if tensor_rank is None: - tensor_rank = mpu.get_tensor_model_parallel_rank() - if pipeline_rank is None: - pipeline_rank = mpu.get_pipeline_model_parallel_rank() - if cp_rank is None: - cp_rank = mpu.get_context_parallel_rank() - if expert_parallel is None: - expert_parallel = mpu.get_expert_model_parallel_world_size() > 1 - if expert_rank is None: - expert_rank = mpu.get_expert_model_parallel_rank() - - # Use both the tensor and pipeline MP rank. If using the distributed - # optimizer, then the optimizer's path must additionally include the - # data parallel rank. - - # due to the fact that models are identical across cp ranks, cp rank is not used in the checkpoint path - if not pipeline_parallel: - common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}") - else: - common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}") - - if expert_parallel: - common_path = common_path + f"_{expert_rank:03d}" - - # NOTE: Added by Reasoning360: replace os.makedirs by local_mkdir - local_mkdir_safe(common_path) - - if return_base_dir: - return common_path - return os.path.join(common_path, basename) - - def generate_state_dict(self): - # For save dist checkpointing - state_dict = {} - - # All ranks Save Model to reduce memory pressure - if self.should_save_model or self.should_load_model: - # Get sharded state dict, notice that state_dict will collect among dp groups, causing memory pressure - for vpp_rank, model in enumerate(self.model): - if len(self.model) > 1: - mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) - key = f"model{vpp_rank}" if len(self.model) > 1 else "model" - else: - key = "model" - if hasattr(model, "module"): - model = model.module - state_dict[key] = model.sharded_state_dict() - - # Optimizer State Dict - if self.should_save_optimizer or self.should_load_optimizer: - torch.distributed.barrier() - optimizer_sharded_states = self.optimizer.sharded_state_dict(state_dict) - state_dict["optimizer"] = optimizer_sharded_states - - if self.lr_scheduler is not None: - lr_state_dict = self.lr_scheduler.state_dict() - state_dict["lr_scheduler"] = lr_state_dict - - # RNG States State Dict - if self.should_save_extra or self.should_load_extra: - torch.distributed.barrier() - rng_state = self.get_rng_state() - state_dict["rng_state"] = rng_state - - return state_dict - - def load_rng_states(self, rng_states, data_parallel_random_init=False, use_dist_ckpt=True): - # access rng_state for data parallel rank - if data_parallel_random_init: - rng_states = rng_states[mpu.get_data_parallel_rank()] - else: - rng_states = rng_states[0] - random.setstate(rng_states["random_rng_state"]) - np.random.set_state(rng_states["np_rng_state"]) - torch.set_rng_state(rng_states["torch_rng_state"]) - - if get_device_name() != "cpu": - get_torch_device().set_rng_state(rng_states[f"{get_device_name()}_rng_state"]) - - # Check for empty states array - if not rng_states["rng_tracker_states"]: - raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states(rng_states["rng_tracker_states"]) - - def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False): - if local_path is not None: - assert os.path.exists(local_path), f"Checkpoint path {local_path} does not exist." - - dist_checkpoint_path = get_dist_checkpoint_path(local_path) - - # Get State Dict for loading - sharded_state_dict = self.generate_state_dict() - log_with_rank(f"Generated state dict for saving: {sharded_state_dict.keys()}", rank=self.rank, logger=logger) - for vpp_rank, model in enumerate(self.model): - if len(self.model) > 1: - model_i_keys = sharded_state_dict[f"model{vpp_rank}"].keys() - log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger) - else: - log_with_rank( - f"Generated state dict for saving: {sharded_state_dict['model'].keys()}", - rank=self.rank, - logger=logger, - ) - - # Load Dist Checkpointing - state_dict = load_dist_checkpointing( - sharded_state_dict=sharded_state_dict, - ckpt_dir=dist_checkpoint_path, - ) - - if self.should_load_model and self.use_dist_checkpointing: - assert "model" in state_dict or any( - f"model{vpp_rank}" in state_dict for vpp_rank in range(len(self.model)) - ), f"Model state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." - for vpp_rank, model in enumerate(self.model): - if len(self.model) == 1: - model_state_dict = state_dict["model"] - else: - assert f"model{vpp_rank}" in state_dict, f"model{vpp_rank} not found in state_dict" - model_state_dict = state_dict[f"model{vpp_rank}"] - mpu.set_virtual_pipeline_model_parallel_rank(vpp_rank) - self.model[vpp_rank].load_state_dict(model_state_dict) - log_with_rank(f"Loaded sharded model checkpoint from {local_path}", rank=self.rank, logger=logger) - elif self.should_load_model and self.use_hf_checkpoint: - hf_model_path = get_hf_model_checkpoint_path(local_path) - self.bridge.load_weights(self.model, hf_model_path) - log_with_rank(f"Loaded HF model checkpoint from {hf_model_path} with bridge", rank=self.rank, logger=logger) - - if self.should_load_optimizer: - assert "optimizer" in state_dict, ( - f"Optimizer state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." - ) - optimizer_state_dict = state_dict["optimizer"] - self.optimizer.load_state_dict(optimizer_state_dict) - log_with_rank(f"Loaded optimizer checkpoint from {local_path}", rank=self.rank, logger=logger) - if self.use_checkpoint_opt_param_scheduler: - assert "lr_scheduler" in state_dict, ( - f"LR scheduler state dict not found in {state_dict.keys()}. Please check the checkpoint file " - f"{local_path}." - ) - lr_scheduler_state_dict = state_dict["lr_scheduler"] - if self.lr_scheduler is not None: - self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) - log_with_rank(f"Loaded LR scheduler checkpoint from {local_path}", rank=self.rank, logger=logger) - - if self.should_load_extra: - assert "rng_state" in state_dict, ( - f"RNG state dict not found in {state_dict.keys()}. Please check the checkpoint file {local_path}." - ) - rng_state = state_dict["rng_state"] - self.load_rng_states(rng_state) - log_with_rank(f"Loaded RNG states from {local_path}", rank=self.rank, logger=logger) - - if del_local_after_load: - try: - os.remove(local_path) if is_non_local(local_path) else None - except Exception as e: - log_with_rank( - f"remove local resume ckpt file after loading failed, exception {e} will be ignored", - rank=self.rank, - logger=logger, - ) - - def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None): - # record the previous global step - self.previous_global_step = global_step - - # remove previous local_path - if ( - max_ckpt_to_keep - and isinstance(max_ckpt_to_keep, int) - and max_ckpt_to_keep > 0 - and len(self.previous_saved_paths) >= max_ckpt_to_keep - ): - keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 - self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) - self.previous_saved_paths = self.previous_saved_paths[keep_start:] - - local_path = local_mkdir_safe(local_path) - dist_checkpoint_path = get_dist_checkpoint_path(local_path) - - if self.use_dist_checkpointing: - # Generate state dict for saving - state_dict = self.generate_state_dict() - log_with_rank(f"Generated state dict for saving: {state_dict.keys()}", rank=self.rank, logger=logger) - for vpp_rank, model in enumerate(self.model): - if len(self.model) > 1: - model_i_keys = state_dict[f"model{vpp_rank}"].keys() - log_with_rank(f"Generated state dict for saving: {model_i_keys}", rank=self.rank, logger=logger) - else: - log_with_rank( - f"Generated state dict for saving: {state_dict['model'].keys()}", rank=self.rank, logger=logger - ) - # Start Async save if enabled - async_save_request = save_dist_checkpointing( - sharded_state_dict=state_dict, - ckpt_path=dist_checkpoint_path, - async_save=self.checkpoint_config.async_save, - ) - - # Synchronize all async save requests - if not self.checkpoint_config.async_save: - assert async_save_request is None, "Async save request should be None when not using async save." - torch.distributed.barrier() - else: - assert self.use_hf_checkpoint, "use_hf_checkpoint should be True when not using dist checkpointing" - log_with_rank(f"Saving HF model checkpoint to {local_path} with bridge", rank=self.rank, logger=logger) - hf_ckpt_path = get_hf_model_checkpoint_path(local_path) - self.bridge.save_weights(self.model, hf_ckpt_path) - log_with_rank(f"Saved bridge checkpoint to {hf_ckpt_path}", rank=self.rank, logger=logger) - - if self.should_save_model: - # Only rank 0 saves the hf config and tokenizer to huggingface path - # No matter whether we save hf model or not - if self.rank == 0: - # Save tokenizer - hf_config_tokenizer_path = get_hf_model_checkpoint_path(local_path) - self.processing_class.save_pretrained(hf_config_tokenizer_path) - # Save huggingface config - self.hf_config.save_pretrained(hf_config_tokenizer_path) - if hasattr(self.hf_config, "name_or_path") and self.hf_config.name_or_path: - try: - generation_config = GenerationConfig.from_pretrained(self.hf_config.name_or_path) - generation_config.save_pretrained(hf_config_tokenizer_path) - except Exception: - # if the generation config isn't available, we don't save it - pass - log_with_rank( - f"Saved Huggingface config and tokenizer to {hf_config_tokenizer_path}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) - - if self.should_save_extra: - if self.rank == 0: - # Save transformer config - print(self.transformer_config) - transformer_config_dict = asdict(self.transformer_config) - to_convert_types = {torch.dtype: str, AttnBackend: str} - ignore_types = [Callable] - pop_keys = [] - for key, value in transformer_config_dict.items(): - if type(value) in to_convert_types: - transformer_config_dict[key] = to_convert_types[type(value)](value) - if type(value) in ignore_types: - pop_keys.append(key) - if callable(value): - pop_keys.append(key) - for key in pop_keys: - transformer_config_dict.pop(key) - transformer_config_path = get_transformer_config_checkpoint_path(local_path) - with open(transformer_config_path, "w") as f: - json.dump(transformer_config_dict, f, indent=2) - - if self.should_save_hf_model: - # wait for everyone to dump to local - state_dict = self.weight_saver( - self.model, - self.hf_config, - dtype=self.param_dtype, - is_value_model=self.is_value_model, - tie_word_embeddings=self.share_embeddings_and_output_weights, - ) - - torch.distributed.barrier() - if self.rank == 0: - hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path) - import warnings - - from accelerate import init_empty_weights - - with init_empty_weights(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - if "mistral7b-rm" in self.config.model.path: - from transformers import MistralForSequenceClassification - - model = MistralForSequenceClassification.from_pretrained( - self.config.model.path - ) # use score head instead of lm_head - state_dict["score.weight"] = state_dict["score.weight"] - else: - from transformers import AutoModelForCausalLM - - model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto") - model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict) - log_with_rank( - f"Saved Huggingface config and tokenizer to {hf_model_ckpt_path}", - rank=self.rank, - logger=logger, - log_only_rank_0=True, - ) - - if hdfs_path is not None: - log_with_rank( - f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True - ) - from verl.utils import hdfs_io - - hdfs_io.makedirs(hdfs_path, exist_ok=True) - hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True) - log_with_rank( - f"HDFS checkpoint uploaded to {hdfs_path}", rank=self.rank, logger=logger, log_only_rank_0=True - ) - - def finalize_save_fn(): - # Rank 0 uploads checkpoint to HDFS if hdfs_path is provided - log_with_rank( - f"Dist checkpointing save completed for {dist_checkpoint_path}", rank=self.rank, logger=logger - ) - if self.rank == 0: - if hdfs_path is not None: - log_with_rank(f"Uploading checkpoint to {hdfs_path}", rank=self.rank, logger=logger) - from verl.utils import hdfs_io - - hdfs_io.makedirs(hdfs_path, exist_ok=True) - hdfs_io.copy(src=dist_checkpoint_path, dst=hdfs_path, dirs_exist_ok=True) - hdfs_io.copy(src=hf_config_tokenizer_path, dst=hdfs_path, dirs_exist_ok=True) - - if self.checkpoint_config.async_save: - assert async_save_request is not None, "Async save request should not be None when using async save." - async_save_request.add_finalize_fn(finalize_save_fn) - else: - finalize_save_fn() - - self.previous_saved_paths.append(local_path) diff --git a/verl/utils/config.py b/verl/utils/config.py deleted file mode 100644 index f1c301f24..000000000 --- a/verl/utils/config.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import is_dataclass -from typing import Any, Optional - -from omegaconf import DictConfig, ListConfig, OmegaConf - -__all__ = ["omega_conf_to_dataclass"] - - -def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any: - """ - Convert an OmegaConf DictConfig to a dataclass. - - Args: - config: The OmegaConf DictConfig or dict to convert. - dataclass_type: The dataclass type to convert to. When dataclass_type is None, - the DictConfig must contain _target_ to be instantiated via hydra.instantiate API. - - Returns: - The dataclass instance. - """ - # Got an empty config - if not config: - return dataclass_type if dataclass_type is None else dataclass_type() - # Got an object - if not isinstance(config, DictConfig | ListConfig | dict | list): - return config - - if dataclass_type is None: - assert "_target_" in config, ( - "When dataclass_type is not provided, config must contain _target_." - "See trainer/config/ppo_trainer.yaml algorithm section for an example." - ) - from hydra.utils import instantiate - - return instantiate(config, _convert_="partial") - - if not is_dataclass(dataclass_type): - raise ValueError(f"{dataclass_type} must be a dataclass") - cfg = OmegaConf.create(config) # in case it's a dict - cfg_from_dataclass = OmegaConf.structured(dataclass_type) - # let cfg override the existing vals in `cfg_from_dataclass` - cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg) - # now convert to `dataclass_type` - config_object = OmegaConf.to_object(cfg_merged) - return config_object - - -def update_dict_with_config(dictionary: dict, config: DictConfig): - for key in dictionary: - if hasattr(config, key): - dictionary[key] = getattr(config, key) diff --git a/verl/utils/data_process/filter.py b/verl/utils/data_process/filter.py deleted file mode 100644 index a8b406479..000000000 --- a/verl/utils/data_process/filter.py +++ /dev/null @@ -1,40 +0,0 @@ -import transformers -import abc -from typing import Dict - - -class Filter(abc.ABC): - """ - Filter class for filtering data. - """ - def __init__(self): - pass - - @abc.abstractmethod - def check(self, data_entry: Dict) -> bool: - pass - - -class LengthFilter(Filter): - """ - Filter class for filtering data by length. - """ - def __init__(self, tokenizer: transformers.PreTrainedTokenizer = None, min_length: int = 0, max_length: int = 2048, length_tolerance: int = 100): - if tokenizer is None: - self.tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B") - else: - self.tokenizer = tokenizer - self.min_length = min_length - self.max_length = max_length - self.length_tolerance = length_tolerance - - def check(self, data_entry: Dict) -> bool: - if data_entry["prompt"]: - - prompt_tokens = self.tokenizer.tokenize(self.tokenizer.apply_chat_template(data_entry["prompt"], tokenize=False)) - elif data_entry["raw_prompt"]: - prompt_tokens = self.tokenizer.tokenize(data_entry["raw_prompt"]) - else: - raise ValueError("No prompt found in data") - # print(f"Prompt length: {len(prompt_tokens)}") - return self.min_length <= len(prompt_tokens) <= self.max_length - self.length_tolerance diff --git a/verl/utils/data_process/prompt.py b/verl/utils/data_process/prompt.py deleted file mode 100644 index 710906f51..000000000 --- a/verl/utils/data_process/prompt.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Prompt utils for data preprocessing -""" - -ZERO_STYLE_PROMPT_TEMPLATE = """{{START_TOKEN}}{{system_symbol}} -{{system_prompt}}{{END_TOKEN}} -{{START_TOKEN}}{{user_symbol}} -{{prompt}}{{extra_instruction}}{{END_TOKEN}} -{{START_TOKEN}}{{assistant_symbol}} -""" - -# deprecated -ZERO_STYLE_PROMPT_TEMPLATE_2 = """{{START_TOKEN}}{{user_symbol}} -{{prompt}}{{extra_instruction}}{{END_TOKEN}} -{{START_TOKEN}}{{assistant_symbol}} -""" - -SYSTEM_PROMPT = """A conversation between a user and an assistant. The assistant first thinks through the problem step-by-step inside ..., then provides the final response to user.""" -SYSTEM_SYMBOL = "system" -USER_SYMBOL = "user" -ASSISTANT_SYMBOL = "assistant" -START_TOKEN = "<|im_start|>" -END_TOKEN = "<|im_end|>" - - -def build_zero_style_prompt( - template: str = ZERO_STYLE_PROMPT_TEMPLATE, - prompt: str = "", - extra_instruction: str = "", - model_name: str = "Qwen/Qwen2.5-7B" - ) -> str: - if extra_instruction: - extra_instruction = "\n" + extra_instruction - if "Qwen" in model_name: - replacements = { - "{{START_TOKEN}}": START_TOKEN, - "{{END_TOKEN}}": END_TOKEN, - "{{system_symbol}}": SYSTEM_SYMBOL, - "{{user_symbol}}": USER_SYMBOL, - "{{assistant_symbol}}": ASSISTANT_SYMBOL, - "{{system_prompt}}": SYSTEM_PROMPT, - "{{prompt}}": prompt, - "{{extra_instruction}}": extra_instruction, - } - for key, val in replacements.items(): - template = template.replace(key, val) - else: - raise ValueError(f"Unsupported model: {model_name}. Only Qwen is supported for now.") - - return template - - -if __name__ == "__main__": - print("=" * 10) - prompt = "What is the sum of 1 and 2?" - print(build_zero_style_prompt(template=ZERO_STYLE_PROMPT_TEMPLATE, prompt=prompt)) - - print("=" * 10) - prompt = "First thinks through the problem step-by-step inside ..., then provides the final answer. What is the sum of 1 and 2?" - print(build_zero_style_prompt(template=ZERO_STYLE_PROMPT_TEMPLATE_2, prompt=prompt)) diff --git a/verl/utils/data_process/utils.py b/verl/utils/data_process/utils.py deleted file mode 100644 index edd139df2..000000000 --- a/verl/utils/data_process/utils.py +++ /dev/null @@ -1,58 +0,0 @@ -import os -import random -import numpy as np -import torch - -def add_suffix(filename, sample_size): - if sample_size < 1000: - size_str = f"{sample_size}" - elif (sample_size / 1000) % 1 != 0: - size_str = f"{sample_size / 1000:.1f}k" - else: - size_str = f"{sample_size // 1000}k" - return f"{filename}_{size_str}" - - -def set_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - - -def sample_dataset(dataset, sample_size): - """ - Sample a dataset to a given size. - """ - if sample_size is not None: - indices = list(range(len(dataset))) - random.shuffle(indices) - indices = indices[:min(sample_size, len(dataset))] - dataset = dataset.select(indices) - return dataset - - -def save_dataset(dataset, output_dir, filename_prefix, sample_size=None): - """ - Save a dataset to a parquet file with appropriate naming. - - Args: - dataset: The dataset to save - output_dir: Directory to save the dataset - filename_prefix: Base filename to use - sample_size: Sample size to add as suffix to filename - - Returns: - str: Path to the saved file - """ - # Add suffix based on actual dataset size if sample_size is None - if sample_size is None: - sample_size = len(dataset) - - # Create filename with appropriate suffix - filename = add_suffix(filename_prefix, sample_size) - output_path = os.path.join(output_dir, f"{filename}.parquet") - - # Save dataset - dataset.to_parquet(output_path) - - return output_path \ No newline at end of file diff --git a/verl/utils/dataset/README.md b/verl/utils/dataset/README.md deleted file mode 100644 index f886a70aa..000000000 --- a/verl/utils/dataset/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Dataset Format -## RLHF dataset -We combine all the data sources into a single parquet files. We directly organize the prompt into the chat format so that multi-turn chats can be easily incorporated. In the prompt, we may add instruction following texts to guide the model output the answers in a particular format so that we can extract the answers. - -Math problems -```json -{ - "data_source": "openai/gsm8k", - "prompt": [{"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Let's think step by step and output the final answer after \"####\""}], - "ability": "math", - "reward_model": { - "style": "rule", - "ground_truth": ["72"] - }, -} -``` diff --git a/verl/utils/dataset/__init__.py b/verl/utils/dataset/__init__.py deleted file mode 100644 index 6032d68c8..000000000 --- a/verl/utils/dataset/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .rl_dataset import RLHFDataset -from .rm_dataset import RMDataset -from .sft_dataset import SFTDataset - -__all__ = ["RLHFDataset", "RMDataset", "SFTDataset"] diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py deleted file mode 100644 index e3eed0fd6..000000000 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2025 ModelBest Inc. and/or its affiliates - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Multi-turn SFT dataset that supports training on conversation data with multiple turns -""" - -import json -import logging -from typing import Any, Optional - -import numpy as np -import pandas as pd -import torch -from torch.utils.data import Dataset -from transformers import PreTrainedTokenizer - -from verl.utils import hf_tokenizer -from verl.utils.fs import copy_local_path_from_hdfs - - -def convert_nested_value_to_list_recursive(data_item): - if isinstance(data_item, dict): - return {k: convert_nested_value_to_list_recursive(v) for k, v in data_item.items()} - elif isinstance(data_item, list): - return [convert_nested_value_to_list_recursive(elem) for elem in data_item] - elif isinstance(data_item, np.ndarray): - # Convert to list, then recursively process the elements of the new list - return convert_nested_value_to_list_recursive(data_item.tolist()) - else: - # Base case: item is already a primitive type (int, str, float, bool, etc.) - return data_item - - -class MultiTurnSFTDataset(Dataset): - """ - Dataset for multi-turn conversations where each assistant response should be trained - """ - - def __init__(self, parquet_files: str | list[str], tokenizer, config=None): - # Set defaults and extract parameters from config if provided - config = config or {} - self.truncation = config.get("truncation", "error") - self.max_length = config.get("max_length", 1024) - # Get messages_key from the new multiturn config structure - multiturn_config = config.get("multiturn", {}) - self.messages_key = multiturn_config.get("messages_key", "messages") - self.tools_key = multiturn_config.get("tools_key", "tools") - self.enable_thinking_key = multiturn_config.get("enable_thinking_key", "enable_thinking") - assert self.truncation in ["error", "left", "right"] - - if not isinstance(parquet_files, list): - parquet_files = [parquet_files] - - self.parquet_files = parquet_files - if isinstance(tokenizer, str): - tokenizer = hf_tokenizer(tokenizer) - self.tokenizer: PreTrainedTokenizer = tokenizer - - self._download() - self._read_files_and_process() - - def _download(self): - for i, parquet_file in enumerate(self.parquet_files): - self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) - - def _read_files_and_process(self): - def series_to_item(ls): - import numpy - import pandas - - while isinstance(ls, pandas.core.series.Series | numpy.ndarray) and len(ls) == 1: - ls = ls[0] - return ls - - dataframes = [] - for parquet_file in self.parquet_files: - dataframe = pd.read_parquet(parquet_file) - dataframes.append(dataframe) - self.dataframe = pd.concat(dataframes) - - # Extract messages list from dataframe - self.messages = self.dataframe[self.messages_key].apply(series_to_item).tolist() - - # Extract tools list from dataframe - if self.tools_key in self.dataframe.columns: - self.tools = self.dataframe[self.tools_key].apply(convert_nested_value_to_list_recursive).tolist() - else: - self.tools = None - # Extract enable_thinking list from dataframe - if self.enable_thinking_key in self.dataframe.columns: - self.enable_thinking = self.dataframe[self.enable_thinking_key].tolist() - else: - self.enable_thinking = None - - def __len__(self): - return len(self.messages) - - def _process_message_tokens( - self, - messages: list[dict[str, Any]], - start_idx: int, - end_idx: int, - is_assistant: bool = False, - enable_thinking: Optional[bool] = None, - tools: Optional[list[dict[str, Any]]] = None, - ) -> tuple[list[int], list[int], list[int]]: - """ - Process tokens for a single message or a group of messages. - - Args: - messages: List of message dictionaries - start_idx: Start index in messages list - end_idx: End index in messages list - is_assistant: Whether this is an assistant message - enable_thinking: Whether to enable thinking mode - - Returns: - Tuple of (tokens, loss_mask, attention_mask) - """ - if start_idx > 0: - prev_applied_text = self.tokenizer.apply_chat_template( - messages[:start_idx], - tokenize=False, - add_generation_prompt=False, - enable_thinking=enable_thinking, - tools=tools, - ) - if is_assistant: - prev_applied_text_w_generation_prompt = self.tokenizer.apply_chat_template( - messages[:start_idx], - tokenize=False, - add_generation_prompt=True, - enable_thinking=enable_thinking, - tools=tools, - ) - - else: - prev_applied_text = "" - - cur_applied_text = self.tokenizer.apply_chat_template( - messages[:end_idx], - tokenize=False, - add_generation_prompt=False, - enable_thinking=enable_thinking, - tools=tools, - ) - # Get tokens for the current message only - if is_assistant: - generation_prompt_text = prev_applied_text_w_generation_prompt[len(prev_applied_text) :] - generation_prompt_tokens = self.tokenizer.encode( - generation_prompt_text, - add_special_tokens=False, - ) - _message_tokens = self.tokenizer.encode( - cur_applied_text[len(prev_applied_text_w_generation_prompt) :], - add_special_tokens=False, - ) - message_tokens = generation_prompt_tokens + _message_tokens - loss_mask = [0] * (len(generation_prompt_tokens)) + [1] * ( - len(message_tokens) - len(generation_prompt_tokens) - ) - else: - message_tokens = self.tokenizer.encode( - cur_applied_text[len(prev_applied_text) :], - add_special_tokens=False, - ) - loss_mask = [0] * len(message_tokens) - - attention_mask = [1] * len(message_tokens) - - return message_tokens, loss_mask, attention_mask - - def _validate_and_convert_tokens( - self, - full_tokens: torch.Tensor, - concat_tokens: list[int], - concat_loss_mask: list[int], - concat_attention_mask: list[int], - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Validate tokenization and convert to tensors. - - Args: - full_tokens: Full conversation tokens - concat_tokens: Concatenated tokens - concat_loss_mask: Concatenated loss mask - concat_attention_mask: Concatenated attention mask - - Returns: - Tuple of (input_ids, loss_mask, attention_mask) as tensors - """ - full_tokens_list = full_tokens.tolist() - - if len(concat_tokens) != len(full_tokens_list) or not all( - a == b for a, b in zip(concat_tokens, full_tokens_list, strict=True) - ): - logging.warning( - f"Token mismatch detected! Full tokenization length: {len(full_tokens_list)}, Concatenated tokens " - f"length: {len(concat_tokens)}. Using concatenated version." - # f"full tokens text: {self.tokenizer.decode(full_tokens_list)}" - # f"concat tokens text: {self.tokenizer.decode(concat_tokens)}" - ) - return ( - torch.tensor(concat_tokens, dtype=torch.long), - torch.tensor(concat_loss_mask, dtype=torch.long), - torch.tensor(concat_attention_mask, dtype=torch.long), - ) - - return ( - full_tokens, - torch.tensor(concat_loss_mask, dtype=torch.long), - torch.tensor(concat_attention_mask, dtype=torch.long), - ) - - def __getitem__(self, item): - tokenizer = self.tokenizer - messages = self.messages[item] - tools = self.tools[item] if self.tools is not None else None - enable_thinking = self.enable_thinking[item] if self.enable_thinking is not None else None - - if self.tools is not None: - tools = json.loads(self.tools[item]) - else: - tools = None - - # First, get the full conversation tokens - try: - full_tokens = tokenizer.apply_chat_template( - messages, - tools=tools, - tokenize=True, - return_tensors="pt", - add_generation_prompt=False, - enable_thinking=enable_thinking, - ) - except Exception as e: - logging.error( - f"Error applying chat template: {e}\nMessages: {messages}\nTools: {tools}\nEnable thinking: " - f"{enable_thinking}" - ) - raise - - # Track concatenated tokens for validation - concat_tokens = [] - concat_loss_mask = [] - concat_attention_mask = [] - - i = 0 - while i < len(messages): - cur_messages = messages[i] - if cur_messages["role"] == "assistant": - # Process assistant message - tokens, loss_mask, attention_mask = self._process_message_tokens( - messages, i, i + 1, is_assistant=True, enable_thinking=enable_thinking, tools=tools - ) - concat_tokens.extend(tokens) - concat_loss_mask.extend(loss_mask) - concat_attention_mask.extend(attention_mask) - i += 1 - elif cur_messages["role"] == "tool": - # Process consecutive tool messages - st = i - ed = i + 1 - while ed < len(messages) and messages[ed]["role"] == "tool": - ed += 1 - tokens, loss_mask, attention_mask = self._process_message_tokens( - messages, st, ed, enable_thinking=enable_thinking, tools=tools - ) - concat_tokens.extend(tokens) - concat_loss_mask.extend(loss_mask) - concat_attention_mask.extend(attention_mask) - i = ed - elif cur_messages["role"] in ["user", "system"]: - # Process user or system message - if cur_messages["role"] == "system" and i != 0: - raise ValueError("System message should be the first message") - tokens, loss_mask, attention_mask = self._process_message_tokens( - messages, i, i + 1, enable_thinking=enable_thinking, tools=tools - ) - concat_tokens.extend(tokens) - concat_loss_mask.extend(loss_mask) - concat_attention_mask.extend(attention_mask) - i += 1 - else: - raise ValueError(f"Unknown role: {cur_messages['role']}") - - # Validate and convert tokens - input_ids, loss_mask, attention_mask = self._validate_and_convert_tokens( - full_tokens[0], concat_tokens, concat_loss_mask, concat_attention_mask - ) - - # Handle sequence length - sequence_length = input_ids.shape[0] - if sequence_length < self.max_length: - # Pad sequences - pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 - padded_input_ids = torch.full((self.max_length - sequence_length,), pad_token_id, dtype=input_ids.dtype) - padded_attention_mask = torch.zeros((self.max_length - sequence_length,), dtype=attention_mask.dtype) - padded_loss_mask = torch.zeros((self.max_length - sequence_length,), dtype=loss_mask.dtype) - - input_ids = torch.cat((input_ids, padded_input_ids)) - attention_mask = torch.cat((attention_mask, padded_attention_mask)) - loss_mask = torch.cat((loss_mask, padded_loss_mask)) - elif sequence_length > self.max_length: - if self.truncation == "left": - input_ids = input_ids[-self.max_length :] - attention_mask = attention_mask[-self.max_length :] - loss_mask = loss_mask[-self.max_length :] - elif self.truncation == "right": - input_ids = input_ids[: self.max_length] - attention_mask = attention_mask[: self.max_length] - loss_mask = loss_mask[: self.max_length] - elif self.truncation == "error": - raise ValueError(f"{sequence_length=} is larger than {self.max_length=}") - else: - raise ValueError(f"Unknown truncation method {self.truncation}") - - # Create position IDs - position_ids = torch.arange(len(input_ids), dtype=torch.long) - # Zero out position IDs for padding - position_ids = position_ids * attention_mask - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "position_ids": position_ids, - "loss_mask": loss_mask, - } diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py deleted file mode 100644 index 87036f37e..000000000 --- a/verl/utils/dataset/rl_dataset.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The file is temporarily reverted by Reasoning360 to use `dataframe` rather than `Dataset`, -# to support heterogeneous keys of multi-domain data - -import copy -import os -from collections import defaultdict -from typing import Optional - -import numpy as np -import pandas as pd -import torch -from omegaconf import ListConfig -from torch.utils.data import Dataset -from transformers import PreTrainedTokenizer, ProcessorMixin - -import verl.utils.torch_functional as verl_F -from verl.utils.model import compute_position_id_with_mask - - -def collate_fn(data_list: list[dict]) -> dict: - tensors = defaultdict(list) - non_tensors = defaultdict(list) - - for data in data_list: - for key, val in data.items(): - if isinstance(val, torch.Tensor): - tensors[key].append(val) - else: - non_tensors[key].append(val) - - for key, val in tensors.items(): - tensors[key] = torch.stack(val, dim=0) - - for key, val in non_tensors.items(): - non_tensors[key] = np.array(val, dtype=object) - - return {**tensors, **non_tensors} - - -def process_image(image: dict, max_pixels: int = 2048 * 2048, min_pixels: int = 512 * 512): - import math - from io import BytesIO - - from PIL import Image - - if isinstance(image, dict): - image = Image.open(BytesIO(image["bytes"])) - - if (image.width * image.height) > max_pixels: - resize_factor = math.sqrt(max_pixels / (image.width * image.height)) - width, height = int(image.width * resize_factor), int(image.height * resize_factor) - image = image.resize((width, height)) - - if (image.width * image.height) < min_pixels: - resize_factor = math.sqrt(min_pixels / (image.width * image.height)) - width, height = int(image.width * resize_factor), int(image.height * resize_factor) - image = image.resize((width, height)) - - if image.mode != "RGB": - image = image.convert("RGB") - - return image - - -class RLHFDataset(Dataset): - """ - We assume the dataset contains a column that contains prompts and other information - """ - - def __init__( - self, - data_files: str | list[str], - tokenizer: PreTrainedTokenizer, - processor: Optional[ProcessorMixin] = None, - prompt_key="prompt", - image_key="images", - max_prompt_length=1024, - filter_prompts=True, - cache_dir="~/.cache/verl/rlhf", - chat_template_func=None, - return_raw_chat=False, - truncation="error", - filter_overlong_prompts=False, - config=None, - ): - self.config = config - # Prioritize config if provided - if config is not None: - self.cache_dir = os.path.expanduser(config.get("cache_dir", cache_dir)) - self.prompt_key = config.get("prompt_key", prompt_key) - self.image_key = config.get("image_key", image_key) - self.video_key = config.get("video_key", None) - self.max_prompt_length = config.get("max_prompt_length", max_prompt_length) - self.return_raw_chat = config.get("return_raw_chat", return_raw_chat) - self.return_full_prompt = config.get("return_full_prompt", False) - self.truncation = config.get("truncation", truncation) - self.filter_overlong_prompts = config.get("filter_overlong_prompts", filter_overlong_prompts) - self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) - self.num_workers = min(self.num_workers, os.cpu_count()) - self.use_shm = config.get("use_shm", False) - self.chat_template_func = config.get("chat_template_func", chat_template_func) - self.need_tools_kwargs = config.get("need_tools_kwargs", False) - self.filter_prompts = config.get("filter_prompts", filter_prompts) - else: - self.cache_dir = os.path.expanduser(cache_dir) - self.prompt_key = prompt_key - self.image_key = image_key - self.video_key = None - self.max_prompt_length = max_prompt_length - self.return_raw_chat = return_raw_chat - self.return_full_prompt = False - self.truncation = truncation - self.filter_overlong_prompts = filter_overlong_prompts - self.num_workers = max(1, os.cpu_count() // 4) - self.num_workers = min(self.num_workers, os.cpu_count()) - self.use_shm = False - self.chat_template_func = chat_template_func - self.need_tools_kwargs = False - self.filter_prompts = filter_prompts - - parquet_files = data_files - if not isinstance(parquet_files, list | ListConfig): - parquet_files = [parquet_files] - self.parquet_files = copy.deepcopy(parquet_files) - self.original_parquet_files = copy.deepcopy(parquet_files) # use for resume - self.tokenizer = tokenizer - self.processor = processor - # whether to store the dataset in state_dict() - # default not store - self.serialize_dataset = False - self._download() - self._read_files_and_tokenize() - - def _download(self, use_origin_parquet=False): - from verl.utils.fs import copy_to_local - - parquet_files = self.parquet_files if not use_origin_parquet else self.original_parquet_files - for i, parquet_file in enumerate(parquet_files): - self.parquet_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir) - - def _read_files_and_tokenize(self): - dataframes = [] - for parquet_file in self.parquet_files: - # read parquet files and cache - try: - dataframe = pd.read_parquet(parquet_file) - except Exception: - # if pandas fails (most likely due to nested columns), use polars - # NOTE: added by Reasoning360 - import polars as pl - - dataframe = pl.read_parquet(parquet_file).to_pandas() - dataframes.append(dataframe) - self.dataframe = pd.concat(dataframes) - - print(f"dataset len: {len(self.dataframe)}") - - print(self.dataframe.head()) - - # Safely check if apply_chat_template exists in dataframe - # NOTE: added by Reasoning360 - if "apply_chat_template" not in self.dataframe: - print("Warning: apply_chat_template column not found in dataframe. Defaulting to True.") - self.dataframe["apply_chat_template"] = [True] * len(self.dataframe) - - # filter out too long prompts - if self.filter_overlong_prompts: - tokenizer = self.tokenizer - prompt_key = self.prompt_key - self.dataframe = self.dataframe[ - self.dataframe.apply( - lambda doc: len( - tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True) - if doc["apply_chat_template"] - else tokenizer.encode(doc["raw_prompt"]) - ) - <= self.max_prompt_length, - axis=1, - ) - ] - - print(f"filter dataset len: {len(self.dataframe)}") - - def resume_dataset_state(self): - self.serialize_dataset = False if hasattr(self, "original_parquet_files") else True - # resume dataframe if not it's serialized in data.pt - if not self.serialize_dataset: - self._download(use_origin_parquet=True) # download and resume from original parquet files - self._read_files_and_tokenize() - else: - print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance") - - def __len__(self): - return len(self.dataframe) - - def __getitem__(self, item): - """ - Note that we also return the raw_input_ids so that it can be combined with other chat template - """ - row_dict: dict = self.dataframe.iloc[item].to_dict() - - chat = row_dict.pop(self.prompt_key) - - prompt_with_chat_template = ( - self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False) - if row_dict["apply_chat_template"] - else row_dict["raw_prompt"] - ) - - is_multi_modal = self.image_key in row_dict - if is_multi_modal: # expand image token - raw_prompt = prompt_with_chat_template.replace("", "<|vision_start|><|image_pad|><|vision_end|>") - row_dict["multi_modal_data"] = {"image": [process_image(image) for image in row_dict.pop(self.image_key)]} - image_inputs = self.processor.image_processor(row_dict["multi_modal_data"]["image"], return_tensors="pt") - image_grid_thw = image_inputs["image_grid_thw"] - row_dict["multi_modal_inputs"] = {key: val for key, val in image_inputs.items()} - - if image_grid_thw is not None: - merge_length = self.processor.image_processor.merge_size**2 - index = 0 - while "" in prompt_with_chat_template: - prompt_with_chat_template = prompt_with_chat_template.replace( - "", - "<|vision_start|>" - + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length) - + "<|vision_end|>", - 1, - ) - index += 1 - - prompt_with_chat_template = prompt_with_chat_template.replace( - "<|placeholder|>", self.processor.image_token - ) - else: - raw_prompt = prompt_with_chat_template - - input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( - prompt=prompt_with_chat_template, - tokenizer=self.tokenizer, - max_length=self.max_prompt_length, - pad_token_id=self.tokenizer.pad_token_id, - left_pad=True, - truncation=self.truncation, - ) - - if is_multi_modal: - from verl.models.transformers.qwen2_vl import get_rope_index - - position_ids = get_rope_index( - self.processor, - input_ids=input_ids[0], - image_grid_thw=image_grid_thw, - attention_mask=attention_mask[0], - ) # (3, seq_len) - else: - position_ids = compute_position_id_with_mask(attention_mask) - - row_dict["input_ids"] = input_ids[0] - row_dict["attention_mask"] = attention_mask[0] - row_dict["position_ids"] = position_ids[0] - row_dict["raw_prompt_ids"] = self.tokenizer.encode(raw_prompt, add_special_tokens=False) - - # encode prompts without chat template - if self.return_raw_chat: - row_dict["raw_prompt"] = chat.tolist() - - # add index for each prompt - index = row_dict.get("extra_info", {}).get("index", 0) - row_dict["index"] = index - - return row_dict - - def __getstate__(self): - if not self.serialize_dataset: - state = self.__dict__.copy() - - if "dataframe" in state: - del state["dataframe"] - return state - return self.__dict__.copy() diff --git a/verl/utils/dataset/rm_dataset.py b/verl/utils/dataset/rm_dataset.py deleted file mode 100644 index 7af792343..000000000 --- a/verl/utils/dataset/rm_dataset.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import pandas as pd -import torch -from torch.utils.data import Dataset - -from verl.utils import hf_tokenizer - - -def download_files_distributed(download_fn): - import torch.distributed - - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - # download files - download_fn() - - torch.distributed.barrier() - else: - # download anyway - download_fn() - - -class RMDataset(Dataset): - def __init__( - self, - parquet_files: str | list[str], - tokenizer, - prompt_key="prompt", - chosen_key="chosen", - rejected_key="rejected", - max_length=1024, - add_eos=True, - cache_dir="~/.cache/verl/rm", - ): - if not isinstance(parquet_files, list): - parquet_files = [parquet_files] - - self.parquet_files = parquet_files - self.cache_dir = os.path.expanduser(cache_dir) - if isinstance(tokenizer, str): - tokenizer = hf_tokenizer(tokenizer) - self.tokenizer = tokenizer - - self.prompt_key = prompt_key - self.chosen_key = chosen_key - self.rejected_key = rejected_key - - self.add_eos = add_eos - self.max_length = max_length - - self._download() - self._read_files_and_tokenize() - - def _download(self): - def _download_files(): - from verl.utils.fs import copy, is_non_local - - os.makedirs(self.cache_dir, exist_ok=True) - assert os.path.exists(self.cache_dir) - for i, parquet_file in enumerate(self.parquet_files): - if is_non_local(parquet_file): - dst = os.path.join(self.cache_dir, os.path.basename(parquet_file)) - if not os.path.exists(dst): - copy(src=parquet_file, dst=dst) - self.parquet_files[i] = dst - - download_files_distributed(_download_files) - - def _read_files_and_tokenize(self): - dataframes = [] - for parquet_file in self.parquet_files: - # read parquet files and cache - dataframe = pd.read_parquet(parquet_file) - dataframes.append(dataframe) - self.dataframe = pd.concat(dataframes) - self.prompts = self.dataframe[self.prompt_key].tolist() - self.chosen_responses = self.dataframe[self.chosen_key].tolist() - self.rejected_responses = self.dataframe[self.rejected_key].tolist() - - def __len__(self): - return len(self.prompts) - - def _pad_to_length(self, input_ids, attention_mask): - curr_length = input_ids.shape[-1] - - if curr_length < self.max_length: - input_ids = torch.cat( - (input_ids, torch.zeros(size=(self.max_length - curr_length,), dtype=input_ids.dtype)), dim=-1 - ) - attention_mask = torch.cat( - (attention_mask, torch.zeros(size=(self.max_length - curr_length,), dtype=attention_mask.dtype)), dim=-1 - ) - elif curr_length > self.max_length: - input_ids = input_ids[: self.max_length] - attention_mask = attention_mask[: self.max_length] - - return input_ids, attention_mask - - def __getitem__(self, item): - prompt = self.prompts[item] - chosen_response = self.chosen_responses[item] - rejected_response = self.rejected_responses[item] - - prompt_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0] - chosen_response_ids = self.tokenizer(chosen_response, return_tensors="pt")["input_ids"][0] - rejected_response_ids = self.tokenizer(rejected_response, return_tensors="pt")["input_ids"][0] - - if self.add_eos: - chosen_response_ids = torch.cat((chosen_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1) - rejected_response_ids = torch.cat( - (rejected_response_ids, torch.tensor([self.tokenizer.eos_token_id])), dim=-1 - ) - - chosen_input_ids = torch.cat((prompt_ids, chosen_response_ids), dim=-1) - chosen_attention_mask = torch.ones_like(chosen_input_ids) - - rejected_input_ids = torch.cat((prompt_ids, rejected_response_ids), dim=-1) - rejected_attention_mask = torch.ones_like(rejected_input_ids) - - chosen_input_ids, chosen_attention_mask = self._pad_to_length(chosen_input_ids, chosen_attention_mask) - rejected_input_ids, rejected_attention_mask = self._pad_to_length(rejected_input_ids, rejected_attention_mask) - - input_ids = torch.stack((chosen_input_ids, rejected_input_ids), dim=0) - attention_mask = torch.stack((chosen_attention_mask, rejected_attention_mask), dim=0) - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - } diff --git a/verl/utils/dataset/vision_utils.py b/verl/utils/dataset/vision_utils.py deleted file mode 100644 index 75cce7f6a..000000000 --- a/verl/utils/dataset/vision_utils.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from io import BytesIO -from typing import Optional - -import torch -from PIL import Image -from qwen_vl_utils import fetch_image, fetch_video - - -def process_image(image: dict | Image.Image) -> Image.Image: - if isinstance(image, Image.Image): - return image.convert("RGB") - - if "bytes" in image: - assert "image" not in image, "Cannot have both `bytes` and `image`" - image["image"] = BytesIO(image["bytes"]) - - return fetch_image(image) - - -VIDEO_FORMAT_HELP = """Currently, we only support the video formats introduced in qwen2-vl. -Refer to https://github.com/QwenLM/Qwen2.5-VL?tab=readme-ov-file#using---transformers-to-chat. - -eg. -{ - "type": "video", - "video": [ - "file:///path/to/frame1.jpg", - "file:///path/to/frame2.jpg" - ] -} - -{ - "type": "video", - "video": "file:///path/to/video.mp4" -} -# Defaults to fps=2, min_frames=4, max_frames=768 - -{ - "type": "video", - "video": "file:///path/to/video.mp4", - "fps": 2, - "min_frames": 1, - "max_frames": 32 -} -""" - - -def process_video( - video: dict, - nframes: Optional[int] = None, - fps: Optional[float] = None, - fps_min_frames: Optional[int] = None, - fps_max_frames: Optional[int] = None, -) -> torch.Tensor: - """Converts a video dict into a [n_frames, 3, H, W] tensor - - Add video sample FPS in a future MR - """ - - if not isinstance(video, dict) or "video" not in video: - raise NotImplementedError(VIDEO_FORMAT_HELP) - assert nframes is None or fps is None, "Can't use both `nframes` or `fps`" - - # Shallow copy... since we might want to add some keys - video = dict(video) - - contains_sampling_rules = "nframes" in video or "fps" in video - if not contains_sampling_rules: - if nframes is not None: - video["nframes"] = nframes - elif fps is not None: - video["fps"] = fps - if fps_min_frames is not None: - video["min_frames"] = fps_min_frames - if fps_max_frames is not None: - video["max_frames"] = fps_max_frames - - return fetch_video(video) - - -def process_multi_modal_inputs_for_minicpmo(input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs): - # Adjust image bounds based on left padding and cumulative sequence lengths - # This is necessary for MiniCPM-o's vision-language alignment - left_padding_length = torch.argmax(attention_mask, dim=1) - image_bounds = [] - for i in range(len(multi_modal_inputs["image_bound"])): - image_bound = ( - multi_modal_inputs["image_bound"][i].to(left_padding_length.device) - left_padding_length[i] + cu_seqlens[i] - ) - image_bounds.append(image_bound) - - # Flatten pixel values list for MiniCPM-o processing - pixel_values = [] - for i in range(len(multi_modal_inputs["pixel_values"])): - pixel_values.extend([p for p in multi_modal_inputs["pixel_values"][i]]) - - multi_modal_inputs["pixel_values"] = [pixel_values] - multi_modal_inputs["image_bound"] = [torch.vstack(image_bounds)] - multi_modal_inputs["tgt_sizes"] = [torch.vstack(multi_modal_inputs["tgt_sizes"])] - multi_modal_inputs["input_ids"] = input_ids - multi_modal_inputs["attention_mask"] = attention_mask - multi_modal_inputs["position_ids"] = position_ids - return {"data": multi_modal_inputs} diff --git a/verl/utils/debug/__init__.py b/verl/utils/debug/__init__.py deleted file mode 100644 index eb67df1b7..000000000 --- a/verl/utils/debug/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# APIs kept for backward compatibility purpose -# For new features please develop in verl/utils/profiler/ -from ..profiler import * # noqa diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py deleted file mode 100644 index 9186e125a..000000000 --- a/verl/utils/debug/performance.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# APIs kept for backward compatibility purpose -# This file is deprecated, for new features please develop in profiler/performance.py -from verl.utils.profiler.performance import simple_timer, reduce_timing # noqa diff --git a/verl/utils/debug/trajectory_tracker.py b/verl/utils/debug/trajectory_tracker.py deleted file mode 100644 index 73afb8540..000000000 --- a/verl/utils/debug/trajectory_tracker.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Trajectory tracker can be inserted into code to save the intermediate results. -The results will be dump to hdfs for offline comparison. -Each process will have a client that first move all the tensors to CPU -""" - -import io -import os -import tempfile -from collections import deque - -import ray -import torch - -from verl.utils.hdfs_io import copy, makedirs - -remote_copy = ray.remote(copy) - - -@ray.remote -def save_to_hdfs(data: io.BytesIO, name, hdfs_dir, verbose): - filename = name + ".pth" - with tempfile.TemporaryDirectory() as tmpdirname: - local_filepath = os.path.join(tmpdirname, filename) - with open(local_filepath, "wb") as f: - f.write(data.getbuffer()) - # upload to hdfs - - if verbose: - print(f"Saving {local_filepath} to {hdfs_dir}") - try: - copy(local_filepath, hdfs_dir) - except Exception as e: - print(e) - - -@ray.remote -class TrajectoryTracker: - def __init__(self, hdfs_dir, verbose) -> None: - self.hdfs_dir = hdfs_dir - makedirs(hdfs_dir) - self.verbose = verbose - - self.handle = deque() - - def dump(self, data: io.BytesIO, name): - # get a temp file and write to it - self.handle.append(save_to_hdfs.remote(data, name, self.hdfs_dir, self.verbose)) - - def wait_for_hdfs(self): - while len(self.handle) != 0: - future = self.handle.popleft() - ray.get(future) - - -def dump_data(data, name): - enable = os.getenv("VERL_ENABLE_TRACKER", "0") == "1" - if not enable: - return - buffer = io.BytesIO() - torch.save(data, buffer) - tracker = get_trajectory_tracker() - ray.get(tracker.dump.remote(buffer, name)) - - -def get_trajectory_tracker(): - hdfs_dir = os.getenv("VERL_TRACKER_HDFS_DIR", default=None) - verbose = os.getenv("VERL_TRACKER_VERBOSE", default="0") == "1" - assert hdfs_dir is not None - tracker = TrajectoryTracker.options(name="global_tracker", get_if_exists=True, lifetime="detached").remote( - hdfs_dir, verbose - ) - return tracker - - -if __name__ == "__main__": - # testing - os.environ["VERL_ENABLE_TRACKER"] = "1" - os.environ["VERL_TRACKER_HDFS_DIR"] = "~/debug/test" - - @ray.remote - def process(iter): - data = {"obs": torch.randn(10, 20)} - dump_data(data, f"process_{iter}_obs") - - ray.init() - - output_lst = [] - - for i in range(10): - output_lst.append(process.remote(i)) - - out = ray.get(output_lst) - - tracker = get_trajectory_tracker() - ray.get(tracker.wait_for_hdfs.remote()) diff --git a/verl/utils/device.py b/verl/utils/device.py deleted file mode 100644 index ed85b0d5b..000000000 --- a/verl/utils/device.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# This code is inspired by the torchtune. -# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py -# -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license in https://github.com/pytorch/torchtune/blob/main/LICENSE - -import logging - -import torch - -logger = logging.getLogger(__name__) - - -def is_torch_npu_available() -> bool: - """Check the availability of NPU""" - try: - import torch_npu # noqa: F401 - - return torch.npu.is_available() - except ImportError: - return False - - -is_cuda_available = torch.cuda.is_available() -is_npu_available = is_torch_npu_available() - - -def get_visible_devices_keyword() -> str: - """Function that gets visible devices keyword name. - Returns: - 'CUDA_VISIBLE_DEVICES' or `ASCEND_RT_VISIBLE_DEVICES` - """ - return "CUDA_VISIBLE_DEVICES" if is_cuda_available else "ASCEND_RT_VISIBLE_DEVICES" - - -def get_device_name() -> str: - """Function that gets the torch.device based on the current machine. - This currently only supports CPU, CUDA, NPU. - Returns: - device - """ - if is_cuda_available: - device = "cuda" - elif is_npu_available: - device = "npu" - else: - device = "cpu" - return device - - -def get_torch_device() -> any: - """Return the corresponding torch attribute based on the device type string. - Returns: - module: The corresponding torch device namespace, or torch.cuda if not found. - """ - device_name = get_device_name() - try: - return getattr(torch, device_name) - except AttributeError: - logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") - return torch.cuda - - -def get_device_id() -> int: - """Return current device id based on the device type. - Returns: - device index - """ - return get_torch_device().current_device() - - -def get_nccl_backend() -> str: - """Return nccl backend type based on the device type. - Returns: - nccl backend type string. - """ - if is_cuda_available: - return "nccl" - elif is_npu_available: - return "hccl" - else: - raise RuntimeError(f"No available nccl backend found on device type {get_device_name()}.") diff --git a/verl/utils/distributed.py b/verl/utils/distributed.py deleted file mode 100644 index 610b5d4c9..000000000 --- a/verl/utils/distributed.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utilities for distributed training.""" - -import os - -import torch.distributed - -from verl.utils.device import get_nccl_backend, get_torch_device - - -def initialize_global_process_group(timeout_second=36000): - from datetime import timedelta - - torch.distributed.init_process_group( - get_nccl_backend(), - timeout=timedelta(seconds=timeout_second), - init_method=os.environ.get("DIST_INIT_METHOD", None), - ) - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - if torch.distributed.is_initialized(): - get_torch_device().set_device(local_rank) - return local_rank, rank, world_size - - -def destroy_global_process_group(): - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() diff --git a/verl/utils/experimental/__init__.py b/verl/utils/experimental/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/utils/experimental/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/utils/experimental/torch_functional.py b/verl/utils/experimental/torch_functional.py deleted file mode 100644 index 0b4ce5c61..000000000 --- a/verl/utils/experimental/torch_functional.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch - - -def _fused_linear_for_ppo_fwd( - hidden_states: torch.FloatTensor, - vocab_weights: torch.FloatTensor, - input_ids: torch.LongTensor, - temperature: float = 1.0, -) -> tuple[torch.FloatTensor, torch.FloatTensor]: - logits = (hidden_states @ vocab_weights.t()) / temperature - orig_dtype = logits.dtype - logits = logits.to(torch.float32) - - # Slower but more numerically stable to do log_softmax than probs.log() - probs = logits.softmax(dim=-1) - log_probs = logits.log_softmax(dim=-1) - - token_log_probs = log_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1) - - return token_log_probs.to(orig_dtype), entropy.to(orig_dtype) - - -def _fused_linear_for_ppo_bwd( - dlog_probs: Optional[torch.FloatTensor], - dentropy: Optional[torch.FloatTensor], - hidden_states: torch.FloatTensor, - vocab_weights: torch.FloatTensor, - input_ids: torch.LongTensor, - temperature: float = 1.0, -) -> tuple[torch.FloatTensor, torch.FloatTensor]: - logits = (hidden_states @ vocab_weights.t()) / temperature - orig_dtype = logits.dtype - logits = logits.to(torch.float32) - - probs = logits.softmax(dim=-1) - - dlogits = 0 - - # Gradient from log_probs - if dlog_probs is not None: - one_hot_input = torch.zeros_like(logits).scatter_(-1, input_ids.unsqueeze(-1), 1) - dlogits += dlog_probs.to(torch.float32).unsqueeze(-1) * (one_hot_input - probs) - - # Gradient from entropy - if dentropy is not None: - log_probs = logits.log_softmax(dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1) - dlogits += probs * (log_probs + entropy.unsqueeze(-1)) * (-dentropy.unsqueeze(-1)) - - dlogits = dlogits.to(orig_dtype) / temperature - - dhidden_states = dlogits @ vocab_weights - dvocab_weights = dlogits.t() @ hidden_states - - return dhidden_states, dvocab_weights - - -class FusedLinearForPPOFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - hidden_states: torch.FloatTensor, - vocab_weights: torch.FloatTensor, - input_ids: torch.LongTensor, - temperature: float = 1.0, - chunk_size: int = 512, - ) -> tuple[torch.FloatTensor, torch.FloatTensor]: - ctx.set_materialize_grads(False) - - # Cast to a 2D tensor of the shape [T, D] for ease of working - orig_ndim = hidden_states.ndim - assert orig_ndim in (2, 3), f"Invalid hidden_states shape, received {hidden_states.shape}" - - orig_batch_size = -1 - if orig_ndim == 3: - assert input_ids.ndim == 2, f"input_ids shape doesn't match, {hidden_states.shape} {input_ids.shape}" - orig_batch_size = hidden_states.shape[0] - hidden_states = hidden_states.flatten(0, 1) - input_ids = input_ids.flatten(0, 1) - - T = hidden_states.shape[0] - - # Allocate memory for outputs - output_requires_grad = hidden_states.requires_grad or vocab_weights.requires_grad - log_probs = hidden_states.new_zeros(T, requires_grad=output_requires_grad) - entropy = hidden_states.new_zeros(T, requires_grad=output_requires_grad) - - # Perform forward one chunk at a time - for chunk_start in range(0, T, chunk_size): - chunk_end = min(chunk_start + chunk_size, T) - - chunk_log_probs, chunk_entropy = _fused_linear_for_ppo_fwd( - hidden_states=hidden_states[chunk_start:chunk_end], - vocab_weights=vocab_weights, - input_ids=input_ids[chunk_start:chunk_end], - temperature=temperature, - ) - log_probs[chunk_start:chunk_end] = chunk_log_probs - entropy[chunk_start:chunk_end] = chunk_entropy - - # Cast the output back to the original input dimension - if orig_ndim == 3: - log_probs = log_probs.view(orig_batch_size, -1) - entropy = entropy.view(orig_batch_size, -1) - - ctx.save_for_backward(hidden_states, vocab_weights, input_ids) - ctx.orig_batch_size = orig_batch_size - ctx.orig_ndim = orig_ndim - ctx.temperature = temperature - ctx.chunk_size = chunk_size - - return log_probs, entropy - - @staticmethod - def backward(ctx, dlog_probs: Optional[torch.FloatTensor], dentropy: Optional[torch.FloatTensor]): - assert dlog_probs is not None or dentropy is not None - - hidden_states, vocab_weights, input_ids = ctx.saved_tensors - orig_batch_size = ctx.orig_batch_size - orig_ndim = ctx.orig_ndim - temperature = ctx.temperature - chunk_size = ctx.chunk_size - - # Here orig_ndim refers to the orig_ndim of hidden_states - if orig_ndim == 3: - if dlog_probs is not None: - dlog_probs = dlog_probs.flatten() - if dentropy is not None: - dentropy = dentropy.flatten() - - T = hidden_states.shape[0] - - # Allocate memory for outputs - dhidden_states = None - if hidden_states.requires_grad: - dhidden_states = torch.zeros_like(hidden_states) - dvocab_weights = None - if vocab_weights.requires_grad: - dvocab_weights = torch.zeros_like(vocab_weights) - - # Perform backward one chunk at a time - for chunk_start in range(0, T, chunk_size): - chunk_end = min(chunk_start + chunk_size, T) - chunk_dlog_probs = None - if dlog_probs is not None: - chunk_dlog_probs = dlog_probs[chunk_start:chunk_end] - chunk_dentropy = None - if dentropy is not None: - chunk_dentropy = dentropy[chunk_start:chunk_end] - - h, v = _fused_linear_for_ppo_bwd( - dlog_probs=chunk_dlog_probs, - dentropy=chunk_dentropy, - hidden_states=hidden_states[chunk_start:chunk_end], - vocab_weights=vocab_weights, - input_ids=input_ids[chunk_start:chunk_end], - temperature=temperature, - ) - - if hidden_states.requires_grad: - dhidden_states[chunk_start:chunk_end] += h - if vocab_weights.requires_grad: - dvocab_weights += v - - # Cast the output back to the original input dimension - if orig_ndim == 3 and hidden_states.requires_grad: - hidden_size = hidden_states.shape[-1] - dhidden_states = dhidden_states.view(orig_batch_size, -1, hidden_size) - - return ( - dhidden_states, # hidden_states - dvocab_weights, # vocab_weights - None, # input_ids - None, # temperature - None, # chunk_size - ) - - -class FusedLinearForPPO(torch.nn.Module): - def __init__(self, chunk_size: int = 512): - super().__init__() - - self.chunk_size = chunk_size - - def forward( - self, - hidden_states: torch.FloatTensor, - vocab_weights: torch.FloatTensor, - input_ids: torch.LongTensor, - temperature: float = 1.0, - ) -> tuple[torch.FloatTensor, torch.FloatTensor]: - input_ids = input_ids.to(torch.int64) - return FusedLinearForPPOFunction.apply( - hidden_states, - vocab_weights, - input_ids, - temperature, - self.chunk_size, - ) diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py deleted file mode 100644 index 1bed92902..000000000 --- a/verl/utils/flops_counter.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers import PretrainedConfig - -from verl.utils.device import get_torch_device - -VALID_CONFIG_TYPE = { - "llama", - "qwen2", - "qwen2_vl", - "qwen2_5_vl", - "qwen3", - "qwen3_moe", - "deepseek_v3", - "minicpmv", - "minicpmo", -} - - -def get_device_flops(unit="T"): - def unit_convert(number, level): - units = ["B", "K", "M", "G", "T", "P"] - if number <= 0: - return number - ptr = 0 - while ptr < len(units) and units[ptr] != level: - number /= 1000 - ptr += 1 - return number - - device_name = get_torch_device().get_device_name() - flops = float("inf") # INF flops for unkown gpu type - - if "MI300X" in device_name: - flops = 1336e12 - elif "H100" in device_name or "H800" in device_name or "H200" in device_name: - flops = 989e12 - elif "A100" in device_name or "A800" in device_name: - flops = 312e12 - elif "L40" in device_name: - flops = 181.05e12 - elif "L20" in device_name: - flops = 119.5e12 - elif "H20" in device_name: - flops = 148e12 - elif "910B" in device_name: - flops = 354e12 - elif "RTX 3070 Ti" in device_name: - flops = 21.75e12 - flops_unit = unit_convert(flops, unit) - return flops_unit - - -class FlopsCounter: - """ - Used to count mfu during training loop - - Example: - flops_counter = FlopsCounter(config) - flops_achieved, flops_promised = flops_counter.estimate_flops(tokens_list, delta_time) - - """ - - def __init__(self, config: PretrainedConfig): - if config.model_type not in VALID_CONFIG_TYPE: - print( - f"Only support config type of {VALID_CONFIG_TYPE}, but got {config.model_type}. MFU will always be " - f"zero." - ) - - self.estimate_func = { - "qwen2": self._estimate_qwen2_flops, - "llama": self._estimate_qwen2_flops, - "qwen2_moe": self._estimate_qwen2_moe_flops, - "qwen2_vl": self._estimate_qwen2_flops, - "qwen2_5_vl": self._estimate_qwen2_flops, - "qwen3": self._estimate_qwen2_flops, - "qwen3_moe": self._estimate_qwen2_moe_flops, - "deepseek_v3": self._estimate_deepseek_v3_flops, - "minicpmv": self._estimate_qwen2_flops, - "minicpmo": self._estimate_qwen2_flops, - } - self.config = config - - def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time): - return 0 - - def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time): - hidden_size = self.config.hidden_size - vocab_size = self.config.vocab_size - num_hidden_layers = self.config.num_hidden_layers - num_key_value_heads = self.config.num_key_value_heads - num_attention_heads = self.config.num_attention_heads - intermediate_size = self.config.intermediate_size - - head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) - q_size = num_attention_heads * head_dim - k_size = num_key_value_heads * head_dim - v_size = num_key_value_heads * head_dim - - # non-attn per layer parm - # Qwen2/LLama use SwiGelu, gate, having up and down linear layer in mlp - mlp_N = hidden_size * intermediate_size * 3 - attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) - emd_and_lm_head_N = vocab_size * hidden_size * 2 - # non-attn all_layer parm - dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N - # non-attn all_layer & all_token fwd & bwd flops - dense_N_flops = 6 * dense_N * tokens_sum - - # attn all_layer & all_token fwd & bwd flops - seqlen_square_sum = 0 - for seqlen in batch_seqlens: - seqlen_square_sum += seqlen * seqlen - attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers - - # all_layer & all_token fwd & bwd flops - flops_all_token = dense_N_flops + attn_qkv_flops - flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 - return flops_achieved - - def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time): - hidden_size = self.config.hidden_size - vocab_size = self.config.vocab_size - moe_intermediate_size = self.config.moe_intermediate_size - num_hidden_layers = self.config.num_hidden_layers - first_k_dense_replace = self.config.first_k_dense_replace - num_query_heads = self.config.num_attention_heads - moe_num_expert = self.config.n_routed_experts - - moe_topk = self.config.num_experts_per_tok - share_expert_num = self.config.n_shared_experts - - # non-attn per layer parm - moe_gata_N = hidden_size * moe_num_expert - # moe has fc1_1, fc1_2 and fc2 using SwiGLU in ExpertMlp layer & shared experts - moe_expertmlp_N = hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3 - # MLA attn - attn_linear_N = 0 - q_head_dim = self.config.qk_nope_head_dim + self.config.qk_rope_head_dim - if self.config.q_lora_rank is None: - attn_linear_N += hidden_size * num_query_heads * q_head_dim - else: - attn_linear_N += hidden_size * self.config.q_lora_rank - attn_linear_N += num_query_heads * q_head_dim * self.config.q_lora_rank - - attn_linear_N += hidden_size * (self.config.kv_lora_rank + self.config.qk_rope_head_dim) - attn_linear_N += ( - num_query_heads - * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim) - * self.config.kv_lora_rank - ) - attn_linear_N += num_query_heads * self.config.v_head_dim * hidden_size - emd_and_lm_head_N = vocab_size * hidden_size * 2 - # non-attn all_layer parm - moe_N = ( - (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace) - + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace - + emd_and_lm_head_N - ) - # non-attn all_layer & all_token fwd & bwd flops - dense_N_flops = 6 * moe_N * tokens_sum - - # attn all_layer & all_token fwd & bwd flops - seqlen_square_sum = 0 - for seqlen in batch_seqlens: - seqlen_square_sum += seqlen * seqlen * num_hidden_layers - - attn_qkv_flops = 12 * seqlen_square_sum * q_head_dim * num_query_heads - # all_layer & all_token fwd & bwk flops - flops_all_token = dense_N_flops + attn_qkv_flops - flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 - - return flops_achieved - - def _estimate_qwen2_moe_flops(self, tokens_sum, batch_seqlens, delta_time): - hidden_size = self.config.hidden_size - vocab_size = self.config.vocab_size - num_hidden_layers = self.config.num_hidden_layers - num_key_value_heads = self.config.num_key_value_heads - num_attention_heads = self.config.num_attention_heads - moe_intermediate_size = self.config.moe_intermediate_size - moe_topk = self.config.num_experts_per_tok - num_experts = self.config.num_experts - - head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) - q_size = num_attention_heads * head_dim - k_size = num_key_value_heads * head_dim - v_size = num_key_value_heads * head_dim - - # non-attn per layer parm - # gate + moe export - moe_mlp_N = hidden_size * moe_topk * moe_intermediate_size * 3 + hidden_size * num_experts - attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim) - emd_and_lm_head_N = vocab_size * hidden_size * 2 - # non-attn all_layer parm - dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N - # non-attn all_layer & all_token fwd & bwd flops - dense_N_flops = 6 * dense_N * tokens_sum - - # attn all_layer & all_token fwd & bwd flops - seqlen_square_sum = 0 - for seqlen in batch_seqlens: - seqlen_square_sum += seqlen * seqlen - attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers - - # all_layer & all_token fwd & bwd flops - flops_all_token = dense_N_flops + attn_qkv_flops - flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 - return flops_achieved - - def estimate_flops(self, batch_seqlens, delta_time): - """ - Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken. - - Args: - batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the - current batch. - delta_time (float): The time taken to process the batch, in seconds. - - Returns: - estimated_flops (float): The estimated FLOPS based on the input tokens and time. - promised_flops (float): The expected FLOPS of the current device. - """ - tokens_sum = sum(batch_seqlens) - func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops) - estimated_flops = func(tokens_sum, batch_seqlens, delta_time) - promised_flops = get_device_flops() - return estimated_flops, promised_flops diff --git a/verl/utils/fs.py b/verl/utils/fs.py deleted file mode 100644 index 7cc11300f..000000000 --- a/verl/utils/fs.py +++ /dev/null @@ -1,292 +0,0 @@ -#!/usr/bin/env python -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- -"""File-system agnostic IO APIs""" - -import hashlib -import os -import shutil -import tempfile - -try: - from hdfs_io import copy, exists, makedirs # for internal use only -except ImportError: - from .hdfs_io import copy, exists, makedirs - -__all__ = ["copy", "exists", "makedirs"] - -_HDFS_PREFIX = "hdfs://" - - -def is_non_local(path): - """Check if a path is a non-local (HDFS) path. - - Args: - path (str): The path to check. - - Returns: - bool: True if the path is an HDFS path, False otherwise. - """ - return path.startswith(_HDFS_PREFIX) - - -def md5_encode(path: str) -> str: - """Generate an MD5 hash of a path string. - - This function is used to create unique identifiers for paths, typically - for creating cache directories or lock files. - - Args: - path (str): The path to encode. - - Returns: - str: The hexadecimal MD5 hash of the path. - """ - return hashlib.md5(path.encode()).hexdigest() - - -def get_local_temp_path(hdfs_path: str, cache_dir: str) -> str: - """Generate a unique local cache path for an HDFS resource. - Creates a MD5-hashed subdirectory in cache_dir to avoid name conflicts, - then returns path combining this subdirectory with the HDFS basename. - - Args: - hdfs_path (str): Source HDFS path to be cached - cache_dir (str): Local directory for storing cached files - - Returns: - str: Absolute local filesystem path in format: - {cache_dir}/{md5(hdfs_path)}/{basename(hdfs_path)} - """ - # make a base64 encoding of hdfs_path to avoid directory conflict - encoded_hdfs_path = md5_encode(hdfs_path) - temp_dir = os.path.join(cache_dir, encoded_hdfs_path) - os.makedirs(temp_dir, exist_ok=True) - dst = os.path.join(temp_dir, os.path.basename(hdfs_path)) - return dst - - -def verify_copy(src: str, dest: str) -> bool: - """ - verify the copy of src to dest by comparing their sizes and file structures. - - return: - bool: True if the copy is verified, False otherwise. - """ - if not os.path.exists(src): - return False - if not os.path.exists(dest): - return False - - if os.path.isfile(src) != os.path.isfile(dest): - return False - - if os.path.isfile(src): - src_size = os.path.getsize(src) - dest_size = os.path.getsize(dest) - if src_size != dest_size: - return False - return True - - src_files = set() - dest_files = set() - - for root, dirs, files in os.walk(src): - rel_path = os.path.relpath(root, src) - dest_root = os.path.join(dest, rel_path) if rel_path != "." else dest - - if not os.path.exists(dest_root): - return False - - for entry in os.listdir(root): - src_entry = os.path.join(root, entry) - src_files.add(os.path.relpath(src_entry, src)) - - for entry in os.listdir(dest_root): - dest_entry = os.path.join(dest_root, entry) - dest_files.add(os.path.relpath(dest_entry, dest)) - - if src_files != dest_files: - return False - - for rel_path in src_files: - src_entry = os.path.join(src, rel_path) - dest_entry = os.path.join(dest, rel_path) - - if os.path.isdir(src_entry) != os.path.isdir(dest_entry): - return False - - if os.path.isfile(src_entry): - src_size = os.path.getsize(src_entry) - dest_size = os.path.getsize(dest_entry) - if src_size != dest_size: - return False - - return True - - -def copy_to_shm(src: str): - """ - Load the model into /dev/shm to make the process of loading the model multiple times more efficient. - """ - shm_model_root = "/dev/shm/verl-cache/" - src_abs = os.path.abspath(os.path.normpath(src)) - dest = os.path.join(shm_model_root, hashlib.md5(src_abs.encode("utf-8")).hexdigest()) - os.makedirs(dest, exist_ok=True) - dest = os.path.join(dest, os.path.basename(src_abs)) - if os.path.exists(dest) and verify_copy(src, dest): - # inform user and depends on him - print( - f"[WARNING]: The memory model path {dest} already exists. If it is not you want, please clear it and " - f"restart the task." - ) - else: - if os.path.isdir(src): - shutil.copytree(src, dest, symlinks=False, dirs_exist_ok=True) - else: - shutil.copy2(src, dest) - return dest - - -def _record_directory_structure(folder_path): - record_file = os.path.join(folder_path, ".directory_record.txt") - with open(record_file, "w") as f: - for root, dirs, files in os.walk(folder_path): - for dir_name in dirs: - relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path) - f.write(f"dir:{relative_dir}\n") - for file_name in files: - if file_name != ".directory_record.txt": - relative_file = os.path.relpath(os.path.join(root, file_name), folder_path) - f.write(f"file:{relative_file}\n") - return record_file - - -def _check_directory_structure(folder_path, record_file): - if not os.path.exists(record_file): - return False - existing_entries = set() - for root, dirs, files in os.walk(folder_path): - for dir_name in dirs: - relative_dir = os.path.relpath(os.path.join(root, dir_name), folder_path) - existing_entries.add(f"dir:{relative_dir}") - for file_name in files: - if file_name != ".directory_record.txt": - relative_file = os.path.relpath(os.path.join(root, file_name), folder_path) - existing_entries.add(f"file:{relative_file}") - with open(record_file) as f: - recorded_entries = set(f.read().splitlines()) - return existing_entries == recorded_entries - - -def copy_to_local( - src: str, cache_dir=None, filelock=".file.lock", verbose=False, always_recopy=False, use_shm: bool = False -) -> str: - """Copy files/directories from HDFS to local cache with validation. - - Args: - src (str): Source path - HDFS path (hdfs://...) or local filesystem path - cache_dir (str, optional): Local directory for cached files. Uses system tempdir if None - filelock (str): Base name for file lock. Defaults to ".file.lock" - verbose (bool): Enable copy operation logging. Defaults to False - always_recopy (bool): Force fresh copy ignoring cache. Defaults to False - use_shm (bool): Enable shared memory copy. Defaults to False - - Returns: - str: Local filesystem path to copied resource - """ - # Save to a local path for persistence. - local_path = copy_local_path_from_hdfs(src, cache_dir, filelock, verbose, always_recopy) - # Load into shm to improve efficiency. - if use_shm: - return copy_to_shm(local_path) - return local_path - - -def copy_local_path_from_hdfs( - src: str, cache_dir=None, filelock=".file.lock", verbose=False, always_recopy=False -) -> str: - """Deprecated. Please use copy_to_local instead.""" - from filelock import FileLock - - assert src[-1] != "/", f"Make sure the last char in src is not / because it will cause error. Got {src}" - - if is_non_local(src): - # download from hdfs to local - if cache_dir is None: - # get a temp folder - cache_dir = tempfile.gettempdir() - os.makedirs(cache_dir, exist_ok=True) - assert os.path.exists(cache_dir) - local_path = get_local_temp_path(src, cache_dir) - # get a specific lock - filelock = md5_encode(src) + ".lock" - lock_file = os.path.join(cache_dir, filelock) - with FileLock(lock_file=lock_file): - if always_recopy and os.path.exists(local_path): - if os.path.isdir(local_path): - shutil.rmtree(local_path, ignore_errors=True) - else: - os.remove(local_path) - if not os.path.exists(local_path): - if verbose: - print(f"Copy from {src} to {local_path}") - copy(src, local_path) - if os.path.isdir(local_path): - _record_directory_structure(local_path) - elif os.path.isdir(local_path): - # always_recopy=False, local path exists, and it is a folder: check whether there is anything missed - record_file = os.path.join(local_path, ".directory_record.txt") - if not _check_directory_structure(local_path, record_file): - if verbose: - print(f"Recopy from {src} to {local_path} due to missing files or directories.") - shutil.rmtree(local_path, ignore_errors=True) - copy(src, local_path) - _record_directory_structure(local_path) - return local_path - else: - return src - - -def local_mkdir_safe(path): - """_summary_ - Thread-safe directory creation function that ensures the directory is created - even if multiple processes attempt to create it simultaneously. - - Args: - path (str): The path to create a directory at. - """ - - from filelock import FileLock - - if not os.path.isabs(path): - working_dir = os.getcwd() - path = os.path.join(working_dir, path) - - # Using hash value of path as lock file name to avoid long file name - lock_filename = f"ckpt_{hash(path) & 0xFFFFFFFF:08x}.lock" - lock_path = os.path.join(tempfile.gettempdir(), lock_filename) - - try: - with FileLock(lock_path, timeout=60): # Add timeout - # make a new dir - os.makedirs(path, exist_ok=True) - except Exception as e: - print(f"Warning: Failed to acquire lock for {path}: {e}") - # Even if the lock is not acquired, try to create the directory - os.makedirs(path, exist_ok=True) - - return path diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py deleted file mode 100644 index 7465b400e..000000000 --- a/verl/utils/fsdp_utils.py +++ /dev/null @@ -1,556 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import itertools -import json -import math -import os -from collections import OrderedDict -from contextlib import contextmanager, nullcontext - -import torch -import torch.distributed as dist -import torch.nn as nn -from packaging import version -from torch.distributed import DeviceMesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp._runtime_utils import _lazy_init -from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy -from transformers.trainer_pt_utils import get_module_class_from_name - -from verl.utils.device import get_device_id, get_device_name, get_torch_device - -if version.parse(torch.__version__) >= version.parse("2.6"): - from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard -elif version.parse(torch.__version__) >= version.parse("2.4"): - from torch.distributed._composable.fsdp import CPUOffloadPolicy, FSDPModule, MixedPrecisionPolicy, fully_shard -else: - fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = None, None, None, None - - -def init_fn(x: torch.nn.Module): - if torch.distributed.get_rank() != 0: - x = x.to_empty(device=get_device_id(), recurse=False) - get_torch_device().empty_cache() - return x - - -def get_init_weight_context_manager(use_meta_tensor=True, mesh: DeviceMesh = None): - from accelerate import init_empty_weights - - cpu_init_weights = lambda: torch.device("cpu") - if use_meta_tensor: - if mesh is None: - init_context = init_empty_weights if torch.distributed.get_rank() != 0 else cpu_init_weights - else: - init_context = init_empty_weights if mesh.get_coordinate()[-1] != 0 else cpu_init_weights - else: - init_context = cpu_init_weights - return init_context - - -# Copyright 2020-present the HuggingFace Inc. team. -# Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py -def get_fsdp_wrap_policy(module, config=None, is_lora=False): - """Get FSDP wrap policy for the module. - - Args: - module: The module to get wrap policy for - config: Configuration for wrap policy - is_lora: Whether to enable lambda policy for LoRA modules - """ - if config is None: - config = {} - - # NOTE: This is a temporary workaround to be compatible with the OmegaConf & dataclass. We will remove this - # once we have make all config in verl from OmegaConf to data class. - def _get_attr(attr_name, default_value=None): - if hasattr(config, "get"): - return config.get(attr_name, default_value) - else: - return config.__getattribute__(attr_name) - - if _get_attr("disable", False): - return None - - default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None) - fsdp_transformer_layer_cls_to_wrap = _get_attr( - "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap - ) - min_num_params = _get_attr("min_num_params", 0) - auto_wrap_policy = None - - policies = [] - - from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy - - # Add lambda policy for LoRA modules if is_lora is True - if is_lora: - - def lambda_policy_fn(module): - return bool( - len(list(module.named_children())) == 0 - and getattr(module, "weight", None) is not None - and module.weight.requires_grad - ) - - lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) - policies.append(lambda_policy) - - if min_num_params > 0: - size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) - policies.append(size_policy) - elif fsdp_transformer_layer_cls_to_wrap is not None: - transformer_cls_to_wrap = set() - for layer_class in fsdp_transformer_layer_cls_to_wrap: - transformer_cls = get_module_class_from_name(module, layer_class) - if transformer_cls is None: - raise Exception("Could not find the transformer layer class to wrap in the model.") - else: - transformer_cls_to_wrap.add(transformer_cls) - - transformer_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls=transformer_cls_to_wrap, - ) - policies.append(transformer_policy) - - if len(policies) > 0: - auto_wrap_policy = functools.partial(_or_policy, policies=policies) - - return auto_wrap_policy - - -@torch.no_grad() -def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): - if fsdp_version(model) == 2: - offload_fsdp2_model_to_cpu(model, empty_cache) - return - - assert isinstance(model, FSDP) - # lazy init FSDP model - _lazy_init(model, model) - assert model._is_root, "Only support root model offloading to CPU" - for handle in model._all_handles: - if handle._offload_params: - continue - flat_param = handle.flat_param - assert ( - flat_param.data.data_ptr() == flat_param._local_shard.data_ptr() - and id(flat_param.data) != id(flat_param._local_shard) - and flat_param.data.size() == flat_param._local_shard.size() - ) - handle.flat_param_to(torch.device("cpu"), non_blocking=True) - # the following still keeps id(._local_shard) != id(.data) - flat_param._local_shard = flat_param.data - assert id(flat_param._local_shard) != id(flat_param.data) - if empty_cache: - get_torch_device().empty_cache() - - -@torch.no_grad() -def offload_fsdp2_model_to_cpu(model, empty_cache: bool = True): - for param in model.parameters(): - param.data = param.data.to(torch.device("cpu"), non_blocking=True) - if empty_cache: - get_torch_device().empty_cache() - - -@torch.no_grad() -def load_fsdp_model_to_gpu(model: FSDP): - if fsdp_version(model) == 2: - load_fsdp2_model_to_gpu(model) - return - - assert isinstance(model, FSDP) - # lazy init FSDP model - _lazy_init(model, model) - assert model._is_root, "Only support root model loading to GPU" - device_id = get_device_id() - for handle in model._all_handles: - if handle._offload_params: - continue - flat_param = handle.flat_param - handle.flat_param_to(torch.device(f"{get_device_name()}:{device_id}"), non_blocking=True) - # the following still keeps id(._local_shard) != id(.data) - flat_param._local_shard = flat_param.data - - -@torch.no_grad() -def load_fsdp2_model_to_gpu(model): - device = get_device_id() - for param in model.parameters(): - param.data = param.data.to(device, non_blocking=True) - - -@torch.no_grad() -def offload_fsdp_optimizer(optimizer): - if not optimizer.state: - return - for param_group in optimizer.param_groups: - for param in param_group["params"]: - state = optimizer.state[param] - for key, value in state.items(): - if isinstance(value, torch.Tensor): - state[key] = value.to("cpu", non_blocking=True) - - -@torch.no_grad() -def load_fsdp_optimizer(optimizer, device_id): - if not optimizer.state: - return - for param_group in optimizer.param_groups: - for param in param_group["params"]: - state = optimizer.state[param] - for key, value in state.items(): - if isinstance(value, torch.Tensor): - state[key] = value.to(device_id, non_blocking=True) - - -@contextmanager -def meta_device_init(): - """ - Create model parameters with meta device. - - Note buffers in model will still be initialized in default device (e.g., CPU), - since the buffers can be non-persistent and filled with expected values that can - NOT be captured in meta device. - """ - device = torch.device("meta") - old_register_parameter = nn.Module.register_parameter - registered = set() - - def register_empty_parameter(module, name, param): - old_register_parameter(module, name, param) - # we will skip register shared parameters as it - # is already registered previously - if param is not None and param not in registered: - param_cls = type(module._parameters[name]) - kwargs = module._parameters[name].__dict__ - kwargs["requires_grad"] = param.requires_grad - module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) - registered.add(module._parameters[name]) - - try: - nn.Module.register_parameter = register_empty_parameter - yield - finally: - registered.clear() - nn.Module.register_parameter = old_register_parameter - - -def parallel_load_safetensors(filepath): - """ - Parallel load safetensors from huggingface checkpoint - - Huggingface checkpoint contains: - - - config.json: a json file for model configuration - - model.safetensor.index.json: a json file for safetensors (parameters & buffers) index - - model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks - - Or (when model is small), - - - model.safetensors: a binary file for all parameters and buffers - - Each rank will own a part of model chunks and load them directly into GPU memory. - """ - from safetensors.torch import load_file - - safetensors2param = {} - - index_file = os.path.join(filepath, "model.safetensors.index.json") - if os.path.exists(index_file): - index = json.load(open(index_file, "rb")) - for param_name, filename in index["weight_map"].items(): - safetensors2param.setdefault(filename, []).append(param_name) - else: - # in this case, the model is small and we can load it all at once - param_file = os.path.join(filepath, "model.safetensors") - assert os.path.exists(param_file), f"Cannot find {param_file}" - states = load_file(param_file) - for param_name in states: - safetensors2param.setdefault("model.safetensors", []).append(param_name) - del states - - total_files = len(safetensors2param) - ckpt_chunks = sorted(safetensors2param.keys()) - world_size = dist.get_world_size() - size = int(math.ceil(total_files / world_size)) - ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)] - - shard_states = {} - device = get_device_id() - for rank, files in enumerate(ckpt_chunks): - if rank == dist.get_rank(): - for file in files: - file = os.path.join(filepath, file) - states = load_file(file, device=device) - # print(f"rank {rank} loading {file}...") - shard_states.update(states) - else: - for file in files: - for param_name in safetensors2param[file]: - shard_states[param_name] = rank - return shard_states - - -def parallel_init_module_fn(module: torch.nn.Module, shard_states: dict[str, torch.nn.Parameter]): - """ - Generate a function to initialize sub-modules in the `module` with `shard_states` - from huggingface checkpoint. - - Args: - module (torch.nn.Module): the global module to be initialized - shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint - - Returns: - init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states` - """ - - state2fqn = {} - for name, state in itertools.chain( - module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False) - ): - state2fqn.setdefault(state, []).append(name) - # remove standalone parameters and buffers - shared = {s for s, names in state2fqn.items() if len(names) > 1} - materialized_states = {} - - @torch.no_grad() - def create_and_sync_state(param_name, state, is_param): - assert param_name in shard_states, f"{param_name} not loaded" - device = get_device_id() - if is_param: - param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) - else: # buffer - param = torch.empty_like(state.data, device=device) - loaded = shard_states[param_name] - if isinstance(loaded, torch.nn.Parameter | torch.Tensor): - # NOTE: loaded.dtype can be different with param.dtype - param.data.copy_(loaded.data) - dist.broadcast(param.data, src=dist.get_rank()) - else: - assert isinstance(loaded, int) # the rank that holds the state - dist.broadcast(param.data, src=loaded) - shard_states.pop(param_name) - del loaded - return param - - def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): - param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False)) - # param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0]) - for name, state in param_and_buffers: - if not state.is_meta: - continue - is_param = name in sub_mod._parameters - fqn = state2fqn[state].pop(0) - # non-persistent buffers will not be saved in state dict, we can safely skip it - if (not is_param) and fqn not in shard_states: - if state.is_meta: - raise RuntimeError( - f"find a non-persistent buffer ({fqn}) initiated with device meta. Such buffer is not saved " - f"in checkpoint and user should guarantee to init in CPU / GPU device." - ) - continue - # for shared parameter, we get it from the first time it is created - if state in shared: - if state not in materialized_states: - materialized_states[state] = create_and_sync_state(fqn, state, is_param) - else: - if fqn in shard_states: - shard_states.pop(fqn) - materialize_state = materialized_states[state] - # for not shared parameter, we create it directly - else: - materialize_state = create_and_sync_state(fqn, state, is_param) - if is_param: - sub_mod._parameters[name] = materialize_state - else: - sub_mod._buffers[name] = materialize_state - if recurse: - for module in sub_mod.children(): - init_fn(module, recurse=True) - - # for debug - # if len(shard_states) == 0: print("clear") - return sub_mod - - return init_fn - - -def fsdp_version(model): - if isinstance(model, FSDP): - return 1 - elif isinstance(model, FSDPModule): - return 2 - else: - return 0 - - -def get_fsdp_state_ctx(model, state_type, state_cfg, optim_cfg): - if fsdp_version(model) == 1: - return FSDP.state_dict_type(model, state_type, state_cfg, optim_cfg) - else: - return nullcontext() - - -def get_fsdp_full_state_dict(model: torch.nn.Module, offload_to_cpu: bool = True, rank0_only: bool = True): - """ - Get the full state dict from an FSDP model. - - Args: - model (torch.nn.Module): The FSDP model to get state dict from - offload_to_cpu (bool, optional): Whether to offload the state dict to CPU. Defaults to True. - rank0_only (bool, optional): Whether to only get state dict on rank 0. Defaults to True. - - Returns: - dict: The full state dict of the model - - Raises: - NotImplementedError: If the FSDP version is unknown - """ - if fsdp_version(model) == 1: - from torch.distributed.fsdp import FullStateDictConfig, StateDictType - - state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only) - with get_fsdp_state_ctx( - model, state_type=StateDictType.FULL_STATE_DICT, state_cfg=state_dict_config, optim_cfg=None - ): - state_dict = model.state_dict() - return state_dict - elif fsdp_version(model) == 2: - from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict - - state_dict_config = StateDictOptions( - full_state_dict=True, cpu_offload=offload_to_cpu, broadcast_from_rank0=not rank0_only - ) - state_dict = get_model_state_dict(model, options=state_dict_config) - return state_dict - else: - raise NotImplementedError(f"Unknown FSDP version {fsdp_version}") - - -def fsdp2_load_full_state_dict(model: torch.nn.Module, full_state: dict, device_mesh=None, cpu_offload=None): - """ - Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the - parameters from rank 0 to all other ranks. This function modifies the model in-place. - - Args: - model (`torch.nn.Module`): The model to load the state dict into - full_state (`dict`): The full state dict to load, can only be on rank 0 - """ - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict - - # To broadcast, it needs to be instantiated in the GPU. - if dist.get_rank() == 0: - model = model.to(device=get_device_id(), non_blocking=True) - else: - model = model.to_empty(device=get_device_id()) - - cpu_offload = cpu_offload is not None - options = StateDictOptions(full_state_dict=True, cpu_offload=cpu_offload, broadcast_from_rank0=True) - set_model_state_dict(model, full_state, options=options) - - # rotary_emb is not in state_dict, so we need to broadcast it manually - for name, buf in model.named_buffers(): - dist.broadcast(buf, src=0) - - if cpu_offload: - model.to("cpu", non_blocking=True) - for buf in model.buffers(): - buf.data = buf.data.to(get_device_id()) - - -def apply_fsdp2(model, fsdp_kwargs, config): - """model: AutoModelForCausalLM""" - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" - - default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) - fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get( - "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap - ) - - if isinstance(fsdp_transformer_layer_cls_to_wrap, str): - fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] - - assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None - - modules = [] - for name, module in model.named_modules(): - if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or ( - isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings - ): - modules.append(module) - - for idx, module in enumerate(modules): - fully_shard(module, **fsdp_kwargs) - fully_shard(model, **fsdp_kwargs) # fsdp2 will not reshard_after_forward for root module - - -def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None): - """torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor""" - from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm - - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - else: - # prevent generators from being exhausted - parameters = list(parameters) - grads = [p.grad for p in parameters if p.grad is not None] - total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) - total_norm = total_norm.to(get_device_id(), non_blocking=True) - _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) - return total_norm - - -def layered_summon_lora_params(fsdp_module) -> OrderedDict: - from peft.utils.save_and_load import get_peft_model_state_dict - - def __prefix_submodules(module, prefix): - for name, submodule in module.named_modules(): - if name.startswith(prefix) and "." not in name[len(prefix) :]: - yield name, submodule - - lora_params = OrderedDict() - prefix_list = [ - # fsdp - "_fsdp_wrapped_module.base_model.model.", - "_fsdp_wrapped_module.base_model.model.model.", - "_fsdp_wrapped_module.base_model.model.model.layers.", - # fsdp2 - "base_model.model.", - "base_model.model.model.", - "base_model.model.model.layers.", - ] - peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module) - for prefix in prefix_list: - for name, submodule in __prefix_submodules(fsdp_module, prefix): - prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.") - if name.endswith(".model") or name.endswith(".layers"): - continue - if fsdp_version(submodule) > 0: - with FSDP.summon_full_params(submodule, writeback=False): - sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict()) - sub_lora_params = { - f"{prefix}.{name}": param.full_tensor().detach().cpu() - if hasattr(param, "full_tensor") - else param.detach().cpu() - for name, param in sub_lora_params.items() - } - lora_params.update(sub_lora_params) - submodule._is_root = False - get_torch_device().empty_cache() - return lora_params diff --git a/verl/utils/hdfs_io.py b/verl/utils/hdfs_io.py deleted file mode 100644 index 31edda1f6..000000000 --- a/verl/utils/hdfs_io.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import shutil - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) - -_HDFS_PREFIX = "hdfs://" - -_HDFS_BIN_PATH = shutil.which("hdfs") - - -def exists(path: str, **kwargs) -> bool: - r"""Works like os.path.exists() but supports hdfs. - - Test whether a path exists. Returns False for broken symbolic links. - - Args: - path (str): path to test - - Returns: - bool: True if the path exists, False otherwise - """ - if _is_non_local(path): - return _exists(path, **kwargs) - return os.path.exists(path) - - -def _exists(file_path: str): - """hdfs capable to check whether a file_path is exists""" - if file_path.startswith("hdfs"): - return _run_cmd(_hdfs_cmd(f"-test -e {file_path}")) == 0 - return os.path.exists(file_path) - - -def makedirs(name, mode=0o777, exist_ok=False, **kwargs) -> None: - r"""Works like os.makedirs() but supports hdfs. - - Super-mkdir; create a leaf directory and all intermediate ones. Works like - mkdir, except that any intermediate path segment (not just the rightmost) - will be created if it does not exist. If the target directory already - exists, raise an OSError if exist_ok is False. Otherwise no exception is - raised. This is recursive. - - Args: - name (str): directory to create - mode (int): file mode bits - exist_ok (bool): if True, do not raise an exception if the directory already exists - kwargs: keyword arguments for hdfs - - """ - if _is_non_local(name): - # TODO(haibin.lin): - # - handle OSError for hdfs(?) - # - support exist_ok for hdfs(?) - _mkdir(name, **kwargs) - else: - os.makedirs(name, mode=mode, exist_ok=exist_ok) - - -def _mkdir(file_path: str) -> bool: - """hdfs mkdir""" - if file_path.startswith("hdfs"): - _run_cmd(_hdfs_cmd(f"-mkdir -p {file_path}")) - else: - os.makedirs(file_path, exist_ok=True) - return True - - -def copy(src: str, dst: str, **kwargs) -> bool: - r"""Works like shutil.copy() for file, and shutil.copytree for dir, and supports hdfs. - - Copy data and mode bits ("cp src dst"). Return the file's destination. - The destination may be a directory. - If source and destination are the same file, a SameFileError will be - raised. - - Arg: - src (str): source file path - dst (str): destination file path - kwargs: keyword arguments for hdfs copy - - Returns: - str: destination file path - - """ - if _is_non_local(src) or _is_non_local(dst): - # TODO(haibin.lin): - # - handle SameFileError for hdfs files(?) - # - return file destination for hdfs files - return _copy(src, dst) - else: - if os.path.isdir(src): - return shutil.copytree(src, dst, **kwargs) - else: - return shutil.copy(src, dst, **kwargs) - - -def _copy(from_path: str, to_path: str, timeout: int = None) -> bool: - if to_path.startswith("hdfs"): - if from_path.startswith("hdfs"): - returncode = _run_cmd(_hdfs_cmd(f"-cp -f {from_path} {to_path}"), timeout=timeout) - else: - returncode = _run_cmd(_hdfs_cmd(f"-put -f {from_path} {to_path}"), timeout=timeout) - else: - if from_path.startswith("hdfs"): - returncode = _run_cmd( - _hdfs_cmd( - f"-get \ - {from_path} {to_path}" - ), - timeout=timeout, - ) - else: - try: - shutil.copy(from_path, to_path) - returncode = 0 - except shutil.SameFileError: - returncode = 0 - except Exception as e: - logger.warning(f"copy {from_path} {to_path} failed: {e}") - returncode = -1 - return returncode == 0 - - -def _run_cmd(cmd: str, timeout=None): - return os.system(cmd) - - -def _hdfs_cmd(cmd: str) -> str: - return f"{_HDFS_BIN_PATH} dfs {cmd}" - - -def _is_non_local(path: str): - return path.startswith(_HDFS_PREFIX) diff --git a/verl/utils/import_utils.py b/verl/utils/import_utils.py deleted file mode 100644 index fc75541e6..000000000 --- a/verl/utils/import_utils.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Utilities to check if packages are available. -We assume package availability won't change during runtime. -""" - -import importlib -import importlib.util -import os -import warnings -from functools import cache, wraps -from typing import Optional - - -@cache -def is_megatron_core_available(): - try: - mcore_spec = importlib.util.find_spec("megatron.core") - except ModuleNotFoundError: - mcore_spec = None - return mcore_spec is not None - - -@cache -def is_vllm_available(): - try: - vllm_spec = importlib.util.find_spec("vllm") - except ModuleNotFoundError: - vllm_spec = None - return vllm_spec is not None - - -@cache -def is_sglang_available(): - try: - sglang_spec = importlib.util.find_spec("sglang") - except ModuleNotFoundError: - sglang_spec = None - return sglang_spec is not None - - -@cache -def is_nvtx_available(): - try: - nvtx_spec = importlib.util.find_spec("nvtx") - except ModuleNotFoundError: - nvtx_spec = None - return nvtx_spec is not None - - -@cache -def is_trl_available(): - try: - trl_spec = importlib.util.find_spec("trl") - except ModuleNotFoundError: - trl_spec = None - return trl_spec is not None - - -def import_external_libs(external_libs=None): - if external_libs is None: - return - if not isinstance(external_libs, list): - external_libs = [external_libs] - import importlib - - for external_lib in external_libs: - importlib.import_module(external_lib) - - -def load_extern_type(file_path: Optional[str], type_name: Optional[str]) -> type: - """Load a external data type based on the file path and type name""" - if not file_path: - return None - - if file_path.startswith("pkg://"): - # pkg://verl.utils.dataset.rl_dataset - # pkg://verl/utils/dataset/rl_dataset - module_name = file_path[6:].replace("/", ".") - module = importlib.import_module(module_name) - - else: - # file://verl/utils/dataset/rl_dataset - # file:///path/to/verl/utils/dataset/rl_dataset.py - # or without file:// prefix - if file_path.startswith("file://"): - file_path = file_path[7:] - - if not os.path.exists(file_path): - raise FileNotFoundError(f"Custom type file '{file_path}' not found.") - - spec = importlib.util.spec_from_file_location("custom_module", file_path) - module = importlib.util.module_from_spec(spec) - try: - spec.loader.exec_module(module) - except Exception as e: - raise RuntimeError(f"Error loading module from '{file_path}'") from e - - if not hasattr(module, type_name): - raise AttributeError(f"Custom type '{type_name}' not found in '{file_path}'.") - - return getattr(module, type_name) - - -def _get_qualified_name(func): - """Get full qualified name including module and class (if any).""" - module = func.__module__ - qualname = func.__qualname__ - return f"{module}.{qualname}" - - -def deprecated(replacement: str = ""): - """Decorator to mark functions or classes as deprecated.""" - - def decorator(obj): - qualified_name = _get_qualified_name(obj) - - if isinstance(obj, type): - original_init = obj.__init__ - - @wraps(original_init) - def wrapped_init(self, *args, **kwargs): - msg = f"Warning: Class '{qualified_name}' is deprecated." - if replacement: - msg += f" Please use '{replacement}' instead." - warnings.warn(msg, category=FutureWarning, stacklevel=2) - return original_init(self, *args, **kwargs) - - obj.__init__ = wrapped_init - return obj - - else: - - @wraps(obj) - def wrapped(*args, **kwargs): - msg = f"Warning: Function '{qualified_name}' is deprecated." - if replacement: - msg += f" Please use '{replacement}' instead." - warnings.warn(msg, category=FutureWarning, stacklevel=2) - return obj(*args, **kwargs) - - return wrapped - - return decorator diff --git a/verl/utils/kernel/__init__.py b/verl/utils/kernel/__init__.py deleted file mode 100644 index e32d583d3..000000000 --- a/verl/utils/kernel/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py deleted file mode 100644 index a125bacda..000000000 --- a/verl/utils/kernel/kernels.py +++ /dev/null @@ -1,1553 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Implementations of the linear cross entropy with token entropy kernel. -""" - -import typing -from dataclasses import dataclass - -import torch -import torch.distributed as dist -import triton -import triton.language as tl - -from verl.utils.device import get_torch_device - - -@dataclass -class EntropyReductionEnum: - """ - Enum for the reduction method of cross entropy. - """ - - _None = 0 - _Sum = 1 - _Mean = 2 - - -def get_entropy_reduction_enum_number(reduction: str) -> int: - """ - Get the enum number for the reduction method of cross entropy. - """ - _enum = EntropyReductionEnum._None - if reduction == "none": - _enum = EntropyReductionEnum._None - elif reduction == "sum": - _enum = EntropyReductionEnum._Sum - elif reduction == "mean": - _enum = EntropyReductionEnum._Mean - else: - raise ValueError(f"Invalid reduction: {reduction}") - return _enum - - -def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: - """ - Get the enum for the reduction method of cross entropy. - """ - _enum = EntropyReductionEnum._None - if ce_reduction == 0: - _enum = EntropyReductionEnum._None - elif ce_reduction == 1: - _enum = EntropyReductionEnum._Sum - elif ce_reduction == 2: - _enum = EntropyReductionEnum._Mean - else: - raise ValueError(f"Invalid ce_reduction: {ce_reduction}") - return _enum - - -@dataclass -class BackwardEnum: - """ - Enum for the backward method. - """ - - _Total_Fuse_MN = ( - 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight - ) - _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight - _Split_Dlogits_N = 2 # split d_logits along its N dimension, aka. vocab_size - _Split_Dlogits_M = 3 # split d_logits along its M dimension, aka. num_tokens - - -@dataclass -class Config: - _backward: BackwardEnum = BackwardEnum._Split_Dlogits_N - _use_triton: bool = True - - -_config = Config() - - -def set_backward_method(backward_method: BackwardEnum): - """ - Set the backward method. - """ - global _config - _config._backward = backward_method - - -@triton.autotune( - configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=3, num_warps=8)], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_kernel_general_mainloop( - rank, - hidden_ptr, - weight_ptr, - labels_ptr, - num_tokens, - hidden_size, - vocab_size, - vocab_per_split, - stride_hidden_m: tl.int64, - stride_hidden_k: tl.int64, - stride_weight_n: tl.int64, - stride_weight_k: tl.int64, - max_ptr, - stride_max_m: tl.int64, - stride_max_n: tl.int64, - accu_ptr, - stride_accu_m: tl.int64, - stride_accu_n: tl.int64, - entropy_b_ptr, - stride_entropy_b_m: tl.int64, - stride_entropy_b_n: tl.int64, - global_logprobs_ptr, - stride_global_logprobs: tl.int64, - global_logprobs_scalar_ptr, - rcp_temperature: tl.float32, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - """ - forward mainloop - """ - pid = tl.program_id(axis=0) - num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) - pid_m = pid % num_pid_m - pid_n = pid // num_pid_m - - if pid_m == 0 and pid_n == 0: - tl.store(global_logprobs_scalar_ptr, 0.0) - - # create pointers for the first blocks of hidden - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_k = tl.arange(0, BLOCK_SIZE_K) - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - - # load labels for this block - labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) - - # traverse over N dimension - # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - _max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32) - _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - for n in range(0, num_pid_n): - offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) - - # iterate over K dimension - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - # load the next block of hidden and weight - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - other=0.0, - ) - # _weight = tl.load(weight_ptrs, - # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < (min( - # (pid_n + 1) * vocab_per_split, vocab_size))), - # other=0.0) - - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) - & (offs_bn[:, None] < (min((pid_n + 1) * vocab_per_split, vocab_size))), - other=0.0, - ) - - # GEMM - logits = tl.dot(_hidden, _weight.trans(), logits) - - # advance the ptrs to the next K block - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - # reset hidden_ptrs for next iteration - hidden_ptrs -= hidden_size * stride_hidden_k - - # scale logits by temperature - logits *= rcp_temperature - - # update global maximum - _max_old = _max - m_pid_n = tl.max(logits, axis=1) - _max = tl.maximum(_max_old, m_pid_n) - - exp_logits = tl.exp(logits - _max[:, None]) - coeff = tl.exp(_max_old - _max) - _accu = coeff * _accu + tl.sum(exp_logits, axis=1) - - _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1) - - label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] - _logprobs += tl.sum(logits * label_mask, axis=1) - - # store maximum - offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_max_n = pid_n - maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m - tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) - - # store entropy - accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m - tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits)) - entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m - tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) - - # store logprobs - vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size - vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size - mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx) - mask &= offs_am < num_tokens - global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs - # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask) - tl.store(global_logprobs_ptrs, _logprobs, mask=mask) - - -@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) -@triton.jit -def efficient_entropy_triton_kernel_epilogue( - max_ptr, - stride_max_m: tl.int64, - stride_max_n: tl.int64, - num_tokens, - num_splits, - global_max_ptr, - stride_global_max: tl.int64, - accu_ptr, - stride_accu_m: tl.int64, - stride_accu_n: tl.int64, - global_accu_ptr, - stride_global_accu: tl.int64, - entropy_b_ptr, - stride_entropy_b_m: tl.int64, - stride_entropy_b_n: tl.int64, - global_entropy_b_ptr, - stride_global_entropy_b: tl.int64, - global_entropy_ptr, - stride_global_entropy: tl.int64, - global_logprobs_ptr, - stride_global_logprobs: tl.int64, - global_logprobs_scalar_ptr, - reduction: int, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, -): - """ - foward epilogue - """ - pid_m = tl.program_id(axis=0) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n - - _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) - - accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n - _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) - - entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n - _entropy_b = tl.load( - entropy_b_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0 - ) - - # local reduction - _max_old = global_max - _local_max = tl.max(_max, axis=1) - global_max = tl.maximum(global_max, _local_max) - - _scale = tl.exp(_max - global_max[:, None]) - _coeff = tl.exp(_max_old - global_max) - global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) - global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) - - # store - maximum_ptrs = global_max_ptr + offs_m * stride_global_max - tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) - - # store entropy_b - global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b - tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) - - # store entropy - global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu - tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) - global_entropy = tl.log(global_accu) + global_max - global_entropy_b # entropy_a - global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy - tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens) - # update logprobs - global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs - global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) - global_logprobs = global_max + tl.log(global_accu) - global_logprobs - - global_logprobs = -1 * global_logprobs - if reduction == 0: - tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) - elif reduction == 1: - global_logprobs_scalar = tl.sum(global_logprobs, axis=0) - tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) - elif reduction == 2: - global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) - tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) - - -@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) -@triton.jit -def efficient_entropy_triton_kernel_epilogue_tp( - num_tokens, - num_splits, - reduced_max_ptr, - stride_reduced_max_m: tl.int64, - stride_reduced_max_n: tl.int64, - original_max_ptr, - stride_original_max_m: tl.int64, - stride_original_max_n: tl.int64, - accu_ptr, - stride_accu_m: tl.int64, - stride_accu_n: tl.int64, - entropy_b_ptr, - stride_entropy_b_m: tl.int64, - stride_entropy_b_n: tl.int64, - global_max_ptr, - stride_global_max: tl.int64, - global_accu_ptr, - stride_global_accu: tl.int64, - global_entropy_b_ptr, - stride_global_entropy_b: tl.int64, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - - global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - _reduced_max = tl.load( - reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + offs_n[None, :] * stride_reduced_max_n, - mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), - other=0.0, - ) - _original_max = tl.load( - original_max_ptr + offs_m[:, None] * stride_original_max_m + offs_n[None, :] * stride_original_max_n, - mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), - other=0.0, - ) - _accu = tl.load( - accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, - mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), - other=0.0, - ) - - # local reduce-max - _max_old = global_max - _local_max = tl.max(_reduced_max, axis=1) - global_max = tl.maximum(global_max, _local_max) - - # update accumulate - _coeff = tl.exp(_max_old - global_max) - _scale = tl.exp(_original_max - global_max[:, None]) - global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) - - # update entropy_b - _entropy_b = tl.load( - entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n, - mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), - other=0.0, - ) - global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) - - # store - tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) - tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) - tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) - - -@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) -@triton.jit -def efficient_entropy_triton_epilogue_tp_update( - num_tokens, - logprobs_ptr, - stride_logprobs: tl.int64, - maximum_ptr, - stride_maximum: tl.int64, - accumulate_ptr, - stride_accumulate: tl.int64, - entropy_b_ptr, - stride_entropy_b: tl.int64, - entropy_ptr, - stride_entropy: tl.int64, - logprobs_scalar_ptr, - reduction: int, - BLOCK_SIZE_M: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - - maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) - accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens) - - entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens) - entropy_b = tl.fdiv(entropy_b, accumulate) - tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens) - - entropy = tl.log(accumulate) + maximum - entropy_b - tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens) - - logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens) - logprobs = maximum + tl.log(accumulate) - logprobs - - logprobs = -1 * logprobs - if reduction == 0: - tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) - elif reduction == 1: - logprobs_scalar = tl.sum(logprobs, axis=0) - tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) - elif reduction == 2: - logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32) - tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) - - -_dedicated_stream, _dedicated_events = None, None - - -def efficient_entropy_forward( - hidden: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - reduction: typing.Optional[int] = 2, - temperature: typing.Optional[float] = 1.0, - dist_process_group: typing.Optional[dist.ProcessGroup] = None, -) -> list[torch.Tensor]: - """ - forward host function - """ - assert hidden.is_cuda and weight.is_cuda and labels.is_cuda - assert weight.device == hidden.device and labels.device == hidden.device - assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 - assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() - - assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] - - _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) - _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) - - if dist_process_group is not None and not hasattr(efficient_entropy_forward, "_initialized"): - global _dedicated_stream, _dedicated_events - _dedicated_stream = get_torch_device().Stream(hidden.device) - _dedicated_events = [get_torch_device().Event() for _ in range(2)] - efficient_entropy_forward._initialized = True - - num_tokens, hidden_size = hidden.shape - num_tokens = labels.shape[0] - vocab_size, hidden_size = weight.shape - assert hidden_size % 128 == 0 - - REDUCTION = get_entropy_reduction_enum(reduction) - - if REDUCTION == EntropyReductionEnum._None: - if dist_process_group is None: - logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) - else: - logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) - elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean): - logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) - else: - raise ValueError(f"Invalid reduction: {reduction}") - - entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) - assert logprobs.is_contiguous() and entropy.is_contiguous() - - maximum = torch.empty_like(entropy) - accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32) - accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens) - accumulate = accumulate_and_entropy_b_view[0, :] - entropy_b = accumulate_and_entropy_b_view[1, :] - assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous() - - vocab_per_split = 1024 - assert vocab_per_split % 128 == 0 - num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - - _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - - if REDUCTION == EntropyReductionEnum._None: - _logprobs = logprobs - else: - _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) - - assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() - assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda - - if _config._use_triton: - # 1D kernel launch, then split the tile - def mainloop_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) - - efficient_entropy_kernel_general_mainloop[mainloop_grid]( - _rank, - hidden, - weight, - labels, - num_tokens, - hidden_size, - vocab_size, - vocab_per_split, - hidden.stride(0), - hidden.stride(1), - weight.stride(0), - weight.stride(1), - _max, - _max.stride(0), - _max.stride(1), - _accu, - _accu.stride(0), - _accu.stride(1), - _entropy_b, - _entropy_b.stride(0), - _entropy_b.stride(1), - _logprobs, - _logprobs.stride(0), - logprobs, - 1.0 / temperature, - ) - else: - raise AssertionError("Triton is required for efficient entropy kernel") - - # reduction on maximum and maximum_indices - def epilogue_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) - - if dist_process_group is None: - efficient_entropy_triton_kernel_epilogue[epilogue_grid]( - _max, - _max.stride(0), - _max.stride(1), - num_tokens, - num_splits, - maximum, - maximum.stride(0), - _accu, - _accu.stride(0), - _accu.stride(1), - accumulate, - accumulate.stride(0), - _entropy_b, - _entropy_b.stride(0), - _entropy_b.stride(1), - entropy_b, - entropy_b.stride(0), - entropy, - entropy.stride(0), - _logprobs, - _logprobs.stride(0), - logprobs, - REDUCTION, - ) - else: - # tensor-parallel - _max_backup = _max.clone() - dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group) - - get_torch_device().current_stream().record_event(_dedicated_events[0]) - with get_torch_device().stream(_dedicated_stream): - _dedicated_stream.wait_event(_dedicated_events[0]) - dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group) - _dedicated_stream.record_event(_dedicated_events[1]) - - efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid]( - num_tokens, - num_splits, - _max, - _max.stride(0), - _max.stride(1), - _max_backup, - _max_backup.stride(0), - _max_backup.stride(1), - _accu, - _accu.stride(0), - _accu.stride(1), - _entropy_b, - _entropy_b.stride(0), - _entropy_b.stride(1), - maximum, - maximum.stride(0), - accumulate, - accumulate.stride(0), - entropy_b, - entropy_b.stride(0), - ) - get_torch_device().current_stream().wait_event(_dedicated_events[1]) - - dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group) - - # update logprobs & entropy - efficient_entropy_triton_epilogue_tp_update[epilogue_grid]( - num_tokens, - _logprobs, - _logprobs.stride(0), - maximum, - maximum.stride(0), - accumulate, - accumulate.stride(0), - entropy_b, - entropy_b.stride(0), - entropy, - entropy.stride(0), - logprobs, - REDUCTION, - ) - - return (logprobs, entropy, maximum, accumulate, entropy_b) - - -# NOTE: merge d_weight & d_hidden here, split along M & N -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, - num_stages=3, - num_warps=8, - ) - ], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_backward_kernel_general_mainloop_MN( - num_tokens: int, - hidden_size: int, - vocab_size: int, - rank: int, - hidden_ptr, - stride_hidden_m: tl.int64, - stride_hidden_k: tl.int64, - weight_ptr, - stride_weight_n: tl.int64, - stride_weight_k: tl.int64, - labels_ptr, - stride_labels: tl.int64, - maximum_ptr, - stride_maximum: tl.int64, - accu_ptr, - stride_accu: tl.int64, - d_entropy_ptr, - stride_d_entropy: tl.int64, - d_logprobs_ptr, - stride_d_logprobs: tl.int64, - reduction: int, - entropy_b_ptr, - stride_entropy_b: tl.int64, - d_hidden_ptr, - stride_d_hidden_m: tl.int64, - stride_d_hidden_k: tl.int64, - d_weight_ptr, - stride_d_weight_n: tl.int64, - stride_d_weight_k: tl.int64, - rcp_temperature: tl.float32, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - backward mainloop, where d_logits & d_hidden & d_weight are fused - """ - # block swizzling - # pid = tl.program_id(axis=0) - # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - # pid_m = pid % num_pid_m - # pid_n = pid // num_pid_m - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - maximum_ptrs = maximum_ptr + offs_am * stride_maximum - maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) - accu_ptrs = accu_ptr + offs_am * stride_accu - accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero - accu_rcp = tl.fdiv(1.0, accu) - - d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy - d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) - if reduction == 0: # none - d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs - d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) - elif reduction == 1: # sum - d_logprobs = tl.load(d_logprobs_ptr) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: # mean - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_logprobs = -1 * d_logprobs - - entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b - entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) - - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) - labels_ptrs = labels_ptr + offs_am * stride_labels - labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) - - d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k - # d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n - d_weight_ptrs = d_weight_ptr + offs_bn[:, None] * stride_d_weight_n + offs_k[None, :] * stride_d_weight_k - - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - other=0.0, - ) - # _weight = tl.load(weight_ptrs, - # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), - # other=0.0) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), - other=0.0, - ) - - logits = tl.dot(_hidden, _weight.trans(), logits) - - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - hidden_ptrs -= hidden_size * stride_hidden_k - weight_ptrs -= hidden_size * stride_weight_k - - # scale logits by temperature - logits *= rcp_temperature - - exp_logits = tl.exp(logits - maximum[:, None]) - - mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] - d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) - - # scale d_logits by temperature - d_logits *= rcp_temperature - - # loop for d_weight & d_hidden - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - other=0.0, - ) - # _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) - # tl.atomic_add(d_weight_ptrs, - # _d_weight, - # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size)) - _d_weight = tl.dot(d_logits.trans(), _hidden.to(tl.float32)) - tl.atomic_add( - d_weight_ptrs, - _d_weight, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), - ) - - # _weight = tl.load(weight_ptrs, - # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), - # other=0.0) - # _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), - other=0.0, - ) - _d_hidden = tl.dot(d_logits, _weight.to(tl.float32)) - tl.atomic_add( - d_hidden_ptrs, - _d_hidden, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - ) - - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k - d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, - num_stages=3, - num_warps=8, - ), - ], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_backward_kernel_d_hidden( - num_tokens: int, - hidden_size: int, - vocab_size: int, - rank: int, - hidden_ptr, - stride_hidden_m: tl.int64, - stride_hidden_k: tl.int64, - weight_ptr, - stride_weight_n: tl.int64, - stride_weight_k: tl.int64, - labels_ptr, - stride_labels: tl.int64, - maximum_ptr, - stride_maximum: tl.int64, - accu_ptr, - stride_accu: tl.int64, - d_entropy_ptr, - stride_d_entropy: tl.int64, - d_logprobs_ptr, - stride_d_logprobs: tl.int64, - reduction: int, - entropy_b_ptr, - stride_entropy_b: tl.int64, - d_hidden_ptr, - stride_d_hidden_m: tl.int64, - stride_d_hidden_k: tl.int64, - rcp_temperature: tl.float32, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - backward d_hidden - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - pid_m = pid % num_pid_m - pid_k = pid // num_pid_m - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_k = tl.arange(0, BLOCK_SIZE_K) - result_offs_k = pid_k * BLOCK_SIZE_K + offs_k - - maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) - accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) - accu_rcp = tl.fdiv(1.0, accu) - d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) - if reduction == 0: - d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) - elif reduction == 1: - d_logprobs = tl.load(d_logprobs_ptr) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_logprobs = -1 * d_logprobs - - entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) - labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) - - # iterate over vocab_size - d_hidden = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) - for n in range(0, tl.cdiv(vocab_size, BLOCK_SIZE_N)): - offs_n = n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) - - # iterate over hidden_size to get logits - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), - other=0.0, - ) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), - other=0.0, - ) - - logits = tl.dot(_hidden, _weight.trans(), logits) - - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - - # scale logits by temperature - logits *= rcp_temperature - - exp_logits = tl.exp(logits - maximum[:, None]) - - mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] - d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) - - # scale d_logits - d_logits *= rcp_temperature - - # calculate d_hidden - weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k) - _weight = tl.load( - weight_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_n[:, None] < vocab_size), other=0.0 - ) - d_hidden = tl.dot(d_logits.to(weight_ptr.dtype.element_ty), _weight, d_hidden) - - # write back - tl.store( - d_hidden_ptr + offs_m[:, None] * stride_d_hidden_m + result_offs_k[None, :] * stride_d_hidden_k, - d_hidden, - mask=(offs_m[:, None] < num_tokens) & (result_offs_k[None, :] < hidden_size), - ) - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, - num_stages=3, - num_warps=8, - ), - ], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_backward_kernel_d_weight( - num_tokens: int, - hidden_size: int, - vocab_size: int, - rank: int, - hidden_ptr, - stride_hidden_m: tl.int64, - stride_hidden_k: tl.int64, - weight_ptr, - stride_weight_n: tl.int64, - stride_weight_k: tl.int64, - labels_ptr, - stride_labels: tl.int64, - maximum_ptr, - stride_maximum: tl.int64, - accu_ptr, - stride_accu: tl.int64, - d_entropy_ptr, - stride_d_entropy: tl.int64, - d_logprobs_ptr, - stride_d_logprobs: tl.int64, - reduction: int, - entropy_b_ptr, - stride_entropy_b: tl.int64, - d_weight_ptr, - stride_d_weight_n: tl.int64, - stride_d_weight_k: tl.int64, - rcp_temperature: tl.float32, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - pid = tl.program_id(axis=0) - num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) - pid_n = pid % num_pid_n - pid_k = pid // num_pid_n - - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - result_offs_k = pid_k * BLOCK_SIZE_K + offs_k - - d_weight = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) - for m in range(0, tl.cdiv(num_tokens, BLOCK_SIZE_M)): - offs_m = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - - maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) - accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) - accu_rcp = tl.fdiv(1.0, accu) - d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) - if reduction == 0: - d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) - elif reduction == 1: - d_logprobs = tl.load(d_logprobs_ptr) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_logprobs = -1 * d_logprobs - - entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) - labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) - - hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) - - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), - other=0.0, - ) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), - other=0.0, - ) - - logits = tl.dot(_hidden, _weight.trans(), logits) - - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - - logits *= rcp_temperature - - exp_logits = tl.exp(logits - maximum[:, None]) - - mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] - d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) - - d_logits *= rcp_temperature - - hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k) - _hidden = tl.load( - hidden_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_m[:, None] < num_tokens), other=0.0 - ) - d_weight = tl.dot(d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight) - - # write back - tl.store( - d_weight_ptr + offs_n[:, None] * stride_d_weight_n + result_offs_k[None, :] * stride_d_weight_k, - d_weight, - mask=(offs_n[:, None] < vocab_size) & (result_offs_k[None, :] < hidden_size), - ) - - -# NOTE: split tile from d_logits' perspective -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, - num_stages=3, - num_warps=8, - ), - ], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_backward_kernel_general_d_logits( - num_tokens: int, - hidden_size: int, - vocab_size: int, - rank: int, - hidden_ptr, - stride_hidden_m: tl.int64, - stride_hidden_k: tl.int64, - weight_ptr, - stride_weight_n: tl.int64, - stride_weight_k: tl.int64, - labels_ptr, - stride_labels: tl.int64, - maximum_ptr, - stride_maximum: tl.int64, - accu_ptr, - stride_accu: tl.int64, - d_entropy_ptr, - stride_d_entropy: tl.int64, - d_logprobs_ptr, - stride_d_logprobs: tl.int64, - reduction: int, - entropy_b_ptr, - stride_entropy_b, - d_logits_ptr, - stride_d_logits_m: tl.int64, - stride_d_logits_n: tl.int64, - rcp_temperature: tl.float32, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - backward d_logits - """ - # block swizzling - # pid = tl.program_id(axis=0) - # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - # pid_m = pid % num_pid_m - # pid_n = pid // num_pid_m - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - maximum_ptrs = maximum_ptr + offs_am * stride_maximum - maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) - accu_ptrs = accu_ptr + offs_am * stride_accu - accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero - accu_rcp = tl.fdiv(1.0, accu) - - d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy - d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) - if reduction == 0: # none - d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs - d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) - elif reduction == 1: # sum - d_logprobs = tl.load(d_logprobs_ptr) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: # mean - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_logprobs = -1 * d_logprobs - - entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b - entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) - - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) - labels_ptrs = labels_ptr + offs_am * stride_labels - labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) - - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - other=0.0, - ) - # _weight = tl.load(weight_ptrs, - # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), - # other=0.0) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), - other=0.0, - ) - - logits = tl.dot(_hidden, _weight.trans(), logits) - - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - hidden_ptrs -= hidden_size * stride_hidden_k - weight_ptrs -= hidden_size * stride_weight_k - - # scale logits by temperature - logits *= rcp_temperature - - exp_logits = tl.exp(logits - maximum[:, None]) - - mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] - d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) - - # scale d_logits by temperature - d_logits *= rcp_temperature - - # store d_logits - d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n - tl.store( - d_logits_ptrs, - d_logits, # will be implicitly converted to d_logits_ptrs.dtype.element_ty - mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size), - ) - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, - num_stages=3, - num_warps=8, - ), - ], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_backward_kernel_general_d_logits_split_N( - split_idx: int, - num_tokens: int, - hidden_size: int, - vocab_size: int, - vocab_per_split: int, - rank: int, - hidden_ptr, - stride_hidden_m: tl.int64, - stride_hidden_k: tl.int64, - weight_ptr, - stride_weight_n: tl.int64, - stride_weight_k: tl.int64, - labels_ptr, - stride_labels: tl.int64, - maximum_ptr, - stride_maximum: tl.int64, - accu_ptr, - stride_accu: tl.int64, - d_entropy_ptr, - stride_d_entropy: tl.int64, - d_logprobs_ptr, - stride_d_logprobs: tl.int64, - reduction: int, - entropy_b_ptr, - stride_entropy_b, - d_logits_ptr, - stride_d_logits_m: tl.int64, - stride_d_logits_n: tl.int64, - rcp_temperature: tl.float32, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - maximum = tl.load(maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0) - accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6) - accu_rcp = tl.fdiv(1.0, accu) - d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0) - if reduction == 0: - d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0) - elif reduction == 1: - d_logprobs = tl.load(d_logprobs_ptr) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_logprobs = -1 * d_logprobs - entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0) - labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0) - - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) - - vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - other=0.0, - ) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_right_bound), - other=0.0, - ) - logits = tl.dot(_hidden, _weight.trans(), logits) - - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - - logits *= rcp_temperature - exp_logits = tl.exp(logits - maximum[:, None]) - - mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] - d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) - - d_logits *= rcp_temperature - - # filter d_logits with mask - result_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split) - - tl.store( - d_logits_ptr + offs_am[:, None] * stride_d_logits_m + result_offs_n[None, :] * stride_d_logits_n, d_logits, mask - ) - - -def efficient_entropy_backward( - dlogprobs: torch.Tensor, - dentropy: torch.Tensor, - hidden: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - maximum: torch.Tensor, - acc: torch.Tensor, - entropy_b: torch.Tensor, - reduction: typing.Optional[int] = 2, - should_return_fp32_grad: bool = False, - temperature: typing.Optional[float] = 1.0, - dist_process_group: typing.Optional[dist.ProcessGroup] = None, -) -> list[torch.Tensor]: - """ - backward host function - """ - assert hidden.is_cuda and weight.is_cuda and labels.is_cuda - assert weight.device == hidden.device and labels.device == hidden.device - assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 - assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() - assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] - - _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) - _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) - - num_tokens, hidden_size = hidden.shape - num_tokens = labels.shape[0] - vocab_size, hidden_size = weight.shape - assert hidden_size % 128 == 0 - - REDUCTION = get_entropy_reduction_enum(reduction) - - if REDUCTION == EntropyReductionEnum._None: - assert dlogprobs.shape == (num_tokens,) - else: - assert dlogprobs.dim() == 0 - - assert dlogprobs.is_contiguous() and dentropy.is_contiguous() - assert dlogprobs.is_cuda and dentropy.is_cuda - assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device - assert dentropy.shape == (num_tokens,) - - d_hidden, d_weight = None, None - if _config._backward == BackwardEnum._Total_Fuse_MN or should_return_fp32_grad: - d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device) - d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device) - else: - d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device) - d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device) - assert d_hidden.is_contiguous() and d_weight.is_contiguous() - - assert maximum.is_contiguous() and acc.is_contiguous() - assert maximum.device == hidden.device and acc.device == hidden.device - assert maximum.shape == labels.shape == acc.shape - assert maximum.is_cuda and acc.is_cuda - - vocab_per_split = 1024 - assert vocab_per_split % 128 == 0 - num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - - assert entropy_b.is_contiguous() and entropy_b.is_cuda - assert entropy_b.shape == (num_tokens,) - - if _config._backward == BackwardEnum._Total_Fuse_MN: - # --- Triton doesn't materialize d_logits at all. Split tiles at the perspective of d_logits. - def mainloop_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) - - efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid]( - num_tokens, - hidden_size, - vocab_size, - _rank, - hidden, - hidden.stride(0), - hidden.stride(1), - weight, - weight.stride(0), - weight.stride(1), - labels, - labels.stride(0), - maximum, - maximum.stride(0), - acc, - acc.stride(0), - dentropy, - dentropy.stride(0), - dlogprobs, - dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, - REDUCTION, - entropy_b, - entropy_b.stride(0), - d_hidden, - d_hidden.stride(0), - d_hidden.stride(1), - d_weight, - d_weight.stride(0), - d_weight.stride(1), - 1.0 / temperature, - ) - - elif _config._backward == BackwardEnum._Total_Separate: - _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype).contiguous() - assert _d_logits.is_contiguous() - - if _config._use_triton: - - def d_logits_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) - - efficient_entropy_backward_kernel_general_d_logits[d_logits_grid]( - num_tokens, - hidden_size, - vocab_size, - _rank, - hidden, - hidden.stride(0), - hidden.stride(1), - weight, - weight.stride(0), - weight.stride(1), - labels, - labels.stride(0), - maximum, - maximum.stride(0), - acc, - acc.stride(0), - dentropy, - dentropy.stride(0), - dlogprobs, - dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, - REDUCTION, - entropy_b, - entropy_b.stride(0), - _d_logits, - _d_logits.stride(0), - _d_logits.stride(1), - 1.0 / temperature, - ) - - torch.matmul(_d_logits, weight, out=d_hidden) - torch.matmul(_d_logits.T, hidden, out=d_weight) - else: - raise AssertionError("Triton is required for efficient entropy kernel") - - elif _config._backward == BackwardEnum._Split_Dlogits_N: - vocab_per_split = 9504 - num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - - _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous() - assert _d_logits.is_contiguous() - - def d_logits_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]),) - - for split_idx in range(num_splits): - efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid]( - split_idx, - num_tokens, - hidden_size, - vocab_size, - vocab_per_split, - _rank, - hidden, - hidden.stride(0), - hidden.stride(1), - weight, - weight.stride(0), - weight.stride(1), - labels, - labels.stride(0), - maximum, - maximum.stride(0), - acc, - acc.stride(0), - dentropy, - dentropy.stride(0), - dlogprobs, - dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, - REDUCTION, - entropy_b, - entropy_b.stride(0), - _d_logits, - _d_logits.stride(0), - _d_logits.stride(1), - 1.0 / temperature, - ) - - if split_idx == (num_splits - 1): - vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split - _d_logits = _d_logits[:, :vocab_right_bound].contiguous() - - if split_idx == 0: - torch.matmul( - _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], out=d_hidden - ) - else: - d_hidden += torch.matmul( - _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :] - ) - torch.matmul( - _d_logits.T, hidden, out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :] - ) - - elif _config._backward == BackwardEnum._Split_Dlogits_M: - raise NotImplementedError("BackwardEnum._Split_Dlogits_M is not implemented yet") - - return d_hidden, d_weight diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py deleted file mode 100644 index 733a8152a..000000000 --- a/verl/utils/kernel/linear_cross_entropy.py +++ /dev/null @@ -1,117 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import typing - -import torch -import torch.distributed as dist - -from . import kernels - - -class LinearCrossEntropy(torch.autograd.Function): - @staticmethod - def forward( - ctx, - hidden: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - temperature: typing.Optional[float] = 1.0, - reduction: typing.Optional[str] = "none", - dist_process_group: typing.Optional[dist.ProcessGroup] = None, - ) -> list[torch.Tensor]: - """_summary_ - - Args: - ctx (_type_): _description_ - hidden (torch.Tensor): (batch_size, num_tokens, hidden_size) -> (batch_size * num_tokens, hidden_size) - weight (torch.Tensor): (vocab_size, hidden_size) - labels (torch.Tensor): (batch_size, num_tokens) -> (batch_size * num_tokens, ) - temperature (typing.Optional[float], optional): _description_. Defaults to 1.0. - reduction (typing.Optional[str], optional): _description_. Defaults to "none". - dist_process_group (typing.Optional[dist.ProcessGroup], optional): _description_. Defaults to None. - - Returns: - typing.List[torch.Tensor]: _description_ - """ - - assert isinstance(temperature, float), f"temperature must be a float, but got {type(temperature)}" - assert isinstance(reduction, str), f"reduction must be a str, but got {type(reduction)}" - with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): - REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) - - original_hidden_shape = hidden.shape - if len(hidden.shape) != 2: - hidden = hidden.view(-1, hidden.shape[-1]) # (batch_size * num_tokens, hidden_size) - if len(labels.shape) != 1: - labels = labels.view(-1) - - logprobs, entropy, _maximum, _accumulate, _entropy_b = kernels.efficient_entropy_forward( - hidden, weight, labels, REDUCTION, temperature, dist_process_group - ) - - ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b) - ctx.original_hidden_shape = original_hidden_shape - ctx.REDUCTION = REDUCTION - ctx.dist_process_group = dist_process_group - ctx.should_return_fp32_grad = False - ctx.temperature = temperature - return logprobs, entropy - - @staticmethod - def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> list[torch.Tensor]: - with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): - (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors - REDUCTION = ctx.REDUCTION - dist_process_group = ctx.dist_process_group - should_return_fp32_grad = ctx.should_return_fp32_grad - temperature = ctx.temperature - - d_hidden, d_weight = kernels.efficient_entropy_backward( - dlogprobs, - dentropy, - hidden, - weight, - labels, - _maximum, - _accumulate, - _entropy_b, - REDUCTION, - should_return_fp32_grad, - temperature, - dist_process_group, - ) - d_hidden = d_hidden.view(ctx.original_hidden_shape) - - return (d_hidden, d_weight, None, None, None, None) - - -linear_cross_entropy = LinearCrossEntropy.apply diff --git a/verl/utils/logger/__init__.py b/verl/utils/logger/__init__.py deleted file mode 100644 index e3184368b..000000000 --- a/verl/utils/logger/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from .aggregate_logger import ( - DecoratorLoggerBase, - LocalLogger, - log_with_rank, - print_rank_0, - print_with_rank, - print_with_rank_and_timer, -) - -__all__ = [ - "LocalLogger", - "DecoratorLoggerBase", - "print_rank_0", - "print_with_rank", - "print_with_rank_and_timer", - "log_with_rank", -] diff --git a/verl/utils/logger/aggregate_logger.py b/verl/utils/logger/aggregate_logger.py deleted file mode 100644 index d29698acc..000000000 --- a/verl/utils/logger/aggregate_logger.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -A Ray logger will receive logging info from different processes. -""" - -import datetime -import logging -import numbers -import pprint - -import torch - - -def concat_dict_to_str(dict: dict, step): - output = [f"step:{step}"] - for k, v in dict.items(): - if isinstance(v, numbers.Number): - output.append(f"{k}:{pprint.pformat(v)}") - output_str = " - ".join(output) - return output_str - - -class LocalLogger: - """ - A local logger that logs messages to the console. - - Args: - print_to_console (bool): Whether to print to the console. - """ - - def __init__(self, print_to_console=True): - self.print_to_console = print_to_console - - def flush(self): - pass - - def log(self, data, step): - if self.print_to_console: - print(concat_dict_to_str(data, step=step), flush=True) - - -class DecoratorLoggerBase: - """ - Base class for all decorators that log messages. - - Args: - role (str): The role (the name) of the logger. - logger (logging.Logger): The logger instance to use for logging. - level (int): The logging level. - rank (int): The rank of the process. - log_only_rank_0 (bool): If True, only log for rank 0. - """ - - def __init__( - self, role: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0, log_only_rank_0: bool = True - ): - self.role = role - self.logger = logger - self.level = level - self.rank = rank - self.log_only_rank_0 = log_only_rank_0 - self.logging_function = self.log_by_logging - if logger is None: - self.logging_function = self.log_by_print - - def log_by_print(self, log_str): - if not self.log_only_rank_0 or self.rank == 0: - print(f"{self.role} {log_str}", flush=True) - - def log_by_logging(self, log_str): - if self.logger is None: - raise ValueError("Logger is not initialized") - if not self.log_only_rank_0 or self.rank == 0: - self.logger.log(self.level, f"{self.role} {log_str}") - - -def print_rank_0(message): - """If distributed is initialized, print only on rank 0.""" - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - print(message, flush=True) - else: - print(message, flush=True) - - -def print_with_rank(message: str, rank: int = 0, log_only_rank_0: bool = False): - """_summary_ - Print a message with rank information. - This function prints the message only if `log_only_rank_0` is False or if the rank is 0. - - Args: - message (str): _description_ - rank (int, optional): _description_. Defaults to 0. - log_only_rank_0 (bool, optional): _description_. Defaults to False. - """ - if not log_only_rank_0 or rank == 0: - print(f"[Rank {rank}] {message}", flush=True) - - -def print_with_rank_and_timer(message: str, rank: int = 0, log_only_rank_0: bool = False): - """_summary_ - Print a message with rank information and a timestamp. - This function prints the message only if `log_only_rank_0` is False or if the rank is 0. - - Args: - message (str): _description_ - rank (int, optional): _description_. Defaults to 0. - log_only_rank_0 (bool, optional): _description_. Defaults to False. - """ - now = datetime.datetime.now() - message = f"[{now.strftime('%Y-%m-%d %H:%M:%S')}] [Rank {rank}] {message}" - if not log_only_rank_0 or rank == 0: - print(message, flush=True) - - -def log_with_rank(message: str, rank, logger: logging.Logger, level=logging.INFO, log_only_rank_0: bool = False): - """_summary_ - Log a message with rank information using a logger. - This function logs the message only if `log_only_rank_0` is False or if the rank is 0. - Args: - message (str): The message to log. - rank (int): The rank of the process. - logger (logging.Logger): The logger instance to use for logging. - level (int, optional): The logging level. Defaults to logging.INFO. - log_only_rank_0 (bool, optional): If True, only log for rank 0. Defaults to False. - """ - if not log_only_rank_0 or rank == 0: - logger.log(level, f"[Rank {rank}] {message}") diff --git a/verl/utils/logging_utils.py b/verl/utils/logging_utils.py deleted file mode 100644 index 13fa9170b..000000000 --- a/verl/utils/logging_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os - -import torch - - -def set_basic_config(level): - """ - This function sets the global logging format and level. It will be called when import verl - """ - logging.basicConfig(format="%(levelname)s:%(asctime)s:%(message)s", level=level) - - -def log_to_file(string): - print(string) - if os.path.isdir("logs"): - with open(f"logs/log_{torch.distributed.get_rank()}", "a+") as f: - f.write(string + "\n") diff --git a/verl/utils/megatron/__init__.py b/verl/utils/megatron/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/utils/megatron/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/utils/megatron/dist_checkpointing.py b/verl/utils/megatron/dist_checkpointing.py deleted file mode 100644 index d95752a45..000000000 --- a/verl/utils/megatron/dist_checkpointing.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from megatron.core import dist_checkpointing, mpu -from megatron.core.dist_checkpointing.serialization import ( - get_default_load_sharded_strategy, - get_default_save_sharded_strategy, -) -from megatron.core.dist_checkpointing.strategies.fully_parallel import ( - FullyParallelLoadStrategyWrapper, - FullyParallelSaveStrategyWrapper, -) - - -def save_dist_checkpointing(sharded_state_dict, ckpt_path, async_save=False): - validate_sharding_integrity = True - # Get checkpointing strategies - save_strategy = get_default_save_sharded_strategy("torch_dist") - save_strategy = FullyParallelSaveStrategyWrapper( - save_strategy, mpu.get_data_parallel_group(with_context_parallel=True) - ) - - # Save model sharded state dicts - async_save_request = dist_checkpointing.save( - sharded_state_dict, - ckpt_path, - sharded_strategy=save_strategy, - async_sharded_save=async_save, - validate_access_integrity=validate_sharding_integrity, - ) - - return async_save_request - - -def load_dist_checkpointing(sharded_state_dict, ckpt_dir): - # Get checkpointing strategies - load_strategy = get_default_load_sharded_strategy(ckpt_dir) - load_strategy = FullyParallelLoadStrategyWrapper( - load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) - ) - - # Load model sharded state dicts - state_dict = dist_checkpointing.load(sharded_state_dict, ckpt_dir, sharded_strategy=load_strategy) - - return state_dict diff --git a/verl/utils/megatron/memory.py b/verl/utils/megatron/memory.py deleted file mode 100644 index bc62d427e..000000000 --- a/verl/utils/megatron/memory.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from verl.utils.device import get_device_id - - -class MemoryBuffer: - def __init__(self, numel, numel_padded, dtype): - self.numel = numel - self.numel_padded = numel_padded - self.dtype = dtype - self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_id(), requires_grad=False) - - def zero(self): - """Reset the buffer to zero.""" - self.data.zero_() - - def get(self, shape, start_index): - """Return a tensor with the input `shape` as a view into the - 1-D data starting at `start_index`.""" - end_index = start_index + shape.numel() - assert end_index <= self.numel, "requested tensor is out of the buffer range." - buffer_tensor = self.data[start_index:end_index] - buffer_tensor = buffer_tensor.view(shape) - return buffer_tensor diff --git a/verl/utils/megatron/optimizer.py b/verl/utils/megatron/optimizer.py deleted file mode 100644 index 100c161a5..000000000 --- a/verl/utils/megatron/optimizer.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from megatron.core.optimizer import OptimizerConfig -from megatron.core.optimizer import get_megatron_optimizer as get_megatron_optimizer_native -from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler - - -def get_megatron_optimizer( - model, - config: OptimizerConfig, - no_weight_decay_cond=None, - scale_lr_cond=None, - lr_mult=1.0, -): - # Base optimizer. - return get_megatron_optimizer_native( - config=config, - model_chunks=model, - no_weight_decay_cond=no_weight_decay_cond, - scale_lr_cond=scale_lr_cond, - lr_mult=lr_mult, - ) - - -def get_megatron_optimizer_param_scheduler( - optimizer, - config, -): - """ - Get the optimizer parameter scheduler for Megatron. - """ - if config.get("lr_decay_steps", None) is None: - config.lr_decay_steps = config.total_training_steps - wsd_decay_steps = None - if config.get("lr_wsd_decay_steps", None) is not None: - wsd_decay_steps = config.lr_wsd_decay_steps - if config.get("lr_warmup_steps_ratio", None) is not None and ( - config.get("lr_warmup_steps", None) is None or config.lr_warmup_steps <= 0 - ): - config.lr_warmup_steps = int(config.lr_warmup_steps_ratio * config.lr_decay_steps) - - opt_param_scheduler = OptimizerParamScheduler( - optimizer, - init_lr=config.lr_warmup_init, - max_lr=config.lr, - min_lr=config.min_lr, - lr_warmup_steps=config.lr_warmup_steps, - lr_decay_steps=config.lr_decay_steps, - lr_decay_style=config.lr_decay_style, - start_wd=config.weight_decay, - end_wd=config.weight_decay, - wd_incr_steps=config.total_training_steps, - wd_incr_style=config.weight_decay_incr_style, - use_checkpoint_opt_param_scheduler=config.use_checkpoint_opt_param_scheduler, - override_opt_param_scheduler=(not config.use_checkpoint_opt_param_scheduler), - wsd_decay_steps=wsd_decay_steps, - lr_wsd_decay_style=config.lr_wsd_decay_style, - ) - - return opt_param_scheduler - - -def get_megatron_last_lr(optimizer): - """ - Get the last learning rate from the optimizer parameter scheduler. - """ - return optimizer.param_groups[0]["lr"] diff --git a/verl/utils/megatron/pipeline_parallel.py b/verl/utils/megatron/pipeline_parallel.py deleted file mode 100644 index 50ba69736..000000000 --- a/verl/utils/megatron/pipeline_parallel.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from megatron.core import parallel_state as mpu - -from .sequence_parallel import pad_to_sequence_parallel - - -def compute_transformers_input_shapes(batches, meta_info): - from flash_attn.bert_padding import unpad_input # flash 2 is a must for Megatron - - # pre-compute input shapes for each micro-batch at each pp stage - input_shapes = [] - for model_inputs in batches: - input_ids = model_inputs["input_ids"] - attention_mask = model_inputs["attention_mask"] - input_ids_rmpad = unpad_input(input_ids.unsqueeze(dim=-1), attention_mask)[0] # (total_nnz, 1) - if meta_info["sequence_parallel"]: - input_ids_rmpad = pad_to_sequence_parallel(input_ids_rmpad) - # compute shapes for model_inputs - input_shapes.append( - torch.Size( - [ - input_ids_rmpad.shape[0] // mpu.get_tensor_model_parallel_world_size(), - 1, - meta_info["hidden_size"], - ] - ) - ) - else: - # compute shapes for model_inputs - input_shapes.append(torch.Size([input_ids_rmpad.shape[0], 1, meta_info["hidden_size"]])) - return input_shapes - - -def make_batch_generator(batches, vpp_size): - """ - Creates a batch generator suitable for Megatron pipeline parallelism, - handling virtual pipeline parallelism (VPP). - - If VPP is used (vpp_size > 1), it duplicates the batch iterator for each - virtual pipeline stage. Otherwise, it returns a single iterator. - - Args: - batches: An iterable (e.g., list) of micro-batches. - vpp_size (int): The virtual pipeline model parallel size. - - Returns: - An iterator or a list of iterators over the micro-batches. - """ - if vpp_size > 1: - # has vpp - batch_generator = [batches] * vpp_size # number of vpp chunks - batch_generator = [iter(b) for b in batch_generator] - else: - # no vpp - batch_generator = iter(batches) - return batch_generator diff --git a/verl/utils/megatron/sequence_parallel.py b/verl/utils/megatron/sequence_parallel.py deleted file mode 100644 index 52fda9b30..000000000 --- a/verl/utils/megatron/sequence_parallel.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn.functional as F -from megatron.core import parallel_state as mpu - - -def mark_parameter_as_sequence_parallel(parameter): - parameter.sequence_parallel = True - - -def is_sequence_parallel_param(param): - return hasattr(param, "sequence_parallel") and param.sequence_parallel - - -def pad_to_sequence_parallel(unpad_tokens: torch.Tensor): - """pad the tokens such that the total length is a multiple of sp world size - - Args: - unpad_tokens: (total_nnz, ...). Tokens after removing padding - - Returns: - the padded tokens: (total_nnz + pad_size,...) - - """ - total_nnz = unpad_tokens.shape[0] - sp_world_size = mpu.get_tensor_model_parallel_world_size() - - pad_size = 0 if total_nnz % sp_world_size == 0 else sp_world_size - total_nnz % sp_world_size - - if pad_size > 0: - if unpad_tokens.ndim == 1: - unpad_tokens = F.pad(unpad_tokens, (0, pad_size)) - elif unpad_tokens.ndim == 2: - unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size)) - else: - raise NotImplementedError(f"Padding dim {unpad_tokens.ndim()} is not supported") - - return unpad_tokens diff --git a/verl/utils/megatron/tensor_parallel.py b/verl/utils/megatron/tensor_parallel.py deleted file mode 100644 index d4a99b9d8..000000000 --- a/verl/utils/megatron/tensor_parallel.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Utilities for using tensor_parallel in megatron -""" - -from typing import TYPE_CHECKING - -import torch -import torch.distributed as dist -from megatron.core import parallel_state as mpu -from torch.nn import init - -if TYPE_CHECKING: - from megatron.core import ModelParallelConfig - - -def update_kwargs_with_config(dictionary: dict, config: "ModelParallelConfig"): - dictionary["config"] = config - return dictionary - - -def get_default_kwargs_for_model_parallel_config(): - model_parallel_config_kwargs = { - "params_dtype": torch.float32, - "use_cpu_initialization": False, - "perform_initialization": True, - "gradient_accumulation_fusion": False, - "sequence_parallel": False, - } - return model_parallel_config_kwargs - - -def get_default_model_parallel_config(): - from megatron.core import ModelParallelConfig - - return ModelParallelConfig(**get_default_kwargs_for_model_parallel_config()) - - -def get_common_default_kwargs_for_parallel_linear(): - default_model_parallel_config = get_default_model_parallel_config() - common_default_kwargs = { - "init_method": init.xavier_normal_, - "stride": 1, - "keep_master_weight_for_test": False, - "config": default_model_parallel_config, - } - return common_default_kwargs - - -def get_default_kwargs_for_column_parallel_linear(): - from megatron.core import ModelParallelConfig - - model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config() - column_parallel_config_kwargs = { - "async_tensor_model_parallel_allreduce": False, - } - model_parallel_config_kwargs.update(column_parallel_config_kwargs) - column_default_kwargs = { - "config": ModelParallelConfig(**model_parallel_config_kwargs), - } - common_default_kwargs = get_common_default_kwargs_for_parallel_linear() - common_default_kwargs.update(column_default_kwargs) - return common_default_kwargs - - -def get_default_kwargs_for_row_parallel_linear(): - common_default_kwargs = get_common_default_kwargs_for_parallel_linear() - return common_default_kwargs - - -def get_default_kwargs_for_parallel_embedding(): - from megatron.core import ModelParallelConfig - - model_parallel_config_kwargs = get_default_kwargs_for_model_parallel_config() - embedding_default_kwargs = { - "init_method": init.xavier_normal_, - "config": ModelParallelConfig(**model_parallel_config_kwargs), - } - return embedding_default_kwargs - - -def is_tensor_parallel_param(param): - return hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel - - -def get_tensor_parallel_partition_dim(param): - assert is_tensor_parallel_param(param) - return param.partition_dim - - -def get_tensor_parallel_partition_stride(param): - assert is_tensor_parallel_param(param) - return param.partition_stride - - -class _VocabParallelEntropy(torch.autograd.Function): - @staticmethod - def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor: - @torch.compile(dynamic=True) - def mul_reduce(a, b): - return (a * b).sum(dim=-1, keepdim=True) - - logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values - dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group()) - normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max - normalized_exp_logits = normalized_vocab_parallel_logits.exp_() - normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True) - dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group()) - softmax_logits = normalized_exp_logits.div_(normalized_sum_exp_logits) - sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits) - dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group()) - entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits - ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits) - return entropy.squeeze(dim=-1) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: - vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors - # reuse softmax_logits as grad - vocab_parallel_logits.sub_(sum_softmax_times_logits) - softmax_logits.mul_(vocab_parallel_logits) - softmax_logits.mul_(grad_output.unsqueeze(dim=-1)) - # recover vocab_parallel_logits - vocab_parallel_logits.add_(sum_softmax_times_logits) - softmax_logits.mul_(-1) - return softmax_logits - - -def vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor: - """Compute entropy when the logits are sharded in tp ranks - - Args: - vocab_parallel_logits: (total_nnz, vocab_size // tp_size) - - Returns: (total_nnz,) - - """ - return _VocabParallelEntropy.apply(vocab_parallel_logits) - - -def vocab_parallel_log_probs_from_logits(logits, labels): - """TODO(zhangchi.usc1992): We may change the implementation later""" - from megatron.core import tensor_parallel - - return -tensor_parallel.vocab_parallel_cross_entropy(vocab_parallel_logits=logits, target=labels) - - -def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length): - """Similar to log_probs_from_logits_response_rmpad, but the logits_rmpad is now spliited across tensor parallel - region. - This will further reduce the peak memory usage during training - - Args: - input_ids: [batch_size, seqlen] - attention_mask: [batch_size, seqlen] - logits_rmpad: [total_nnz, vocab_size // tp_size] - response_length: int - - """ - from flash_attn.bert_padding import pad_input, unpad_input - - batch_size, seqlen = input_ids.shape - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask) - input_ids_rmpad = input_ids_rmpad.squeeze(-1) - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) - full_log_probs_rmpad = vocab_parallel_log_probs_from_logits( - logits=logits_rmpad, labels=input_ids_rmpad_rolled - ) # (total_nnz,) - full_output = pad_input( - hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen - ) - output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] - return output diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py deleted file mode 100644 index e59d1ed36..000000000 --- a/verl/utils/megatron_utils.py +++ /dev/null @@ -1,1017 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Pretrain utilities.""" - -import gc -import os -import warnings -from typing import Any - -import torch -import torch.nn.functional as F -from megatron.core import ModelParallelConfig, mpu, tensor_parallel -from megatron.core.distributed import DistributedDataParallel as DDP -from megatron.core.distributed import DistributedDataParallelConfig -from megatron.core.enums import ModelType -from megatron.core.optimizer import ChainedOptimizer, OptimizerConfig -from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.module import Float16Module -from megatron.core.utils import get_attr_wrapped_model -from transformers import PretrainedConfig - -import verl.utils.megatron.tensor_parallel as tp_utils -from verl.utils.device import get_device_id, get_device_name, get_torch_device -from verl.utils.fs import local_mkdir_safe -from verl.utils.model import normalize_model_name -from verl.utils.torch_dtypes import PrecisionType - - -def get_model_config(model): - return get_attr_wrapped_model(model, "config", allow_none=False) - - -def get_model( - model_provider_func, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True, - use_distributed_optimizer=True, - transformer_config=None, - override_ddp_config=None, -): - """Build the model.""" - # Build model. - if ( - mpu.get_pipeline_model_parallel_world_size() > 1 - and mpu.get_virtual_pipeline_model_parallel_world_size() is not None - ): - assert model_type != ModelType.encoder_and_decoder, ( - "Interleaved schedule not supported for model with both encoder and decoder" - ) - model = [] - for i in range(mpu.get_virtual_pipeline_model_parallel_world_size()): - mpu.set_virtual_pipeline_model_parallel_rank(i) - # Set pre_process and post_process only after virtual rank is set. - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - this_model = model_provider_func(pre_process=pre_process, post_process=post_process) - this_model.model_type = model_type - model.append(this_model) - mpu.set_virtual_pipeline_model_parallel_rank(0) - else: - pre_process = mpu.is_pipeline_first_stage() - post_process = mpu.is_pipeline_last_stage() - add_encoder = True - add_decoder = True - if model_type == ModelType.encoder_and_decoder: - if mpu.get_pipeline_model_parallel_world_size() > 1: - assert mpu.get_pipeline_model_parallel_split_rank() is not None, ( - "Split rank needs to be specified for model with both encoder and decoder" - ) - rank = mpu.get_pipeline_model_parallel_rank() - split_rank = mpu.get_pipeline_model_parallel_split_rank() - world_size = mpu.get_pipeline_model_parallel_world_size() - pre_process = rank == 0 or rank == split_rank - post_process = (rank == (split_rank - 1)) or (rank == (world_size - 1)) - add_encoder = mpu.is_pipeline_stage_before_split() - add_decoder = mpu.is_pipeline_stage_after_split() - model = model_provider_func( - pre_process=pre_process, post_process=post_process, add_encoder=add_encoder, add_decoder=add_decoder - ) - else: - model = model_provider_func(pre_process=pre_process, post_process=post_process) - model.model_type = model_type - - if not isinstance(model, list): - model = [model] - - # Set tensor model parallel attributes if not set. - # Only parameters that are already tensor model parallel have these - # attributes set for them. We should make sure the default attributes - # are set for all params so the optimizer can use them. - for model_module in model: - for param in model_module.parameters(): - tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) - - # Print number of parameters. - if mpu.get_data_parallel_rank() == 0: - print( - " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( - mpu.get_tensor_model_parallel_rank(), - mpu.get_pipeline_model_parallel_rank(), - sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]), - ), - flush=True, - ) - - # GPU allocation. - if transformer_config is None or (not transformer_config.use_cpu_initialization): - for model_module in model: - model_module.to(f"{get_device_name()}:{get_device_id()}") - - # Fp16 conversion. - config: TransformerConfig = get_model_config(model[0]) - config.fp8 = None - tfconfig: TransformerConfig = model[0].config - if config.fp16 or config.bf16: # the ModelParallelConfig in GPTModel - model = [Float16Module(config, model_module) for model_module in model] - - if wrap_with_ddp: - ddp_models = [] - ddp_config_dict = { - "use_distributed_optimizer": use_distributed_optimizer, - "grad_reduce_in_fp32": True, - "overlap_grad_reduce": False, - } - if override_ddp_config is not None: - ddp_config_dict.update(override_ddp_config) - ddp_config = DistributedDataParallelConfig(**ddp_config_dict) - for model_chunk_idx, model_chunk in enumerate(model): - ddp_model = DDP( - config=tfconfig, - module=model_chunk, - disable_bucketing=(model_chunk_idx > 0), - ddp_config=ddp_config, - ) - ddp_models.append(ddp_model) - model = ddp_models - # # Broadcast params from data parallel src rank to other data parallel ranks. - # # if args.data_parallel_random_init: - for model_module in model: - model_module.broadcast_params() - return model - - -ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) - - -def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): - return_list = True - if not isinstance(model, list): - model = [model] - return_list = False - unwrapped_model = [] - for model_module in model: - while isinstance(model_module, module_instances): - model_module = model_module.module - unwrapped_model.append(model_module) - if not return_list: - return unwrapped_model[0] - return unwrapped_model - - -def convert_config(hf_config: PretrainedConfig, megatron_config) -> TransformerConfig: - print(f"megatron config {megatron_config}") - dt = PrecisionType.to_dtype(megatron_config.params_dtype) - print(f"pipeline_dtype=megatron_config {dt}") - qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) - overlap_p2p_comm = ( - mpu.get_virtual_pipeline_model_parallel_world_size() is not None - and mpu.get_virtual_pipeline_model_parallel_world_size() > 1 - ) - batch_p2p_comm = False - transformer_config = TransformerConfig( - num_layers=hf_config.num_hidden_layers, - hidden_size=hf_config.hidden_size, - num_attention_heads=hf_config.num_attention_heads, - num_query_groups=hf_config.num_key_value_heads, - ffn_hidden_size=hf_config.intermediate_size, - # max_position_embeddings=hf_config.max_position_embeddings, - activation_func=F.silu, - normalization="RMSNorm", - # rotary_percent=False, # default, - gated_linear_unit=True, # for llama - use_cpu_initialization=True, - apply_residual_connection_post_layernorm=False, # check what's this mean - add_bias_linear=False, - tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(), - pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(), - virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(), - context_parallel_size=mpu.get_context_parallel_world_size(), - overlap_p2p_comm=overlap_p2p_comm, - batch_p2p_comm=batch_p2p_comm, - pipeline_dtype=dt, - params_dtype=dt, - sequence_parallel=mpu.get_tensor_model_parallel_world_size() > 1, - variable_seq_lengths=True, - masked_softmax_fusion=True, - moe_token_dispatcher_type="alltoall", - attention_dropout=hf_config.attention_dropout, - hidden_dropout=getattr(hf_config, "hidden_dropout", 0.0), - add_qkv_bias=qkv_bias, - bf16=dt is torch.bfloat16, - ) - - return transformer_config - - -def init_megatron_optim_config(optim_config: dict) -> OptimizerConfig: - config = OptimizerConfig( - optimizer=optim_config.get("optimizer", "adam"), - lr=optim_config.get("lr"), - min_lr=optim_config.get("min_lr", None), - clip_grad=optim_config.get("clip_grad", 1.0), - weight_decay=optim_config.get("weight_decay", 0.01), - bf16=True, - params_dtype=torch.bfloat16, - use_distributed_optimizer=True, - ) - return config - - -def mcore_model_parallel_config( - sequence_parallel: bool, - params_dtype: torch.dtype, -) -> ModelParallelConfig: - # WARNING: Code should not reach this point. This function is deprecated and will be removed. - # Please use hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead. - warnings.warn( - "Code should not reach this point. This function is deprecated and will be removed. Please use " - "hf_to_mcore_config_dense() from verl.models.mcore.config_converter instead.", - DeprecationWarning, - stacklevel=2, - ) - return ModelParallelConfig( - tensor_model_parallel_size=mpu.get_tensor_model_parallel_world_size(), - pipeline_model_parallel_size=mpu.get_pipeline_model_parallel_world_size(), - virtual_pipeline_model_parallel_size=mpu.get_virtual_pipeline_model_parallel_world_size(), - context_parallel_size=mpu.get_context_parallel_world_size(), - sequence_parallel=sequence_parallel, - params_dtype=params_dtype, - pipeline_dtype=params_dtype, - bf16=True, - fp16=False, - timers=None, - ) - - -@torch.no_grad() -def offload_megatron_model_to_cpu(models): - """ - In megatron, the model and optimizer storage are: - - bf16 parameter data chunked in model parallel group - - fp32 grad chunked in model parallel group - - fp32 main_parameter chunked in model and dp group - - fp32 optimizer state chunked in model and dp group - """ - for model_chunk in models: - if isinstance(model_chunk, DDP): - model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] - for buffers in model_chunk_all_buffers: - for buffer in buffers: - # offload parameters - if buffer.param_data.storage().size() > 0: - buffer.param_data.cpu_data = buffer.param_data.data.cpu().pin_memory() - buffer.param_data_size = buffer.param_data.storage().size() - buffer.param_data.storage().resize_(0) - - assert buffer.param_data_size == buffer.param_data.cpu_data.storage().size() - - if buffer.grad_data.storage().size() > 0: - # if the grad_data size is already zero, we assume that it is already offloaded - buffer.grad_data_size = buffer.grad_data.storage().size() - buffer.grad_data.storage().resize_(0) - else: - # we need this for ref module - for _, param in model_chunk.named_parameters(): - param.data = param.data.to("cpu", non_blocking=True) - if param.grad is not None: - param.grad = param.grad.to("cpu", non_blocking=True) - gc.collect() - get_torch_device().empty_cache() - - -@torch.no_grad() -def load_megatron_model_to_gpu(models, load_grad=True): - for model_chunk in models: - if isinstance(model_chunk, DDP): - model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] - for buffers in model_chunk_all_buffers: - for buffer in buffers: - # sometimes, we don't want to load grad for pure inference - if load_grad: - buffer.grad_data.storage().resize_(buffer.grad_data_size) - buffer.grad_data.zero_() - - if buffer.param_data.storage().size() == 0: - buffer.param_data.storage().resize_(buffer.param_data_size) - # copy data from cpu to cuda - buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True) - else: - # we need this for ref module - device_id = get_device_id() - for _, param in model_chunk.named_parameters(): - param.data = param.data.to(device_id, non_blocking=True) - if param.grad is not None: - param.grad = param.grad.to(device_id, non_blocking=True) - gc.collect() - get_torch_device().empty_cache() - - -@torch.no_grad() -def offload_megatron_copy_params(optimizers): - """ - Offload optimizer parameters to CPU. Supports both Megatron optimizers - and `ChainedOptimizer`, which wraps a list of underlying optimizers. - - Args: - optimizers: The optimizer or ChainedOptimizer instance. - """ - - def _iter_opts(opt): - if isinstance(opt, ChainedOptimizer): - return opt.chained_optimizers - return [opt] - - def offload_tensor_to_cpu(tensor): - if tensor is None: - return - tensor.data = tensor.data.to("cpu", non_blocking=True) - - def offload_group_to_cpu(group): - if group is None: - return - - if isinstance(group, list): - for param_group in group: - if isinstance(param_group, list): - for param in param_group: - offload_tensor_to_cpu(param) - else: - offload_tensor_to_cpu(param_group) - else: - offload_tensor_to_cpu(group) - - # Offload all parameter groups to CPU for each underlying optimizer - - for _opt in _iter_opts(optimizers): - if hasattr(_opt, "shard_fp32_from_float16_groups"): - offload_group_to_cpu(_opt.shard_fp32_from_float16_groups) - - -@torch.no_grad() -def load_megatron_copy_params(optimizers): - """ - Load optimizer parameters back to GPU. Handles ChainedOptimizer. - - Args: - optimizers: Optimizer or ChainedOptimizer instance. - """ - - def _iter_opts(opt): - if isinstance(opt, ChainedOptimizer): - return opt.chained_optimizers - return [opt] - - def load_tensor_to_gpu(tensor): - if tensor is None: - return - device_id = get_device_id() - tensor.data = tensor.data.to(device_id, non_blocking=True) - - def load_group_to_gpu(group): - if group is None: - return - - if isinstance(group, list): - for param_group in group: - if isinstance(param_group, list): - for param in param_group: - load_tensor_to_gpu(param) - else: - load_tensor_to_gpu(param_group) - else: - load_tensor_to_gpu(group) - - # Load all parameter groups to GPU for each underlying optimizer - - for _opt in _iter_opts(optimizers): - if hasattr(_opt, "shard_fp32_from_float16_groups"): - load_group_to_gpu(_opt.shard_fp32_from_float16_groups) - - -@torch.no_grad() -def offload_megatron_optimizer(optimizers): - def _iter_opts(opt): - if isinstance(opt, ChainedOptimizer): - return opt.chained_optimizers - return [opt] - - for _opt in _iter_opts(optimizers): - offload_megatron_copy_params(_opt) - opt_state_dict_values = _opt.optimizer.state.values() - for v in opt_state_dict_values: - if "exp_avg" in v: - v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True) - if "exp_avg_sq" in v: - v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) - gc.collect() - get_torch_device().empty_cache() - - -@torch.no_grad() -def load_megatron_optimizer(optimizers): - def _iter_opts(opt): - if isinstance(opt, ChainedOptimizer): - return opt.chained_optimizers - return [opt] - - for _opt in _iter_opts(optimizers): - load_megatron_copy_params(_opt) - opt_state_dict_values = _opt.optimizer.state.values() - for v in opt_state_dict_values: - if "exp_avg" in v: - v["exp_avg"] = v["exp_avg"].to(get_device_id(), non_blocking=True) - if "exp_avg_sq" in v: - v["exp_avg_sq"] = v["exp_avg_sq"].to(get_device_id(), non_blocking=True) - gc.collect() - get_torch_device().empty_cache() - - -def get_dist_checkpoint_path(checkpoint_path): - local_mkdir_safe(checkpoint_path) - local_mkdir_safe(os.path.join(checkpoint_path, "dist_ckpt")) - return os.path.join(checkpoint_path, "dist_ckpt") - - -def get_hf_model_checkpoint_path(checkpoint_path): - local_mkdir_safe(checkpoint_path) - local_mkdir_safe(os.path.join(checkpoint_path, "huggingface")) - return os.path.join(checkpoint_path, "huggingface") - - -def get_transformer_config_checkpoint_path(checkpoint_path): - os.makedirs(checkpoint_path, exist_ok=True) - return os.path.join(checkpoint_path, "transformer_config.json") - - -def convert_megatron_model_to_transformers_model( - name, - param, - config: PretrainedConfig, - tp_size: int, - num_query_groups: int, - convert_qkv_gate_up_by_trunk_concat=False, -): - """Convert megatron model to transformers model.""" - new_params = {} - - def convert_qkv_shard(full_tensor, q_name, k_name, v_name): - nonlocal config - nonlocal tp_size - nonlocal num_query_groups - - q_shard_list = [] - k_shard_list = [] - v_shard_list = [] - hidden_size_per_head = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - - if config.num_key_value_heads >= tp_size: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - num_query_groups_per_partition = num_query_groups // tp_size - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_size_chunk = q_size_tp // num_query_groups_per_partition - kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): - q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] - q_shard_list.append(q_part) - k_shard_list.append(k_part) - v_shard_list.append(v_part) - else: - q_size_tp = hidden_size_per_head * config.num_attention_heads // tp_size - kv_size_tp = hidden_size_per_head - total_size = q_size_tp + 2 * kv_size_tp - for i in range(tp_size): - num_query_groups_per_partition = num_query_groups // tp_size - qkv_part = full_tensor[i * total_size : (i + 1) * total_size] - q_size_chunk = q_size_tp // num_query_groups_per_partition - kv_size_chunk = kv_size_tp // num_query_groups_per_partition - for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition): - q_part = qkv_part_chunk[:q_size_chunk] - k_part = qkv_part_chunk[q_size_chunk : q_size_chunk + kv_size_chunk] - v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk :] - q_shard_list.append(q_part) - if i * config.num_key_value_heads % tp_size == 0: - k_shard_list.append(k_part) - v_shard_list.append(v_part) - - new_params[q_name] = torch.cat(q_shard_list, dim=0) - new_params[k_name] = torch.cat(k_shard_list, dim=0) - new_params[v_name] = torch.cat(v_shard_list, dim=0) - - def convert_gate_up_shard(full_tensor, gate_name, up_name): - nonlocal config - nonlocal tp_size - - intermediate_size_tp = config.intermediate_size // tp_size - gate_weight_list = [] - up_weight_list = [] - for i in range(tp_size): - gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i : intermediate_size_tp * 2 * (i + 1)] - gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp] - up_weight_tp = gate_up_weight_tp[intermediate_size_tp:] - gate_weight_list.append(gate_weight_tp) - up_weight_list.append(up_weight_tp) - - new_params[gate_name] = torch.cat(gate_weight_list, dim=0) - new_params[up_name] = torch.cat(up_weight_list, dim=0) - - if name == "embedding.word_embeddings.weight": - new_params["model.embed_tokens.weight"] = param - elif "self_attention" in name: - splitted_name = name.split(".") - layer_number = splitted_name[2] - component = splitted_name[4] - param_type = splitted_name[5] - if component == "linear_proj": - new_params[f"model.layers.{layer_number}.self_attn.o_proj.weight"] = param - elif component == "linear_qkv" and not isinstance(param, list): - if param_type == "layer_norm_weight": - new_params[f"model.layers.{layer_number}.input_layernorm.weight"] = param - else: - if convert_qkv_gate_up_by_trunk_concat: - convert_qkv_shard( - param, - f"model.layers.{layer_number}.self_attn.q_proj.{param_type}", - f"model.layers.{layer_number}.self_attn.k_proj.{param_type}", - f"model.layers.{layer_number}.self_attn.v_proj.{param_type}", - ) - else: - new_params[f"model.layers.{layer_number}.self_attn.qkv_proj.{param_type}"] = param - elif component == "q_layernorm" or component == "k_layernorm": - hf_component = component.replace("layer", "") - new_params[f"model.layers.{layer_number}.self_attn.{hf_component}.weight"] = param - else: - assert isinstance(param, list) and len(param) == 3 - assert param_type == "weight" or param_type == "bias" - new_params[f"model.layers.{layer_number}.self_attn.q_proj.{param_type}"] = param[0] - new_params[f"model.layers.{layer_number}.self_attn.k_proj.{param_type}"] = param[1] - new_params[f"model.layers.{layer_number}.self_attn.v_proj.{param_type}"] = param[2] - elif "mlp" in name: - splitted_name = name.split(".") - layer_number = splitted_name[2] - component = splitted_name[4] - param_type = splitted_name[5] - if component == "linear_fc1" and not isinstance(param, list): - if param_type == "layer_norm_weight": - new_params[f"model.layers.{layer_number}.post_attention_layernorm.weight"] = param - elif param_type == "weight": - if convert_qkv_gate_up_by_trunk_concat: - convert_gate_up_shard( - param, - f"model.layers.{layer_number}.mlp.gate_proj.weight", - f"model.layers.{layer_number}.mlp.up_proj.weight", - ) - else: - new_params[f"model.layers.{layer_number}.mlp.gate_up_proj.weight"] = param - elif component == "linear_fc1" and isinstance(param, list): - assert len(param) == 2 - assert param_type == "weight" or param_type == "bias" - new_params[f"model.layers.{layer_number}.mlp.gate_proj.weight"] = param[0] - new_params[f"model.layers.{layer_number}.mlp.up_proj.weight"] = param[1] - elif component == "linear_fc2": - new_params[f"model.layers.{layer_number}.mlp.down_proj.weight"] = param - elif name == "decoder.final_layernorm.weight": - new_params["model.norm.weight"] = param - elif name == "output_layer.weight": - new_params["lm_head.weight"] = param - else: - raise ValueError(f"Unknown param name: {name}") - return new_params.keys(), new_params.values() - - -def broadcast_from_megatron_pp(tensor: torch.Tensor): - # tensor is not None only in one of the pp ranks - if tensor is not None: - shape = tensor.shape - dtype = tensor.dtype - tensor_parallel = getattr(tensor, "tensor_model_parallel", None) - partition_dim = getattr(tensor, "partition_dim", None) - tensor_spec = (shape, dtype, tensor_parallel, partition_dim) - else: - tensor_spec = None - tensor_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object( - object_list=tensor_spec_output, obj=tensor_spec, group=mpu.get_pipeline_model_parallel_group() - ) - # find the src rank - target_tensor_spec = None - src_rank = None - for rank, tensor_spec in enumerate(tensor_spec_output): - if tensor_spec is not None: - if target_tensor_spec is None: - target_tensor_spec = tensor_spec - else: - raise ValueError("A tensor exists on two pp ranks") - src_rank = rank - assert target_tensor_spec is not None - if tensor is None: - tensor = torch.empty(size=target_tensor_spec[0], dtype=target_tensor_spec[1], device=get_device_id()) - if target_tensor_spec[2] is not None: - tensor.tensor_model_parallel = target_tensor_spec[2] - if target_tensor_spec[3] is not None: - tensor.partition_dim = target_tensor_spec[3] - - global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank) - torch.distributed.broadcast(tensor=tensor, src=global_rank, group=mpu.get_pipeline_model_parallel_group()) - return tensor - - -def broadcast_str_from_megatron_pp(obj: Any): - obj_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object(object_list=obj_output, obj=obj, group=mpu.get_pipeline_model_parallel_group()) - - src_rank = None - target_obj = None - for rank, item in enumerate(obj_output): - if item is not None: - if target_obj is not None: - raise ValueError("An object exists on two pp ranks") - target_obj = item - src_rank = rank - - assert target_obj is not None, "No valid object found to broadcast." - - global_rank = torch.distributed.get_global_rank(group=mpu.get_pipeline_model_parallel_group(), group_rank=src_rank) - - obj_output = [None] * torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) - obj_output[0] = target_obj - torch.distributed.broadcast_object_list( - object_list=obj_output, src=global_rank, group=mpu.get_pipeline_model_parallel_group() - ) - - return obj_output[0] - - -def default_tp_concat_fn( - layer_name_mapping, - name, - train_params, - infer_params, - model_config, - hf_config=None, - convert_qkv_gate_up_by_simple_split=False, -): - """ - name: name of the parameter - train_params: training parameters - infer_params (Iterable[torch.Tensor]): a iterator towards list of parameters all-gathered from micro_dp_group - model_config: huggingface model_config - TODO(zhangchi.usc1992): currently, the implementation is adhoc. We can move this function to the model - definition so that it is model-agnostic. If the model doesn't implement this function, - we can throw an error to force user disable TP HybridEngine. - """ - from megatron.core import mpu - - train_tp_size = mpu.get_tensor_model_parallel_world_size() - if layer_name_mapping.get("qkv_layer_name") in name and "layer_norm" not in name: - # if the tensor is qkv, for each param on tp, split into q, k, v - # concat q, k, v separately. - q_lst = [] - k_lst = [] - v_lst = [] - num_attention_heads = model_config.num_attention_heads - num_key_value_heads = model_config.num_key_value_heads - if "vision_model" in name: - num_attention_heads = hf_config.vision_config.num_heads - num_key_value_heads = hf_config.vision_config.num_heads - assert num_attention_heads % num_key_value_heads == 0 - num_q_per_kv = num_attention_heads // num_key_value_heads - assert infer_params[0].shape[0] % (num_q_per_kv + 2) == 0, ( - f"param '{name}' shape '{infer_params[0].shape}' dim0 is not divisible by {num_q_per_kv + 2}" - ) - kv_size_per_tp = infer_params[0].shape[0] // (num_q_per_kv + 2) - split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp] - for infer_param in infer_params: - num_query_groups_per_partition = num_key_value_heads // train_tp_size - for chunk in infer_param.chunk(num_query_groups_per_partition): - split_size = [ - kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - kv_size_per_tp // num_query_groups_per_partition, - ] - q, k, v = chunk.split(split_size) - q_lst.append(q) - k_lst.append(k) - v_lst.append(v) - q = torch.cat(q_lst, dim=0) - k = torch.cat(k_lst, dim=0) - v = torch.cat(v_lst, dim=0) - infer_params = torch.cat((q, k, v), dim=0) if not convert_qkv_gate_up_by_simple_split else [q, k, v] - - elif ( - layer_name_mapping.get("gate_proj_layer_name") in name - and "layer_norm" not in name - and "vision_model.projection" not in name - ): - # if the tensor is gate and proj - gate_lst = [] - up_lst = [] - for infer_param in infer_params: - gate, up = infer_param.chunk(2) - gate_lst.append(gate) - up_lst.append(up) - gate = torch.cat(gate_lst, dim=0) - up = torch.cat(up_lst, dim=0) - infer_params = torch.cat((gate, up), dim=0) if not convert_qkv_gate_up_by_simple_split else [gate, up] - - elif "mlp.experts.linear_fc2.weight" in name: # moe - infer_params = torch.cat(infer_params, dim=1) - - else: - # concat tensor - infer_params = torch.cat(infer_params, dim=tp_utils.get_tensor_parallel_partition_dim(train_params)) - - return infer_params - - -def per_tensor_generator( - actor_module, - model_config, - weight_converter, - transformer_config, - layer_name_mapping, - convert_qkv_gate_up_by_simple_split=True, -): - from megatron.core import parallel_state as mpu - - pp_rank = mpu.get_pipeline_model_parallel_rank() - ep_size = mpu.get_expert_model_parallel_world_size() - etp_size = mpu.get_expert_tensor_parallel_world_size() - ep_group = mpu.get_expert_model_parallel_group() - etp_group = mpu.get_expert_tensor_parallel_group() - vpp_size = len(actor_module) - all_gather_group = mpu.get_tensor_model_parallel_group() - all_gather_group_size = torch.distributed.get_world_size(group=all_gather_group) - - def tensor_generator(): - for scan_vpp_idx in range(vpp_size): - existing_keys = set() - model = unwrap_model(actor_module[scan_vpp_idx]) - for name, param in model.named_parameters(): - existing_keys.add(name) - yield name, param - # note - # there is a bug in megatron GPTModel - # decoder.layers[n].mlp.router.expert_bias" in GPTModel is not registered in named_parameter, but in - # state_dict(). for now we patch it by adding those keys to extra_keys. - extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] - for name in extra_keys: - yield name, model.state_dict()[name].to(get_device_id()) - - # we need first make all rank get full model information - meta_info = [] - for scan_vpp_idx in range(vpp_size): - existing_keys = set() - model = unwrap_model(actor_module[scan_vpp_idx]) - for idx, (name, _) in enumerate(model.named_parameters()): - existing_keys.add(name) - meta_info.append((pp_rank, scan_vpp_idx, idx, name)) - extra_keys = [x for x in model.state_dict().keys() if "_extra_state" not in x and x not in existing_keys] - for name in extra_keys: - meta_info.append((pp_rank, scan_vpp_idx, idx, name)) - - obj_spec_output = [None] * mpu.get_pipeline_model_parallel_world_size() - torch.distributed.all_gather_object( - object_list=obj_spec_output, obj=meta_info, group=mpu.get_pipeline_model_parallel_group() - ) - layer_list_meta = [item for sublist in obj_spec_output for item in sublist] - - gen_func = tensor_generator() - - # lazy load tensor for full model - for cur_pp_rank, scan_vpp_idx, idx, name in layer_list_meta: - if model_config.tie_word_embeddings and ("output_layers" in name): - import warnings - - warnings.warn( - "Current model sharing word and embedding weights, skip output layer conversion", stacklevel=2 - ) - continue - - if cur_pp_rank == pp_rank: - try: - cur_name, cur_tensor = next(gen_func) - except StopIteration: - cur_name, cur_tensor = None, None - cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, transformer_config) - else: - cur_tensor, cur_name = None, None - - # pp broadcast model tensor and name - cur_name = broadcast_str_from_megatron_pp(cur_name) - broad_pp_tensor = broadcast_from_megatron_pp(cur_tensor) - - # (xya): this is a hack to fix the name of the parameters - while cur_name.startswith("module."): - cur_name = cur_name[len("module.") :] - - # EP - if ".mlp.experts.linear_fc" in cur_name and ep_size > 1: - num_experts = weight_converter.mcore_config.num_moe_experts - num_experts_per_rank = num_experts // ep_size - infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(ep_size)] - torch.distributed.all_gather(infer_params, broad_pp_tensor, group=ep_group) - - name_prefix, local_expert_id = cur_name.split(".weight") - local_expert_id = int(local_expert_id) - global_expert_ids = [num_experts_per_rank * ep_rank + local_expert_id for ep_rank in range(ep_size)] - global_expert_names = [f"{name_prefix}.weight{expert_id}" for expert_id in global_expert_ids] - - for name, param in zip(global_expert_names, infer_params, strict=True): - if etp_size > 1: - # gather etp - etp_params = [torch.empty_like(param) for _ in range(etp_size)] - torch.distributed.all_gather(etp_params, param, group=etp_group) - params = etp_params - else: - params = [param] - - merge_params = default_tp_concat_fn( - layer_name_mapping, - name, - broad_pp_tensor, - params, - model_config, - weight_converter.hf_config, - convert_qkv_gate_up_by_simple_split, - ) - if not isinstance(merge_params, list): - merge_params = [merge_params] - converted_names, converted_params = weight_converter.convert_param(name, merge_params) - - yield from zip(converted_names, converted_params, strict=True) - continue - - # tp all gather - if tp_utils.is_tensor_parallel_param(broad_pp_tensor): - # allocate a new tensor with proper size - if all_gather_group_size <= 1: - infer_params = [broad_pp_tensor] - else: - infer_params = [torch.empty_like(broad_pp_tensor) for _ in range(all_gather_group_size)] - torch.distributed.all_gather(infer_params, broad_pp_tensor, group=mpu.get_tensor_model_parallel_group()) - infer_params = default_tp_concat_fn( - layer_name_mapping, - cur_name, - broad_pp_tensor, - infer_params, - model_config, - weight_converter.hf_config, - convert_qkv_gate_up_by_simple_split, - ) - else: - infer_params = broad_pp_tensor - - if not isinstance(infer_params, list): - infer_params = [infer_params] - converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params) - - yield from zip(converted_names, converted_params, strict=True) - - -def get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConfig): - ''' - Get the index offset of any pipeline stage, given the level of pipelining. - - Make pp_rank and vpp_rank as two arguments to make it more flexible, - which is able to fetch layer offset for any pipeline stage. - The original function only returns the layer offset for current pipeline stage. - - Extension to https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_layer.py::get_transformer_layer_offset""" - ''' - if config.pipeline_model_parallel_size > 1: - if ( - config.num_layers_in_first_pipeline_stage is not None - or config.num_layers_in_last_pipeline_stage is not None - ): - # Calculate number of pipeline stages to distribute the remaining Transformer - # layers after deducting the Transformer layers in the first or the last stages - middle_pipeline_stages = config.pipeline_model_parallel_size - middle_pipeline_stages -= sum( - [ - 1 if x is not None else 0 - for x in ( - config.num_layers_in_first_pipeline_stage, - config.num_layers_in_last_pipeline_stage, - ) - ] - ) - - # Calculate layers to distribute in each pipeline stage. If the - # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage - # are not set, we will not enable uneven pipeline. All layers will be treated - # as middle layers. - num_layers_in_first_pipeline_stage = ( - 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage - ) - num_layers_in_last_pipeline_stage = ( - 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage - ) - - middle_num_layers = ( - config.num_layers - num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage - ) - - if mpu.get_virtual_pipeline_model_parallel_world_size() is not None: - vp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - - # Calculate number of layers in each virtual model chunk - # If the num_layers_in_first_pipeline_stage and - # num_layers_in_last_pipeline_stage are not set, all pipeline stages - # will be treated as middle pipeline stages in the calculation - num_layers_per_virtual_model_chunk_in_first_pipeline_stage = ( - 0 - if config.num_layers_in_first_pipeline_stage is None - else config.num_layers_in_first_pipeline_stage // vp_size - ) - - num_layers_per_virtual_model_chunk_in_last_pipeline_stage = ( - 0 - if config.num_layers_in_last_pipeline_stage is None - else config.num_layers_in_last_pipeline_stage // vp_size - ) - - num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = middle_num_layers // vp_size - - # First stage + middle stage + last stage - total_virtual_chunks = ( - num_layers_per_virtual_model_chunk_in_first_pipeline_stage - + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage - + num_layers_per_virtual_model_chunk_in_last_pipeline_stage - ) - - # Calculate the layer offset with interleaved uneven pipeline parallelism - if pipeline_rank == 0: - offset = vp_rank * total_virtual_chunks - else: - offset = ( - vp_rank * total_virtual_chunks - + num_layers_per_virtual_model_chunk_in_first_pipeline_stage - + (pipeline_rank - 1) - * (num_layers_per_vritual_model_chunk_in_middle_pipeline_stage // middle_pipeline_stages) - ) - else: - if middle_pipeline_stages > 0: - num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages - else: - num_layers_per_pipeline_rank = 0 - - middle_pipeline_rank = ( - pipeline_rank if config.num_layers_in_first_pipeline_stage is None else pipeline_rank - 1 - ) - - if pipeline_rank == 0: - offset = 0 - else: - offset = (middle_pipeline_rank * num_layers_per_pipeline_rank) + num_layers_in_first_pipeline_stage - else: - num_layers = config.num_layers - - # Increase the number of layers by one if we include the embedding (loss) - # layer into pipeline parallelism partition and placement - if config.account_for_embedding_in_pipeline_split: - num_layers += 1 - - if config.account_for_loss_in_pipeline_split: - num_layers += 1 - - num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size - - if mpu.get_virtual_pipeline_model_parallel_world_size() is not None: - vp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - - num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size - total_virtual_chunks = num_layers // vp_size - offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) - - # Reduce the offset of embedding layer from the total layer number - if config.account_for_embedding_in_pipeline_split and not mpu.is_pipeline_first_stage(): - offset -= 1 - else: - offset = pipeline_rank * num_layers_per_pipeline_rank - - # Reduce the offset of embedding layer from the total layer number - if config.account_for_embedding_in_pipeline_split and not mpu.is_pipeline_first_stage(): - offset -= 1 - else: - offset = 0 - return offset diff --git a/verl/utils/memory_buffer.py b/verl/utils/memory_buffer.py deleted file mode 100644 index 9386f0d88..000000000 --- a/verl/utils/memory_buffer.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This file contains utilities to manipulate torch memory buffers -""" - -from typing import Optional - -import torch -from torch import nn - -from verl.utils.device import get_device_name - - -class MemoryBuffer: - """ - A memory buffer is a contiguous torch tensor that may combine multiple tensors sharing with the underlying - memory. It must have a unique type to support this behavior. - """ - - def __init__(self, numel: int, numel_padded: int, dtype: torch.dtype, source: Optional[torch.Tensor] = None): - self.numel = numel - self.numel_padded = numel_padded - self.dtype = dtype - if source is not None: - self.data = source - else: - self.data = torch.zeros(self.numel_padded, dtype=self.dtype, device=get_device_name(), requires_grad=False) - - def zero(self): - """Reset the buffer to zero.""" - self.data.zero_() - - def get(self, shape, start_index): - """Return a tensor with the input `shape` as a view into the - 1-D data starting at `start_index`.""" - end_index = start_index + shape.numel() - assert end_index <= self.numel, "requested tensor is out of the buffer range." - buffer_tensor = self.data[start_index:end_index] - buffer_tensor = buffer_tensor.view(shape) - return buffer_tensor - - -def calc_padded_numel(shape: torch.Size, dtype: torch.dtype): - """for cuda memory alignment, make sure alignment by 128-bits""" - align_numel = 128 // torch.finfo(dtype).bits - numel = shape.numel() - return (numel + align_numel - 1) // align_numel * align_numel - - -def get_weight_buffer_meta_from_module(module: nn.Module) -> dict[str, dict]: - """ - Return a dictionary containing name to a shape and dtype. - """ - weight_buffer_meta = {} - for name, param in sorted(module.named_parameters()): - weight_buffer_meta[name] = {"shape": param.shape, "dtype": param.dtype} - return weight_buffer_meta - - -def build_memory_buffer(weight_buffer_meta: dict[str, dict]) -> dict[torch.dtype, MemoryBuffer]: - """Build the memory buffer given weight_buffer_meta - - Args: - weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors - - Returns: a large memory buffer for each dtype that can hold all the tensors - - """ - memory_buffers = {} - total_numel_map = {} # map from dtype to the total numel - for name, meta_info in sorted(weight_buffer_meta.items()): - shape = meta_info["shape"] - dtype = meta_info["dtype"] - - assert isinstance(shape, torch.Size) - assert isinstance(dtype, torch.dtype) - - if dtype not in total_numel_map: - total_numel_map[dtype] = 0 - - total_numel_map[dtype] += calc_padded_numel(shape, dtype) - - for dtype, total_numel in total_numel_map.items(): - memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype) - - return memory_buffers - - -def build_memory_reference_from_module( - module: torch.nn.Module, memory_buffers: dict[torch.dtype, MemoryBuffer], maintain_weight=True -): - start_index = {} - for dtype in memory_buffers: - start_index[dtype] = 0 - for name, param in sorted(module.named_parameters()): - memory_buffer = memory_buffers[param.dtype] - buffer = memory_buffer.get(shape=param.shape, start_index=start_index[param.dtype]) - # need to increment start_index - start_index[param.dtype] += calc_padded_numel(param.shape, param.dtype) - if maintain_weight: - buffer.copy_(param.data) - param.data = buffer - - -def build_memory_reference(weight_buffer_meta: dict[str, dict], memory_buffers: dict[torch.dtype, MemoryBuffer]): - """Build the memory references. The memory buffers are built using the build_memory_buffer API. - This API will allocate a weight buffer pointer to the memory buffer according to the weight_buffer_meta. - - Args: - weight_buffer_meta: - memory_buffers: - - Returns: - - """ - start_idx = {} - weight_buffers = {} - for dtype in memory_buffers: - start_idx[dtype] = 0 - - for name, meta_info in sorted(weight_buffer_meta.items()): - shape = meta_info["shape"] - dtype = meta_info["dtype"] - - buffer = memory_buffers[dtype].get(shape, start_index=start_idx[dtype]) - start_idx[dtype] += calc_padded_numel(shape, dtype) - weight_buffers[name] = buffer - - return weight_buffers - - -class MemoryBufferModuleWrapper: - """ - Note that we do not design MemoryBufferModuleWrapper as an nn.Module due to - - It will change the checkpoint name - """ - - def __init__(self, module: nn.Module): - super().__init__() - self.module = module - self.weight_buffer_meta = get_weight_buffer_meta_from_module(self.module) - self.memory_buffers = build_memory_buffer(self.weight_buffer_meta) - build_memory_reference_from_module(self.module, self.memory_buffers) - - def get_memory_buffers(self): - return self.memory_buffers - - def get_weight_buffer_meta(self): - return self.weight_buffer_meta - - -class MegatronMemoryBufferForRollout: - """ - We assume that - - inference engine has tp + dp - - actor has tp + pp + dp - - the tp between inference engine and actor should be the same - - memory_buffers: contains a list of memory_buffers, each is a dict from dtype to MemoryBuffer - - weight_buffers: contains a list of weight_buffers, each is a dict from name to param - - named_parameters: a dict from name to parameter that normalizes the names from pp and vpp. Note that - the named_parameters may not be directly compatible with inference engine. User has to take care of - this part such as the layout mismatches. (e.g. qkv transpose) - - Note that weight_buffer, named_parameters and memory_buffers share the same underlying GPU memory. - - When doing weight sync, the data is transfer via memory buffers - """ - - def __init__(self, transform_memory_param_fn): - self._memory_buffers = [] - self._weight_buffers = [] - self._named_parameters = {} - self.transform_memory_param_fn = transform_memory_param_fn - - def initialize_weight_buffer(self, weight_buffer_meta_pp: list[dict[str, dict]]): - """ - Initialize the weight buffer. The weight buffer is obtained according to the actor. We will construct - a large buffer for each dtype in the weight_buffer. - - Args: - weight_buffer_meta: contains pp models, each pp models contains a dictionary of mapping from - - Returns: None - - """ - self.weight_buffer_meta_pp = weight_buffer_meta_pp - - for weight_buffer_meta in self.weight_buffer_meta_pp: - memory_buffer = build_memory_buffer(weight_buffer_meta) - self._memory_buffers.append(memory_buffer) - self._weight_buffers.append(None) - - def build_memory_reference(self): - for i, weight_buffer_meta in enumerate(self.weight_buffer_meta_pp): - self._weight_buffers[i] = build_memory_reference(weight_buffer_meta, self._memory_buffers[i]) - self._named_parameters = self.transform_memory_param_fn(self._weight_buffers) - - @property - def named_parameters(self): - return self._named_parameters - - @property - def weight_buffers(self): - return self._weight_buffers - - @property - def memory_buffers(self): - return self._memory_buffers diff --git a/verl/utils/metric/__init__.py b/verl/utils/metric/__init__.py deleted file mode 100644 index 1e19d3f79..000000000 --- a/verl/utils/metric/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .utils import reduce_metrics - -__all__ = ["reduce_metrics"] diff --git a/verl/utils/metric/utils.py b/verl/utils/metric/utils.py deleted file mode 100644 index f9e7cd511..000000000 --- a/verl/utils/metric/utils.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Metrics utils. -""" - -from typing import Any - -import numpy as np - - -def reduce_metrics(metrics: dict[str, list[Any]]) -> dict[str, Any]: - """ - Reduces a dictionary of metric lists by computing the mean, max, or min of each list. - The reduce operation is determined by the key name: - - If the key contains "max", np.max is used - - If the key contains "min", np.min is used - - Otherwise, np.mean is used - - Args: - metrics: A dictionary mapping metric names to lists of metric values. - - Returns: - A dictionary with the same keys but with each list replaced by its reduced value. - - Example: - >>> metrics = { - ... "loss": [1.0, 2.0, 3.0], - ... "accuracy": [0.8, 0.9, 0.7], - ... "max_reward": [5.0, 8.0, 6.0], - ... "min_error": [0.1, 0.05, 0.2] - ... } - >>> reduce_metrics(metrics) - {"loss": 2.0, "accuracy": 0.8, "max_reward": 8.0, "min_error": 0.05} - """ - for key, val in metrics.items(): - if "max" in key: - metrics[key] = np.max(val) - elif "min" in key: - metrics[key] = np.min(val) - else: - metrics[key] = np.mean(val) - return metrics diff --git a/verl/utils/model.py b/verl/utils/model.py deleted file mode 100644 index 04cc34fe5..000000000 --- a/verl/utils/model.py +++ /dev/null @@ -1,664 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Utilities to create common models from huggingface -""" - -import os -import re -import warnings -from dataclasses import dataclass -from typing import Optional - -import numpy as np -import torch -from torch import nn -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - GenerationConfig, - MistralForSequenceClassification, - PretrainedConfig, - PreTrainedModel, -) -from transformers.modeling_outputs import CausalLMOutputWithPast - -from verl.models.registry import ModelRegistry -from verl.utils.import_utils import is_trl_available - - -class LambdaLayer(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - - def forward(self, *args, **kwargs): - return self.fn(*args, **kwargs) - - -def squeeze(x): - return torch.squeeze(x, dim=-1) - - -def update_model_config(module_config, override_config_kwargs): - """Update the module config with the override_config_kwargs. - Args: - module_config: The module config from Huggingface Transformers. - override_config_kwargs: The kwargs to override the module config. - """ - for key, val in override_config_kwargs.items(): - if isinstance(val, dict): - update_model_config(getattr(module_config, key), val) - else: - setattr(module_config, key, val) - - -def get_huggingface_actor_config(model_name: str, override_config_kwargs=None, trust_remote_code=False) -> dict: - if override_config_kwargs is None: - override_config_kwargs = {} - assert isinstance(override_config_kwargs, dict), ( - f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" - ) - module_config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code) - update_model_config(module_config, override_config_kwargs) - - return module_config - - -def get_generation_config( - model: str, - trust_remote_code: bool = False, -) -> Optional[GenerationConfig]: - try: - return GenerationConfig.from_pretrained(model) - except OSError: # Not found - try: - config = get_huggingface_actor_config( - model, - trust_remote_code=trust_remote_code, - ) - return GenerationConfig.from_model_config(config) - except OSError: # Not found - return None - - -def create_huggingface_actor(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module: - """ - - Args: - model_name: - override_config_kwargs: - - Returns: - - """ - if override_config_kwargs is None: - override_config_kwargs = {} - if automodel_kwargs is None: - automodel_kwargs = {} - assert isinstance(override_config_kwargs, dict), ( - f"override_config_kwargs must be a dict, got {type(override_config_kwargs)}" - ) - module_config = get_huggingface_actor_config( - model_name, override_config_kwargs, trust_remote_code=automodel_kwargs.get("trust_remote_code", False) - ) - module: nn.Module = AutoModelForCausalLM.from_config(module_config, **automodel_kwargs) - return module - - -def create_huggingface_critic(model_name: str, override_config_kwargs=None, automodel_kwargs=None) -> nn.Module: - """ - - Args: - model_name: - override_config_kwargs: - - Returns: - - """ - critic_module: nn.Module = create_huggingface_actor( - model_name, override_config_kwargs=override_config_kwargs, automodel_kwargs=automodel_kwargs - ) - if automodel_kwargs is None: - automodel_kwargs = {} - torch_dtype = automodel_kwargs.get("torch_dtype", torch.float32) - critic_module.lm_head = nn.Sequential( - nn.Linear(critic_module.config.hidden_size, 1, dtype=torch_dtype), LambdaLayer(fn=squeeze) - ) - return critic_module - - -def get_model_size(model: nn.Module, scale="auto"): - n_params = sum(p.numel() for p in model.parameters()) - - if scale == "auto": - if n_params > 1e9: - scale = "B" - elif n_params > 1e6: - scale = "M" - elif n_params > 1e3: - scale = "K" - else: - scale = "" - - if scale == "B": - n_params = n_params / 1e9 - elif scale == "M": - n_params = n_params / 1e6 - elif scale == "K": - n_params = n_params / 1e3 - elif scale == "": - pass - else: - raise NotImplementedError(f"Unknown scale {scale}") - - return n_params, scale - - -def print_model_size(model: nn.Module, name: str = None): - n_params, scale = get_model_size(model, scale="auto") - if name is None: - name = model.__class__.__name__ - print(f"{name} contains {n_params:.2f}{scale} parameters") - - -def create_random_mask( - input_ids: torch.Tensor, - max_ratio_of_valid_token: float, - max_ratio_of_left_padding: float, - min_ratio_of_valid_token: float = 0, -): - """Create a random mask given input_ids. Support left padding and right padding. - Process: - - Sample valid token length - - Sample left_padding length - - Generate padding - - Args: - input_ids: - shape (batch_size, seq_len) - - Returns: - - """ - assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.0 - assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.0 - assert min_ratio_of_valid_token <= max_ratio_of_valid_token - - batch_size, sequence_length = input_ids.shape - max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token) - min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token)) - max_left_padding = int(sequence_length * max_ratio_of_left_padding) - assert max_num_valid_tokens + max_left_padding <= sequence_length - assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length - masks = torch.ones_like(input_ids, dtype=torch.int64) - # TODO: we can make this faster - for i in range(batch_size): - num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64) - num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64) - - for index in range(num_left_padding): - masks[i, index] = 0 - - for index in range(num_left_padding + num_valid, sequence_length): - masks[i, index] = 0 - return masks - - -def compute_position_id_with_mask(mask): - return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None) - - -def convert_weight_keys(state_dict: dict[str, torch.Tensor], model: PreTrainedModel): - # convert state dict keys: https://github.com/huggingface/transformers/pull/38385 - if not hasattr(model, "_checkpoint_conversion_mapping"): - return state_dict - - reverse_key_mapping = {v: k for k, v in model._checkpoint_conversion_mapping.items()} - original_weights = {} - for key, value in state_dict.items(): - for pattern, replacement in reverse_key_mapping.items(): - replacement = replacement.lstrip("^") # strip off un-needed chars and patterns - replacement = re.sub(r"\(.*\)", "", replacement) - key, n_replace = re.subn(pattern, replacement, key) - # Early exit of the loop - if n_replace > 0: - break - - original_weights[key] = value - - return original_weights - - -def check_exclude_modules(config, key: str) -> bool: - """ - A helper method to check if the passed module's key name matches any of the exclude modules in the adapter_config. - Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py - - Args: - config (`LoraConfig` | `LycorisConfig`): A config to match exclude modules from - key (`str`): A key to search any matches in config - - Returns: - True of match object if key matches any exclude modules from config, False if no match found - """ - if hasattr(config, "exclude_modules") and config.exclude_modules: - if isinstance(config.exclude_modules, str): - if re.fullmatch(config.exclude_modules, key): - return True - elif key in config.exclude_modules: - return True - elif any(key.endswith(f".{exclude_key}") for exclude_key in config.exclude_modules): - return True - return False - - -def check_target_modules(config, key: str) -> bool: - """ - A helper method to check if the passed module's key name matches any of the target modules in the adapter_config. - Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/tuners_utils.py - - Args: - config (`LoraConfig` | `LycorisConfig`): A config to match target modules from - key (`str`): A key to search any matches in config - - Returns: - True of match object if key matches any target modules from config, False if no match found - """ - if isinstance(config.target_modules, str): - target_module_found = re.fullmatch(config.target_modules, key) - elif key in config.target_modules: - # this module is specified directly in target_modules - target_module_found = True - else: - target_module_found = any(key.endswith(f".{target_key}") for target_key in config.target_modules) - - layer_indexes = getattr(config, "layers_to_transform", None) - layers_pattern = getattr(config, "layers_pattern", None) - - is_using_layer_indexes = layer_indexes is not None and ( - len(layer_indexes) != 0 if isinstance(layer_indexes, list) else True - ) - if is_using_layer_indexes and target_module_found: - layer_index = None - # TODO: It's still unclear how empty layers_pattern (None, [], or "") should behave - # For now, empty layers_pattern means any layer pattern is ok - if layers_pattern is None or len(layers_pattern) == 0: - layer_index = re.match(r".*\.[^.]*\.(\d+)\.", key) - else: - layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern - for pattern in layers_pattern: - layer_index = re.match(rf".*\.{pattern}\.(\d+)\.", key) - if layer_index is not None: - break - - if layer_index is None: - target_module_found = False - else: - layer_index = int(layer_index.group(1)) - if isinstance(layer_indexes, int): - target_module_found = layer_index == layer_indexes - else: - target_module_found = layer_index in layer_indexes - - return target_module_found - - -def normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name="layers"): - """ - Transform the model name in each model_chunk in each pp stage into the name in inference engine - """ - from verl.utils.megatron_utils import get_transformer_layer_offset - - layer_offset = get_transformer_layer_offset(pp_rank, vpp_rank, transformer_config) - - if layer_name in name: # belong to an intermediate layer - split_name = name.split(".") - # find the num next to split_name - for i, name in enumerate(split_name): - if name == layer_name: - break - layer_num_idx = i + 1 - # check the name - assert len(split_name) >= layer_num_idx + 1, f"split_name = {split_name}" - assert split_name[layer_num_idx].isdigit(), f"split_name = {split_name}" - # increment layer_num_idx by layer_offset - split_name[layer_num_idx] = str(int(split_name[layer_num_idx]) + layer_offset) - name = ".".join(split_name) # weight name in inference_tp_model - return name - - -def normalize_pp_vpp_params(params, num_hidden_layers, layer_name="layers"): - """ - Normalize the pp vpp params into a complete named parameters. - This is useful when gather parameters from pp ranks and passed to a model without pp - - params: Iterable[List[Dict[str, param]]] - params contains a list of pp, with a list of vpp named_parameters in each vpp chunk. - output: Dict[str, param] - - """ - pp_size = len(params) - for pp_rank in range(len(params)): - vpp_size = len(params[pp_rank]) - for vpp_rank in range(vpp_size): - for name, param in params[pp_rank][vpp_rank].items(): - normalized_name = normalize_model_name( - name, pp_rank, vpp_rank, pp_size, vpp_size, num_hidden_layers, layer_name=layer_name - ) - yield normalized_name, param - - -def get_parallel_model_from_config( - config, megatron_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False -): - from megatron.core import ModelParallelConfig - - assert isinstance(megatron_config, ModelParallelConfig) - model_class = _get_parallel_model_architecture_from_config(config, value) - - model = model_class( - config, - megatron_config, - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - ) - return model - - -def _get_parallel_model_architecture_from_config(config: PretrainedConfig, value=False) -> type[nn.Module]: - architectures = getattr(config, "architectures", []) - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch, value) - print("after load model cls") - if model_cls is not None: - return model_cls - raise ValueError( - f"Model architectures {architectures} are not supported for now. Supported architectures: " - f"{ModelRegistry.get_supported_archs()}" - ) - - -def _load_hf_model(config, model_config, is_value_model, local_cache_path): - """Helper function containing the loading hf model logic""" - from accelerate import init_empty_weights - from megatron.core import parallel_state as mpu - - from verl.models.mcore.saver import _megatron_calc_global_rank - - assert hasattr(model_config, "architectures"), "architectures cannot be empty when load weight!" - architectures = getattr(model_config, "architectures", []) - local_cache_path = os.path.expanduser(local_cache_path) - - if config.model.path.startswith("hdfs:"): - from verl.utils.fs import copy_to_local - - print(f"start download from {config.model.path}") - local_model_path = copy_to_local( - src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False) - ) - print("finish download") - else: - local_model_path = config.model.path - print(f"load from local dir {local_model_path}") - - src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=mpu.get_context_parallel_rank()) - cpu_init_weights = lambda: torch.device("cpu") - init_context = init_empty_weights if torch.distributed.get_rank() != src_rank else cpu_init_weights - with init_context(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - # TODO: to find a better way to load mistral7b-rm lm_head - if "mistral7b-rm" in config.model.path: - model = MistralForSequenceClassification.from_pretrained( - local_model_path, - torch_dtype="auto", - # device_map="auto", # disable auto device_map, the HF weight is only loaded to CPU in src_rank - # low_cpu_mem_usage=True - ) # use score head instead of lm_head - state_dict = model.state_dict() - state_dict["lm_head.weight"] = state_dict["score.weight"] - state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"][ - :32000 - ] # workaround, 32001 -> 32000 - is_value_model = True - else: - model = AutoModelForCausalLM.from_pretrained( - local_model_path, - torch_dtype="auto", - # device_map="auto", # disable auto device_map, the HF weight is only loaded to CPU in src_rank - # low_cpu_mem_usage=True - ) - state_dict = model.state_dict() - - return architectures, model, state_dict, is_value_model - - -def get_hf_model_path(config, local_cache_path="~/.cache/verl/rlhf"): - local_cache_path = os.path.expanduser(local_cache_path) - if config.model.path.startswith("hdfs:"): - from verl.utils.fs import copy_to_local - - local_model_path = copy_to_local( - src=config.model.path, cache_dir=local_cache_path, use_shm=config.model.get("use_shm", False) - ) - else: - local_model_path = config.model.path - return local_model_path - - -def load_megatron_model_weights( - config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf" -): - """Load weights for verl customized model.""" - architectures, model, state_dict, is_value_model = _load_hf_model( - config, model_config, is_value_model, local_cache_path - ) - - from verl.models.weight_loader_registry import get_weight_loader - - print(f"before weight loader: architectures = {architectures}...") - for arch in architectures: - print(f"call weight loader arch = {arch}, model config = {model.config}") - weight_loader = get_weight_loader(arch) - weight_loader( - state_dict=state_dict, - wrapped_models=parallel_model, - config=model.config, - params_dtype=params_dtype, - is_value_model=is_value_model, - tie_word_embeddings=model_config.tie_word_embeddings, - ) - return model.config - - -def load_megatron_gptmodel_weights( - config, model_config, parallel_model, params_dtype, is_value_model=False, local_cache_path="~/.cache/verl/rlhf" -): - """Load weights for mcore GPT model.""" - _, model, state_dict, is_value_model = _load_hf_model(config, model_config, is_value_model, local_cache_path) - - from verl.models.mcore.loader import load_state_dict_to_megatron_gptmodel - - load_state_dict_to_megatron_gptmodel( - state_dict=state_dict, - wrapped_models=parallel_model, - config=model.config, - params_dtype=params_dtype, - is_value_model=is_value_model, - ) - del state_dict, model - - -# pad input_ids_rmpad, cu_seqlens and max_seqlen_in_batch to be divisible by tp -def pad_packed_inputs(unpad_tokens: torch.Tensor, cu_seqlens, max_seqlen_in_batch, size): - """pad the tokens such that the total length is a multiple of size. - This function is useful when applying sequence parallel and context parallel - - Args: - unpad_tokens: (total_nnz, ...). Tokens after removing padding - cu_seqlens: (total_nnz + 1,) - max_seqlen_in_batch: int - - Returns: - - """ - F = nn.functional - - total_nnz = unpad_tokens.shape[0] - - pad_size = 0 if total_nnz % size == 0 else size - total_nnz % size - - # we assume adding a new data in the batch with seqlen pad_size - if pad_size > 0: - if unpad_tokens.ndim == 1: - unpad_tokens = F.pad(unpad_tokens, (0, pad_size)) - elif unpad_tokens.ndim == 2: - unpad_tokens = F.pad(unpad_tokens, (0, 0, 0, pad_size)) - else: - raise NotImplementedError(f"Padding dim {unpad_tokens.ndim()} is not supported") - - cu_seqlens = F.pad(cu_seqlens, (0, 1), value=pad_size + cu_seqlens[-1]) - max_seqlen_in_batch = max(max_seqlen_in_batch, pad_size) - - return unpad_tokens, cu_seqlens, max_seqlen_in_batch - - -def load_mcore_dist_weights(parallel_model, dist_weight_path, is_value_model=False): - from megatron.core import dist_checkpointing - from megatron.core.dist_checkpointing.serialization import StrictHandling - - from verl.utils.megatron_utils import unwrap_model - - # strict = StrictHandling.IGNORE_ALL if is_value_model else StrictHandling.ASSUME_OK_UNEXPECTED - strict = StrictHandling.ASSUME_OK_UNEXPECTED - for model in parallel_model: - ssd = unwrap_model(model).sharded_state_dict() - if is_value_model: - for k in list(ssd.keys()): - if "output_layer" in k: - ssd.pop(k) - dist_checkpointing.load(ssd, dist_weight_path, strict=strict) - - return - - -def get_parallel_gptmodel_from_config( - tfconfig, hf_config, pre_process=None, post_process=None, share_embeddings_and_output_weights=False, value=False -): - from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec - from megatron.core.models.gpt.gpt_model import GPTModel - - use_te = True - assert tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now" - transformer_layer_spec = get_gpt_decoder_block_spec(tfconfig, use_transformer_engine=use_te) - rope_scaling_args = {} - if hf_config.rope_scaling is not None: - assert hf_config.rope_scaling["type"] == "linear", "only linear scaling is supported for now" - rope_scaling_args["seq_len_interpolation_factor"] = hf_config.rope_scaling["factor"] - parallel_model = GPTModel( - config=tfconfig, - transformer_layer_spec=transformer_layer_spec, - vocab_size=hf_config.vocab_size, - max_sequence_length=hf_config.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - share_embeddings_and_output_weights=share_embeddings_and_output_weights, - position_embedding_type="rope", - rotary_base=hf_config.rope_theta, - **rope_scaling_args, - ) - # # for layer in parallel_model.decoder.layers: - # layer.self_attention.core_attention.flash_attention.softmax_scale = None - if post_process and value: - from verl.models.llama.megatron.layers.parallel_linear import LinearForLastLayer - - parallel_model.output_layer = LinearForLastLayer( - input_size=tfconfig.hidden_size, output_size=1, config=tfconfig - ) - return parallel_model - - -def patch_valuehead_model(model) -> None: - from types import MethodType - - from transformers import PreTrainedModel - from trl import AutoModelForCausalLMWithValueHead - - def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: - if isinstance(self.pretrained_model, PreTrainedModel): - self.pretrained_model.tie_weights() - - def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: - if isinstance(self.pretrained_model, PreTrainedModel): - return self.pretrained_model.get_input_embeddings() - - def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: - if isinstance(self.pretrained_model, PreTrainedModel): - return self.pretrained_model.get_output_embeddings() - - def can_generate(self): - return False - - ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] - model._keys_to_ignore_on_save = ignore_modules - model.tie_weights = MethodType(tie_weights, model) - model.get_input_embeddings = MethodType(get_input_embeddings, model) - model.get_output_embeddings = MethodType(get_output_embeddings, model) - model.can_generate = MethodType(can_generate, model) - model._no_split_modules = getattr(model.pretrained_model, "_no_split_modules", []) - - -def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_code): - from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq - - try: - model = AutoModelForTokenClassification.from_pretrained( - pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=model_config, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, - ) - return model - except BaseException as e: - if not is_trl_available(): - raise RuntimeError( - f"model({local_path}) is not a value head model, please install trl to make it valid" - ) from e - - assert is_trl_available() - - from trl import AutoModelForCausalLMWithValueHead - - if type(model_config) in AutoModelForVision2Seq._model_mapping.keys(): - module_class = AutoModelForVision2Seq - else: - module_class = AutoModelForCausalLM - ori_model = module_class.from_pretrained( - pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=model_config, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, - ) - model = AutoModelForCausalLMWithValueHead.from_pretrained(ori_model) - patch_valuehead_model(model) - return model - - -@dataclass -class CausalLMOutputForPPO(CausalLMOutputWithPast): - log_probs: Optional[torch.FloatTensor] = None - entropy: Optional[torch.FloatTensor] = None diff --git a/verl/utils/net_utils.py b/verl/utils/net_utils.py deleted file mode 100644 index 138821cf3..000000000 --- a/verl/utils/net_utils.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import ipaddress - - -def is_ipv4(ip_str: str) -> bool: - """ - Check if the given string is an IPv4 address - - Args: - ip_str: The IP address string to check - - Returns: - bool: Returns True if it's an IPv4 address, False otherwise - """ - try: - ipaddress.IPv4Address(ip_str) - return True - except ipaddress.AddressValueError: - return False - - -def is_ipv6(ip_str: str) -> bool: - """ - Check if the given string is an IPv6 address - - Args: - ip_str: The IP address string to check - - Returns: - bool: Returns True if it's an IPv6 address, False otherwise - """ - try: - ipaddress.IPv6Address(ip_str) - return True - except ipaddress.AddressValueError: - return False diff --git a/verl/utils/profiler/__init__.py b/verl/utils/profiler/__init__.py deleted file mode 100644 index 2242c24fe..000000000 --- a/verl/utils/profiler/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..device import is_npu_available -from ..import_utils import is_nvtx_available -from .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer -from .profile import DistProfilerExtension, ProfilerConfig - -if is_nvtx_available(): - from .nvtx_profile import NsightSystemsProfiler as DistProfiler - from .nvtx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer -elif is_npu_available: - from .mstx_profile import NPUProfiler as DistProfiler - from .mstx_profile import mark_annotate, mark_end_range, mark_start_range, marked_timer -else: - from .performance import marked_timer - from .profile import DistProfiler, mark_annotate, mark_end_range, mark_start_range - -__all__ = [ - "GPUMemoryLogger", - "log_gpu_memory_usage", - "mark_start_range", - "mark_end_range", - "mark_annotate", - "DistProfiler", - "DistProfilerExtension", - "ProfilerConfig", - "simple_timer", - "marked_timer", -] diff --git a/verl/utils/profiler/config.py b/verl/utils/profiler/config.py deleted file mode 100644 index d4fb53650..000000000 --- a/verl/utils/profiler/config.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field -from typing import ClassVar - -from verl.base_config import BaseConfig - - -@dataclass -class ProfilerConfig(BaseConfig): - """Worker profiler config. Currently only support Nsight system profiler. - - The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. - - Args: - discrete (bool): True for each task has its own database, False for all tasks in one training step - share one database. - all_ranks (bool): Whether to profile all ranks. - ranks (list[int]): The ranks that will be profiled. Defaults to []. - """ - - # the fields expected to be frozen - _frozen_fields: ClassVar[set[str]] = {"discrete", "all_ranks", "ranks"} - - discrete: bool = False - - all_ranks: bool = False - - ranks: list[int] = field(default_factory=list) - - def union(self, other: "ProfilerConfig") -> "ProfilerConfig": - return ProfilerConfig( - all_ranks=self.all_ranks or other.all_ranks, - ranks=list(set(self.ranks or []) | set(other.ranks or [])), - discrete=self.discrete or other.discrete, - ) - - def intersect(self, other: "ProfilerConfig") -> "ProfilerConfig": - return ProfilerConfig( - all_ranks=self.all_ranks and other.all_ranks, - ranks=list(set(self.ranks or []) & set(other.ranks or [])), - discrete=self.discrete and other.discrete, - ) - - def __post_init__(self) -> None: - """config validation logics go here""" - assert isinstance(self.ranks, set | list | tuple), ( - f"Profiler ranks must be of type list, got {type(self.ranks)}" - ) diff --git a/verl/utils/profiler/empty_annotations.py b/verl/utils/profiler/empty_annotations.py deleted file mode 100644 index ed18dd359..000000000 --- a/verl/utils/profiler/empty_annotations.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Optional - - -def mark_start_range( - message: Optional[str] = None, - color: Optional[str] = None, - domain: Optional[str] = None, - category: Optional[str] = None, -) -> None: - pass - - -def mark_end_range(range_id: str) -> None: - pass - - -def mark_annotate( - message: Optional[str] = None, - color: Optional[str] = None, - domain: Optional[str] = None, - category: Optional[str] = None, -) -> Callable: - def decorator(func): - return func - - return decorator diff --git a/verl/utils/profiler/mstx_profile.py b/verl/utils/profiler/mstx_profile.py deleted file mode 100644 index c5c35cec0..000000000 --- a/verl/utils/profiler/mstx_profile.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Inspired from https://gitee.com/ascend/MindSpeed-RL/blob/master/mindspeed_rl/utils/utils.py -import functools -import os -from contextlib import contextmanager -from typing import Callable, Optional - -import torch_npu -from omegaconf import DictConfig -from torch_npu.npu import mstx - -from .profile import DistProfiler, ProfilerConfig - - -def mark_start_range(message: Optional[str] = None) -> None: - """Start a mark range in the profiler. - - Args: - message (str, optional): - The message to be displayed in the profiler. Defaults to None. - """ - return mstx.range_start(message=message) - - -def mark_end_range(range_id: str) -> None: - """End a mark range in the profiler. - - Args: - range_id (str): - The id of the mark range to end. - """ - return mstx.range_end(range_id) - - -def mark_annotate(message: Optional[str] = None) -> Callable: - """Decorate a function to annotate a mark range along with the function life cycle. - - Args: - message (str, optional): - The message to be displayed in the profiler. Defaults to None. - """ - - def decorator(func): - profile_message = message or func.__name__ - return mstx.mstx_range(profile_message)(func) - - return decorator - - -@contextmanager -def marked_timer(name: str, timing_raw: dict[str, float], **kwargs): - """Context manager for timing with MSTX markers. - - This utility function measures the execution time of code within its context, - accumulates the timing information, and adds MSTX markers for profiling. - - Args: - name (str): The name/identifier for this timing measurement. - timing_raw (Dict[str, float]): Dictionary to store timing information. - - Yields: - None: This is a context manager that yields control back to the code block. - """ - mark_range = mark_start_range(message=name) - from .performance import _timer - - yield from _timer(name, timing_raw) - mark_end_range(mark_range) - - -def get_npu_profiler(option: DictConfig, role: Optional[str] = None, profile_step: Optional[str] = None): - """Generate and return an NPU profiler object. - - Args: - option (DictConfig): - The options to control npu profiler. - role (str, optional): - The role of the current data collection. Defaults to None. - profile_step(str, optional): - The current training step. Defaults to None. - """ - if option.level == "level_none": - profile_level = torch_npu.profiler.ProfilerLevel.Level_none - elif option.level == "level0": - profile_level = torch_npu.profiler.ProfilerLevel.Level0 - elif option.level == "level1": - profile_level = torch_npu.profiler.ProfilerLevel.Level1 - elif option.level == "level2": - profile_level = torch_npu.profiler.ProfilerLevel.Level2 - else: - raise ValueError(f"level only supports level0, 1, 2, and level_none, but gets {option.level}") - - profile_save_path = option.save_path - if profile_step: - profile_save_path = os.path.join(profile_save_path, profile_step) - if role: - profile_save_path = os.path.join(profile_save_path, role) - - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=profile_level, - export_type=torch_npu.profiler.ExportType.Text, - data_simplification=True, - msprof_tx=True, - ) - - activites = [] - if option.with_npu: - activites.append(torch_npu.profiler.ProfilerActivity.NPU) - if option.with_cpu: - activites.append(torch_npu.profiler.ProfilerActivity.CPU) - - prof = torch_npu.profiler.profile( - with_modules=option.with_module, - with_stack=option.with_stack, - record_shapes=option.record_shapes, - profile_memory=option.with_memory, - activities=activites, - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_save_path, analyse_flag=option.analysis), - experimental_config=experimental_config, - ) - return prof - - -class NPUProfiler(DistProfiler): - """ - NPU profiler. Initialized in a worker to control the NPU profiler. - """ - - _define_count = 0 - - def __init__(self, rank: int, config: ProfilerConfig, **kwargs): - """Initialize the NsightSystemsProfiler. - - Args: - rank (int): The rank of the current process. - config (Optional[ProfilerConfig]): Configuration for the profiler. If None, a default configuration is used. - """ - if not config: - config = ProfilerConfig(ranks=[]) - self.this_step: bool = False - self.discrete: bool = config.discrete - self.this_rank: bool = False - self.profile_npu = None - self.profile_option = kwargs.get("option", None) - if config.all_ranks: - self.this_rank = True - elif config.ranks: - self.this_rank = rank in config.ranks - - def start(self, **kwargs): - role, profile_step = kwargs.get("role", None), kwargs.get("profile_step", None) - profile_step = str(profile_step) if profile_step is not None else None - if self.this_rank and self.profile_option is not None: - self.this_step = True - if not self.discrete and NPUProfiler._define_count == 0: - self.profile_npu = get_npu_profiler(option=self.profile_option, role=role, profile_step=profile_step) - self.profile_npu.start() - NPUProfiler._define_count += 1 - - def stop(self): - if self.this_rank and self.profile_option is not None: - self.this_step = False - if not self.discrete and NPUProfiler._define_count == 1: - self.profile_npu.step() - self.profile_npu.stop() - NPUProfiler._define_count -= 1 - - @staticmethod - def annotate(message: Optional[str] = None, role: Optional[str] = None, **kwargs) -> Callable: - """Decorate a Worker member function to profile the current rank in the current training step. - - Requires the target function to be a member function of a Worker, - which has a member field `profiler` with NPUProfiler type. - - Args: - message (str, optional): - The message to be displayed in the profiler. Defaults to None. - role (str, optional): - The role of the current data collection. Defaults to None. - """ - - def decorator(func): - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - profile_name = message or func.__name__ - - if self.profiler.this_step and self.profile_option is not None: - if self.profiler.discrete: - profile_npu = get_npu_profiler(option=self.profile_option, role=role) - profile_npu.start() - mark_range = mark_start_range(message=profile_name) - - result = func(self, *args, **kwargs) - - if self.profiler.this_step and self.profile_option is not None: - mark_end_range(mark_range) - if self.profiler.discrete: - profile_npu.step() - profile_npu.stop() - - return result - - return wrapper - - return decorator diff --git a/verl/utils/profiler/nvtx_profile.py b/verl/utils/profiler/nvtx_profile.py deleted file mode 100644 index 9ebce374f..000000000 --- a/verl/utils/profiler/nvtx_profile.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -from contextlib import contextmanager -from typing import Callable, Optional - -import nvtx -import torch - -from .profile import DistProfiler, ProfilerConfig - - -def mark_start_range( - message: Optional[str] = None, - color: Optional[str] = None, - domain: Optional[str] = None, - category: Optional[str] = None, -) -> None: - """Start a mark range in the profiler. - - Args: - message (str, optional): - The message to be displayed in the profiler. Defaults to None. - color (str, optional): - The color of the range. Defaults to None. - domain (str, optional): - The domain of the range. Defaults to None. - category (str, optional): - The category of the range. Defaults to None. - """ - return nvtx.start_range(message=message, color=color, domain=domain, category=category) - - -def mark_end_range(range_id: str) -> None: - """End a mark range in the profiler. - - Args: - range_id (str): - The id of the mark range to end. - """ - return nvtx.end_range(range_id) - - -def mark_annotate( - message: Optional[str] = None, - color: Optional[str] = None, - domain: Optional[str] = None, - category: Optional[str] = None, -) -> Callable: - """Decorate a function to annotate a mark range along with the function life cycle. - - Args: - message (str, optional): - The message to be displayed in the profiler. Defaults to None. - color (str, optional): - The color of the range. Defaults to None. - domain (str, optional): - The domain of the range. Defaults to None. - category (str, optional): - The category of the range. Defaults to None. - """ - - def decorator(func): - profile_message = message or func.__name__ - return nvtx.annotate(profile_message, color=color, domain=domain, category=category)(func) - - return decorator - - -@contextmanager -def marked_timer( - name: str, - timing_raw: dict[str, float], - color: str = None, - domain: Optional[str] = None, - category: Optional[str] = None, -): - """Context manager for timing with NVTX markers. - - This utility function measures the execution time of code within its context, - accumulates the timing information, and adds NVTX markers for profiling. - - Args: - name (str): The name/identifier for this timing measurement. - timing_raw (Dict[str, float]): Dictionary to store timing information. - color (Optional[str]): Color for the NVTX marker. Defaults to None. - domain (Optional[str]): Domain for the NVTX marker. Defaults to None. - category (Optional[str]): Category for the NVTX marker. Defaults to None. - - Yields: - None: This is a context manager that yields control back to the code block. - """ - mark_range = mark_start_range(message=name, color=color, domain=domain, category=category) - from .performance import _timer - - yield from _timer(name, timing_raw) - mark_end_range(mark_range) - - -class NsightSystemsProfiler(DistProfiler): - """Nsight system profiler. Installed in a worker to control the Nsight system profiler.""" - - def __init__(self, rank: int, config: Optional[ProfilerConfig], **kwargs): - """Initialize the NsightSystemsProfiler. - - Args: - rank (int): The rank of the current process. - config (Optional[ProfilerConfig]): Configuration for the profiler. If None, a default configuration is used. - """ - # If no configuration is provided, create a default ProfilerConfig with an empty list of ranks - if not config: - config = ProfilerConfig(ranks=[]) - self.this_step: bool = False - self.discrete: bool = config.discrete - self.this_rank: bool = False - if config.all_ranks: - self.this_rank = True - elif config.ranks: - self.this_rank = rank in config.ranks - - def start(self, **kwargs): - if self.this_rank: - self.this_step = True - if not self.discrete: - torch.cuda.profiler.start() - - def stop(self): - if self.this_rank: - self.this_step = False - if not self.discrete: - torch.cuda.profiler.stop() - - @staticmethod - def annotate( - message: Optional[str] = None, - color: Optional[str] = None, - domain: Optional[str] = None, - category: Optional[str] = None, - **kwargs, - ) -> Callable: - """Decorate a Worker member function to profile the current rank in the current training step. - - Requires the target function to be a member function of a Worker, which has a member field `profiler` with - NightSystemsProfiler type. - - Args: - message (str, optional): - The message to be displayed in the profiler. Defaults to None. - color (str, optional): - The color of the range. Defaults to None. - domain (str, optional): - The domain of the range. Defaults to None. - category (str, optional): - The category of the range. Defaults to None. - """ - - def decorator(func): - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - profile_name = message or func.__name__ - - if self.profiler.this_step: - if self.profiler.discrete: - torch.cuda.profiler.start() - mark_range = mark_start_range(message=profile_name, color=color, domain=domain, category=category) - - result = func(self, *args, **kwargs) - - if self.profiler.this_step: - mark_end_range(mark_range) - if self.profiler.discrete: - torch.cuda.profiler.stop() - - return result - - return wrapper - - return decorator diff --git a/verl/utils/profiler/performance.py b/verl/utils/profiler/performance.py deleted file mode 100644 index 8991896a2..000000000 --- a/verl/utils/profiler/performance.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import datetime -import inspect -import logging -from contextlib import contextmanager -from typing import Any, Optional - -import torch -import torch.distributed as dist -from codetiming import Timer - -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.logger import DecoratorLoggerBase - - -def _get_current_mem_info(unit: str = "GB", precision: int = 2) -> tuple[str]: - """Get current memory usage.""" - assert unit in ["GB", "MB", "KB"] - divisor = 1024**3 if unit == "GB" else 1024**2 if unit == "MB" else 1024 - mem_allocated = get_torch_device().memory_allocated() - mem_reserved = get_torch_device().memory_reserved() - # use get_torch_device().mem_get_info to profile device memory - # since vllm's sleep mode works below pytorch - # see https://github.com/vllm-project/vllm/pull/11743#issuecomment-2754338119 - mem_free, mem_total = get_torch_device().mem_get_info() - mem_used = mem_total - mem_free - mem_allocated = f"{mem_allocated / divisor:.{precision}f}" - mem_reserved = f"{mem_reserved / divisor:.{precision}f}" - mem_used = f"{mem_used / divisor:.{precision}f}" - mem_total = f"{mem_total / divisor:.{precision}f}" - return mem_allocated, mem_reserved, mem_used, mem_total - - -def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): - """Log GPU memory usage information. - - Args: - head (str): A descriptive header for the memory usage log message. - logger (logging.Logger, optional): Logger instance to use for logging. If None, prints to stdout. - level: Logging level to use. Defaults to logging.DEBUG. - rank (int): The rank of the process to log memory for. Defaults to 0. - """ - if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank): - mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() - message = ( - f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " - f"device memory used/total (GB): {mem_used}/{mem_total}" - ) - - if logger is None: - print(message) - else: - logger.log(msg=message, level=level) - - -class GPUMemoryLogger(DecoratorLoggerBase): - """A decorator class to log GPU memory usage. - - Example: - >>> from verl.utils.profiler.performance import GPUMemoryLogger - >>> @GPUMemoryLogger(role="actor") - >>> def update_actor(self, batch): - ... # real actor update logics - ... return - """ - - def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True): - if dist.is_initialized() and dist.get_world_size() > 1: - rank = dist.get_rank() - else: - rank = 0 - super().__init__(role, logger, level, rank, log_only_rank_0) - - def __call__(self, decorated_function: callable): - def f(*args, **kwargs): - return self.log(decorated_function, *args, **kwargs) - - return f - - def log(self, func, *args, **kwargs): - name = func.__name__ - mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() - message = ( - f"Before {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " - f"device memory used/total (GB): {mem_used}/{mem_total}" - ) - self.logging_function(message) - - output = func(*args, **kwargs) - - mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() - message = ( - f"After {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, " - f"device memory used/total (GB): {mem_used}/{mem_total}" - ) - - self.logging_function(message) - return output - - -def log_print(ctn: Any): - current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - frame = inspect.currentframe().f_back - function_name = frame.f_code.co_name - line_number = frame.f_lineno - file_name = frame.f_code.co_filename.split("/")[-1] - print(f"[{current_time}-{file_name}:{line_number}:{function_name}]: {ctn}") - - -def _timer(name: str, timing_raw: dict[str, float]): - """Inner function that handles the core timing logic. - - Args: - name (str): The name/identifier for this timing measurement. - timing_raw (Dict[str, float]): Dictionary to store timing information. - """ - with Timer(name=name, logger=None) as timer: - yield - if name not in timing_raw: - timing_raw[name] = 0 - timing_raw[name] += timer.last - - -@contextmanager -def simple_timer(name: str, timing_raw: dict[str, float]): - """Context manager for basic timing without NVTX markers. - - This utility function measures the execution time of code within its context - and accumulates the timing information in the provided dictionary. - - Args: - name (str): The name/identifier for this timing measurement. - timing_raw (Dict[str, float]): Dictionary to store timing information. - - Yields: - None: This is a context manager that yields control back to the code block. - """ - yield from _timer(name, timing_raw) - - -@contextmanager -def marked_timer( - name: str, - timing_raw: dict[str, float], - color: str = None, - domain: Optional[str] = None, - category: Optional[str] = None, -): - """Context manager for timing with platform markers. - - This utility function measures the execution time of code within its context, - accumulates the timing information, and adds platform markers for profiling. - This function is a default implementation when hardware profiler is not available. - - Args: - name (str): The name/identifier for this timing measurement. - timing_raw (Dict[str, float]): Dictionary to store timing information. - color (Optional[str]): Color for the marker. Defaults to None. - domain (Optional[str]): Domain for the marker. Defaults to None. - category (Optional[str]): Category for the marker. Defaults to None. - - Yields: - None: This is a context manager that yields control back to the code block. - """ - yield from _timer(name, timing_raw) - - -def reduce_timing(timing_raw: dict[str, float]) -> dict[str, float]: - """Reduce timing information across all processes. - - This function uses distributed communication to gather and sum the timing - information from all processes in a distributed environment. - - Args: - timing_raw (Dict[str, float]): Dictionary containing timing information. - - Returns: - Dict[str, float]: Reduced timing information. - """ - if not dist.is_initialized(): - return timing_raw - - key_list, timing_list = [], [] - for key in sorted(timing_raw.keys()): - key_list.append(key) - timing_list.append(timing_raw[key]) - timing_list = torch.tensor(timing_list, dtype=torch.float32, device=get_device_id()) - torch.distributed.all_reduce(timing_list, op=torch.distributed.ReduceOp.AVG) - timing_list = [tensor.item() for tensor in timing_list.to("cpu")] - timing_generate = {key_list[i]: timing_list[i] for i in range(len(key_list))} - return timing_generate diff --git a/verl/utils/profiler/profile.py b/verl/utils/profiler/profile.py deleted file mode 100644 index 4e7ce4fd3..000000000 --- a/verl/utils/profiler/profile.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from typing import Callable, Optional - -import torch -import torch.distributed - -from .config import ProfilerConfig - - -class Profiler: - """A PyTorch profiler wrapper class for collecting performance metrics. - - TODO(haibin.lin): this should implement the DistProfiler interface, and the config should be unified. - - This profiler provides a convenient interface for profiling PyTorch operations, - with support for: - - - CPU and CUDA activity profiling - - Configurable profiling schedule (wait/warmup/active steps) - - Multi-rank profiling support - - Chrome trace export - - Args: - config: Configuration object containing profiling parameters - """ - - def __init__(self, config): - # note : if we do not set use_profile, it will be set as None, so that all function will be skip - self.config = config - self.skip_prof = False - self.saved = False - self.prof = None - self.rank = torch.distributed.get_rank() - # we need to validate the config before using the profiler - self._validate() - if config.use_profile and self.rank in self.config.profile_ranks: - print(f"[Profiler] Profiler init for rank {self.rank}") - - self.prof = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule( - wait=max(self.config.step_start - 1, 0), - warmup=1 if self.config.step_start > 0 else 0, - active=self.config.step_end - self.config.step_start, - repeat=1, - ), - record_shapes=True, - with_stack=True, - ) - - def _validate(self): - if self.config.use_profile: - if self.config.profile_ranks is None: - print("[WARNING] Profile ranks is not set, default to rank 0") - self.config.profile_ranks = [0] - assert self.config.step_start >= 0, "[ERROR] Profile step start must be greater than 0" - assert self.config.step_end >= 0, "[ERROR] Profile step end must be greater than 0" - assert self.config.step_start < self.config.step_end, ( - "[ERROR] Profile step start must be less than step end" - ) - - def check(self): - return self.prof is not None and not self.skip_prof - - def start(self): - if self.check(): - print(f"[Profiler] started for rank {self.rank}") - self.prof.start() - - def step(self): - if self.check(): - self.prof.step() - - def stop(self): - if self.check(): - print(f"[Profiler] stopped for rank {self.rank}") - self.prof.stop() - - def save(self): - if self.prof is not None and not self.saved: - if not os.path.exists(self.config.save_path): - os.makedirs(self.config.save_path) - save_file_name = f"/prof_start_{self.config.step_start}_end_{self.config.step_end}_rank_{self.rank}.json" - print(f"[Profiler] Saving trace to {self.config.save_path + save_file_name}") - self.prof.export_chrome_trace(self.config.save_path + save_file_name) - self.skip_prof = True - self.saved = True - - def stop_and_save(self): - if self.check(): - self.stop() - self.save() - - def stop_trace(self): - if self.check(): - print(f"[Profiler] Trace stopped for rank {self.rank}") - self.skip_prof = True - - -def mark_start_range( - message: Optional[str] = None, - color: Optional[str] = None, - domain: Optional[str] = None, - category: Optional[str] = None, -) -> None: - """Start a profiling range marker (no-op implementation). - - Args: - message (Optional[str]): Message to associate with the range marker. - color (Optional[str]): Color for the marker visualization. - domain (Optional[str]): Domain for the marker. - category (Optional[str]): Category for the marker. - """ - pass - - -def mark_end_range(range_id: str) -> None: - """End a profiling range marker (no-op implementation). - - Args: - range_id (str): Identifier of the range to end. - """ - pass - - -def mark_annotate( - message: Optional[str] = None, - color: Optional[str] = None, - domain: Optional[str] = None, - category: Optional[str] = None, -) -> Callable: - """Decorator to annotate a function with profiling markers (no-op implementation). - - Args: - message (Optional[str]): Message to associate with the annotation. - color (Optional[str]): Color for the marker visualization. - domain (Optional[str]): Domain for the marker. - category (Optional[str]): Category for the marker. - - Returns: - Callable: Decorator function that returns the original function unchanged. - """ - - def decorator(func): - return func - - return decorator - - -class DistProfiler: - """A distributed profiler class for collecting performance metrics across multiple ranks. - - This profiler is designed to work in distributed training environments, allowing selective - profiling of specific ranks or all ranks. It provides basic start/stop functionality and - supports annotation of code sections for detailed profiling. - - Args: - rank (int): The rank of the current process - config (ProfilerConfig, optional): Configuration for the profiler. - """ - - def __init__(self, rank: int, config: Optional[ProfilerConfig] = None, **kwargs): - pass - - def start(self, **kwargs): - pass - - def stop(self): - pass - - @staticmethod - def annotate( - message: Optional[str] = None, - color: Optional[str] = None, - domain: Optional[str] = None, - category: Optional[str] = None, - **kwargs, - ) -> Callable: - def decorator(func): - return func - - return decorator - - -class DistProfilerExtension: - """An extension class for DistProfiler that provides distributed profiling capabilities. - - It is intended for workers in verl that single controller invokes. - - This class wraps a DistProfiler instance and provides methods to start/stop profiling - that can be dispatched across multiple ranks in a distributed training environment. - - Args: - profiler (DistProfiler): The base distributed profiler instance to extend - """ - - def __init__(self, profiler: DistProfiler): - self.profiler = profiler - - from verl.single_controller.base.decorator import Dispatch, register - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def start_profile(self, **kwargs) -> None: - """Start profiling for the current rank in the current training step.""" - self.profiler.start(**kwargs) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def stop_profile(self) -> None: - """Stop profiling for the current rank in the current training step.""" - self.profiler.stop() diff --git a/verl/utils/py_functional.py b/verl/utils/py_functional.py deleted file mode 100644 index affe4ed9a..000000000 --- a/verl/utils/py_functional.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Contain small python utility functions -""" - -import importlib -import multiprocessing -import os -import queue # Import the queue module for exception type hint -import signal -from contextlib import contextmanager -from functools import wraps -from types import SimpleNamespace -from typing import Any, Callable, Iterator, Optional - - -# --- Top-level helper for multiprocessing timeout --- -# This function MUST be defined at the top level to be pickleable -def _mp_target_wrapper(target_func: Callable, mp_queue: multiprocessing.Queue, args: tuple, kwargs: dict[str, Any]): - """ - Internal wrapper function executed in the child process. - Calls the original target function and puts the result or exception into the queue. - """ - try: - result = target_func(*args, **kwargs) - mp_queue.put((True, result)) # Indicate success and put result - except Exception as e: - # Ensure the exception is pickleable for the queue - try: - import pickle - - pickle.dumps(e) # Test if the exception is pickleable - mp_queue.put((False, e)) # Indicate failure and put exception - except (pickle.PicklingError, TypeError): - # Fallback if the original exception cannot be pickled - mp_queue.put((False, RuntimeError(f"Original exception type {type(e).__name__} not pickleable: {e}"))) - - -# Renamed the function from timeout to timeout_limit -def timeout_limit(seconds: float, use_signals: bool = False): - """ - Decorator to add a timeout to a function. - - Args: - seconds: The timeout duration in seconds. - use_signals: (Deprecated) This is deprecated because signals only work reliably in the main thread - and can cause issues in multiprocessing or multithreading contexts. - Defaults to False, which uses the more robust multiprocessing approach. - - Returns: - A decorated function with timeout. - - Raises: - TimeoutError: If the function execution exceeds the specified time. - RuntimeError: If the child process exits with an error (multiprocessing mode). - NotImplementedError: If the OS is not POSIX (signals are only supported on POSIX). - """ - - def decorator(func): - if use_signals: - if os.name != "posix": - raise NotImplementedError(f"Unsupported OS: {os.name}") - # Issue deprecation warning if use_signals is explicitly True - print( - "WARN: The 'use_signals=True' option in the timeout decorator is deprecated. \ - Signals are unreliable outside the main thread. \ - Please use the default multiprocessing-based timeout (use_signals=False)." - ) - - @wraps(func) - def wrapper_signal(*args, **kwargs): - def handler(signum, frame): - # Update function name in error message if needed (optional but good practice) - raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (signal)!") - - old_handler = signal.getsignal(signal.SIGALRM) - signal.signal(signal.SIGALRM, handler) - # Use setitimer for float seconds support, alarm only supports integers - signal.setitimer(signal.ITIMER_REAL, seconds) - - try: - result = func(*args, **kwargs) - finally: - # Reset timer and handler - signal.setitimer(signal.ITIMER_REAL, 0) - signal.signal(signal.SIGALRM, old_handler) - return result - - return wrapper_signal - else: - # --- Multiprocessing based timeout (existing logic) --- - @wraps(func) - def wrapper_mp(*args, **kwargs): - q = multiprocessing.Queue(maxsize=1) - process = multiprocessing.Process(target=_mp_target_wrapper, args=(func, q, args, kwargs)) - process.start() - process.join(timeout=seconds) - - if process.is_alive(): - process.terminate() - process.join(timeout=0.5) # Give it a moment to terminate - if process.is_alive(): - print(f"Warning: Process {process.pid} did not terminate gracefully after timeout.") - # Update function name in error message if needed (optional but good practice) - raise TimeoutError(f"Function {func.__name__} timed out after {seconds} seconds (multiprocessing)!") - - try: - success, result_or_exc = q.get(timeout=0.1) # Small timeout for queue read - if success: - return result_or_exc - else: - raise result_or_exc # Reraise exception from child - except queue.Empty as err: - exitcode = process.exitcode - if exitcode is not None and exitcode != 0: - raise RuntimeError( - f"Child process exited with error (exitcode: {exitcode}) before returning result." - ) from err - else: - # Should have timed out if queue is empty after join unless process died unexpectedly - # Update function name in error message if needed (optional but good practice) - raise TimeoutError( - f"Operation timed out or process finished unexpectedly without result " - f"(exitcode: {exitcode})." - ) from err - finally: - q.close() - q.join_thread() - - return wrapper_mp - - return decorator - - -def union_two_dict(dict1: dict, dict2: dict): - """Union two dict. Will throw an error if there is an item not the same object with the same key. - - Args: - dict1: - dict2: - - Returns: - - """ - for key, val in dict2.items(): - if key in dict1: - assert dict2[key] == dict1[key], f"{key} in meta_dict1 and meta_dict2 are not the same object" - dict1[key] = val - - return dict1 - - -def append_to_dict(data: dict, new_data: dict): - """Append values from new_data to lists in data. - - For each key in new_data, this function appends the corresponding value to a list - stored under the same key in data. If the key doesn't exist in data, a new list is created. - - Args: - data (Dict): The target dictionary containing lists as values. - new_data (Dict): The source dictionary with values to append. - - Returns: - None: The function modifies data in-place. - """ - for key, val in new_data.items(): - if key not in data: - data[key] = [] - data[key].append(val) - - -class NestedNamespace(SimpleNamespace): - """A nested version of SimpleNamespace that recursively converts dictionaries to namespaces. - - This class allows for dot notation access to nested dictionary structures by recursively - converting dictionaries to NestedNamespace objects. - - Example: - config_dict = {"a": 1, "b": {"c": 2, "d": 3}} - config = NestedNamespace(config_dict) - # Access with: config.a, config.b.c, config.b.d - - Args: - dictionary: The dictionary to convert to a nested namespace. - **kwargs: Additional attributes to set on the namespace. - """ - - def __init__(self, dictionary, **kwargs): - super().__init__(**kwargs) - for key, value in dictionary.items(): - if isinstance(value, dict): - self.__setattr__(key, NestedNamespace(value)) - else: - self.__setattr__(key, value) - - -class DynamicEnumMeta(type): - def __iter__(cls) -> Iterator[Any]: - return iter(cls._registry.values()) - - def __contains__(cls, item: Any) -> bool: - # allow `name in EnumClass` or `member in EnumClass` - if isinstance(item, str): - return item in cls._registry - return item in cls._registry.values() - - def __getitem__(cls, name: str) -> Any: - return cls._registry[name] - - def __reduce_ex__(cls, protocol): - # Always load the existing module and grab the class - return getattr, (importlib.import_module(cls.__module__), cls.__name__) - - def names(cls): - return list(cls._registry.keys()) - - def values(cls): - return list(cls._registry.values()) - - -class DynamicEnum(metaclass=DynamicEnumMeta): - _registry: dict[str, "DynamicEnum"] = {} - _next_value: int = 0 - - def __init__(self, name: str, value: int): - self.name = name - self.value = value - - def __repr__(self): - return f"<{self.__class__.__name__}.{self.name}: {self.value}>" - - def __reduce_ex__(self, protocol): - """ - Unpickle via: getattr(import_module(module).Dispatch, 'ONE_TO_ALL') - so the existing class is reused instead of re-executed. - """ - module = importlib.import_module(self.__class__.__module__) - enum_cls = getattr(module, self.__class__.__name__) - return getattr, (enum_cls, self.name) - - @classmethod - def register(cls, name: str) -> "DynamicEnum": - key = name.upper() - if key in cls._registry: - raise ValueError(f"{key} already registered") - member = cls(key, cls._next_value) - cls._registry[key] = member - setattr(cls, key, member) - cls._next_value += 1 - return member - - @classmethod - def remove(cls, name: str): - key = name.upper() - member = cls._registry.pop(key) - delattr(cls, key) - return member - - @classmethod - def from_name(cls, name: str) -> Optional["DynamicEnum"]: - return cls._registry.get(name.upper()) - - -@contextmanager -def temp_env_var(key: str, value: str): - """Context manager for temporarily setting an environment variable. - - This context manager ensures that environment variables are properly set and restored, - even if an exception occurs during the execution of the code block. - - Args: - key: Environment variable name to set - value: Value to set the environment variable to - - Yields: - None - - Example: - >>> with temp_env_var("MY_VAR", "test_value"): - ... # MY_VAR is set to "test_value" - ... do_something() - ... # MY_VAR is restored to its original value or removed if it didn't exist - """ - original = os.environ.get(key) - os.environ[key] = value - try: - yield - finally: - if original is None: - os.environ.pop(key, None) - else: - os.environ[key] = original - - -def convert_to_regular_types(obj): - """Convert Hydra configs and other special types to regular Python types.""" - from omegaconf import DictConfig, ListConfig - - if isinstance(obj, ListConfig | DictConfig): - return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj) - elif isinstance(obj, list | tuple): - return [convert_to_regular_types(x) for x in obj] - elif isinstance(obj, dict): - return {k: convert_to_regular_types(v) for k, v in obj.items()} - return obj diff --git a/verl/utils/ray_utils.py b/verl/utils/ray_utils.py deleted file mode 100644 index a738c0f3d..000000000 --- a/verl/utils/ray_utils.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Contains commonly used utilities for ray -""" - -import concurrent.futures -import os -from typing import Any, Optional - -import ray - - -def ray_noset_visible_devices(env_vars=os.environ): - # Refer to - # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 - # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103 - # https://github.com/ray-project/ray/blob/3b9e729f6a669ffd85190f901f5e262af79771b0/python/ray/_private/accelerators/amd_gpu.py#L114-L115 - # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95 - # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117 - # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109 - # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172 - # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98 - NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", - "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", - "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", - "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", - "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", - "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", - "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", - ] - return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) - - -def parallel_put(data_list: list[Any], max_workers: Optional[int] = None): - """ - Puts a list of data into the Ray object store in parallel using a thread pool. - - Args: - data_list (List[Any]): A list of Python objects to be put into the Ray object store. - max_workers (int, optional): The maximum number of worker threads to use. - Defaults to min(len(data_list), 16). - - Returns: - List[ray.ObjectRef]: A list of Ray object references corresponding to the input data_list, - maintaining the original order. - """ - assert len(data_list) > 0, "data_list must not be empty" - - def put_data(index, data): - return index, ray.put(data) - - if max_workers is None: - max_workers = min(len(data_list), 16) - - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - data_list_f = [executor.submit(put_data, i, data) for i, data in enumerate(data_list)] - res_lst = [] - for future in concurrent.futures.as_completed(data_list_f): - res_lst.append(future.result()) - - # reorder based on index - output = [None for _ in range(len(data_list))] - for res in res_lst: - index, data_ref = res - output[index] = data_ref - - return output diff --git a/verl/utils/rendezvous/__init__.py b/verl/utils/rendezvous/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/utils/rendezvous/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/utils/rendezvous/ray_backend.py b/verl/utils/rendezvous/ray_backend.py deleted file mode 100644 index d9911815d..000000000 --- a/verl/utils/rendezvous/ray_backend.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time - -import ray -from cupy.cuda.nccl import NcclCommunicator, get_unique_id -from ray.util import list_named_actors - - -@ray.remote -class NCCLIDStore: - def __init__(self, nccl_id): - self._nccl_id = nccl_id - - def get(self): - return self._nccl_id - - -def get_nccl_id_store_by_name(name): - all_actors = list_named_actors(all_namespaces=True) - matched_actors = [actor for actor in all_actors if actor.get("name", None) == name] - if len(matched_actors) == 1: - actor = matched_actors[0] - return ray.get_actor(**actor) - elif len(matched_actors) > 1: - logging.warning("multiple actors with same name found: %s", matched_actors) - elif len(matched_actors) == 0: - logging.info("failed to get any actor named %s", name) - return None - - -def create_nccl_communicator_in_ray( - rank: int, world_size: int, group_name: str, max_retries: int = 100, interval_s: int = 5 -): - if rank == 0: - nccl_id = get_unique_id() - nccl_id_store = NCCLIDStore.options(name=group_name).remote(nccl_id) - - assert ray.get(nccl_id_store.get.remote()) == nccl_id - communicator = NcclCommunicator( - ndev=world_size, - commId=nccl_id, - rank=0, - ) - return communicator - else: - for i in range(max_retries): - nccl_id_store = get_nccl_id_store_by_name(group_name) - if nccl_id_store is not None: - logging.info("nccl_id_store %s got", group_name) - nccl_id = ray.get(nccl_id_store.get.remote()) - logging.info("nccl id for %s got: %s", group_name, nccl_id) - communicator = NcclCommunicator( - ndev=world_size, - commId=nccl_id, - rank=rank, - ) - return communicator - logging.info("failed to get nccl_id for %d time, sleep for %d seconds", i + 1, interval_s) - time.sleep(interval_s) diff --git a/verl/utils/rollout_trace.py b/verl/utils/rollout_trace.py deleted file mode 100644 index e34e285d0..000000000 --- a/verl/utils/rollout_trace.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import contextlib -import functools -import inspect -import os -from typing import Optional - - -class RolloutTraceConfig: - _instance: Optional["RolloutTraceConfig"] = None - backend: Optional[str] = None - client: Optional[object] = None - token2text: bool = False - _initialized: bool = False - project_name: str = None - experiment_name: str = None - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - @classmethod - def get_instance(cls) -> "RolloutTraceConfig": - if cls._instance is None: - cls._instance = cls() - return cls._instance - - @classmethod - def init(cls, project_name: str, experiment_name: str, backend: str, token2text: bool = False): - config = cls.get_instance() - if config._initialized: - return - - config.backend = backend - config.token2text = token2text - config.project_name = project_name - config.experiment_name = experiment_name - - if backend == "weave": - import weave - - config.client = weave.init(project_name) - elif backend == "mlflow": - import mlflow - - mlflow.config.enable_async_logging() - config.client = mlflow - - MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db") - mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) - - mlflow.set_experiment(project_name) - else: - config.client = None - - config._initialized = True - - @classmethod - def get_backend(cls) -> Optional[str]: - return cls.get_instance().backend - - @classmethod - def get_client(cls) -> Optional[object]: - return cls.get_instance().client - - @classmethod - def enable_token2text(cls) -> Optional[bool]: - return cls.get_instance().token2text - - @classmethod - def reset(cls): - cls._instance = None - - -@contextlib.contextmanager -def rollout_trace_attr(sample_index=None, step=None, rollout_n=None, name="rollout_trace", validate=False): - """A context manager to add attributes to a trace for the configured backend.""" - backend = RolloutTraceConfig.get_backend() - attributes = {} - if backend: - if sample_index is not None: - attributes["sample_index"] = sample_index - if step is not None: - attributes["step"] = step - if rollout_n is not None: - attributes["rollout_n"] = rollout_n - attributes["validate"] = validate - attributes["experiment_name"] = RolloutTraceConfig.get_instance().experiment_name - - if not attributes or backend is None: - yield - return - - if backend == "weave": - import weave - - with weave.attributes(attributes): - yield - elif backend == "mlflow": - import mlflow - - with mlflow.start_span(name=name) as span: - trace_id = span.trace_id - for key, value in attributes.items(): - mlflow.set_trace_tag(trace_id, str(key), str(value)) - yield - else: - yield - - -def rollout_trace_op(func): - @functools.wraps(func) - async def async_wrapper(self, *args, **kwargs): - backend = RolloutTraceConfig.get_backend() - enable_token2text = RolloutTraceConfig.enable_token2text() - if backend is None: - return await func(self, *args, **kwargs) - - sig = inspect.signature(func) - bound_args = sig.bind(self, *args, **kwargs) - bound_args.apply_defaults() - inputs = dict(bound_args.arguments) - del inputs["self"] - - async def add_token2text(self, result): - if hasattr(result, "prompt_ids") and hasattr(self, "tokenizer") and hasattr(self.tokenizer, "decode"): - _result = vars(result) - loop = asyncio.get_running_loop() - if hasattr(result, "prompt_ids"): - prompt_text = await loop.run_in_executor(None, self.tokenizer.decode, result.prompt_ids) - _result["prompt_text"] = prompt_text - - if hasattr(result, "response_ids"): - response_text = await loop.run_in_executor(None, self.tokenizer.decode, result.response_ids) - _result["response_text"] = response_text - return _result - return result - - if backend == "weave": - tracer = RolloutTraceConfig.get_client() - from weave.trace.context import call_context - - cur_attributes = {**call_context.call_attributes.get()} - call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes) - try: - result = await func(self, *args, **kwargs) - - if enable_token2text: - _result = await add_token2text(self, result) - tracer.finish_call(call, output=_result) - else: - tracer.finish_call(call, output=result) - - return result - - except Exception as e: - tracer.finish_call(call, exception=e) - raise e - elif backend == "mlflow": - import mlflow - - with mlflow.start_span(name=func.__qualname__) as span: - span.set_inputs(inputs) - result = await func(self, *args, **kwargs) - if enable_token2text: - _result = await add_token2text(self, result) - span.set_outputs(_result) - else: - span.set_outputs(result) - - return result - - else: - return await func(self, *args, **kwargs) - - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - backend = RolloutTraceConfig.get_backend() - if backend is None: - return func(self, *args, **kwargs) - - sig = inspect.signature(func) - bound_args = sig.bind(self, *args, **kwargs) - bound_args.apply_defaults() - inputs = dict(bound_args.arguments) - del inputs["self"] - - if backend == "weave": - tracer = RolloutTraceConfig.get_client() - from weave.trace.context import call_context - - cur_attributes = {**call_context.call_attributes.get()} - call = tracer.create_call(op=func.__qualname__, inputs=inputs, attributes=cur_attributes) - try: - result = func(self, *args, **kwargs) - tracer.finish_call(call, output=result) - return result - except Exception as e: - tracer.finish_call(call, exception=e) - raise e - elif backend == "mlflow": - import mlflow - - return mlflow.trace(func)(self, *args, **kwargs) - else: - return func(self, *args, **kwargs) - - return async_wrapper if inspect.iscoroutinefunction(func) else wrapper diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py deleted file mode 100644 index bde116adf..000000000 --- a/verl/utils/seqlen_balancing.py +++ /dev/null @@ -1,375 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import heapq -from itertools import chain - -import torch -from torch import distributed as dist - -from verl.protocol import DataProto -from verl.utils.device import get_device_name - - -def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool): - # see: https://en.wikipedia.org/wiki/Largest_differencing_method - class Set: - def __init__(self) -> None: - self.sum = 0 - self.items = [] - - def add(self, idx: int, val: int): - self.items.append((idx, val)) - self.sum += val - - def merge(self, other): - for idx, val in other.items: - self.items.append((idx, val)) - self.sum += val - - def __lt__(self, other): - if self.sum != other.sum: - return self.sum < other.sum - if len(self.items) != len(other.items): - return len(self.items) < len(other.items) - return self.items < other.items - - class State: - def __init__(self, items: list[tuple[int, int]], k: int) -> None: - self.k = k - # sets should always be decreasing order - self.sets = [Set() for _ in range(k)] - assert len(items) in [1, k], f"{len(items)} not in [1, {k}]" - for i, (idx, seqlen) in enumerate(items): - self.sets[i].add(idx=idx, val=seqlen) - self.sets = sorted(self.sets, reverse=True) - - def get_partitions(self): - partitions = [] - for i in range(len(self.sets)): - cur_partition = [] - for idx, _ in self.sets[i].items: - cur_partition.append(idx) - partitions.append(cur_partition) - return partitions - - def merge(self, other): - for i in range(self.k): - self.sets[i].merge(other.sets[self.k - 1 - i]) - self.sets = sorted(self.sets, reverse=True) - - @property - def spread(self) -> int: - return self.sets[0].sum - self.sets[-1].sum - - def __lt__(self, other): - # least heap, let the state with largest spread to be popped first, - # if the spread is the same, let the state who has the largest set - # to be popped first. - if self.spread != other.spread: - return self.spread > other.spread - return self.sets[0] > other.sets[0] - - def __repr__(self) -> str: - repr_str = "[" - for i in range(self.k): - if i > 0: - repr_str += "," - repr_str += "{" - for j, (_, seqlen) in enumerate(self.sets[i].items): - if j > 0: - repr_str += "," - repr_str += str(seqlen) - repr_str += "}" - repr_str += "]" - return repr_str - - sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)]) - states_pq = [] - if equal_size: - assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0" - for offset in range(0, len(sorted_seqlen_list), k_partitions): - items = [] - for i in range(k_partitions): - seqlen, idx = sorted_seqlen_list[offset + i] - items.append((idx, seqlen)) - heapq.heappush(states_pq, State(items=items, k=k_partitions)) - else: - for seqlen, idx in sorted_seqlen_list: - heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions)) - - while len(states_pq) > 1: - state0 = heapq.heappop(states_pq) - state1 = heapq.heappop(states_pq) - # merge states - state0.merge(state1) - heapq.heappush(states_pq, state0) - - final_state = states_pq[0] - partitions = final_state.get_partitions() - if equal_size: - for i, partition in enumerate(partitions): - assert len(partition) * k_partitions == len(seqlen_list), ( - f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" - ) - return partitions - - -def greedy_partition(seqlen_list: list[int], k_partitions: int, equal_size: bool): - bias = sum(seqlen_list) + 1 if equal_size else 0 - sorted_seqlen = [(seqlen + bias, i) for i, seqlen in enumerate(seqlen_list)] - partitions = [[] for _ in range(k_partitions)] - partition_sums = [0 for _ in range(k_partitions)] - for seqlen, i in sorted_seqlen: - min_idx = None - for j in range(k_partitions): - if min_idx is None or partition_sums[j] < partition_sums[min_idx]: - min_idx = j - partitions[min_idx].append(i) - partition_sums[min_idx] += seqlen - if equal_size: - for i, partition in enumerate(partitions): - assert len(partition) * k_partitions == len(seqlen_list), ( - f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" - ) - return partitions - - -def get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool): - """ - Calculates partitions of indices from seqlen_list such that the sum of sequence lengths - in each partition is balanced. Uses the Karmarkar-Karp differencing method. - - This is useful for balancing workload across devices or batches, especially when - dealing with variable sequence lengths. - - Args: - seqlen_list (List[int]): A list of sequence lengths for each item. - k_partitions (int): The desired number of partitions. - equal_size (bool): If True, ensures that each partition has the same number of items. - Requires len(seqlen_list) to be divisible by k_partitions. - If False, partitions can have varying numbers of items, focusing - only on balancing the sum of sequence lengths. - - Returns: - List[List[int]]: A list containing k_partitions lists. Each inner list contains the - original indices of the items assigned to that partition. The indices - within each partition list are sorted. - - Raises: - AssertionError: If len(seqlen_list) < k_partitions. - AssertionError: If equal_size is True and len(seqlen_list) is not divisible by k_partitions. - AssertionError: If any resulting partition is empty. - """ - assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" - - def _check_and_sort_partitions(partitions): - assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" - seen_idx = set() - sorted_partitions = [None] * k_partitions - for i, partition in enumerate(partitions): - assert len(partition) > 0, f"the {i}-th partition is empty" - for idx in partition: - seen_idx.add(idx) - sorted_partitions[i] = sorted(partition) - assert seen_idx == set(range(len(seqlen_list))) - return sorted_partitions - - partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) - return _check_and_sort_partitions(partitions) - - -def log_seqlen_unbalance(seqlen_list: list[int], partitions: list[list[int]], prefix): - """ - Calculate and log metrics related to sequence length imbalance before and after partitioning. - - Args: - seqlen_list (List[int]): A list of sequence lengths for each item. - partitions (List[List[int]]): A list of partitions, where each inner list contains indices - from seqlen_list assigned to that partition. - prefix (str): A prefix to be added to each metric key in the returned dictionary. - - Returns: - dict: A dictionary containing metrics related to sequence length imbalance. - """ - # Get the number of partitions - k_partition = len(partitions) - # assert len(seqlen_list) % k_partition == 0 - batch_size = len(seqlen_list) // k_partition - min_sum_seqlen = None - max_sum_seqlen = None - total_sum_seqlen = 0 - - # Iterate over each batch of sequence lengths - for offset in range(0, len(seqlen_list), batch_size): - cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size]) - if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen: - min_sum_seqlen = cur_sum_seqlen - if max_sum_seqlen is None or cur_sum_seqlen > max_sum_seqlen: - max_sum_seqlen = cur_sum_seqlen - total_sum_seqlen += cur_sum_seqlen - - balanced_sum_seqlen_list = [] - for partition in partitions: - cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition]) - balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced) - # print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list) - min_sum_seqlen_balanced = min(balanced_sum_seqlen_list) - max_sum_seqlen_balanced = max(balanced_sum_seqlen_list) - - return { - f"{prefix}/min": min_sum_seqlen, - f"{prefix}/max": max_sum_seqlen, - f"{prefix}/minmax_diff": max_sum_seqlen - min_sum_seqlen, - f"{prefix}/balanced_min": min_sum_seqlen_balanced, - f"{prefix}/balanced_max": max_sum_seqlen_balanced, - f"{prefix}/mean": total_sum_seqlen / len(partitions), - } - - -def ceildiv(a, b): - return -(a // -b) - - -def roundup_divisible(a, b): - return ((a + b - 1) // b) * b - - -def rearrange_micro_batches( - batch, - max_token_len, - dp_group=None, - num_batches_divided_by=None, - same_micro_num_in_dp=True, - min_num_micro_batch=None, - use_dynamic_bsz_balance=True, -): - """ - Split a batch into micro-batches by total token count, with optional DP sync and padding. - - Args: - batch (TensorDict): must include "attention_mask" (B*S); other fields are sliced similarly. - max_token_len (int): max sum of attention_mask per micro-batch. - dp_group (optional): torch.distributed group for data-parallel sync. - num_batches_divided_by (optional): virtual pipeline parallel size, for megatron. - same_micro_num_in_dp (bool): if True and dp_group set, pad all ranks to the same count. - min_num_micro_batch (int, optional): force at least this many splits (pads empty ones). - use_dynamic_bsz_balance (bool, optional): balance the computational workload between micro-batches - - Returns: - List[TensorDict]: the micro-batches. - List[List[int]]: index lists mapping each micro-batch back to original positions. - """ - # this is per local micro_bsz - max_seq_len = batch["attention_mask"].shape[-1] - assert max_token_len >= max_seq_len, ( - f"max_token_len must be greater than the sequence length. Got {max_token_len=} and {max_seq_len=}" - ) - seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) - total_seqlen = seq_len_effective.sum().item() - # NOTE: num_microbatches <= batch_size, so take the min of this two. - num_micro_batches = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len)) - if min_num_micro_batch is not None: - # used to support pp - num_micro_batches = max(min_num_micro_batch, num_micro_batches) - if dist.is_initialized() and same_micro_num_in_dp: - num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name()) - dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) - num_micro_batches = num_micro_batches.cpu().item() - if num_batches_divided_by is not None: - num_micro_batches = roundup_divisible(num_micro_batches, num_batches_divided_by) - - seq_len_effective = seq_len_effective.tolist() - assert num_micro_batches <= len(seq_len_effective) - - micro_bsz_idx = get_seqlen_balanced_partitions(seq_len_effective, num_micro_batches, equal_size=False) - - if use_dynamic_bsz_balance: - # Use the sum of squared sequence lengths to approximate attention computation workload - micro_bsz_idx.sort( - key=lambda partition: ( - sum(seq_len_effective[idx] ** 2 for idx in partition), - min(partition) if partition else 0, - ), - reverse=True, - ) - - micro_batches = [] - - for partition in micro_bsz_idx: - curr_micro_batch = [] - for idx in partition: - curr_micro_batch.append(batch[idx : idx + 1]) - curr_micro_batch = torch.cat(curr_micro_batch) - - micro_batches.append(curr_micro_batch) - - return micro_batches, micro_bsz_idx - - -def get_reverse_idx(idx_map): - """ - Build the inverse of an index mapping. - - Args: - idx_map (Sequence[int]): Sequence where idx_map[i] = j. - - Returns: - List[int]: Inverse mapping list such that output[j] = i for each i. - """ - reverse_idx_map = copy.deepcopy(idx_map) - - for i, idx in enumerate(idx_map): - reverse_idx_map[idx] = i - - return reverse_idx_map - - -def prepare_dynamic_batch(data: DataProto, max_token_len: int) -> tuple[list[DataProto], list[list[int]]]: - """ - Prepare a batch for dynamic batching. - - Args: - data (DataProto): The input data. - max_token_len (int): The maximum token length for dynamic batching. - - Returns: - Tuple[List[DataProto], List[List[int]]]: A tuple containing a list of DataProto objects - and a list of index lists. - """ - batch, batch_idx_list = rearrange_micro_batches(data.batch, max_token_len=max_token_len) - micro_batches = [] - for i, batch_idx in enumerate(batch_idx_list): - tensors = dict(batch[i]) - non_tensors = {key: value[batch_idx] for key, value in data.non_tensor_batch.items()} - micro_batches.append(DataProto.from_dict(tensors, non_tensors)) - - return micro_batches, batch_idx_list - - -def restore_dynamic_batch(data: torch.Tensor, batch_idx_list: list[list[int]]) -> torch.Tensor: - """ - Restore a batch from dynamic batching. - - Args: - data (torch.Tensor): The input data. - batch_idx_list (List[List[int]]): The list of index lists. - - Returns: - torch.Tensor: The restored data. - """ - indices = list(chain.from_iterable(batch_idx_list)) - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - return data[revert_indices] diff --git a/verl/utils/tokenizer.py b/verl/utils/tokenizer.py deleted file mode 100644 index 668ea3e14..000000000 --- a/verl/utils/tokenizer.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utils for tokenization.""" - -import warnings - -__all__ = ["hf_tokenizer", "hf_processor"] - - -def set_pad_token_id(tokenizer): - """Set pad_token_id to eos_token_id if it is None. - - Args: - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set. - - """ - if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id - warnings.warn(f"tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}", stacklevel=1) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - warnings.warn(f"tokenizer.pad_token is None. Now set to {tokenizer.eos_token}", stacklevel=1) - - -def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs): - """Create a huggingface pretrained tokenizer which correctness handles eos and pad tokens. - - Args: - - name (str): The name of the tokenizer. - correct_pad_token (bool): Whether to correct the pad token id. - correct_gemma2 (bool): Whether to correct the gemma2 tokenizer. - - Returns: - - transformers.PreTrainedTokenizer: The pretrained tokenizer. - - """ - from transformers import AutoTokenizer - - if correct_gemma2 and isinstance(name_or_path, str) and "gemma-2-2b-it" in name_or_path: - # the EOS token in gemma2 is ambiguious, which may worsen RL performance. - # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a - warnings.warn( - "Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.", stacklevel=1 - ) - kwargs["eos_token"] = "" - kwargs["eos_token_id"] = 107 - tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) - if correct_pad_token: - set_pad_token_id(tokenizer) - return tokenizer - - -def hf_processor(name_or_path, **kwargs): - """Create a huggingface processor to process multimodal data. - - Args: - name_or_path (str): The name of the processor. - - Returns: - transformers.ProcessorMixin: The pretrained processor. - """ - from transformers import AutoProcessor - - try: - processor = AutoProcessor.from_pretrained(name_or_path, **kwargs) - except Exception as e: - processor = None - # TODO(haibin.lin): try-catch should be removed after adding transformer version req to setup.py to avoid - # silent failure - warnings.warn(f"Failed to create processor: {e}. This may affect multimodal processing", stacklevel=1) - # Avoid load tokenizer, see: - # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 - if processor is not None and "Processor" not in processor.__class__.__name__: - processor = None - return processor diff --git a/verl/utils/torch_dtypes.py b/verl/utils/torch_dtypes.py deleted file mode 100644 index f2f445c26..000000000 --- a/verl/utils/torch_dtypes.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Adapted from Cruise. -""" - -import torch - -HALF_LIST = [16, "16", "fp16", "float16", torch.float16] -FLOAT_LIST = [32, "32", "fp32", "float32", torch.float32] -BFLOAT_LIST = ["bf16", "bfloat16", torch.bfloat16] - - -class PrecisionType: - """Type of precision used. - - >>> PrecisionType.HALF == 16 - True - >>> PrecisionType.HALF in (16, "16") - True - """ - - HALF = "16" - FLOAT = "32" - FULL = "64" - BFLOAT = "bf16" - MIXED = "mixed" - - @staticmethod - def supported_type(precision: str | int) -> bool: - return any(x == precision for x in PrecisionType) - - @staticmethod - def supported_types() -> list[str]: - return [x.value for x in PrecisionType] - - @staticmethod - def is_fp16(precision): - return precision in HALF_LIST - - @staticmethod - def is_fp32(precision): - return precision in FLOAT_LIST - - @staticmethod - def is_bf16(precision): - return precision in BFLOAT_LIST - - @staticmethod - def to_dtype(precision): - if precision in HALF_LIST: - return torch.float16 - elif precision in FLOAT_LIST: - return torch.float32 - elif precision in BFLOAT_LIST: - return torch.bfloat16 - else: - raise RuntimeError(f"unexpected precision: {precision}") - - @staticmethod - def to_str(precision): - if precision == torch.float16: - return "fp16" - elif precision == torch.float32: - return "fp32" - elif precision == torch.bfloat16: - return "bf16" - else: - raise RuntimeError(f"unexpected precision: {precision}") diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py deleted file mode 100644 index df91ad778..000000000 --- a/verl/utils/torch_functional.py +++ /dev/null @@ -1,771 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Contain small torch utilities -""" - -import math -from contextlib import contextmanager -from typing import Optional - -import torch -import torch.distributed -import torch.nn.functional as F -from tensordict import TensorDict -from torch import nn -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR -from transformers import PreTrainedTokenizer - -from verl.utils.device import get_device_name, get_torch_device - -try: - from flash_attn.ops.triton.cross_entropy import cross_entropy_loss - - FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True -except ImportError: - FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False - - -try: - import torch_npu - - NPU_CROSS_ENTROPY_LOSS_AVAILABLE = hasattr(torch_npu, "npu_cross_entropy_loss") -except ImportError: - NPU_CROSS_ENTROPY_LOSS_AVAILABLE = False - - -def gather_from_labels(data, label): - """Gather the label from data. The value in label should be [0, vocab_size) - - Args: - data: (..., vocab_size) - label (torch.IntTensor) : (...,) - - Returns: - - """ - - output = torch.gather(data, -1, label.unsqueeze(-1)).squeeze(-1) - return output - - -def logprobs_from_logits(logits, labels, inplace_backward=True): - """ - Compute per-token log-probabilities for the given labels. - - Uses a Flash-Attention–based cross-entropy (if available) for efficient backward, - otherwise falls back to a standard log-softmax+gather approach. - - See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 - - Args: - logits (Tensor): Model outputs of shape (..., vocab_size). - labels (LongTensor): True class indices of shape matching logits[..., :-1]. - inplace_backward (bool): If True and Flash-Attn is available, perform backward in-place. - - Returns: - Tensor: Log-probabilities of the target labels, shape logits.shape[:-1]. - """ - if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: - batch_dim = logits.shape[:-1] - last_dim = logits.shape[-1] - logits = logits.reshape(-1, last_dim) - labels = labels.reshape(-1) - output = logprobs_from_logits_flash_attn(logits, labels, inplace_backward=inplace_backward) - output = output.view(*batch_dim) - elif NPU_CROSS_ENTROPY_LOSS_AVAILABLE: - output = logprobs_from_logits_torch_npu(logits, labels) - else: - output = logprobs_from_logits_v2(logits, labels) - return output - - -def logprobs_from_logits_flash_attn(logits, labels, inplace_backward=True): - output = cross_entropy_loss(logits, labels, inplace_backward=inplace_backward) - assert isinstance(output, tuple), ( - "please make sure flash-attn>=2.4.3 where cross_entropy_loss returns Tuple[losses, z_losses]." - ) - return -output[0] - - -def logprobs_from_logits_torch_npu(logits, labels): - batch_dim = logits.shape[:-1] - logits = logits.reshape(-1, logits.shape[-1]) - loss, _, _, _ = torch_npu.npu_cross_entropy_loss(logits, labels.reshape(-1), reduction="none") - return -loss.view(*batch_dim) - - -def logprobs_from_logits_naive(logits, labels): - logp = F.log_softmax(logits, dim=-1) - logpy = gather_from_labels(logp, labels) - return logpy - - -def logprobs_from_logits_v2(logits: torch.FloatTensor, labels): - """ - A memory efficient implementation of logprobs_from_logits - """ - if logits.dtype in [torch.float32, torch.float64]: - logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) - # loop to reduce peak mem consumption - logsumexp_values = torch.stack([torch.logsumexp(logit, dim=-1) for logit in logits]) - logprobs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) - else: - # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach - logprobs_labels = [] - for row_logits, row_labels in zip(logits, labels, strict=True): # loop to reduce peak mem consumption - row_logprobs = F.log_softmax(row_logits, dim=-1) - row_logprobs_labels = row_logprobs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) - logprobs_labels.append(row_logprobs_labels) - logprobs_labels = torch.stack(logprobs_labels) - return logprobs_labels - - -def clip_by_value(x, tensor_min, tensor_max): - """ - Tensor extenstion to torch.clamp - https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713 - """ - clipped = torch.max(torch.min(x, tensor_max), tensor_min) - return clipped - - -def entropy_from_logits(logits: torch.Tensor): - """Calculate entropy from logits.""" - pd = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1) - return entropy - - -def entropy_from_logits_with_chunking(logits: torch.Tensor, chunk_size: int = 2048): - """Memory-efficient entropy calculation with chunking.""" - entropy = torch.zeros(logits.shape[0], device=logits.device) - for i in range(0, logits.shape[0], chunk_size): - logits_chunk = logits[i : i + chunk_size].float() - pd_chunk = torch.nn.functional.softmax(logits_chunk, dim=-1) - entropy_chunk = torch.logsumexp(logits_chunk, dim=-1) - torch.sum(pd_chunk * logits_chunk, dim=-1) - entropy[i : i + chunk_size] = entropy_chunk - return entropy - - -def masked_sum(values, mask, axis=None): - """Compute mean of tensor with a masked values.""" - # If NaNs exist out of mask, replace NaNs in values with a value that - # won't affect the sum (e.g., 0 for masked regions) - valid_values = torch.where(mask.bool(), values, 0.0) - return (valid_values * mask).sum(axis=axis) - - -def masked_mean(values, mask, axis=None): - """ - Compute the mean of `values` over elements selected by `mask`. - - Args: - values (Tensor): Input tensor. - mask (Tensor): Boolean or numeric mask of the same shape as `values`. - axis (int or tuple of int, optional): Dimension(s) along which to compute the mean. - Defaults to None (over all elements). - - Returns: - Tensor: Masked mean, with shape equal to `values` reduced over `axis`. - """ - s = masked_sum(values, mask, axis) - return s / (mask.sum(axis=axis) + 1e-8) - - -def masked_var(values, mask, unbiased=True): - """Compute variance of tensor with masked values.""" - mean = masked_mean(values, mask) - centered_values = values - mean - variance = masked_mean(centered_values**2, mask) - if unbiased: - mask_sum = mask.sum() - if mask_sum == 0: - raise ValueError("At least one element in the mask has to be 1.") - # note that if mask_sum == 1, then there is a division by zero issue - # to avoid it you just need to use a larger minibatch_size - if mask_sum == 1: - raise ValueError("The sum of the mask is one, which can cause a division by zero.") - bessel_correction = mask_sum / (mask_sum - 1) - variance = variance * bessel_correction - return variance - - -def masked_whiten(values, mask, shift_mean=True): - """ - Whiten `values` by normalizing with mean and variance computed over `mask`. - - Args: - values (torch.Tensor): Input tensor. - mask (torch.Tensor): Boolean tensor of same shape, selects elements for stats. - shift_mean (bool): If True (default), output is zero-mean; - if False, the original mean is re-added after scaling. - - Returns: - torch.Tensor: Whitened tensor of same shape as `values`. - """ - mean, var = masked_mean(values, mask), masked_var(values, mask) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -def get_response_mask(response_id: torch.Tensor, eos_token: int | list[int] = 2, dtype=torch.int64): - """ - end of sentence token can be int or list: 1 or [1, 2] - e.g. - response_id = torch.tensor([[20, 10, 34, 1, 0, 0, 0], - [78, 0, 76, 2, 1, 0, 0], - [23, 98, 1, 0, 0, 0, 0], - [33, 3, 98, 45, 1, 0, 0]]) - #eos_token=1 - response_mask: tensor([[1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0]]) - #eos_token=[1,2] - response_mask: tensor([[1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0]]) - """ - eos_mask = torch.isin(response_id, torch.tensor(eos_token, device=response_id.device)).int() - return (eos_mask.cumsum(dim=1) - eos_mask).eq(0).to(dtype) - - -def compute_grad_norm(model: nn.Module): - total_grad_square = 0 - for param in model.parameters(): - if param.grad is not None: - total_grad_square += torch.sum(torch.square(param.grad.detach())).item() - return total_grad_square - - -def broadcast_dict_tensor(tensors: dict[str, torch.Tensor] | TensorDict, src, group): - """ - TODO: optimize this. Technically, we only need one broadcast - """ - - for key in tensors.sorted_keys: - torch.distributed.broadcast(tensors[key], src=src, group=group, async_op=False) - - -def allgather_dict_tensors(tensors: dict[str, torch.Tensor] | TensorDict, size, group, dim=0): - """ - TODO: optimize this. - - We can use async ops - - We can use only one allgather - Args: - tensors: - size: - group: - - Returns: - - """ - if isinstance(tensors, TensorDict): - is_tensor_dict = True - tensors_as_dict = tensors.to_dict() - else: - tensors_as_dict = tensors - is_tensor_dict = False - - output = {} - sorted_keys = sorted(tensors_as_dict.keys()) - for key in sorted_keys: - val = tensors_as_dict[key] - output[key] = [torch.empty_like(val) for _ in range(size)] - torch.distributed.all_gather(output[key], val, group=group, async_op=False) - output[key] = torch.cat(output[key], dim=dim) - - if is_tensor_dict: - output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size) - - return output - - -def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> list[TensorDict]: - assert tensors.batch_size[0] % batch_size == 0, ( - f"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}" - ) - return tensors.split(batch_size) - - -def pad_2d_list_to_length(response, pad_token_id, max_length=None): - """ - pad a 2D list (e.g. responses, logprobs) to a 2D tensor. - """ - response_length = max(len(sub_list) for sub_list in response) - target_length = max_length if max_length is not None and max_length > response_length else response_length - padded_response = [tuple(sub_list) + (pad_token_id,) * (target_length - len(sub_list)) for sub_list in response] - tensor = torch.tensor(padded_response) - return tensor - - -def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False): - """ - pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length. - input shape: [bs, seq_length] - output shape: [bs, max_seq_length] - """ - if tensors.shape[-1] >= max_seq_len: - return tensors - # (0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad - pad_tuple = (max_seq_len - tensors.shape[-1], 0) if left_pad else (0, max_seq_len - tensors.shape[-1]) - return F.pad(tensors, pad_tuple, "constant", pad_token_id) - - -def postprocess_data( - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - max_length: int, - pad_token_id: int, - left_pad=True, - truncation="error", -): - """Process tokenizer outputs to consistent shapes via padding/truncation. - - Args: - input_ids: Token indices [batch_size, seq_len] - attention_mask: Mask [batch_size, seq_len] - max_length: Target sequence length - pad_token_id: Padding token ID - left_pad: Pad left if True - truncation: "left", "right", "middle" or "error" - - Returns: - (input_ids, attention_mask) padded/truncated to max_length - """ - assert truncation in ["left", "right", "middle", "error"] - assert input_ids.ndim == 2 - - sequence_length = input_ids.shape[-1] - if sequence_length < max_length: - input_ids = pad_sequence_to_length( - input_ids, max_seq_len=max_length, pad_token_id=pad_token_id, left_pad=left_pad - ) - attention_mask = pad_sequence_to_length( - attention_mask, max_seq_len=max_length, pad_token_id=0, left_pad=left_pad - ) - elif sequence_length > max_length: - if truncation == "left": - # actually, left truncation may not be reasonable - input_ids = input_ids[:, -max_length:] - attention_mask = attention_mask[:, -max_length:] - elif truncation == "right": - input_ids = input_ids[:, :max_length] - attention_mask = attention_mask[:, :max_length] - elif truncation == "middle": - left_half = max_length // 2 - right_half = max_length - left_half - input_ids = torch.cat([input_ids[:, :left_half], input_ids[:, -right_half:]], dim=-1) - attention_mask = torch.cat([attention_mask[:, :left_half], attention_mask[:, -right_half:]], dim=-1) - elif truncation == "error": - raise NotImplementedError(f"{sequence_length=} is larger than {max_length=}") - else: - raise NotImplementedError(f"Unknown truncation method {truncation}") - - return input_ids, attention_mask - - -def tokenize_and_postprocess_data( - prompt: str, tokenizer: PreTrainedTokenizer, max_length: int, pad_token_id: int, left_pad=True, truncation="error" -): - """Tokenize text and process outputs to consistent tensor shapes. - - Args: - prompt: Input text to tokenize - tokenizer: HuggingFace tokenizer instance - max_length: Target sequence length - pad_token_id: Padding token ID - left_pad: Pad left if True - truncation: Truncation strategy ("left"/"right"/"error") - - Returns: - Tuple of (input_ids, attention_mask) from postprocess_data - """ - input_data = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) - input_ids = input_data["input_ids"] - attention_mask = input_data["attention_mask"] - - return postprocess_data(input_ids, attention_mask, max_length, pad_token_id, left_pad, truncation) - - -def remove_pad_token(input_ids: torch.Tensor, attention_mask: torch.Tensor): - """Remove the pad token. - - Args: - input_ids shape: [bs, seq_length] - attention_mask shape: [bs, seq_length] - Returns: - no_padding_batch(List[List[int]]): contains the rmpad token ids per query. - """ - no_padding_batch = [] - for ids, mask in zip(input_ids, attention_mask, strict=True): - no_padding_batch.append((ids[len(ids) - mask.sum() :]).cpu().numpy().tolist()) - return no_padding_batch - - -def log_probs_from_logits_response(input_ids, logits, response_length): - """Compute the response log_probs from full logits. Note that logits = model(input_ids) - - Args: - input_ids: [batch_size, seqlen] - logits: [batch_size, seqlen, vocab_size] - - Returns: - response_log_prob: - """ - response_logits = logits[:, -response_length - 1 : -1] - response = input_ids[:, -response_length:] - response_log_prob = logprobs_from_logits(logits=response_logits, labels=response) - return response_log_prob - - -def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length): - """Compute the log_probs from logits with rmpad logits and pad input. Note that - logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between - logits and input_ids. - The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive - for large vocab_size - - Args: - input_ids: [batch_size, seqlen] - attention_mask: [batch_size, seqlen] - logits_rmpad: [total_nnz, vocab_size] - response_length: int - """ - from flash_attn.bert_padding import pad_input, unpad_input - - batch_size, seqlen = input_ids.shape - input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask=attention_mask) - input_ids_rmpad = input_ids_rmpad.squeeze(-1) - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) - full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) - full_output = pad_input( - hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen - ) - output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] - return output - - -def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length): - """Compute the log_probs from logits with rmpad input_ids and logits. Note that - logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between - logits and input_ids. - The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive - for large vocab_size - - Args: - input_ids_rmpad: [1, total_nnz] - logits_rmpad: [total_nnz, vocab_size] - indices: [total_nnz] - batch_size: int - seqlen: int - response_length: int - """ - from flash_attn.bert_padding import pad_input - - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # transpose back to [total_nnz, 1] - input_ids_rmpad = input_ids_rmpad.squeeze(-1) - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0) - full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,) - full_output = pad_input( - hidden_states=full_log_probs_rmpad.unsqueeze(-1), indices=indices, batch=batch_size, seqlen=seqlen - ) - output = full_output.squeeze(-1)[:, -response_length - 1 : -1] # [batch_size, response_length] - return output - - -def post_process_logits(input_ids, logits, temperature, top_k, top_p): - if temperature != 1.0: - logits = logits.div_(temperature) # inplace operation to avoid OOM - # TODO: add them back - # if top_k is not None and top_k > 0: - # logits = TopKLogitsWarper(top_k=top_k)(input_ids, logits) - # if top_p is not None and top_p < 1.0 and top_p > 0.0: - # logits = TopPLogitsWarper(top_p=top_p)(input_ids, logits) - return logits - - -""" -Optimizer related -""" - - -def get_cosine_schedule_with_warmup( - optimizer: Optimizer, - num_warmup_steps: int, - num_training_steps: int, - min_lr_ratio: float = 0.0, - num_cycles: float = 0.5, - last_epoch: int = -1, -): - """ - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the - initial lr set in the optimizer. - Args: - optimizer (:class:`~torch.optim.Optimizer`): - The optimizer for which to schedule the learning rate. - num_warmup_steps (:obj:`int`): - The number of steps for the warmup phase. - num_training_steps (:obj:`int`): - The total number of training steps. - min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0): - The minimum lr ratio w.r.t the maximum. - num_cycles (:obj:`float`, `optional`, defaults to 0.5): - The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 - following a half-cosine). - last_epoch (:obj:`int`, `optional`, defaults to -1): - The index of the last epoch when resuming training. - Return: - :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - min_lr_ratio = 0.0 if min_lr_ratio is None else min_lr_ratio - assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0 - coef = (1 - min_lr_ratio) * 0.5 - intercept = (1 + min_lr_ratio) * 0.5 - - def lr_lambda(current_step): - if current_step < num_warmup_steps: - return min_lr_ratio + (1.0 - min_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps))) - progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) - x = math.cos(math.pi * float(num_cycles) * 2.0 * progress) - return max(min_lr_ratio, x * coef + intercept) - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def get_constant_schedule_with_warmup( - optimizer: Optimizer, - num_warmup_steps: int, - last_epoch: int = -1, -): - """ - Create a constant LR schedule with a linear warmup phase. - - Args: - optimizer (Optimizer): Wrapped optimizer. - num_warmup_steps (int): Number of steps to ramp up the LR from 0 to initial value. - last_epoch (int, optional): The index of the last epoch when resuming training. Defaults to -1. - - Returns: - LambdaLR: Scheduler that increases LR linearly during warmup, then holds it constant. - """ - - def lr_lambda(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1.0, num_warmup_steps)) - return 1.0 - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -def get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def get_wsd_schedule_with_warmup( - optimizer: Optimizer, - num_warmup_steps: int, - num_training_steps: int, - min_lr_ratio: float = 0.0, - num_cycles: float = 0.5, - last_epoch: int = -1, - stable_ratio: float = 0.9, -): - """ - Create a Warmup-Stable-Decay learning rate scheduler. - - The schedule follows three phases: - 1. Warmup: Learning rate increases linearly from 0 to the initial LR - 2. Stable: Learning rate remains constant at the initial LR - 3. Decay: Learning rate decreases following a cosine curve to min_lr_ratio * initial LR - - Args: - optimizer (:class:`~torch.optim.Optimizer`): - The optimizer for which to schedule the learning rate. - num_warmup_steps (:obj:`int`): - The number of steps for the warmup phase. - num_training_steps (:obj:`int`): - The total number of training steps. - min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0): - The minimum learning rate ratio w.r.t the initial learning rate. - num_cycles (:obj:`float`, `optional`, defaults to 0.5): - The number of waves in the cosine schedule during decay phase. - last_epoch (:obj:`int`, `optional`, defaults to -1): - The index of the last epoch when resuming training. - stable_ratio (:obj:`float`, `optional`, defaults to 0.0): - The ratio of non-warmup steps that should maintain a constant learning rate. - Set to 0.0 to behave exactly like cosine schedule. - - Return: - :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - remaining_steps = max(0, num_training_steps - num_warmup_steps) - num_stable_steps = int(remaining_steps * stable_ratio) - num_decay_steps = remaining_steps - num_stable_steps - - def lr_lambda(current_step): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - if current_step < num_warmup_steps + num_stable_steps: - return 1.0 - if current_step < num_training_steps: - progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps)) - value = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) - return (1.0 - min_lr_ratio) * value + min_lr_ratio - return min_lr_ratio - - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -@contextmanager -def check_device_is_available(): - """ - Some modules must be imported after CUDA is initialized. Such as sglang's sharding manager. - - This context manager checks if CUDA is available and raises an error if it is not. - """ - if not get_torch_device().is_available(): - raise RuntimeError("Device {} must be initialized before importing this module.".format(get_device_name())) - - yield - - -def distributed_mean_max_min_std(local_tensor, compute_max=True, compute_min=True, compute_std=True): - """Compute distributed statistics across all processes. - - Args: - local_tensor: Tensor containing local values - compute_max: Include maximum value calculation - compute_min: Include minimum value calculation - compute_std: Include standard deviation calculation - - Returns: - Tuple containing (mean, max, min, std) in this order. None for disabled metrics. - """ - # Sum the local tensor across all processes - local_sum = torch.sum(local_tensor) - local_num = torch.tensor(torch.numel(local_tensor), device=get_device_name()) - - torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM) - - global_mean = local_sum / local_num - - if compute_max: - local_max = torch.max(local_tensor) - torch.distributed.all_reduce(local_max, op=torch.distributed.ReduceOp.MAX) - else: - local_max = None - - if compute_min: - local_min = torch.min(local_tensor) - torch.distributed.all_reduce(local_min, op=torch.distributed.ReduceOp.MIN) - else: - local_min = None - - if compute_std: - square_diff = torch.sum(torch.pow(local_tensor - global_mean, 2)) - torch.distributed.all_reduce(square_diff, op=torch.distributed.ReduceOp.SUM) - global_std = torch.sqrt(square_diff / (local_num - 1)) - else: - global_std = None - - return global_mean, local_max, local_min, global_std - - -def distributed_masked_mean(local_tensor, local_mask): - """Compute global mean of non-masked elements across distributed processes. - - Args: - local_tensor (torch.Tensor): Input tensor with local values - local_mask (torch.Tensor): Binary mask (1=valid, 0=ignore) matching local_tensor shape - - Returns: - torch.Tensor: Global mean of all valid elements across processes - """ - local_tensor = local_tensor * local_mask - - local_sum = torch.sum(local_tensor) - local_num = torch.sum(local_mask) - - torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce(local_num, op=torch.distributed.ReduceOp.SUM) - - global_mean = local_sum / local_num - return global_mean diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py deleted file mode 100644 index 877368650..000000000 --- a/verl/utils/tracking.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -A unified tracking interface that supports logging data to different backend -""" - -import dataclasses -import os -from enum import Enum -from functools import partial -from pathlib import Path -from typing import Any - - -class Tracking: - """A unified tracking interface for logging experiment data to multiple backends. - - This class provides a centralized way to log experiment metrics, parameters, and artifacts - to various tracking backends including WandB, MLflow, SwanLab, TensorBoard, and console. - - Attributes: - supported_backend: List of supported tracking backends. - logger: Dictionary of initialized logger instances for each backend. - """ - - supported_backend = ["wandb", "mlflow", "swanlab", "vemlp_wandb", "tensorboard", "console", "clearml"] - - def __init__(self, project_name, experiment_name, default_backend: str | list[str] = "console", config=None): - if isinstance(default_backend, str): - default_backend = [default_backend] - for backend in default_backend: - if backend == "tracking": - import warnings - - warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning, stacklevel=2) - else: - assert backend in self.supported_backend, f"{backend} is not supported" - - self.logger = {} - - if "tracking" in default_backend or "wandb" in default_backend: - import wandb - - settings = None - if config and config["trainer"].get("wandb_proxy", None): - settings = wandb.Settings(https_proxy=config["trainer"]["wandb_proxy"]) - wandb.init(project=project_name, name=experiment_name, config=config, settings=settings) - self.logger["wandb"] = wandb - - if "mlflow" in default_backend: - import os - - import mlflow - - MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "sqlite:////tmp/mlruns.db") - mlflow.set_tracking_uri(MLFLOW_TRACKING_URI) - - # Project_name is actually experiment_name in MLFlow - # If experiment does not exist, will create a new experiment - experiment = mlflow.set_experiment(project_name) - mlflow.start_run(experiment_id=experiment.experiment_id, run_name=experiment_name) - mlflow.log_params(_compute_mlflow_params_from_objects(config)) - self.logger["mlflow"] = _MlflowLoggingAdapter() - - if "swanlab" in default_backend: - import os - - import swanlab - - SWANLAB_API_KEY = os.environ.get("SWANLAB_API_KEY", None) - SWANLAB_LOG_DIR = os.environ.get("SWANLAB_LOG_DIR", "swanlog") - SWANLAB_MODE = os.environ.get("SWANLAB_MODE", "cloud") - if SWANLAB_API_KEY: - swanlab.login(SWANLAB_API_KEY) # NOTE: previous login information will be overwritten - - if config is None: - config = {} # make sure config is not None, otherwise **config will raise error - swanlab.init( - project=project_name, - experiment_name=experiment_name, - config={"FRAMEWORK": "verl", **config}, - logdir=SWANLAB_LOG_DIR, - mode=SWANLAB_MODE, - ) - self.logger["swanlab"] = swanlab - - if "vemlp_wandb" in default_backend: - import os - - import volcengine_ml_platform - from volcengine_ml_platform import wandb as vemlp_wandb - - volcengine_ml_platform.init( - ak=os.environ["VOLC_ACCESS_KEY_ID"], - sk=os.environ["VOLC_SECRET_ACCESS_KEY"], - region=os.environ["MLP_TRACKING_REGION"], - ) - - vemlp_wandb.init( - project=project_name, - name=experiment_name, - config=config, - sync_tensorboard=True, - ) - self.logger["vemlp_wandb"] = vemlp_wandb - - if "tensorboard" in default_backend: - self.logger["tensorboard"] = _TensorboardAdapter(project_name, experiment_name) - - if "console" in default_backend: - from verl.utils.logger import LocalLogger - - self.console_logger = LocalLogger(print_to_console=True) - self.logger["console"] = self.console_logger - - if "clearml" in default_backend: - self.logger["clearml"] = ClearMLLogger(project_name, experiment_name, config) - - def log(self, data, step, backend=None): - for default_backend, logger_instance in self.logger.items(): - if backend is None or default_backend in backend: - logger_instance.log(data=data, step=step) - - def __del__(self): - if "wandb" in self.logger: - self.logger["wandb"].finish(exit_code=0) - if "swanlab" in self.logger: - self.logger["swanlab"].finish() - if "vemlp_wandb" in self.logger: - self.logger["vemlp_wandb"].finish(exit_code=0) - if "tensorboard" in self.logger: - self.logger["tensorboard"].finish() - - if "clearnml" in self.logger: - self.logger["clearnml"].finish() - - -class ClearMLLogger: - def __init__(self, project_name: str, experiment_name: str, config): - self.project_name = project_name - self.experiment_name = experiment_name - - import clearml - - self._task: clearml.Task = clearml.Task.init( - task_name=experiment_name, - project_name=project_name, - continue_last_task=True, - output_uri=False, - ) - - self._task.connect_configuration(config, name="Hyperparameters") - - def _get_logger(self): - return self._task.get_logger() - - def log(self, data, step): - import numpy as np - import pandas as pd - - # logs = self._rewrite_logs(data) - logger = self._get_logger() - for k, v in data.items(): - title, series = k.split("/", 1) - - if isinstance(v, int | float | np.floating | np.integer): - logger.report_scalar( - title=title, - series=series, - value=v, - iteration=step, - ) - elif isinstance(v, pd.DataFrame): - logger.report_table( - title=title, - series=series, - table_plot=v, - iteration=step, - ) - else: - logger.warning( - f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}". This ' - f"invocation of ClearML logger's function is incorrect so this attribute was dropped. " - ) - - def finish(self): - self._task.mark_completed() - - -class _TensorboardAdapter: - def __init__(self, project_name, experiment_name): - import os - - from torch.utils.tensorboard import SummaryWriter - - tensorboard_dir = os.environ.get("TENSORBOARD_DIR", f"tensorboard_log/{project_name}/{experiment_name}") - os.makedirs(tensorboard_dir, exist_ok=True) - print(f"Saving tensorboard log to {tensorboard_dir}.") - self.writer = SummaryWriter(tensorboard_dir) - - def log(self, data, step): - for key in data: - self.writer.add_scalar(key, data[key], step) - - def finish(self): - self.writer.close() - - -class _MlflowLoggingAdapter: - def log(self, data, step): - import mlflow - - results = {k.replace("@", "_at_"): v for k, v in data.items()} - mlflow.log_metrics(metrics=results, step=step) - - -def _compute_mlflow_params_from_objects(params) -> dict[str, Any]: - if params is None: - return {} - - return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep="/") - - -def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): - _transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict) - - if dataclasses.is_dataclass(x): - return _transform(dataclasses.asdict(x)) - if isinstance(x, dict): - return {k: _transform(v) for k, v in x.items()} - if isinstance(x, list): - if convert_list_to_dict: - return {"list_len": len(x)} | {f"{i}": _transform(v) for i, v in enumerate(x)} - else: - return [_transform(v) for v in x] - if isinstance(x, Path): - return str(x) - if isinstance(x, Enum): - return x.value - - return x - - -def _flatten_dict(raw: dict[str, Any], *, sep: str) -> dict[str, Any]: - import pandas as pd - - ans = pd.json_normalize(raw, sep=sep).to_dict(orient="records")[0] - assert isinstance(ans, dict) - return ans - - -@dataclasses.dataclass -class ValidationGenerationsLogger: - def log(self, loggers, samples, step): - if "wandb" in loggers: - self.log_generations_to_wandb(samples, step) - if "swanlab" in loggers: - self.log_generations_to_swanlab(samples, step) - if "mlflow" in loggers: - self.log_generations_to_mlflow(samples, step) - - if "clearml" in loggers: - self.log_generations_to_clearml(samples, step) - if "tensorboard" in loggers: - self.log_generations_to_tensorboard(samples, step) - - if "vemlp_wandb" in loggers: - self.log_generations_to_vemlp_wandb(samples, step) - - def log_generations_to_vemlp_wandb(self, samples, step): - from volcengine_ml_platform import wandb as vemlp_wandb - - self._log_generations_to_wandb(samples, step, vemlp_wandb) - - def log_generations_to_wandb(self, samples, step): - import wandb - - self._log_generations_to_wandb(samples, step, wandb) - - def _log_generations_to_wandb(self, samples, step, wandb): - """Log samples to wandb as a table""" - - # Create column names for all samples - columns = ["step"] + sum( - [[f"input_{i + 1}", f"output_{i + 1}", f"score_{i + 1}"] for i in range(len(samples))], [] - ) - - if not hasattr(self, "validation_table"): - # Initialize the table on first call - self.validation_table = wandb.Table(columns=columns) - - # Create a new table with same columns and existing data - # Workaround for https://github.com/wandb/wandb/issues/2981#issuecomment-1997445737 - new_table = wandb.Table(columns=columns, data=self.validation_table.data) - - # Add new row with all data - row_data = [] - row_data.append(step) - for sample in samples: - row_data.extend(sample) - - new_table.add_data(*row_data) - - # Update reference and log - wandb.log({"val/generations": new_table}, step=step) - self.validation_table = new_table - - def log_generations_to_swanlab(self, samples, step): - """Log samples to swanlab as text""" - import swanlab - - swanlab_table = swanlab.echarts.Table() - - # Create column names - headers = ["step", "input", "output", "score"] - - swanlab_row_list = [[step, *sample] for sample in samples] - swanlab_table.add(headers=headers, rows=swanlab_row_list) - - # Log to swanlab - swanlab.log({"val/generations": swanlab_table}, step=step) - - def log_generations_to_mlflow(self, samples, step): - """Log validation generation to mlflow as artifacts""" - # https://mlflow.org/docs/latest/api_reference/python_api/mlflow.html?highlight=log_artifact#mlflow.log_artifact - - import json - import tempfile - - import mlflow - - try: - with tempfile.TemporaryDirectory() as tmp_dir: - validation_gen_step_file = Path(tmp_dir, f"val_step{step}.json") - row_data = [] - for sample in samples: - data = {"input": sample[0], "output": sample[1], "score": sample[2]} - row_data.append(data) - with open(validation_gen_step_file, "w") as file: - json.dump(row_data, file) - mlflow.log_artifact(validation_gen_step_file) - except Exception as e: - print(f"WARNING: save validation generation file to mlflow failed with error {e}") - - def log_generations_to_clearml(self, samples, step): - """Log validation generation to clearml as table""" - - import clearml - import pandas as pd - - task: clearml.Task | None = clearml.Task.current_task() - if task is None: - return - - table = [ - { - "step": step, - "input": sample[0], - "output": sample[1], - "score": sample[2], - } - for sample in samples - ] - - logger = task.get_logger() - logger.report_table( - series="Validation generations", - title="Validation", - table_plot=pd.DataFrame.from_records(table), - iteration=step, - ) - - def log_generations_to_tensorboard(self, samples, step): - """Log samples to tensorboard as text""" - # Initialize tensorboard writer if not exists - if not hasattr(self, "writer"): - from torch.utils.tensorboard import SummaryWriter - - tensorboard_dir = os.environ.get("TENSORBOARD_DIR", "tensorboard_log") - os.makedirs(tensorboard_dir, exist_ok=True) - self.writer = SummaryWriter(log_dir=tensorboard_dir) - - # Format the samples data into readable text - text_content = f"**Generation Results - Step {step}**\n\n" - - for i, sample in enumerate(samples): - text_content += f"### Sample {i + 1}\n" - - # Assuming sample contains [input, output, score] - if len(sample) >= 3: - input_text, output_text, score = sample[0], sample[1], sample[2] - - text_content += f"**Input:** {input_text}\n\n" - text_content += f"**Output:** {output_text}\n\n" - text_content += f"**Score:** {score}\n\n" - else: - # Handle cases where sample format might be different - text_content += f"**Data:** {sample}\n\n" - - text_content += "---\n\n" - - # Log to tensorboard as text - self.writer.add_text("val/generations", text_content, step) - # Flush to ensure data is written - self.writer.flush() diff --git a/verl/utils/ulysses.py b/verl/utils/ulysses.py deleted file mode 100644 index 1669f6f32..000000000 --- a/verl/utils/ulysses.py +++ /dev/null @@ -1,328 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Utilities for DeepSpeed Ulysses Sequence Parallelism. -DeepSpeed Ulysses Paper: https://arxiv.org/abs/2309.14509 -Inspired from: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/sequence/layer.py -""" - -from typing import Any, Optional - -import torch -import torch.distributed as dist -from torch import Tensor -from torch.distributed import ProcessGroup - -_ULYSSES_SEQUENCE_PARALLEL_GROUP = None - - -def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup): - """ - Set ulysses sequence parallel process group. - """ - global _ULYSSES_SEQUENCE_PARALLEL_GROUP - _ULYSSES_SEQUENCE_PARALLEL_GROUP = group - - -def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]: - """ - Get ulysses sequence parallel process group. - """ - global _ULYSSES_SEQUENCE_PARALLEL_GROUP - return _ULYSSES_SEQUENCE_PARALLEL_GROUP - - -def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int: - """ - Get ulysses sequence parallel world size. - """ - group = get_ulysses_sequence_parallel_group() if group is None else group - return dist.get_world_size(group) if group else 1 - - -def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int: - """ - Get ulysses sequence parallel rank. - """ - group = get_ulysses_sequence_parallel_group() if group is None else group - return dist.get_rank(group) if group else 0 - - -def gather_seq_scatter_heads( - x: Tensor, - seq_dim: int, - head_dim: int, - unpadded_dim_size: int = 0, - group: ProcessGroup = None, -) -> Tensor: - """ - A func to sync embedding input with alltoall in sequence parallel - gather sequence dimension and scatter head dim: - e.g. seq_dim: 1, head_dim: 2 - [bsz, seq/n, h, ...] -> [bsz, seq, h/n, ...] - """ - group = get_ulysses_sequence_parallel_group() if group is None else group - if not group: - return x - sp_world = get_ulysses_sequence_parallel_world_size(group) - x = SeqAllToAll.apply(group, x, head_dim, seq_dim) - if unpadded_dim_size and unpadded_dim_size % sp_world != 0: - padding_size = x.size(seq_dim) - unpadded_dim_size - x = _unpad_tensor(x, seq_dim, padding_size) - return x - - -def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor: - """ - A func to sync attention result with alltoall in sequence parallel - gather head dimension and scatter seq dim: - e.g. seq_dim: 1, head_dim: 2 - [bsz, seq, h/n, ...] -> [bsz, seq/n, h, ...] - """ - group = get_ulysses_sequence_parallel_group() if group is None else group - if not group: - return x - dim_size = x.size(seq_dim) - sp_world = get_ulysses_sequence_parallel_world_size(group) - if dim_size % sp_world != 0: - padding_size = sp_world - (dim_size % sp_world) - x = _pad_tensor(x, seq_dim, padding_size) - return SeqAllToAll.apply(group, x, seq_dim, head_dim, False) - - -def _pad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor: - shape = list(x.shape) - shape[dim] = padding_size - pad = torch.zeros(shape, dtype=x.dtype, device=x.device) - return torch.cat([x, pad], dim=dim) - - -def _unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor: - slc = [slice(None)] * len(x.shape) - slc[dim] = slice(0, -padding_size) - return x[slc] - - -def slice_input_tensor(x: Tensor, dim: int, padding: bool = True, group: ProcessGroup = None) -> Tensor: - group = get_ulysses_sequence_parallel_group() if group is None else group - sp_world_size = dist.get_world_size(group) - sp_rank = get_ulysses_sequence_parallel_rank() - dim_size = x.size(dim) - # pad before slice - if padding and dim_size % sp_world_size: - padding_size = sp_world_size - (dim_size % sp_world_size) - x = _pad_tensor(x, dim, padding_size) - # slice the input tensor - parts = x.size(dim) // sp_world_size - slc = [slice(None)] * len(x.shape) - slc[dim] = slice(sp_rank * parts, (sp_rank + 1) * parts) - return x[slc].contiguous() - - -def all_to_all_tensor( - local_input: Tensor, - scatter_dim: int, - gather_dim: int, - group: Optional[dist.ProcessGroup] = None, - async_op: bool = False, -): - group = get_ulysses_sequence_parallel_group() if group is None else group - seq_world_size = dist.get_world_size(group) - input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)] - output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] - comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op) - if async_op: - - def wait(): - comm.wait() - return torch.cat(output_list, dim=gather_dim).contiguous() - - return wait - return torch.cat(output_list, dim=gather_dim).contiguous() - - -def all_gather_tensor(local_tensor: Tensor, group: Optional[dist.ProcessGroup] = None, async_op: bool = False): - group = get_ulysses_sequence_parallel_group() if group is None else group - sp_world_size = dist.get_world_size(group=group) - output_shape = list(local_tensor.shape) - output_shape[0] = output_shape[0] * sp_world_size - output = torch.empty(output_shape, dtype=local_tensor.dtype, device=local_tensor.device) - dist.all_gather_into_tensor(output, local_tensor, group=group, async_op=async_op) - return output - - -class SeqAllToAll(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - group: dist.ProcessGroup, - local_input: Tensor, - scatter_dim: int, - gather_dim: int, - async_op: bool = False, - ) -> Tensor: - ctx.group = group - ctx.scatter_dim = scatter_dim - ctx.gather_dim = gather_dim - ctx.async_op = async_op - return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op) - - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]: - input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous() if ctx.async_op else grad_output[0] - return ( - None, - all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False), - None, - None, - None, - None, - ) - - -class Gather(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - group: dist.ProcessGroup, - local_tensor: Tensor, - gather_dim: int, - grad_scaler: bool = True, - async_op=False, - ) -> Tensor: - ctx.group = group - ctx.gather_dim = gather_dim - ctx.grad_scaler = grad_scaler - ctx.async_op = async_op - - sp_world_size = dist.get_world_size(group=group) - ctx.sp_world_size = sp_world_size - - sp_rank = dist.get_rank(group=group) - ctx.sp_rank = sp_rank - - local_shape = list(local_tensor.size()) - split_size = local_shape[0] - part_size = local_shape[gather_dim] # store original size - ctx.part_size = part_size - - output = all_gather_tensor(local_tensor, group, async_op) - return torch.cat(output.split(split_size, dim=0), dim=gather_dim) - - @staticmethod - def backward(ctx: Any, grad_output: Tensor) -> Any: - if ctx.grad_scaler: - grad_output = grad_output * ctx.sp_world_size - return ( - None, - grad_output.split(ctx.part_size, dim=ctx.gather_dim)[ctx.sp_rank].contiguous(), - None, - None, - None, - None, - ) - - -def gather_outpus_and_unpad(*args, **kwargs): - raise RuntimeError( - "please use verl.utils.ulysses.gather_outputs_and_unpad instead of verl.utils.ulysses.gather_outpus_and_unpad" - ) - - -def gather_outputs_and_unpad( - x: Tensor, - gather_dim: int, - unpad_dim: int = None, - padding_size: int = 0, - grad_scaler: bool = True, - group: Optional[dist.ProcessGroup] = None, -): - """ - Gather a tensor across a process group and optionally unpad its padded elements. - - Args: - x (Tensor): Input tensor to gather. - gather_dim (int): Dimension along which to gather across ranks. - unpad_dim (int, optional): Dimension from which to remove padding. If None, no unpadding. - padding_size (int): Number of padding elements to remove on `unpad_dim`. Defaults to 0. - grad_scaler (bool): Whether to apply gradient scaling during gather. Defaults to True. - group (ProcessGroup, optional): Process group for gathering. If None, uses - `get_ulysses_sequence_parallel_group()`. If still None, returns `x` unchanged. - - Returns: - Tensor: The gathered tensor, with padding removed if requested. - """ - group = get_ulysses_sequence_parallel_group() if group is None else group - if group is None: - return x - x = Gather.apply(group, x, gather_dim, grad_scaler) - if unpad_dim is not None: - assert isinstance(padding_size, int), "padding size is not given or is not an integer" - if padding_size == 0: - return x - x = _unpad_tensor(x, unpad_dim, padding_size) - return x - - -def ulysses_pad(input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1): - if position_ids_rmpad is not None: - assert position_ids_rmpad.size(-2) == 1 - assert input_ids_rmpad.size(-1) == position_ids_rmpad.size(-1) - if sp_size <= 1: - return input_ids_rmpad, position_ids_rmpad, 0 - _, total_seq_len = input_ids_rmpad.shape - pad_size = (sp_size - total_seq_len % sp_size) % sp_size - if pad_size > 0: - input_ids_rmpad = torch.nn.functional.pad(input_ids_rmpad, (0, pad_size), value=0) - if position_ids_rmpad is not None: - pad_pos_ids = torch.arange(pad_size, device=position_ids_rmpad.device).unsqueeze(0) - if position_ids_rmpad.dim() == 3: - pad_pos_ids = pad_pos_ids.unsqueeze(0).repeat(3, 1, 1) - position_ids_rmpad = torch.cat((position_ids_rmpad, pad_pos_ids), dim=-1) - return input_ids_rmpad, position_ids_rmpad, pad_size - - -def ulysses_pad_and_slice_inputs( - input_ids_rmpad: torch.Tensor, position_ids_rmpad: Optional[torch.Tensor] = None, sp_size: int = 1 -): - """ - Pad and slice input_ids to be divisible by sp_size - Pad position_ids to be divisible by sp_size. - - Note both input_ids_rmpad and position_ids_rmpad will be padded and sliced. - - The is the utility of pre-forward for ulysses sequence parallelism - - Args: - input_ids_rmpad: shape of [bsz, seqlen] - position_ids_rmpad: shape of [bsz, seqlen], where bsz must be 1 - sp_size (int): ulysses sequence parallelism size - - Returns: - torch.Tensor: padded and sliced input_ids - torch.Tensor: padded and sliced position_ids - int: pad size - """ - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad(input_ids_rmpad, position_ids_rmpad, sp_size) - input_ids_rmpad = slice_input_tensor(input_ids_rmpad, dim=1, padding=False) - if position_ids_rmpad is not None: - position_ids_rmpad = slice_input_tensor(position_ids_rmpad, dim=1, padding=False) - return input_ids_rmpad, position_ids_rmpad, pad_size - - -def validate_ulysses_config(num_heads, ulysses_sequence_size): - if ulysses_sequence_size > 1: - assert num_heads % ulysses_sequence_size == 0, ( - f"num_heads ({num_heads}) must be divisible by ulysses sequence size({ulysses_sequence_size})" - ) diff --git a/verl/utils/vllm_utils.py b/verl/utils/vllm_utils.py deleted file mode 100644 index 25ee6656d..000000000 --- a/verl/utils/vllm_utils.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from msgspec import field -from packaging import version as vs -from vllm.lora.models import LoRAModel -from vllm.lora.request import LoRARequest -from vllm.lora.utils import get_adapter_absolute_path -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager - -from verl.third_party.vllm import get_version - -# To support different vLLM versions, we add the model into SUPPORTED_MOE_MODELS separately to avoid triggering -# unsupported issues. -SUPPORTED_MOE_MODELS = [] - -try: - from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM - - SUPPORTED_MOE_MODELS.append(DeepseekV2ForCausalLM) - SUPPORTED_MOE_MODELS.append(DeepseekV3ForCausalLM) -except ImportError: - pass - -try: - from vllm.model_executor.models.mixtral import MixtralForCausalLM - - SUPPORTED_MOE_MODELS.append(MixtralForCausalLM) -except ImportError: - pass - -try: - from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM - - SUPPORTED_MOE_MODELS.append(Qwen2MoeForCausalLM) -except ImportError: - pass - -try: - from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM - - SUPPORTED_MOE_MODELS.append(Qwen3MoeForCausalLM) -except ImportError: - pass - -try: - from vllm.model_executor.models.kimi_vl import KimiVLForConditionalGeneration - - SUPPORTED_MOE_MODELS.append(KimiVLForConditionalGeneration) -except ImportError: - pass - - -def patch_vllm_moe_model_weight_loader(model): - # this is a work around to load the weight of vllm fused moe model - # it is from a bug from vllm 0.8.2 - # all the weights are supposed to have a weight_loader, but the moe weights - # do not have a weight_loader, so we need to patch it - # (True, 'model.embed_tokens.weight') - # (True, 'model.layers.0.self_attn.qkv_proj.weight') - # (True, 'model.layers.0.self_attn.qkv_proj.bias') - # (True, 'model.layers.0.self_attn.o_proj.weight') - # (True, 'model.layers.0.mlp.gate.weight') - # (True, 'model.layers.0.mlp.shared_expert.gate_up_proj.weight') - # (True, 'model.layers.0.mlp.shared_expert.down_proj.weight') - # (False, 'model.layers.0.mlp.shared_expert_gate.weight') use default - # (False, 'model.layers.0.input_layernorm.weight') use default - # (False, 'model.layers.0.post_attention_layernorm.weight') use default - # (False, 'model.layers.0.mlp.experts.w13_weight') use mlp.experts.weight_loader - # (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader - - # Define MLP attribute mapping for different model types - MLP_ATTR_MAPPING = { - MixtralForCausalLM: "block_sparse_moe", - } - DEFAULT_MLP_ATTR = "mlp" - - if not isinstance(model, tuple(SUPPORTED_MOE_MODELS)): - return - - model = getattr(model, "model", None) or getattr(model, "language_model", None) - if model is None: - raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.") - - for layer in model.layers: - mlp_attr = MLP_ATTR_MAPPING.get(type(model), DEFAULT_MLP_ATTR) - mlp = getattr(layer, mlp_attr) - - param_dict = dict(mlp.named_parameters()) - for name, param in param_dict.items(): - if "w13_weight" in name or "w2_weight" in name: - param.weight_loader = mlp.experts.weight_loader - - -class TensorLoRARequest(LoRARequest): - peft_config: dict = field(default=None) - lora_tensors: dict = field(default=None) - - -class VLLMHijack: - @staticmethod - def hijack(): - def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: - """ - based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors - - Reason: - VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths. - To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to - load memory-based LoRA tensors. - """ - try: - supported_lora_modules = self._adapter_manager.supported_lora_modules - packed_modules_mapping = self._adapter_manager.packed_modules_mapping - expected_lora_modules: list[str] = [] - for module in supported_lora_modules: - if module in packed_modules_mapping: - expected_lora_modules.extend(packed_modules_mapping[module]) - else: - expected_lora_modules.append(module) - - expected_lora_modules = list(set(expected_lora_modules)) - - lora_tensors = None - from vllm.lora.peft_helper import PEFTHelper - - if isinstance(lora_request, TensorLoRARequest): - peft_config = lora_request.peft_config - lora_tensors = lora_request.lora_tensors - peft_helper = PEFTHelper.from_dict(peft_config) - else: - lora_path = get_adapter_absolute_path(lora_request.lora_path) - - peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings) - - # Validates the LoRA configuration against requirements before - # loading weights, throwing an exception if validation fails. - peft_helper.validate_legal(self.lora_config) - - # For some models like Qwen2VL, we need to use hf_to_vllm_mapper - # to ensure correct loading of lora weights. - model = self._adapter_manager.model - hf_to_vllm_mapper = None - if hasattr(model, "hf_to_vllm_mapper") and model.hf_to_vllm_mapper is not None: - hf_to_vllm_mapper = model.hf_to_vllm_mapper - - if isinstance(lora_request, TensorLoRARequest): - lora = self._lora_model_cls.from_lora_tensors( - lora_model_id=lora_request.lora_int_id, - tensors=lora_tensors, - peft_helper=peft_helper, - device="cpu", - dtype=self.lora_config.lora_dtype, - embeddings=None, - target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, - embedding_modules=self.embedding_modules, - embedding_padding_modules=self.embedding_padding_modules, - weights_mapper=hf_to_vllm_mapper, - ) - else: - lora = self._lora_model_cls.from_local_checkpoint( - lora_path, - expected_lora_modules, - peft_helper=peft_helper, - lora_model_id=lora_request.lora_int_id, - device="cpu", - dtype=self.lora_config.lora_dtype, - target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size, - embedding_modules=self.embedding_modules, - embedding_padding_modules=self.embedding_padding_modules, - weights_mapper=hf_to_vllm_mapper, - ) - except Exception as e: - raise e - - if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: - raise ValueError( - f"LoRA added vocab size {lora.extra_vocab_size} is greater than lora_extra_vocab_size " - f"{self.lora_config.lora_extra_vocab_size}." - ) - return lora - - def do_hijack(target_cls, target_method_name, hooking_method): - setattr(target_cls, target_method_name, hooking_method) - - do_hijack(LRUCacheWorkerLoRAManager, "_load_adapter", hijack__load_adapter) - - -def is_version_ge(pkg: str = "vllm", minver: str = "0.7.3"): - """check if the package version is greater than or equal to the minimum version""" - return vs.parse(get_version(pkg)) >= vs.parse(minver) diff --git a/verl/version/version b/verl/version/version deleted file mode 100644 index 04c4d903d..000000000 --- a/verl/version/version +++ /dev/null @@ -1 +0,0 @@ -0.4.1.dev diff --git a/verl/workers/__init__.py b/verl/workers/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/workers/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/workers/actor/__init__.py b/verl/workers/actor/__init__.py deleted file mode 100644 index 7a1404e17..000000000 --- a/verl/workers/actor/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .base import BasePPOActor -from .dp_actor import DataParallelPPOActor - -__all__ = ["BasePPOActor", "DataParallelPPOActor"] diff --git a/verl/workers/actor/base.py b/verl/workers/actor/base.py deleted file mode 100644 index 2d1ba290d..000000000 --- a/verl/workers/actor/base.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -The base class for Actor -""" - -from abc import ABC, abstractmethod - -import torch - -from verl import DataProto - -__all__ = ["BasePPOActor"] - - -class BasePPOActor(ABC): - def __init__(self, config): - """The base class for PPO actor - - Args: - config (DictConfig): a config passed to the PPOActor. We expect the type to be - DictConfig (https://omegaconf.readthedocs.io/), but it can be any namedtuple in general. - """ - super().__init__() - self.config = config - - @abstractmethod - def compute_log_prob(self, data: DataProto) -> torch.Tensor: - """Compute logits given a batch of data. - - Args: - data (DataProto): a batch of data represented by DataProto. It must contain key ```input_ids```, - ```attention_mask``` and ```position_ids```. - - Returns: - DataProto: a DataProto containing the key ```log_probs``` - - - """ - pass - - @abstractmethod - def update_policy(self, data: DataProto) -> dict: - """Update the policy with an iterator of DataProto - - Args: - data (DataProto): an iterator over the DataProto that returns by - ```make_minibatch_iterator``` - - Returns: - Dict: a dictionary contains anything. Typically, it contains the statistics during updating the model - such as ```loss```, ```grad_norm```, etc,. - - """ - pass diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py deleted file mode 100644 index d5cea3620..000000000 --- a/verl/workers/actor/dp_actor.py +++ /dev/null @@ -1,486 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Single Process Actor -""" - -import logging -import os - -import torch -from torch import nn -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -import verl.utils.torch_functional as verl_F -from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty -from verl.utils.device import get_device_name, is_cuda_available, is_npu_available -from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ -from verl.utils.profiler import GPUMemoryLogger -from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch -from verl.utils.torch_functional import logprobs_from_logits -from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad, ulysses_pad_and_slice_inputs -from verl.workers.actor import BasePPOActor - -if is_cuda_available: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input -elif is_npu_available: - from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input - - -__all__ = ["DataParallelPPOActor"] - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class DataParallelPPOActor(BasePPOActor): - def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim.Optimizer = None): - """When optimizer is None, it is Reference Policy""" - super().__init__(config) - self.actor_module = actor_module - self.actor_optimizer = actor_optimizer - - self.use_remove_padding = self.config.get("use_remove_padding", False) - if torch.distributed.get_rank() == 0: - print(f"Actor use_remove_padding={self.use_remove_padding}") - self.use_fused_kernels = self.config.get("use_fused_kernels", False) - if torch.distributed.get_rank() == 0: - print(f"Actor use_fused_kernels={self.use_fused_kernels}") - - self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size - self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 - - if self.config.entropy_from_logits_with_chunking: - entropy_from_logits = verl_F.entropy_from_logits_with_chunking - else: - entropy_from_logits = verl_F.entropy_from_logits - - self.compute_entropy_from_logits = ( - torch.compile(entropy_from_logits, dynamic=True) - if self.config.get("use_torch_compile", True) # use torch compile by default - else entropy_from_logits - ) - self.device_name = get_device_name() - - def _forward_micro_batch( - self, micro_batch, temperature, calculate_entropy=False - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Returns: - entropy: # (bs, response_len) - log_probs: # (bs, response_len) - """ - response_length = micro_batch["responses"].size(-1) - multi_modal_inputs = {} - if "multi_modal_inputs" in micro_batch.keys(): - if "image_bound" in micro_batch["multi_modal_inputs"][0]: # minicpm-o logic - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = [inputs[key] for inputs in micro_batch["multi_modal_inputs"]] - else: - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 - ) - - with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - entropy = None - if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) - - if self.use_remove_padding: - input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - if position_ids.dim() == 3: - position_ids_rmpad = ( - index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) - .transpose(0, 1) - .unsqueeze(1) - ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) - else: - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - if "image_bound" in multi_modal_inputs: - from verl.utils.dataset.vision_utils import process_multi_modal_inputs_for_minicpmo - - multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( - input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs - ) - - # for compute the log_prob - input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) - - # pad and slice the inputs if sp > 1 - if self.use_ulysses_sp: - is_vlm_model = "multi_modal_inputs" in micro_batch.keys() - if is_vlm_model: - # vlm model's inputs will be sliced after embedding - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( - input_ids_rmpad, - position_ids_rmpad=position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) - else: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, - position_ids_rmpad=position_ids_rmpad, - sp_size=self.ulysses_sequence_parallel_size, - ) - input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( - input_ids_rmpad_rolled, - position_ids_rmpad=None, - sp_size=self.ulysses_sequence_parallel_size, - ) - - input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) - - # only pass input_ids and position_ids to enable flash_attn_varlen - extra_args = {} - if self.use_fused_kernels: - extra_args["temperature"] = temperature - extra_args["return_dict"] = True - - output = self.actor_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - **multi_modal_inputs, - use_cache=False, - **extra_args, - ) # prevent model thinks we are generating - - if self.use_fused_kernels: - log_probs = output.log_probs.squeeze(0) # (total_nnz,) - entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) - - else: - logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - logits_rmpad.div_(temperature) - - # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - inplace_backward = True - if calculate_entropy: - inplace_backward = False - log_probs = logprobs_from_logits( - logits=logits_rmpad, - labels=input_ids_rmpad_rolled, - inplace_backward=inplace_backward, - ) - - # compute entropy - if calculate_entropy: - if not self.config.entropy_checkpointing: - entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) - else: - entropy_rmpad = torch.utils.checkpoint.checkpoint( - self.compute_entropy_from_logits, logits_rmpad - ) - - # gather log_prob if sp > 1 - if self.use_ulysses_sp: - # gather and unpad for the ulysses sp - log_probs = gather_outputs_and_unpad( - log_probs, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size, - ) - if calculate_entropy: - entropy_rmpad = gather_outputs_and_unpad( - entropy_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size, - ) - # pad back to (bsz, seqlen) - if calculate_entropy: - full_entropy = pad_input( - hidden_states=entropy_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - full_log_probs = pad_input( - hidden_states=log_probs.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen, - ) - - # only return response part: - if calculate_entropy: - entropy = full_entropy.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) - log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1 : -1] # (bsz, response_length) - - else: # not using rmpad and no ulysses sp - extra_args = {} - if self.use_fused_kernels: - extra_args["temperature"] = temperature - extra_args["return_dict"] = True - - output = self.actor_module( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **multi_modal_inputs, - use_cache=False, - **extra_args, - ) # prevent model thinks we are generating - - if self.use_fused_kernels: - log_probs = output.log_probs[:, -response_length - 1 : -1] - entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) - - else: - logits = output.logits - - logits.div_(temperature) - logits = logits[:, -response_length - 1 : -1, :] # (bsz, response_length, vocab_size) - log_probs = logprobs_from_logits(logits, micro_batch["responses"]) - if calculate_entropy: - if not self.config.entropy_checkpointing: - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) - else: - entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits) - - return entropy, log_probs - - def _optimizer_step(self): - assert self.config.grad_clip is not None - - if isinstance(self.actor_module, FSDP): - grad_norm = self.actor_module.clip_grad_norm_(max_norm=self.config.grad_clip) - elif isinstance(self.actor_module, FSDPModule): - grad_norm = fsdp2_clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) - else: - grad_norm = torch.nn.utils.clip_grad_norm_(self.actor_module.parameters(), max_norm=self.config.grad_clip) - - # if grad_norm is not finite, skip the update - if not torch.isfinite(grad_norm): - print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}") - self.actor_optimizer.zero_grad() - else: - self.actor_optimizer.step() - return grad_norm - - @GPUMemoryLogger(role="dp actor", logger=logger) - def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor: - """Compute the log probability of the responses given input_ids, attention_mask and position_ids - - Args: - data (DataProto): a DataProto containing keys - - ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the - concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. - - ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``responses``: tensor of shape [batch_size, response_length]. torch.int64. - - Returns: - torch.Tensor: the log_prob tensor - """ - # set to eval - self.actor_module.eval() - - micro_batch_size = data.meta_info["micro_batch_size"] - temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error - use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] - - data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) - - if use_dynamic_bsz: - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) - else: - micro_batches = data.split(micro_batch_size) - - log_probs_lst = [] - entropy_lst = [] - for micro_batch in micro_batches: - model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} - with torch.no_grad(): - entropy, log_probs = self._forward_micro_batch( - model_inputs, temperature=temperature, calculate_entropy=calculate_entropy - ) - log_probs_lst.append(log_probs) - if calculate_entropy: - entropy_lst.append(entropy) - - log_probs = torch.concat(log_probs_lst, dim=0) - entropys = None - if calculate_entropy: - entropys = torch.concat(entropy_lst, dim=0) - - if use_dynamic_bsz: - log_probs = restore_dynamic_batch(log_probs, batch_idx_list) - if calculate_entropy: - entropys = restore_dynamic_batch(entropys, batch_idx_list) - - return log_probs, entropys - - @GPUMemoryLogger(role="dp actor", logger=logger) - def update_policy(self, data: DataProto): - # make sure we are in training mode - self.actor_module.train() - - temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error - - select_keys = [ - "responses", - "response_mask", - "input_ids", - "attention_mask", - "position_ids", - "old_log_probs", - "advantages", - ] - if self.config.use_kl_loss: - select_keys.append("ref_log_prob") - - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] - - data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) - - # Split to make minibatch iterator for updating the actor - # See PPO paper for details. https://arxiv.org/abs/1707.06347 - mini_batches = data.split(self.config.ppo_mini_batch_size) - - metrics = {} - for _ in range(self.config.ppo_epochs): - for batch_idx, mini_batch in enumerate(mini_batches): - if self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) - else: - self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - ) - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - - self.actor_optimizer.zero_grad() - - for micro_batch in micro_batches: - micro_batch_metrics = {} - model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} - response_mask = model_inputs["response_mask"] - old_log_prob = model_inputs["old_log_probs"] - advantages = model_inputs["advantages"] - - clip_ratio = self.config.clip_ratio - clip_ratio_low = ( - self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio - ) - clip_ratio_high = ( - self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio - ) - clip_ratio_c = self.config.get("clip_ratio_c", 3.0) - entropy_coeff = self.config.entropy_coeff - loss_agg_mode = self.config.loss_agg_mode - - # all return: (bsz, response_length) - calculate_entropy = False - if entropy_coeff != 0: - calculate_entropy = True - entropy, log_prob = self._forward_micro_batch( - model_inputs, temperature=temperature, calculate_entropy=calculate_entropy - ) - - loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") - - if self.config.policy_loss.loss_mode == "vanilla": - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - cliprange=clip_ratio, - cliprange_low=clip_ratio_low, - cliprange_high=clip_ratio_high, - clip_ratio_c=clip_ratio_c, - loss_agg_mode=loss_agg_mode, - ) - - else: - policy_loss_fn = get_policy_loss_fn(loss_mode) - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - loss_agg_mode=loss_agg_mode, - config=self.config, - ) - - if entropy_coeff != 0: - entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - - # compute policy loss - policy_loss = pg_loss - entropy_loss * entropy_coeff - else: - policy_loss = pg_loss - - if self.config.use_kl_loss: - ref_log_prob = model_inputs["ref_log_prob"] - # compute kl loss - kld = kl_penalty( - logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type - ) - kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - - policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - micro_batch_metrics["actor/kl_loss"] = kl_loss.detach().item() - micro_batch_metrics["actor/kl_coef"] = self.config.kl_loss_coef - - if self.config.use_dynamic_bsz: - # relative to the dynamic bsz - loss = policy_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size) - else: - loss = policy_loss / self.gradient_accumulation - loss.backward() - - micro_batch_metrics.update( - { - "actor/pg_loss": pg_loss.detach().item(), - "actor/pg_clipfrac": pg_clipfrac.detach().item(), - "actor/ppo_kl": ppo_kl.detach().item(), - "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), - } - ) - append_to_dict(metrics, micro_batch_metrics) - - grad_norm = self._optimizer_step() - mini_batch_metrics = {"actor/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, mini_batch_metrics) - self.actor_optimizer.zero_grad() - return metrics diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py deleted file mode 100644 index ce52956d0..000000000 --- a/verl/workers/actor/megatron_actor.py +++ /dev/null @@ -1,658 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Megatron Actor. -In megatron actor, the differences are: -1. We only make minibatch - -Note that our model doesn't have to be `MegatronModule` because we don't share embedding in the last layer -""" - -import itertools -import logging -import os -from functools import partial -from typing import Iterable - -import torch -import torch.distributed -from megatron.core import parallel_state as mpu -from megatron.core.distributed import finalize_model_grads - -# from megatron.core.optimizer import DistributedOptimizer -from megatron.core.optimizer import DistributedOptimizer -from megatron.core.pipeline_parallel import get_forward_backward_func -from omegaconf import OmegaConf -from torch import nn - -from verl import DataProto -from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, get_policy_loss_fn, kl_penalty -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.megatron.pipeline_parallel import make_batch_generator -from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits -from verl.utils.megatron_utils import get_model_config -from verl.utils.profiler import GPUMemoryLogger -from verl.utils.profiler.profile import Profiler -from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches -from verl.utils.torch_functional import broadcast_dict_tensor -from verl.workers.actor import BasePPOActor - -__all__ = ["MegatronPPOActor"] - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class MegatronPPOActor(BasePPOActor): - def __init__( - self, - config, - model_config, - hf_config, - tf_config, - actor_module: nn.ModuleList, - actor_optimizer: DistributedOptimizer, - ): - """MeagtronPPOActor class. This class implements the simple PPO logics when the model is built with Megatron. - - Args: - config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain - - ``ppo_micro_batch_size_per_gpu``: micro batch size when updating ppo. - - ``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data. - - ``ppo_epochs``: number of epochs to update the actor using the batch data. - - ``shuffle``: whether to shuffle the data after each ppo epoch. - - ``clip_ratio``: clip ratio of the ppo algorithm. See https://arxiv.org/abs/1707.06347. - - ``entropy_coeff``: entropy coefficient of the PPO loss. See https://arxiv.org/abs/1707.06347. - model_config (OmegaConf): model configuration. It must contains ``model_config.vocab_size`` and - ``model_config.hidden_size`` - hf_config (PretrainedConfig): huggingface config - tf_config (TransformerConfig): mcore transformer config - actor_module (nn.ModuleList): actor module is a ModuleList that contains a list of nn.Module in this - pp stage. - each nn.Module in this rank holds a vpp module chunk. See https://arxiv.org/pdf/2104.04473.pdf for - more details. - The actor module has some constraints to follow in order to use the updating logics implemented here - - 1. It must implement unpad_input before any computation and pad_input after all the computation. - Remove padding is an - optimization that removes the padding tokens. See unpad_input and pad_input function in flash-attn - (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py). - - 2. Each pp stage must return the hidden state with the same shape [total_nnz, 1, hidden_size], - where total_nnz is the number of valid tokens in this batch. If sequence parallel is enabled, the size - of the hidden state is [total_nnz // tp, 1, hidden_size]. - actor_optimizer (DistributedOptimizer): currently, we only support DistributedOptimizer in Megatron. - It implements - zero1 optimizer that shards the optimizer state across dp ranks. - - >>> from megatron.training import get_model - >>> from megatron.optimizer import get_megatron_optimizer - >>> actor_module = get_model(megatron_actor_model_provider, wrap_with_ddp=True) - >>> actor_module = nn.ModuleList(actor_module) - >>> actor_optimizer = get_megatron_optimizer(actor_module) - >>> actor = MegatronPPOActor(config=config, - >>> model_config=actor_model_config, - >>> hf_config=hf_config, - >>> tf_config=tf_config, - >>> actor_module=actor_module, - >>> actor_optimizer=actor_optimizer) - """ - super().__init__(config) - self._validate_config(config) - self.model_config = model_config - self.hf_config = hf_config - self.tf_config = tf_config - self.actor_module = actor_module - self.actor_optimizer: DistributedOptimizer = actor_optimizer - self.prof = Profiler(self.config.profile) - self.use_fused_kernels = self.config.get("use_fused_kernels", False) - if self.use_fused_kernels: - from verl.models.mcore.model_forward_fused import patch_fused_forward - - for model in self.actor_module: - patch_fused_forward(model) - - self.optimizer_step_args = OmegaConf.create( - { - "skip_grad": None, - "overlap_dp_param_comm": False, - "overlap_dp_grad_comm": False, - "gradient_accumulation_steps": 1, - "sequence_parallel": self.tf_config.sequence_parallel, - "DDP_impl": "local", - "layernorm_allreduce_bucket_threshold": 0, - "pipeline_model_parallel_split_rank": None, - "reduce_grads_use_alltoall": False, - } - ) - - config = get_model_config(self.actor_module[0]) - print(config) - config.finalize_model_grads_func = finalize_model_grads - - def _validate_config(self, config) -> None: - """Validate config options not implemented for Megatron backend""" - assert config.get("ulysses_sequence_parallel_size", 1) == 1 - if config.get("shuffle", False): - assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set" - if config.megatron.tensor_model_parallel_size == 1: - print("[Warining] Because actor tp size == 1, set sp to False") - config.megatron.sequence_parallel = False - self.config = config - - @GPUMemoryLogger(role="megatron actor", logger=logger) - def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor: - """Compute the log probability of the responses given input_ids, attention_mask and position_ids - - Args: - data (DataProto): a DataProto containing keys - - ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. Note that input_ids is the - concatenation of prompt and response. Note that ``sequence_length = prompt_length + response_length``. - - ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64. - - ``responses``: tensor of shape [batch_size, response_length]. torch.int64. - - Returns: - DataProto: torch.Tensor: the log_prob tensor - """ - data.to(get_device_id()) - data.batch = data.batch.contiguous() - use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) - micro_batch_size = data.meta_info.get("micro_batch_size", None) - max_token_len = data.meta_info.get("max_token_len", None) - assert micro_batch_size is not None, "micro batch size is needed for forward compute" - if use_dynamic_bsz: - assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" - max_token_len = max_token_len * self.config.megatron.context_parallel_size - - def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None): - response = data["responses"] - response_length = response.size(1) - log_probs = output["log_probs"][:, -response_length - 1 : -1].contiguous() - return {"log_probs": log_probs} - - # We make recompute_old_log_prob by default here. - # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be - # handled by user outside - recompute_old_log_prob = self.config.get("recompute_old_log_prob", True) - - entropys = torch.Tensor() - if recompute_old_log_prob: - select_keys = ["responses", "input_ids", "attention_mask", "position_ids"] - batch = data.select(batch_keys=select_keys).batch - input_ids = batch["input_ids"] - batch_size = input_ids.size(0) - response = batch["responses"] - response_length = response.size(1) - with torch.no_grad(): - output = self.forward_backward_batch( - data, - forward_only=True, - post_process_fn=compute_logprobs_fn, - calculate_entropy=calculate_entropy, - use_dynamic_bsz=use_dynamic_bsz, - micro_batch_size=micro_batch_size, - max_token_len=max_token_len, - ) - if mpu.is_pipeline_last_stage(ignore_virtual=True): - # only on last rank. It should be on every tp rank - if calculate_entropy: - log_probs = [o[0]["log_probs"] for o in output["output"]] # (bs, seq_size) - else: - log_probs = [o["log_probs"] for o in output["output"]] # (bs, seq_size) - log_probs = torch.cat(log_probs, dim=0).to(torch.float32) - if use_dynamic_bsz: - indices = output["indices"] - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - log_probs = log_probs[revert_indices] - else: - log_probs = torch.empty( - size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device - ) - - # broadcast across pp ranks - torch.distributed.broadcast( - tensor=log_probs, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - async_op=False, - ) - if calculate_entropy: - # Note that o[0] is metrics, o[1] is entropy - if mpu.is_pipeline_last_stage(ignore_virtual=True): - entropys = torch.cat([o[1] for o in output["output"]], dim=0) - entropys = entropys.to(torch.float32) - if use_dynamic_bsz: - indices = output["indices"] - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == entropys.size(0), f"{len(indices)} vs. {entropys.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - entropys = entropys[revert_indices] - else: - entropys = torch.empty( - size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device - ) - # broadcast across pp ranks - torch.distributed.broadcast( - tensor=entropys, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - async_op=False, - ) - - # add empty cache after each compute - get_torch_device().empty_cache() - - return log_probs, entropys - - def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: - """Make minibatch iterator for updating the actor - - Args: - data (DataProto): a DataProto containing keys - - ``input_ids``: tensor of shape [batch_size, sequence_length]. torch.int64, where - ``sequence_length = prompt_length + response_length`` - - ``attention_mask``: tensor of shape [batch_size, sequence_length]. torch.int64 - - ``position_ids``: tensor of shape [batch_size, sequence_length]. torch.int64 - - ``responses``: tensor of shape [batch_size, response_length]. torch.int64. Note that - responses = input_ids[:, -response_length:] - - ``old_log_probs``: tensor of shape [batch_size, response_length]. torch.float32. The log probability - of responses. - - ``advantages``: tensor of shape [batch_size, response_length]. torch.float32. The advantages of - responses. - See PPO paper for details. https://arxiv.org/abs/1707.06347 - - Returns: - - """ - select_keys = [ - "responses", - "input_ids", - "attention_mask", - "response_mask", - "position_ids", - "old_log_probs", - "advantages", - ] - if self.config.use_kl_loss: - select_keys.append("ref_log_prob") - self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - if self.has_multi_modal_inputs: - data = data.select(select_keys, ["multi_modal_inputs"]) - else: - data = data.select(batch_keys=select_keys) - return data.make_iterator( - mini_batch_size=self.config.ppo_mini_batch_size, - epochs=self.config.ppo_epochs, - seed=self.config.data_loader_seed, - dataloader_kwargs={"shuffle": self.config.shuffle}, - ) - - def forward_backward_batch( - self, - data: DataProto, - forward_only=False, - post_process_fn=None, - calculate_entropy=False, - use_dynamic_bsz=False, - micro_batch_size=None, - max_token_len=None, - mini_batch_size=None, - ): - """ - We assume: - - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input - - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled - """ - # broadcast from last pp rank to all other pp ranks - # TODO: actually, we just need to control the sampling order. - mini_batch = data - broadcast_dict_tensor( - mini_batch.batch, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - ) - # split into micro-batches - mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) - self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() - if self.has_multi_modal_inputs: - mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"] - mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor( - list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"]))) - ).to(torch.int64) - - if mini_batch.batch["position_ids"].dim() == 3: # qwen2vl mrope [bs, 3, seq_len] - mini_batch.batch["position_ids"] = mini_batch.batch["position_ids"][ - :, 0 - ] # mcore patch recompute qwen2vl's pos ids during forward - - indices = None - if use_dynamic_bsz: - assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" - vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - if vpp_size is not None and vpp_size > 1: - microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage - micro_batches, indices = rearrange_micro_batches( - batch=mini_batch.batch, - num_batches_divided_by=microbatch_group_size_per_vp_stage, - max_token_len=max_token_len, - ) - assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( - f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " - f"{microbatch_group_size_per_vp_stage} for megatron backend" - ) - else: - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) - total_seqlen = max_token_len - else: - assert micro_batch_size is not None, ( - "micro_batch_size is needed to be passed in when not using dynamic batch size" - ) - micro_batches = mini_batch.batch.split(micro_batch_size) - seq_len = micro_batches[0]["input_ids"].shape[1] - total_seqlen = micro_batch_size * seq_len - # compute input shapes for pp stages - n_micro_batch = len(micro_batches) - - forward_backward_func = get_forward_backward_func() - - def loss_func(output, data, meta_info): - # For memory efficiency - # We move calculation of entropy to compute_log_probs, forward_only == True - device = output["log_probs"].device - metrics = {} - if forward_only: - if post_process_fn is None: - pass - # metrics["logits"] = output - else: - stats = post_process_fn(output, data) - metrics.update(stats) - if not calculate_entropy: - return torch.tensor(1.0, device=device), metrics - - responses = data["responses"] - response_length = responses.size(1) - response_mask = data["response_mask"].to(bool) - loss_agg_mode = self.config.loss_agg_mode - - # compute policy loss - log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous() - ret_entropy = None - stats = {} - if not forward_only: - old_log_prob = data["old_log_probs"] - advantages = data["advantages"] - - clip_ratio = self.config.clip_ratio - clip_ratio_low = self.config.clip_ratio_low if self.config.clip_ratio_low is not None else clip_ratio - clip_ratio_high = self.config.clip_ratio_high if self.config.clip_ratio_high is not None else clip_ratio - - clip_ratio_c = self.config.get("clip_ratio_c", 3.0) - entropy_coeff = self.config.entropy_coeff - loss_agg_mode = self.config.loss_agg_mode - - loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") - - if self.config.policy_loss.loss_mode == "vanilla": - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - cliprange=clip_ratio, - cliprange_low=clip_ratio_low, - cliprange_high=clip_ratio_high, - clip_ratio_c=clip_ratio_c, - loss_agg_mode=loss_agg_mode, - ) - - else: - policy_loss_fn = get_policy_loss_fn(loss_mode) - pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn( - old_log_prob=old_log_prob, - log_prob=log_prob, - advantages=advantages, - response_mask=response_mask, - loss_agg_mode=loss_agg_mode, - config=self.config, - ) - - stats.update( - { - "actor/pg_loss": pg_loss.detach().item(), - "actor/pg_clipfrac": pg_clipfrac.detach().item(), - "actor/ppo_kl": ppo_kl.detach().item(), - "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), - } - ) - policy_loss = pg_loss - - if calculate_entropy: - entropy = output["entropy"][:, -response_length - 1 : -1].contiguous() - if not forward_only: - entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) - entropy_coeff = meta_info["entropy_coeff"] - policy_loss = pg_loss - entropy_coeff * entropy_loss - else: - ret_entropy = entropy - - if forward_only: - policy_loss = torch.tensor(1.0, device=device) - else: - if self.config.use_kl_loss: - ref_log_prob = data["ref_log_prob"] - # compute kl loss - kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type) - kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode) - - policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef - metrics["actor/kl_loss"] = kl_loss.detach().item() - metrics["actor/kl_coef"] = self.config.kl_loss_coef - - # return loss and stats - - append_to_dict(metrics, stats) - return policy_loss, [metrics, ret_entropy] - - def forward_step(batch_iter, model): - batch = next(batch_iter) - input_ids = batch["input_ids"] - attention_mask = batch["attention_mask"].to(bool) - position_ids = batch["position_ids"] - - multi_modal_inputs = {} - if "multi_modal_inputs" in batch: - for key in batch["multi_modal_inputs"][0].keys(): - idxs = batch["multi_modal_inputs_idx"] - mmi = batch["multi_modal_inputs"] - multi_modal_inputs[key] = torch.cat( - [mmi[idx].get(key) for idx in idxs if mmi[idx].get(key) is not None], dim=0 - ) - responses = batch["responses"] - response_length = responses.size(1) - label = position_ids.clone() - label[:, -response_length - 1 : -1] = responses - label_mask = attention_mask.clone() - label_mask[:, : -response_length - 1] = False - label_mask[:, -1] = False - - from verl.models.mcore import get_mcore_forward_fn, get_mcore_forward_fused_fn - - if self.use_fused_kernels: - forward_fn = get_mcore_forward_fused_fn(self.hf_config) - # return dict of [logits, entropy] - output = forward_fn( - model, - input_ids, - position_ids, - attention_mask, - sequence_parallel=self.tf_config.sequence_parallel, - multi_modal_inputs=multi_modal_inputs, - labels=label, - labels_mask=label_mask, - ) - else: - forward_fn = get_mcore_forward_fn(self.hf_config) - - def logits_processor(logits, label, label_mask): - assert logits.shape[:2] == label.shape[:2] - assert label.shape == label_mask.shape - ret = {} - if calculate_entropy: - entropy = vocab_parallel_entropy(logits) - ret["entropy"] = entropy - log_probs = vocab_parallel_log_probs_from_logits(logits, label) - log_probs = log_probs.masked_fill(~label_mask, 0.0) - ret["log_probs"] = log_probs - return ret - - logits_processor_args = {"label": label, "label_mask": label_mask} - output = forward_fn( - model, - input_ids, - attention_mask, - position_ids, - sequence_parallel=self.tf_config.sequence_parallel, - multi_modal_inputs=multi_modal_inputs, - logits_processor=logits_processor, - logits_processor_args=logits_processor_args, - ) - - if forward_only: - meta_info = None - else: - clip_ratio_c = self.config.get("clip_ratio_c", 3.0) - meta_info = { - "clip_ratio": self.config.clip_ratio, - "entropy_coeff": self.config.entropy_coeff, - "clip_ratio_c": clip_ratio_c, - } - return output, partial(loss_func, data=batch, meta_info=meta_info) - - # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.actor_module)) - - # TODO: we may use the new schedule instead - # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) - if mpu.get_pipeline_model_parallel_world_size() > 1: - losses_reduced = forward_backward_func( - forward_step_func=forward_step, - data_iterator=batch_generator, - model=self.actor_module, - num_microbatches=n_micro_batch, - seq_length=total_seqlen, # no use when input_shapes was set - micro_batch_size=1, # no use when input_shapes was set - forward_only=forward_only, - ) - else: - losses_reduced = forward_backward_func( - forward_step_func=forward_step, - data_iterator=batch_generator, - model=self.actor_module, - num_microbatches=n_micro_batch, - seq_length=total_seqlen, # in use for pp = 1 - micro_batch_size=1, # in use for pp = 1 - forward_only=forward_only, - ) - # loss_reduces contains the stats returned from loss_func - - if self.has_multi_modal_inputs: - data.batch.pop("multi_modal_inputs") - data.batch.pop("multi_modal_inputs_idx") - data.non_tensor_batch.pop("multi_modal_inputs") - - losses_reduced = {"output": losses_reduced} - if use_dynamic_bsz: - losses_reduced["indices"] = indices - return losses_reduced - - @GPUMemoryLogger(role="megatron actor", logger=logger) - def update_policy(self, dataloader: Iterable[DataProto]) -> dict: - """Update the policy with an iterator of DataProto - - Args: - dataloader (Iterable[DataProto]): an iterator over the DataProto that returns by ``make_minibatch_iterator`` - The keys of each data batch is described in the make_minibatch_iterator. - - Returns: - Dict: a dictionary containing the statistics. Note that the statistics are only valid in the last pp stage - and users have to combine the output in each dp rank manually. - - """ - metrics = {} - self.prof.start() - for data in dataloader: - data.to(get_device_id()) - self.actor_optimizer.zero_grad() - # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm - for chunk in self.actor_module: - # if use distributed optimizer, zero grad buffer will be handled by optimizer - chunk.zero_grad_buffer() - - calculate_entropy = self.config.entropy_coeff != 0 - if data.meta_info.get("micro_batch_size", None) is not None: - micro_batch_size = data.meta_info["micro_batch_size"] - else: - micro_batch_size = self.config.ppo_micro_batch_size_per_gpu - max_token_len = None - if self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size - metric_micro_batch = self.forward_backward_batch( - data, - calculate_entropy=calculate_entropy, - use_dynamic_bsz=self.config.use_dynamic_bsz, - micro_batch_size=micro_batch_size, - max_token_len=max_token_len, - mini_batch_size=self.config.ppo_mini_batch_size, - ) - metric_micro_batch = metric_micro_batch["output"] - for metric in metric_micro_batch: - # Note that o[0] is metrics, o[1] is entropy, o[2] is response_mask - append_to_dict(metrics, metric[0]) # append the metric from this micro-batch to global metrics. - - update_successful, grad_norm, num_zeros_in_grad = self.actor_optimizer.step() - data = {"actor/grad_norm": grad_norm} - append_to_dict(metrics, data) - - if update_successful: - # allgather already execute in optimizer.step in new megatron - pass - else: - raise NotImplementedError - self.prof.step() - # add empty cache after each compute - self.prof.stop_and_save() - self.prof.stop_trace() - get_torch_device().empty_cache() - return metrics diff --git a/verl/workers/critic/__init__.py b/verl/workers/critic/__init__.py deleted file mode 100644 index 80808f106..000000000 --- a/verl/workers/critic/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .base import BasePPOCritic -from .dp_critic import DataParallelPPOCritic - -__all__ = ["BasePPOCritic", "DataParallelPPOCritic"] diff --git a/verl/workers/critic/base.py b/verl/workers/critic/base.py deleted file mode 100644 index 8201758f3..000000000 --- a/verl/workers/critic/base.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Base class for a critic -""" - -from abc import ABC, abstractmethod - -import torch - -from verl import DataProto - -__all__ = ["BasePPOCritic"] - - -class BasePPOCritic(ABC): - def __init__(self, config): - super().__init__() - self.config = config - - @abstractmethod - def compute_values(self, data: DataProto) -> torch.Tensor: - """Compute values""" - pass - - @abstractmethod - def update_critic(self, data: DataProto): - """Update the critic""" - pass diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py deleted file mode 100644 index 4d7c87ef7..000000000 --- a/verl/workers/critic/dp_critic.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Implement a multiprocess PPOCritic -""" - -import logging -import os - -import torch -import torch.distributed -from torch import nn, optim -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -from verl import DataProto -from verl.trainer.ppo import core_algos -from verl.utils.device import get_device_name, is_cuda_available, is_npu_available -from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_ -from verl.utils.profiler import GPUMemoryLogger -from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import prepare_dynamic_batch, restore_dynamic_batch -from verl.utils.torch_functional import masked_mean -from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs -from verl.workers.critic import BasePPOCritic - -if is_cuda_available: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input -elif is_npu_available: - from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class DataParallelPPOCritic(BasePPOCritic): - def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer): - super().__init__(config=config) - self.critic_module = critic_module - self.critic_optimizer = critic_optimizer - self.use_remove_padding = self.config.model.get("use_remove_padding", False) - print(f"Critic use_remove_padding={self.use_remove_padding}") - - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) - self.device_name = get_device_name() - - def _forward_micro_batch(self, micro_batch): - response_length = micro_batch["responses"].size(-1) - multi_modal_inputs = {} - if "multi_modal_inputs" in micro_batch.keys(): - for key in micro_batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 - ) - - with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): - input_ids = micro_batch["input_ids"] - batch, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) - - if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - if position_ids.dim() == 3: - position_ids_rmpad = ( - index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) - .transpose(0, 1) - .unsqueeze(1) - ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) - else: - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # pad and slice the inputs if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size - ) - - # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.critic_module( - input_ids=input_ids_rmpad, - attention_mask=None, - position_ids=position_ids_rmpad, - **multi_modal_inputs, - use_cache=False, - ) # prevent model thinks we are generating - - if hasattr(self.critic_module, "v_head"): - # For trl.AutoModelForCausalLMWithValueHead - values_rmpad = output[2].squeeze(0).unsqueeze(-1) - else: - values_rmpad = output.logits - values_rmpad = values_rmpad.squeeze(0) # (total_nnz) - - # gather output if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - values_rmpad = gather_outputs_and_unpad( - values_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - - # pad it back - values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1) - values = values[:, -response_length - 1 : -1] - else: - output = self.critic_module( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - **multi_modal_inputs, - use_cache=False, - ) # prevent model thinks we are generating - if hasattr(self.critic_module, "v_head"): - # For trl.AutoModelForCausalLMWithValueHead - values = output[2] - else: - values = output.logits - values = values[:, -response_length - 1 : -1].squeeze(-1) - return values - - def _optimizer_step(self): - assert self.config.grad_clip is not None - - if isinstance(self.critic_module, FSDP): - grad_norm = self.critic_module.clip_grad_norm_(self.config.grad_clip) - elif isinstance(self.critic_module, FSDPModule): - grad_norm = fsdp2_clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) - else: - grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_module.parameters(), max_norm=self.config.grad_clip) - - # if grad_norm is not finite, skip the update - if not torch.isfinite(grad_norm): - print(f"WARN: grad_norm is not finite: {grad_norm}") - self.critic_optimizer.zero_grad() - else: - self.critic_optimizer.step() - return grad_norm - - @GPUMemoryLogger(role="dp critic", logger=logger) - def compute_values(self, data: DataProto) -> torch.Tensor: - self.critic_module.eval() - micro_batch_size = data.meta_info["micro_batch_size"] - use_dynamic_bsz = data.meta_info["use_dynamic_bsz"] - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - select_keys = ["responses", "input_ids", "response_mask", "attention_mask", "position_ids"] - non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] - - data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) - - if use_dynamic_bsz: - max_token_len = data.meta_info["max_token_len"] * self.ulysses_sequence_parallel_size - micro_batches, batch_idx_list = prepare_dynamic_batch(data, max_token_len=max_token_len) - else: - micro_batches = data.split(micro_batch_size) - - values_lst = [] - for micro_batch in micro_batches: - model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} - with torch.no_grad(): - values = self._forward_micro_batch(model_inputs) - values_lst.append(values) - values = torch.concat(values_lst, dim=0) - - if use_dynamic_bsz: - values = restore_dynamic_batch(values, batch_idx_list) - - response_mask = data.batch["response_mask"] - values = values * response_mask # Only action tokens have values - return values - - @GPUMemoryLogger(role="dp critic", logger=logger) - def update_critic(self, data: DataProto): - # make sure we are in training mode - self.critic_module.train() - metrics = {} - - select_keys = ["input_ids", "responses", "response_mask", "attention_mask", "position_ids", "values", "returns"] - has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys() - non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else [] - - data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys) - - # Split to make minibatch iterator for updating the actor - # See PPO paper for details. https://arxiv.org/abs/1707.06347 - mini_batches = data.split(self.config.ppo_mini_batch_size) - - for _ in range(self.config.ppo_epochs): - for batch_idx, mini_batch in enumerate(mini_batches): - if self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, _ = prepare_dynamic_batch(mini_batch, max_token_len=max_token_len) - else: - self.gradient_accumulation = ( - self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - ) - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) - - self.critic_optimizer.zero_grad() - - for micro_batch in micro_batches: - micro_batch_metrics = {} - model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} - response_mask = model_inputs["response_mask"] - values = model_inputs["values"] - returns = model_inputs["returns"] - - vpreds = self._forward_micro_batch(model_inputs) - vf_loss, vf_clipfrac = core_algos.compute_value_loss( - vpreds=vpreds, - values=values, - returns=returns, - response_mask=response_mask, - cliprange_value=self.config.cliprange_value, - loss_agg_mode=self.config.loss_agg_mode, - ) - if self.config.use_dynamic_bsz: - # relative to the dynamic bsz - loss = vf_loss * (response_mask.shape[0] / self.config.ppo_mini_batch_size) - else: - loss = vf_loss / self.gradient_accumulation - - loss.backward() - - micro_batch_metrics.update( - { - "critic/vf_loss": vf_loss.detach().item(), - "critic/vf_clipfrac": vf_clipfrac.detach().item(), - "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), - } - ) - - append_to_dict(metrics, micro_batch_metrics) - - grad_norm = self._optimizer_step() - mini_batch_metrics = {"critic/grad_norm": grad_norm.detach().item()} - append_to_dict(metrics, mini_batch_metrics) - self.critic_optimizer.zero_grad() - return metrics diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py deleted file mode 100644 index 1d44a8876..000000000 --- a/verl/workers/critic/megatron_critic.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Implement a multiprocess PPOCritic -""" - -import itertools -import logging -import os -from functools import partial -from typing import Iterable - -import torch -import torch.distributed -from megatron.core import parallel_state as mpu -from megatron.core.optimizer import DistributedOptimizer, OptimizerConfig -from megatron.core.pipeline_parallel import get_forward_backward_func -from omegaconf import OmegaConf -from torch import nn - -from verl import DataProto -from verl.trainer.ppo import core_algos -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.megatron.pipeline_parallel import make_batch_generator -from verl.utils.profiler import GPUMemoryLogger -from verl.utils.py_functional import append_to_dict -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches -from verl.utils.torch_functional import broadcast_dict_tensor, masked_mean -from verl.workers.critic import BasePPOCritic - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class MegatronPPOCritic(BasePPOCritic): - def __init__( - self, - config, - model_config, - hf_config, - tf_config, - critic_module: nn.ModuleList, - critic_optimizer: DistributedOptimizer, - critic_optimizer_config: OptimizerConfig, - ): - super().__init__(config=config) - self._validate_config(config) - self.model_config = model_config - self.hf_config = hf_config # huggingface config - self.tf_config = tf_config # mcore transformer config - - self.critic_module = critic_module - self.critic_optimizer = critic_optimizer - self.critic_optimizer_config = critic_optimizer_config - - # we create a separate nametuple for optimizer step so that global args won't affect it. - self.optimizer_step_args = OmegaConf.create( - { - "skip_grad": None, - "overlap_dp_param_comm": False, - "overlap_dp_grad_comm": False, - "gradient_accumulation_steps": 1, - "sequence_parallel": self.tf_config.sequence_parallel, - "DDP_impl": "local", - "layernorm_allreduce_bucket_threshold": 0, - "pipeline_model_parallel_split_rank": None, - "reduce_grads_use_alltoall": False, - } - ) - - def _validate_config(self, config) -> None: - """Validate config options not implemented for Megatron backend""" - assert config.get("ulysses_sequence_parallel_size", 1) == 1 - if config.shuffle: - assert config.data_loader_seed is not None, "If shuffle dataloader, seed must be manually set" - if config.megatron.tensor_model_parallel_size == 1: - print("[Warining] Because critic tp size == 1, set sp to False") - config.megatron.sequence_parallel = False - self.config = config - - @GPUMemoryLogger("megatron critic", logger=logger) - def compute_values(self, data: DataProto) -> DataProto: - data.to(get_device_id()) - responses = data.batch["responses"] - attention_mask = data.batch["attention_mask"] - use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) - micro_batch_size = data.meta_info.get("micro_batch_size", None) - max_token_len = data.meta_info.get("max_token_len", None) - assert micro_batch_size is not None, "micro batch size is needed for forward compute" - if use_dynamic_bsz: - assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" - max_token_len = max_token_len * self.config.megatron.context_parallel_size - response_length = responses.size(1) - with torch.no_grad(): - output = self.forward_backward_batch( - data=data, - forward_only=True, - use_dynamic_bsz=use_dynamic_bsz, - micro_batch_size=micro_batch_size, - max_token_len=max_token_len, - mini_batch_size=None, - ) - if mpu.is_pipeline_last_stage(ignore_virtual=True): - # only on last rank. It should be on every tp rank - values = [o["vpreds"] for o in output["output"]] # (bs, seq_size, vocal_size) - values = torch.cat(values, dim=0).to(torch.float32) - if use_dynamic_bsz: - indices = output["indices"] - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == values.size(0), f"{len(indices)} vs. {values.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - values = values[revert_indices] - else: - values = torch.empty_like(attention_mask, dtype=torch.float32) - - # each tp ranks should contain the same value - values = values[ - :, -response_length - 1 : -1 - ] # Values are predicted at the ends of prefixes, e.g., the last prompt token - response_mask = attention_mask[:, -response_length:] - values = values * response_mask # Only action tokens have values - values = values.contiguous() - - # sync among pp ranks - torch.distributed.broadcast( - tensor=values, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - ) - - # add empty cache after each compute - get_torch_device().empty_cache() - - return values - - def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: - select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"] - data = data.select(batch_keys=select_keys) - return data.make_iterator( - mini_batch_size=self.config.ppo_mini_batch_size, - epochs=self.config.ppo_epochs, - seed=self.config.data_loader_seed, - dataloader_kwargs={"shuffle": self.config.shuffle}, - ) - - def forward_backward_batch( - self, - data: DataProto, - forward_only=False, - use_dynamic_bsz=False, - micro_batch_size=None, - max_token_len=None, - mini_batch_size=None, - ): - # broadcast from last pp rank to all other pp ranks - mini_batch = data - mini_batch.to(get_device_id()) - mini_batch.batch = mini_batch.batch.contiguous() - broadcast_dict_tensor( - mini_batch.batch, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - ) - # split into micro-batches - mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) - - indices = None - if use_dynamic_bsz: - assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" - vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - if vpp_size is not None and vpp_size > 1: - microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage - micro_batches, indices = rearrange_micro_batches( - batch=mini_batch.batch, - num_batches_divided_by=microbatch_group_size_per_vp_stage, - max_token_len=max_token_len, - ) - assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( - f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " - f"{microbatch_group_size_per_vp_stage} for megatron backend" - ) - else: - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) - total_seqlen = max_token_len - else: - assert micro_batch_size is not None, ( - "micro_batch_size is needed to be passed in when not using dynamic batch size" - ) - micro_batches = mini_batch.batch.split(micro_batch_size) - seq_len = micro_batches[0]["input_ids"].shape[1] - total_seqlen = micro_batch_size * seq_len - n_micro_batch = len(micro_batches) - - forward_backward_func = get_forward_backward_func() - - def loss_func(output, data, meta_info): - nonlocal use_dynamic_bsz - - if forward_only: - return torch.tensor(1.0, device=output.device), {"vpreds": output} - - responses = data["responses"] - attention_mask = data["attention_mask"] - values = data["values"] - returns = data["returns"] - response_length = responses.size(1) - - response_mask = attention_mask[:, -response_length:] - - cliprange_value = self.config.cliprange_value - - vpreds = output # (bs, sequence_length) - vpreds = vpreds[:, -response_length - 1 : -1] - - vf_loss, vf_clipfrac = core_algos.compute_value_loss( - vpreds=vpreds, - values=values, - returns=returns, - response_mask=response_mask, - cliprange_value=cliprange_value, - loss_agg_mode=self.config.loss_agg_mode, - ) - - stats = { - "critic/vf_loss": vf_loss.detach().item(), - "critic/vf_clipfrac": vf_clipfrac.detach().item(), - "critic/vpred_mean": masked_mean(vpreds, response_mask).detach().item(), - } - - return vf_loss, stats - - def forward_step(batch_iter, model): - batch = next(batch_iter) - input_ids = batch["input_ids"] - attention_mask = batch["attention_mask"] - position_ids = batch["position_ids"] - from verl.models.mcore import get_mcore_forward_fn - - forward_fn = get_mcore_forward_fn(self.hf_config) - - output = forward_fn( - model, - input_ids, - attention_mask, - position_ids, - sequence_parallel=self.tf_config.sequence_parallel, - value_model=True, - ) - - return output, partial(loss_func, data=batch, meta_info={}) - - # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.critic_module)) - - # TODO: we may use the new schedule instead - # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) - if mpu.get_pipeline_model_parallel_world_size() > 1: - losses_reduced = forward_backward_func( - forward_step_func=forward_step, - data_iterator=batch_generator, - model=self.critic_module, - num_microbatches=n_micro_batch, - seq_length=total_seqlen, # no use when input_shapes was set - micro_batch_size=1, # no use when input_shapes was set - forward_only=forward_only, - ) - else: - losses_reduced = forward_backward_func( - forward_step_func=forward_step, - data_iterator=batch_generator, - model=self.critic_module, - num_microbatches=n_micro_batch, - seq_length=total_seqlen, # in use for pp = 1 - micro_batch_size=1, # in use for pp = 1 - forward_only=forward_only, - ) - # loss_reduces contains the stats returned from loss_func - losses_reduced = {"output": losses_reduced} - if use_dynamic_bsz: - losses_reduced["indices"] = indices - return losses_reduced - - @GPUMemoryLogger("megatron critic", logger=logger) - def update_critic(self, dataloader: Iterable[DataProto]): - metrics = {} - - for data in dataloader: - # data = data.batch.to(self.critic_module.device) - self.critic_optimizer.zero_grad() - # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm - for chunk in self.critic_module: - chunk.zero_grad_buffer() - - micro_batch_size = self.config.ppo_micro_batch_size_per_gpu - max_token_len = None - if self.config.use_dynamic_bsz: - max_token_len = self.config.ppo_max_token_len_per_gpu * self.config.megatron.context_parallel_size - metric_micro_batch = self.forward_backward_batch( - data, - forward_only=False, - use_dynamic_bsz=self.config.use_dynamic_bsz, - micro_batch_size=micro_batch_size, - max_token_len=max_token_len, - mini_batch_size=self.config.ppo_mini_batch_size, - ) - metric_micro_batch = metric_micro_batch["output"] - update_successful, grad_norm, num_zeros_in_grad = self.critic_optimizer.step() - learning_rate = self.critic_optimizer.param_groups[-1]["lr"] - data = {"critic/grad_norm": grad_norm, "critic/lr": learning_rate} - append_to_dict(metrics, data) - - if update_successful: - # allgather already execute in optimizer.step in new megatron - pass - else: - raise NotImplementedError - - for metric in metric_micro_batch: - append_to_dict(metrics, metric) # append the metric from this micro-batch to global metrics. - - # add empty cache after each compute - get_torch_device().empty_cache() - return metrics diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py deleted file mode 100644 index 30e117000..000000000 --- a/verl/workers/fsdp_workers.py +++ /dev/null @@ -1,1709 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -The main entry point to run the PPO algorithm -""" - -import json -import logging -import os -import warnings -from dataclasses import asdict -from typing import Any - -import psutil -import torch -import torch.distributed -import torch.distributed as dist -from codetiming import Timer -from omegaconf import DictConfig, OmegaConf, open_dict -from peft import LoraConfig, TaskType, get_peft_model -from safetensors.torch import save_file -from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -import verl.utils.torch_functional as verl_F -from verl import DataProto -from verl.models.transformers.monkey_patch import apply_monkey_patch -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import Dispatch, register -from verl.utils import hf_processor, hf_tokenizer -from verl.utils.activation_offload import enable_activation_offloading -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager -from verl.utils.config import omega_conf_to_dataclass -from verl.utils.device import ( - get_device_id, - get_device_name, - get_nccl_backend, - get_torch_device, - is_cuda_available, - is_npu_available, -) -from verl.utils.flops_counter import FlopsCounter -from verl.utils.fs import copy_to_local -from verl.utils.fsdp_utils import ( - CPUOffloadPolicy, - MixedPrecisionPolicy, - apply_fsdp2, - fsdp2_load_full_state_dict, - fsdp_version, - get_fsdp_wrap_policy, - get_init_weight_context_manager, - init_fn, - layered_summon_lora_params, - load_fsdp_model_to_gpu, - load_fsdp_optimizer, - offload_fsdp_model_to_cpu, - offload_fsdp_optimizer, -) -from verl.utils.import_utils import import_external_libs -from verl.utils.model import compute_position_id_with_mask -from verl.utils.profiler import DistProfiler, DistProfilerExtension, log_gpu_memory_usage, simple_timer -from verl.utils.profiler.performance import reduce_timing -from verl.utils.py_functional import convert_to_regular_types -from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -device_name = get_device_name() - - -def create_device_mesh(world_size, fsdp_size): - if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) - else: - device_mesh = init_device_mesh( - device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] - ) - return device_mesh - - -def get_sharding_strategy(device_mesh): - from torch.distributed.fsdp import ShardingStrategy - - if device_mesh.ndim == 1: - sharding_strategy = ShardingStrategy.FULL_SHARD - elif device_mesh.ndim == 2: - sharding_strategy = ShardingStrategy.HYBRID_SHARD - else: - raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") - return sharding_strategy - - -class ActorRolloutRefWorker(Worker, DistProfilerExtension): - """ - This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy - or a hybrid engine based on the config.rollout - """ - - def __init__(self, config: DictConfig, role: str, **kwargs): - Worker.__init__(self) - - self.config = config - self.profile_option = kwargs.get("profile_option", None) - import torch.distributed - - if not torch.distributed.is_initialized(): - rank = int(os.environ.get("RANK", 0)) - world_size = int(os.environ.get("WORLD_SIZE", 1)) - torch.distributed.init_process_group( - backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", - rank=rank, - world_size=world_size, - init_method=os.environ.get("DIST_INIT_METHOD", None), - ) - - # build device mesh for FSDP - world_size = torch.distributed.get_world_size() - # TODO(sgm): support FSDP hybrid shard for larger model - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size) - - # build device mesh for Ulysses Sequence Parallel - self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) - dp = world_size // self.ulysses_sequence_parallel_size - if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh( - device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] - ) - - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - self._lora_rank = self.config.model.get("lora_rank", 0) - self._is_lora = self._lora_rank > 0 - - self.role = role - assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] - - self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] - self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] - self._is_ref = self.role in ["ref", "actor_rollout_ref"] - - # TODO(haibin.lin): - # As of now the type of config is DictConfig, if we assign config.profiler with ProfilerConfig, - # it will actually convert the ProfilerConfig dataclass back to a DictConfig. - # We can still use ProfilerConfig for testing purpose (tests/utils/test_nvtx_profile.py) - # as they provides DictConfig-like interface - # The benefit of creating the dataclass config is to perform validation during __post_init__ - profiler_config = omega_conf_to_dataclass(config.get("profiler")) - DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=profiler_config, option=self.profile_option) - ) - - self._is_offload_param = False - self._is_offload_optimizer = False - if self._is_actor: - self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False) - self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False) - elif self._is_ref: - # TODO: it seems that manual offload is slowly than FSDP offload - self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False) - - # normalize config - if self._is_actor: - self.config.actor.ppo_mini_batch_size *= self.config.rollout.n - self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size - assert self.config.actor.ppo_mini_batch_size > 0, ( - f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after " - f"normalization" - ) - # micro bsz - if self.config.actor.ppo_micro_batch_size is not None: - self.config.actor.ppo_micro_batch_size //= ( - self.device_mesh.size() // self.ulysses_sequence_parallel_size - ) - self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size - - if self.config.actor.ppo_micro_batch_size_per_gpu is not None: - assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( - f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by " - f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" - ) - assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( - f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than " - f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" - ) - - # normalize rollout config - if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: - self.config.rollout.log_prob_micro_batch_size //= ( - self.device_mesh.size() // self.ulysses_sequence_parallel_size - ) - self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size - # normalize ref config - if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: - self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size - self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size - - def _build_model_optimizer( - self, - model_path, - fsdp_config, - optim_config, - override_model_config, - use_remove_padding=False, - use_fused_kernels=False, - enable_gradient_checkpointing=False, - trust_remote_code=False, - use_liger=False, - role="actor", - enable_activation_offload=False, - ): - from torch import optim - from torch.distributed.fsdp import CPUOffload, MixedPrecision - from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq - - from verl.utils.model import get_generation_config, print_model_size, update_model_config - from verl.utils.torch_dtypes import PrecisionType - - assert role in ["actor", "ref"] - - log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger) - local_path = model_path - - # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect - # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) - self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) - - if self.config.model.get("custom_chat_template", None) is not None: - if self.processor is not None: - self.processor.chat_template = self.config.model.custom_chat_template - else: - self.tokenizer.chat_template = self.config.model.custom_chat_template - - torch_dtype = fsdp_config.get("model_dtype", None) - if torch_dtype is None: - torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 - else: - torch_dtype = PrecisionType.to_dtype(torch_dtype) - - # override model kwargs - actor_model_config = AutoConfig.from_pretrained( - local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" - ) - - # patch for kimi-vl - if getattr(actor_model_config, "model_type", None) == "kimi_vl": - actor_model_config.text_config.topk_method = "greedy" - - self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code) - - override_config_kwargs = { - "bos_token_id": self.tokenizer.bos_token_id, - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - } - override_config_kwargs.update(override_model_config) - update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) - if self.rank == 0: - print(f"Model config after override: {actor_model_config}") - - # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang - init_context = get_init_weight_context_manager( - use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh - ) - - with init_context(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): - actor_module_class = AutoModelForVision2Seq - else: - actor_module_class = AutoModelForCausalLM - - actor_module = actor_module_class.from_pretrained( - pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, - config=actor_model_config, - trust_remote_code=trust_remote_code, - ) - - # Apply Liger kernel to the model if use_liger is set to True - if use_liger: - from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance - - _apply_liger_kernel_to_instance(model=actor_module) - - fused_kernel_options = self.config.model.get("fused_kernel_options", None) - fused_kernels_backend = ( - fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None - ) - - apply_monkey_patch( - model=actor_module, - use_remove_padding=use_remove_padding, - ulysses_sp_size=self.ulysses_sequence_parallel_size, - use_fused_kernels=use_fused_kernels, - fused_kernels_backend=fused_kernels_backend, - ) - - # some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2 - actor_module.to(torch_dtype) - - if enable_gradient_checkpointing: - actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - if self._is_lora: - print("Applying LoRA to actor module") - actor_module.enable_input_require_grads() - # Convert config to regular Python types before creating PEFT model - lora_config = { - "task_type": TaskType.CAUSAL_LM, - "r": self.config.model.lora_rank, - "lora_alpha": self.config.model.lora_alpha, - "target_modules": convert_to_regular_types(self.config.model.target_modules), - "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules), - "bias": "none", - } - actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) - torch.distributed.barrier() - - if self.rank == 0: - print_model_size(actor_module) - - log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger) - - # We wrap FSDP for rollout as well - mixed_precision_config = fsdp_config.get("mixed_precision", None) - if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) - reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) - buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) - else: - param_dtype = torch.bfloat16 - reduce_dtype = torch.float32 - buffer_dtype = torch.float32 - - mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - - auto_wrap_policy = get_fsdp_wrap_policy( - module=actor_module, - config=fsdp_config.get("wrap_policy", None), - is_lora=self.config.model.get("lora_rank", 0) > 0, - ) - - if self._is_rollout and self.config.rollout.name == "hf": - # TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma - auto_wrap_policy = None - - if self.rank == 0: - print(f"wrap_policy: {auto_wrap_policy}") - - fsdp_mesh = self.device_mesh - sharding_strategy = get_sharding_strategy(fsdp_mesh) - - # TODO: add transformer policy - # We force reference policy to use CPUOffload to save memory. - # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation - cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) - fsdp_strategy = self.config.actor.strategy - if fsdp_strategy == "fsdp": - actor_module_fsdp = FSDP( - actor_module, - cpu_offload=cpu_offload, - param_init_fn=init_fn, - auto_wrap_policy=auto_wrap_policy, - device_id=get_device_id(), - sharding_strategy=sharding_strategy, # zero3 - mixed_precision=mixed_precision, - sync_module_states=True, - device_mesh=self.device_mesh, - use_orig_params=self.config.actor.fsdp_config.get("use_orig_params", False), - forward_prefetch=self.config.actor.fsdp_config.get("forward_prefetch", False), - ) - elif fsdp_strategy == "fsdp2": - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" - mp_policy = MixedPrecisionPolicy( - param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True - ) - if role == "actor" and fsdp_config.offload_policy: - cpu_offload = CPUOffloadPolicy(pin_memory=True) - self._is_offload_param = False - self._is_offload_optimizer = False - else: - cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) - - fsdp_kwargs = { - "mesh": fsdp_mesh, - "mp_policy": mp_policy, - "offload_policy": cpu_offload, - "reshard_after_forward": fsdp_config.reshard_after_forward, - } - full_state = actor_module.state_dict() - apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) - fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload) - actor_module_fsdp = actor_module - else: - raise NotImplementedError(f"not implement {fsdp_strategy}") - - if enable_activation_offload: - enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing) - - log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) - - # TODO: add more optimizer args into config - if role == "actor" and optim_config is not None: - from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup - - actor_optimizer = optim.AdamW( - actor_module_fsdp.parameters(), - lr=optim_config.lr, - betas=optim_config.get("betas", (0.9, 0.999)), - weight_decay=optim_config.get("weight_decay", 1e-2), - ) - - total_steps = optim_config.get("total_training_steps", 0) - num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) - warmup_style = optim_config.get("warmup_style", "constant") - min_lr_ratio = optim_config.get("min_lr_ratio", 0.0) - num_cycles = optim_config.get("num_cycles", 0.5) - if num_warmup_steps < 0: - num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) - num_warmup_steps = int(num_warmup_steps_ratio * total_steps) - - if self.rank == 0: - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") - - if warmup_style == "constant": - actor_lr_scheduler = get_constant_schedule_with_warmup( - optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps - ) - elif warmup_style == "cosine": - actor_lr_scheduler = get_cosine_schedule_with_warmup( - optimizer=actor_optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=total_steps, - min_lr_ratio=min_lr_ratio, - num_cycles=num_cycles, - ) - else: - raise NotImplementedError(f"Warmup style {warmup_style} is not supported") - - log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) - else: - actor_optimizer = None - actor_lr_scheduler = None - - return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config - - def _build_rollout(self, trust_remote_code=False): - from torch.distributed.device_mesh import init_device_mesh - - # TODO(sgm): support FSDP hybrid shard for larger model - infer_tp = self.config.rollout.tensor_model_parallel_size - dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, ( - f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - ) - rollout_device_mesh = init_device_mesh( - device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] - ) - rollout_name = self.config.rollout.name - if rollout_name == "hf": - from verl.workers.rollout import HFRollout - from verl.workers.sharding_manager.base import BaseShardingManager - - rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout) - rollout_sharding_manager = BaseShardingManager() - # TODO: a sharding manager that do nothing? - - elif rollout_name == "vllm": - from verl.workers.rollout.vllm_rollout import vLLMRollout - from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager - - log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) - local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False)) - lora_kwargs = ( - {"lora_kwargs": {"enable_lora": True, "max_loras": 1, "max_lora_rank": self._lora_rank}} - if self._is_lora - else {} - ) - # lora_kwargs = {} - from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout - - vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout - rollout = vllm_rollout_cls( - model_path=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - device_mesh=rollout_device_mesh, - trust_remote_code=trust_remote_code, - **lora_kwargs, - ) - - log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) - full_params = torch.distributed.get_world_size() == 1 - rollout_sharding_manager = FSDPVLLMShardingManager( - module=self.actor_module_fsdp, - inference_engine=rollout.inference_engine, - model_config=self.actor_model_config, - rollout_config=self.config.rollout, - full_params=full_params, - device_mesh=rollout_device_mesh, - offload_param=self._is_offload_param, - load_format=self.config.rollout.load_format, - layered_summon=self.config.rollout.get("layered_summon", False), - ) - log_gpu_memory_usage("After building sharding manager", logger=logger) - - elif rollout_name == "sglang": - from verl.workers.rollout.sglang_rollout import SGLangRollout - - # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to - # SGLang's model_runner would check CUDA device capability. However, due to verl's setting, - # the main process of ray can not find any CUDA device, which would potentially lead to: - # "RuntimeError: No CUDA GPUs are available". - # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and - # we import it here use the abs path. - # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 - from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager - - local_path = copy_to_local(self.config.model.path) - log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) - rollout = SGLangRollout( - actor_module=local_path, - config=self.config.rollout, - processing_class=self.processor if self.processor is not None else self.tokenizer, - model_hf_config=self.actor_model_config, - trust_remote_code=trust_remote_code, - ) - log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) - - if torch.distributed.get_world_size() == 1: - self.config.rollout.load_format = "dummy_hf" - rollout_sharding_manager = FSDPSGLangShardingManager( - module=self.actor_module_fsdp, - inference_engine=rollout._engine, - model_config=self.actor_model_config, - rollout_config=self.config.rollout, - full_params="hf" in self.config.rollout.load_format, - device_mesh=rollout_device_mesh, - offload_param=self._is_offload_param, - multi_stage_wake_up=self.config.rollout.multi_stage_wake_up, - ) - log_gpu_memory_usage("After building sharding manager", logger=logger) - - else: - raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported") - - return rollout, rollout_sharding_manager - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - from verl.workers.actor import DataParallelPPOActor - - # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get("external_lib", None)) - - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) - - use_remove_padding = self.config.model.get("use_remove_padding", False) - use_shm = self.config.model.get("use_shm", False) - use_fused_kernels = self.config.model.get("use_fused_kernels", False) - - if self._is_actor or self._is_rollout: - # we need the model for actor and rollout - if self._is_actor: - optim_config = self.config.actor.optim - fsdp_config = self.config.actor.fsdp_config - else: - optim_config = None - fsdp_config = OmegaConf.create() - - local_path = copy_to_local(self.config.model.path, use_shm=use_shm) - ( - self.actor_module_fsdp, - self.actor_optimizer, - self.actor_lr_scheduler, - self.actor_model_config, - ) = self._build_model_optimizer( - model_path=local_path, - fsdp_config=fsdp_config, - optim_config=optim_config, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - use_fused_kernels=use_fused_kernels, - enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), - trust_remote_code=self.config.model.get("trust_remote_code", False), - use_liger=self.config.model.get("use_liger", False), - role="actor", - enable_activation_offload=self.config.model.get("enable_activation_offload", False), - ) - - # get the original unwrapped module - if fsdp_version(self.actor_module_fsdp) == 1: - self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage("After offload actor model during init", logger=logger) - - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) - - if self._is_actor: - OmegaConf.set_struct(self.config.actor, True) - with open_dict(self.config.actor): - self.config.actor.use_remove_padding = use_remove_padding - self.config.actor.use_fused_kernels = use_fused_kernels - self.actor = DataParallelPPOActor( - config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer - ) - - if self._is_rollout: - self.rollout, self.rollout_sharding_manager = self._build_rollout( - trust_remote_code=self.config.model.get("trust_remote_code", False) - ) - - if self._is_ref: - local_path = copy_to_local(self.config.model.path, use_shm=use_shm) - self.ref_module_fsdp = self._build_model_optimizer( - model_path=local_path, - fsdp_config=self.config.ref.fsdp_config, - optim_config=None, - override_model_config=override_model_config, - use_remove_padding=use_remove_padding, - use_fused_kernels=use_fused_kernels, - trust_remote_code=self.config.model.get("trust_remote_code", False), - use_liger=self.config.model.get("use_liger", False), - role="ref", - )[0] - OmegaConf.set_struct(self.config.ref, True) - with open_dict(self.config.ref): - self.config.ref.use_remove_padding = use_remove_padding - self.config.ref.use_fused_kernels = use_fused_kernels - self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) - - if self._is_actor: - self.flops_counter = FlopsCounter(self.actor_model_config) - self.checkpoint_manager = FSDPCheckpointManager( - model=self.actor_module_fsdp, - optimizer=self.actor.actor_optimizer, - lr_scheduler=self.actor_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_config=self.config.actor.checkpoint, - ) - - if not self._is_actor and self._is_rollout: - # If ActorRolloutRefWorker is initialized as a standalone rollout, - # create a checkpoint manager for FSDP model to allow loading FSDP checkpoints for rollout. - - checkpoint_contents = OmegaConf.create({"load_contents": ["model"], "save_contents": []}) - self.checkpoint_manager = FSDPCheckpointManager( - model=self.actor_module_fsdp, - optimizer=None, - lr_scheduler=None, - processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_config=checkpoint_contents, - ) - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - @DistProfiler.annotate(color="red", role="actor_update") - def update_actor(self, data: DataProto): - # Support all hardwares - data = data.to(get_device_id()) - - assert self._is_actor - if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) - - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) - # perform training - with Timer(name="update_policy", logger=None) as timer: - metrics = self.actor.update_policy(data=data) - delta_time = timer.last - global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/actor"] = ( - estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size - ) - metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) - metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) - metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) - - lr = self.actor_lr_scheduler.get_last_lr()[0] - metrics["actor/lr"] = lr - self.actor_lr_scheduler.step() - - # TODO: here, we should return all metrics - output = DataProto(meta_info={"metrics": metrics}) - - output = self.ulysses_sharding_manager.postprocess_data(data=output) - output = output.to("cpu") - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage("After offload actor model during update_actor", logger=logger) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) - - return output - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - @DistProfiler.annotate(color="red", role="rollout_generate") - def generate_sequences(self, prompts: DataProto): - # Support all hardwares - prompts = prompts.to(get_device_id()) - - assert self._is_rollout - - meta_info = { - "eos_token_id": self.generation_config.eos_token_id - if self.generation_config is not None - else self.tokenizer.eos_token_id, - "pad_token_id": self.generation_config.pad_token_id - if self.generation_config is not None - else self.tokenizer.pad_token_id, - } - prompts.meta_info.update(meta_info) - timing_generate = {} - with self.rollout_sharding_manager: - log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) - - prompts = self.rollout_sharding_manager.preprocess_data(prompts) - with simple_timer("generate_sequences", timing_generate): - output = self.rollout.generate_sequences(prompts=prompts) - - log_gpu_memory_usage("After rollout generation", logger=logger) - - output = self.rollout_sharding_manager.postprocess_data(output) - - timing_generate.update(self.rollout_sharding_manager.timing) - # We calculate the average timing across all ranks - # to make sure meta_info["timing"] is the same - timing_generate = reduce_timing(timing_generate) - output.meta_info["timing"] = timing_generate - output = output.to("cpu") - - # clear kv cache - get_torch_device().empty_cache() - return output - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") - def compute_log_prob(self, data: DataProto): - # when is_lora is True, we use the actor without lora applied to calculate the log_prob - # which is mostly used for ref log_prob calculation - assert self._is_actor - if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - - # Support all hardwares - from contextlib import nullcontext - - is_lora = data.meta_info.pop("is_lora", False) - adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() - data = data.to(get_device_id()) - # we should always recompute old_log_probs when it is HybridEngine - data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz - data.meta_info["temperature"] = self.config.rollout.temperature - # perform recompute log_prob - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) - with adapter_ctx: - output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) - output = DataProto.from_dict( - tensors={"old_log_probs": output, "entropys": entropys}, - meta_info={"temperature": self.config.rollout.temperature}, - ) - output = self.ulysses_sharding_manager.postprocess_data(output) - - output = output.to("cpu") - - # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes - # unshard the root FSDP module - if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1: - self.actor.actor_module._handle.reshard(True) - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger) - - return output - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") - def compute_ref_log_prob(self, data: DataProto): - if self._is_lora: - # if _is_lora, actor without lora applied is the ref - data.meta_info["is_lora"] = True - data = self.compute_log_prob(data) - # this old_log_probs is in fact ref_log_prob - data = DataProto.from_dict(tensors={"ref_log_prob": data.batch["old_log_probs"]}) - return data - assert self._is_ref - # else: - # otherwise, the class have a standalone ref model - # Support all hardwares - data = data.to(get_device_id()) - - micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu - data.meta_info["micro_batch_size"] = micro_batch_size - data.meta_info["temperature"] = self.config.rollout.temperature - data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data) - output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) - output = DataProto.from_dict(tensors={"ref_log_prob": output}) - output = self.ulysses_sharding_manager.postprocess_data(output) - - output = output.to("cpu") - - # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes - # unshard the root FSDP module - if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1: - self.ref_policy.actor_module._handle.reshard(True) - - return output - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): - from verl.utils.logger import log_with_rank - - # only support save and load ckpt for actor - assert self._is_actor - - if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - - self.checkpoint_manager.save_checkpoint( - local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep - ) - dist.barrier() - - if self._is_lora and hasattr(getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"): - lora_save_path = os.path.join(local_path, "lora_adapter") - peft_model = getattr(self, "actor_module", self.actor_module_fsdp) - peft_config = {} - if dist.get_rank() == 0: - os.makedirs(lora_save_path, exist_ok=True) - peft_config = asdict(peft_model.peft_config.get("default", {})) - peft_config["task_type"] = peft_config["task_type"].value - peft_config["peft_type"] = peft_config["peft_type"].value - peft_config["target_modules"] = list(peft_config["target_modules"]) - try: - if fsdp_version(self.actor_module_fsdp) > 0: - self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name()) - lora_params = layered_summon_lora_params(self.actor_module_fsdp) - if dist.get_rank() == 0: - save_file(lora_params, os.path.join(lora_save_path, "adapter_model.safetensors")) - with open(os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8") as f: - json.dump(peft_config, f, ensure_ascii=False, indent=4) - except Exception as e: - log_with_rank( - f"Save LoRA Adapter Error ({e})", rank=dist.get_rank(), logger=logger, log_only_rank_0=True - ) - - dist.barrier() - log_with_rank( - f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}", - rank=dist.get_rank(), - logger=logger, - log_only_rank_0=True, - ) - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): - assert self._is_actor or (not self._is_actor and self._is_rollout), ( - f"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got " - f"{self._is_actor} and {self._is_rollout}" - ) - - if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) - - self.checkpoint_manager.load_checkpoint( - local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load - ) - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - - if self._is_offload_optimizer: - offload_fsdp_optimizer(self.actor_optimizer) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def start_profile(self, **kwargs) -> None: - """Start profiling for the current rank in the current training step.""" - self.profiler.start(**kwargs) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def stop_profile(self) -> None: - """Stop profiling for the current rank in the current training step.""" - self.profiler.stop() - - -class CriticWorker(Worker, DistProfilerExtension): - def __init__(self, config): - Worker.__init__(self) - DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) - ) - import torch.distributed - - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group( - backend=get_nccl_backend(), init_method=os.environ.get("DIST_INIT_METHOD", None) - ) - self.config = config - - # build device mesh for Ulysses Sequence Parallel - world_size = torch.distributed.get_world_size() - from torch.distributed.device_mesh import init_device_mesh - - fsdp_size = self.config.model.fsdp_config.fsdp_size - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) - - self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) - dp = world_size // self.ulysses_sequence_parallel_size - if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh( - device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] - ) - - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - - # set FSDP offload params - self._is_offload_param = self.config.model.fsdp_config.param_offload - self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload - - # normalize config - self.config.ppo_mini_batch_size *= self.config.rollout_n - self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size - if self.config.ppo_micro_batch_size is not None: - self.config.ppo_micro_batch_size //= ( - torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size - ) - self.config.forward_micro_batch_size //= ( - torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size - ) - self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size - self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size - - if self.config.ppo_micro_batch_size_per_gpu is not None: - assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( - f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by " - f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" - ) - assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( - f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than " - f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" - ) - self._is_lora = self.config.model.get("lora_rank", 0) > 0 - - def _build_critic_model_optimizer(self, config): - # the following line is necessary - from torch import optim - from torch.distributed.fsdp import MixedPrecision - - from verl.utils.model import load_valuehead_model, print_model_size - from verl.utils.torch_dtypes import PrecisionType - - use_shm = config.model.get("use_shm", False) - local_path = copy_to_local(config.model.path, use_shm=use_shm) - # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info - # using random initialized model from any architecture. May not be the same as Actor. - - tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm) - self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) - self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) - - if self.config.model.get("custom_chat_template", None) is not None: - if self.processor is not None: - self.processor.chat_template = self.config.model.custom_chat_template - else: - self.tokenizer.chat_template = self.config.model.custom_chat_template - - override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) - override_config_kwargs = { - "bos_token_id": self.tokenizer.bos_token_id, - "eos_token_id": self.tokenizer.eos_token_id, - "pad_token_id": self.tokenizer.pad_token_id, - } - override_config_kwargs.update(override_config) - if self.rank == 0: - print(f"Critic overriding config {override_config_kwargs}") - - torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") - torch_dtype = PrecisionType.to_dtype(torch_dtype) - - from transformers import AutoConfig - - critic_model_config = AutoConfig.from_pretrained( - local_path, - attn_implementation="flash_attention_2", - trust_remote_code=config.model.get("trust_remote_code", False), - ) - critic_model_config.num_labels = 1 - # patch for kimi-vl - if getattr(critic_model_config, "model_type", None) == "kimi_vl": - critic_model_config.text_config.topk_method = "greedy" - - init_context = get_init_weight_context_manager( - use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh - ) - - with init_context(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - critic_model_config.classifier_dropout = 0.0 - critic_model_config.hidden_dropout = "0" - critic_model_config.summary_dropout_prob = 0.0 - - critic_module = load_valuehead_model( - local_path, - torch_dtype, - critic_model_config, - config.model.get("trust_remote_code", False), - ) - - use_remove_padding = config.model.get("use_remove_padding", False) - - apply_monkey_patch( - model=critic_module, - use_remove_padding=use_remove_padding, - ulysses_sp_size=self.ulysses_sequence_parallel_size, - ) - - # some parameters may not in torch_dtype - critic_module.to(torch_dtype) - - if config.model.get("enable_gradient_checkpointing", False): - critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - - if self._is_lora: - print("Applying LoRA to critic module") - critic_module.enable_input_require_grads() - # Convert config to regular Python types before creating PEFT model - lora_config = { - "task_type": TaskType.CAUSAL_LM, - "r": self.config.model.lora_rank, - "lora_alpha": self.config.model.lora_alpha, - "target_modules": convert_to_regular_types(self.config.model.target_modules), - "bias": "none", - } - critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) - - if self.rank == 0: - print_model_size(critic_module) - - self.critic_model_config = critic_model_config - - fsdp_config = self.config.model.fsdp_config - mixed_precision_config = fsdp_config.get("mixed_precision", None) - if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) - reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) - buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) - else: - param_dtype = torch.bfloat16 - reduce_dtype = torch.float32 - buffer_dtype = torch.float32 - - mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - - auto_wrap_policy = get_fsdp_wrap_policy( - module=critic_module, - config=self.config.model.fsdp_config.wrap_policy, - is_lora=self.config.model.get("lora_rank", 0) > 0, - ) - - log_gpu_memory_usage("Before critic FSDP", logger=None) - - fsdp_mesh = self.device_mesh - sharding_strategy = get_sharding_strategy(fsdp_mesh) - - # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation - if config.strategy == "fsdp": - critic_module = FSDP( - critic_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=get_device_id(), - sharding_strategy=sharding_strategy, - mixed_precision=mixed_precision, - sync_module_states=True, - forward_prefetch=self.config.model.fsdp_config.forward_prefetch, - device_mesh=self.device_mesh, - cpu_offload=None, - ) - elif config.strategy == "fsdp2": - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" - mp_policy = MixedPrecisionPolicy( - param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True - ) - offload_policy = None - if fsdp_config.offload_policy: - self._is_offload_param = False - self._is_offload_optimizer = False - offload_policy = CPUOffloadPolicy(pin_memory=True) - - fsdp_kwargs = { - "mesh": fsdp_mesh, - "mp_policy": mp_policy, - "offload_policy": offload_policy, - "reshard_after_forward": fsdp_config.reshard_after_forward, - } - full_state = critic_module.state_dict() - apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) - fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy) - else: - raise NotImplementedError(f"Unknown strategy {config.strategy}") - - if config.model.get("enable_activation_offload", False): - enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) - enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing) - - log_gpu_memory_usage("After critic FSDP", logger=None) - - critic_optimizer = optim.AdamW( - critic_module.parameters(), - lr=config.optim.lr, - betas=config.optim.get("betas", (0.9, 0.999)), - weight_decay=config.optim.get("weight_decay", 1e-2), - ) - - total_steps = config.optim.get("total_training_steps", 0) - num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) - warmup_style = config.optim.get("warmup_style", "constant") - if num_warmup_steps < 0: - num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) - num_warmup_steps = int(num_warmup_steps_ratio * total_steps) - - if self.rank == 0: - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") - - from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup - - if warmup_style == "constant": - critic_lr_scheduler = get_constant_schedule_with_warmup( - optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps - ) - elif warmup_style == "cosine": - critic_lr_scheduler = get_cosine_schedule_with_warmup( - optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps - ) - else: - raise NotImplementedError(f"Warmup style {warmup_style} is not supported") - - return critic_module, critic_optimizer, critic_lr_scheduler - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get("external_lib", None)) - - from verl.workers.critic import DataParallelPPOCritic - - self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( - self.config - ) - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) - log_gpu_memory_usage("After offload critic model during init", logger=logger) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.critic_optimizer) - log_gpu_memory_usage("After offload critic optimizer during init", logger=logger) - - self.critic = DataParallelPPOCritic( - config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer - ) - - self.flops_counter = FlopsCounter(self.critic_model_config) - self.checkpoint_manager = FSDPCheckpointManager( - model=self.critic_module, - optimizer=self.critic_optimizer, - lr_scheduler=self.critic_lr_scheduler, - processing_class=self.processor if self.processor is not None else self.tokenizer, - checkpoint_config=self.config.checkpoint, - ) - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - @DistProfiler.annotate(color="cyan") - def compute_values(self, data: DataProto): - # Support all hardwares - data = data.to(get_device_id()) - - if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) - micro_batch_size = self.config.forward_micro_batch_size_per_gpu - data.meta_info["micro_batch_size"] = micro_batch_size - data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz - # perform forward computation - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) - values = self.critic.compute_values(data=data) - output = DataProto.from_dict(tensors={"values": values}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) - - output = output.to("cpu") - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) - return output - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - @DistProfiler.annotate(color="pink") - def update_critic(self, data: DataProto): - # Support all hardwares - data = data.to(get_device_id()) - if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) - if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id()) - - # perform forward computation - with self.ulysses_sharding_manager: - data = self.ulysses_sharding_manager.preprocess_data(data=data) - - with Timer(name="update_critic", logger=None) as timer: - metrics = self.critic.update_critic(data=data) - delta_time = timer.last - - global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size - - lr = self.critic_lr_scheduler.get_last_lr()[0] - metrics["critic/lr"] = lr - self.critic_lr_scheduler.step() - - output = DataProto(batch=None, meta_info={"metrics": metrics}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) - - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.critic_optimizer) - - output = output.to("cpu") - return output - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): - import torch - - if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) - - self.checkpoint_manager.save_checkpoint( - local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep - ) - - torch.distributed.barrier() - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): - import torch - - if self._is_offload_param: - load_fsdp_model_to_gpu(self.critic_module) - - self.checkpoint_manager.load_checkpoint( - local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load - ) - - torch.distributed.barrier() - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.critic_module) - - if self._is_offload_optimizer: - offload_fsdp_optimizer(self.critic_optimizer) - - -# TODO(sgm): we may need to extract it to dp_reward_model.py -class RewardModelWorker(Worker, DistProfilerExtension): - """ - Note that we only implement the reward model that is subclass of AutoModelForTokenClassification. - """ - - def __init__(self, config): - Worker.__init__(self) - DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) - ) - - import torch.distributed - - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group( - backend=get_nccl_backend(), init_method=os.environ.get("DIST_INIT_METHOD", None) - ) - self.config = config - - # build device mesh for Ulysses Sequence Parallel - world_size = torch.distributed.get_world_size() - from torch.distributed.device_mesh import init_device_mesh - - fsdp_size = self.config.model.fsdp_config.fsdp_size - self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) - - self.ulysses_device_mesh = None - self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) - dp = world_size // self.ulysses_sequence_parallel_size - if self.ulysses_sequence_parallel_size > 1: - self.ulysses_device_mesh = init_device_mesh( - device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] - ) - - self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) - - self.use_remove_padding = self.config.model.get("use_remove_padding", False) - - # normalize config - if self.config.micro_batch_size is not None: - self.config.micro_batch_size //= torch.distributed.get_world_size() - self.config.micro_batch_size_per_gpu = self.config.micro_batch_size - - def _build_model(self, config): - # the following line is necessary - from torch.distributed.fsdp import CPUOffload - from transformers import AutoConfig, AutoModelForTokenClassification - - use_shm = config.model.get("use_shm", False) - # download the checkpoint from hdfs - local_path = copy_to_local(config.model.path, use_shm=use_shm) - - if self.config.model.input_tokenizer is None: - self._do_switch_chat_template = False - else: - self._do_switch_chat_template = True - input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer, use_shm=use_shm) - self.input_tokenizer = hf_tokenizer( - input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False) - ) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False)) - - trust_remote_code = config.model.get("trust_remote_code", False) - model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) - model_config.num_labels = 1 - - # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect - init_context = get_init_weight_context_manager( - use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh - ) - - with init_context(), warnings.catch_warnings(): - warnings.simplefilter("ignore") - model_config.classifier_dropout = 0.0 - reward_module = AutoModelForTokenClassification.from_pretrained( - pretrained_model_name_or_path=local_path, - config=model_config, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - trust_remote_code=trust_remote_code, - ) - - apply_monkey_patch( - model=reward_module, - use_remove_padding=config.model.get("use_remove_padding", False), - ulysses_sp_size=self.ulysses_sequence_parallel_size, - ) - - reward_module.to(torch.bfloat16) - - auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) - - fsdp_mesh = self.device_mesh - sharding_strategy = get_sharding_strategy(fsdp_mesh) - - if config.strategy == "fsdp": - reward_module = FSDP( - reward_module, - param_init_fn=init_fn, - use_orig_params=False, - auto_wrap_policy=auto_wrap_policy, - device_id=get_device_id(), - sharding_strategy=sharding_strategy, # zero3 - sync_module_states=True, - cpu_offload=CPUOffload(offload_params=True), - forward_prefetch=self.config.model.fsdp_config.forward_prefetch, - device_mesh=self.device_mesh, - ) - elif config.strategy == "fsdp2": - assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" - cpu_offload = CPUOffloadPolicy(pin_memory=True) - fsdp_kwargs = { - "mesh": fsdp_mesh, - "offload_policy": cpu_offload, - "reshard_after_forward": config.model.fsdp_config.reshard_after_forward, - } - full_state = reward_module.state_dict() - apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config) - fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload) - else: - raise NotImplementedError(f"Unknown strategy: {config.strategy}") - return reward_module - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - # This is used to import external_lib into the huggingface systems - import_external_libs(self.config.model.get("external_lib", None)) - self.reward_module = self._build_model(config=self.config) - - def _forward_micro_batch(self, micro_batch): - if is_cuda_available: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input - elif is_npu_available: - from transformers.integrations.npu_flash_attention import ( - index_first_axis, - pad_input, - rearrange, - unpad_input, - ) - - from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs - - with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16): - input_ids = micro_batch["input_ids"] - batch_size, seqlen = input_ids.shape - attention_mask = micro_batch["attention_mask"] - position_ids = micro_batch["position_ids"] - if position_ids.dim() == 3: # qwen2vl mrope - position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen) - - if self.use_remove_padding: - input_ids_rmpad, indices, *_ = unpad_input( - input_ids.unsqueeze(-1), attention_mask - ) # input_ids_rmpad (total_nnz, ...) - input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) - - # unpad the position_ids to align the rotary - if position_ids.dim() == 3: - position_ids_rmpad = ( - index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices) - .transpose(0, 1) - .unsqueeze(1) - ) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen) - else: - position_ids_rmpad = index_first_axis( - rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices - ).transpose(0, 1) - - # pad and slice the inputs if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs( - input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size - ) - - # only pass input_ids and position_ids to enable flash_attn_varlen - output = self.reward_module( - input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False - ) - reward_rmpad = output.logits - reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz) - - # gather output if sp > 1 - if self.ulysses_sequence_parallel_size > 1: - reward_rmpad = gather_outputs_and_unpad( - reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size - ) - - # pad it back - rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1) - else: - output = self.reward_module( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False - ) - rm_score = output.logits # (batch_size, seq_len, 1) - rm_score = rm_score.squeeze(-1) - - # extract the result of the last valid token - eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) - rm_score = rm_score[torch.arange(batch_size), eos_mask_idx] - return rm_score - - def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): - batch_size = data.batch.batch_size[0] - # expand as token_level_reward - attention_mask = data.batch["attention_mask"] - position_ids = data.batch["position_ids"] - response_length = data.batch["responses"].shape[-1] - if position_ids.dim() == 3: # qwen2vl mrope [bs, 3, seq_len] - position_ids = position_ids[:, 0, :] - eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,) - token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen) - token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores - - # select the response part - token_level_scores = token_level_scores[:, -response_length:] - - return token_level_scores - - def _switch_chat_template(self, data: DataProto): - - # NOTE: added by Reasoning360. FIXME - raise NotImplementedError( - "We have revised the code to enable the model to not use the chat template. " - "The current code for _switch_chat_template still assumes the chat template is used. " - "So we raise an error here. " - "You can remove this error temporarily if you want to use the chat template." - ) - - src_max_length = data.batch["attention_mask"].shape[-1] - - src_tokenizer = self.input_tokenizer - target_tokenizer = self.tokenizer - - rm_input_ids = [] - rm_attention_mask = [] - - for i in range(data.batch.batch_size[0]): - # extract raw prompt - if isinstance(data.non_tensor_batch["raw_prompt"][i], list): - chat: list = data.non_tensor_batch["raw_prompt"][i] - else: - chat: list = data.non_tensor_batch["raw_prompt"][i].tolist() - - # extract response - response_ids = data.batch["responses"][i] - response_length = response_ids.shape[-1] - valid_response_length = data.batch["attention_mask"][i][-response_length:].sum() - valid_response_ids = response_ids[:valid_response_length] - - # decode - response = src_tokenizer.decode(valid_response_ids) - # remove bos and eos - response = response.replace(src_tokenizer.eos_token, "") - - chat.append({"role": "assistant", "content": response}) - - prompt_with_chat_template = target_tokenizer.apply_chat_template( - chat, add_generation_prompt=False, tokenize=False - ) - if self.rank == 0 and i == 0: - # for debugging purpose - print(f"Switch template. chat: {prompt_with_chat_template}") - - # the maximum length is actually determined by the reward model itself - max_length = self.config.get("max_length", src_max_length) - if max_length is None: - max_length = src_max_length - - model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) - input_ids, attention_mask = verl_F.postprocess_data( - input_ids=model_inputs["input_ids"], - attention_mask=model_inputs["attention_mask"], - max_length=max_length, - pad_token_id=target_tokenizer.pad_token_id, - left_pad=False, # right padding - truncation=self.config.get("truncation", "right"), - ) # truncate from the right - - rm_input_ids.append(input_ids) - rm_attention_mask.append(attention_mask) - - rm_input_ids = torch.cat(rm_input_ids, dim=0) - rm_attention_mask = torch.cat(rm_attention_mask, dim=0) - - rm_position_ids = compute_position_id_with_mask(rm_attention_mask) - - rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids} - - return DataProto.from_dict(rm_inputs) - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - @DistProfiler.annotate(color="brown") - def compute_rm_score(self, data: DataProto): - import itertools - - from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches - - # Support all hardwares - data = data.to(get_device_id()) - if self._do_switch_chat_template: - rm_data = self._switch_chat_template(data) - else: - rm_input_ids = data.batch["input_ids"] - rm_attention_mask = data.batch["attention_mask"] - rm_position_ids = data.batch["position_ids"] - rm_inputs = { - "input_ids": rm_input_ids, - "attention_mask": rm_attention_mask, - "position_ids": rm_position_ids, - } - rm_data = DataProto.from_dict(rm_inputs) - - # Support all hardwares - rm_data.batch = rm_data.batch.to(get_device_id()) - - # perform forward computation - with self.ulysses_sharding_manager: - rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data) - data = self.ulysses_sharding_manager.preprocess_data(data=data) - - use_dynamic_bsz = self.config.use_dynamic_bsz - if use_dynamic_bsz: - max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size - micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) - else: - micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) - output = [] - for micro_batch in micro_batches: - rm_score = self._forward_micro_batch(micro_batch) - output.append(rm_score) - scores = torch.cat(output, dim=0) # (batch_size) - - if use_dynamic_bsz: - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - scores = scores[revert_indices] - - token_level_scores = self._expand_to_token_level(data, scores) - # Note that this is only the scores, may not be the final rewards used to train RL - output = DataProto.from_dict(tensors={"rm_scores": token_level_scores}) - output = self.ulysses_sharding_manager.postprocess_data(data=output) - - # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes - # unshard the root FSDP module - if self.world_size > 1 and fsdp_version(self.reward_module) == 1: - self.reward_module._handle.reshard(True) - - output = output.to("cpu") - return output - - -# ================================= Async related workers ================================= -class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): - def _build_rollout(self, trust_remote_code=False): - rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code) - - # NOTE: rollout is not actually initialized here, it's deferred - # to be initialized by AsyncvLLMServer. - - self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size - self.vllm_dp_rank = int(os.environ["RANK"]) // self.vllm_tp_size - self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size - - # used for sleep/wake_up - rollout.sharding_manager = rollout_sharding_manager - - return rollout, rollout_sharding_manager - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def generate_sequences(self, prompts: DataProto): - raise NotImplementedError("AsyncActorRolloutRefWorker does not support generate_sequences") - - # ============================ vLLM related ============================ - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def execute_method(self, method: str | bytes, *args, **kwargs): - """Called by ExternalRayDistributedExecutor collective_rpc.""" - return self.rollout.execute_method(method, *args, **kwargs) - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def get_zeromq_address(self): - return self.rollout.get_zeromq_address() - - # ============================ SGLang related ============================ - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def chat_completion(self, json_request): - ret = await self.rollout.chat_completion(json_request) - return ret - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: - ret = await self.rollout.generate(prompt_ids, sampling_params, request_id) - return ret - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - async def wake_up(self): - if self.config.rollout.free_cache_engine: - await self.rollout.wake_up() - # return something to block the caller - return True - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - async def sleep(self): - if self.config.rollout.free_cache_engine: - await self.rollout.sleep() - # return something to block the caller - return True diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py deleted file mode 100644 index de7267dc9..000000000 --- a/verl/workers/megatron_workers.py +++ /dev/null @@ -1,1200 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -The main entry point to run the PPO algorithm -""" - -import datetime -import logging -import os -import time -from typing import Any - -import psutil -import torch -import torch.distributed -from codetiming import Timer -from megatron.core import parallel_state as mpu -from omegaconf import DictConfig, OmegaConf, open_dict - -from verl import DataProto -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.base.megatron.worker import MegatronWorker -from verl.utils import hf_tokenizer -from verl.utils.checkpoint.megatron_checkpoint_manager import MegatronCheckpointManager -from verl.utils.config import omega_conf_to_dataclass -from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device -from verl.utils.flops_counter import FlopsCounter -from verl.utils.fs import copy_to_local -from verl.utils.megatron_utils import ( - load_megatron_model_to_gpu, - load_megatron_optimizer, - offload_megatron_model_to_cpu, - offload_megatron_optimizer, -) -from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights -from verl.utils.profiler import ( - DistProfiler, - DistProfilerExtension, - GPUMemoryLogger, - log_gpu_memory_usage, - simple_timer, -) -from verl.utils.profiler.performance import reduce_timing -from verl.utils.torch_functional import broadcast_dict_tensor # NOTE: added by Reasoning360 -from verl.workers.actor.megatron_actor import MegatronPPOActor -from verl.workers.critic.megatron_critic import MegatronPPOCritic -from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -def set_random_seed(seed): - import random - - import numpy as np - import torch - - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) - if get_torch_device().device_count() > 0: - from megatron.core import tensor_parallel - - tensor_parallel.model_parallel_cuda_manual_seed(seed) - # FIXME: torch cumsum not support deterministic (used in vllm sampler), - # https://github.com/pytorch/pytorch/issues/89492 - # torch.use_deterministic_algorithms(True, warn_only=True) - # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' - - -def megatron_pp_dummy_output(data: DataProto): - from verl.single_controller.base.decorator import _make_dummy_data_proto - - if ( - mpu.get_pipeline_model_parallel_rank() != mpu.get_pipeline_model_parallel_world_size() - 1 # not the last stage - or mpu.get_tensor_model_parallel_rank() != 0 # not the first tensor parallel rank - ): - return _make_dummy_data_proto(data) - return data - - -class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension): - """ - This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy - or a hybrid engine based on the config.rollout - """ - - def __init__(self, config: DictConfig, role: str, **kwargs): - MegatronWorker.__init__(self) - self.config = config - - # NOTE(sgm): We utilize colocate WorkerGroup by default. - # As a result, Workers for different model share the same process. - # Therefore, we only require one distribute initialization. - # To utilize different parallel strategy in different models: - # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, - # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 - if not torch.distributed.is_initialized(): - rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group( - backend=get_nccl_backend(), - timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), - init_method=os.environ.get("DIST_INIT_METHOD", None), - ) - get_torch_device().set_device(rank) - - if self.config.actor.megatron.sequence_parallel: - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - mpu.initialize_model_parallel( - tensor_model_parallel_size=self.config.actor.megatron.tensor_model_parallel_size, - pipeline_model_parallel_size=self.config.actor.megatron.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=self.config.actor.megatron.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_split_rank=None, - use_sharp=False, - context_parallel_size=self.config.actor.megatron.context_parallel_size, - expert_model_parallel_size=self.config.actor.megatron.expert_model_parallel_size, - expert_tensor_parallel_size=self.config.actor.megatron.expert_tensor_parallel_size, - nccl_communicator_config_path=None, - ) - - set_random_seed(seed=self.config.actor.megatron.seed) - - self.role = role - assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] - - self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] - self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] - self._is_ref = self.role in ["ref", "actor_rollout_ref"] - - profiler_config = omega_conf_to_dataclass(config.get("profiler")) - DistProfilerExtension.__init__(self, DistProfiler(rank=self.rank, config=profiler_config)) - - # TODO(sgm): Currently, we only support reference model param offload - # will support other offload later - self._is_offload_param = False - self._is_offload_grad = False - self._is_offload_optimizer = False - - # normalize config - if self._is_actor and self._is_rollout: - self.config.actor.ppo_mini_batch_size *= self.config.rollout.n - self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() - if self.config.actor.get("ppo_micro_batch_size", None): - self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() - self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() - self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size - self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size - - self._is_offload_param = self.config.actor.megatron.get("param_offload", False) - self._is_offload_grad = self.config.actor.megatron.get("grad_offload", False) - self._is_offload_optimizer = self.config.actor.megatron.get("optimizer_offload", False) - elif self._is_ref: - if self.config.ref.get("log_prob_micro_batch_size", None): - self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() - self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size - else: - assert self.config.ref.get("log_prob_micro_batch_size_per_gpu", None) is not None, ( - "Please note that in the ref policy configuration, `log_prob_micro_batch_size_per_gpu` and " - "`log_prob_micro_batch_size` should not be None at the same time." - ) - self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False) - - def _build_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config): - from verl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler - from verl.utils.megatron_utils import get_model, init_megatron_optim_config - from verl.utils.model import get_generation_config, print_model_size - - self._init_hf_config_and_tf_config( - model_path, - model_path, - self.dtype, - override_model_config, - override_transformer_config, - self.config.model.get("trust_remote_code", False), - self.config.actor.megatron.use_mbridge, - ) - self.generation_config = get_generation_config(self.local_path) - - def make_model(wrap_with_ddp=False): - if self.bridge is not None: - from verl.models.mcore.mbridge import freeze_moe_router - - post_model_creation_callbacks = [] - if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): - post_model_creation_callbacks.append(freeze_moe_router) - return self.bridge.get_model( - post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=wrap_with_ddp - ) - else: - - def megatron_actor_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model - - parallel_model = init_mcore_model( - self.tf_config, - self.hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - value=False, - freeze_moe_router=override_model_config.get("moe_config", {}).get("freeze_moe_router", False), - ) - parallel_model.to(get_device_name()) - return parallel_model - - override_ddp_config = OmegaConf.to_container( - self.config.actor.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True - ) - return get_model( - megatron_actor_model_provider, - wrap_with_ddp=wrap_with_ddp, - use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, - override_ddp_config=override_ddp_config, - ) - - if self._is_actor and self._is_rollout: - actor_module = make_model(wrap_with_ddp=True) - print(f"actor_module: {len(actor_module)}") - if self.config.actor.load_weight: - if self.config.actor.megatron.use_dist_checkpointing: - load_mcore_dist_weights( - actor_module, self.config.actor.megatron.dist_checkpointing_path, is_value_model=False - ) - else: - if self.bridge is not None: - local_model_path = get_hf_model_path(self.config) - self.bridge.load_weights(actor_module, local_model_path) - else: - load_megatron_gptmodel_weights( - self.config, self.hf_config, actor_module, params_dtype=self.dtype, is_value_model=False - ) - - if self.rank == 0: - print_model_size(actor_module[0]) - log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) - elif self._is_ref: - print(f"self.config.ref.load_weight: {self.config.ref.load_weight}") - ref_module = make_model(wrap_with_ddp=False) - if self.config.ref.load_weight: # should align with the actor: - assert self.config.actor.load_weight == self.config.ref.load_weight - print("load ref weight start") - if self.config.ref.megatron.use_dist_checkpointing: - load_mcore_dist_weights( - ref_module, self.config.ref.megatron.dist_checkpointing_path, is_value_model=False - ) - else: - if self.bridge is not None: - local_model_path = get_hf_model_path(self.config) - self.bridge.load_weights(ref_module, local_model_path) - else: - load_megatron_gptmodel_weights( - self.config, self.hf_config, ref_module, params_dtype=self.dtype, is_value_model=False - ) - log_gpu_memory_usage("After ref module init", logger=logger) - return ref_module, self.hf_config - - # TODO: add more optimizer args into config - if self._is_actor: - optim_config_megatron = init_megatron_optim_config(optim_config) - actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config_megatron) - actor_optimizer_scheduler = get_megatron_optimizer_param_scheduler( - optimizer=actor_optimizer, config=optim_config - ) - else: - optim_config = None - actor_optimizer = None - actor_optimizer_scheduler = None - - log_gpu_memory_usage("After actor optimizer init", logger=logger) - - return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config - - def _build_rollout(self, trust_remote_code=False): - from torch.distributed.device_mesh import init_device_mesh - - layer_name_mapping = { - "qkv_layer_name": "self_attention.linear_qkv.", - "gate_proj_layer_name": "linear_fc1.", - } - if self.config.rollout.name == "vllm": - from torch.distributed.device_mesh import init_device_mesh - - from verl.workers.rollout.vllm_rollout import vLLMRollout - from verl.workers.sharding_manager.megatron_vllm import MegatronVLLMShardingManager - - # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor, - # we will reorganize their weight format when resharding from actor to rollout. - - infer_tp = self.config.rollout.tensor_model_parallel_size - dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, ( - f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - ) - rollout_device_mesh = init_device_mesh( - get_device_name(), mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] - ) - log_gpu_memory_usage("Before building vllm rollout", logger=None) - - local_path = copy_to_local(self.config.model.path, use_shm=self.config.model.get("use_shm", False)) - from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout - - vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout - rollout = vllm_rollout_cls( - model_path=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - device_mesh=rollout_device_mesh, - trust_remote_code=trust_remote_code, - ) - log_gpu_memory_usage("After building vllm rollout", logger=logger) - - # perform weight resharding between actor and rollout - from verl.models.mcore import get_mcore_weight_converter - - weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) - sharding_manager = MegatronVLLMShardingManager( - inference_engine=rollout.inference_engine, - model_config=self.actor_model_config, - transformer_config=self.tf_config, - rollout_config=self.config.rollout, - layer_name_mapping=layer_name_mapping, - actor_module=self.actor.actor_module, - weight_converter=weight_converter, - device_mesh=rollout_device_mesh, - offload_param=self._is_offload_param, - bridge=self.bridge, - ) - log_gpu_memory_usage("After building sharding manager", logger=logger) - - elif self.config.rollout.name == "sglang": - from verl.workers.rollout.sglang_rollout import SGLangRollout - - # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's - # model_runner would check CUDA device capability. - # However, due to verl's setting, the main process of ray can not find any CUDA device, which would - # potentially lead to: "RuntimeError: No CUDA GPUs are available". - # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it - # here use the abs path. - # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 - from verl.workers.sharding_manager.megatron_sglang import MegatronSGLangShardingManager - - infer_tp = self.config.rollout.tensor_model_parallel_size - dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, ( - f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - ) - rollout_device_mesh = init_device_mesh( - "cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp") - ) - - local_path = copy_to_local(self.config.model.path) - log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None) - rollout = SGLangRollout( - actor_module=local_path, - config=self.config.rollout, - processing_class=self.processor if self.processor is not None else self.tokenizer, - model_hf_config=self.actor_model_config, - trust_remote_code=trust_remote_code, - device_mesh=rollout_device_mesh, - ) - log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None) - - from verl.models.mcore import get_mcore_weight_converter - - weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) - sharding_manager = MegatronSGLangShardingManager( - actor_module=self.actor.actor_module, - inference_engine=rollout._engine, - model_config=self.actor_model_config, - rollout_config=self.config.rollout, - transformer_config=self.tf_config, - layer_name_mapping=layer_name_mapping, - weight_converter=weight_converter, - bridge=self.bridge, - device_mesh=rollout_device_mesh, - offload_param=self._is_offload_param, - ) - log_gpu_memory_usage("After building sharding manager", logger=logger) - else: - raise NotImplementedError("Only vllmRollout is supported with Megatron now") - print(f"rollout and sharding manager init done sharding_manager: {sharding_manager}") - return rollout, sharding_manager - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - if self.config.model.get("external_lib", None) is not None: - # This is used to import external_lib into the huggingface systems - import importlib - - importlib.import_module(self.config.model.external_lib) - - from verl.utils.torch_dtypes import PrecisionType - - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) - if self._is_actor: - override_transformer_config = OmegaConf.to_container( - self.config.actor.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True - ) - elif self._is_ref: - override_transformer_config = OmegaConf.to_container( - self.config.ref.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True - ) - else: - override_transformer_config = None - self.param_dtype = torch.bfloat16 - log_gpu_memory_usage("Before init actor model and optimizer", logger=logger) - self.dtype = PrecisionType.to_dtype(self.param_dtype) - if self._is_actor or self._is_rollout: - # we need the model for actor and rollout - optim_config = self.config.actor.optim if self._is_actor else None - ( - self.actor_module, - self.actor_optimizer, - self.actor_optimizer_scheduler, - self.actor_model_config, - self.actor_optim_config, - ) = self._build_model_optimizer( - model_path=self.config.model.path, - optim_config=optim_config, - override_model_config=override_model_config, - override_transformer_config=override_transformer_config, - ) - if self._is_offload_param: - offload_megatron_model_to_cpu(self.actor_module) - log_gpu_memory_usage("After offload actor params and grad during init", logger=logger) - if self._is_offload_optimizer: - offload_megatron_optimizer(self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) - - if self._is_actor: - OmegaConf.set_struct(self.config.actor, True) - with open_dict(self.config.actor): - use_fused_kernels = self.config.model.get("use_fused_kernels", False) - self.config.actor.use_fused_kernels = use_fused_kernels - self.actor = MegatronPPOActor( - config=self.config.actor, - model_config=self.actor_model_config, - hf_config=self.hf_config, - tf_config=self.tf_config, - actor_module=self.actor_module, - actor_optimizer=self.actor_optimizer, - ) - log_gpu_memory_usage("After MegatronPPOActor init", logger=logger) - - if self._is_rollout: - self.rollout, self.sharding_manager = self._build_rollout( - trust_remote_code=self.config.model.get("trust_remote_code", False) - ) - # used for sleep/wake_up - self.rollout.sharding_manager = self.sharding_manager - log_gpu_memory_usage("After rollout init", logger=logger) - - if self._is_ref: - self.ref_module, self.ref_model_config = self._build_model_optimizer( - model_path=self.config.model.path, - optim_config=None, - override_model_config=override_model_config, - override_transformer_config=override_transformer_config, - ) - log_gpu_memory_usage("After ref model init", logger=logger) - self.ref_policy = MegatronPPOActor( - config=self.config.ref, - model_config=self.ref_model_config, - hf_config=self.hf_config, - tf_config=self.tf_config, - actor_module=self.ref_module, - actor_optimizer=None, - ) - if self._ref_is_offload_param: - offload_megatron_model_to_cpu(self.ref_module) - log_gpu_memory_usage("After offload ref params during init", logger=logger) - - if self._is_actor: - self.flops_counter = FlopsCounter(self.actor_model_config) - self.checkpoint_mananager = MegatronCheckpointManager( - config=self.config, - checkpoint_config=self.config.actor.checkpoint, - model_config=self.actor_model_config, - transformer_config=self.tf_config, - role="actor", - model=self.actor_module, - arch=self.architectures[0], - hf_config=self.hf_config, - param_dtype=self.param_dtype, - share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - processing_class=self.processor if self.processor is not None else self.tokenizer, - optimizer=self.actor_optimizer, - optimizer_scheduler=self.actor_optimizer_scheduler, - use_distributed_optimizer=self.config.actor.megatron.use_distributed_optimizer, - use_checkpoint_opt_param_scheduler=self.config.actor.optim.use_checkpoint_opt_param_scheduler, - bridge=self.bridge, - use_dist_checkpointing=self.config.actor.megatron.use_dist_checkpointing, - ) - get_torch_device().empty_cache() - log_gpu_memory_usage("After init_model finish", logger=logger) - - # Modified by Reasoning360. - @register(dispatch_mode=Dispatch.MEGATRON_PP_DUMMY_PROTO) - @GPUMemoryLogger(role="update_actor", logger=logger) - @DistProfiler.annotate(color="red") - def update_actor(self, data: DataProto): - assert self._is_actor - if self._is_offload_param: - load_megatron_model_to_gpu(self.actor_module) - log_gpu_memory_usage("After load actor params and grad during update_actor", logger=logger) - if self._is_offload_optimizer: - load_megatron_optimizer(self.actor_optimizer) - log_gpu_memory_usage("After load actor optimizer during update_actor", logger=logger) - data.batch = data.batch.to(get_device_name()) - - # NOTE: added by Reasoning360. - broadcast_dict_tensor( - data.batch, src=mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() - ) - - micro_batch_size = self.config.actor.ppo_micro_batch_size_per_gpu - data.meta_info["micro_batch_size"] = micro_batch_size - dataloader = self.actor.make_minibatch_iterator(data=data) - with Timer(name="update_policy", logger=None) as timer: - metrics = self.actor.update_policy(dataloader=dataloader) - delta_time = timer.last - global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size - metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) - metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) - metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) - from verl.utils.megatron.optimizer import get_megatron_last_lr - - metrics["actor/lr"] = get_megatron_last_lr(self.actor_optimizer) - self.actor_optimizer_scheduler.step(1) - - # TODO: here, we should return all metrics - output = DataProto(meta_info={"metrics": metrics}) - output = output.to("cpu") - - if self._is_offload_param: - offload_megatron_model_to_cpu(self.actor_module) - log_gpu_memory_usage("After offload actor params and grad during update_actor", logger=logger) - if self._is_offload_optimizer: - offload_megatron_optimizer(self.actor_optimizer) - log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) - - get_torch_device().empty_cache() - return output - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - @GPUMemoryLogger(role="generate_sequences", logger=logger) - @DistProfiler.annotate(color="red") - def generate_sequences(self, prompts: DataProto): - assert self._is_rollout - prompts.batch = prompts.batch.to(get_device_name()) - meta_info = { - "eos_token_id": self.generation_config.eos_token_id - if self.generation_config is not None - else self.tokenizer.eos_token_id, - "pad_token_id": self.generation_config.pad_token_id - if self.generation_config is not None - else self.tokenizer.pad_token_id, - } - prompts.meta_info.update(meta_info) - if self._is_offload_optimizer: - offload_megatron_optimizer(self.actor_optimizer) - - timing_generate = {} - with self.sharding_manager: - log_gpu_memory_usage("After entering sharding manager", logger=logger) - prompts = self.sharding_manager.preprocess_data(prompts) - with simple_timer("generate_sequences", timing_generate): - output = self.rollout.generate_sequences(prompts=prompts) - output = self.sharding_manager.postprocess_data(output) - log_gpu_memory_usage("After rollout generation", logger=logger) - - timing_generate.update(self.sharding_manager.timing) - # We calculate the average timing across all ranks - # to make sure meta_info["timing"] is the same - timing_generate = reduce_timing(timing_generate) - output.meta_info["timing"] = timing_generate - output = output.to("cpu") - # clear kv cache - get_torch_device().empty_cache() - return output - - # Modified by Reasoning360. - @register(dispatch_mode=Dispatch.MEGATRON_PP_DUMMY_PROTO) - @GPUMemoryLogger(role="compute_ref_log_prob", logger=logger) - @DistProfiler.annotate(color="olive") - def compute_ref_log_prob(self, data: DataProto): - assert self._is_ref - if self._ref_is_offload_param: - load_megatron_model_to_gpu(self.ref_module, load_grad=False) - log_gpu_memory_usage("After load ref params and grad during compute_ref_log_prob", logger=logger) - micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu - data.meta_info["micro_batch_size"] = micro_batch_size - data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz - data.meta_info["temperature"] = self.config.rollout.temperature - data = data.to(get_device_id()) - # NOTE: added by Reasoning360. - broadcast_dict_tensor( - data.batch, src=mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() - ) - - # NOTE: this function internally broadcasts the last stage's input and output to all ranks. - output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) - output = DataProto.from_dict(tensors={"ref_log_prob": output}) - output = output.to("cpu") - if self._ref_is_offload_param: - offload_megatron_model_to_cpu(self.ref_module) - log_gpu_memory_usage("After offload ref params and grad during compute_ref_log_prob", logger=logger) - get_torch_device().empty_cache() - # NOTE: added by Reasoning360. - return megatron_pp_dummy_output(output) - - # Modified by Reasoning360. - @register(dispatch_mode=Dispatch.MEGATRON_PP_DUMMY_PROTO) - @GPUMemoryLogger(role="compute_log_prob", logger=logger) - @DistProfiler.annotate(color="blue") - def compute_log_prob(self, data: DataProto): - assert self._is_actor - if self._is_offload_param: - load_megatron_model_to_gpu(self.actor_module, load_grad=False) - log_gpu_memory_usage("After load actor params and grad during compute_log_prob", logger=logger) - # we should always recompute old_log_probs when it is HybridEngine - data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu - data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz - data.meta_info["temperature"] = self.config.rollout.temperature - data = data.to(get_device_id()) - - # NOTE: added by Reasoning360. - broadcast_dict_tensor( - data.batch, src=mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() - ) - - output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True) - output = DataProto.from_dict( - tensors={"old_log_probs": output, "entropys": entropys}, - meta_info={"temperature": self.config.rollout.temperature}, - ) - output = output.to("cpu") - # clear kv cache - if self._is_offload_param: - offload_megatron_model_to_cpu(self.actor_module) - log_gpu_memory_usage("After offload actor params and grad during compute_log_prob", logger=logger) - get_torch_device().empty_cache() - - # NOTE: modified by Reasoning360. - return megatron_pp_dummy_output(output) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): - if self._is_offload_param: - load_megatron_model_to_gpu(self.actor_module) - self.checkpoint_mananager.load_checkpoint( - local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load - ) - if self._is_offload_param: - offload_megatron_model_to_cpu(self.actor_module) - if self._is_offload_optimizer: - offload_megatron_optimizer(self.actor_optimizer) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_pretrained_model(self, checkpoint_path, del_local_after_load=True): - pass - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): - if self._is_offload_param: - load_megatron_model_to_gpu(self.actor_module) - self.checkpoint_mananager.save_checkpoint( - local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep - ) - torch.distributed.barrier() - if self._is_offload_param: - offload_megatron_model_to_cpu(self.actor_module) - - -class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): - def _build_rollout(self, trust_remote_code=False): - rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code) - - # NOTE: rollout is not actually initialized here, it's deferred - # to be initialized by AsyncvLLMServer. - - self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size - self.vllm_dp_rank = int(os.environ["RANK"]) // self.vllm_tp_size - self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size - - # used for sleep/wake_up - rollout.sharding_manager = rollout_sharding_manager - - return rollout, rollout_sharding_manager - - # ============================ vLLM related ============================ - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def execute_method(self, method: str | bytes, *args, **kwargs): - """Called by ExternalRayDistributedExecutor collective_rpc.""" - if self.vllm_tp_rank == 0 and method != "execute_model": - print( - f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: " - f"{method if isinstance(method, str) else 'Callable'}" - ) - return self.rollout.execute_method(method, *args, **kwargs) - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def get_zeromq_address(self): - return self.rollout.get_zeromq_address() - - # ============================ SGLang related ============================ - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def chat_completion(self, json_request): - ret = await self.rollout.chat_completion(json_request) - return ret - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: - ret = await self.rollout.generate(prompt_ids, sampling_params, request_id) - return ret - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - async def wake_up(self): - if self.config.rollout.free_cache_engine: - await self.rollout.wake_up() - # return something to block the caller - return True - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - async def sleep(self): - if self.config.rollout.free_cache_engine: - await self.rollout.sleep() - # return something to block the caller - return True - - -class CriticWorker(MegatronWorker, DistProfilerExtension): - def __init__(self, config): - MegatronWorker.__init__(self) - DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) - ) - self.config = config - - # NOTE(sgm): We utilize colocate WorkerGroup by default. - # As a result, Workers for different model share the same process. - # Therefore, we only require one distribute initialization. - # To utilize different parallel strategy in different models: - # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, - # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 - if not torch.distributed.is_initialized(): - rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group( - backend=get_nccl_backend(), - timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), - init_method=os.environ.get("DIST_INIT_METHOD", None), - ) - get_torch_device().set_device(rank) - - if self.config.megatron.sequence_parallel: - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - mpu.initialize_model_parallel( - tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size, - pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_split_rank=None, - use_sharp=False, - context_parallel_size=self.config.megatron.context_parallel_size, - expert_model_parallel_size=self.config.megatron.expert_model_parallel_size, - expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size, - nccl_communicator_config_path=None, - ) - - set_random_seed(seed=self.config.megatron.seed) - - # set FSDP offload params - self._is_offload_param = self.config.megatron.param_offload - self._is_offload_optimizer = self.config.megatron.optimizer_offload - - # normalize config - self.config.ppo_mini_batch_size *= self.config.rollout_n - self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() - if self.config.get("ppo_micro_batch_size", None): - self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() - self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size - - # TODO(sgm): support critic model offload - - def _build_critic_model_optimizer( - self, model_path, optim_config, override_model_config, override_transformer_config - ): - from megatron.core.models.gpt.gpt_model import ModelType - - from verl.utils.megatron.optimizer import get_megatron_optimizer, get_megatron_optimizer_param_scheduler - from verl.utils.megatron_utils import get_model, init_megatron_optim_config - from verl.utils.model import print_model_size - - self._init_hf_config_and_tf_config( - model_path, - self.config.model.tokenizer_path, - self.dtype, - override_model_config, - override_transformer_config, - self.config.model.get("trust_remote_code", False), - self.config.megatron.use_mbridge, - ) - - if self.bridge is not None: - from verl.models.mcore.mbridge import freeze_moe_router, make_value_model - - post_model_creation_callbacks = [make_value_model] - if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): - post_model_creation_callbacks.append(freeze_moe_router) - critic_module = self.bridge.get_model( - post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=True - ) - else: - - def megatron_critic_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model - - parallel_model = init_mcore_model( - self.tf_config, - self.hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=False, - value=True, - freeze_moe_router=override_model_config.get("moe_config", {}).get("freeze_moe_router", False), - ) - parallel_model.to(get_device_name()) - return parallel_model - - override_ddp_config = OmegaConf.to_container( - self.config.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True - ) - # Step 3: initialize the megatron model - critic_module = get_model( - model_provider_func=megatron_critic_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=True, - use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, - override_ddp_config=override_ddp_config, - ) - # note that here critic_module will be a list to be compatible with the construction of interleaved pp (vpp). - # but here, we do not use pp (vpp) yet. For simplicity, we remove the list - # critic_module = nn.ModuleList(critic_module) - - if self.config.load_weight: - t0 = time.time() - if self.config.megatron.use_dist_checkpointing: - load_mcore_dist_weights( - critic_module, self.config.megatron.dist_checkpointing_path, is_value_model=True - ) - else: - if self.bridge is not None: - local_model_path = get_hf_model_path(self.config) - self.bridge.load_weights(critic_module, local_model_path) - else: - load_megatron_gptmodel_weights( - self.config, self.hf_config, critic_module, params_dtype=self.dtype, is_value_model=True - ) - t1 = time.time() - if torch.distributed.get_rank() == 0: - print(f"critic load_weight time: {t1 - t0}") - if self.rank == 0: - print_model_size(critic_module[0]) - - # TODO: add more optimizer args into config - optim_config_megatron = init_megatron_optim_config(optim_config) - critic_optimizer = get_megatron_optimizer(model=critic_module, config=optim_config_megatron) - critic_optimizer_scheduler = get_megatron_optimizer_param_scheduler( - optimizer=critic_optimizer, config=optim_config - ) - get_torch_device().empty_cache() - return critic_module, critic_optimizer, critic_optimizer_scheduler, self.hf_config, optim_config - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - # create critic - - from verl.utils.torch_dtypes import PrecisionType - - if self.config.model.get("external_lib", None) is not None: - # This is used to import external_lib into the huggingface systems - import importlib - - importlib.import_module(self.config.model.external_lib) - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) - override_transformer_config = OmegaConf.to_container( - self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True - ) - self.param_dtype = torch.bfloat16 - self.dtype = PrecisionType.to_dtype(self.param_dtype) - ( - self.critic_module, - self.critic_optimizer, - self.critic_optimizer_scheduler, - self.critic_model_config, - critic_optimizer_config, - ) = self._build_critic_model_optimizer( - model_path=self.config.model.path, - optim_config=self.config.optim, - override_model_config=override_model_config, - override_transformer_config=override_transformer_config, - ) - if self._is_offload_param: - offload_megatron_model_to_cpu(self.critic_module) - if self._is_offload_optimizer: - offload_megatron_optimizer(self.critic_optimizer) - - self.critic = MegatronPPOCritic( - config=self.config, - model_config=self.critic_model_config, - hf_config=self.hf_config, - tf_config=self.tf_config, - critic_module=self.critic_module, - critic_optimizer=self.critic_optimizer, - critic_optimizer_config=critic_optimizer_config, - ) - self.flops_counter = FlopsCounter(self.critic_model_config) - self.checkpoint_mananager = MegatronCheckpointManager( - config=self.config, - checkpoint_config=self.config.checkpoint, - model_config=self.critic_model_config, - transformer_config=self.tf_config, - role="critic", - model=self.critic_module, - arch=self.architectures[0], - hf_config=self.hf_config, - param_dtype=self.param_dtype, - share_embeddings_and_output_weights=False, - processing_class=self.processor if self.processor is not None else self.tokenizer, - optimizer=self.critic_optimizer, - optimizer_scheduler=self.critic_optimizer_scheduler, - use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, - use_checkpoint_opt_param_scheduler=self.config.optim.use_checkpoint_opt_param_scheduler, - bridge=self.bridge, - use_dist_checkpointing=self.config.megatron.use_dist_checkpointing, - ) - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - @DistProfiler.annotate(color="cyan") - def compute_values(self, data: DataProto): - micro_batch_size = self.config.ppo_micro_batch_size_per_gpu - data.meta_info["micro_batch_size"] = micro_batch_size - data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz - data = data.to(get_device_id()) - if self._is_offload_param: - load_megatron_model_to_gpu(self.critic_module) - values = self.critic.compute_values(data=data) - output = DataProto.from_dict(tensors={"values": values}) - output = output.to("cpu") - if self._is_offload_param: - offload_megatron_model_to_cpu(self.critic_module) - return output - - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - @DistProfiler.annotate(color="pink") - def update_critic(self, data: DataProto): - data = data.to(get_device_id()) - - if self._is_offload_param: - load_megatron_model_to_gpu(self.critic_module) - if self._is_offload_optimizer: - load_megatron_optimizer(self.critic_optimizer) - - dataloader = self.critic.make_minibatch_iterator(data) - with Timer(name="update_critic", logger=None) as timer: - metrics = self.critic.update_critic(dataloader=dataloader) - delta_time = timer.last - global_num_tokens = data.meta_info["global_token_num"] - estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) - metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size - from verl.utils.megatron.optimizer import get_megatron_last_lr - - metrics["critic/lr"] = get_megatron_last_lr(self.critic_optimizer) - self.critic_optimizer_scheduler.step(1) - - output = DataProto(batch=None, meta_info={"metrics": metrics}) - - if self._is_offload_param: - offload_megatron_model_to_cpu(self.critic_module) - if self._is_offload_optimizer: - offload_megatron_optimizer(self.critic_optimizer) - output = output.to("cpu") - return output - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def load_checkpoint(self, checkpoint_path, hdfs_path=None, del_local_after_load=True): - if self._is_offload_param: - load_megatron_model_to_gpu(self.critic_module) - self.checkpoint_mananager.load_checkpoint( - local_path=checkpoint_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load - ) - if self._is_offload_param: - offload_megatron_model_to_cpu(self.critic_module) - if self._is_offload_optimizer: - offload_megatron_optimizer(self.critic_optimizer) - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, checkpoint_path, hdfs_path=None, global_steps=0, max_ckpt_to_keep=None): - if self._is_offload_param: - load_megatron_model_to_gpu(self.critic_module) - self.checkpoint_mananager.save_checkpoint( - local_path=checkpoint_path, hdfs_path=hdfs_path, global_step=global_steps, max_ckpt_to_keep=max_ckpt_to_keep - ) - if self._is_offload_param: - offload_megatron_model_to_cpu(self.critic_module) - - -class RewardModelWorker(MegatronWorker, DistProfilerExtension): - """ - Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification. - """ - - def __init__(self, config): - MegatronWorker.__init__(self) - DistProfilerExtension.__init__( - self, DistProfiler(rank=self.rank, config=omega_conf_to_dataclass(config.get("profiler"))) - ) - self.config = config - - # NOTE(sgm): We utilize colocate WorkerGroup by default. - # As a result, Workers for different model share the same process. - # Therefore, we only require one distribute initialization. - # To utilize different parallel strategy in different models: - # 1, users should disable WorkerDict; 2.assign different ResourcePool to different models, - # 3. and apply the following patch in ray==2.10, https://github.com/ray-project/ray/pull/44385 - if not torch.distributed.is_initialized(): - rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group( - backend=get_nccl_backend(), - timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), - init_method=os.environ.get("DIST_INIT_METHOD", None), - ) - get_torch_device().set_device(rank) - - if self.config.megatron.sequence_parallel: - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - mpu.initialize_model_parallel( - tensor_model_parallel_size=self.config.megatron.tensor_model_parallel_size, - pipeline_model_parallel_size=self.config.megatron.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size=self.config.megatron.virtual_pipeline_model_parallel_size, - pipeline_model_parallel_split_rank=None, - use_sharp=False, - context_parallel_size=self.config.megatron.context_parallel_size, - expert_model_parallel_size=self.config.megatron.expert_model_parallel_size, - expert_tensor_parallel_size=self.config.megatron.expert_tensor_parallel_size, - nccl_communicator_config_path=None, - ) - - set_random_seed(seed=self.config.megatron.seed) - - # normalize config - if self.config.micro_batch_size is not None: - self.config.micro_batch_size //= mpu.get_data_parallel_world_size() - self.config.micro_batch_size_per_gpu = self.config.micro_batch_size - - def _build_rm_model(self, model_path, tokenizer, override_model_config, override_transformer_config): - from megatron.core.models.gpt.gpt_model import ModelType - - from verl.utils.megatron_utils import get_model - - self._init_hf_config_and_tf_config( - model_path, - tokenizer, - self.dtype, - override_model_config, - override_transformer_config, - self.config.model.get("trust_remote_code", False), - self.config.megatron.use_mbridge, - ) - if self.bridge is not None: - from verl.models.mcore.mbridge import freeze_moe_router, make_value_model - - post_model_creation_callbacks = [make_value_model] - if override_model_config.get("moe_config", {}).get("freeze_moe_router", False): - post_model_creation_callbacks.append(freeze_moe_router) - reward_model = self.bridge.get_model( - post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=False - ) - else: - - def megatron_rm_model_provider(pre_process, post_process): - from verl.models.mcore import init_mcore_model - - parallel_model = init_mcore_model( - self.tf_config, - self.hf_config, - pre_process, - post_process, - share_embeddings_and_output_weights=False, - value=True, - ) - parallel_model.to(get_device_name()) - return parallel_model - - # Step 3: initialize the megatron model - reward_model = get_model( - model_provider_func=megatron_rm_model_provider, - model_type=ModelType.encoder_or_decoder, - wrap_with_ddp=False, - use_distributed_optimizer=self.config.megatron.use_distributed_optimizer, - ) - # note that here reward_model will be a list to be compatible with the construction of interleaved pp (vpp) - # but here, we do not use pp (vpp) yet. For simplicity, we remove the list - # reward_model = nn.ModuleList(reward_model) - - if self.config.load_weight: - if self.config.megatron.use_dist_checkpointing: - load_mcore_dist_weights(reward_model, self.config.megatron.dist_checkpointing_path, is_value_model=True) - else: - if self.bridge is not None: - local_model_path = get_hf_model_path(self.config) - self.bridge.load_weights(reward_model, local_model_path) - else: - load_megatron_gptmodel_weights( - self.config, self.hf_config, reward_model, params_dtype=self.dtype, is_value_model=True - ) - - # TODO: add more optimizer args into config - get_torch_device().empty_cache() - return reward_model, self.hf_config - - @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def init_model(self): - # create critic - - from verl.utils.torch_dtypes import PrecisionType - - if self.config.model.get("external_lib", None) is not None: - # This is used to import external_lib into the huggingface systems - import importlib - - importlib.import_module(self.config.model.external_lib) - override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) - override_transformer_config = OmegaConf.to_container( - self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True - ) - - use_shm = self.config.model.get("use_shm", False) - sft_tokenizer_local_path = copy_to_local(self.config.model.input_tokenizer, use_shm=use_shm) - sft_tokenizer = hf_tokenizer(sft_tokenizer_local_path) - rm_tokenizer_path = self.config.model.get("rm_tokenizer", None) - rm_tokenizer = None - if rm_tokenizer_path is not None: - rm_tokenizer_local_path = copy_to_local(rm_tokenizer_path, use_shm=use_shm) - rm_tokenizer = hf_tokenizer( - rm_tokenizer_local_path, trust_remote_code=self.config.model.get("trust_remote_code", False) - ) - - self.param_dtype = torch.bfloat16 - self.dtype = PrecisionType.to_dtype(self.param_dtype) - - reward_model_module, reward_model_config = self._build_rm_model( - model_path=self.config.model.path, - tokenizer=rm_tokenizer, - override_model_config=override_model_config, - override_transformer_config=override_transformer_config, - ) - # FIXME(sgm): reward model param offload is implemented in MegatronRewardModel - # should be implemented in workers - self.rm = MegatronRewardModel( - config=self.config, - reward_model_module=reward_model_module, - model_config=reward_model_config, - hf_config=self.hf_config, - tf_config=self.tf_config, - sft_tokenizer=sft_tokenizer, - rm_tokenizer=rm_tokenizer, - ) - - # TODO: reward model use itself tokenizer instead of sft tokenizer - # the input_ids, responses, attention_mask and position_ids may be different! - @register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO) - @DistProfiler.annotate(color="brown") - def compute_rm_score(self, data: DataProto): - data.meta_info["micro_batch_size"] = self.config.micro_batch_size_per_gpu - data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu - data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz - data = data.to(get_device_id()) - output = self.rm.compute_reward(data) - output = output.to("cpu") - return output diff --git a/verl/workers/reward_manager/batch.py b/verl/workers/reward_manager/batch.py deleted file mode 100644 index 8d1b11228..000000000 --- a/verl/workers/reward_manager/batch.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2025 Individual Contributor: Mert Unsal -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict - -import torch - -from verl import DataProto -from verl.workers.reward_manager import register - - -@register("batch") -class BatchRewardManager: - """ - A batch reward manager that computes rewards for a batch of data. - - Args: - tokenizer (Tokenizer): The tokenizer to use for decoding the responses. - num_examine (int): The number of responses to examine. - compute_score (callable): The function to compute the rewards. - reward_fn_key (str): The key to use for the reward function. - reward_kwargs (dict): The keyword arguments to pass to the reward function. - """ - - def __init__(self, tokenizer, num_examine, compute_score, reward_fn_key="data_source", **reward_kwargs): - self.tokenizer = tokenizer - self.num_examine = num_examine - self.compute_score = compute_score - self.reward_fn_key = reward_fn_key - self.reward_kwargs = reward_kwargs - - def verify(self, data): - prompt_ids = data.batch["prompts"] - response_ids = data.batch["responses"] - attention_mask = data.batch["attention_mask"] - - prompt_len = prompt_ids.shape[-1] - valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1) - - responses_str = [] - for i in range(len(data)): - valid_len = valid_response_lengths[i] - valid_response_ids = response_ids[i][:valid_len] - response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) - responses_str.append(response_str) - - ground_truths = [item.non_tensor_batch["reward_model"].get("ground_truth", None) for item in data] - data_sources = data.non_tensor_batch[self.reward_fn_key] - extras = data.non_tensor_batch.get("extra_info", [None] * len(data)) - - scores = self.compute_score( - data_sources=data_sources, - solution_strs=responses_str, - ground_truths=ground_truths, - extra_infos=extras, - **self.reward_kwargs, - ) - - return scores - - def __call__(self, data: DataProto, return_dict=False): - # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if "rm_scores" in data.batch.keys(): - if return_dict: - return {"reward_tensor": data.batch["rm_scores"]} - else: - return data.batch["rm_scores"] - - reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) - reward_extra_info = defaultdict(list) - prompt_ids = data.batch["prompts"] - prompt_len = prompt_ids.shape[-1] - attention_mask = data.batch["attention_mask"] - valid_response_lengths = attention_mask[:, prompt_len:].sum(dim=-1) - data_sources = data.non_tensor_batch[self.reward_fn_key] - - scores = self.verify(data) - rewards = [] - already_printed = {} - - for i in range(len(data)): - length = valid_response_lengths[i].item() - score = scores[i] - - if isinstance(score, dict): - reward = score["score"] - for key, value in score.items(): - reward_extra_info[key].append(value) - else: - reward = score - - rewards.append(reward) - reward_tensor[i, length - 1] = reward - - data_source = data_sources[i] - if already_printed.get(data_source, 0) < self.num_examine: - response_str = self.tokenizer.decode(data.batch["responses"][i][:length], skip_special_tokens=True) - prompt_str = self.tokenizer.decode(data.batch["prompts"][i], skip_special_tokens=True) - ground_truth = data[i].non_tensor_batch["reward_model"].get("ground_truth", None) - print("[prompt]", prompt_str) - print("[response]", response_str) - print("[ground_truth]", ground_truth) - print("[score]", scores[i]) - already_printed[data_source] = already_printed.get(data_source, 0) + 1 - - data.batch["acc"] = torch.tensor(rewards, dtype=torch.float32, device=prompt_ids.device) - - if return_dict: - return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info} - else: - return reward_tensor diff --git a/verl/workers/reward_manager/naive.py b/verl/workers/reward_manager/naive.py deleted file mode 100644 index f6f979eef..000000000 --- a/verl/workers/reward_manager/naive.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import defaultdict - -import torch - -from verl import DataProto -from verl.utils.reward_score import default_compute_score -from verl.workers.reward_manager import register - - -@register("naive") -class NaiveRewardManager: - """The reward manager.""" - - def __init__(self, tokenizer, num_examine, compute_score=None, reward_fn_key="data_source") -> None: - """ - Initialize the NaiveRewardManager instance. - - Args: - tokenizer: The tokenizer used to decode token IDs into text. - num_examine: The number of batches of decoded responses to print to the console for debugging purpose. - compute_score: A function to compute the reward score. If None, `default_compute_score` will be used. - reward_fn_key: The key used to access the data source in the non-tensor batch data. Defaults to - "data_source". - """ - self.tokenizer = tokenizer # Store the tokenizer for decoding token IDs - self.num_examine = num_examine # the number of batches of decoded responses to print to the console - self.compute_score = compute_score or default_compute_score - self.reward_fn_key = reward_fn_key # Store the key for accessing the data source - - def __call__(self, data: DataProto, return_dict=False): - """We will expand this function gradually based on the available datasets""" - - # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if "rm_scores" in data.batch.keys(): - if return_dict: - return {"reward_tensor": data.batch["rm_scores"]} - else: - return data.batch["rm_scores"] - - reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) - reward_extra_info = defaultdict(list) - - already_print_data_sources = {} - - for i in range(len(data)): - data_item = data[i] # DataProtoItem - - prompt_ids = data_item.batch["prompts"] - - prompt_length = prompt_ids.shape[-1] - - valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() - valid_prompt_ids = prompt_ids[-valid_prompt_length:] - - response_ids = data_item.batch["responses"] - valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() - valid_response_ids = response_ids[:valid_response_length] - - # decode - prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) - response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) - - ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] - data_source = data_item.non_tensor_batch[self.reward_fn_key] - extra_info = data_item.non_tensor_batch.get("extra_info", {}) - num_turns = data_item.non_tensor_batch.get("__num_turns__", None) - extra_info["num_turns"] = num_turns - - score = self.compute_score( - data_source=data_source, - solution_str=response_str, - ground_truth=ground_truth, - extra_info=extra_info, - ) - - if isinstance(score, dict): - reward = score["score"] - # Store the information including original reward - for key, value in score.items(): - reward_extra_info[key].append(value) - else: - reward = score - - reward_tensor[i, valid_response_length - 1] = reward - - if data_source not in already_print_data_sources: - already_print_data_sources[data_source] = 0 - - if already_print_data_sources[data_source] < self.num_examine: - already_print_data_sources[data_source] += 1 - print("[prompt]", prompt_str) - print("[response]", response_str) - print("[ground_truth]", ground_truth) - if isinstance(score, dict): - for key, value in score.items(): - print(f"[{key}]", value) - else: - print("[score]", score) - - if return_dict: - return { - "reward_tensor": reward_tensor, - "reward_extra_info": reward_extra_info, - } - else: - return reward_tensor diff --git a/verl/workers/reward_manager/prime.py b/verl/workers/reward_manager/prime.py deleted file mode 100644 index f2c526b63..000000000 --- a/verl/workers/reward_manager/prime.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright 2024 PRIME team and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from concurrent.futures import ProcessPoolExecutor -from functools import partial -from typing import Callable, Optional - -import psutil -import torch -from transformers import PreTrainedTokenizer - -from verl import DataProto -from verl.utils.reward_score import default_compute_score -from verl.workers.reward_manager import register - - -async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0): - loop = asyncio.get_running_loop() - try: - # Ensure process_completion is called properly - future = loop.run_in_executor(executor, partial(evaluation_func, task, completion, reference, task_extra_info)) - return await asyncio.wait_for(future, timeout=timeout) - except asyncio.TimeoutError: - print(f"[Timeout] Task timeout: {completion}") - return None # Default value for timed-out rows - except Exception as e: - print(f"[Error] Task failed: {e}, completion: {completion[:80]}") - return None # Default value for failed rows - - -async def parallel_compute_score_async( - evaluation_func, completions, references, tasks, extra_info=None, num_processes=64 -): - if extra_info is None: - extra_info = [None] * len(tasks) - scores = [] - with ProcessPoolExecutor(max_workers=num_processes) as executor: - # to prevent very occasional starvation caused by some anomalous programs ( like infinite loop ), the - # exceptions in async programs will instantly halt the evaluation, and all summoned processes will be killed. - try: - # Create tasks for all rows - tasks_async = [ - single_compute_score(evaluation_func, c, r, t, ei, executor, timeout=300.0) - for c, r, t, ei in zip(completions, references, tasks, extra_info, strict=True) - ] - results = await asyncio.gather(*tasks_async, return_exceptions=False) - except Exception as e: - print(f"[Exception] async gather failed: {e}") - raise - finally: - terminated_count = 0 - for pid, proc in executor._processes.items(): - try: - p = psutil.Process(pid) - p.terminate() - try: - p.wait(timeout=5) - except psutil.TimeoutExpired: - p.kill() - terminated_count += 1 - except Exception: - pass - print(f"[Shutdown] {terminated_count} subprocess(es) terminated.") - - # Process results - for result, completion, reference, task in zip(results, completions, references, tasks, strict=True): - if isinstance(result, Exception) or result is None: - # Handle failed or timed-out tasks - scores.append(0.0) - elif isinstance(result, int | float | bool): - scores.append(float(result)) - else: - scores.append(float(result[0])) - return scores - - -def run_reward_scoring(evaluation_func, completions, references, tasks, extra_info=None, num_processes=64): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete( - parallel_compute_score_async(evaluation_func, completions, references, tasks, extra_info, num_processes) - ) - finally: - loop.close() - - -@register("prime") -class PrimeRewardManager: - """ - The Reward Manager used in https://github.com/PRIME-RL/PRIME - """ - - def __init__( - self, - tokenizer: PreTrainedTokenizer, - num_examine: int, - compute_score: Optional[Callable] = None, - reward_fn_key: str = "data_source", - ) -> None: - self.tokenizer = tokenizer - self.num_examine = num_examine # the number of batches of decoded responses to print to the console - self.compute_score = compute_score or default_compute_score - self.reward_fn_key = reward_fn_key - - def verify(self, data): - """ - verify the batch and save as ``acc`` tensor - """ - # batched scoring - prompt_ids = data.batch["prompts"] - - response_ids = data.batch["responses"] - sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) - ground_truth = [data_item.non_tensor_batch["reward_model"]["ground_truth"] for data_item in data] - data_sources = data.non_tensor_batch[self.reward_fn_key] - extra_info = data.non_tensor_batch.get("extra_info", None) - - assert len(sequences_str) == len(ground_truth) == len(data_sources) - try: - scores = run_reward_scoring( - self.compute_score, - completions=sequences_str, - references=ground_truth, - tasks=data_sources, - extra_info=extra_info, - num_processes=64, - ) - except asyncio.TimeoutError: - print("[Timeout] Global reward scoring timed out. Setting all as 0.") - scores = [0.0 for _ in range(len(sequences_str))] - except Exception as e: - print(f"[Error] Unexpected error during scoring. Setting all as 0. {e}") - scores = [0.0 for _ in range(len(sequences_str))] - data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device) - return scores - - def __call__(self, data: DataProto, return_dict: bool = False): - """We will expand this function gradually based on the available datasets""" - - # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn - if "rm_scores" in data.batch.keys(): - return data.batch["rm_scores"] - - reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) - - already_print_data_sources = {} - - # batched scoring - prompt_ids = data.batch["prompts"] - prompt_length = prompt_ids.shape[-1] - - response_ids = data.batch["responses"] - valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1) - sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True) - data_sources = data.non_tensor_batch["data_source"] - - scores = self.verify(data) - - for i in range(len(data)): - data_source = data_sources[i] - reward_tensor[i, valid_response_length[i].item() - 1] = scores[i] - - if data_source not in already_print_data_sources: - already_print_data_sources[data_source] = 0 - - if already_print_data_sources[data_source] < self.num_examine: - already_print_data_sources[data_source] += 1 - print(sequences_str) - - if return_dict: - return {"reward_tensor": reward_tensor} - else: - return reward_tensor diff --git a/verl/workers/reward_manager/registry.py b/verl/workers/reward_manager/registry.py deleted file mode 100644 index 3fc34efaa..000000000 --- a/verl/workers/reward_manager/registry.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ["register", "get_reward_manager_cls"] - -REWARD_MANAGER_REGISTRY = {} - - -def register(name): - """Decorator to register a reward manager class with a given name. - - Args: - name: `(str)` - The name of the reward manager. - """ - - def decorator(cls): - if name in REWARD_MANAGER_REGISTRY and REWARD_MANAGER_REGISTRY[name] != cls: - raise ValueError( - f"Reward manager {name} has already been registered: {REWARD_MANAGER_REGISTRY[name]} vs {cls}" - ) - REWARD_MANAGER_REGISTRY[name] = cls - return cls - - return decorator - - -def get_reward_manager_cls(name): - """Get the reward manager class with a given name. - - Args: - name: `(str)` - The name of the reward manager. - - Returns: - `(type)`: The reward manager class. - """ - if name not in REWARD_MANAGER_REGISTRY: - raise ValueError(f"Unknown reward manager: {name}") - return REWARD_MANAGER_REGISTRY[name] diff --git a/verl/workers/reward_model/base.py b/verl/workers/reward_model/base.py deleted file mode 100644 index cb719bd0f..000000000 --- a/verl/workers/reward_model/base.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -The base class for reward model -""" - -from abc import ABC, abstractmethod - -from verl import DataProto - - -class BasePPORewardModel(ABC): - def __init__(self, config): - self.config = config - - @abstractmethod - def compute_reward(self, data: DataProto) -> DataProto: - """Computing reward given input_ids. The transformers should output a tensor with shape - [batch_size, sequence_length], and the value at [EOS] mask should be gathered. - - Args: - data: must contain keys "input_ids", "attention_mask" and "position_ids". - - input_ids: [batch_size, sequence_length] - - attention_mask: [batch_size, sequence_length] - - position_ids: [batch_size, sequence_length] - - Returns: a data pass protocol containing "reward". Only the [EOS] position contains the reward. - Other position should have zero reward. Note that this may change in the future if we use - dense reward. So, we leave the interface for general case. - - reward: [batch_size, sequence_length]. - - """ - pass diff --git a/verl/workers/reward_model/megatron/__init__.py b/verl/workers/reward_model/megatron/__init__.py deleted file mode 100644 index 5bd4da2ba..000000000 --- a/verl/workers/reward_model/megatron/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .reward_model import MegatronRewardModel - -__all__ = ["MegatronRewardModel"] diff --git a/verl/workers/reward_model/megatron/reward_model.py b/verl/workers/reward_model/megatron/reward_model.py deleted file mode 100644 index 01b132497..000000000 --- a/verl/workers/reward_model/megatron/reward_model.py +++ /dev/null @@ -1,350 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Megatron Reward Model. -""" - -import itertools - -import torch -import torch.distributed -from megatron.core import parallel_state as mpu -from megatron.core.pipeline_parallel import get_forward_backward_func -from tensordict import TensorDict - -from verl import DataProto -from verl.utils.device import get_device_id, get_device_name, get_torch_device -from verl.utils.megatron.pipeline_parallel import make_batch_generator -from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches -from verl.utils.torch_functional import broadcast_dict_tensor, pad_sequence_to_length -from verl.workers.reward_model.base import BasePPORewardModel - - -class MegatronRewardModel(BasePPORewardModel): - def __init__( - self, - config, - model_config, - reward_model_module: torch.nn.ModuleList, - hf_config, - tf_config, - sft_tokenizer=None, - rm_tokenizer=None, - ): - self.config = config - self.reward_model_module = reward_model_module - self.hf_config = hf_config - self.tf_config = tf_config - self.model_config = model_config - self.device = "cuda" - self.sft_tokenizer = sft_tokenizer - self.rm_tokenizer = rm_tokenizer - self.use_different_tokenizer = rm_tokenizer is not None - - print(f"MegatronRewardModel.config: {self.config}") - - if self.config.megatron.param_offload: - self.offload_params_to_cpu() - - def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto: - assert self.use_different_tokenizer, "re-encode need rm tokenizer not be None!" - # need to use rm tokenizer to re-generate input_ids, attention_mask and position_ids - # 1. remove pad for each sequence - # 2. decode by sft_tokenizer, remove sft system prompts - # 3. encode by rm_tokenizer with rm system prompts, get rm_input_ids - # 4. generate attention_mask and position_ids - input_ids = data.batch["input_ids"] # (bs, seq_len) - attention_mask = data.batch["attention_mask"] - position_ids = data.batch["position_ids"] - ori_values = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids} - _, ori_seqlen = input_ids.size(0), input_ids.size(1) - input_ids_for_rm = [] - attention_mask_for_rm = [] - position_ids_for_rm = [] - print_decode = True - ori_seqlen = ori_seqlen + 128 - for id, mask in zip(input_ids, attention_mask, strict=True): - # 1. remove pad for each sequence - non_zero_indices = torch.nonzero(mask).view(-1) - begin_pos, end_pos = non_zero_indices[0].item(), non_zero_indices[-1].item() - valid_id = id[begin_pos : end_pos + 1] - # 2. decode by sft_tokenizer, remove sft system prompts - decode_result = self.sft_tokenizer.decode(valid_id) - # workaround - decode_with_rm_chat = ( - decode_result.replace("<|user|>\n", "[INST] ") - .replace("\n<|assistant|>\n", " [/INST]") - .replace(" \n<|assistant|>\n", " [/INST]") - + "" - ) - if print_decode and torch.distributed.get_rank() == 0: - # only print first decode result - print( - f"device {get_device_id()}: sft decode result:\n{decode_result}\n \ - \ndevice {get_device_id()}: sft decode result with \ - rm chat template:\n{decode_with_rm_chat}\n\n" - ) - print_decode = False - # 3. encode by rm_tokenizer - rm_input_ids = self.rm_tokenizer(decode_with_rm_chat, return_tensors="pt")["input_ids"][0].to( - input_ids.device - ) - # 4. generate attention_mask and position_ids - rm_attention_mask = torch.ones_like(rm_input_ids, device=input_ids.device) - cur_seqlen = rm_input_ids.shape[-1] - # NOTE(gh): the later reward compute will process the shape (bs, seqlen_pad_128) - if cur_seqlen > ori_seqlen: - print(f"warninig: rm encode seqlen {cur_seqlen} > sft encode seqlen {ori_seqlen}") - rm_input_ids = rm_input_ids[:ori_seqlen] - rm_attention_mask = rm_attention_mask[:ori_seqlen] - else: - # right padding - rm_input_ids = pad_sequence_to_length(rm_input_ids, ori_seqlen, self.rm_tokenizer.pad_token_id) - rm_attention_mask = pad_sequence_to_length(rm_attention_mask, ori_seqlen, 0) - rm_position_ids = torch.arange(0, ori_seqlen, device=input_ids.device) - input_ids_for_rm.append(torch.unsqueeze(rm_input_ids, dim=0)) - attention_mask_for_rm.append(torch.unsqueeze(rm_attention_mask, dim=0)) - position_ids_for_rm.append(torch.unsqueeze(rm_position_ids, dim=0)) - input_ids_for_rm = torch.cat(input_ids_for_rm, dim=0) - attention_mask_for_rm = torch.cat(attention_mask_for_rm, dim=0) - position_ids_for_rm = torch.cat(position_ids_for_rm, dim=0) - - # (bs, seqlen) will not change, but input_ids, attention_mask and position_ids will change - # NOTE(gh): need to replace into origin values after compute reward! - data.batch["input_ids"] = input_ids_for_rm - data.batch["attention_mask"] = attention_mask_for_rm - data.batch["position_ids"] = position_ids_for_rm - - return data, ori_values - - @torch.no_grad() - def compute_reward(self, data: DataProto) -> DataProto: - if self.config.megatron.param_offload: - self.load_params_to_cuda() - - if self.use_different_tokenizer: - data, ori_values = self.re_encode_by_rm_tokenizer(data) - - input_ids = data.batch["input_ids"] # (bs, seq_len') - attention_mask = data.batch["attention_mask"] - position_ids = data.batch["position_ids"] - use_dynamic_bsz = data.meta_info.get("use_dynamic_bsz", False) - micro_batch_size = data.meta_info.get("micro_batch_size", None) - max_token_len = data.meta_info.get("max_token_len", None) - assert micro_batch_size is not None, "micro batch size is needed for forward compute" - if use_dynamic_bsz: - assert max_token_len is not None, "use_dynamic_bsz is True, but max_token_len is None!" - max_token_len = max_token_len * self.config.megatron.context_parallel_size - - responses = data.batch["responses"] - batch_size = responses.size(0) - response_length = responses.size(1) - - with torch.no_grad(): - output = self.forward_batch( - data, use_dynamic_bsz=use_dynamic_bsz, micro_batch_size=micro_batch_size, max_token_len=max_token_len - ) - if mpu.is_pipeline_last_stage(ignore_virtual=True): - logits = torch.cat(output["output"], dim=0) - if use_dynamic_bsz: - indices = output["indices"] - indices = list(itertools.chain.from_iterable(indices)) - assert len(indices) == logits.size(0), f"{len(indices)} vs. {logits.size()}" - revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long) - logits = logits[revert_indices] - else: - logits = torch.empty( - (input_ids.shape[0], input_ids.shape[1]), - device=input_ids.device, - ) - logits = logits.to(torch.float32) - - # broadcast across pp ranks - torch.distributed.broadcast( - tensor=logits, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - async_op=False, - ) - - # (bs, seqlen', hidden_size) -> (bs, seqlen', 1) -> (bs, seqlen') - token_level_rewards = logits - # find the last token reward - ends = attention_mask.cumsum(dim=-1).argmax(dim=-1).view(-1, 1) # (bs, 1) - rewards = torch.gather(token_level_rewards, dim=1, index=ends) # (bs, 1) - - if self.use_different_tokenizer: - data.batch.update(ori_values) - input_ids = ori_values["input_ids"] - attention_mask = ori_values["attention_mask"] - position_ids = ori_values["position_ids"] - - token_level_rewards = rewards.expand(attention_mask.shape[0], attention_mask.shape[1]) # (bs, ori_seqlen) - - # assign last valid token reward to ori position - if position_ids.dim() == 3: # qwen2vl mrope [bs, 3, seq_len] - position_ids = position_ids[:, 0, :] - eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bs,) - eos_mask = torch.zeros_like(attention_mask) - eos_mask[torch.arange(batch_size), eos_mask_idx] = 1.0 - - token_level_rewards = token_level_rewards * eos_mask - token_level_rewards = token_level_rewards[:, -response_length:] - - if self.config.megatron.param_offload: - self.offload_params_to_cpu() - else: - # add empty cache after each compute - get_torch_device().empty_cache() - - batch = TensorDict({"rm_scores": token_level_rewards}, batch_size=input_ids.shape[0]) - - return DataProto(batch=batch) - - def forward_batch(self, data: DataProto, use_dynamic_bsz=False, micro_batch_size=None, max_token_len=None): - """ - We assume: - - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input - - The communication shape is (total_nnz_pad_to_sp // tp_size, 1, hidden_size) if sequence parallel is enabled - """ - # broadcast from last pp rank to all other pp ranks - # TODO: actually, we just need to control the sampling order. - mini_batch = data - mini_batch.batch = mini_batch.batch.contiguous() - broadcast_dict_tensor( - mini_batch.batch, - src=mpu.get_pipeline_model_parallel_last_rank(), - group=mpu.get_pipeline_model_parallel_group(), - ) - - mini_batch.batch["attention_mask"] = mini_batch.batch["attention_mask"].to(bool) - - self.has_multi_modal_inputs = "multi_modal_inputs" in mini_batch.non_tensor_batch.keys() - if self.has_multi_modal_inputs: - mini_batch.batch["multi_modal_inputs"] = mini_batch.non_tensor_batch["multi_modal_inputs"] - mini_batch.batch["multi_modal_inputs_idx"] = torch.Tensor( - list(range(len(mini_batch.non_tensor_batch["multi_modal_inputs"]))) - ).to(torch.int64) - - indices = None - if use_dynamic_bsz: - assert max_token_len is not None, "max_token_len must be set when use_dynamic_bsz is True" - vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size() - if vpp_size is not None and vpp_size > 1: - microbatch_group_size_per_vp_stage = self.tf_config.microbatch_group_size_per_vp_stage - micro_batches, indices = rearrange_micro_batches( - batch=mini_batch.batch, - num_batches_divided_by=microbatch_group_size_per_vp_stage, - max_token_len=max_token_len, - ) - assert len(micro_batches) % self.tf_config.microbatch_group_size_per_vp_stage == 0, ( - f"micro_batches {micro_batches} must be divisible by microbatch_group_size_per_vp_stage " - f"{microbatch_group_size_per_vp_stage} for megatron backend" - ) - else: - micro_batches, indices = rearrange_micro_batches(batch=mini_batch.batch, max_token_len=max_token_len) - total_seqlen = max_token_len - else: - assert micro_batch_size is not None, ( - "micro_batch_size is needed to be passed in when not using dynamic batch size" - ) - micro_batches = mini_batch.batch.split(micro_batch_size) - seq_len = micro_batches[0]["input_ids"].shape[1] - total_seqlen = micro_batch_size * seq_len - n_micro_batch = len(micro_batches) - - # compute input shapes for pp stages - forward_backward_func = get_forward_backward_func() - - def loss_func(output): - return torch.tensor(1.0, device=output.device), output - - def forward_step(batch_iter, model): - batch = next(batch_iter) - input_ids = batch["input_ids"] - attention_mask = batch["attention_mask"] - position_ids = batch["position_ids"] - from verl.models.mcore import get_mcore_forward_fn - - forward_fn = get_mcore_forward_fn(self.hf_config) - - multi_modal_inputs = {} - if "multi_modal_inputs" in batch: - for key in batch["multi_modal_inputs"][0].keys(): - multi_modal_inputs[key] = torch.cat( - [batch["multi_modal_inputs"][i][key] for i in batch["multi_modal_inputs_idx"]], dim=0 - ) - - output = forward_fn( - model, - input_ids, - attention_mask, - position_ids, - sequence_parallel=self.tf_config.sequence_parallel, - value_model=True, - multi_modal_inputs=multi_modal_inputs, - ) - - return output, loss_func - - # batch should be a list of batches inside micro-batches - batch_generator = make_batch_generator(micro_batches, vpp_size=len(self.reward_model_module)) - - # TODO: we may use the new schedule instead - # for flash-attn: (seq_len, batch_size, hidden_size) = (mbs*seq_len, 1, hidden_size) - if mpu.get_pipeline_model_parallel_world_size() > 1: - losses_reduced = forward_backward_func( - forward_step_func=forward_step, - data_iterator=batch_generator, - model=self.reward_model_module, - num_microbatches=n_micro_batch, - seq_length=total_seqlen, # no use when input_shapes was set - micro_batch_size=1, # no use when input_shapes was set - forward_only=True, - ) - else: - losses_reduced = forward_backward_func( - forward_step_func=forward_step, - data_iterator=batch_generator, - model=self.reward_model_module, - num_microbatches=n_micro_batch, - seq_length=total_seqlen, # in use for pp = 1 - micro_batch_size=1, # in use for pp = 1 - forward_only=True, - ) - - if self.has_multi_modal_inputs: - data.batch.pop("multi_modal_inputs") - data.batch.pop("multi_modal_inputs_idx") - data.non_tensor_batch.pop("multi_modal_inputs") - # loss_reduces contains the stats returned from loss_func - losses_reduced = {"output": losses_reduced} - if use_dynamic_bsz: - losses_reduced["indices"] = indices - return losses_reduced - - def offload_params_to_cpu(self): - if self.device in ["cuda", "npu"]: - for reward_model_module in self.reward_model_module: - for name, param in reward_model_module.named_parameters(): - param.data = param.data.to("cpu", non_blocking=True) - self.device = "cpu" - get_torch_device().empty_cache() - - def load_params_to_cuda(self): - if self.device == "cpu": - for reward_model_module in self.reward_model_module: - for name, param in reward_model_module.named_parameters(): - param.data = param.data.to(get_device_id(), non_blocking=True) - self.device = get_device_name() diff --git a/verl/workers/rollout/__init__.py b/verl/workers/rollout/__init__.py deleted file mode 100644 index 5efcd337d..000000000 --- a/verl/workers/rollout/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .base import BaseRollout -from .hf_rollout import HFRollout -from .naive import NaiveRollout - -__all__ = ["BaseRollout", "NaiveRollout", "HFRollout"] diff --git a/verl/workers/rollout/async_server.py b/verl/workers/rollout/async_server.py deleted file mode 100644 index 8b09c0b59..000000000 --- a/verl/workers/rollout/async_server.py +++ /dev/null @@ -1,275 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import logging -import os -import socket -import threading -from abc import ABC, abstractmethod -from contextlib import asynccontextmanager -from typing import Any, Optional - -import fastapi -import ray -import uvicorn -from omegaconf import DictConfig -from starlette.requests import Request - -from verl.protocol import DataProto -from verl.single_controller.ray.base import RayWorkerGroup -from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler - -logger = logging.getLogger(__file__) - - -def _get_free_port(): - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - -class AsyncServerBase(ABC): - """Base class for AsyncServer.""" - - def __init__(self): - self.address = ray.util.get_node_ip_address() - self.port = None - self.server_ready = asyncio.Event() - asyncio.create_task(self._start_fastapi_server()) - - async def _start_fastapi_server(self): - @asynccontextmanager - async def lifespan(app: fastapi.FastAPI): - print(f"FastAPI listen on {self.address}:{self.port}") - self.server_ready.set() - yield - - # There's no way to gracefully restart uvicorn server if port is already in use, - # so we exit the process directly and let AsyncLLMServerManager restart it. - print("FastAPI shutdown, maybe address already in use, exit process immediately.") - os._exit(-1) - - app = fastapi.FastAPI(lifespan=lifespan) - app.router.add_api_route("/v1/chat/completions", self.chat_completion, methods=["POST"]) - - self.port = _get_free_port() - config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") - server = uvicorn.Server(config) - await server.serve() - - async def get_server_address(self) -> tuple[str, int]: - """Get FastAPI server address.""" - await self.server_ready.wait() - return f"{self.address}:{self.port}" - - @abstractmethod - async def chat_completion(self, raw_request: Request): - """OpenAI chat completion API. - - API reference: https://platform.openai.com/docs/api-reference/chat/create - """ - raise NotImplementedError - - @abstractmethod - async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: - """Generate response ids given prompt ids. - - Args: - prompt_ids (List[int]): prompt ids - sampling_params (Dict[str, Any]): sampling params - request_id (str): request id - - Returns: - List[int]: response ids - """ - raise NotImplementedError - - @abstractmethod - async def init_engine(self): - """Init async LLM engine.""" - raise NotImplementedError - - @abstractmethod - async def wake_up(self): - """Wake up engine to load model weights and build kv cache.""" - raise NotImplementedError - - @abstractmethod - async def sleep(self): - """Sleep engine to offload model weights and discard kv cache.""" - raise NotImplementedError - - -class AsyncLLMServerManager: - """AsyncLLMServerManager manage a group of vllm instances, i.e AsyncvLLMServer.""" - - def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): - """Initialize AsyncLLMServerManager. - - Args: - config: DictConfig, actor_rollout_ref config. - worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker. - """ - self.full_config = config - self.config = config.actor_rollout_ref - self.worker_group = worker_group - - self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size - self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size - - register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center") - workers_info = ray.get(register_center.get_worker_info.remote()) - assert len(workers_info) == self.worker_group.world_size - - self.async_llm_servers = [None] * self.rollout_dp_size - self.server_addresses = [None] * self.rollout_dp_size - - if self.config.rollout.agent.custom_async_server: - server_class = async_server_class( - rollout_backend=self.config.rollout.name, - rollout_backend_module=self.config.rollout.agent.custom_async_server.path, - rollout_backend_class=self.config.rollout.agent.custom_async_server.name, - ) - else: - server_class = async_server_class(rollout_backend=self.config.rollout.name) - - # Start all server instances, restart if address already in use. - unready_dp_ranks = set(range(self.rollout_dp_size)) - while len(unready_dp_ranks) > 0: - servers = { - rollout_dp_rank: server_class.options( - # make sure AsyncvLLMServer colocates with its corresponding workers - scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( - node_id=workers_info[rollout_dp_rank * self.rollout_tp_size], - soft=False, - ), - name=f"async_llm_server_{rollout_dp_rank}", - ).remote(config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) - for rollout_dp_rank in unready_dp_ranks - } - - for rollout_dp_rank, server in servers.items(): - try: - address = ray.get(server.get_server_address.remote()) - self.server_addresses[rollout_dp_rank] = address - self.async_llm_servers[rollout_dp_rank] = server - unready_dp_ranks.remove(rollout_dp_rank) - except Exception: - ray.kill(server) - print(f"rollout server {rollout_dp_rank} failed, maybe address already in use, restarting...") - - # All server instances are ready, init AsyncLLM engine. - ray.get([server.init_engine.remote() for server in self.async_llm_servers]) - - # Init user provided chat scheduler in sperate thread. - self.chat_scheduler: ChatCompletionScheduler = None - self.chat_scheduler_exception: Exception = None - self.chat_scheduler_loop = None - self.chat_scheduler_ready = threading.Event() - self.chat_scheduler_thread = threading.Thread(target=self._init_chat_scheduler, daemon=True) - self.chat_scheduler_thread.start() - self.chat_scheduler_ready.wait() - - def _init_chat_scheduler(self): - self.chat_scheduler_loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.chat_scheduler_loop) - - try: - self.chat_scheduler = ChatCompletionScheduler( - config=self.full_config, - server_addresses=self.server_addresses, - ) - except Exception as e: - logger.exception(f"chat_scheduler init error: {e}") - self.chat_scheduler_exception = e - finally: - self.chat_scheduler_ready.set() - self.chat_scheduler_loop.run_forever() - - def wake_up(self): - """Wake up all vllm instances.""" - if self.config.rollout.free_cache_engine: - ray.get([server.wake_up.remote() for server in self.async_llm_servers]) - - def sleep(self): - """Sleep all vllm instances.""" - if self.config.rollout.free_cache_engine: - ray.get([server.sleep.remote() for server in self.async_llm_servers]) - - def submit_chat_completions( - self, - messages: list[dict[str, str]], - sampling_params: dict[str, Any], - ): - """Submit a chat completion request to chat scheduler and wait until it is done. - To submit multiple requests in parallel, please use `generate_sequences` instead. - - Args: same as ChatCompletionScheduler.submit_chat_completions. - """ - assert self.chat_scheduler is not None, "chat scheduler is not initialized." - future = asyncio.run_coroutine_threadsafe( - self.chat_scheduler._submit_chat_completions_semaphore( - messages=messages, - request_id=None, - sampling_params=sampling_params, - ), - self.chat_scheduler_loop, - ) - future.result() - - def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: - """Generate multiple sequences in parallel via chat scheduler.""" - assert self.chat_scheduler is not None, "chat scheduler is not initialized." - - future = asyncio.run_coroutine_threadsafe( - self.chat_scheduler.generate_sequences(prompts, **sampling_params), self.chat_scheduler_loop - ) - return future.result() - - -def async_server_class( - rollout_backend: str, rollout_backend_module: Optional[str] = None, rollout_backend_class: Optional[str] = None -) -> type[AsyncServerBase]: - """Get async server class. - - Args: - rollout_backend: str, rollout backend type (alias), should be "vllm" or "sglang". - rollout_backend_module: Optional[str], import path of the rollout backend. - rollout_backend_class: Optional[str], class name of the rollout backend. - - Returns: - Type[AsyncServerBase]: async server class. - """ - if rollout_backend_class is None and rollout_backend_module is None: - # If both are None, use the default backend class - # Do not change the original import behavior - # importlib.import_module and from ... import ... have subtle differences in ray - - if rollout_backend == "vllm": - from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer - - return AsyncvLLMServer - elif rollout_backend == "sglang": - from verl.workers.rollout.sglang_rollout.async_sglang_server import AsyncSglangServer - - return AsyncSglangServer - else: - raise NotImplementedError(f"rollout backend {rollout_backend} is not supported") - - if rollout_backend_module is None or rollout_backend_class is None: - raise ValueError("rollout_backend_module and rollout_backend_class must be both provided for customization") - - from verl.utils.import_utils import load_extern_type - - return load_extern_type(rollout_backend_module, rollout_backend_class) diff --git a/verl/workers/rollout/base.py b/verl/workers/rollout/base.py deleted file mode 100644 index 031982464..000000000 --- a/verl/workers/rollout/base.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod - -from verl import DataProto - -__all__ = ["BaseRollout"] - - -class BaseRollout(ABC): - """Base class for rollout.""" - - @abstractmethod - def generate_sequences(self, prompts: DataProto) -> DataProto: - """Generate sequences""" - pass diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py deleted file mode 100644 index 268c82d02..000000000 --- a/verl/workers/rollout/chat_scheduler.py +++ /dev/null @@ -1,444 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import heapq -import importlib -import itertools -import json -import logging -import time -from abc import ABC, abstractmethod -from typing import Any -from uuid import uuid4 - -import aiohttp -import numpy as np -import torch -from cachetools import LRUCache -from omegaconf import DictConfig -from openai import AsyncOpenAI -from openai.types.chat.chat_completion import ChatCompletion -from tensordict import TensorDict - -from verl.protocol import DataProto -from verl.tools.utils.tool_registry import initialize_tools_from_config -from verl.utils import hf_tokenizer -from verl.utils.fs import copy_to_local -from verl.utils.import_utils import deprecated - -logger = logging.getLogger(__file__) - - -class CompletionCallback(ABC): - def __init__(self, config: DictConfig, scheduler: "ChatCompletionScheduler"): - self.config = config - self.scheduler = scheduler - - # Initialize tools from config file - self.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns - tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path - tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] - self.tools = {tool.name: tool for tool in tool_list} - self._tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] - print(f"Initialized tools: {self.tools}", flush=True) - - local_path = copy_to_local(config.actor_rollout_ref.model.path) - self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) - - @property - def tool_schemas(self): - """OpenAI JSON tool schemas.""" - return self._tool_schemas - - @property - def extra_body(self) -> dict[str, Any]: - """Extra body pass to OpenAI API.""" - return None - - @abstractmethod - async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]): - """Call back function to process completions. - - Args: - messages: List of messages including raw prompt and assistant, tool response generated so far. - completions: Chat completions from OpenAI compatible server. - info: Any other auxiliary information pass across multi-turn. - """ - raise NotImplementedError - - @abstractmethod - def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, str]]], n: int) -> DataProto: - """Post process batch data. - - Args: - batch: Batch input messages from RLHFDataset. - batch_conversations: List of messages including raw prompt, assistant response, tool response. - Note that `len(batch_conversations) == len(batch) * n`, e.g n=2, - batch_conversations=[messages_0_0, messages_0_1, messages_1_0, messages_1_1, ...] - n: How many chat completion choices to generate for each input message. - - Returns: - Batch data, should include ["prompts", "responses", "response_mask", "input_ids", "attention_mask", - "position_ids"]. - """ - raise NotImplementedError - - -class ToolCompletionCallback(CompletionCallback): - def __init__(self, config: DictConfig, scheduler: "ChatCompletionScheduler"): - super().__init__(config, scheduler) - - # TODO: add reward manager to calculate reward score once a sample finish - - async def __call__(self, messages: list[dict[str, str]], completions: ChatCompletion, info: dict[str, Any]): - message = completions.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) - if "content" not in message: - message["content"] = "" - messages.append(message) - finish_reason = completions.choices[0].finish_reason - - # STEP 0: check if we reach max turns - if self.max_assistant_turns and len(messages) >= self.max_assistant_turns: - print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Reach max turns, done!") - return - - # STEP 1: check if the model called tools - if finish_reason != "tool_calls": - print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] No tool called, done!") - return - - # STEP 2: call tools - tool_calls = completions.choices[0].message.tool_calls - print(f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Call {len(tool_calls)} tools") - tasks = [] - for tool_call in tool_calls: - tasks.append(self._call_tool(tool_call)) - tool_responses = await asyncio.gather(*tasks) - if any(isinstance(item, Exception) for item in tool_responses): - print( - f"[id={completions.id},turn={len(messages)},finish_reason={finish_reason}] Error when calling tools, " - f"done!" - ) - return - messages.extend(tool_responses) - - # STEP 3: resubmit completion request with tool responses - self.scheduler.submit_chat_completions(messages=messages, request_id=completions.id, info=info) - - async def _call_tool(self, tool_call) -> dict[str, str]: - """Call tool and return tool response.""" - tool_name = tool_call.function.name - tool_args = json.loads(tool_call.function.arguments) - tool = self.tools[tool_name] - - instance_id = await tool.create() - try: - tool_response, tool_reward_score, tool_metrics = await tool.execute(instance_id, tool_args) - except Exception as e: - logger.exception(f"Error when executing tool: {e}") - return e - finally: - await tool.release(instance_id) - - return { - "role": "tool", - "content": tool_response, - "tool_call_id": tool_call.id, - } - - def postprocess(self, batch: DataProto, batch_conversations: list[list[dict[str, str]]], n: int) -> DataProto: - # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py - # prompts: left pad - # responses: right pad - # input_ids: prompt + response - # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] - # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] - - # prompts: [prompt] from input dataset - prompts = [ - self.tokenizer.apply_chat_template( - prompt, tools=self.tool_schemas, add_generation_prompt=True, tokenize=False - ) - for prompt in batch.non_tensor_batch["raw_prompt"] - ] - assert len(batch_conversations) == len(prompts) * n - - # sequences: [prompt + response] - sequences = [ - self.tokenizer.apply_chat_template( - conversation, tools=self.tool_schemas, add_generation_prompt=False, tokenize=False - ) - for conversation in batch_conversations - ] - - # responses: [response] - responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)] - - prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left") - responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right") - if n > 1: - prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0) - prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0) - - # response_mask: response mask with tools calling masked out - response_mask = self._mask_out_tools_calling_tokens( - batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0), - batch_conversations, - responses["input_ids"], - responses["attention_mask"], - ) - - input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1) - attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1) - position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask - - batch = TensorDict( - { - "prompts": prompts["input_ids"], # [bsz, prompt_length] - "responses": responses["input_ids"], # [bsz, response_length] - "response_mask": response_mask, # [bsz, response_length] - "input_ids": input_ids, # [bsz, prompt_length + response_length] - "attention_mask": attention_mask, # [bsz, prompt_length + response_length] - "position_ids": position_ids, # [bsz, prompt_length + response_length] - }, - batch_size=len(input_ids), - ) - - num_turns = np.array([len(conversation) for conversation in batch_conversations], dtype=np.int32) - return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}) - - def _mask_out_tools_calling_tokens( - self, - raw_prompts: list[list[dict[str, str]]], - batch_conversations: list[list[dict[str, str]]], - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - ) -> torch.Tensor: - """Mask out tools calling tokens in the responses. - - Args: - raw_prompts: [prompt] from input dataset - batch_conversations: [prompt + response] - input_ids: responses tokens - attention_mask: responses attention mask - - Returns: - mask: (batch_size, response_length) - """ - batch_size = input_ids.size(0) - assert len(raw_prompts) == batch_size, f"{len(raw_prompts)} != {batch_size}" - assert len(batch_conversations) == batch_size, f"{len(batch_conversations)} != {batch_size}" - - # Deduplicate adjacent tool calls, since they're merged into one turn. - # [user, assistant, tool, tool, assistant] -> [user, assistant, tool, assistant] - # TODO: it's chat_template specific, find a more generic way to do this. - def deduplicate_adjacent_tool_calls(roles): - result = [] - for role, group in itertools.groupby(roles): - if role == "tool": - result.append(role) - else: - result.extend(group) - return result - - loss_mask = attention_mask.clone() - for i in range(batch_size): - responses = batch_conversations[i][len(raw_prompts[i]) :] - assert len(responses) > 0, f"responses is empty: {responses}" - - roles = deduplicate_adjacent_tool_calls([response["role"] for response in responses]) - # Each turn should be: [BOS]...[EOS] - eos_indices = input_ids[i].eq(self.tokenizer.eos_token_id).nonzero().squeeze(1)[: len(roles)] - for j in range(len(roles)): - if roles[j] == "tool": - bos = eos_indices[j - 1] + 1 if j > 0 else 0 - eos = eos_indices[j] - loss_mask[i, bos : eos + 1] = 0 - - return loss_mask - - -@deprecated("verl.experimental.agent_loop.AgentLoopManager") -class ChatCompletionScheduler: - def __init__( - self, - config: DictConfig, - server_addresses: list[str], - max_cache_size: int = 10000, - ): - """ - Args: - config: DictConfig. - server_addresses: List[str], OpenAI compatible server addresses. - max_cache_size: int, max cache size of request_id to address mapping. - """ - self.config = config.actor_rollout_ref.rollout - model_path = config.actor_rollout_ref.model.path - self.model_name = "/".join(model_path.split("/")[-2:]) - - # Least requests load balancing - self.weighted_addresses = [[0, address] for address in server_addresses] - heapq.heapify(self.weighted_addresses) - - # LRU cache to map request_id to address - self.request_id_to_address = LRUCache(maxsize=max_cache_size) - - self.background_tasks = set() - if self.config.multi_turn.completion_callback is None: - self.completion_callback = ToolCompletionCallback(config, self) - logger.warning("completion_callback is None, use ToolCompletionCallback") - else: - module_path, class_name = self.config.multi_turn.completion_callback.rsplit(".", 1) - module = importlib.import_module(module_path) - self.completion_callback = getattr(module, class_name)(config, self) - - def submit_chat_completions(self, *, messages: list[dict[str, str]], request_id: str, info: dict[str, Any]): - """Submit chat completion request without wait, completion_callback will be called when the request is done. - - Args: - messages: List of messages. - request_id: Request id. - info: Any other auxiliary information pass across multi-turn. - """ - info["__depth__"] += 1 - task = asyncio.create_task(self._submit_chat_completions_and_callback(messages, request_id, info)) - - # “fire-and-forget” background tasks - self.background_tasks.add(task) - task.add_done_callback(self.background_tasks.discard) - - async def _submit_chat_completions_and_callback( - self, - messages: list[dict[str, str]], - request_id: str, - info: dict[str, Any], - ): - """Submit chat completion request, wait request finish and do callback.""" - if request_id: - request_id = request_id.removeprefix("chatcmpl-") - assert request_id in self.request_id_to_address - address = self.request_id_to_address.pop(request_id) - else: - address = self.weighted_addresses[0][1] - self.weighted_addresses[0][0] += 1 - heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0]) - - # use new request_id to avoid duplicate request_id problem - request_id = uuid4().hex - self.request_id_to_address[request_id] = address - - completions, exception = None, None - try: - # NOTE: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. - completions = await self._chat_completions_aiohttp( - address, - messages=messages, - tools=self.completion_callback.tool_schemas, - extra_body=self.completion_callback.extra_body, - extra_headers={"x-request-id": request_id}, - **info["__sampling_params__"], - ) - except Exception as e: - # Let user handle the exception - exception = e - - info["__depth__"] -= 1 - - if exception is not None: - logger.exception(f"chat completion failed with exception: {exception}") - else: - try: - await self.completion_callback(messages, completions, info) - except Exception as e: - logger.exception(f"completion callback failed with exception: {e}") - - # No more ongoing completion requests - if info["__depth__"] == 0: - info["__done__"].set() - - async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: - client = AsyncOpenAI(base_url=f"http://{address}/v1", api_key="token-abc123", timeout=None, max_retries=0) - return await client.chat.completions.create(**chat_complete_request) - - async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: - try: - extra_body = chat_complete_request.pop("extra_body", {}) - chat_complete_request.update(extra_body or {}) - extra_headers = chat_complete_request.pop("extra_headers") - timeout = aiohttp.ClientTimeout(total=None) - session = aiohttp.ClientSession(timeout=timeout) - async with session.post( - url=f"http://{address}/v1/chat/completions", - headers={"Authorization": "Bearer token-abc123", **extra_headers}, - json=chat_complete_request, - ) as resp: - data = await resp.json() - return ChatCompletion(**data) - finally: - await session.close() - - async def generate_sequences(self, batch: DataProto) -> DataProto: - t_start = time.time() - kwargs = dict( - model=self.model_name, - temperature=self.config.temperature, - top_p=self.config.top_p, - ) - - # override sampling params for validation - if batch.meta_info.get("validate", False): - kwargs["top_p"] = self.config.val_kwargs.top_p - kwargs["temperature"] = self.config.val_kwargs.temperature - - print(f"[ChatCompletionScheduler] generate_sequences sampling params: {kwargs}") - - # NOTE: For multi-turn rollout, repeat raw_prompt n times and process each prompt independently, - # validation dataset has already been repeated in `PPOTrainer._validate`. - n = 1 if batch.meta_info.get("validate", False) else self.config.n - tasks, batch_conversations = [], [None] * len(batch) * n - for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"].repeat(n, axis=0)): - # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] - batch_conversations[batch_index] = conversation.tolist() - - tasks.append( - asyncio.create_task( - self._submit_chat_completions_semaphore( - messages=batch_conversations[batch_index], - request_id=None, - sampling_params=kwargs, - ) - ) - ) - - await asyncio.gather(*tasks) - output_batch = self.completion_callback.postprocess(batch, batch_conversations, n=n) - output_batch.meta_info["timing"] = {"generate_sequences": time.time() - t_start} - print("[ChatCompletionScheduler] generate_sequences done") - return output_batch - - async def _submit_chat_completions_semaphore( - self, messages: list[dict[str, str]], request_id: str, sampling_params: dict[str, Any] - ): - done = asyncio.Event() - - info = { - "__done__": done, - "__depth__": 0, # indicate how many ongoing completion requests - "__sampling_params__": sampling_params, - } - - self.submit_chat_completions(messages=messages, request_id=request_id, info=info) - - # Wait until all completion requests are done - await done.wait() diff --git a/verl/workers/rollout/hf_rollout.py b/verl/workers/rollout/hf_rollout.py deleted file mode 100644 index 32d0bc8a5..000000000 --- a/verl/workers/rollout/hf_rollout.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Rollout with huggingface models. -TODO: refactor this class. Currently, it will hang when using FSDP HybridShard. We should actually create a single -GPU model. Then, get full state_dict and bind the state_dict to the single GPU model. Then, use the single GPU model -to perform generation. -""" - -import contextlib - -import torch -import torch.distributed -from tensordict import TensorDict -from torch import nn -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from transformers import GenerationConfig - -from verl import DataProto -from verl.utils.device import get_device_name, get_torch_device -from verl.utils.torch_functional import get_response_mask - -from .base import BaseRollout - -__all__ = ["HFRollout"] - - -class HFRollout(BaseRollout): - def __init__(self, module: nn.Module, config): - super().__init__() - self.config = config - self.module = module - - def generate_sequences(self, prompts: DataProto) -> DataProto: - batch_size = prompts.batch.batch_size[0] - num_chunks = max(batch_size // self.config.get("micro_batch_size", batch_size), 1) - batch_prompts = prompts.chunk(chunks=num_chunks) - output = [self._generate_minibatch(p) for p in batch_prompts] - output = DataProto.concat(output) - return output - - @torch.no_grad() - def _generate_minibatch(self, prompts: DataProto) -> DataProto: - # make sampling args can be overridden by inputs - do_sample = prompts.meta_info.get("do_sample", self.config.do_sample) - is_validate = prompts.meta_info.get("validate", False) - - temperature = prompts.meta_info.get("temperature", self.config.temperature) - response_length = prompts.meta_info.get("response_length", self.config.response_length) - top_p = prompts.meta_info.get("top_p", self.config.get("top_p", 1.0)) - top_k = max(0, prompts.meta_info.get("top_k", self.config.get("top_k", 0))) # to be compatible with vllm - - if not do_sample: - # do_sample==False -> greedy decoding - kwargs = { - "do_sample": False, - "num_beams": 1, - } - elif is_validate: - # do validate and do sample -> use val_kwargs - kwargs = { - "do_sample": True, - "num_beams": 1, - "top_k": max(0, self.config.val_kwargs.top_k), # to be compatible with vllm - "top_p": self.config.val_kwargs.top_p, - "temperature": self.config.val_kwargs.temperature, - "num_return_sequences": 1, # if validate, already repeat in ray_trainer - } - else: - # do_sample -> use rollout config - kwargs = { - "do_sample": True, - "num_beams": 1, - "top_p": top_p, - "top_k": top_k, - "temperature": temperature, - "num_return_sequences": self.config.n, - } - - # make config according to generate mode - generation_config = GenerationConfig(**kwargs) - - idx = prompts.batch["input_ids"] # (bs, prompt_length) - prompt_length = idx.size(1) - attention_mask = prompts.batch["attention_mask"] # left-padded attention_mask - position_ids = prompts.batch["position_ids"] - - # used to construct attention_mask - eos_token_id = prompts.meta_info["eos_token_id"] - pad_token_id = prompts.meta_info["pad_token_id"] - - self.module.eval() - param_ctx = contextlib.nullcontext() - - if isinstance(self.module, FSDP): - # recurse need to set to False according to https://github.com/pytorch/pytorch/issues/100069 - param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False) - with param_ctx, torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16): - output = self.module.generate( - input_ids=idx, - attention_mask=attention_mask, - position_ids=position_ids, - do_sample=do_sample, - max_new_tokens=response_length, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - generation_config=generation_config, - output_scores=False, # this is potentially very large - return_dict_in_generate=True, - use_cache=True, - ) - - # TODO: filter out the seq with no answers like ds-chat - seq = output.sequences - generated_batch_size = seq.size(0) # bs * num_return_sequences - - # huggingface generate will stop generating when all the batch reaches [EOS]. - # We have to pad to response_length - sequence_length = prompt_length + self.config.response_length - delta_length = sequence_length - seq.shape[1] - - if delta_length > 0: - delta_tokens = torch.ones(size=(generated_batch_size, delta_length), device=seq.device, dtype=seq.dtype) - delta_tokens = pad_token_id * delta_tokens - seq = torch.cat((seq, delta_tokens), dim=1) - assert seq.shape[1] == sequence_length - - # make necessary reputations if num_return_sequences > 1 - num_return_sequences = kwargs.get("num_return_sequences", 1) - if num_return_sequences > 1: - position_ids = position_ids.repeat_interleave(num_return_sequences, dim=0) - attention_mask = attention_mask.repeat_interleave(num_return_sequences, dim=0) - - prompt = seq[:, :prompt_length] # (generated_batch_size, prompt_length) - response = seq[:, prompt_length:] # (generated_batch_size, response_length) - - response_length = response.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(generated_batch_size, 1) - - response_position_ids = position_ids[:, -1:] + delta_position_id - position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - - response_attention_mask = get_response_mask( - response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype - ) - attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) - - batch = TensorDict( - { - "prompts": prompt, - "responses": response, - "input_ids": seq, - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=generated_batch_size, - ) - - # empty cache before compute old_log_prob - get_torch_device().empty_cache() - - self.module.train() - return DataProto(batch=batch) diff --git a/verl/workers/rollout/naive/__init__.py b/verl/workers/rollout/naive/__init__.py deleted file mode 100644 index cb6c23bf4..000000000 --- a/verl/workers/rollout/naive/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .naive_rollout import NaiveRollout - -__all__ = ["NaiveRollout"] diff --git a/verl/workers/rollout/naive/naive_rollout.py b/verl/workers/rollout/naive/naive_rollout.py deleted file mode 100644 index fe56dc4c9..000000000 --- a/verl/workers/rollout/naive/naive_rollout.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -In single GPU rollout, the sequences are generated directly by sampling from the model. -The output will contain -1. output_ids -2. attention_masks (left padding) -3. eos_masks -4. log_probs -""" - -import torch -import torch.nn.functional as F -from tensordict import TensorDict -from torch import nn - -from verl import DataProto -from verl.utils.torch_functional import logprobs_from_logits - -from ..base import BaseRollout - -__all__ = ["NaiveRollout"] - - -class NaiveRollout(BaseRollout): - def __init__(self, module: nn.Module, config): - """A naive rollout. It requires the module to be compatible with huggingface APIs. That is: - The module should define __call__ to receive input_ids, attention_mask and position_ids. - It outputs a structure that contains logits field. - - Args: - module: module here follows huggingface APIs - config: DictConfig - """ - super().__init__() - self.config = config - self.module = module - - @torch.no_grad() - def generate_sequences(self, prompts: DataProto) -> DataProto: - """Generate sequences""" - idx = prompts.batch["input_ids"] # (bs, prompt_length) - attention_mask = prompts.batch["attention_mask"] # left-padded attention_mask - position_ids = prompts.batch["position_ids"] - - # used to construct attention_mask - eos_token_id = prompts.meta_info["eos_token_id"] - - batch_size = idx.size(0) - prompt_length = idx.size(1) - - self.module.eval() - - prev_attention_mask = torch.ones(size=(batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device) - - logits_lst = [] - for _ in range(self.config.response_length): - # if the sequence context is growing too long we must crop it at block_size - # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] - idx_cond = idx - # forward the model to get the logits for the index in the sequence - # we use huggingface APIs here - output = self.module(input_ids=idx_cond, attention_mask=attention_mask, position_ids=position_ids) - logits = output.logits - # pluck the logits at the final step and scale by desired temperature - logits = logits[:, -1, :] / self.config.temperature # (bs, vocab_size) - # optionally crop the logits to only the top k options - if self.config.top_k is not None: - v, _ = torch.topk(logits, min(self.config.top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float("Inf") - # apply softmax to convert logits to (normalized) probabilities - probs = F.softmax(logits, dim=-1) - # sample from the distribution - if self.config.do_sample: - idx_next = torch.multinomial(probs, num_samples=1) - else: - idx_next = torch.argmax(probs, dim=-1, keepdim=True) - - attention_mask = torch.cat((attention_mask, prev_attention_mask), dim=-1) - - for token_id in eos_token_id: - prev_attention_mask = torch.logical_and(idx_next != token_id, prev_attention_mask.bool()) - prev_attention_mask.to(attention_mask.dtype) - - position_ids = torch.cat((position_ids, position_ids[:, -1:] + 1), dim=-1) - - # append sampled index to the running sequence and continue - idx = torch.cat((idx, idx_next), dim=1) - logits_lst.append(logits) - - logits = torch.stack(logits_lst, dim=1) # (bs, response_length, vocab_size) - prompts = idx[:, :prompt_length] # (bs, prompt_length) - response = idx[:, prompt_length:] # (bs, response_length) - log_probs = logprobs_from_logits(logits=logits, labels=response) - batch = TensorDict( - { - "input_ids": prompts, - "responses": response, - "sequences": idx, - "old_log_probs": log_probs, - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=batch_size, - ) - - self.module.train() - - return DataProto(batch=batch) diff --git a/verl/workers/rollout/schemas.py b/verl/workers/rollout/schemas.py deleted file mode 100644 index 99f860acd..000000000 --- a/verl/workers/rollout/schemas.py +++ /dev/null @@ -1,675 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import difflib -import logging -import os -from enum import Enum -from typing import Any, Optional - -import torch -from pydantic import BaseModel, ConfigDict, model_validator -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin - -from verl.tools.schemas import OpenAIFunctionToolCall, OpenAIFunctionToolSchema -from verl.utils.model import compute_position_id_with_mask - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -BASE_CHAT_HISTORY = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "I am a user."}, -] - - -class FinishReasonTypeEnum(str, Enum): - """The enum for finish reason type.""" - - LENGTH = "length" - STOP = "stop" - TOOL_CALL = "tool_calls" - - @classmethod - def from_str(cls, value: str) -> "FinishReasonTypeEnum": - if value == "stop": - return cls.STOP - elif value == "length": - return cls.LENGTH - elif value == "tool_calls": - return cls.TOOL_CALL - else: - raise ValueError(f"Unsupported finish reason type: {value}") - - -class Message(BaseModel): - role: str - content: str | dict[str, Any] | list[dict[str, Any]] - tool_calls: Optional[list[OpenAIFunctionToolCall]] = None - - -class AsyncRolloutRequestStateEnum(str, Enum): - """The enum for async rollout request state.""" - - PENDING = "pending" - RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" - TOOL_CALLING = "tool_calling" - INTERACTING = "interacting" - - -class TokenizationSanityCheckModeEnum(str, Enum): - """The enum for tokenization sanity check mode.""" - - DISABLE = "disable" - STRICT = "strict" - IGNORE_STRIPPABLE = "ignore_strippable" - - -class AsyncRolloutRequest(BaseModel): - """The data model for async rollout.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - batch_data_id: int = 0 - rollout_offset: int = 0 - request_id: str - state: AsyncRolloutRequestStateEnum - messages: list[Message] - multi_modal_keys: Optional[list[str]] = None - multi_modal_data: Optional[dict[str, Any]] = None - multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None - tool_schemas: Optional[list[OpenAIFunctionToolSchema]] = None - tools_kwargs: dict[str, Any] = {} - interaction_kwargs: dict[str, Any] = {} - input_ids: Optional[torch.Tensor] = None - prompt_ids: Optional[torch.Tensor] = None - response_ids: Optional[torch.Tensor] = None - attention_mask: Optional[torch.Tensor] = None - prompt_attention_mask: Optional[torch.Tensor] = None - response_attention_mask: Optional[torch.Tensor] = None - position_ids: Optional[torch.Tensor] = None - prompt_position_ids: Optional[torch.Tensor] = None - response_position_ids: Optional[torch.Tensor] = None - loss_mask: Optional[torch.Tensor] = None - prompt_loss_mask: Optional[torch.Tensor] = None - response_loss_mask: Optional[torch.Tensor] = None - reward_scores: dict[str, float] - max_prompt_len: int - max_response_len: int = 8192 - max_model_len: int = 32768 - metrics: dict[str, list[Any]] = {} - - use_inference_chat_template: bool - tokenization_sanity_check_mode: TokenizationSanityCheckModeEnum - generation_prompt_ids: Optional[torch.Tensor] = None - base_conv_wo_gen_prompt_end_pos: int - base_conv_with_gen_prompt_end_pos: int - - @model_validator(mode="before") - @classmethod - def initialize_request(cls, values): - if not (messages := values.get("messages")): - raise ValueError("messages is required for AsyncRolloutRequest initialization") - if not (max_prompt_len := values.get("max_prompt_len")): - raise ValueError("max_prompt_len is required for AsyncRolloutRequest initialization") - if not (processing_class := values.pop("processing_class", None)): - raise ValueError("processing_class is required for AsyncRolloutRequest initialization") - - values["messages"] = [Message.model_validate(msg) for msg in messages] - - # If there is no multi_modal_keys, we assume the multi-modal data is image and video. - if not values.get("multi_modal_keys"): - values["multi_modal_keys"] = ["image", "video"] - if not values.get("multi_modal_data"): - values["multi_modal_data"] = {key: [] for key in values["multi_modal_keys"]} - else: - # check if all multi_modal_keys are in multi_modal_data - for key in values["multi_modal_keys"]: - if key not in values["multi_modal_data"]: - values["multi_modal_data"][key] = [] - if not values.get("multi_modal_inputs"): - values["multi_modal_inputs"] = {} - - tools = ( - [tool.model_dump() for tool in tool_schemas] if (tool_schemas := values.get("tool_schemas", [])) else None - ) - - multi_modal_data = values["multi_modal_data"] - tokens_without_prompt = cls._handle_apply_chat_template( - processing_class, - messages, - multi_modal_data=multi_modal_data, - tools=tools, - add_generation_prompt=False, - tokenize=True, - ) - if ( - values.get("input_ids") is None - or values.get("attention_mask") is None - or values.get("position_ids") is None - ): - tokenization_dict_with_prompt = cls._handle_apply_chat_template( - processing_class, - messages, - multi_modal_data=multi_modal_data, - tools=tools, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - ) - - values["input_ids"], values["attention_mask"] = ( - tokenization_dict_with_prompt["input_ids"], - tokenization_dict_with_prompt["attention_mask"], - ) - if values["input_ids"].shape[-1] > max_prompt_len: - # Only log the warning to avoid truncating in the middle of generation prompt. Consider raising an - # error for this case in the future. - logger.warning( - f"Prompt {values['batch_data_id']} has length {values['input_ids'].shape[-1]} " - f"which is greater than max_prompt_len {max_prompt_len} after applied chat template with tools." - ) - - # Process multi_modal_inputs - multi_modal_inputs = tokenization_dict_with_prompt.copy() - multi_modal_inputs.pop("input_ids", None) - multi_modal_inputs.pop("attention_mask", None) - values["multi_modal_inputs"] = multi_modal_inputs - - values["position_ids"] = values["prompt_position_ids"] = cls._get_position_ids( - processing_class, values["input_ids"], values["attention_mask"], multi_modal_inputs - ) - - values["prompt_ids"], values["prompt_attention_mask"] = values["input_ids"], values["attention_mask"] - values["loss_mask"] = values["prompt_loss_mask"] = torch.zeros_like(values["input_ids"], dtype=torch.bool) - values["generation_prompt_ids"] = values["input_ids"][..., tokens_without_prompt.shape[-1] :] - values["base_conv_wo_gen_prompt_end_pos"] = cls._handle_apply_chat_template( - processing_class, - BASE_CHAT_HISTORY, - multi_modal_data=multi_modal_data, - tools=tools, - add_generation_prompt=False, - tokenize=True, - ).shape[-1] - - values["base_conv_with_gen_prompt_end_pos"] = cls._handle_apply_chat_template( - processing_class, - BASE_CHAT_HISTORY, - multi_modal_data=multi_modal_data, - tools=tools, - add_generation_prompt=True, - tokenize=True, - ).shape[-1] - - return values - - @staticmethod - def _handle_apply_chat_template( - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, - messages: list[Message], - multi_modal_data: dict[str, Any], - tools: Optional[list[OpenAIFunctionToolSchema]] = None, - add_generation_prompt: bool = False, - tokenize: bool = False, - return_dict: bool = False, - ): - raw_prompt = processing_class.apply_chat_template( - messages, tools=tools, add_generation_prompt=add_generation_prompt, tokenize=False - ) - if not tokenize: - return raw_prompt - - if isinstance(processing_class, PreTrainedTokenizer) or isinstance(processing_class, PreTrainedTokenizerFast): - if any(len(values) > 0 for values in multi_modal_data.values()): - logger.warning( - "There is multi_modal_data but you are not using a processor. Multi-modal data will be ignored." - ) - model_inputs = processing_class(text=[raw_prompt], return_tensors="pt") - elif isinstance(processing_class, ProcessorMixin): - # When we update multi_model_keys, we also need to update this logic - images = images if len(images := multi_modal_data.get("image", [])) > 0 else None - videos = videos if len(videos := multi_modal_data.get("video", [])) > 0 else None - model_inputs = processing_class(text=[raw_prompt], images=images, videos=videos, return_tensors="pt") - else: - raise ValueError(f"Unsupported processing class type: {type(processing_class)}") - - model_inputs = dict(model_inputs) - if return_dict: - return model_inputs - else: - return model_inputs["input_ids"] - - @staticmethod - def _get_position_ids( - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None, - ) -> torch.Tensor: - # special case for qwen2vl - is_qwen2vl = ( - hasattr(processing_class, "image_processor") - and "Qwen2VLImageProcessor" in processing_class.image_processor.__class__.__name__ - ) - if is_qwen2vl: - from verl.models.transformers.qwen2_vl import get_rope_index - - image_grid_thw = video_grid_thw = second_per_grid_ts = None - if multi_modal_inputs: - image_grid_thw = multi_modal_inputs.get("image_grid_thw") - video_grid_thw = multi_modal_inputs.get("video_grid_thw") - second_per_grid_ts = multi_modal_inputs.get("second_per_grid_ts") - - assert input_ids.dim() == 2 and input_ids.shape[0] == 1, ( - f"input_ids should be 2D with batch size 1, but got shape {input_ids.shape}" - ) - assert attention_mask.dim() == 2 and attention_mask.shape[0] == 1, ( - f"attention_mask should be 2D with batch size 1, but got shape {attention_mask.shape}" - ) - new_position_ids = get_rope_index( - processing_class, - input_ids=input_ids.squeeze(0), - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - attention_mask=attention_mask.squeeze(0), - ) - return new_position_ids # (3, seq_len) - else: - return compute_position_id_with_mask(attention_mask) # (1, seq_len) - - def _update_input_ids( - self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, - new_input_ids: torch.Tensor, - attention_mask: bool, - loss_mask: bool, - new_multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None, - ) -> None: - """ - Update the input_ids, attention_mask, position_ids, and loss_mask of the request in additive manner. - """ - self.input_ids = torch.cat([self.input_ids, new_input_ids], dim=-1) - attention_mask = torch.ones_like(new_input_ids) * int(attention_mask) - self.attention_mask = torch.cat([self.attention_mask, attention_mask], dim=-1) - loss_mask = torch.ones_like(new_input_ids) * int(loss_mask) - self.loss_mask = torch.cat([self.loss_mask, loss_mask], dim=-1) - - if new_multi_modal_inputs: - self._update_multi_modal_inputs(new_multi_modal_inputs) - - new_position_ids = self._get_position_ids( - processing_class, new_input_ids, attention_mask, new_multi_modal_inputs - ) - - last_pos = self.position_ids[..., -1:] - new_position_ids = new_position_ids + (last_pos + 1) - - self.position_ids = torch.cat([self.position_ids, new_position_ids], dim=-1) - - assert ( - self.input_ids.shape[-1] - == self.attention_mask.shape[-1] - == self.position_ids.shape[-1] - == self.loss_mask.shape[-1] - ), f"""Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, - {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}""" - - def _update_multi_modal_inputs(self, new_multi_modal_inputs: dict[str, torch.Tensor]) -> None: - """ - Update the multi_modal_inputs of the request in additive manner. - """ - for key in new_multi_modal_inputs: - input_tensor = new_multi_modal_inputs[key] - self.multi_modal_inputs[key] = ( - torch.cat([self.multi_modal_inputs[key], input_tensor], dim=0) - if key in self.multi_modal_inputs - else input_tensor - ) - - def get_generation_prompt_ids( - self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin - ) -> list[int]: - """ - Get the generation prompt ids for rollout engine. - - Because rollout engine(SGLang) requires the ids to be a list, we need to convert the tensor to a list. - """ - generation_prompt_ids = ( - None - if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all() - else self.generation_prompt_ids - ) - if generation_prompt_ids is not None: - self._update_input_ids(processing_class, generation_prompt_ids, attention_mask=True, loss_mask=False) - - if self.use_inference_chat_template: - messages = [msg.model_dump() for msg in self.messages] - tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None - generation_prompt_ids = self._handle_apply_chat_template( - processing_class, - messages, - multi_modal_data=self.multi_modal_data, - tools=tools, - add_generation_prompt=True, - tokenize=True, - ) - return generation_prompt_ids.squeeze(0).tolist() - else: - return self.input_ids.squeeze(0).tolist() - - def add_user_message( - self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, - content: str, - ) -> None: - self.messages.append(Message(role="user", content=content)) - messages = [*BASE_CHAT_HISTORY, self.messages[-1]] - tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None - - # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine - # Inference, it is pure text. - content_ids = self._handle_apply_chat_template( - processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True - )[..., self.base_conv_wo_gen_prompt_end_pos :] - self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=False) - - def add_assistant_message( - self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, - content: str, - tool_calls: Optional[list[OpenAIFunctionToolCall]] = None, - ) -> None: - self.messages.append(Message(role="assistant", content=content, tool_calls=tool_calls)) - - messages = [*BASE_CHAT_HISTORY, self.messages[-1]] - tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None - - # We don't need to pass multi_modal_data here because we don't have any multi-modal data from Engine - # Inference, it is pure text. - content_ids = self._handle_apply_chat_template( - processing_class, messages, multi_modal_data={}, tools=tools, add_generation_prompt=False, tokenize=True - )[..., self.base_conv_with_gen_prompt_end_pos :] - self._update_input_ids(processing_class, content_ids, attention_mask=True, loss_mask=True) - - def add_tool_response_messages( - self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, - contents: list[str | dict[str, Any]], - ) -> None: - if not contents: - return - # We also handle the case when tool returns image - # We require the processing of the image and video to be done at tool.execute() level - delta_multi_modal_data = {key: [] for key in self.multi_modal_keys} - for content in contents: - if isinstance(content, dict): - content_list = [] - # When we update multi_model_keys, we also need to update this logic - if "image" in content: - if not isinstance(content["image"], list): - raise ValueError( - f"Image must be a list, but got {type(content['image'])}. Please check the tool.execute(). " - f"For single images, wrap in a list: [image]. " - f"Example: {{'image': [img1]}} or {{'image': [img1, img2, ...]}}." - ) - - content_list.extend([{"type": "image"} for _ in content["image"]]) - delta_multi_modal_data["image"].extend(content["image"]) - if "video" in content: - if not isinstance(content["video"], list): - raise ValueError( - f"Video must be a list, but got {type(content['video'])}. Please check the tool.execute(). " - f"For single videos, wrap in a list: [video]. " - f"Example: {{'video': [video1]}} or {{'video': [video1, video2, ...]}}." - ) - - content_list.extend([{"type": "video"} for _ in content["video"]]) - delta_multi_modal_data["video"].extend(content["video"]) - if "text" in content: - content_list.append({"type": "text", "text": content["text"]}) - for key in content: - if key not in ["image", "video", "text"]: - logger.warning( - f"Tool response message contains unexpected key: {key} " - f"while we only support `image`, `video`, and `text`." - ) - self.messages.append(Message(role="tool", content=content_list)) - else: - self.messages.append(Message(role="tool", content=content)) - - messages = [*BASE_CHAT_HISTORY, *self.messages[-len(contents) :]] - tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None - - for key in self.multi_modal_keys: - if len(delta_multi_modal_data[key]) > 0: - self.multi_modal_data[key].extend(delta_multi_modal_data[key]) - - # We just passed the new multi-modal data to the chat template to update the input_ids. - content_info = self._handle_apply_chat_template( - processing_class, - messages, - multi_modal_data=delta_multi_modal_data, - tools=tools, - add_generation_prompt=False, - tokenize=True, - return_dict=True, - ) - content_ids = content_info["input_ids"][..., self.base_conv_wo_gen_prompt_end_pos :] - - # process multi_modal_inputs - multi_modal_inputs = content_info.copy() - multi_modal_inputs.pop("input_ids", None) - multi_modal_inputs.pop("attention_mask", None) - self._update_input_ids( - processing_class, - content_ids, - attention_mask=True, - loss_mask=False, - new_multi_modal_inputs=multi_modal_inputs, - ) - - def update_metrics(self, metrics: Any, tool_id: str) -> None: - """ - metrics: should be a dict of tools_name -> Any - """ - if self.metrics.get(tool_id) is None: - self.metrics[tool_id] = [] - self.metrics[tool_id].append(metrics) - - def _get_prompt_diffs( - self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, - full_prompt_ids: torch.Tensor, - current_prompt_ids: torch.Tensor, - diff_surrounding_chars: int = 10, - ) -> list[dict[str, Any]]: - """Get differences between full prompt and current prompt with surrounding context. - - This function helps debug tokenization mismatches by showing the differences between - full prompt and current prompt with surrounding context. Instead of just showing - the exact diff, it includes additional tokens before and after to help locate - the issue in the chat template. - - For example, if the actual diff is a newline change from "\n\n" to "\n", with - diff_surrounding_chars the output might look like: - - full_prompt_chunk: "<|im_start|>assistant\n\nI think..." - current_prompt_chunk: "<|im_start|>assistant\nI think..." - - This context makes it much easier to identify where in the chat template the - mismatch occurs. - - Args: - processing_class: The processing class to use for decoding the token IDs - full_prompt_ids: Token IDs from applying chat template to all messages at once - current_prompt_ids: Token IDs from incremental chat template application - diff_surrounding_chars: Number of surrounding characters to include for context (default: 10) - - Returns: - List of dicts containing the differing chunks with context and their indices - """ - full_prompt_ids = full_prompt_ids.squeeze(0) - current_prompt_ids = current_prompt_ids.squeeze(0) - full_prompt = processing_class.decode(full_prompt_ids, skip_special_tokens=False) - current_prompt = processing_class.decode(current_prompt_ids, skip_special_tokens=False) - s = difflib.SequenceMatcher(None, full_prompt, current_prompt, autojunk=False) - diffs = [] - for tag, i1, i2, j1, j2 in s.get_opcodes(): - if tag == "equal": - continue - - # Get the surrounding context for better readability - start_i = max(0, i1 - diff_surrounding_chars) - end_i = min(len(full_prompt), i2 + diff_surrounding_chars) - start_j = max(0, j1 - diff_surrounding_chars) - end_j = min(len(current_prompt), j2 + diff_surrounding_chars) - - diffs.append( - { - "full_prompt_chunk": full_prompt[start_i:end_i], - "current_prompt_chunk": current_prompt[start_j:end_j], - "indices": (start_i, end_i, start_j, end_j), - } - ) - return diffs - - def finalize( - self, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, - reward_scores: dict[str, list[float]], - finish_reason_type: FinishReasonTypeEnum = FinishReasonTypeEnum.STOP, - ) -> None: - self.state = AsyncRolloutRequestStateEnum.COMPLETED - self.reward_scores = reward_scores - - # In case we failed to generate the assistant message and the generation prompt ids were already added to - # input_ids, remove them from the end of input_ids - if self.input_ids[..., -self.generation_prompt_ids.shape[-1] :].eq(self.generation_prompt_ids).all(): - self.input_ids = self.input_ids[..., : -self.generation_prompt_ids.shape[-1]] - self.attention_mask = self.attention_mask[..., : -self.generation_prompt_ids.shape[-1]] - self.position_ids = self.position_ids[..., : -self.generation_prompt_ids.shape[-1]] - self.loss_mask = self.loss_mask[..., : -self.generation_prompt_ids.shape[-1]] - - self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :] - - if self.tokenization_sanity_check_mode != TokenizationSanityCheckModeEnum.DISABLE: - # When there is a diff, we log the diffs with diff_surrounding_chars context - diff_surrounding_chars = 10 - - messages = [msg.model_dump() for msg in self.messages] - tools = [tool.model_dump() for tool in self.tool_schemas] if self.tool_schemas else None - full_prompt_info = self._handle_apply_chat_template( - processing_class, - messages, - multi_modal_data=self.multi_modal_data, - tools=tools, - add_generation_prompt=False, - tokenize=True, - return_dict=True, - ) - full_prompt_ids = full_prompt_info["input_ids"] - - # We must use dict(full_prompt_info) to convert BatchFeature values to a new dict - # because np.array() only keeps the keys for BatchFeature. - full_prompt_multi_modal_inputs = full_prompt_info.copy() - full_prompt_multi_modal_inputs.pop("input_ids", None) - full_prompt_multi_modal_inputs.pop("attention_mask", None) - - for multi_modal_inputs_key in self.multi_modal_inputs: - if multi_modal_inputs_key in full_prompt_multi_modal_inputs: - if ( - not self.multi_modal_inputs[multi_modal_inputs_key] - .eq(full_prompt_multi_modal_inputs[multi_modal_inputs_key]) - .all() - ): - logger.warning( - f"Multi-modal data {multi_modal_inputs_key} is not consistent. " - f"This may lead to unexpected behavior during training. " - f"Please review your multi_modal_inputs logic." - ) - else: - logger.warning( - f"Multi-modal inputs key {multi_modal_inputs_key} is not found in the multi_modal_inputs. " - f"This may lead to unexpected behavior during training." - f"Please review your multi_modal_inputs logic." - ) - - if diffs := self._get_prompt_diffs( - processing_class, full_prompt_ids, self.input_ids, diff_surrounding_chars=diff_surrounding_chars - ): - log_warning = False - if self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.STRICT: - log_warning = True - elif self.tokenization_sanity_check_mode == TokenizationSanityCheckModeEnum.IGNORE_STRIPPABLE: - non_strippable_diffs_exist = any( - d["full_prompt_chunk"].strip() or d["current_prompt_chunk"].strip() for d in diffs - ) - if non_strippable_diffs_exist: - log_warning = True - - if log_warning: - mode_str = f" ({self.tokenization_sanity_check_mode.value})" - logger.warning( - f"Inconsistent training and inference tokenization detected{mode_str}. This may lead to " - f"unexpected behavior during training. Please review your chat template to determine if this " - f"is intentional. For more information, refer to the multiturn README.md." - ) - logger.warning( - f"Showing {diff_surrounding_chars} characters before and after the diffs for context and " - f"better readability." - ) - diff_details_list = [] - for d in diffs: - i1, i2, j1, j2 = d["indices"] - diff_details_list.append( - f"idx {i1}:{i2} -> {j1}:{j2} | full_prompt_chunk: {repr(d['full_prompt_chunk'])} | " - f"current_prompt_chunk: {repr(d['current_prompt_chunk'])}" - ) - diff_details = "\n".join(diff_details_list) - logger.warning(f"Found differences:\n{diff_details}") - - if finish_reason_type == FinishReasonTypeEnum.STOP: - pass - elif finish_reason_type == FinishReasonTypeEnum.LENGTH: - pass - else: - raise ValueError(f"Unsupported finalize finish reason type: {finish_reason_type}") - self.truncate_output_ids(processing_class) - - assert ( - self.input_ids.shape[-1] - == self.attention_mask.shape[-1] - == self.position_ids.shape[-1] - == self.loss_mask.shape[-1] - ), f"""Request {self.request_id} has different length of {self.input_ids.shape[-1]=}, - {self.attention_mask.shape[-1]=}, {self.position_ids.shape[-1]=}, {self.loss_mask.shape[-1]=}""" - - def truncate_output_ids( - self, processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin - ) -> None: - self.input_ids = self.input_ids[..., : self.max_model_len] - self.attention_mask = self.attention_mask[..., : self.max_model_len] - self.position_ids = self.position_ids[..., : self.max_model_len] - self.loss_mask = self.loss_mask[..., : self.max_model_len] - self.response_ids = self.input_ids[..., self.prompt_ids.shape[-1] :][..., : self.max_response_len] - self.response_attention_mask = self.attention_mask[..., self.prompt_attention_mask.shape[-1] :][ - ..., : self.max_response_len - ] - self.response_position_ids = self.position_ids[..., self.prompt_position_ids.shape[-1] :][ - ..., : self.max_response_len - ] - self.response_loss_mask = self.loss_mask[..., self.prompt_loss_mask.shape[-1] :][..., : self.max_response_len] diff --git a/verl/workers/rollout/sglang_rollout/__init__.py b/verl/workers/rollout/sglang_rollout/__init__.py deleted file mode 100644 index 43a1eebb4..000000000 --- a/verl/workers/rollout/sglang_rollout/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -from .sglang_rollout import SGLangRollout - -__all__ = ["SGLangRollout"] diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py deleted file mode 100644 index df26765c2..000000000 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import logging -from typing import Any - -import ray -from omegaconf import DictConfig -from starlette.requests import Request -from starlette.responses import JSONResponse - -from verl.workers.rollout.async_server import AsyncServerBase - -logger = logging.getLogger(__file__) - - -@ray.remote(num_cpus=1) -class AsyncSglangServer(AsyncServerBase): - def __init__(self, config: DictConfig, dp_size: int, dp_rank: int, wg_prefix: str): - super().__init__() - self.config = config.actor_rollout_ref - self._tp_size = self.config.rollout.get("tensor_model_parallel_size", 1) - self._dp_size = dp_size - self._dp_rank = dp_rank - self.wg_prefix = wg_prefix - self.workers = [] - self.master_worker = None - - async def init_engine(self): - if self.workers: - # avoid init twice - return - all_actors = ray.util.list_named_actors(all_namespaces=True) - matched_actors = [ - actor for actor in all_actors if actor.get("name", None).startswith(self.wg_prefix + "WorkerDict_") - ] - - for matched_actor in matched_actors: - fields = matched_actor["name"].split(":") - assert len(fields) == 2, f"invalid actor name: {matched_actor['name']}" - pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1]) - - if (self._dp_size * pg_index + local_rank) // self._tp_size == self._dp_rank: - worker = ray.get_actor(**matched_actor) - self.workers.append(worker) - if (self._dp_size * pg_index + local_rank) / self._tp_size == self._dp_rank: - self.master_worker = worker - - async def chat_completion(self, raw_request: Request): - request = await raw_request.json() - - # only send request to master worker in tp rank 0 - output_future = self.master_worker.chat_completion.remote(request) - [outputs] = await asyncio.gather(output_future) - return JSONResponse(outputs) - - async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: - return await self.master_worker.generate.remote(prompt_ids, sampling_params, request_id) - - async def wake_up(self): - if not self.config.rollout.free_cache_engine: - return - - tasks = [worker.wake_up.remote() for worker in self.workers] - if tasks: - await asyncio.gather(*tasks) - - async def sleep(self): - if not self.config.rollout.free_cache_engine: - return - - tasks = [worker.sleep.remote() for worker in self.workers] - if tasks: - await asyncio.gather(*tasks) diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py deleted file mode 100644 index 3c6694325..000000000 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ /dev/null @@ -1,1391 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import asyncio -import logging -import multiprocessing as mp -import os -import time -from copy import deepcopy -from json import JSONDecodeError -from typing import Any, List, Optional, Tuple -from uuid import uuid4 - -import numpy as np -import sglang.srt.entrypoints.engine -import torch -import torch.distributed as dist -from omegaconf import DictConfig -from sglang.srt.managers.tokenizer_manager import ( - ReleaseMemoryOccupationReqInput, - ResumeMemoryOccupationReqInput, - UpdateWeightsFromTensorReqInput, -) -from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import ( - MultiprocessingSerializer, - assert_pkg_version, - get_ip, - get_open_port, - is_cuda, - maybe_set_triton_cache_manager, - set_prometheus_multiproc_dir, - set_ulimit, -) -from tensordict import TensorDict -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.nn.utils.rnn import pad_sequence -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin - -from verl import DataProto -from verl.interactions.base import BaseInteraction -from verl.interactions.utils.interaction_registry import initialize_interactions_from_config -from verl.third_party.sglang import parallel_state as sglang_ps -from verl.tools.base_tool import BaseTool -from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall -from verl.tools.utils.tool_registry import initialize_tools_from_config -from verl.utils.net_utils import is_ipv6 -from verl.utils.profiler import GPUMemoryLogger -from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length -from verl.workers.rollout.base import BaseRollout -from verl.workers.rollout.schemas import ( - AsyncRolloutRequest, - AsyncRolloutRequestStateEnum, - FinishReasonTypeEnum, - Message, -) -from verl.workers.rollout.sglang_rollout.utils import broadcast_pyobj - -try: - from sglang.srt.function_call.function_call_parser import FunctionCallParser -except ImportError: - from sglang.srt.function_call_parser import FunctionCallParser - -try: - from sglang.srt.entrypoints.openai.protocol import Tool -except ImportError: - from sglang.srt.openai_api.protocol import Tool - - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -# patch to avoid issue https://github.com/sgl-project/sglang/issues/6723 -def _set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" - os.environ["CUDA_MODULE_LOADING"] = "AUTO" - - # Set prometheus env vars - if server_args.enable_metrics: - set_prometheus_multiproc_dir() - - # Set ulimit - set_ulimit() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Check flashinfer version - if server_args.attention_backend == "flashinfer": - assert_pkg_version( - "flashinfer_python", - "0.2.5", - "Please uninstall the old version and reinstall the latest version by following the instructions at https://docs.flashinfer.ai/installation.html.", - ) - if is_cuda(): - assert_pkg_version( - "sgl-kernel", - "0.1.1", - "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", - ) - - # Set mp start method - mp.set_start_method("spawn", force=True) - - -sglang.srt.entrypoints.engine._set_envs_and_config = _set_envs_and_config - - -# because chatCompletion is an async method, it makes the whole ray actor be an async actor -# which can not call loop.run_until_complete. So we need to make the engine to be an async class -class AsyncEngine(sglang.srt.entrypoints.engine.Engine): - def __init__(self, **kwargs): - super().__init__(**kwargs) - # default to use dummy load format, which need to reload weights in first time - self._need_reload = True - - async def release_memory_occupation(self, tags: Optional[list[str]] = None): - """Release GPU occupation temporarily.""" - if tags is None: - obj = ReleaseMemoryOccupationReqInput() - else: - obj = ReleaseMemoryOccupationReqInput(tags=tags) - return await self.tokenizer_manager.release_memory_occupation(obj, None) - - async def resume_memory_occupation(self, tags: Optional[list[str]] = None): - """Resume GPU occupation.""" - # because __init__ is a sync method, it can not call the async release_memory_occupation - # have to move release_memory_occupation from __init__ to here - # For multi-stage awake, we run release weight and kv_cache when we resume weights for the first time. - if self._need_reload: - await self.release_memory_occupation() - self._need_reload = False - - if tags is None: - obj = ResumeMemoryOccupationReqInput() - else: - obj = ResumeMemoryOccupationReqInput(tags=tags) - return await self.tokenizer_manager.resume_memory_occupation(obj, None) - - async def update_weights_from_tensor( - self, - named_tensors: List[Tuple[str, torch.Tensor]], # noqa: UP006 - load_format: Optional[str] = None, - flush_cache: bool = True, - ): - """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false - to avoid duplicated cache cleaning operation.""" - obj = UpdateWeightsFromTensorReqInput( - serialized_named_tensors=[ - MultiprocessingSerializer.serialize(named_tensors) for _ in range(self.server_args.tp_size) - ], - load_format=load_format, - flush_cache=flush_cache, - ) - return await self.tokenizer_manager.update_weights_from_tensor(obj, None) - - async def flush_cache(self): - return await self.tokenizer_manager.flush_cache() - - -# NOTE(sgm): add for verl. We can optimize it by making -# the dataloader yield List[int] without padding. -def _pre_process_inputs( - pad_token_id, - prompt_token_ids: torch.Tensor, -) -> torch.Tensor: - # remove the left padding in the prompt token_id - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] - return prompt_token_ids[non_pad_index:] - - -# NOTE(linjunrong): adhoc -def _post_process_outputs(processing_class, output): - try: - # This is when processing_class is a processor - tokenizer = processing_class.tokenizer - except AttributeError: - try: - # This is when processing_class is a tokenizer - tokenizer = processing_class - except AttributeError as e: - raise ValueError(f"Cannot get tokenizer from processing_class {processing_class}") from e - - def _map_each_response(resp): - output_token_logprobs = resp["meta_info"]["output_token_logprobs"] - log_probs, output_token_ids = zip( - *[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs], strict=True - ) - return torch.tensor(output_token_ids), torch.tensor(log_probs) - - out_map = map(lambda x: _map_each_response(x), output) - batched_output_token_ids = [] - batched_logprobs = [] - for output_token_ids, log_probs in out_map: - batched_output_token_ids.append(output_token_ids) - batched_logprobs.append(log_probs) - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id - batched_output_token_ids = pad_sequence(batched_output_token_ids, batch_first=True, padding_value=pad_token_id) - if len(batched_logprobs) > 0: - batched_logprobs = pad_sequence(batched_logprobs, batch_first=True, padding_value=pad_token_id) - return batched_output_token_ids, batched_logprobs - - -def get_tool_call_parser_type( - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, -) -> str: - items = FunctionCallParser.ToolCallParserEnum.items() - for parser_type, parser_cls in items: - parser = parser_cls() - try: - # This is when processing_class is a tokenizer - tokenizer_vocab = processing_class.get_vocab() - except AttributeError: - try: - # This is when processing_class is a processor - tokenizer_vocab = processing_class.tokenizer.get_vocab() - except AttributeError as e: - raise ValueError(f"Cannot get vocab from processing_class {processing_class}") from e - - if parser.bot_token.strip() in tokenizer_vocab and ( - parser.eot_token == "" or parser.eot_token.strip() in tokenizer_vocab - ): - return parser_type - else: - raise ValueError(f"No tool call parser found for processing_class {processing_class}") - - -class SGLangRollout(BaseRollout): - def __init__( - self, - actor_module: str, - config: DictConfig, - processing_class: PreTrainedTokenizer | PreTrainedTokenizerFast | ProcessorMixin, - model_hf_config, - port=None, - trust_remote_code: bool = False, - device_mesh: DeviceMesh | None = None, - **kwargs, - ): - """Synchronized SGLang rollout engine. - - Args: - actor_module: Huggingface model name or path to the model. The - model should be supported by SGLang. - config: A DictConfig object containing SGLang-specific operational - parameters and rollout settings. - Refer to https://docs.sglang.ai/backend/server_arguments.html - processing_class: The tokenizer or processor instance compatible with the actor_module. - model_hf_config: The Hugging Face model's configuration (e.g., - `transformers.PretrainedConfig`). It provides architectural - details and hyperparameters like `max_position_embeddings`, - used by SGLang for correct model initialization. This is - the model's inherent design, not SGLang's runtime behavior. - port: Optional port for multi-node initialization when nnodes > 1. - trust_remote_code: Whether or not to allow for custom models - defined on the Hub in their own modeling files. - device_mesh: Optional `DeviceMesh` object for distributed setup. - **kwargs: Additional keyword arguments, primarily `train_tp` for - Megatron Backend integration to initialize hybrid engine - process groups. - """ - super().__init__() - self.config = config - self._device_mesh_cpu = device_mesh - os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") - - ( - self._tool_schemas, - self._tool_map, - self._tool_call_parser_type, - self._sgl_tools, - self._function_call_parser, - ) = self._initialize_tools(config, processing_class) - self.interaction_map: dict[str, BaseInteraction] = self._initialize_interactions(config) - # If turn on `free_cache_engine`, SGLang engine's KV cache - # will be freed after each `generate_sequences` call. - logger.info( - f"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: " - f"{self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: " - f"{self._function_call_parser}" - ) - - self._init_distributed_env(device_mesh_cpu=device_mesh, **kwargs) - - self._verify_config(model_hf_config=model_hf_config) - # initialize the inference engine - self._init_inference_engine(trust_remote_code, actor_module, port) - - self._init_sampling_params(**kwargs) - - self.processing_class = processing_class - - try: - # This is when processing_class is a tokenizer - self.pad_token_id = self.processing_class.pad_token_id - except AttributeError: - try: - # This is when processing_class is a processor - self.pad_token_id = self.processing_class.tokenizer.pad_token_id - except AttributeError as e: - raise ValueError(f"Cannot get pad_token_id from processing_class {self.processing_class}") from e - - def _init_distributed_env(self, device_mesh_cpu, **kwargs): - self._device_mesh_cpu = device_mesh_cpu - os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") - self.tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert self.tensor_parallel_size <= dist.get_world_size(), ( - "tensor parallel size should be less than or equal to the world size" - ) - self.train_tp = kwargs.get("train_tp", None) - if self.train_tp is not None: - # deployed with megatron - os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" - os.environ["MEGATRON_IMPORT_TIMERS"] = "0" - train_tp = kwargs.get("train_tp", None) - num_tp_per_train_tp = train_tp // self.tensor_parallel_size - sglang_ps.initialize_parallel_state( - tensor_model_parallel_size=self.tensor_parallel_size, - num_tp_per_train_tp=num_tp_per_train_tp, - ) - - tp_size = self.tensor_parallel_size - world_size = int(os.getenv("WORLD_SIZE", "-1")) - - # init device mesh - if self._device_mesh_cpu is None: - device_mesh_kwargs = dict( - mesh_shape=(world_size // tp_size, tp_size, 1), - mesh_dim_names=["dp", "tp", "pp"], - ) - - self._device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs) - - self._rank = self._device_mesh_cpu.get_rank() - self._tp_rank = self._device_mesh_cpu["tp"].get_local_rank() - self._tp_size = self._device_mesh_cpu["tp"].size() - if self._rank == 0: - logger.info(f"_init_distributed_env: :tp_world: {self._tp_size}, global_world: {world_size}") - # get tp_rank of this process in this tp group - visible_devices = [None] * self._device_mesh_cpu.size(1) - - torch.distributed.all_gather_object( - visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], self._device_mesh_cpu.get_group("tp") - ) - self.visible_devices_set = set(",".join(visible_devices).split(",")) - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(self.visible_devices_set))) - - def _verify_config(self, model_hf_config): - if not self.config.get("max_model_len", None): - self.config.max_model_len = self.config.prompt_length + self.config.response_length - assert ( - self.config.max_model_len >= self.config.prompt_length + self.config.response_length - ), f"""max_model_len should be greater than total sequence length (prompt_length + response_length): - {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}""" - max_position_embeddings = None - if hasattr(model_hf_config, "max_position_embeddings"): - max_position_embeddings = model_hf_config.max_position_embeddings - elif hasattr(model_hf_config, "llm_config") and hasattr(model_hf_config.llm_config, "max_position_embeddings"): - max_position_embeddings = model_hf_config.llm_config.max_position_embeddings - elif hasattr(model_hf_config, "text_config") and hasattr( - model_hf_config.text_config, "max_position_embeddings" - ): - max_position_embeddings = model_hf_config.text_config.max_position_embeddings - if max_position_embeddings is None: - raise ValueError("max_position_embeddings not found in model_hf_config") - rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) - if not rope_scaling_config: - assert max_position_embeddings >= self.config.prompt_length + self.config.response_length, ( - "model context length should be greater than total sequence length" - ) - else: - # handle type where there's a length extend factor - # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support - # for using yarn as an example - rope_scaling_factor = rope_scaling_config.get("factor", 1.0) - - assert ( - model_hf_config.max_position_embeddings * rope_scaling_factor - >= self.config.prompt_length + self.config.response_length - ), ( - f"model context length should be greater than total sequence length, " - f"got rope_scaling_factor={rope_scaling_factor} and " - f"max_position_embeddings={model_hf_config.max_position_embeddings}" - ) - - # currently max_assistant_turns stand for max number of tool calls - if self.config.multi_turn.max_assistant_turns is None: - self.config.multi_turn.max_assistant_turns = self.config.max_model_len // 3 - if self.config.multi_turn.max_user_turns is None: - self.config.multi_turn.max_user_turns = self.config.max_model_len // 3 - - def _init_inference_engine(self, trust_remote_code, actor_module, port): - # initialize the inference engine - nnodes = -(-self._tp_size // len(self.visible_devices_set)) - if nnodes > 1: - ip = get_ip() - port = get_open_port() if port is None else port - [ip, port] = broadcast_pyobj( - [ip, port], - rank=self._rank, - dist_group=self._device_mesh_cpu.get_group("tp"), - src=self._device_mesh_cpu["tp"].mesh[0].item(), - force_cpu_device=False, - ) - dist_init_addr = f"[{ip}]:{port}" if is_ipv6(ip) else f"{ip}:{port}" - else: - dist_init_addr = None - - load_format = "dummy" if self.config.load_format.startswith("dummy") else self.config.load_format - tp_size_per_node = self._tp_size // nnodes - node_rank = self._tp_rank // tp_size_per_node - first_rank_in_node = self._tp_rank % tp_size_per_node == 0 - - if first_rank_in_node: - rank = dist.get_rank() - os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" - self._engine = AsyncEngine( - model_path=actor_module, - dtype=self.config.dtype, - mem_fraction_static=self.config.gpu_memory_utilization, - enable_memory_saver=True, - base_gpu_id=0, - gpu_id_step=1, - tp_size=self._tp_size, - node_rank=node_rank, - load_format=load_format, - dist_init_addr=dist_init_addr, - nnodes=nnodes, - trust_remote_code=trust_remote_code, - # NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new - # when random.seed is being set during training - port=30000 + rank, - # NOTE(Chenyang): if you want to debug the SGLang engine output - # please set the following parameters - # Otherwise, it will make the engine run too slow - # log_level="INFO", - # log_requests=True, - # log_requests_level=2, - # max_running_requests=1, - mm_attention_backend="fa3", - attention_backend="fa3", - # In async mode, we want token in token out. - skip_tokenizer_init=self.config.mode == "async", - ) - else: - self._engine = None - - self.sharding_manager = None - self.is_sleep = True - - def _init_sampling_params(self, **kwargs): - kwargs = dict( - n=1, - max_new_tokens=self.config.response_length, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - ) - # supporting adding any sampling params from the config file - for k in self.config.keys(): - if hasattr(SamplingParams(), str(k)) or "stop" in str(k): - kwargs[k] = self.config.get(k) - kwargs["n"] = 1 # already repeat in ray_trainer - self.sampling_params = kwargs - - def _initialize_tools(self, config, processing_class): - """Initialize tools from configuration. - - Args: - config: Configuration object containing tool-related settings, - specifically `config.multi_turn.tool_config_path`. - tokenizer: The tokenizer instance used for parsing tool calls from - the model's generated text. - - Returns: - tuple: A tuple containing: - - tool_schemas (list[dict]): OpenAI-formatted JSON schemas - defining each tool's capabilities. - - tool_map (dict[str, BaseTool]): A dictionary mapping tool - names to their executable `BaseTool` objects. - - tool_call_parser_type (str): The identifier for the specific - parser type (e.g., 'json_mode', 'tool_code') used to extract - tool calls. - - sgl_tools (list[sglang.srt.openai_api.protocol.Tool]): Tool - definitions optimized for SGLang's internal engine. - - function_call_parser (sglang.srt.function_call_parser.FunctionCallParser): - The active parser instance responsible for extracting - structured tool calls from model outputs. - """ - if config.multi_turn.tool_config_path is None: - return [], {}, None, [], None - - tools_config_file = config.multi_turn.tool_config_path - tool_list = initialize_tools_from_config(tools_config_file) - - logger.info(f"Initialize tools from configuration.: tool_list: {tool_list}") - tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list] - tool_map = {tool.name: tool for tool in tool_list} - tool_call_parser_type = get_tool_call_parser_type(processing_class) - sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas] - function_call_parser = FunctionCallParser( - sgl_tools, - tool_call_parser_type, - ) - - return ( - tool_schemas, - tool_map, - tool_call_parser_type, - sgl_tools, - function_call_parser, - ) - - def _initialize_interactions(self, config): - """Initialize interactions from configuration. - - Returns: - dict[str, BaseInteraction]: A dictionary mapping interaction names to interaction instances. - """ - if config.multi_turn.interaction_config_path is None: - return {} - - interaction_config_file = config.multi_turn.interaction_config_path - interaction_map = initialize_interactions_from_config(interaction_config_file) - - logger.info(f"Initialize interactions from configuration: interaction_map: {list(interaction_map.keys())}") - return interaction_map - - @GPUMemoryLogger(role="sglang rollout", logger=logger) - @torch.no_grad() - def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: - """Generate sequences for a batch of prompts. - - Args: - batch (DataProto): Input batch. - - Returns: - DataProto: Output batch. - - prompts: [bsz, prompt_length], prompt token ids from dataset. - - responses: [bsz, response_length], output token ids include response tokens - from LLM generation and observation tokens from tool_calls. - - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens. - - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens - and response tokens. - - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens. - - position_ids: [bsz, prompt_length + response_length], incremental position ids. - - For multi-turn conversations: - responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->| - response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0| - """ - if self.config.multi_turn.enable: - return self._req_level_generate_sequences(prompts, **kwargs) - return self._batch_level_generate_sequences(prompts, **kwargs) - - @GPUMemoryLogger(role="sglang rollout", logger=logger) - @torch.no_grad() - def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: - """Generates single-turn sequences for a batch of prompts. - For single-turn generation, all prompts are processed in one request. - `_batch_level_generate_sequences` involves: - 1. Extracting and pre-processing prompt token IDs from the input - `prompts`. This includes handling padding and preparing raw - token ID lists. - 2. Preparing inputs for the SGLang engine, including multi-modal - data if present. - 3. Invoking the SGLang engine (`self._engine.async_generate`, - an async coroutine) with the batch of processed inputs and - specified sampling parameters on the master TP rank. - 4. Broadcasting the results from the master TP rank to all - other TP ranks. - 5. Post-processing the engine's output to format the generated - token IDs and (if applicable) log probabilities. - 6. Constructing the final sequences by concatenating original - prompts with the generated responses. - 7. Updating attention masks and position IDs to reflect the full - concatenated sequences. - 8. If `self.config.free_cache_engine` is true, the SGLang engine's - KV cache is flushed after generation on the master TP rank. - Args: - prompts: A `DataProto` object containing the batch of - input prompts, including tensor data (like `input_ids`, - `attention_mask`) and meta-information (like `eos_token_id`, - `do_sample`). - **kwargs: Additional keyword arguments that can override the - default sampling parameters (e.g., `temperature`, `top_p`, - `max_new_tokens`). These are temporarily applied using - `update_sampling_params`. - Returns: - DataProto: A `DataProto` object containing the batch of - generated sequences. This includes tensors for `prompts` - (original input IDs), `responses` (generated token IDs), - `input_ids` (concatenated prompt and response), - `attention_mask`, and `position_ids` for the full - sequences. - Note that in GRPO, if the prompts are validated, we repeat the prompts for rollout.n times in ray_trainer. - Thus we do not need to repeat the prompts here and set the sampling parameter n to 1. - """ - # input ids: (bs, prompt_length), left-padded - idx = prompts.batch["input_ids"] - # attention_mask: (bs, seq_length), left-padded - attention_mask = prompts.batch["attention_mask"] - position_ids = prompts.batch["position_ids"] - - # used to generate attention mask for the - # response based on EOS token position - eos_token_id = prompts.meta_info["eos_token_id"] - - batch_size = idx.size(0) - - # Extract non-tensor data - non_tensor_batch = prompts.non_tensor_batch - if "raw_prompt_ids" not in non_tensor_batch: - non_tensor_batch["raw_prompt_ids"] = np.array( - [_pre_process_inputs(self.pad_token_id, idx[i]).tolist() for i in range(batch_size)], - dtype=object, - ) - - if "multi_modal_data" in non_tensor_batch: - sglang_inputs = [] - for raw_prompt_ids, multi_modal_data in zip( - non_tensor_batch.pop("raw_prompt_ids"), - non_tensor_batch.pop("multi_modal_data"), - strict=True, - ): - sglang_inputs.append( - { - "prompt_token_ids": raw_prompt_ids, - "multi_modal_data": multi_modal_data, - "image_data": ( - multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None - ), - } - ) - else: - sglang_inputs = [ - {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") - ] - - # Ensure token IDs are lists or numpy arrays - for input_data in sglang_inputs: - if isinstance(input_data["prompt_token_ids"], np.ndarray): - input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() - elif not isinstance(input_data["prompt_token_ids"], list): - raise TypeError( - f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}" - ) - - # Extract token IDs and image data for SGLang Engine - idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs] - image_list = [input_data.get("image_data", None) for input_data in sglang_inputs] - - do_sample = prompts.meta_info.get("do_sample", True) - is_validate = prompts.meta_info.get("validate", False) - - # Create request-level sampling parameters - request_sampling_params = self.sampling_params.copy() - if not do_sample: - request_sampling_params.update( - { - "n": 1, - "presence_penalty": 0.0, - "frequency_penalty": 0.0, - "repetition_penalty": 1.0, - "temperature": 0, - "top_p": 1, - "top_k": -1, - "ignore_eos": False, - "min_new_tokens": 0, - "max_new_tokens": self.config.response_length, - "skip_special_tokens": True, - "spaces_between_special_tokens": True, - } - ) - elif is_validate: - request_sampling_params.update( - { - "top_k": self.config.val_kwargs.top_k, - "top_p": self.config.val_kwargs.top_p, - "temperature": self.config.val_kwargs.temperature, - "n": 1, # if validate, already repeat in ray_trainer - } - ) - - # Update with any additional kwargs - request_sampling_params.update(kwargs) - - if self._tp_rank == 0: - loop = asyncio.get_event_loop() - output = loop.run_until_complete( - self._engine.async_generate( - prompt=None, # because we have already convert it to prompt token id - sampling_params=request_sampling_params, - return_logprob=True, - input_ids=idx_list, - image_data=image_list, - ) - ) - else: - output = None - - # Most naive implementation, can extract tensor and send via gloo if too slow - dist.barrier() - [output] = broadcast_pyobj( - data=[output], - rank=self._rank, - dist_group=self._device_mesh_cpu["tp"].get_group(), - src=self._device_mesh_cpu["tp"].mesh[0].item(), - force_cpu_device=False, - ) - out = _post_process_outputs(self.processing_class, output) - - response = out[0].to(idx.device) - rollout_log_probs = None - if self.config.calculate_log_probs: - rollout_log_probs = out[1].to(idx.device) - - if response.shape[1] < self.config.response_length: - response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) - if self.config.calculate_log_probs: - rollout_log_probs = pad_sequence_to_length( - rollout_log_probs, self.config.response_length, self.pad_token_id - ) - - seq = torch.cat([idx, response], dim=-1) - - response_length = response.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) - if position_ids.dim() == 3: # qwen2vl mrope - delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) - - # TODO(sgm): fix position_ids on right_pad - # prompt: left pad + response: right pad - # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] - # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] - response_position_ids = position_ids[..., -1:] + delta_position_id - position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask( - response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype - ) - attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) - - # all the tp ranks should contain the same data here. data in all ranks are valid - batch = TensorDict( - { - "prompts": idx, - "responses": response, - "input_ids": seq, # here input_ids become the whole sentences - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=batch_size, - ) - if self.config.calculate_log_probs: - # we will recompute old log prob with actor - batch["rollout_log_probs"] = rollout_log_probs - - # free cache engine - if self._engine is not None and self._tp_rank == 0: - loop = asyncio.get_event_loop() - loop.run_until_complete(self._engine.flush_cache()) - - return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) - - async def _async_rollout_a_request( - self, - req: AsyncRolloutRequest, - do_sample: bool = True, - is_validate: bool = False, - **kwargs, - ) -> AsyncRolloutRequest: - assert self._tp_rank == 0, "only the master process can call this function" - _req = deepcopy(req) - finish_reason_type = None - output = None - - current_turns = 0 - user_turns = 0 - user_turn_rewards = [] - - # Create request-level sampling parameters - request_sampling_params = self.sampling_params.copy() - if not do_sample: - request_sampling_params.update( - { - "n": 1, - "presence_penalty": 0.0, - "frequency_penalty": 0.0, - "repetition_penalty": 1.0, - "temperature": 0, - "top_p": 1, - "top_k": -1, - "ignore_eos": False, - "min_new_tokens": 0, - "max_new_tokens": self.config.response_length, - "skip_special_tokens": True, - "spaces_between_special_tokens": True, - } - ) - elif is_validate: - request_sampling_params.update( - { - "top_k": self.config.val_kwargs.top_k, - "top_p": self.config.val_kwargs.top_p, - "temperature": self.config.val_kwargs.temperature, - "n": 1, # if validate, already repeat in ray_trainer - } - ) - - # Update with any additional kwargs - request_sampling_params.update(kwargs) - - while current_turns < self.config.multi_turn.max_assistant_turns: - if _req.state == AsyncRolloutRequestStateEnum.PENDING: - await self._handle_pending_state(_req) - _req.state = AsyncRolloutRequestStateEnum.RUNNING - elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING: - if _req.messages[-1].tool_calls is not None: - parsed_tool_calls = _req.messages[-1].tool_calls - tool_call_results = await asyncio.gather( - *[ - self._tool_map[tool_call.function.name].execute( - _req.request_id, - tool_call.function.arguments, - **_req.tools_kwargs[tool_call.function.name].get("execute_kwargs", {}), - ) - for tool_call in parsed_tool_calls - ] - ) - _req.add_tool_response_messages(self.processing_class, [resp for resp, _, _ in tool_call_results]) - for tool_call, (resp, reward, metrics) in zip(parsed_tool_calls, tool_call_results, strict=True): - _req.update_metrics(metrics, tool_call.function.name) - if len(_req.input_ids) >= self.config.max_model_len: - finish_reason_type = FinishReasonTypeEnum.STOP - break - _req.state = AsyncRolloutRequestStateEnum.RUNNING - else: - raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") - elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: - # Only continue the conversation if the prompt length is not greater than max_model_len - 1, - # since SGLang raises an error when max_new_tokens + 1 is greater to max_model_len (the extra - # token accounts for the EOS token). - if len(_req.get_generation_prompt_ids(self.processing_class)) + 1 >= self.config.max_model_len: - finish_reason_type = FinishReasonTypeEnum.LENGTH - break - - # Video support is not implemented yet - image_data = ( - _req.multi_modal_data["image"] - if _req.multi_modal_data and "image" in _req.multi_modal_data - else None - ) - video_data = ( - _req.multi_modal_data["video"] - if _req.multi_modal_data and "video" in _req.multi_modal_data - else None - ) - if video_data: - logger.warning( - "video support is not implemented yet, current length of video data is %d", len(video_data) - ) - - output = await self._handle_engine_call(_req, request_sampling_params, image_data=image_data) - content = output["text"] - finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"]) - current_turns += 1 - if finish_reason_type == FinishReasonTypeEnum.LENGTH: - _req.add_assistant_message(self.processing_class, content) - break - else: - if self._function_call_parser and self._function_call_parser.has_tool_call(content): - finish_reason_type = FinishReasonTypeEnum.TOOL_CALL - _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING - try: - normed_content, tool_calls = self._function_call_parser.parse_non_stream(content) - except JSONDecodeError: - normed_content = content - tool_calls = [] - except AttributeError: - normed_content = content - tool_calls = [] - parsed_tool_calls = [] - for tool_call in tool_calls: - function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema( - OpenAIFunctionParsedSchema( - name=tool_call.name, - arguments=tool_call.parameters, - ) - ) - # Drop the tool call if its arguments has decode error - if has_decode_error: - continue - parsed_tool_calls.append( - OpenAIFunctionToolCall( - id=str(tool_call.tool_index), - function=function, - ) - ) - if len(parsed_tool_calls) > 0: - _req.add_assistant_message( - self.processing_class, normed_content, tool_calls=parsed_tool_calls - ) - else: - _req.add_assistant_message(self.processing_class, content) - finish_reason_type = FinishReasonTypeEnum.STOP - _req.state = AsyncRolloutRequestStateEnum.COMPLETED - break - else: - _req.add_assistant_message( - self.processing_class, - content, - ) - if ( - _req.interaction_kwargs - and self.interaction_map - and user_turns < self.config.multi_turn.max_user_turns - and current_turns < self.config.multi_turn.max_assistant_turns - ): - _req.state = AsyncRolloutRequestStateEnum.INTERACTING - else: - break - elif _req.state == AsyncRolloutRequestStateEnum.INTERACTING: - user_turns += 1 - messages = [{"role": x.role, "content": x.content} for x in _req.messages] - - # Get interaction by name from interaction_kwargs - interaction_name = _req.interaction_kwargs.get( - "name", "gsm8k" - ) # Default to gsm8k for backward compatibility - if interaction_name not in self.interaction_map: - raise ValueError( - f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: " - f"{list(self.interaction_map.keys())}" - ) - - interaction = self.interaction_map[interaction_name] - should_terminate_sequence, content, reward, metrics = await interaction.generate_response( - _req.request_id, messages, **_req.interaction_kwargs - ) - user_turn_rewards.append(reward) - if should_terminate_sequence: - finish_reason_type = FinishReasonTypeEnum.STOP - _req.state = AsyncRolloutRequestStateEnum.COMPLETED - break - else: - _req.add_user_message(self.processing_class, content) - if len(_req.input_ids) >= self.config.max_model_len: - finish_reason_type = FinishReasonTypeEnum.STOP - break - else: - _req.state = AsyncRolloutRequestStateEnum.RUNNING - - if current_turns >= self.config.multi_turn.max_assistant_turns: - finish_reason_type = FinishReasonTypeEnum.STOP - - # Calculate the reward for each tool - async def calc_reward_and_release_fn(name: str, tool: BaseTool): - reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {})) - await tool.release(_req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {})) - return name, reward - - tool_reward_tasks = [] - for name in _req.tools_kwargs.keys(): - tool = self._tool_map[name] - tool_reward_tasks.append(calc_reward_and_release_fn(name, tool)) - tool_reward_scores = await asyncio.gather(*tool_reward_tasks) - tool_reward_scores = dict(tool_reward_scores) - all_rewards = {**tool_reward_scores, **{"user_turn_rewards": user_turn_rewards}} - _req.finalize(self.processing_class, all_rewards, finish_reason_type) - - return _req - - async def _handle_engine_call( - self, _req: AsyncRolloutRequest, sampling_params: dict, image_data: Optional[list[Any]] = None - ) -> dict: - generation_prompt_ids = _req.get_generation_prompt_ids(self.processing_class) - return await self._handle_engine_generate(generation_prompt_ids, sampling_params, image_data) - - async def _handle_engine_generate( - self, generation_prompt_ids: list[int], sampling_params: dict, image_data: Optional[list[Any]] = None - ) -> dict: - max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1) - kwargs = sampling_params.copy() - kwargs["max_new_tokens"] = max_new_tokens - kwargs["n"] = 1 # group size is supported in preprocess - output = await self._engine.async_generate( - input_ids=generation_prompt_ids, - sampling_params=kwargs, - return_logprob=False, - image_data=image_data, - ) - return output - - async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest: - if _req.tool_schemas is not None: - tool_creation_coroutines = [] - for tool_schema in _req.tool_schemas: - tool = self._tool_map[tool_schema.function.name] - create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {}) - tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs)) - await asyncio.gather(*tool_creation_coroutines) - if _req.interaction_kwargs and self.interaction_map: - interaction_kwargs = _req.interaction_kwargs - # Get interaction by name from interaction_kwargs - interaction_name = interaction_kwargs.get("name", "gsm8k") # Default to gsm8k for backward compatibility - if interaction_name not in self.interaction_map: - raise ValueError( - f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: " - f"{list(self.interaction_map.keys())}" - ) - - interaction = self.interaction_map[interaction_name] - await interaction.start_interaction(_req.request_id, **interaction_kwargs) - - @GPUMemoryLogger(role="sglang rollout", logger=logger) - @torch.no_grad() - def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto: - logger.warning( - "`generate_sequences_with_tools` is deprecated, please use `generate_sequences(...)`", - DeprecationWarning, - stacklevel=2, - ) - return self._req_level_generate_sequences(prompts, **kwargs) - - @GPUMemoryLogger(role="sglang rollout", logger=logger) - @torch.no_grad() - def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: - """Generates multi-turn sequences for a batch of prompts. - For multi-turn generation, each prompt is processed separately via - `_req_level_generate_sequences` for better tool calling control. - Note that in multi-turn generation, we repeat the prompts for rollout.n times in ray_trainer. - Thus we do not need to repeat the prompts here and set the sampling parameter n to 1. - """ - # Async rollout with tools support - do_sample = prompts.meta_info.get("do_sample", True) - is_validate = prompts.meta_info.get("validate", False) - tgt_device = prompts.batch["input_ids"].device - if self._tp_rank == 0: - req_list = self._preprocess_prompt_to_async_rollout_requests( - prompts, - ) - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather( - *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list], - ) - ) - sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset)) - else: - sorted_output_req_list = None - - dist.barrier() - [sorted_output_req_list] = broadcast_pyobj( - data=[sorted_output_req_list], - rank=self._rank, - dist_group=self._device_mesh_cpu["tp"].get_group(), - src=self._device_mesh_cpu["tp"].mesh[0].item(), - force_cpu_device=False, - ) - # Construct the batch data - prompt_ids, response_ids = [], [] - prompt_attention_mask, response_attention_mask = [], [] - prompt_position_ids, response_position_ids = [], [] - prompt_loss_mask, response_loss_mask = [], [] - messages = [] - reward_scores = [] - multi_modal_inputs = [] - - for req in sorted_output_req_list: - assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed" - assert ( - req.input_ids.shape[-1] - == req.attention_mask.shape[-1] - == req.position_ids.shape[-1] - == req.loss_mask.shape[-1] - ), f"""Request {req.request_id} has different length of - {req.input_ids.shape[-1]=}, {req.attention_mask.shape[-1]=}, - {req.position_ids.shape[-1]=}, {req.loss_mask.shape[-1]=}""" - error_message_lines = [ - f"""Request {req.request_id} has input_ids length {req.input_ids.shape[-1]} - greater than max_model_len {self.config.max_model_len}""", - f"Decoded input_ids: {self.processing_class.decode(req.input_ids.squeeze(0))}", - f"Decoded prompt_ids: {self.processing_class.decode(req.prompt_ids.squeeze(0))}", - f"Decoded response_ids: {self.processing_class.decode(req.response_ids.squeeze(0))}", - f"Messages: {req.messages}", - f"Max model length: {req.max_model_len}", - ] - error_message = "\n".join(error_message_lines) - assert req.input_ids.shape[-1] <= self.config.max_model_len, error_message - - prompt_ids.append(req.prompt_ids.to(tgt_device).squeeze(0)) - response_ids.append(req.response_ids.to(tgt_device).squeeze(0)) - if req.response_ids.shape[-1] > self.config.response_length: - logger.warning( - f"""{req.request_id=} has response_ids length {req.response_ids.shape[-1]} - greater than max_response_len {self.config.response_length},\n{req=}""" - ) - prompt_attention_mask.append(req.prompt_attention_mask.to(tgt_device).squeeze(0)) - response_attention_mask.append(req.response_attention_mask.to(tgt_device).squeeze(0)) - prompt_position_ids.append(req.prompt_position_ids.to(tgt_device).squeeze(0)) - response_position_ids.append(req.response_position_ids.to(tgt_device).squeeze(0)) - prompt_loss_mask.append(req.prompt_loss_mask.to(tgt_device).squeeze(0)) - response_loss_mask.append(req.response_loss_mask.to(tgt_device).squeeze(0)) - messages.append({"messages": req.messages}) - reward_scores.append(req.reward_scores) - multi_modal_inputs.append(req.multi_modal_inputs) - - prompt_ids = pad_sequence( - prompt_ids, - batch_first=True, - padding_value=self.pad_token_id, - padding_side="left", - ) - if prompt_ids.shape[-1] < self.config.prompt_length: - prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True) - response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id) - if response_ids.shape[-1] < self.config.response_length: - response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id) - prompt_attention_mask = pad_sequence( - prompt_attention_mask, - batch_first=True, - padding_value=0, - padding_side="left", - ) - if prompt_attention_mask.shape[-1] < self.config.prompt_length: - prompt_attention_mask = pad_sequence_to_length( - prompt_attention_mask, self.config.prompt_length, 0, left_pad=True - ) - response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0) - if response_attention_mask.shape[-1] < self.config.response_length: - response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) - - # padding prompt_position_ids - if prompt_position_ids[0].dim() == 2: - # if prompt_position_ids is a 2D tensor - # e.g. from qwen2vl, prompt_position_ids.shape = (3, seq_len) - transposed_prompt_position_ids = [p.transpose(0, 1) for p in prompt_position_ids] - prompt_position_ids = pad_sequence( - transposed_prompt_position_ids, batch_first=True, padding_value=0, padding_side="left" - ) - prompt_position_ids = prompt_position_ids.transpose(1, 2) - else: - prompt_position_ids = pad_sequence( - prompt_position_ids, batch_first=True, padding_value=0, padding_side="left" - ) - if prompt_position_ids.shape[-1] < self.config.prompt_length: - prompt_position_ids = pad_sequence_to_length( - prompt_position_ids, self.config.prompt_length, 0, left_pad=True - ) - - # padding response_position_ids - if response_position_ids[0].dim() == 2: - # if response_position_ids is a 2D tensor - # e.g. from qwen2vl, response_position_ids.shape = (3, seq_len) - transposed_response_position_ids = [p.transpose(0, 1) for p in response_position_ids] - response_position_ids = pad_sequence( - transposed_response_position_ids, batch_first=True, padding_value=0, padding_side="left" - ) - response_position_ids = response_position_ids.transpose(1, 2) - else: - response_position_ids = pad_sequence(response_position_ids, batch_first=True, padding_value=0) - if response_position_ids.shape[-1] < self.config.response_length: - response_position_ids = pad_sequence_to_length(response_position_ids, self.config.response_length, 0) - - prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left") - if prompt_loss_mask.shape[1] < self.config.prompt_length: - prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True) - response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0) - if response_loss_mask.shape[1] < self.config.response_length: - response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0) - - input_ids = torch.cat((prompt_ids, response_ids), dim=-1) - attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) - position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1) - - # Construct the batch data - batch = TensorDict( - { - "prompts": prompt_ids, - "responses": response_ids, - "response_mask": response_loss_mask, - "input_ids": input_ids, # here input_ids become the whole sentences - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=len(sorted_output_req_list), - ) - - # free cache engine - if self._engine is not None and self._tp_rank == 0: - loop = asyncio.get_event_loop() - loop.run_until_complete(self._engine.flush_cache()) - - return DataProto( - batch=batch, - non_tensor_batch={ - "messages": np.array(messages), - "reward_scores": np.array(reward_scores), - "multi_modal_inputs": np.array(multi_modal_inputs, dtype=object), - }, - ) - - def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int = 1) -> list[AsyncRolloutRequest]: - assert "raw_prompt" in prompts.non_tensor_batch, ( - "need data.return_raw_chat=True, due to no official way do parse_messages" - ) - logger.info( - "n is deprecated for SGLang rollout since ray ppo trainer will repeat the prompts for rollout.n times" - ) - req_list = [] - multi_modal_data_list = prompts.non_tensor_batch.get( - "multi_modal_data", [None] * len(prompts.non_tensor_batch["raw_prompt"]) - ) - - for data_idx, (raw_prompt, multi_modal_data) in enumerate( - zip(prompts.non_tensor_batch["raw_prompt"], multi_modal_data_list, strict=True) - ): - if self._tool_schemas: - _tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx] - _tool_schemas = [self._tool_map[k].get_openai_tool_schema() for k in _tools_kwargs.keys()] - _input_ids = None - _attention_mask = None - else: - _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx]) - _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx]) - _tools_kwargs = {} - _tool_schemas = None - - if self.interaction_map: - _interaction_kwargs = prompts.non_tensor_batch["interaction_kwargs"][data_idx] - else: - _interaction_kwargs = {} - - req = AsyncRolloutRequest( - batch_data_id=data_idx, - rollout_offset=0, - request_id=str(uuid4()), - state=AsyncRolloutRequestStateEnum.PENDING, - messages=raw_prompt.tolist(), - multi_modal_data=multi_modal_data, - tool_schemas=_tool_schemas, - tools_kwargs=_tools_kwargs, - interaction_kwargs=_interaction_kwargs, - input_ids=_input_ids, - response_ids=None, - attention_mask=_attention_mask, - response_attention_mask=None, - response_position_ids=None, - response_loss_mask=None, - reward_scores={}, - max_prompt_len=self.config.prompt_length, - max_response_len=self.config.response_length, - max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), - use_inference_chat_template=self.config.multi_turn.use_inference_chat_template, - tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode, - processing_class=self.processing_class, - ) - error_message = f"""Request {req.request_id} has mismatched lengths: - input_ids={req.input_ids.shape[-1]}, - attention_mask={req.attention_mask.shape[-1]}, - position_ids={req.position_ids.shape[-1]}, - loss_mask={req.loss_mask.shape[-1]}""" - assert ( - req.input_ids.shape[-1] - == req.attention_mask.shape[-1] - == req.position_ids.shape[-1] - == req.loss_mask.shape[-1] - ), error_message - req_list.append(req) - - return req_list - - async def chat_completion(self, json_request): - assert self._tp_rank == 0, "only called in tp rank 0" - _input_ids = None - _attention_mask = None - _position_ids = None - _tool_schemas = [] - _tools_kwargs = {} - - req = AsyncRolloutRequest( - request_id=str(uuid4()), - state=AsyncRolloutRequestStateEnum.PENDING, - messages=[Message.model_validate(msg) for msg in json_request["messages"]], - tool_schemas=_tool_schemas, - tools_kwargs=_tools_kwargs, - input_ids=_input_ids, - prompt_ids=_input_ids, - response_ids=None, - attention_mask=_attention_mask, - prompt_attention_mask=_attention_mask, - response_attention_mask=None, - position_ids=_position_ids, - prompt_position_ids=_position_ids, - response_position_ids=None, - loss_mask=None, - prompt_loss_mask=None, - response_loss_mask=None, - reward_scores={}, - max_prompt_len=self.config.prompt_length, - max_response_len=self.config.response_length, - max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), - use_inference_chat_template=self.config.multi_turn.use_inference_chat_template, - tokenization_sanity_check_mode=self.config.multi_turn.tokenization_sanity_check_mode, - processing_class=self.processing_class, - ) - - # json_request already contains sampling_params - # Filter only valid SamplingParams arguments - valid_sampling_params = {} - temp_sampling_params = SamplingParams() # Create temporary instance to check valid attributes - for k, v in json_request.items(): - if k not in ["messages", "model", "tools"] and hasattr(temp_sampling_params, k): - valid_sampling_params[k] = v - output = await self._handle_engine_call(req, valid_sampling_params) - # it can be Dict or AsyncIterator[Dict] - if isinstance(output, dict): - outputs = [output] - else: - outputs = output - - # build openai chat completion format - choices = [] - id = None - for i, content in enumerate(outputs): - choices.append( - { - "index": i, - "message": { - "role": "assistant", - "content": content["text"], - }, - "finish_reason": content["meta_info"]["finish_reason"]["type"], - } - ) - id = content["meta_info"]["id"] - - return { - "id": "chatcmpl-" + id, - "object": "chat.completion", - "created": int(time.time()), - "model": json_request.get("model", "sglang_model"), - "choices": choices, - } - - # this function is left for uniform train-inference resharding - - async def generate( - self, prompt_ids: torch.Tensor, sampling_params: dict[str, Any], request_id: str - ) -> torch.Tensor: - request_sampling_params = self.sampling_params.copy() - request_sampling_params.update(sampling_params) - output = await self._handle_engine_generate(prompt_ids, request_sampling_params) - return output["output_ids"] - - async def wake_up(self): - if not self.is_sleep: - return - await self.sharding_manager.wake_up() # pylint: disable=C2801 - self.is_sleep = False - - # this function is left for uniform train-inference resharding - async def sleep(self): - if self.is_sleep: - return - await self.sharding_manager.sleep() - self.is_sleep = True diff --git a/verl/workers/rollout/sglang_rollout/utils.py b/verl/workers/rollout/sglang_rollout/utils.py deleted file mode 100644 index f64bf63b8..000000000 --- a/verl/workers/rollout/sglang_rollout/utils.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pickle -from typing import Any, Iterator, Optional - -import numpy as np -import torch -import torch.distributed as dist - -from verl.utils.device import get_device_name - - -def broadcast_pyobj( - data: list[Any], - rank: int, - dist_group: Optional[torch.distributed.ProcessGroup] = None, - src: int = 0, - force_cpu_device: bool = False, -): - """from https://github.com/sgl-project/sglang/blob/844e2f227ab0cce6ef818a719170ce37b9eb1e1b/python/sglang/srt/utils.py#L905 - - Broadcast inputs from src rank to all other ranks with torch.dist backend. - The `rank` here refer to the source rank on global process group (regardless - of dist_group argument). - """ - device = torch.device(get_device_name() if not force_cpu_device else "cpu") - - if rank == src: - if len(data) == 0: - tensor_size = torch.tensor([0], dtype=torch.long, device=device) - dist.broadcast(tensor_size, src=src, group=dist_group) - else: - serialized_data = pickle.dumps(data) - size = len(serialized_data) - - tensor_data = torch.ByteTensor(np.frombuffer(serialized_data, dtype=np.uint8)).to(device) - tensor_size = torch.tensor([size], dtype=torch.long, device=device) - - dist.broadcast(tensor_size, src=src, group=dist_group) - dist.broadcast(tensor_data, src=src, group=dist_group) - return data - else: - tensor_size = torch.tensor([0], dtype=torch.long, device=device) - dist.broadcast(tensor_size, src=src, group=dist_group) - size = tensor_size.item() - - if size == 0: - return [] - - tensor_data = torch.empty(size, dtype=torch.uint8, device=device) - dist.broadcast(tensor_data, src=src, group=dist_group) - - serialized_data = bytes(tensor_data.cpu().numpy()) - data = pickle.loads(serialized_data) - return data - - -def get_named_tensor_buckets( - iterable: Iterator[tuple[str, torch.Tensor]], bucket_bytes: int -) -> Iterator[list[tuple[str, torch.Tensor]]]: - """ - Group tensors into buckets based on a specified size in megabytes. - - Args: - iterable: An iterator of tuples containing tensor names and tensors. - bucket_bytes: The maximum size of each bucket in bytes. - - Yields: - Lists of tuples, where each tuple contains a tensor name and its corresponding tensor. - - Example: - >>> tensors = [('tensor1', torch.randn(1000, 1000)), ('tensor2', torch.randn(2000, 2000))] - >>> for bucket in get_named_tensor_buckets(tensors, bucket_size_mb=10): - ... print(bucket) - [('tensor1', tensor(...)), ('tensor2', tensor(...))] - - """ - if bucket_bytes <= 0: - raise ValueError(f"bucket_bytes must be greater than 0, got {bucket_bytes}") - - current_bucket = [] - current_size = 0 - for name, tensor in iterable: - tensor_size = tensor.element_size() * tensor.numel() - if current_size + tensor_size > bucket_bytes: - if current_bucket: - yield current_bucket - current_bucket = [(name, tensor)] - current_size = tensor_size - else: - current_bucket.append((name, tensor)) - current_size += tensor_size - - if current_bucket: - yield current_bucket diff --git a/verl/workers/rollout/tokenizer.py b/verl/workers/rollout/tokenizer.py deleted file mode 100644 index 1e1212e50..000000000 --- a/verl/workers/rollout/tokenizer.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -The base tokenizer class, required for any hybrid engine based rollout or inference with vLLM. -""" - -from abc import ABC, abstractmethod - -import numpy as np -import torch - -__all__ = ["HybridEngineBaseTokenizer"] - - -class HybridEngineBaseTokenizer(ABC): - """the tokenizer property and function name should align with HF's to meet vllm requirement""" - - @property - @abstractmethod - def vocab_size(self): - """ - `int`: Size of the base vocabulary (without the added tokens). - """ - pass - - @property - @abstractmethod - def pad_token_id(self): - """ - `Optional[int]`: Id of the padding token in the vocabulary. Returns `None` if the token has not been set. - """ - pass - - @property - @abstractmethod - def eos_token_id(self): - """ - `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been - set. - """ - pass - - @property - @abstractmethod - def all_special_ids(self) -> list[int]: - """ - `List[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes. - """ - pass - - @property - @abstractmethod - def all_special_tokens(self) -> list[str]: - """ - `List[str]`: A list of the unique special tokens (`''`, `''`, ..., etc.). - - Convert tokens of `tokenizers.AddedToken` type to string. - """ - pass - - @abstractmethod - def encode(self, text): - """ - Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. - - Args: - text (`str`, `List[str]` or `List[int]`): - The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the - `tokenize` method) or a list of integers. - - text_pair (`str`, `List[str]` or `List[int]`, *optional*): - Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using - the `tokenize` method) or a list of integers. - """ - pass - - @abstractmethod - def decode( - self, - token_ids: int | list[int] | np.ndarray | torch.Tensor, - skip_special_tokens: bool = False, - clean_up_tokenization_spaces: bool = None, - **kwargs, - ) -> str: - """ - Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special - tokens and clean up tokenization spaces. - - Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. - - Args: - token_ids (`Union[int, List[int], np.ndarray, torch.Tensor]`): - List of tokenized input ids. Can be obtained using the `__call__` method. - skip_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to remove special tokens in the decoding. - clean_up_tokenization_spaces (`bool`, *optional*): - Whether or not to clean up the tokenization spaces. If `None`, will default to - `self.clean_up_tokenization_spaces`. - kwargs (additional keyword arguments, *optional*): - Will be passed to the underlying model specific decode method. - - Returns: - `str`: The decoded sentence. - """ - pass - - @abstractmethod - def convert_ids_to_tokens(self, ids: int | list[int], skip_special_tokens: bool = False) -> str | list[str]: - """ - Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and - added tokens. - - Args: - ids (`int` or `List[int]`): - The token id (or token ids) to convert to tokens. - skip_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to remove special tokens in the decoding. - - Returns: - `str` or `List[str]`: The decoded token(s). - """ - pass - - @abstractmethod - def get_added_vocab(self) -> dict[str, int]: - """ - Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from - the fast call because for now we always add the tokens even if they are already in the vocabulary. This is - something we should change. - - Returns: - `Dict[str, int]`: The added tokens. - """ - pass - - @abstractmethod - def convert_tokens_to_string(self, tokens: list[str]) -> str: - """ - Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we - often want to remove sub-word tokenization artifacts at the same time. - - Args: - tokens (`List[str]`): The token to join in a string. - - Returns: - `str`: The joined tokens. - """ - pass - - @property - def is_fast(self): - return False diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py deleted file mode 100644 index 767858fe3..000000000 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from importlib.metadata import PackageNotFoundError, version - -from .vllm_rollout_spmd import vLLMAsyncRollout, vLLMRollout # noqa: F401 - - -def get_version(pkg): - try: - return version(pkg) - except PackageNotFoundError: - return None - - -vllm_package_name = "vllm" -vllm_package_version = get_version(vllm_package_name) -if vllm_package_version is None: - raise PackageNotFoundError( - "To use vllm rollout, please ensure the 'vllm' package is properly installed. See " - "https://verl.readthedocs.io/en/latest/start/install.html for more details" - ) - -if "ROCM_PATH" in os.environ: - import re - - match = re.match(r"(\d+\.\d+\.?\d*)", vllm_package_version) - if match: - vllm_package_version = match.group(1) - else: - raise ValueError(f"Warning: Could not parse version format: {vllm_package_version}") diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py deleted file mode 100644 index 988dac407..000000000 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import os -import pickle -from typing import Any, Callable, Optional - -import ray -import zmq -from omegaconf import DictConfig -from starlette.requests import Request -from starlette.responses import JSONResponse, StreamingResponse -from vllm import SamplingParams -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse -from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels -from vllm.inputs import TokensPrompt -from vllm.outputs import RequestOutput -from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.executor.abstract import Executor -from vllm.worker.worker_base import WorkerWrapperBase - -from verl.utils.fs import copy_to_local -from verl.workers.rollout.async_server import AsyncServerBase - -logger = logging.getLogger(__file__) - - -def _get_model_runner_workers(vllm_config, init_ray: bool = True): - assert vllm_config.instance_id is not None, "instance_id must be set for external ray actors." - - fields = vllm_config.instance_id.split(":") - assert len(fields) == 4, ( - f"instance_id: {vllm_config.instance_id} must be in the format of " - f":::." - ) - namespace, wg_prefix, vllm_dp_size, vllm_dp_rank = fields[0], fields[1], int(fields[2]), int(fields[3]) - - # Make sure subprocess in same namespace as parent actor. - # actor name format: {name_prefix}WorkerDict_{pg_idx}:{local_rank} - if init_ray: - ray.init(namespace=namespace) - actor_names = [ - actor_name for actor_name in ray.util.list_named_actors() if actor_name.startswith(f"{wg_prefix}WorkerDict") - ] - - vllm_tp_size = vllm_config.parallel_config.tensor_parallel_size - assert len(actor_names) == vllm_dp_size * vllm_tp_size, ( - f"instance_id: {vllm_config.instance_id} has {len(actor_names)} actors, but vllm_dp_size: " - f"{vllm_dp_size} * vllm_tp_size: {vllm_tp_size} = {vllm_dp_size * vllm_tp_size} is expected." - ) - - def get_pg_index_and_local_rank(actor_name) -> tuple[int, int]: - fields = actor_name.split(":") - assert len(fields) == 2, f"invalid actor name: {actor_name}" - pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1]) - return pg_index, local_rank - - # sort actor names by pg_index and local_rank - actor_names = sorted(actor_names, key=get_pg_index_and_local_rank) - actor_names = actor_names[vllm_dp_rank * vllm_tp_size : (vllm_dp_rank + 1) * vllm_tp_size] - workers: list[WorkerWrapperBase] = [ray.get_actor(actor_name) for actor_name in actor_names] - print(f"instance_id: {vllm_config.instance_id} initializes with external actors: {actor_names}") - - return workers - - -class ExternalRayDistributedExecutor(Executor): - """An executor that engines are launched by external ray actors.""" - - uses_ray: bool = False - - def _init_executor(self) -> None: - self.workers = _get_model_runner_workers(vllm_config=self.vllm_config, init_ray=True) - - kwargs = dict( - vllm_config=self.vllm_config, - local_rank=None, - rank=None, - distributed_init_method="env://", - is_driver_worker=True, - ) - self.collective_rpc("init_worker", args=([kwargs],)) - self.collective_rpc("init_device") - self.collective_rpc("load_model") - print(f"instance_id: {self.vllm_config.instance_id} initializes finished.") - - def collective_rpc( - self, - method: str | Callable, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, - ) -> list[Any]: - # TODO(wuxibin): support ray compiled graph - if isinstance(method, str): - sent_method = method - else: - sent_method = pickle.dumps(method) - del method - - # ~3ms overhead per schedule step due to SchedulerOutput/ModelRunnerOutput serialization/deserialization. - outputs = ray.get( - [worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers] - ) - return outputs - - def check_health(self): - return - - -class ExternalZeroMQDistributedExecutor(Executor): - """An executor that engines are launched by external ray actors.""" - - uses_ray: bool = False - - def _init_executor(self) -> None: - addresses = os.environ["VERL_VLLM_ZMQ_ADDRESSES"].split(",") - self.context = zmq.Context() - self.sockets = [] - for address in addresses: - socket = self.context.socket(zmq.REQ) - socket.connect(address) - self.sockets.append(socket) - - kwargs = dict( - vllm_config=self.vllm_config, - local_rank=None, - rank=None, - distributed_init_method="env://", - is_driver_worker=True, - ) - self.collective_rpc("init_worker", args=([kwargs],)) - self.collective_rpc("init_device") - self.collective_rpc("load_model") - - def collective_rpc( - self, - method: str | Callable, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, - ) -> list[Any]: - if isinstance(method, str): - sent_method = method - else: - sent_method = pickle.dumps(method) - del method - - message = pickle.dumps((sent_method, args, kwargs or {})) - for socket in self.sockets: - socket.send(message, zmq.DONTWAIT) - - outputs = [] - for socket in self.sockets: - outputs.append(pickle.loads(socket.recv())) - return outputs - - def check_health(self): - return - - -@ray.remote(num_cpus=1) -class AsyncvLLMServer(AsyncServerBase): - """ - AsyncvLLMServer is a wrapper for AsyncLLM, it uses ExternalRayDistributedExecutor to launch engines - in hybrid rollout workers, i.e AsyncActorRolloutRefWorker. - - AsyncvLLMServer works as follows: - 1. Start FastAPI server first. - 2. Initialize AsyncLLM with ExternalRayDistributedExecutor. - 3. AsyncLLM spawn EngineCore in subprocess. - 4. EngineCore initialize ExternalRayDistributedExecutor. - 5. ExternalRayDistributedExecutor lookup its corresponding actors by name. - 6. ExternalRayDistributedExecutor init executor: init_worker, init_device, load_model. - - For vLLM AsyncLLM design, see: https://github.com/vllm-project/vllm/pull/9826 - """ - - def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_prefix: str): - """ - Args: - config: DictConfig. - vllm_dp_size: int, vllm data parallel size. - vllm_dp_rank: int, vllm data parallel rank. - wg_prefix: str, worker group prefix, used to lookup actors. - """ - super().__init__() - - self.config = config.actor_rollout_ref - self.vllm_dp_size = vllm_dp_size - self.vllm_dp_rank = vllm_dp_rank - self.wg_prefix = wg_prefix - self.engine: AsyncLLM = None - - async def init_engine(self): - """Init vLLM AsyncLLM engine.""" - config = self.config - model_path = config.model.path - model_name = "/".join(model_path.split("/")[-2:]) - local_path = copy_to_local(model_path) - trust_remote_code = config.model.get("trust_remote_code", False) - config = config.rollout - - tensor_parallel_size = config.get("tensor_model_parallel_size", 1) - max_num_batched_tokens = config.get("max_num_batched_tokens", 8192) - max_model_len = config.max_model_len if config.max_model_len else config.prompt_length + config.response_length - self.max_model_len = int(max_model_len) - - # Override default generation config from hugging face model config, - # user can still override them by passing kwargs in each request. - kwargs = dict( - n=1, - logprobs=0, - repetition_penalty=1.0, - max_new_tokens=config.response_length, - ) - for k in config.keys(): - if hasattr(SamplingParams(), str(k)): - kwargs[k] = config.get(k) - print(f"override_generation_config: {kwargs}") - - backend = os.environ.get("VERL_VLLM_DISTRIBUTED_BACKEND", "zeromq") - if backend == "zeromq": - distributed_executor_backend = ExternalZeroMQDistributedExecutor - elif backend == "ray": - distributed_executor_backend = ExternalRayDistributedExecutor - else: - distributed_executor_backend = None - - engine_args = AsyncEngineArgs( - model=local_path, - enable_sleep_mode=config.free_cache_engine, - override_generation_config=kwargs, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - dtype=config.dtype, - enforce_eager=config.enforce_eager, - gpu_memory_utilization=config.gpu_memory_utilization, - disable_custom_all_reduce=True, - skip_tokenizer_init=False, - max_model_len=self.max_model_len, - load_format="auto", - disable_log_stats=config.disable_log_stats, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=config.enable_chunked_prefill, - enable_prefix_caching=True, - trust_remote_code=trust_remote_code, - seed=config.get("seed", 0), - ) - - # init async llm engine - vllm_config = self._create_engine_config(engine_args) - self.engine = AsyncLLM.from_vllm_config(vllm_config) - - # build serving chat - model_config = self.engine.model_config - BASE_MODEL_PATHS = [BaseModelPath(name=model_name, model_path=model_path)] - models = OpenAIServingModels(self.engine, model_config, BASE_MODEL_PATHS) - self.openai_serving_chat = OpenAIServingChat( - self.engine, - model_config, - models, - "assistant", - request_logger=RequestLogger(max_log_len=4096), - chat_template=None, - chat_template_content_format="auto", - enable_auto_tools=config.multi_turn.tool_config_path is not None, - tool_parser=config.multi_turn.format, # hermes, llama3_json, ... - ) - - def _create_engine_config(self, engine_args: AsyncEngineArgs): - vllm_config = engine_args.create_engine_config() - namespace = ray.get_runtime_context().namespace - vllm_config.instance_id = f"{namespace}:{self.wg_prefix}:{self.vllm_dp_size}:{self.vllm_dp_rank}" - - # VERL_VLLM_ZMQ_ADDRESSES - if engine_args.distributed_executor_backend == ExternalZeroMQDistributedExecutor: - workers = _get_model_runner_workers(vllm_config=vllm_config, init_ray=False) - zmq_addresses = ray.get([worker.get_zeromq_address.remote() for worker in workers]) - print(f"VERL_VLLM_ZMQ_ADDRESSES: {zmq_addresses}") - os.environ["VERL_VLLM_ZMQ_ADDRESSES"] = ",".join(zmq_addresses) - - return vllm_config - - async def chat_completion(self, raw_request: Request): - """OpenAI-compatible HTTP endpoint. - - API reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html - """ - request_json = await raw_request.json() - request = ChatCompletionRequest(**request_json) - generator = await self.openai_serving_chat.create_chat_completion(request, raw_request) - - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), status_code=generator.code) - if request.stream: - return StreamingResponse(content=generator, media_type="text/event-stream") - else: - assert isinstance(generator, ChatCompletionResponse) - return JSONResponse(content=generator.model_dump()) - - async def generate(self, prompt_ids: list[int], sampling_params: dict[str, Any], request_id: str) -> list[int]: - max_tokens = self.max_model_len - len(prompt_ids) - sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params) - prompt = TokensPrompt(prompt_token_ids=prompt_ids) - generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id) - - # Get final response - final_res: Optional[RequestOutput] = None - async for output in generator: - final_res = output - assert final_res is not None - - return final_res.outputs[0].token_ids - - async def wake_up(self): - if self.config.rollout.free_cache_engine: - await self.engine.wake_up() - - async def sleep(self): - # TODO: https://github.com/vllm-project/vllm/issues/17103 - await self.engine.reset_prefix_cache() - if self.config.rollout.free_cache_engine: - await self.engine.sleep() diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py deleted file mode 100644 index 46b6051ee..000000000 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ /dev/null @@ -1,513 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -The vllm_rollout that can be applied in different backend -When working with FSDP: -- Use DTensor weight loader (recommended) or HF weight loader -- Utilize state_dict from the FSDP to synchronize the weights among tp ranks in vLLM -When working with Megatron: -- Use Megatron weight loader -- During training, only the current pp stage holds the parameters -- Before inference, broadcast the parameters of the current pp rank - to all other pp ranks (all pp ranks holds all the parameters) -- Bind the parameters to the inference engine -- Do inference in tp. pp is treated as additional dp -- After inference, all the parameters that doesn't belong to this pp rank is freed. -""" - -import logging -import os -import pickle -import socket -import threading -from contextlib import contextmanager -from copy import deepcopy -from types import MethodType -from typing import Any - -import numpy as np -import ray -import torch -import torch.distributed -import zmq -from filelock import FileLock -from omegaconf import DictConfig, OmegaConf -from tensordict import TensorDict -from vllm import LLM, SamplingParams -from vllm.distributed import parallel_state as vllm_ps -from vllm.lora.request import LoRARequest -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.worker.worker_base import WorkerWrapperBase - -from verl import DataProto -from verl.utils.profiler import GPUMemoryLogger -from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length -from verl.workers.rollout.base import BaseRollout - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -# TODO -# 1. support pp in vllm -# 2. passing tokenizer is not necessary? no encoding/decoding is happending here -# 3. simplify init logics - - -# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding. -def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[int]: - # remove the left padding in the prompt token_id - # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id - # is not None else self.llm_engine.tokenizer.eos_token_id - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] - token_ids = prompt_token_ids[non_pad_index:].tolist() - return token_ids - - -class vLLMRollout(BaseRollout): - def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): - """A vLLM rollout. It requires the module is supported by the vllm. - - Args: - module: module here follows huggingface APIs - config: DictConfig - tokenizer: the task/model tokenizer - model_hf_config: the huggingface config to initiallize the generating model in vllm - **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group - """ - super().__init__() - self.config = config - - tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert tensor_parallel_size <= torch.distributed.get_world_size(), ( - "tensor parallel size should be less than or equal to the world size" - ) - max_num_batched_tokens = self.config.get("max_num_batched_tokens", 8192) - - if kwargs.get("train_tp") is not None: - # deployed with megatron - # NOTE: import os removed by Reasoning360. Definitely a bug of the official code. - os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" - os.environ["MEGATRON_IMPORT_TIMERS"] = "0" - vllm_ps.initialize_model_parallel(tensor_model_parallel_size=tensor_parallel_size) - - rope_scaling_config = getattr(model_hf_config, "rope_scaling", None) - if not rope_scaling_config: - max_position_embeddings = None - if hasattr(model_hf_config, "max_position_embeddings"): - max_position_embeddings = model_hf_config.max_position_embeddings - elif hasattr(model_hf_config, "llm_config") and hasattr( - model_hf_config.llm_config, "max_position_embeddings" - ): - max_position_embeddings = model_hf_config.llm_config.max_position_embeddings - elif hasattr(model_hf_config, "text_config") and hasattr( - model_hf_config.text_config, "max_position_embeddings" - ): - max_position_embeddings = model_hf_config.text_config.max_position_embeddings - if max_position_embeddings is None: - raise ValueError("max_position_embeddings not found in model_hf_config") - assert max_position_embeddings >= config.prompt_length + config.response_length, ( - "model context length should be greater than total sequence length" - ) - else: - # handle type where there's a length extend factor - # see https://qwen.readthedocs.io/en/latest/deployment/vllm.html#extended-context-support - # for using yarn as an example - rope_scaling_factor = rope_scaling_config.get("factor", 1.0) - - assert ( - model_hf_config.max_position_embeddings * rope_scaling_factor - >= config.prompt_length + config.response_length - ), ( - "model context length should be greater than total sequence length, " - + f"got rope_scaling_factor={rope_scaling_factor} and " - + f"max_position_embeddings={model_hf_config.max_position_embeddings}" - ) - - max_model_len = int(config.max_model_len or config.prompt_length + config.response_length) - - if max_num_batched_tokens < max_model_len and self.config.enable_chunked_prefill: - raise ValueError( - "Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ - please increase max_num_batched_tokens or disable chunked prefill" - ) - - trust_remote_code = kwargs.get("trust_remote_code", False) - load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format - - lora_kwargs = kwargs.pop("lora_kwargs", {}) - self.lora_kwargs = lora_kwargs - # copy it to avoid secretly modifying the engine config - engine_kwargs = ( - {} - if "engine_kwargs" not in config or "vllm" not in config.engine_kwargs - else OmegaConf.to_container(deepcopy(config.engine_kwargs.vllm)) - ) - # For each vLLM engine parameter, - # - `None` means not setting it, so we pop it, and leave it to vLLM default value - # (which can vary across different vLLM versions); - # - Otherwise it's the desired value we want to explicitly set. - engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} - if config.get("limit_images", None): # support for multi-image data - engine_kwargs["limit_mm_per_prompt"] = {"image": config.get("limit_images")} - - self.inference_engine = LLM( - model=model_path, - enable_sleep_mode=config.free_cache_engine, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend="external_launcher", - dtype=config.dtype, - enforce_eager=config.enforce_eager, - gpu_memory_utilization=config.gpu_memory_utilization, - disable_custom_all_reduce=True, - skip_tokenizer_init=False, - max_model_len=max_model_len, - load_format=load_format, - disable_log_stats=config.disable_log_stats, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=config.enable_chunked_prefill, - enable_prefix_caching=True, - trust_remote_code=trust_remote_code, - seed=int(os.getenv("RANK", "0")) - // tensor_parallel_size, # NOTE: modified by Reasoning360. Originally config.get("seed", 0) - **lora_kwargs, - **engine_kwargs, - ) - # NOTE: added by Reasoning360 - # self._monkey_patch_vllm_engine_v0() - - # Offload vllm model to reduce peak memory usage - if config.free_cache_engine: - self.inference_engine.sleep(level=1) - - kwargs = dict( - n=1, - logprobs=0, # can be set to 0 and let actor to recompute - max_tokens=config.response_length, - ) - - kwargs["detokenize"] = False - - # supporting adding any sampling params from the config file - for k in config.keys(): - if hasattr(SamplingParams(), str(k)): - kwargs[k] = config.get(k) - kwargs["n"] = 1 # already repeat in ray_trainer - print(f"kwargs: {kwargs}") - self.sampling_params = SamplingParams(**kwargs) - - self.pad_token_id = tokenizer.pad_token_id - - @contextmanager - def update_sampling_params(self, **kwargs): - # update sampling params - old_sampling_params_args = {} - if kwargs: - for key, value in kwargs.items(): - if hasattr(self.sampling_params, key): - old_value = getattr(self.sampling_params, key) - old_sampling_params_args[key] = old_value - setattr(self.sampling_params, key, value) - yield - # roll back to previous sampling params - # if len(old_sampling_params_args): - for key, value in old_sampling_params_args.items(): - setattr(self.sampling_params, key, value) - - # NOTE: added by Reasoning360. timer for precise logging - @staticmethod - @contextmanager - def timer(): - import time - - start = end = time.perf_counter() - yield lambda: end - start - end = time.perf_counter() - - @GPUMemoryLogger(role="vllm rollout spmd", logger=logger) - @torch.no_grad() - def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: - """Generate sequences for a batch of prompts. - - Args: - batch (DataProto): Input batch. - - Returns: - DataProto: Output batch. - - prompts: [bsz, prompt_length], prompt token ids from dataset. - - responses: [bsz, response_length], output token ids include response tokens - from LLM generation and observation tokens from tool_calls. - - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens. - - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens - and response tokens. - - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens. - - position_ids: [bsz, prompt_length + response_length], incremental position ids. - - For multi-turn conversations: - responses: |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->| - response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0| - """ - idx = prompts.batch["input_ids"] # (bs, prompt_length) - # left-padded attention_mask - attention_mask = prompts.batch["attention_mask"] - position_ids = prompts.batch["position_ids"] - - # used to construct attention_mask - eos_token_id = prompts.meta_info["eos_token_id"] - - batch_size = idx.size(0) - - non_tensor_batch = prompts.non_tensor_batch - if "raw_prompt_ids" not in non_tensor_batch: - non_tensor_batch["raw_prompt_ids"] = np.array( - [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object - ) - - if batch_size != len(non_tensor_batch["raw_prompt_ids"]): - raise RuntimeError("vllm sharding manager is not work properly.") - - if "multi_modal_data" in non_tensor_batch: - vllm_inputs = [] - for raw_prompt_ids, multi_modal_data in zip( - non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data"), strict=True - ): - vllm_inputs.append({"prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data}) - else: - vllm_inputs = [ - {"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids") - ] - - # ensure the type of `prompt_token_ids` passed to vllm is list[int] - # https://github.com/volcengine/verl/pull/772 - for input_data in vllm_inputs: - if isinstance(input_data["prompt_token_ids"], np.ndarray): - input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() - elif not isinstance(input_data["prompt_token_ids"], list): - raise TypeError( - f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}" - ) - - do_sample = prompts.meta_info.get("do_sample", True) - is_validate = prompts.meta_info.get("validate", False) - if not do_sample: - kwargs = { - "best_of": 1, - "top_p": 1.0, - "top_k": -1, - "min_p": 0.0, - "temperature": 0, - "n": 1, # if greedy, only 1 response - } - elif is_validate: - # TODO: try ** - kwargs = { - "top_k": self.config.val_kwargs.top_k, - "top_p": self.config.val_kwargs.top_p, - "temperature": self.config.val_kwargs.temperature, - "n": 1, # if validate, already repeat in ray_trainer - } - - lora_requests = None - if self.lora_kwargs: - lora_int_ids = list(self.inference_engine.llm_engine.list_loras()) - if len(lora_int_ids) > 0: - lora_int_id = lora_int_ids[0] - lora_requests = [ - LoRARequest(lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/simon-stub-path") - ] * batch_size - - # users can customize different sampling_params at different run - with self.update_sampling_params(**kwargs), self.timer(): - outputs = self.inference_engine.generate( - prompts=vllm_inputs, # because we have already convert it to prompt token id - sampling_params=self.sampling_params, - lora_request=lora_requests, - use_tqdm=False, - ) - - # TODO(sgm): disable logprob when recompute_log_prob is enable - # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) - - response = [] - rollout_log_probs = [] - for output in outputs: - for sample_id in range(len(output.outputs)): - response_ids = output.outputs[sample_id].token_ids - response.append(response_ids) - if self.config.calculate_log_probs: - curr_log_prob = [] - for i, logprob in enumerate(output.outputs[sample_id].logprobs): - curr_log_prob.append(logprob[response_ids[i]].logprob) - rollout_log_probs.append(curr_log_prob) - - response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to( - idx.device - ) - if self.config.calculate_log_probs: - rollout_log_probs = pad_2d_list_to_length( - rollout_log_probs, -1, max_length=self.config.response_length - ).to(idx.device) - rollout_log_probs = rollout_log_probs.to(torch.float32) - - seq = torch.cat([idx, response], dim=-1) - - response_length = response.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1) - if position_ids.dim() == 3: # qwen2vl mrope - delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1) - - # TODO(sgm): fix position_ids on right_pad - # prompt: left pad + response: right pad - # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] - # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] - response_position_ids = position_ids[..., -1:] + delta_position_id - position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask( - response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype - ) - attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) - - # all the tp ranks should contain the same data here. data in all ranks are valid - batch = TensorDict( - { - "prompts": idx, - "responses": response, - "input_ids": seq, # here input_ids become the whole sentences - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=batch_size, - ) - if self.config.calculate_log_probs: - # we will recompute old log prob with actor - batch["rollout_log_probs"] = rollout_log_probs - - return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) - - -# https://github.com/vllm-project/vllm/issues/13175 -def _monkey_patch_compute_logits(model, vocab_size: int): - original_compute_logits = model.compute_logits - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - logits = original_compute_logits(hidden_states, sampling_metadata) - logits[..., vocab_size:] = float("-inf") - return logits - - model.compute_logits = MethodType(compute_logits, model) - - -class vLLMAsyncRollout: - """vLLMAsyncRollout is a thin wrapper of WorkerWrapperBase, - which is engine in single worker process. - """ - - def __init__(self, model_path: str, config: DictConfig, tokenizer, model_hf_config, **kwargs): - self.tokenizer = tokenizer - - # Engine is deferred to be initialized in init_worker - self.config = config - self.inference_engine: WorkerWrapperBase = None - self.sharding_manager = None - self.is_sleep = False - self.address = self._init_zeromq() - - def _init_zeromq(self) -> str: - tensor_parallel_size = self.config.tensor_model_parallel_size - - # single node: ipc, multi nodes: tcp - local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"]) - socket_type = "ipc" if tensor_parallel_size <= local_world_size else "tcp" - - # File lock to prevent multiple workers listen to same port - with FileLock("/tmp/verl_vllm_zmq.lock"): - if socket_type == "ipc": - pid = os.getpid() - address = f"ipc:///tmp/verl_vllm_zmq_{pid}.ipc" - else: - ip, port = self._get_free_port() - address = f"tcp://{ip}:{port}" - context = zmq.Context() - self.socket = context.socket(zmq.REP) - self.socket.bind(address) - - self.loop_thread = threading.Thread(target=self._loop_forever) - self.loop_thread.start() - - return address - - def _get_free_port(self): - ip = ray.util.get_node_ip_address() - with socket.socket() as sock: - sock.bind(("", 0)) - port = sock.getsockname()[1] - return ip, port - - def _loop_forever(self): - while True: - message = self.socket.recv() - method, args, kwargs = pickle.loads(message) - result = self.execute_method(method, *args, **kwargs) - self.socket.send(pickle.dumps(result)) - - def get_zeromq_address(self): - return self.address - - def init_worker(self, all_kwargs: list[dict[str, Any]]): - """Initialize worker engine.""" - all_kwargs[0]["rank"] = int(os.environ["RANK"]) - all_kwargs[0]["local_rank"] = 0 - - self.vllm_config = all_kwargs[0]["vllm_config"] - self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config) - self.inference_engine.init_worker(all_kwargs) - - def load_model(self, *args, **kwargs): - self.inference_engine.load_model(*args, **kwargs) - - # inference engine is initialized now, update sharding manager - self.sharding_manager.inference_engine = self.inference_engine - self.sharding_manager.model_runner = self.inference_engine.worker.model_runner - - _monkey_patch_compute_logits(self.inference_engine.worker.model_runner.model, len(self.tokenizer)) - - def sleep(self, *args, **kwargs): - """Offload model weights and discard kv cache.""" - if self.is_sleep: - return - self.sharding_manager.__exit__(None, None, None) - self.is_sleep = True - - def wake_up(self, *args, **kwargs): - """Load model weights and build kv cache.""" - if not self.is_sleep: - return - self.sharding_manager.__enter__() # pylint: disable=C2801 - self.is_sleep = False - - def execute_method(self, method: str | bytes, *args, **kwargs): - if method == "init_worker": - return self.init_worker(*args, **kwargs) - elif method == "load_model": - return self.load_model(*args, **kwargs) - elif method == "sleep": - return self.sleep(*args, **kwargs) - elif method == "wake_up": - return self.wake_up(*args, **kwargs) - else: - return self.inference_engine.execute_method(method, *args, **kwargs) diff --git a/verl/workers/sharding_manager/__init__.py b/verl/workers/sharding_manager/__init__.py deleted file mode 100644 index 1ce90c5eb..000000000 --- a/verl/workers/sharding_manager/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/verl/workers/sharding_manager/base.py b/verl/workers/sharding_manager/base.py deleted file mode 100644 index 59537be64..000000000 --- a/verl/workers/sharding_manager/base.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Sharding manager to implement HybridEngine -""" - -from verl import DataProto - - -class BaseShardingManager: - def __init__(self): - self.timing = {} - - def __enter__(self): - pass - - def __exit__(self, exc_type, exc_value, traceback): - pass - - def preprocess_data(self, data: DataProto) -> DataProto: - return data - - def postprocess_data(self, data: DataProto) -> DataProto: - return data diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py deleted file mode 100644 index 80201dc56..000000000 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import logging -import os - -import torch -import torch.distributed as dist -from sglang.srt.entrypoints.engine import Engine -from sglang.srt.model_executor.model_runner import LocalSerializedTensor -from sglang.srt.utils import MultiprocessingSerializer -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP -from torch.distributed.tensor import DTensor - -from verl import DataProto -from verl.protocol import all_gather_data_proto -from verl.utils.device import get_device_id, get_torch_device -from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu -from verl.utils.model import convert_weight_keys -from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer -from verl.utils.torch_functional import check_device_is_available -from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets - -from .base import BaseShardingManager - -# from vllm.distributed import parallel_state as sglang_ps -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -def _preprocess_tensor_for_update_weights(tensor: torch.Tensor): - if isinstance(tensor, DTensor): - return tensor.full_tensor() - return tensor - - -class FSDPSGLangShardingManager(BaseShardingManager): - @check_device_is_available() - def __init__( - self, - module: FSDP, - inference_engine: Engine, - model_config, - rollout_config, - full_params: bool = False, - device_mesh: DeviceMesh = None, - offload_param: bool = False, - multi_stage_wake_up: bool = False, - ): - self.module = module - self.inference_engine = inference_engine - self.model_config = model_config - self.rollout_config = rollout_config - self.device_mesh = device_mesh - self.offload_param = offload_param - self.multi_stage_wake_up = multi_stage_wake_up - - # Full params - self.full_params = full_params - if full_params and fsdp_version(self.module) == 1: - FSDP.set_state_dict_type( - self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig() - ) - elif fsdp_version(self.module) == 1: - FSDP.set_state_dict_type( - self.module, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig(), - ) - - self.tp_size = self.device_mesh["infer_tp"].size() - self.tp_rank = self.device_mesh["infer_tp"].get_local_rank() - - # Note that torch_random_states may be different on each dp rank - self.torch_random_states = get_torch_device().get_rng_state() - # get a random rng states - if self.device_mesh is not None: - gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - else: - self.gen_random_states = None - - @GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger) - def __enter__(self): - self.timing = {} - with simple_timer("reshard", self.timing): - loop = asyncio.get_event_loop() - loop.run_until_complete(self.wake_up()) - - @GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) - def __exit__(self, exc_type, exc_value, traceback): - loop = asyncio.get_event_loop() - loop.run_until_complete(self.sleep()) - - async def update_weights(self, params): - # Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update - named_tensors = [(k, v) for k, v in params.items()] - load_format = None - # convert megabytes to bytes - update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20 - for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes): - # On each rank, serialize a batch of (name, tensor) tuples. - # named_tensors_batch will be a list like: - # [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...] - named_tensors_batch = [ - (name, MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor))) - for name, tensor in batch - ] - - if self.device_mesh["infer_tp"].get_local_rank() == 0: - # On rank 0, prepare a list to hold the gathered batches from all ranks. - gathered_serialized_batches = [None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0])] - else: - gathered_serialized_batches = None - - # Gather the named_tensors_batch from all ranks to rank 0. - # After this, on rank 0, gathered_serialized_batches will be a list of lists: - # [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ], # batch from TP rank 0 - # [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ], # batch from TP rank 1 - # ... ] - # On other ranks, gathered_serialized_batches will be None. - dist.gather_object( - obj=named_tensors_batch, - object_gather_list=gathered_serialized_batches, - dst=self.device_mesh["infer_tp"].mesh.tolist()[0], - group=self.device_mesh["infer_tp"].get_group(), - ) - - if self.device_mesh["infer_tp"].get_local_rank() == 0: - # Use zip(*) to "transpose" the data structure. - # This groups the serialized parts for each individual tensor across all TP ranks. - # Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]] - # to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ] - logical_tensors = zip(*gathered_serialized_batches, strict=True) - - await self.inference_engine.update_weights_from_tensor( - named_tensors=[ - # 'tensor_group' represents a single logical tensor's data from all ranks. - ( - tensor_group[0][0], # Get the name from the first rank's data. - LocalSerializedTensor( - # 'rank_part' is the (name, serialized_tensor) tuple from one specific rank. - values=[rank_part[1] for rank_part in tensor_group] - ), - ) - for tensor_group in logical_tensors - # each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) ) - ], - load_format=load_format, - flush_cache=False, - ) - - if self.device_mesh["infer_tp"].get_local_rank() == 0: - await self.inference_engine.flush_cache() - - async def release_memory(self): - if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - await self.inference_engine.release_memory_occupation() - - @GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger) - async def wake_up(self): - get_torch_device().empty_cache() - - if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - if self.multi_stage_wake_up: - await self.inference_engine.resume_memory_occupation(tags=["weights"]) - log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) - else: - await self.inference_engine.resume_memory_occupation() - log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) - - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) - if self.offload_param: - load_fsdp_model_to_gpu(self.module) - params = self.module.state_dict() - log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) - device = get_device_id() # used when fsdp2 set cpu_offload_policy - params = { - k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items() - } - - # convert weight keys to match the model config - params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) - - # Copy, not share memory - await self.update_weights(params) - log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) - - del params - if self.offload_param: - offload_fsdp_model_to_cpu(self.module) - get_torch_device().empty_cache() - log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) - - if ( - self.multi_stage_wake_up - and self.rollout_config.free_cache_engine - and self.device_mesh["infer_tp"].get_local_rank() == 0 - ): - await self.inference_engine.resume_memory_occupation(tags=["kv_cache"]) - log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger) - - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) - - @GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) - async def sleep(self): - if self.rollout_config.free_cache_engine: - log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - await self.release_memory() - log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) - - self.module.train() - - # add empty cache after each compute - get_torch_device().empty_cache() - - # restore random states - if self.device_mesh is not None: - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - - def preprocess_data(self, data: DataProto) -> DataProto: - """All gather across tp group to make each rank has identical input.""" - if self.tp_size == 1: - return data - - # TODO: Current impl doesn't consider FSDP with torch micro-dp - group = self.device_mesh["infer_tp"].get_group() - - all_gather_data_proto(data=data, process_group=group) - return data - - def postprocess_data(self, data: DataProto) -> DataProto: - """Get chunk data of this tp rank since we do all gather in preprocess.""" - if self.tp_size == 1: - return data - - return data.chunk(chunks=self.tp_size)[self.tp_rank] diff --git a/verl/workers/sharding_manager/fsdp_ulysses.py b/verl/workers/sharding_manager/fsdp_ulysses.py deleted file mode 100644 index 39ccb77cc..000000000 --- a/verl/workers/sharding_manager/fsdp_ulysses.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Contains a resharding manager that binds weights from FSDP zero3 to XPerfGPT -""" - -from torch.distributed.device_mesh import DeviceMesh - -from verl import DataProto -from verl.protocol import all_gather_data_proto -from verl.utils.ulysses import get_ulysses_sequence_parallel_group, set_ulysses_sequence_parallel_group - -from .base import BaseShardingManager - - -class FSDPUlyssesShardingManager(BaseShardingManager): - """ - Sharding manager to support data resharding when using FSDP + Ulysses - """ - - def __init__(self, device_mesh: DeviceMesh): - super().__init__() - self.device_mesh = device_mesh - self.seed_offset = 12345 - - def __enter__(self): - if self.device_mesh is not None: - # We have a global SP group - # so we have to change to use model-specific sp group - self.prev_sp_group = get_ulysses_sequence_parallel_group() - set_ulysses_sequence_parallel_group(self.device_mesh["sp"].get_group()) - # TODO: check how to set seed for each model - - def __exit__(self, exc_type, exc_value, traceback): - # restore random states - if self.device_mesh is not None: - # revert to previous sp group - set_ulysses_sequence_parallel_group(self.prev_sp_group) - # TODO: check how to set seed for each model - - def preprocess_data(self, data: DataProto) -> DataProto: - """ - AllGather data from sp region - This is because the data is first sharded along the FSDP dimension as we utilize the DP_COMPUTE - In Ulysses, we need to make sure the same data is used across a SP group - """ - if self.device_mesh is not None: - group = self.device_mesh["sp"].get_group() - - all_gather_data_proto(data=data, process_group=group) - return data - - def postprocess_data(self, data: DataProto) -> DataProto: - """ - Split the data to follow FSDP partition - """ - if self.device_mesh is not None: - sp_size = self.device_mesh["sp"].size() - sp_rank = self.device_mesh["sp"].get_local_rank() - data = data.chunk(chunks=sp_size)[sp_rank] - return data diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py deleted file mode 100644 index 1a9677df5..000000000 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ /dev/null @@ -1,342 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import logging -import os -import time -from collections import OrderedDict - -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType -from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP - -try: - # for torch 2.5+ - from torch.distributed.tensor import DTensor -except ImportError: - from torch.distributed._tensor import DTensor - -from dataclasses import asdict - -from verl import DataProto -from verl.protocol import all_gather_data_proto -from verl.third_party.vllm import LLM -from verl.third_party.vllm import parallel_state as vllm_ps -from verl.utils.device import get_device_id, get_device_name, get_torch_device -from verl.utils.fsdp_utils import ( - fsdp_version, - layered_summon_lora_params, - load_fsdp_model_to_gpu, - offload_fsdp_model_to_cpu, -) -from verl.utils.model import check_exclude_modules, check_target_modules, convert_weight_keys -from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer -from verl.utils.torch_functional import check_device_is_available -from verl.utils.vllm_utils import TensorLoRARequest, VLLMHijack, is_version_ge, patch_vllm_moe_model_weight_loader - -from .base import BaseShardingManager - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -class FSDPVLLMShardingManager(BaseShardingManager): - """Sharding manager for FSDP models with vLLM inference engine integration. - - Manages parameter synchronization between FSDP training models and vLLM - inference engines, handling both full parameters and LoRA adapters with - efficient memory management and device placement. - """ - - @check_device_is_available() - def __init__( - self, - module: FSDP, - inference_engine: LLM, - model_config, - rollout_config, - full_params: bool = False, - device_mesh: DeviceMesh = None, - offload_param: bool = False, - load_format: str = "dummy_hf", - layered_summon: bool = True, - ): - self.module = module - # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model - self.inference_engine = inference_engine - # self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if - # inference_engine else None - - self.model_runner = ( - self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner - if self.inference_engine - else None - ) - - self.model_config = model_config - self.rollout_config = rollout_config - self.device_mesh = device_mesh - self.offload_param = offload_param - self.load_format = load_format - self.layered_summon = layered_summon - - # Full params - self.full_params = full_params - if full_params and fsdp_version(self.module) == 1: - FSDP.set_state_dict_type( - self.module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig() - ) - elif fsdp_version(self.module) == 1: - FSDP.set_state_dict_type( - self.module, - state_dict_type=StateDictType.SHARDED_STATE_DICT, - state_dict_config=ShardedStateDictConfig(), - ) - - self.tp_size = self.device_mesh["infer_tp"].size() - self.tp_rank = self.device_mesh["infer_tp"].get_local_rank() - - # Note that torch_random_states may be different on each dp rank - self.torch_random_states = get_torch_device().get_rng_state() - # get a random rng states - if self.device_mesh is not None: - gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - else: - self.gen_random_states = None - - self.base_sync_done: bool = "dummy" not in load_format - if is_version_ge(pkg="vllm", minver="0.7.3"): - VLLMHijack.hijack() - - @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) - def __enter__(self): - def __collect_lora_params() -> OrderedDict: - """ - collect lora params or full params if base model is not ready in vllm - work with if isinstance(self.module._fsdp_wrapped_module, PeftModel) - """ - from peft.utils.save_and_load import get_peft_model_state_dict - - lora_params = OrderedDict() - peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module) - if fsdp_version(self.module) > 0: - if self.layered_summon: - if not self.base_sync_done: - raise ValueError( - "To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let " - "rollout.load_format=safetensors" - ) - lora_params = layered_summon_lora_params(self.module) - else: - with FSDP.summon_full_params(self.module, writeback=False): - if self.base_sync_done: - lora_params = get_peft_model_state_dict(peft_model) - lora_params = { - name: param.full_tensor().detach().cpu() - if hasattr(param, "full_tensor") - else param.detach().cpu() - for name, param in lora_params.items() - } - else: - model = peft_model.base_model.model - orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() - model = model.to("cpu") - for name, param in model.state_dict().items(): - if any(x in name for x in ["_flat_param", "lora_"]): - continue - name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") - lora_params[name] = ( - param.full_tensor().detach().cpu() - if hasattr(param, "full_tensor") - else param.detach().cpu() - ) - model = model.to(orig_dev) - get_torch_device().empty_cache() - else: - if self.base_sync_done: - lora_params = get_peft_model_state_dict(peft_model) - else: - model = peft_model.base_model.model - orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name() - model = model.to("cpu") - for name, param in model.state_dict().items(): - if any(x in name for x in ["_flat_param", "lora_"]): - continue - name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "") - lora_params[name] = param.detach().cpu() - model = model.to(orig_dev) - return lora_params - - # NOTE: Basically, we only need `get_torch_device().empty_cache()` before vllm wake_up and - # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. - # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory - # to speed up memory allocations. - # - # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management - # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 - self.timing = {} - with simple_timer("reshard", self.timing): - get_torch_device().empty_cache() - - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) - if self.offload_param: - load_fsdp_model_to_gpu(self.module) - - peft_config = None - peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module) - if hasattr(peft_model, "peft_config"): - peft_config = peft_model.peft_config.get("default", None) - params = __collect_lora_params() - else: - params = self.module.state_dict() - params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) - log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) - - if self.rollout_config.free_cache_engine: - if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: - self.inference_engine.wake_up(tags=["weights"]) - else: - self.inference_engine.wake_up() - - # update model params - self.update_params(params, peft_config=peft_config) - log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) - del params - if self.offload_param: - offload_fsdp_model_to_cpu(self.module) - get_torch_device().empty_cache() - - if ( - self.rollout_config.free_cache_engine - and "tags" in inspect.signature(self.inference_engine.wake_up).parameters - ): - self.inference_engine.wake_up(tags=["kv_cache"]) - - log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) - - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) - - @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) - def __exit__(self, exc_type, exc_value, traceback): - if self.rollout_config.free_cache_engine: - self.inference_engine.sleep(level=1) - - self.module.train() - - # add empty cache after each compute - get_torch_device().empty_cache() - - # restore random states - if self.device_mesh is not None: - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - - @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) - def preprocess_data(self, data: DataProto) -> DataProto: - """All gather across tp group to make each rank has identical input.""" - if self.tp_size == 1: - return data - - # TODO: Current impl doesn't consider FSDP with torch micro-dp - group = vllm_ps.get_tensor_model_parallel_group().device_group - - all_gather_data_proto(data=data, process_group=group) - return data - - @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) - def postprocess_data(self, data: DataProto) -> DataProto: - """Get chunk data of this tp rank since we do all gather in preprocess.""" - if self.tp_size == 1: - return data - - return data.chunk(chunks=self.tp_size)[self.tp_rank] - - def update_params(self, updated_params, peft_config=None): - """Update model parameters in the vLLM inference engine. - - Synchronizes parameters from the FSDP training model to the vLLM inference - engine, handling both full model parameters and LoRA adapters with proper - device placement and memory management. - - Args: - updated_params (dict): Dictionary of parameter names to tensor values. - peft_config (optional): PEFT configuration for LoRA adapters. - """ - model = self.model_runner.model - if peft_config: - if self.base_sync_done: - lora_int_id = int(time.time_ns() % 0x7FFFFFFF) - lora_reqest = TensorLoRARequest( - lora_name=f"{lora_int_id}", - lora_int_id=lora_int_id, - lora_path="simon_lora_path", - peft_config=asdict(peft_config), - lora_tensors=updated_params, - ) - self.inference_engine.llm_engine.add_lora(lora_reqest) - logger.info(f"vLLM load weights, loaded_params: {len(updated_params)}") - return - else: - - def replace_lora_wrapper(k): - """Replace LoRA parameter keys with base layer equivalents. - - Transforms LoRA parameter names to their corresponding base layer - names for proper weight loading in vLLM when base model sync is not done. - - Args: - k (str): Original parameter key name. - - Returns: - str: Transformed parameter key for base layer. - """ - stacked_params = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] - if k.endswith(".weight"): - module_k = k[: -len(".weight")] - if check_exclude_modules(peft_config, module_k): - return k - elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules( - peft_config, module_k - ): - return f"{module_k}.base_layer.weight" - if k.endswith(".bias"): - module_k = k[: -len(".bias")] - if check_exclude_modules(peft_config, module_k): - return k - elif any([module_k.endswith(s) for s in stacked_params]) or check_target_modules( - peft_config, module_k - ): - return f"{module_k}.base_layer.bias" - return k - - updated_params = {replace_lora_wrapper(k): v for k, v in updated_params.items()} - - patch_vllm_moe_model_weight_loader(model) - device = get_device_id() # used when fsdp2 set cpu_offload_policy - loaded_params = model.load_weights( - ( - (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) - for name, param in updated_params.items() - ) - ) - - self.base_sync_done = True - logger.info(f"vLLM load weights, loaded_params: {len(loaded_params) if loaded_params else -1}") diff --git a/verl/workers/sharding_manager/megatron_sglang.py b/verl/workers/sharding_manager/megatron_sglang.py deleted file mode 100644 index d353c70e8..000000000 --- a/verl/workers/sharding_manager/megatron_sglang.py +++ /dev/null @@ -1,262 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine. -""" - -import asyncio -import logging -import os - -import torch.distributed as dist -from omegaconf import DictConfig -from sglang.srt.entrypoints.engine import Engine -from sglang.srt.model_executor.model_runner import LocalSerializedTensor -from sglang.srt.utils import MultiprocessingSerializer -from torch import nn -from torch.distributed.device_mesh import DeviceMesh - -from verl.protocol import DataProto, all_gather_data_proto -from verl.utils.device import get_torch_device -from verl.utils.megatron_utils import ( - load_megatron_model_to_gpu, - offload_megatron_model_to_cpu, - per_tensor_generator, -) -from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer -from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets - -from .base import BaseShardingManager - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN")) - - -""" -Megatron Hybrid Engine: -- During training, only the current pp stage holds the parameters -- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all - the parameters) -- Bind the parameters to the inference engine -- Do inference in tp. pp is treated as additional dp -- After inference, all the parameters that doesn't belong to this pp rank is freed. -""" - - -class MegatronSGLangShardingManager(BaseShardingManager): - """A sharding manager for Megatron-style training & inference with SGLang. - - This class manages the sharding of model parameters between training and inference - phases in a Megatron-style parallel setup. It handles: - - Loading/offloading parameters between CPU/GPU - - Updating inference engine weights - - Managing random states for reproducibility - - Data preprocessing for distributed inference - - Args: - actor_module (nn.ModuleList): The actor model modules - inference_engine (Engine): The SGLang inference engine - model_config: Configuration for the actor's model - rollout_config: Configuration for rollout generation - transformer_config: Transformer-specific configuration - layer_name_mapping: Mapping between layer names and parameters - weight_converter: Utility for converting weights between formats - device_mesh (DeviceMesh | None): PyTorch device mesh for distributed training - offload_param (bool): Whether to offload parameters to CPU when not in use - """ - - def __init__( - self, - actor_module: nn.ModuleList, - inference_engine: Engine, - model_config: DictConfig, - rollout_config: DictConfig, - transformer_config, - layer_name_mapping, - weight_converter, - device_mesh: DeviceMesh | None = None, - offload_param: bool = False, - bridge=None, - ): - self.actor_module = actor_module - self.inference_engine = inference_engine - self.model_config = model_config - self.rollout_config = rollout_config - self.transformer_config = transformer_config - self.layer_name_mapping = layer_name_mapping - self.weight_converter = weight_converter - self.device_mesh = device_mesh - self.bridge = bridge - self.offload_param = offload_param - - if self.device_mesh is not None: - self.infer_tp_size = self.device_mesh["tp"].mesh.size()[0] - else: - self.infer_tp_size = self.inference_engine._tp_size - - # Note that torch_random_states may be different on each dp rank - self.torch_random_states = get_torch_device().get_rng_state() - # get a random rng states - if self.device_mesh is not None: - gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - else: - self.gen_random_states = None - - @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) - def __enter__(self): - self.timing = {} - with simple_timer("reshard", self.timing): - loop = asyncio.get_event_loop() - loop.run_until_complete(self.wake_up()) - - @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) - def __exit__(self, exc_type, exc_value, traceback): - loop = asyncio.get_event_loop() - loop.run_until_complete(self.sleep()) - - async def update_weights(self, params): - """ - Update model weights using tensor buckets, similar to THUDM/slime's implementation. - - Notes: - - For the best performance of `rebuild_cuda_tensor`, it is recommended to: - 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES`. - 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` - when using Tensor Parallelism (TP >= 8). - - See reference implementations in SLIME: - - Main logic: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L452 - - runtime envs: https://github.com/THUDM/slime/blob/fb7605cc5fb09af0f9369d37f7192f12bddee577/slime/ray/ppo_actor.py#L39 - """ - if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - await self.inference_engine.resume_memory_occupation() - named_tensors = params - load_format = None - - update_weights_bucket_bytes = int(self.rollout_config.update_weights_bucket_megabytes) << 20 - for batch in get_named_tensor_buckets(named_tensors, update_weights_bucket_bytes): - # On each rank, serialize a batch of (name, tensor) tuples. - # named_tensors_batch will be a list like: - # [(name0, serialized_tensor0_tp0), (name1, serialized_tensor1_tp0), ...] - named_tensors_batch = [ - (name, MultiprocessingSerializer.serialize(tensor.detach())) for name, tensor in batch - ] - - if self.device_mesh["tp"].get_local_rank() == 0: - # On rank 0, prepare a list to hold the gathered batches from all ranks. - gathered_serialized_batches = [None for _ in range(self.device_mesh["tp"].mesh.size()[0])] - else: - gathered_serialized_batches = None - - # Gather the named_tensors_batch from all ranks to rank 0. - # After this, on rank 0, gathered_serialized_batches will be a list of lists: - # [ [ (name0, s_t0_tp0), (name1, s_t1_tp0), ... ], # batch from TP rank 0 - # [ (name0, s_t0_tp1), (name1, s_t1_tp1), ... ], # batch from TP rank 1 - # ... ] - # On other ranks, gathered_serialized_batches will be None. - dist.gather_object( - obj=named_tensors_batch, - object_gather_list=gathered_serialized_batches, - dst=self.device_mesh["tp"].mesh.tolist()[0], - group=self.device_mesh["tp"].get_group(), - ) - - if self.device_mesh["tp"].get_local_rank() == 0: - # Use zip(*) to "transpose" the data structure. - # This groups the serialized parts for each individual tensor across all TP ranks. - # Example: from [[(n0, t0_tp0), (n1, t1_tp0)], [(n0, t0_tp1), (n1, t1_tp1)]] - # to [ ( (n0, t0_tp0), (n0, t0_tp1) ), ( (n1, t1_tp0), (n1, t1_tp1) ) ] - logical_tensors = zip(*gathered_serialized_batches, strict=False) - await self.inference_engine.update_weights_from_tensor( - named_tensors=[ - # 'tensor_group' represents a single logical tensor's data from all ranks. - ( - tensor_group[0][0], # Get the name from the first rank's data. - LocalSerializedTensor( - # 'rank_part' is the (name, serialized_tensor) tuple from one specific rank. - values=[rank_part[1] for rank_part in tensor_group] - ), - ) - for tensor_group in logical_tensors - # each tensor_group is like ( (n0, t0_tp0), (n0, t0_tp1) ) - ], - load_format=load_format, - flush_cache=False, - ) - - if self.device_mesh["tp"].get_local_rank() == 0: - await self.inference_engine.flush_cache() - - async def release_memory(self): - if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - await self.inference_engine.release_memory_occupation() - - @GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger) - async def wake_up(self): - if self.offload_param: - load_megatron_model_to_gpu(self.actor_module) - if self.bridge is not None: - per_tensor_param = self.bridge.export_weights(self.actor_module) - else: - per_tensor_param = per_tensor_generator( - self.actor_module, - self.model_config, - self.weight_converter, - self.transformer_config, - self.layer_name_mapping, - ) - await self.update_weights(per_tensor_param) - if self.offload_param: - offload_megatron_model_to_cpu(self.actor_module) - get_torch_device().empty_cache() - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) - - @GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger) - async def sleep(self): - if self.rollout_config.free_cache_engine: - log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - await self.release_memory() - log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) - - for model in self.actor_module: - model.train() - # add empty cache after each compute - get_torch_device().empty_cache() - - # restore random states - if self.device_mesh is not None: - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - - @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) - def preprocess_data(self, data: DataProto) -> DataProto: - # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp - if self.infer_tp_size == 1: - return data - all_gather_data_proto(data, self.device_mesh["tp"].get_group()) - return data - - @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) - def postprocess_data(self, data: DataProto) -> DataProto: - # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp - if self.infer_tp_size == 1: - return data - return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()] diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py deleted file mode 100644 index b04352c24..000000000 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This file contains a Megatron style Hybrid Engine that shares the weights of the actor with the inference engine. -""" - -import inspect -import logging -import os - -import torch -import torch.distributed -from megatron.core import parallel_state as mpu -from omegaconf import DictConfig -from torch import nn - -from verl import DataProto -from verl.models.mcore.weight_converter import McoreToHFWeightConverterBase -from verl.protocol import all_gather_data_proto -from verl.third_party.vllm import LLM -from verl.third_party.vllm import parallel_state as vllm_ps -from verl.utils.device import get_torch_device -from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator -from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage -from verl.utils.profiler.performance import simple_timer -from verl.utils.torch_functional import check_device_is_available -from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader - -from .base import BaseShardingManager - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -""" -Megatron Hybrid Engine: -- During training, only the current pp stage holds the parameters -- Before inference, broadcast the parameters of the current pp rank - to all other pp ranks (all pp ranks holds all the parameters) -- Bind the parameters to the inference engine -- Do inference in tp. pp is treated as additional dp -- After inference, all the parameters that doesn't belong to this pp rank is freed. -""" - - -class MegatronVLLMShardingManager(BaseShardingManager): - """A sharding manager that bridges Megatron-LM training with vLLM inference. - - This class handles the parameter sharding and communication between: - - Megatron-LM's tensor/expert parallel training setup - - vLLM's tensor parallel inference setup - - Key responsibilities: - - Manages parameter broadcasting between training and inference configurations - - Handles weight conversion between Megatron and HuggingFace formats - - Coordinates memory management between training and inference phases - - Maintains random state consistency across different parallel groups - - Args: - actor_module (nn.ModuleList): The Megatron-LM model being trained - inference_engine (LLM): The vLLM inference engine - model_config: Configuration for the actor's model - transformer_config: Transformer-specific configuration for the model - rollout_config: Configuration for rollout - layer_name_mapping: Mapping between Megatron and HF layer names - weight_converter (McoreToHFWeightConverterBase): Converts weights between formats - device_mesh: Device mesh for parallel operations - offload_param (bool): Whether to offload parameters when not in use - """ - - @check_device_is_available() - def __init__( - self, - actor_module: nn.ModuleList, - inference_engine: LLM, - model_config: DictConfig, - transformer_config, - rollout_config: DictConfig, - layer_name_mapping, - weight_converter: McoreToHFWeightConverterBase, - device_mesh, - offload_param: bool = True, - bridge=None, - ): - self.actor_module = actor_module - self.inference_engine = inference_engine - self.offload_param = offload_param - - # For AsyncLLM, inference_engine and model_runner are defer initialized in vLLMAsyncRollout.load_model - self.model_runner = ( - self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner - if self.inference_engine - else None - ) - - self.model_config = model_config - self.transformer_config = transformer_config - self.rollout_config = rollout_config - self.layer_name_mapping = layer_name_mapping - self.weight_converter = weight_converter - self.bridge = bridge - # initialize groups for vllm inference - self.rank = torch.distributed.get_rank() - self.world_size = torch.distributed.get_world_size() - - self.device_mesh = device_mesh - self.infer_tp_size = self.device_mesh["infer_tp"].size() - self.infer_tp_rank = self.device_mesh["infer_tp"].get_local_rank() - - self.train_tp_size = mpu.get_tensor_model_parallel_world_size() - self.train_tp_rank = mpu.get_tensor_model_parallel_rank() - self.train_tp_group = mpu.get_tensor_model_parallel_group() - self.train_ep_size = mpu.get_expert_model_parallel_world_size() - self.train_ep_rank = mpu.get_expert_model_parallel_rank() - self.train_ep_group = mpu.get_expert_model_parallel_group() - self.train_etp_size = mpu.get_expert_tensor_parallel_world_size() - self.train_etp_rank = mpu.get_expert_tensor_parallel_rank() - self.train_etp_group = mpu.get_expert_tensor_parallel_group() - self.need_tp_reshard = self.train_tp_size != self.infer_tp_size - self.train_tp_larger = self.train_tp_size > self.infer_tp_size - - self.torch_random_states = get_torch_device().get_rng_state() - if self.device_mesh is not None: - gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - else: - self.gen_random_states = None - - @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) - def __enter__(self): - self.timing = {} - with simple_timer("reshard", self.timing): - get_torch_device().empty_cache() - - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) - if self.offload_param: - load_megatron_model_to_gpu(self.actor_module) - - if self.rollout_config.free_cache_engine: - if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: - self.inference_engine.wake_up(tags=["weights"]) - else: - self.inference_engine.wake_up() - if self.bridge is not None: - per_tensor_param = self.bridge.export_weights(self.actor_module) - else: - per_tensor_param = per_tensor_generator( - self.actor_module, - self.model_config, - self.weight_converter, - self.transformer_config, - self.layer_name_mapping, - ) - model = self.model_runner.model - patch_vllm_moe_model_weight_loader(model) - loaded_params = model.load_weights(per_tensor_param) - info = f"vLLM load weights, loaded_params: {len(loaded_params)}" - logger.info(info) - - if self.offload_param: - offload_megatron_model_to_cpu(self.actor_module) - get_torch_device().empty_cache() - - if ( - self.rollout_config.free_cache_engine - and "tags" in inspect.signature(self.inference_engine.wake_up).parameters - ): - self.inference_engine.wake_up(tags=["kv_cache"]) - - # important: need to manually set the random states of each tp to be identical. - if self.device_mesh is not None: - self.torch_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.gen_random_states) - - @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) - def __exit__(self, exc_type, exc_value, traceback): - if self.rollout_config.free_cache_engine: - self.inference_engine.sleep(level=1) - for model in self.actor_module: - model.train() - - get_torch_device().empty_cache() - - # restore random states - if self.device_mesh is not None: - self.gen_random_states = get_torch_device().get_rng_state() - get_torch_device().set_rng_state(self.torch_random_states) - - @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) - def preprocess_data(self, data: DataProto) -> DataProto: - # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp - if self.infer_tp_size == 1: - return data - - # TODO: Current impl doesn't consider FSDP with torch micro-dp - group = vllm_ps.get_tensor_model_parallel_group().device_group - - all_gather_data_proto(data=data, process_group=group) - return data - - @GPUMemoryLogger(role="megatron vllm sharding_manager", logger=logger) - def postprocess_data(self, data: DataProto) -> DataProto: - # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp - if self.infer_tp_size == 1: - return data - return data.chunk(chunks=self.infer_tp_size)[self.infer_tp_rank]