diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index cffb342e5..2e1441300 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -251,9 +251,10 @@ def fit(self): batch = new_batch if batch is None else DataProto.concat([batch, new_batch]) prompt_bsz = self.config.data.train_batch_size - if num_prompt_in_batch < prompt_bsz: + max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + if num_prompt_in_batch < prompt_bsz and max_num_gen_batches > 1: # Added by Reasoning360 TWK NOTE: second condition is to account for when we have zero-variance filtering but are not dynamically growing the batch... print(f"{num_prompt_in_batch=} < {prompt_bsz=}") - max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches + # max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches: print(f"{num_gen_batches=}. Keep generating...") progress_bar.update(1) @@ -267,9 +268,14 @@ def fit(self): + " You could also try set max_num_gen_batches=0 to enable endless trials." ) else: - # Align the batch - traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n - batch = batch[:traj_bsz] + # Added by Reasoning360, need to account for when our batch is smaller due to zero-variance filtering + if num_prompt_in_batch >= prompt_bsz: + # Align the batch + traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n + batch = batch[:traj_bsz] + else: + # TWK TODO!!!: RESCALE THIS SO THAT THE BATCH*N IS DIVISIBLE BY k_partitions (n_gpus...) + print(f"Final {num_prompt_in_batch=} < {prompt_bsz=} after {num_gen_batches=} generation batches. Proceeding with smaller batch...") # === Updating === diff --git a/scripts/train/k2p_hero_cispo.sh b/scripts/train/k2p_hero_cispo.sh new file mode 100644 index 000000000..f37efe39c --- /dev/null +++ b/scripts/train/k2p_hero_cispo.sh @@ -0,0 +1,326 @@ +#!/bin/bash +#SBATCH --job-name=cispo-focused-k2p-finalInstruct-temp1.0-wOmni-fix2 +#SBATCH --nodes=32 +#SBATCH --ntasks=32 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main +#SBATCH --exclude=azure-uk-hpc-H200-instance-114,azure-uk-hpc-H200-instance-394 + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=/lustrefs/users/haonan.li/data/k2 +MATH_DATA_PATH=/lustrefs/users/zhuojun.cheng/vpim/guru_data/train/postprocessed_dedup_am_semantic_filtered_0.05_0.94_thresh_ratio0.5_sample1.0_balanced_step2 +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_1_datamix_6 +TEST_DATA_DIR=${SHARED_DATA_PATH}/test_12k_len + +# Math (train) +math_train_path1=${MATH_DATA_PATH}/math__combined_118.2k.part1_scored.parquet +math_train_path2=${MATH_DATA_PATH}/math__combined_118.2k.part2_scored.parquet +# math_train_path1=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +# math_train_path2=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoninggym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +# codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (train) +if_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet + +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${reasoninggym_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}','${nemotron_train_path}','${if_train_path}']" # '${synlogic_train_path}', +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # '${synlogic_test_path}', + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=2.0 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 32)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_mode="cispo" # Default is 'vanilla' which is equivalent to PPO; +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.0 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Reduced from 32 to reduce memory pressure +gen_tp=4 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 \ No newline at end of file diff --git a/scripts/train/k2p_hero_grpo.sh b/scripts/train/k2p_hero_grpo.sh new file mode 100644 index 000000000..e0ec0988f --- /dev/null +++ b/scripts/train/k2p_hero_grpo.sh @@ -0,0 +1,328 @@ +#!/bin/bash +#SBATCH --job-name=grpo-k2p-finalInstruct-64k-temp1.2-focused +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main +#SBATCH --exclude=azure-uk-hpc-H200-instance-114,azure-uk-hpc-H200-instance-394 + +# SBATCH --job-name=grpo-hero-k2p-finalInstruct-temp1.2-wOmni-fix2 + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="grpo-k2p-finalInstruct-64k-temp1.2-focused-404084" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=/lustrefs/users/haonan.li/data/k2 +MATH_DATA_PATH=/lustrefs/users/zhuojun.cheng/vpim/guru_data/train/postprocessed_dedup_am_semantic_filtered_0.05_0.94_thresh_ratio0.5_sample1.0_balanced_step2 +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train_scored_dedup_am_12k_len_rm_flipscore_score_method_5_1_datamix_6 +TEST_DATA_DIR=${SHARED_DATA_PATH}/test_12k_len + +# Math (train) +math_train_path1=${MATH_DATA_PATH}/math__combined_118.2k.part1_scored.parquet +math_train_path2=${MATH_DATA_PATH}/math__combined_118.2k.part2_scored.parquet +# math_train_path1=${TRAIN_DATA_DIR}/math__combined_118.2k.part1.parquet +# math_train_path2=${TRAIN_DATA_DIR}/math__combined_118.2k.part2.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_2.4k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_599.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_9.6k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_11.1k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_297.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_653.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_3.4k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.4k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_2.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_5.0k.parquet +reasoninggym_train_path=${TRAIN_DATA_DIR}/logic__reasoning_gym_40.6k.parquet +synlogic_train_path=${TRAIN_DATA_DIR}/logic__synlogic_12.1k.parquet +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k.parquet +# Simulation (test) +# codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_7.4k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_2.9k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_31.7k.parquet +nemotron_train_path=${TRAIN_DATA_DIR}/stem__nemotron_13.3k.parquet +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (train) +if_train_path=${TRAIN_DATA_DIR}/ifbench__fixed_85.6k.parquet + +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${reasoninggym_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}','${nemotron_train_path}','${if_train_path}']" # '${synlogic_train_path}', +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # '${synlogic_test_path}', + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2-403906/global_step_300/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Reduced from 32 to reduce memory pressure +gen_tp=4 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 \ No newline at end of file diff --git a/scripts/train/k2p_hero_grpo_newData.sh b/scripts/train/k2p_hero_grpo_newData.sh new file mode 100644 index 000000000..ade668870 --- /dev/null +++ b/scripts/train/k2p_hero_grpo_newData.sh @@ -0,0 +1,343 @@ +#!/bin/bash +#SBATCH --job-name=grpo-k2p-newFiltered-64k-fullData-finalInstruct +#SBATCH --nodes=64 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main +#SBATCH --exclude=azure-uk-hpc-H200-instance-114,azure-uk-hpc-H200-instance-394 + +# SBATCH --job-name=grpo-hero-k2p-finalInstruct-temp1.2-wOmni-fix2 + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-009:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain +export MATH_LLM_JUDGE_URL="http://azure-uk-hpc-H200-instance-033:8000" # Fill in the OmniMATH llm-as-judge hosted URL, only used to score OmniMATH dataset if not empty + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/lustrefs/users/taylor.killian/miniconda3/envs/sync-rl/bin/ +export ROCR_VISIBLE_DEVICES=None +export NCCL_TIMEOUT_SECONDS=4800000 +export OMPI_MCA_coll_hcoll_enable=0 \ +TORCH_NCCL_ENABLE_MONITORING=0 \ +CUDA_DEVICE_ORDER=PCI_BUS_ID \ +NCCL_SOCKET_IFNAME=eth0 \ +UCX_TLS=rc \ +UCX_NET_DEVICES=mlx5_ib0:1 \ +NCCL_DEBUG=WARN \ +NCCL_TOPO_FILE=/opt/microsoft/ndv5-topo.xml \ +NCCL_IB_PCI_RELAXED_ORDERING=1 \ +NCCL_IB_QPS_PER_CONNECTION=4 \ +NCCL_IGNORE_CPU_AFFINITY=1 \ +NCCL_P2P_NET_CHUNKSIZE=$((512 * 1024)) \ +NCCL_PXN_DISABLE=1 \ +NCCL_MIN_NCHANNELS=32 \ +SHARP_SMX_UCX_INTERFACE=mlx5_ib0:1 \ +SHARP_COLL_ENABLE_SAT=1 \ +SHARP_COLL_LOG_LEVEL=3 \ +SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1 \ +NCCL_COLLNET_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== + +# Training Data Configuration +DATA_MIX_DIR="/lustrefs/users/varad.pimpalkhute/data/k2/final/data_mix_1" +train_file_list=() + +# List of datasets to include (filename only) +# Comment out lines to exclude specific datasets +dataset_names=( + "codegen__deduped_leetcode2k_2.4k.parquet" + "codegen__deduped_livecodebench_599.parquet" + "codegen__deduped_primeintellect_9.6k.parquet" + "codegen__deduped_taco_11.1k.parquet" + "ifbench__fixed_85.6k.parquet" + "logic__arcagi1_297.parquet" + "logic__arcagi2_653.parquet" + "logic__barc_3.4k.parquet" + "logic__graph_logical_dataset_1.4k.parquet" + "logic__ordering_puzzle_dataset_2.9k.parquet" + "logic__reasoning_gym_40.6k.parquet" + "logic__synlogic_12.1k.parquet" + "logic__zebra_puzzle_dataset_5.0k.parquet" + "math__combined_118.2k.part1.parquet" + "math__combined_118.2k.part2.parquet" + "omni_math_4.43k_dedup.parquet" + "simulation__codeio_fixed_12.1k.parquet" + "stem__nemotron_13.3k.parquet" + "stem__web_31.7k.parquet" + "table__hitab_7.4k.parquet" + "table__multihier_2.9k.parquet" +) + +echo "Collecting training files from ${DATA_MIX_DIR}..." + +# Search for each dataset in all subdirectories +for dataset in "${dataset_names[@]}"; do + for subdir in "impossible_questions" "131k_context_questions" "main_questions"; do + file_path="${DATA_MIX_DIR}/${subdir}/${dataset}" + if [ -f "$file_path" ]; then + echo "Adding: $file_path" + train_file_list+=("'$file_path'") + fi + done +done + +# Join with comma to form Python list string +IFS=, +train_files="[${train_file_list[*]}]" +unset IFS + +echo "Total training files found: ${#train_file_list[@]}" + +# Test Data Configuration +TEST_DATA_DIR=/lustrefs/users/haonan.li/data/k2/test_12k_len +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +aime25_test_path2=${TEST_DATA_DIR}/math__aime2025_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet +reasoninggym_test_path=${TEST_DATA_DIR}/logic__reasoning_gym_425.parquet +synlogic_test_path=${TEST_DATA_DIR}/logic__synlogic_217.parquet +arcagi1_test_path=${TEST_DATA_DIR}/logic__arcagi1_400.parquet +# graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +# ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet + +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_336.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_1k.parquet + +# Stem (test) +nemotron_test_path=${TEST_DATA_DIR}/stem__nemotron_100.parquet +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_1k.parquet + +# Instruction follow (test) +if_test_path=${TEST_DATA_DIR}/ood__ifeval_100.parquet +if_bench_test_path=${TEST_DATA_DIR}/ifbench_800.parquet + +# Focused data mixture (math, code, stem) +# train_files="['${math_train_path1}','${math_train_path2}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}','${nemotron_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Full data mixture (uncomment to use) +test_files="['${math_test_path}','${aime_test_path}','${aime25_test_path2}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${synlogic_test_path}','${reasoninggym_test_path}','${arcagi1_test_path}','${multihier_test_path}','${hitab_test_path}','${nemotron_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}','${if_test_path}','${if_bench_test_path}']" # + + +# =================== Model =================== +# BASE_MODEL=/lustrefs/users/runner/workspace/checkpoints/huggingface/sft/mid4_rope_sft_reasoning_am_251117/checkpoints/checkpoint_0002250 # AM-Think SFT +BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Oss-Instruct-mid4 # Final Instruct SFT (after stg4_iter 10k) +# BASE_MODEL=/lustrefs/users/varad.pimpalkhute/data_process/K2-Plus-Instruct-mid4 # Instruct SFT, after stg4_iter 7k +# BASE_MODEL=/lustrefs/users/taylor.killian/Reasoning360/checkpoints/k2plus_rl/grpo-focused-k2p-finalInstruct-temp1.2-wOmni-fix2-403906/global_step_300/actor/huggingface + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${SLURM_JOB_ID} #-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 64)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 12)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=256 # model grad update batchsize + +# Algorithm +temperature=1.2 +val_temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=16 # Reduced from 32 to reduce memory pressure +gen_tp=4 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra','hf_model'] \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + actor_rollout_ref.model.use_liger=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=3 \ No newline at end of file diff --git a/scripts/train/test_k2p_cispo_m2.sh b/scripts/train/test_k2p_cispo_m2.sh new file mode 100644 index 000000000..01cf6ae81 --- /dev/null +++ b/scripts/train/test_k2p_cispo_m2.sh @@ -0,0 +1,304 @@ +#!/bin/bash +#SBATCH --job-name=cispo-focused-fixed +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --account=iq +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="cispo-focused-fixed-Qwen2.5-7B-1016971" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://10.24.2.1:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/mnt/weka/home/taylor.killian/miniconda3/envs/sync-rl/bin/ +export NCCL_TIMEOUT_SECONDS=4800 +export TORCH_NCCL_ENABLE_MONITORING=0 +export NCCL_DEBUG=warn +export NCCL_NET=IB +export NCCL_IB_HCA="mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7" +export NCCL_CROSS_NIC=1 +export NCCL_IB_TC=136 +export NCCL_SOCKET_IFNAME="^lo,docker,virbr" +export CUDA_DEVICE_MAX_CONNECTIONS=8 +export NCCL_NVLS_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=/mnt/sharefs/users/zhuojun.cheng +SHARED_MODEL_PATH=/mnt/sharefs/users/haonan.li/models +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/guru_data/train/guru92k_release_0603 +TEST_DATA_DIR=${SHARED_DATA_PATH}/guru_data/test/online # ← unchanged + +# ---------- Math ---------- +# train +math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet +# test +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# ---------- Code ---------- +# train +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_8.8k.parquet +# test (unchanged) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500_sampled_200.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# ---------- Logic ---------- +# train +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_1.3k.parquet +# test (unchanged) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300_sampled_200.parquet +graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet +arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet + +# ---------- Simulation ---------- +# train +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k_3.7k.parquet +# test (unchanged) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet + +# ---------- Table ---------- +# train +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# test (unchanged) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_300_sampled_200.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_300_sampled_200.parquet + +# ---------- Stem ---------- +# train +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# test (unchanged) +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet + +# Full Guru92k mixture +# train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${graph_test_path}','${ordering_puzzle_test_path}','${arcagi1_test_path}','${codeio_test_path}','${multihier_test_path}','${hitab_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Focused Guru92k mixture (Math + Code + STEM) +train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}']" +test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# =================== Model =================== +BASE_MODEL=Qwen/Qwen2.5-7B + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${BASE_MODEL##*/}-${SLURM_JOB_ID} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=2.0 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 32)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_mode="cispo" # Default is "vanilla" which is equivalent to PPO; +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=2 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.dtype=${rollout_dtype} \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + actor_rollout_ref.rollout.dtype=${rollout_dtype} \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=1 \ No newline at end of file diff --git a/scripts/train/test_k2p_grpo_m2.sh b/scripts/train/test_k2p_grpo_m2.sh new file mode 100644 index 000000000..aefbf34fc --- /dev/null +++ b/scripts/train/test_k2p_grpo_m2.sh @@ -0,0 +1,304 @@ +#!/bin/bash +#SBATCH --job-name=grpo-focused +#SBATCH --nodes=4 +#SBATCH --ntasks=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=96 +#SBATCH --account=iq +#SBATCH --mem=0 +#SBATCH --output=slurm/%x-%j.log +#SBATCH --error=slurm/%x-%j.log +#SBATCH --exclusive +#SBATCH --time=720:00:00 +#SBATCH --partition=main + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="grpo-focused-Qwen2.5-7B-994853" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="http://10.24.2.1:8000" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export CONDA_BIN_PATH=/mnt/weka/home/taylor.killian/miniconda3/envs/sync-rl/bin/ +export NCCL_TIMEOUT_SECONDS=4800 +export TORCH_NCCL_ENABLE_MONITORING=0 +export NCCL_DEBUG=warn +export NCCL_NET=IB +export NCCL_IB_HCA="mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7" +export NCCL_CROSS_NIC=1 +export NCCL_IB_TC=136 +export NCCL_SOCKET_IFNAME="^lo,docker,virbr" +export CUDA_DEVICE_MAX_CONNECTIONS=8 +export NCCL_NVLS_ENABLE=1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=1 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=/mnt/sharefs/users/zhuojun.cheng +SHARED_MODEL_PATH=/mnt/sharefs/users/haonan.li/models +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/guru_data/train/guru92k_release_0603 +TEST_DATA_DIR=${SHARED_DATA_PATH}/guru_data/test/online # ← unchanged + +# ---------- Math ---------- +# train +math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet +# test +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# ---------- Code ---------- +# train +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__deduped_leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__deduped_livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__deduped_primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__deduped_taco_8.8k.parquet +# test (unchanged) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500_sampled_200.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# ---------- Logic ---------- +# train +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_dataset_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_dataset_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_dataset_1.3k.parquet +# test (unchanged) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300_sampled_200.parquet +graph_test_path=${TEST_DATA_DIR}/logic__graph_logical_dataset_150_sampled_77.parquet +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150_sampled_100.parquet +arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet + +# ---------- Simulation ---------- +# train +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_fixed_12.1k_3.7k.parquet +# test (unchanged) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500_sampled_200.parquet + +# ---------- Table ---------- +# train +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# test (unchanged) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_300_sampled_200.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_300_sampled_200.parquet + +# ---------- Stem ---------- +# train +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# test (unchanged) +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet + +# Full Guru92k mixture +# train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${arcagi1_train_path}','${arcagi2_train_path}','${barc_train_path}','${graph_train_path}','${ordering_train_path}','${zebra_train_path}','${codeio_train_path}','${hitab_train_path}','${multihier_train_path}','${webinstruct_train_path}']" +# test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${zebralogic_test_path}','${graph_test_path}','${ordering_puzzle_test_path}','${arcagi1_test_path}','${codeio_test_path}','${multihier_test_path}','${hitab_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# Focused Guru92k mixture (Math + Code + STEM) +train_files="['${math_train_path}','${leetcode_train_path}','${livecodebench_train_path}','${primeintellect_train_path}','${taco_train_path}','${webinstruct_train_path}']" +test_files="['${math_test_path}','${aime_test_path}','${amc_test_path}','${humaneval_test_path}','${mbpp_test_path}','${livecodebench_test_path}','${gpqa_diamond_test_path}','${supergpqa_test_path}']" + +# =================== Model =================== +BASE_MODEL=Qwen/Qwen2.5-7B + +# =================== Logging =================== +WANDB_PROJECT=k2plus_rl +WANDB_EXPERIMENT_NAME=${SLURM_JOB_NAME}-${BASE_MODEL##*/}-${SLURM_JOB_ID} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ${CONDA_BIN_PATH}ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + env -u ROCR_VISIBLE_DEVICES -u HIP_VISIBLE_DEVICES \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 32)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_mode="vanilla" # Default is "vanilla" which is equivalent to PPO; +loss_agg_mode="token-mean" +rollout_dtype="float16" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=256 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Training config +sp_size=1 +gen_tp=2 +gen_max_num_seqs=1024 # Reduced from 1024 to reduce memory pressure +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 1)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m recipe.dapo.main_dapo \ + --config-path=config \ + --config-name="dapo_fsdp_config.yaml" \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.nccl_timeout=${NCCL_TIMEOUT_SECONDS} \ + actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \ + actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.policy_loss.cispo_clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp2" \ + actor_rollout_ref.actor.dtype=${rollout_dtype} \ + actor_rollout_ref.actor.optim.lr=5e-7 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \ + actor_rollout_ref.actor.entropy_checkpointing=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.ref.entropy_from_logits_with_chunking=True \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.max_num_seqs=${gen_max_num_seqs} \ + actor_rollout_ref.rollout.disable_log_stats=False \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.enable_prefix_caching=True \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.rollout.multi_turn.enable=False \ + actor_rollout_ref.rollout.mode="sync" \ + actor_rollout_ref.rollout.dtype=${rollout_dtype} \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.model.enable_activation_offload=True \ + reward_model.reward_manager=async_multi_process \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=$worker_num \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.log_val_generations=50 \ + trainer.resume_mode=auto \ + trainer.max_actor_ckpt_to_keep=1 \ No newline at end of file diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 444595c76..fe94e7a49 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -56,6 +56,8 @@ actor_rollout_ref: clip_cov_ub: 5.0 kl_cov_ratio: 0.0002 ppo_kl_coef: 0.1 + cispo_clip_ratio_high: 0.2 + cispo_clip_ratio_low: 0.2 clip_ratio_c: 3.0 loss_agg_mode: token-mean entropy_coeff: 0 diff --git a/verl/trainer/config/actor/actor.yaml b/verl/trainer/config/actor/actor.yaml index 7c733ed60..e8c082e63 100644 --- a/verl/trainer/config/actor/actor.yaml +++ b/verl/trainer/config/actor/actor.yaml @@ -49,7 +49,7 @@ policy_loss: # # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs _target_: verl.workers.config.PolicyLossConfig - # Loss function mode: vanilla / clip-cov / kl-cov /gpg from https://arxiv.org/abs/2505.22617 + # Loss function mode: vanilla / clip-cov / kl-cov / cispo (from https://arxiv.org/abs/2506.13585) / gpg (from https://arxiv.org/abs/2505.22617) loss_mode: "vanilla" # Ratio of tokens to be clipped for clip-cov loss @@ -67,6 +67,12 @@ policy_loss: # KL divergence penalty coefficient ppo_kl_coef: 0.1 + # Upper bound for CISPO importance ratio clipping + cispo_clip_ratio_high: 0.2 + + # Lower bound for CISPO importance ratio clipping + cispo_clip_ratio_low: 0.2 + # Constant C in Dual-clip PPO; clips when advantage < 0 and ratio > C clip_ratio_c: 3.0 diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 7a9103c4d..425f095b0 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -1164,6 +1164,77 @@ def compute_policy_loss_kl_cov( return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0) +@register_policy_loss("cispo") +def compute_policy_loss_cispo( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | AlgoConfig] = None, + rollout_log_probs: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the CISPO policy objective and related metrics. + CISPO (Clipped Importance Sampling Policy Optimization) clips importance sampling weights + instead of dropping tokens, which is beneficial for training on sparse but critical tokens + and long-context reasoning in RL. + Reference: https://www.arxiv.org/abs/2506.13585 + Args: + old_log_prob (torch.Tensor): + Log-probabilities of actions under the old policy, shape (batch_size, response_length). + log_prob (torch.Tensor): + Log-probabilities of actions under the current policy, shape (batch_size, response_length). + advantages (torch.Tensor): + Advantage estimates for each action, shape (batch_size, response_length). + response_mask (torch.Tensor): + Mask indicating which tokens to include in the loss, shape (batch_size, response_length). + loss_agg_mode (str, optional): + Aggregation mode for loss computation + config (AlgoConfig): + Algorithm configuration containing CISPO parameters + rollout_log_probs: `(torch.Tensor)`: + log probabilities of actions under the rollout policy, shape (batch_size, response_length). + Returns: + tuple: (pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower) + """ + # Setup CISPO configuration + assert config.policy_loss.loss_mode == "cispo", "CISPO loss mode not set in config" + cispo_clip_ratio_high = config.policy_loss.cispo_clip_ratio_high + cispo_clip_ratio_low = config.policy_loss.cispo_clip_ratio_low + clip_ratio_c = config.get("clip_ratio_c", 3.0) + + # Same code as compute_policy_loss + assert clip_ratio_c > 1.0, ( + "The lower bound of the clip_ratio_c for dual-clip PPO should be greater than 1.0," + + f" but get the value: {clip_ratio_c}." + ) + + negative_approx_kl = log_prob - old_log_prob + # Clamp negative_approx_kl for stability + negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) + ratio = torch.exp(negative_approx_kl) + ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask) + + # CISPO specific loss + ratio = ratio.detach() # Stop gradient on IS ratio + importance_sampling_weight = torch.clamp(ratio, max=1+cispo_clip_ratio_high) + pg_losses = -advantages * log_prob * importance_sampling_weight + + if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None: + # Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl + tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs) + tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap) + pg_losses = pg_losses * tis_imp_ratio + + pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode) + # For compatibility, return zero for pg_clipfrac_lower and pg_clipfrac (not used in CISPO) + pg_clipfrac = torch.tensor(0.0, device=pg_loss.device) + pg_clipfrac_lower = torch.tensor(0.0, device=pg_loss.device) + + return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower + + @register_policy_loss("geo_mean") def compute_policy_loss_geo_mean( old_log_prob: torch.Tensor, diff --git a/verl/utils/reward_score/math_llm_judge/__init__.py b/verl/utils/reward_score/math_llm_judge/__init__.py index 7d2b20cae..3d78e1c67 100644 --- a/verl/utils/reward_score/math_llm_judge/__init__.py +++ b/verl/utils/reward_score/math_llm_judge/__init__.py @@ -398,14 +398,15 @@ def llm_check_answer(model_output: str, ground_truth: str, question: str) -> boo # use llm to check if the answer is correct # url = "http://176.56.200.81:30000/v1/chat/completions" - url = os.getenv("MATH_LLM_JUDGE_URL") - if not url: + url_base = os.getenv("MATH_LLM_JUDGE_URL") + if not url_base: raise ValueError("MATH_LLM_JUDGE_URL is not set") + url = url_base.rstrip("/") + "/v1/chat/completions" prompt = input_template.format(QUESTION=question, STUDENT_ANSWER=model_output, REFERENCE_ANSWER=ground_truth) data = { - "model": "Qwen/Qwen2.5-32B-Instruct", + "model": "openai/gpt-oss-120b", "messages": [{"role": "user", "content": prompt}], } response = requests.post(url, json=data) @@ -423,7 +424,7 @@ def llm_check_answer(model_output: str, ground_truth: str, question: str) -> boo def compute_score(model_output: str, ground_truth: str, extra_info: dict) -> bool: - question = extra_info["question"] + question = extra_info["original_question"] model_output = str(model_output) ground_truth = str(ground_truth) @@ -447,5 +448,4 @@ def compute_score(model_output: str, if is_matched and not is_correct: # use llm to check if the answer is correct is_correct = llm_check_answer(extracted_model_output, ground_truth, question) - - return is_correct, 1, extracted_model_output + return is_correct, 1, extracted_model_output \ No newline at end of file diff --git a/verl/utils/reward_score/naive_dapo.py b/verl/utils/reward_score/naive_dapo.py index d26a1dd72..f819048a6 100644 --- a/verl/utils/reward_score/naive_dapo.py +++ b/verl/utils/reward_score/naive_dapo.py @@ -189,7 +189,7 @@ def _parse_latex(expr: str) -> str: expr = expr.replace("\\tfrac", "\\frac") expr = expr.replace("\\dfrac", "\\frac") expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. - expr = latex2text.LatexNodes2Text().latex_to_text(expr) + # expr = latex2text.LatexNodes2Text().latex_to_text(expr) # Replace the specific characters that this parser uses. expr = expr.replace("√", "sqrt") @@ -426,7 +426,8 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) is_correct = False else: - is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + # is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) + is_correct = False if not is_correct: break diff --git a/verl/utils/reward_score/prime_math/__init__.py b/verl/utils/reward_score/prime_math/__init__.py index 8d9d273e3..d0ea47ac9 100644 --- a/verl/utils/reward_score/prime_math/__init__.py +++ b/verl/utils/reward_score/prime_math/__init__.py @@ -55,7 +55,7 @@ def _parse_latex(expr: str) -> str: expr = expr.replace("\\tfrac", "\\frac") expr = expr.replace("\\dfrac", "\\frac") expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. - expr = latex2text.LatexNodes2Text().latex_to_text(expr) + # expr = latex2text.LatexNodes2Text().latex_to_text(expr) # Replace the specific characters that this parser uses. expr = expr.replace("√", "sqrt") diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py index af6199732..011d2ebc5 100644 --- a/verl/workers/config/actor.py +++ b/verl/workers/config/actor.py @@ -49,6 +49,8 @@ class PolicyLossConfig(BaseConfig): clip_cov_ub: float = 5.0 kl_cov_ratio: float = 0.0002 ppo_kl_coef: float = 0.1 + cispo_clip_ratio_high: float = 0.2 + cispo_clip_ratio_low: float = 0.2 @dataclass @@ -225,6 +227,7 @@ class FSDPActorConfig(ActorConfig): """ strategy: str = "fsdp" + dtype: str = "bfloat16" grad_clip: float = 1.0 ulysses_sequence_parallel_size: int = 1 entropy_from_logits_with_chunking: bool = False