diff --git a/.gitignore b/.gitignore index dd3df16f0..53dc2dca1 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,6 @@ artifacts/ # Analysis analysis/ +wandb/ +outputs/ +slurm/ \ No newline at end of file diff --git a/data_preprocess/ood/ifbench.py b/data_preprocess/ood/ifbench.py new file mode 100644 index 000000000..7d2656121 --- /dev/null +++ b/data_preprocess/ood/ifbench.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python +import argparse +import os + +import datasets +import transformers +from datasets import load_dataset +from tqdm import tqdm + +from verl.utils.data_process.filter import LengthFilter +from verl.utils.data_process.utils import sample_dataset, save_dataset, set_seed + +""" +python data_preprocess/ood/ifbench.py +""" + +def get_datasets(cache_dir: str): + """ + Loads the ifbench dataset. + """ + try: + dataset = load_dataset("allenai/IF_multi_constraints_upto5", cache_dir=cache_dir)["train"] + print(f"ifbench dataset: {len(dataset)} examples") + return None, dataset + except Exception as e: + print(f"Error loading dataset: {e}") + return None, None + +PromptTemplate = """{{context}}""" + +def make_map_fn(split: str, data_source: str) -> callable: + def process_fn(example, idx): + # Extract the user prompt from messages + prompt = example.get("messages", [{}])[0].get("content", "") + + # Preserve original ground_truth + original_gt = example.get("ground_truth") + + data = { + "data_source": data_source, + "prompt": [ + { + "role": "user", + "content": PromptTemplate.replace("{{context}}", prompt) + } + ], + "ability": "ood", + "apply_chat_template": True, + "reward_model": { + "style": "rule", + "ground_truth": original_gt, + }, + "extra_info": None + } + + # Debug print for first two examples + if idx < 2: + print("\n" + "=" * 10 + f"{data_source} {split} {idx}" + "=" * 10) + print(data) + + return data + + return process_fn + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Download, process, and save OOD ifbench dataset." + ) + parser.add_argument( + "--data-dir", + default="data", + help="Base directory to save the processed data files.", + ) + parser.add_argument("--domain", default="ood", help="Domain of the dataset.") + parser.add_argument("--name", default="ifbench", help="Name of the dataset.") + parser.add_argument( + "--sample-size", + type=int, + default=None, + help="Number of samples to use from dataset. If None, use all samples.", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + + args = parser.parse_args() + + # Config + set_seed(args.seed) + data_source = f"{args.domain}__{args.name}" + test_output_dir = os.path.join(args.data_dir, "test") + + # Download dataset + cache_dir = datasets.config.HF_DATASETS_CACHE + _, dataset = get_datasets(cache_dir) + + if dataset is None: + raise RuntimeError("Failed to load ifbench dataset") + + # Process dataset + process_fn = make_map_fn("test", data_source) + dataset = dataset.map(function=process_fn, with_indices=True) + + # Filter dataset + try: + # length filter + tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B") + length_filter = LengthFilter(tokenizer=tokenizer, max_length=4096) + dataset = dataset.filter(lambda x: length_filter.check(x)) + + # filter out examples without ground_truth + dataset = dataset.filter(lambda x: x["reward_model"]["ground_truth"] is not None) + except Exception as e: + print(f"Warning: Could not perform length filtering. Error: {e}") + print("Proceeding without length filtering.") + + # Sample the dataset + dataset = sample_dataset(dataset, args.sample_size) + + # Save dataset + test_output_path = save_dataset( + dataset=dataset, + output_dir=test_output_dir, + filename_prefix=data_source, + sample_size=len(dataset), + ) + + print( + f"\nDone!\n" + f"Data source: {data_source}\n" + f"Test data saved to {test_output_path} ({len(dataset)} samples)" + ) diff --git a/scripts/train/example_singlenode_rl_qwen7b_synlogic.sh b/scripts/train/example_singlenode_rl_qwen7b_synlogic.sh new file mode 100644 index 000000000..fc4516a00 --- /dev/null +++ b/scripts/train/example_singlenode_rl_qwen7b_synlogic.sh @@ -0,0 +1,268 @@ +#!/bin/bash +#SBATCH --job-name=Qwen2-7B-OC +#SBATCH --nodes=2 +#SBATCH --ntasks=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=10 +#SBATCH --output=slurm/%x-%j.out +#SBATCH --error=slurm/%x-%j.err +#SBATCH --account=iq +#SBATCH --mem=512G + + +# =================== Frequently Used Variables =================== +RESUME_CKPT_DIR_NAME="" # Fill in the checkpoint directory name to resume from, otherwise from scratch +export STEM_LLM_JUDGE_URL="" # Fill in the llm-as-judge hosted URL, currently used only in 'STEM' domain + +# =================== Cluster Environment =================== +export NCCL_DEBUG=info +export NCCL_ALGO=NVLSTree +export NCCL_IBEXT_DISABLE=1 +export NCCL_NVLS_ENABLE=1 +export NCCL_IB_HCA=mlx5 +export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1 + +# Get the list of allocated nodes +nodes=( $(scontrol show hostnames "$SLURM_JOB_NODELIST") ) +echo "Nodes to check: ${nodes[@]}" + +# We'll track PIDs so we can wait on them and detect errors +declare -A pids +export head_node=${nodes[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +port=6379 +address_head=$head_node_ip:$port + +export worker_num=$SLURM_NNODES +export HYDRA_FULL_ERROR=1 +export VLLM_USE_V1=0 + +# =================== Data Mixture =================== +SHARED_DATA_PATH=./data +TRAIN_DATA_DIR=${SHARED_DATA_PATH}/train +TEST_DATA_DIR=${SHARED_DATA_PATH}/test + +# synlogic +synlogic_train_path=${TRAIN_DATA_DIR}/synlogic_object_counting_train.parquet +synlogic_test_path=${TEST_DATA_DIR}/synlogic_object_counting_test.parquet + +# Math (train) +math_train_path=${TRAIN_DATA_DIR}/math__combined_54.4k.parquet +# Math (test) +math_test_path=${TEST_DATA_DIR}/math__math_500.parquet +aime_test_path=${TEST_DATA_DIR}/math__aime_repeated_8x_240.parquet +amc_test_path=${TEST_DATA_DIR}/math__amc_repeated_4x_332.parquet + +# Code (train) +leetcode_train_path=${TRAIN_DATA_DIR}/codegen__leetcode2k_1.3k.parquet +livecodebench_train_path=${TRAIN_DATA_DIR}/codegen__livecodebench_440.parquet +primeintellect_train_path=${TRAIN_DATA_DIR}/codegen__primeintellect_7.5k.parquet +taco_train_path=${TRAIN_DATA_DIR}/codegen__taco_8.8k.parquet +# Code (test) +humaneval_test_path=${TEST_DATA_DIR}/codegen__humaneval_164.parquet +mbpp_test_path=${TEST_DATA_DIR}/codegen__mbpp_500.parquet +livecodebench_test_path=${TEST_DATA_DIR}/codegen__livecodebench_279.parquet + +# Logic (train) +arcagi1_train_path=${TRAIN_DATA_DIR}/logic__arcagi1_111.parquet +arcagi2_train_path=${TRAIN_DATA_DIR}/logic__arcagi2_190.parquet +barc_train_path=${TRAIN_DATA_DIR}/logic__barc_1.6k.parquet +graph_train_path=${TRAIN_DATA_DIR}/logic__graph_logical_1.2k.parquet +ordering_train_path=${TRAIN_DATA_DIR}/logic__ordering_puzzle_1.9k.parquet +zebra_train_path=${TRAIN_DATA_DIR}/logic__zebra_puzzle_1.3k.parquet +# Logic (test) +zebralogic_test_path=${TEST_DATA_DIR}/logic__zebra_puzzle_dataset_300.parquet +ordering_puzzle_test_path=${TEST_DATA_DIR}/logic__ordering_puzzle_dataset_150.parquet + +# Simulation (train) +codeio_train_path=${TRAIN_DATA_DIR}/simulation__codeio_3.7k.parquet +# Simulation (test) +codeio_test_path=${TEST_DATA_DIR}/simulation__codeio_500.parquet +arcagi1_test_path=${TEST_DATA_DIR}/simulation__arcagi1_200.parquet + +# Table (train) +hitab_train_path=${TRAIN_DATA_DIR}/table__hitab_4.3k.parquet +multihier_train_path=${TRAIN_DATA_DIR}/table__multihier_1.5k.parquet +# Table (test) +multihier_test_path=${TEST_DATA_DIR}/table__multihier_300.parquet +hitab_test_path=${TEST_DATA_DIR}/table__hitab_300.parquet + +# Stem (train) +webinstruct_train_path=${TRAIN_DATA_DIR}/stem__web_3.6k.parquet +# Stem (test) +gpqa_diamond_test_path=${TEST_DATA_DIR}/stem__gpqa_diamond_198.parquet +supergpqa_test_path=${TEST_DATA_DIR}/stem__supergpqa_200.parquet + +train_files="['${synlogic_train_path}']" # Use math as example, add to more tasks as needed +test_files="['${synlogic_test_path}']" # Use math as example, add to more tasks as needed + +# =================== Model =================== +# BASE_MODEL=Qwen/Qwen2.5-32B # Note: This is the original Qwen32B-Base model. In training, we add 'think' system prompt to it (see README). +# BASE_MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-7B +BASE_MODEL=Qwen/Qwen2.5-7B-instruct +# BASE_MODEL=Qwen/Qwen3-4B-Thinking-2507 +# =================== Logging =================== +WANDB_PROJECT=Reasoning360 +WANDB_EXPERIMENT_NAME=${SLURM_JOB_ID}-${SLURM_JOB_NAME}-${BASE_MODEL##*/} + +# If RESUME_CKPT_DIR is not empty, resume from the checkpoint +if [[ -n "$RESUME_CKPT_DIR_NAME" ]]; then + WANDB_EXPERIMENT_NAME="$RESUME_CKPT_DIR_NAME" +fi + + +# =================== Ray start =================== +# ray stop at all nodes +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 ray stop + +sleep 10 +# Remove existing Ray cluster +srun --nodes=$worker_num --ntasks=$worker_num --ntasks-per-node=1 rm -rf /tmp/ray/ray_current_cluster + +# Start Ray head node +srun --nodes=1 --ntasks=1 -w "$head_node" --export=ALL \ + ${CONDA_BIN_PATH}ray start --head --node-ip-address="$head_node_ip" --port=$port \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --include-dashboard=True --block & + +sleep 10 + +# Start Ray worker nodes +for ((i = 1; i < worker_num; i++)); do + node_i=${nodes[$i]} + echo "Starting WORKER $i at $node_i" + srun --nodes=1 --ntasks=1 -w "$node_i" --export=ALL \ + ${CONDA_BIN_PATH}ray start --address "$address_head" \ + --num-cpus "${SLURM_CPUS_PER_TASK}" --num-gpus 8 --block & +done +sleep 10 + + +# =================== RL Config =================== +# Note, we borrowed the config format from DAPO while here disabled all DAPO features to run the naive RL baseline. + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.2 + +max_prompt_length=$((1024 * 4)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=False +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +enable_filter_groups=False +filter_groups_metric=acc +max_num_gen_batches=10 +train_prompt_bsz=128 # on-policy model update batchsize: train_prompt_bsz * rollout.n +gen_prompt_bsz=$((train_prompt_bsz * 1)) +n_resp_per_prompt=16 +train_prompt_mini_bsz=16 # model grad update batchsize + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout + +# Mathematically equivalent +sp_size=4 +gen_tp=4 +infer_micro_batch_size=null +train_micro_batch_size=null +use_dynamic_bsz=True +actor_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up model forward & backward but note memory overflow +infer_ppo_max_token_len=$(( (max_prompt_length + max_response_length) * 2)) # increase this to speed up modelforward, but note memory overflow +offload=True + +# =================== Start RL training =================== +"${CONDA_BIN_PATH}python" -m verl.recipe.dapo.src.main_dapo \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + algorithm.filter_groups.enable=${enable_filter_groups} \ + algorithm.filter_groups.metric=${filter_groups_metric} \ + algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.prompt_key=prompt \ + data.truncation='right' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.actor.strategy="fsdp" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.optim.warmup_style=constant \ + actor_rollout_ref.actor.optim.min_lr_ratio=0. \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=${infer_micro_batch_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${top_p}\ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.model.path=$BASE_MODEL \ + actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.rollout.multi_turn.enable=False \ + +actor_rollout_ref.rollout.mode="sync" \ + +actor_rollout_ref.model.override_config.attention_dropout=0. \ + +actor_rollout_ref.model.override_config.embd_pdrop=0. \ + +actor_rollout_ref.model.override_config.resid_pdrop=0. \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + reward_model.reward_manager=async_dapo \ + reward_model.overlong_buffer.enable=${enable_overlong_buffer} \ + reward_model.overlong_buffer.len=${overlong_buffer_len} \ + reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \ + trainer.logger=['console','wandb'] \ + trainer.project_name=${WANDB_PROJECT} \ + trainer.experiment_name=${WANDB_EXPERIMENT_NAME} \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes="${NNODES}" \ + trainer.nnodes=$worker_num \ + trainer.save_freq=5 \ + trainer.test_freq=5 \ + trainer.total_epochs=5 \ + +trainer.val_generations_to_log_to_wandb=30 \ + trainer.resume_mode=auto \ No newline at end of file diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index d1d64448a..d8d5ce4f8 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -32,7 +32,10 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No Raises: NotImplementedError: If the reward function is not implemented for the given data source. """ - reward_metric = extra_info.get("reward_metric", None) + # Handle extra_info format robustly + reward_metric = None + if extra_info and isinstance(extra_info, dict): + reward_metric = extra_info.get("reward_metric", None) # math if data_source.startswith("math"): @@ -90,16 +93,21 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No elif data_source.startswith('stem_web'): from . import stem_llm_judge res = stem_llm_judge.compute_score(data_source=data_source, model_output=solution_str, ground_truth=ground_truth, extra_info=extra_info) + elif data_source in ["reasoning_gym"]: + from . import reasoning_gym + res = reasoning_gym.compute_score(solution_str, ground_truth, extra_info=extra_info) elif data_source in ["ood__ifeval"]: from . import ifeval res = ifeval.compute_score(solution_str, ground_truth, extra_info=extra_info) elif data_source in ["ood__livebench"]: from . import livebench res = livebench.compute_score(solution_str, ground_truth, extra_info=extra_info) + elif data_source in ["ood__ifbench"]: + from . import ifbench + res = ifbench.compute_score(solution_str, ground_truth, extra_info=extra_info) # NOTE: above is added by Reasoning360 elif data_source == "openai/gsm8k": from . import gsm8k - res = gsm8k.compute_score(solution_str, ground_truth) elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval"]: from . import math @@ -148,6 +156,24 @@ def default_compute_score(data_source, solution_str, ground_truth, extra_info=No from . import search_r1_like_qa_em res = search_r1_like_qa_em.compute_score(solution_str, ground_truth) + + elif data_source.startswith("synlogic"): + from .synlogic.synlogic import verifier_classes + from .synlogic.data import Data + + form_solution = solution_str.strip().split('')[-1].strip() + # with open("solution_str_Qwen3-4B.txt_maze", "a") as f: + # f.write("data_source: " + data_source + '\n') + # f.write("solution_str: " + solution_str + '\n') + # f.write("form_solution: " + form_solution + '\n') + # f.write('-'*32 + '\n') + data = Data.from_json_str(extra_info["game_data_str"]) + verifier = verifier_classes[data_source.replace("synlogic_", "")]() + res = verifier.verify(data, form_solution) + if res: + res = 1.0 + else: + res = 0.0 else: raise NotImplementedError(f"Reward function is not implemented for {data_source=}") diff --git a/verl/utils/reward_score/ifbench/__init__.py b/verl/utils/reward_score/ifbench/__init__.py new file mode 100644 index 000000000..5d53be4dc --- /dev/null +++ b/verl/utils/reward_score/ifbench/__init__.py @@ -0,0 +1,67 @@ +import ast +import json + +import numpy as np + +from .instructions_registry import INSTRUCTION_DICT + + +def compute_score(solution_str, ground_truth, extra_info=None): + """ + Compute the reward score for IFBench tasks based on ground truth constraints. + + Args: + solution_str (str): Model's full output, may include a '' section. + ground_truth (str or list): Original ground_truth, either a Python-literal string or list of dicts. + extra_info (dict, optional): Ignored for IFBench since constraints are in ground_truth. + + Returns: + dict: {"score": float, "acc": bool} + """ + # Strip off any thinking section + if "" in solution_str: + answer = solution_str.split("", 1)[1].strip() + else: + answer = solution_str.strip() + + # Parse ground_truth if it's a string + if isinstance(ground_truth, str): + try: + gt_list = ast.literal_eval(ground_truth) + except Exception: + gt_list = json.loads(ground_truth) + else: + gt_list = ground_truth + + # Take the first set of constraints + if not isinstance(gt_list, list) or not gt_list: + return {"score": 0.0, "acc": False} + first_item = gt_list[0] + instruction_ids = first_item.get("instruction_id", []) + kwargs_list = first_item.get("kwargs", []) + + # Evaluate each instruction + results = [] + for instr_id, raw_args in zip(instruction_ids, kwargs_list): + # Prepare args dict + args = {} if raw_args is None else raw_args + # Convert numpy and floats + clean_args = {} + for key, val in args.items(): + if isinstance(val, float): + clean_args[key] = int(val) + elif isinstance(val, np.ndarray): + clean_args[key] = val.tolist() + else: + clean_args[key] = val + + # Build and check instruction + instr_cls = INSTRUCTION_DICT[instr_id] + instr = instr_cls(instr_id) + instr.build_description(**clean_args) + passed = bool(answer and instr.check_following(answer)) + results.append(passed) + + # Return 1.0 if all constraints are satisfied, 0.0 otherwise + score = 1.0 if all(results) else 0.0 + return {"score": score, "acc": score == 1.0} diff --git a/verl/utils/reward_score/ifbench/instructions.py b/verl/utils/reward_score/ifbench/instructions.py new file mode 100644 index 000000000..145fb5be0 --- /dev/null +++ b/verl/utils/reward_score/ifbench/instructions.py @@ -0,0 +1,2686 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Library of instructions.""" + + +import collections +import json +import random +import re +import string +from typing import Dict, Optional, Sequence, Union + +import langdetect +import logging + +from . import instructions_util + +logger = logging.getLogger(__name__) + +_InstructionArgsDtype = Optional[Dict[str, Union[int, str, Sequence[str]]]] + +_LANGUAGES = instructions_util.LANGUAGE_CODES + +# The relational operation for comparison. +_COMPARISON_RELATION = ("less than", "at least") + +# The maximum number of sentences. +_MAX_NUM_SENTENCES = 20 + +# The number of placeholders. +_NUM_PLACEHOLDERS = 4 + +# The number of bullet lists. +_NUM_BULLETS = 5 + +# The options of constrained response. +_CONSTRAINED_RESPONSE_OPTIONS = ("My answer is yes.", "My answer is no.", "My answer is maybe.") + +# The options of starter keywords. +_STARTER_OPTIONS = ( + "I would say", + "My answer is", + "I believe", + "In my opinion", + "I think", + "I reckon", + "I feel", + "From my perspective", + "As I see it", + "According to me", + "As far as I'm concerned", + "To my understanding", + "In my view", + "My take on it is", + "As per my perception", +) + +# The options of ending keywords. +# TODO(jeffreyzhou) add more ending options +_ENDING_OPTIONS = ("Any other questions?", "Is there anything else I can help with?") + +# The number of highlighted sections. +_NUM_HIGHLIGHTED_SECTIONS = 4 + +# The section spliter. +_SECTION_SPLITER = ("Section", "SECTION") + +# The number of sections. +_NUM_SECTIONS = 5 + +# The number of paragraphs. +_NUM_PARAGRAPHS = 5 + +# The postscript marker. +_POSTSCRIPT_MARKER = ("P.S.", "P.P.S") + +# The number of keywords. +_NUM_KEYWORDS = 2 + +# The occurrences of a single keyword. +_KEYWORD_FREQUENCY = 3 + +# The occurrences of a single letter. +_LETTER_FREQUENCY = 10 + +# The occurrences of words with all capital letters. +_ALL_CAPITAL_WORD_FREQUENCY = 20 + +# The number of words in the response. +_NUM_WORDS_LOWER_LIMIT = 100 +_NUM_WORDS_UPPER_LIMIT = 500 + +# phrases +_PHRASES = [ + "Dance like nobody is watching you", + "The early bird catches the worm", + "Time flies when having fun", + "Every cloud has a silver lining", + "Actions speak louder than words", + "Don't judge a book by cover", + "Live each day to the fullest", + "All that glitters is not gold", + "Laughter is the best medicine", + "The pen is mightier than sword", +] + + +class Instruction: + """An instruction template.""" + + def __init__(self, instruction_id): + self.id = instruction_id + + def build_description(self, **kwargs): + raise NotImplementedError("`build_description` not implemented.") + + def get_instruction_args(self): + raise NotImplementedError("`get_instruction_args` not implemented.") + + def get_instruction_args_keys(self): + raise NotImplementedError("`get_instruction_args_keys` not implemented.") + + def check_following(self, value): + raise NotImplementedError("`check_following` not implemented.") + + +class ResponseLanguageChecker(Instruction): + """Check the language of the entire response.""" + + def build_description(self, *, language=None): + """Build the instruction description. + + Args: + language: A string representing the expected language of the response. The + language has to comply to the 97 types defined in + `langid.py` (https://pypi.org/project/langid/1.1.5/), which follows + ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes); + for example, `en` for English, `zh` for Chinese, `fr` for French. + + Returns: + A string representing the instruction description. + """ + self._language = language + if self._language is None: + self._language = random.choice(list(_LANGUAGES.keys())) + # TODO(tianjianlu): opens the description generation to more choices. + self._description_pattern = ( + "Your ENTIRE response should be in {language} language, no other " + "language is allowed." + ) + return self._description_pattern.format(language=_LANGUAGES[self._language]) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"language": self._language} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["language"] + + def check_following(self, value): + """Check if the language of the entire response follows the instruction. + + Args: + value: A string representing the response. + + Returns: + True if the language of `value` follows instruction; otherwise False. + """ + assert isinstance(value, str) + + try: + return langdetect.detect(value) == self._language + except langdetect.LangDetectException as e: + # Count as instruction is followed. + logging.error("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 + return True + + +class NumberOfSentences(Instruction): + """Check the number of sentences.""" + + def build_description(self, *, num_sentences=None, relation=None): + """Build the instruction description. + + Args: + num_sentences: An integer specifying the number of sentences as a + threshold. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of sentences < the threshold; + if 'at least', the actual number of sentences >= the threshold. + + Returns: + A string representing the instruction description. + """ + # The number of sentences as a threshold for comparison. + self._num_sentences_threshold = num_sentences + if self._num_sentences_threshold is None or self._num_sentences_threshold < 0: + self._num_sentences_threshold = random.randint(1, _MAX_NUM_SENTENCES) + + if relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif relation not in _COMPARISON_RELATION: + raise ValueError( + f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." + ) + else: + self._comparison_relation = relation + + self._description_pattern = "Your response should contain {relation} {num_sentences} sentences." + return self._description_pattern.format( + relation=self._comparison_relation, num_sentences=self._num_sentences_threshold + ) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_sentences": self._num_sentences_threshold, "relation": self._comparison_relation} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_sentences", "relation"] + + def check_following(self, value): + """Check if the number of sentences follows the instruction. + + Args: + value: A string representing the response. + + Returns: + True if the response follows the instruction. + + Raise: + ValueError if the string in `instruction_args` is not in + [`less_than`, `at_least`]. + """ + num_sentences = instructions_util.count_sentences(value) + if self._comparison_relation == _COMPARISON_RELATION[0]: + return num_sentences < self._num_sentences_threshold + elif self._comparison_relation == _COMPARISON_RELATION[1]: + return num_sentences >= self._num_sentences_threshold # pytype: disable=bad-return-type + + +class PlaceholderChecker(Instruction): + """Check the placeholders in template writing.""" + + def build_description(self, *, num_placeholders=None): + """Build the instruction description. + + Args: + num_placeholders: An integer denoting the minimum number of + placeholders required in the response. + + Returns: + A string representing the instruction description. + """ + self._num_placeholders = num_placeholders + if self._num_placeholders is None or self._num_placeholders < 0: + self._num_placeholders = random.randint(1, _NUM_PLACEHOLDERS) + self._description_pattern = ( + "The response must contain at least {num_placeholders} placeholders " + + "represented by square brackets, such as [address]." + ) + return self._description_pattern.format(num_placeholders=self._num_placeholders) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_placeholders": self._num_placeholders} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_placeholders"] + + def check_following(self, value): + """Check if the number of placeholders follows the instruction. + + Args: + value: A string representing the response. + + Returns: + True if the actual number of placeholders in the response is greater than + or equal to `num_placeholders`; otherwise, False. + """ + placeholders = re.findall(r"\[.*?\]", value) + num_placeholders = len(placeholders) + return num_placeholders >= self._num_placeholders + + +class BulletListChecker(Instruction): + """Checks the bullet list in the prompt.""" + + def build_description(self, *, num_bullets=None): + """Build the instruction description. + + Args: + num_bullets: An integer specifying the exact number of bullet lists + that is required to appear in the response. + + Returns: + A string representing the instruction description. + """ + self._num_bullets = num_bullets + if self._num_bullets is None or self._num_bullets < 0: + self._num_bullets = random.randint(1, _NUM_BULLETS) + self._description_pattern = ( + "Your answer must contain exactly {num_bullets} bullet points. " + + "Use the markdown bullet points such as:\n" + + "* This is point 1. \n" + + "* This is point 2" + ) + return self._description_pattern.format(num_bullets=self._num_bullets) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_bullets": self._num_bullets} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_bullets"] + + def check_following(self, value): + r"""Check if the number of bullet lists meets the requirement. + + Args: + value: A string representing the response. The response is expected to + contain some bullet lists that start with `\*`. + + Returns: + True if the actual number of bullet lists in the response meets the + requirement. + """ + bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE) + bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE) + num_bullet_lists = len(bullet_lists) + len(bullet_lists_2) + return num_bullet_lists == self._num_bullets + + +class ConstrainedResponseChecker(Instruction): + """Checks the constrained response.""" + + def build_description(self): + """Build the instruction description.""" + # A sequence of string(s) representing the options of the expected response. + self._constrained_responses = _CONSTRAINED_RESPONSE_OPTIONS + self._description_pattern = "Answer with one of the following options: {response_options}" + return self._description_pattern.format(response_options=self._constrained_responses) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response matches the constrained options. + + Args: + value: A string representing the response. + + Returns: + True if the actual response contains one of the options in the constrained + responses; otherwise False. + """ + value = value.strip() + for constrained_response in self._constrained_responses: + if constrained_response in value: + return True + return False + + +class ConstrainedStartChecker(Instruction): + """Checks the response start.""" + + def build_description(self, *, starter=None): + """Build the instruction description. + + Args: + starter: A string representing the keyward that the response should start + with. + + Returns: + A string representing the instruction description. + """ + self._starter = starter.strip() if isinstance(starter, str) else starter + if self._starter is None: + self._starter = random.choice(_STARTER_OPTIONS) + self._description_pattern = ( + "During the conversation, when it is your turn, " + "please always start with {starter}" + ) + return self._description_pattern.format(starter=self._starter) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"starter": self._starter} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["starter"] + + def check_following(self, value): + """Checks if the response starts with the constrained keyword or phrase. + + Args: + value: A string representing the response. + + Returns: + True if the response starts with the given phrase or keyword that is + contained in `instruction_args`; otherwise, False. + """ + response_pattern = r"^\s*" + self._starter + r".*$" + response_with_constrained_start = re.search(response_pattern, value, flags=re.MULTILINE) + return True if response_with_constrained_start else False + + +class HighlightSectionChecker(Instruction): + """Checks the highlighted section.""" + + def build_description(self, *, num_highlights=None): + """Build the instruction description. + + Args: + num_highlights: An integer specifying the minimum number of highlighted + sections. + + Returns: + A string representing the instruction description. + """ + self._num_highlights = num_highlights + if self._num_highlights is None or self._num_highlights < 0: + self._num_highlights = random.randint(1, _NUM_HIGHLIGHTED_SECTIONS) + + self._description_pattern = ( + "Highlight at least {num_highlights} sections in your answer with " + + "markdown, i.e. *highlighted section*." + ) + + return self._description_pattern.format(num_highlights=self._num_highlights) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_highlights": self._num_highlights} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_highlights"] + + def check_following(self, value): + """Checks if the number of highlighted sections meets the requirement. + + Args: + value: a string repesenting the response. The response is expected to + contain highlighted sections in the format of *highlighted*. + + Returns: + True if the actual number of highlighted sections in the format of + *highlighed sections* meets the minimum requirement; otherwise False. + """ + num_highlights = 0 + highlights = re.findall(r"\*[^\n\*]*\*", value) + double_highlights = re.findall(r"\*\*[^\n\*]*\*\*", value) + for highlight in highlights: + if highlight.strip("*").strip(): + num_highlights += 1 + for highlight in double_highlights: + if highlight.removeprefix("**").removesuffix("**").strip(): + num_highlights += 1 + + return num_highlights >= self._num_highlights + + +class SectionChecker(Instruction): + """Checks the sections.""" + + def build_description(self, *, section_spliter=None, num_sections=None): + """Build the instruction description. + + Args: + section_spliter: A string represents the section spliter keyword that + marks a new section, i.e., `Section` or `SECTION`. + num_sections: An integer specifying the number of sections. + + Returns: + A string representing the instruction description. + """ + self._section_spliter = section_spliter.strip() if isinstance(section_spliter, str) else section_spliter + if self._section_spliter is None: + self._section_spliter = random.choice(_SECTION_SPLITER) + + self._num_sections = num_sections + if self._num_sections is None or self._num_sections < 0: + self._num_sections = random.randint(1, _NUM_SECTIONS) + + self._description_pattern = ( + "Your response must have {num_sections} sections. Mark the beginning " + + "of each section with {section_spliter} X, such as:\n" + + "{section_spliter} 1\n" + + "[content of section 1]\n" + + "{section_spliter} 2\n" + + "[content of section 2]" + ) + + return self._description_pattern.format(num_sections=self._num_sections, section_spliter=self._section_spliter) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"section_spliter": self._section_spliter, "num_sections": self._num_sections} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["section_spliter", "num_sections"] + + def check_following(self, value): + """Checks the response contains multiple sections. + + Args: + value: A string representing the response. The response is expected + to contain multiple sections (number of sections is greater than 1). + A new section starts with `Section 1`, where the number denotes the + section index. + + Returns: + True if the number of sections in the response is greater than or equal to + the minimum number of sections; otherwise, False. + """ + section_splitter_patten = r"\s?" + self._section_spliter + r"\s?\d+\s?" + sections = re.split(section_splitter_patten, value) + num_sections = len(sections) - 1 + return num_sections >= self._num_sections + + +class ParagraphChecker(Instruction): + """Checks the paragraphs.""" + + def build_description(self, *, num_paragraphs=None): + """Build the instruction description. + + Args: + num_paragraphs: An integer specifying the number of paragraphs. + + Returns: + A string representing the instruction description. + """ + self._num_paragraphs = num_paragraphs + if self._num_paragraphs is None or self._num_paragraphs < 0: + self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) + + self._description_pattern = ( + "There should be {num_paragraphs} paragraphs. " + "Paragraphs are separated with the markdown divider: ***" + ) + + return self._description_pattern.format(num_paragraphs=self._num_paragraphs) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_paragraphs": self._num_paragraphs} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_paragraphs"] + + def check_following(self, value): + """Checks the response contains required number of paragraphs. + + Args: + value: A string representing the response. The response may contain + paragraphs that are separated by the markdown divider: `***`. + + Returns: + True if the actual number of paragraphs is the same as required; + otherwise, False. + """ + paragraphs = re.split(r"\s?\*\*\*\s?", value) + num_paragraphs = len(paragraphs) + + for index, paragraph in enumerate(paragraphs): + if not paragraph.strip(): + if index == 0 or index == len(paragraphs) - 1: + num_paragraphs -= 1 + else: + return False + + return num_paragraphs == self._num_paragraphs + + +class PostscriptChecker(Instruction): + """Checks the postscript.""" + + def build_description(self, *, postscript_marker=None): + """Build the instruction description. + + Args: + postscript_marker: A string containing the keyword that marks the start + of the postscript section. + + Returns: + A string representing the instruction description. + """ + self._postscript_marker = ( + postscript_marker.strip() if isinstance(postscript_marker, str) else postscript_marker + ) + if self._postscript_marker is None: + self._postscript_marker = random.choice(_POSTSCRIPT_MARKER) + + self._description_pattern = ( + "At the end of your response, please explicitly add a postscript " + "starting with {postscript}" + ) + + return self._description_pattern.format(postscript=self._postscript_marker) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"postscript_marker": self._postscript_marker} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["postscript_marker"] + + def check_following(self, value): + """Checks if the response follows the postscript format. + + Args: + value: a string representing the response. The response is expected to + contain a postscript section. + + Returns: + True if the response contains a postscript section starting with + the keyword containing in the `instruction_args`; otherwise False. + """ + value = value.lower() + if self._postscript_marker == "P.P.S": + postscript_pattern = r"\s*p\.\s?p\.\s?s.*$" + elif self._postscript_marker == "P.S.": + postscript_pattern = r"\s*p\.\s?s\..*$" + else: + postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$" + postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE) + return True if postscript else False + + +class RephraseChecker(Instruction): + """Checks the repharse.""" + + def build_description(self, *, original_message): + """Build the instruction description. + + Args: + original_message: A string representing the original message. The + rephrased response should only change its words/sentences in between + its two asterisks, for example, *change me*. Both original and rephrased + messages should contain the changes in the form of *change me*. + + Returns: + A string representing the instruction description. + """ + if not self.is_change(original_message): + raise ValueError(f"Message {original_message} does not contain changes in the form of *change me*.") + + self._reference_without_change = original_message + self._description = ( + "Rephrasing: Your rephrased response should only" + + "change the words/sentences in between two asterisks" + + "such as *change me*." + ) + return self._description + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"original_message": self._reference_without_change} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["original_message"] + + def check_following(self, value): + r"""Checks if the rephrasing follows the instruction. + + Args: + value: A string representing the response, which is expected to rephras + the string of `instruction_args`. + + Returns: + True if `value` and `instruction_args` only differ by the words/sentences + in between two asterisks such as *change me*; otherwise, False. + """ + + if not self.is_change(value): + raise ValueError(f"value {value} does not contain changes in the form of *change me*.") + + response_without_changes = self.strip_changes(value) + reference_without_changes = self.strip_changes(self._reference_without_change) + + return response_without_changes == reference_without_changes + + def is_change(self, response): + """Check if there is change in the response in the form of *change me*.""" + return re.search(r"\*.*\*", response) + + def strip_changes(self, response): + """Strips off the changes.""" + return re.sub(r"\*.*\*", "", response) + + +class KeywordChecker(Instruction): + """Check the exisitence of certain keywords.""" + + def build_description(self, *, keywords=None): + """Build the instruction description. + + Args: + keywords: A sequence of strings representing the keywords that are + expected in the response. + + Returns: + A string representing the instruction description. + """ + + if not keywords: + self._keywords = instructions_util.generate_keywords(num_keywords=_NUM_KEYWORDS) + else: + self._keywords = keywords + self._keywords = sorted(self._keywords) + + self._description_pattern = "Include keywords {keywords} in the response." + + return self._description_pattern.format(keywords=self._keywords) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"keywords": self._keywords} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keywords"] + + def check_following(self, value): + """Check if the response contain the expected keywords.""" + for keyword in self._keywords: + if not re.search(keyword, value, flags=re.IGNORECASE): + return False + return True + + +class KeywordFrequencyChecker(Instruction): + """Check the keyword frequency.""" + + def build_description(self, *, keyword=None, frequency=None, relation=None): + """Build the instruction description. + + Args: + keyword: A string representing a keyword that is expected in the response. + frequency: An integer specifying the number of times `keyword` is expected + to appear in the response. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of occurrences < frequency; + if 'at least', the actual number of occurrences >= frequency. + + Returns: + A string representing the instruction description. + """ + if not keyword: + self._keyword = instructions_util.generate_keywords(num_keywords=1)[0] + else: + self._keyword = keyword.strip() + + self._frequency = frequency + if self._frequency is None or self._frequency < 0: + self._frequency = random.randint(1, _KEYWORD_FREQUENCY) + + if relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif relation not in _COMPARISON_RELATION: + raise ValueError( + f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." + ) + else: + self._comparison_relation = relation + + self._description_pattern = ( + "In your response, the word {keyword} should appear {relation} " + "{frequency} times." + ) + + return self._description_pattern.format( + keyword=self._keyword, relation=self._comparison_relation, frequency=self._frequency + ) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"keyword": self._keyword, "frequency": self._frequency, "relation": self._comparison_relation} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keyword", "frequency", "relation"] + + def check_following(self, value): + """Checks if the response contain the keyword with required frequency.""" + actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return actual_occurrences < self._frequency + elif self._comparison_relation == _COMPARISON_RELATION[1]: + return actual_occurrences >= self._frequency # pytype: disable=bad-return-type + + +class NumberOfWords(Instruction): + """Checks the number of words.""" + + def build_description(self, *, num_words=None, relation=None): + """Build the instruction description. + + Args: + num_words: An integer specifying the number of words contained in the + response. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of words < num_words; + if 'at least', the actual number of words >= num_words. + + Returns: + A string representing the instruction description. + """ + + self._num_words = num_words + if self._num_words is None or self._num_words < 0: + self._num_words = random.randint(_NUM_WORDS_LOWER_LIMIT, _NUM_WORDS_UPPER_LIMIT) + + if relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif relation not in _COMPARISON_RELATION: + raise ValueError( + f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." + ) + else: + self._comparison_relation = relation + + self._description_pattern = "Answer with {relation} {num_words} words." + + return self._description_pattern.format(relation=self._comparison_relation, num_words=self._num_words) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_words": self._num_words, "relation": self._comparison_relation} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_words", "relation"] + + def check_following(self, value): + """Checks if the response contains the expected number of words.""" + num_words = instructions_util.count_words(value) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return num_words < self._num_words + elif self._comparison_relation == _COMPARISON_RELATION[1]: + return num_words >= self._num_words # pytype: disable=bad-return-type + + +class JsonFormat(Instruction): + """Check the Json format.""" + + def build_description(self): + self._description_pattern = ( + "Entire output should be wrapped in JSON format. You can use markdown ticks such as ```." + ) + return self._description_pattern + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + value = ( + value.strip() + .removeprefix("```json") + .removeprefix("```Json") + .removeprefix("```JSON") + .removeprefix("```") + .removesuffix("```") + .strip() + ) + try: + json.loads(value) + except ValueError: + return False + return True + + +class ParagraphFirstWordCheck(Instruction): + """Check the paragraph and the first word of the nth paragraph.""" + + def build_description(self, num_paragraphs=None, nth_paragraph=None, first_word=None): + r"""Build the instruction description. + + Args: + num_paragraphs: An integer indicating the number of paragraphs expected + in the response. A paragraph is a subset of the string that is + expected to be separated by '\n\n'. + nth_paragraph: An integer indicating the paragraph number that we look at. + Note that n starts from 1. + first_word: A string that represent the first word of the bth paragraph. + + Returns: + A string representing the instruction description. + """ + self._num_paragraphs = num_paragraphs + if self._num_paragraphs is None or self._num_paragraphs < 0: + self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) + + self._nth_paragraph = nth_paragraph + if self._nth_paragraph is None or self._nth_paragraph <= 0 or self._nth_paragraph > self._num_paragraphs: + self._nth_paragraph = random.randint(1, self._num_paragraphs + 1) + + self._first_word = first_word + if self._first_word is None: + self._first_word = instructions_util.generate_keywords(num_keywords=1)[0] + self._first_word = self._first_word.lower() + + self._description_pattern = ( + "There should be {num_paragraphs} paragraphs. " + + "Paragraphs and only paragraphs are separated with each other by two " + + "new lines as if it was '\\n\\n' in python. " + + "Paragraph {nth_paragraph} must start with word {first_word}." + ) + + return self._description_pattern.format( + num_paragraphs=self._num_paragraphs, nth_paragraph=self._nth_paragraph, first_word=self._first_word + ) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return { + "num_paragraphs": self._num_paragraphs, + "nth_paragraph": self._nth_paragraph, + "first_word": self._first_word, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_paragraphs", "nth_paragraph", "first_word"] + + def check_following(self, value): + """Checks for required number of paragraphs and correct first word. + + Args: + value: a string representing the response. The response may contain + paragraphs that are separated by two new lines and the first word of + the nth paragraph will have to match a specified word. + + Returns: + True if the number of paragraphs is the same as required and the first + word of the specified paragraph is the same as required. Otherwise, false. + """ + + paragraphs = re.split(r"\n\n", value) + num_paragraphs = len(paragraphs) + + for paragraph in paragraphs: + if not paragraph.strip(): + num_paragraphs -= 1 + + # check that index doesn't go out of bounds + if self._nth_paragraph <= num_paragraphs: + paragraph = paragraphs[self._nth_paragraph - 1].strip() + if not paragraph: + return False + else: + return False + + first_word = "" + punctuation = {".", ",", "?", "!", "'", '"'} + + # get first word and remove punctuation + word = paragraph.split()[0].strip() + # TODO(jeffrey): make more complex? + word = word.lstrip("'") + word = word.lstrip('"') + + for letter in word: + if letter in punctuation: + break + first_word += letter.lower() + + return num_paragraphs == self._num_paragraphs and first_word == self._first_word + + +# TODO(jeffrey) add relation - at least/at most? +class KeySentenceChecker(Instruction): + """Check the existence of certain key sentences.""" + + def build_description(self, key_sentences=None, num_sentences=None): + """Build the instruction description. + + Args: + key_sentences: A sequences of strings representing the key sentences that + are expected in the response. + num_sentences: The number of key sentences that are expected to be seen in + the response. + + Returns: + A string representing the instruction description. + """ + + if not key_sentences: + # TODO(jeffrey) make a generate sentences function? wonderwords package + self._key_sentences = set(["For now, this is fine."]) + else: + self._key_sentences = key_sentences + + if not num_sentences: + self._num_sentences = random.randint(1, len(self._key_sentences)) + else: + self._num_sentences = num_sentences + + self._description_pattern = "Include {num_sentences} of the following sentences {key_sentences}" + + return self._description_pattern.format(num_sentences=self._num_sentences, key_sentences=self._key_sentences) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_sentences": self._num_sentences, "key_sentences": list(self._key_sentences)} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_sentences", "key_sentences"] + + def check_following(self, value): + """Checks if the response contains the expected key sentences.""" + count = 0 + sentences = instructions_util.split_into_sentences(value) + for sentence in self._key_sentences: + if sentence in sentences: + count += 1 + + return count == self._num_sentences + + +class ForbiddenWords(Instruction): + """Checks that specified words are not used in response.""" + + def build_description(self, forbidden_words=None): + """Build the instruction description. + + Args: + forbidden_words: A sequences of strings respresenting words that are not + allowed in the response. + + Returns: + A string representing the instruction description. + """ + + if not forbidden_words: + self._forbidden_words = instructions_util.generate_keywords(num_keywords=_NUM_KEYWORDS) + else: + self._forbidden_words = list(set(forbidden_words)) + self._forbidden_words = sorted(self._forbidden_words) + self._description_pattern = "Do not include keywords {forbidden_words} in the response." + + return self._description_pattern.format(forbidden_words=self._forbidden_words) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"forbidden_words": self._forbidden_words} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["forbidden_words"] + + def check_following(self, value): + """Check if the response does not contain the expected keywords.""" + for word in self._forbidden_words: + if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE): + return False + return True + + +class RephraseParagraph(Instruction): + """Checks that the paragraph is rephrased.""" + + def build_description(self, *, original_paragraph, low, high): + """Builds the instruction description. + + Args: + original_paragraph: A string presenting the original paragraph. The + rephrases response should have betweeb low-high words in common. + low: An integer presenting the lower bound of similar words. + high: An integer representing the upper bound of similar words. + + Returns: + A string representing the instruction description. + """ + # TODO(jeffrey) make more encompassing + self._original_paragraph = original_paragraph + self._low = low + self._high = high + + self._description = ( + "Rephrase the following paragraph: " + + "{original_paragraph}\nYour response should have " + + "between {low} and {high} of the same words. " + + "Words are the same if and only if all of the " + + "letters, ignoring cases, are the same. For " + + "example, 'run' is the same as 'Run' but different " + + "to 'ran'." + ) + + return self._description.format(original_paragraph=original_paragraph, low=self._low, high=self._high) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"original_paragraph": self._original_paragraph, "low": self._low, "high": self._high} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["original_paragraph", "low", "high"] + + def check_following(self, value): + val_words = re.findall(r"\w+", value.lower()) + original_words = re.findall(r"\w+", self._original_paragraph.lower()) + similar_words = 0 + + dict_val = collections.Counter(val_words) + dict_original = collections.Counter(original_words) + + for word in dict_original: + similar_words += min(dict_original[word], dict_val[word]) + + return similar_words >= self._low and similar_words <= self._high + + +class TwoResponsesChecker(Instruction): + """Check that two responses were given.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Give two different responses. Responses and only responses should" + " be separated by 6 asterisk symbols: ******." + ) + return self._description_pattern + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response has two different answers. + + Args: + value: A string representing the response. + + Returns: + True if two responses are detected and false otherwise. + """ + valid_responses = list() + responses = value.split("******") + for index, response in enumerate(responses): + if not response.strip(): + if index != 0 and index != len(responses) - 1: + return False + else: + valid_responses.append(response) + return len(valid_responses) == 2 and valid_responses[0].strip() != valid_responses[1].strip() + + +class RepeatPromptThenAnswer(Instruction): + """Checks that Prompt is first repeated then answered.""" + + def build_description(self, *, prompt_to_repeat=None): + """Build the instruction description. + + Args: + prompt_to_repeat: The prompt that is meant to be repeated. + + Returns: + A string representing the instruction description. + """ + if not prompt_to_repeat: + raise ValueError("prompt_to_repeat must be set.") + else: + self._prompt_to_repeat = prompt_to_repeat + self._description_pattern = ( + "First repeat the request word for word without change," + " then give your answer (1. do not say any words or characters" + " before repeating the request; 2. the request you need to repeat" + " does not include this sentence)" + ) + return self._description_pattern + + def get_instruction_args(self): + return {"prompt_to_repeat": self._prompt_to_repeat} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["prompt_to_repeat"] + + def check_following(self, value): + if value.strip().lower().startswith(self._prompt_to_repeat.strip().lower()): + return True + return False + + +class EndChecker(Instruction): + """Checks that the prompt ends with a given phrase.""" + + def build_description(self, *, end_phrase=None): + """Build the instruction description. + + Args: + end_phrase: A string representing the phrase the response should end with. + + Returns: + A string representing the instruction description. + """ + self._end_phrase = end_phrase.strip() if isinstance(end_phrase, str) else end_phrase + if self._end_phrase is None: + self._end_phrase = random.choice(_ENDING_OPTIONS) + self._description_pattern = ( + "Finish your response with this exact phrase {ender}. No other words should follow this phrase." + ) + return self._description_pattern.format(ender=self._end_phrase) + + def get_instruction_args(self): + return {"end_phrase": self._end_phrase} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["end_phrase"] + + def check_following(self, value): + """Checks if the response ends with the expected phrase.""" + value = value.strip().strip('"').lower() + self._end_phrase = self._end_phrase.strip().lower() + return value.endswith(self._end_phrase) + + +class TitleChecker(Instruction): + """Checks the response for a title.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Your answer must contain a title, wrapped in double angular brackets, such as <>." + ) + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response contains a title.""" + pattern = r"<<[^\n]+>>" + re_pattern = re.compile(pattern) + titles = re.findall(re_pattern, value) + + for title in titles: + if title.lstrip("<").rstrip(">").strip(): + return True + return False + + +class LetterFrequencyChecker(Instruction): + """Checks letter frequency.""" + + def build_description(self, *, letter=None, let_frequency=None, let_relation=None): + """Build the instruction description. + + Args: + letter: A string representing a letter that is expected in the response. + let_frequency: An integer specifying the number of times `keyword` is + expected to appear in the response. + let_relation: A string in (`less than`, `at least`), defining the + relational operator for comparison. Two relational comparisons are + supported for now; if 'less than', the actual number of + occurrences < frequency; if 'at least', the actual number of + occurrences >= frequency. + + Returns: + A string representing the instruction description. + """ + if not letter or len(letter) > 1 or ord(letter.lower()) < 97 or ord(letter.lower()) > 122: + self._letter = random.choice(list(string.ascii_letters)) + else: + self._letter = letter.strip() + self._letter = self._letter.lower() + + self._frequency = let_frequency + if self._frequency is None or self._frequency < 0: + self._frequency = random.randint(1, _LETTER_FREQUENCY) + + if let_relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif let_relation not in _COMPARISON_RELATION: + raise ValueError( + "The supported relation for comparison must be in " + f"{_COMPARISON_RELATION}, but {let_relation} is given." + ) + else: + self._comparison_relation = let_relation + + self._description_pattern = ( + "In your response, the letter {letter} should appear {let_relation} {let_frequency} times." + ) + + return self._description_pattern.format( + letter=self._letter, let_frequency=self._frequency, let_relation=self._comparison_relation + ) + + def get_instruction_args(self): + """Returns the keyword args of build description.""" + return {"letter": self._letter, "let_frequency": self._frequency, "let_relation": self._comparison_relation} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["letter", "let_frequency", "let_relation"] + + def check_following(self, value): + """Checks that the response contains the letter at the right frequency.""" + value = value.lower() + letters = collections.Counter(value) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return letters[self._letter] < self._frequency + else: + return letters[self._letter] >= self._frequency + + +class CapitalLettersEnglishChecker(Instruction): + """Checks that the response is in english and is in all capital letters.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "Your entire response should be in English, and in all capital letters." + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response is in English and in all capital letters.""" + assert isinstance(value, str) + + try: + return value.isupper() and langdetect.detect(value) == "en" + except langdetect.LangDetectException as e: + # Count as instruction is followed. + logging.error("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 + return True + + +class LowercaseLettersEnglishChecker(Instruction): + """Checks that the response is in english and is in all lowercase letters.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Your entire response should be in English, and in all lowercase letters. No capital letters are allowed." + ) + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response is in English and in all lowercase letters.""" + assert isinstance(value, str) + + try: + return value.islower() and langdetect.detect(value) == "en" + except langdetect.LangDetectException as e: + # Count as instruction is followed. + logging.error("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 + return True + + +class CommaChecker(Instruction): + """Checks the response for no commas.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "In your entire response, refrain from the use of any commas." + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response does not contain commas.""" + return not re.search(r"\,", value) + + +class CapitalWordFrequencyChecker(Instruction): + """Checks frequency of words with all capital letters.""" + + def build_description(self, capital_frequency=None, capital_relation=None): + """Build the instruction description. + + Args: + capital_frequency: An integer that represents the number of words that + should be in all capital letters. + capital_relation: A string that is 'at least' or 'at most' that refers to + the frequency. + + Returns: + A string representing the instruction description. + """ + self._frequency = capital_frequency + if self._frequency is None: + self._frequency = random.randint(1, _ALL_CAPITAL_WORD_FREQUENCY) + + self._comparison_relation = capital_relation + if capital_relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif capital_relation not in _COMPARISON_RELATION: + raise ValueError( + "The supported relation for comparison must be in " + f"{_COMPARISON_RELATION}, but {capital_relation} is given." + ) + + self._description_pattern = ( + "In your response, words with all capital letters should appear {relation} {frequency} times." + ) + + return self._description_pattern.format(frequency=self._frequency, relation=self._comparison_relation) + + def get_instruction_args(self): + """Returns the keyword args of build description.""" + return {"capital_frequency": self._frequency, "capital_relation": self._comparison_relation} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["capital_frequency", "capital_relation"] + + def check_following(self, value): + """Checks the frequency of words with all capital letters.""" + # Hyphenated words will count as one word + words = instructions_util.nltk.word_tokenize(value) + capital_words = [word for word in words if word.isupper()] + + capital_words = len(capital_words) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return capital_words < self._frequency + else: + return capital_words >= self._frequency + + +class QuotationChecker(Instruction): + """Checks response is wrapped with double quotation marks.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "Wrap your entire response with double quotation marks." + return self._description_pattern + + def get_instruction_args(self): + """Returns the keyword args of build description.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response is wrapped with double quotation marks.""" + value = value.strip() + return len(value) > 1 and value[0] == '"' and value[-1] == '"' + + +class RepeatPhraseChecker(Instruction): + "Repeat the phrase {phrase} exactly {small_n} times, transforming it slightly each time by replacing only one word in the center of the phrase." + + def build_description(self, phrase=None, small_n=None): + """Build the instruction description. + + Args: + phrase: A string representing the phrase to be repeated. + N: An integer representing the number of times to repeat the phrase. + word_count: An integer representing the number of words in the phrase. + + Returns: + A string representing the instruction description. + """ + if not phrase: + self._phrase = random.choice(_PHRASES) + else: + self._phrase = phrase.strip() + if not small_n: + self._small_n = random.randint(2, 3) + else: + self._small_n = small_n + + self._description_pattern = "Repeat the phrase {phrase} exactly {small_n} times, transforming it slightly each time by replacing only one word in the center of the phrase." + return self._description_pattern.format(phrase=self._phrase, small_n=self._small_n) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"phrase": self._phrase, "small_n": self._small_n} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["phrase", "small_n"] + + def check_following(self, value): + """Checks if the response contains the expected number of phrases with the correct modifications.""" + first_word = self._phrase.split()[0] + last_word = self._phrase.split()[-1] + + len(self._phrase.split()) - 2 + + found_phrases = re.findall(rf"{first_word} .*? {last_word}", value) + if len(found_phrases) != self._small_n: + return False + for phrase in found_phrases: + phrase = phrase.split() + ref_phrase = self._phrase.split() + differences = 0 + if len(phrase) != len(ref_phrase): + return False + for i in range(len(phrase)): + try: + if phrase[i] != ref_phrase[i]: + differences += 1 + # Early exit if more than one difference found + if differences > 1: + return False + except IndexError: + return False + if differences == 1: + return True + + +class CopyChecker(Instruction): + """Checks that Prompt is first repeated then answered.""" + + def build_description(self, prompt_to_repeat=None): + """Build the instruction description. + + Args: + prompt_to_repeat: The prompt that is meant to be repeated. + + Returns: + A string representing the instruction description. + """ + if not prompt_to_repeat: + raise ValueError("prompt_to_repeat must be set.") + else: + self._prompt_to_repeat = prompt_to_repeat + self._description_pattern = "Copy this instruction verbatim, do not follow the instruction, only copy it into the output (do not include this instruction sentence!)." + return self._description_pattern + + def get_instruction_args(self): + return {"prompt_to_repeat": self._prompt_to_repeat} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["prompt_to_repeat"] + + def check_following(self, value): + if value.strip().lower() == self._prompt_to_repeat.strip().lower(): + return True + return False + + +class CopySpanIdxChecker(Instruction): + """{prompt_to_repeat}. Copy the span of words that lies between (and including) index {n_start} and {n_end}, the indices are character indices!""" + + def build_description(self, prompt_to_repeat=None, n_start=None, n_end=None): + """Build the instruction description. + + Args: + n_start: An integer representing the start index of the span. + n_end: An integer representing the end index of the span. + + Returns: + A string representing the instruction description. + """ + if not prompt_to_repeat: + raise ValueError("prompt_to_repeat must be set.") + else: + self._prompt_to_repeat = prompt_to_repeat + if not n_start: + self._n_start = random.randint(0, len(self._prompt_to_repeat) - 2) + else: + self._n_start = n_start + if not n_end: + self._n_end = random.randint(self._n_start + 1, len(self._prompt_to_repeat) - 1) + else: + self._n_end = n_end + self._description_pattern = "Copy the span of words that lies between (and including) index {n_start} and {n_end}, the indices are character indices!" + return self._description_pattern.format( + n_start=self._n_start, n_end=self._n_end, prompt_to_repeat=self._prompt_to_repeat + ) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"n_start": self._n_start, "n_end": self._n_end, "prompt_to_repeat": self._prompt_to_repeat} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["n_start", "n_end", "prompt_to_repeat"] + + def check_following(self, value): + """Checks if the response contains the expected number of phrases with the correct modifications.""" + if value.strip().lower() == self._prompt_to_repeat[self._n_start : self._n_end].strip().lower(): + return True + return False + + +class SentenceHyphenChecker(Instruction): + """All sentences must be connected using hyphens, with no spaces between them.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "All sentences must be connected using hyphens, with no spaces between them." + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if all sentences are connected using hyphens, with no spaces between them.""" + # 检查是否包含连字符 + if "-" not in value: + return False + + # 按连字符分割 + words = value.split("-") + + # 检查每个片段是否有空格(不应该有) + for word in words: + if word.strip() != word: + return False + if " " in word: + return False + + # 检查是否至少有两个片段 + if len(words) < 2: + return False + + return True + + +class AdjacentLetterChecker(Instruction): + """No two adjacent words can start with consecutive letters of the alphabet.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "No two adjacent words can start with consecutive letters of the alphabet." + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if no two adjacent words start with consecutive letters of the alphabet.""" + words = value.split() + for i in range(len(words) - 1): + first_letter = words[i][0].lower() + second_letter = words[i + 1][0].lower() + if len(first_letter) != 1 or len(second_letter) != 1: + return False + if ord(second_letter) - ord(first_letter) == 1: + return False + return True + + +class SquareBracketChecker(Instruction): + """Enclose every word in your response within square brackets.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "Enclose every word in your response within square brackets." + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if every word in the response is enclosed within square brackets.""" + words = value.split() + for word in words: + if not (word.startswith("[") and word.endswith("]")): + return False + return True + + +class KeywordFrequencyOnceChecker(Instruction): + """Check the keyword frequency.""" + + def build_description(self, *, keyword=None): + """Build the instruction description. + + Args: + keyword: A string representing a keyword that is expected in the response. + frequency: An integer specifying the number of times `keyword` is expected + to appear in the response. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of occurrences < frequency; + if 'at least', the actual number of occurrences >= frequency. + + Returns: + A string representing the instruction description. + """ + if not keyword: + self._keyword = instructions_util.generate_keywords(num_keywords=1)[0] + else: + self._keyword = keyword.strip() + + self._frequency = 1 + + self._description_pattern = "Include keyword {keyword} in your response." + + return self._description_pattern.format(keyword=self._keyword, frequency=self._frequency) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"keyword": self._keyword} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keyword"] + + def check_following(self, value): + """Checks if the response contain the keyword with required frequency.""" + actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) + + if actual_occurrences == 1: + return True + else: + return False + + +class KeywordFrequencyCheckerDifferent(Instruction): + """Check the keyword frequency.""" + + def build_description(self, *, keyword=None, frequency=None, relation=None): + """Build the instruction description. + + Args: + keyword: A string representing a keyword that is expected in the response. + frequency: An integer specifying the number of times `keyword` is expected + to appear in the response. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of occurrences < frequency; + if 'at least', the actual number of occurrences >= frequency. + + Returns: + A string representing the instruction description. + """ + if not keyword: + self._keyword = instructions_util.generate_keywords(num_keywords=1)[0] + else: + self._keyword = keyword.strip() + + self._frequency = frequency + if self._frequency is None or self._frequency < 0: + self._frequency = random.randint(1, _KEYWORD_FREQUENCY) + + if relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif relation not in _COMPARISON_RELATION: + raise ValueError( + f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." + ) + else: + self._comparison_relation = relation + + self._description_pattern = "In your response, the word {keyword} should appear {frequency} times." + + return self._description_pattern.format( + keyword=self._keyword, relation=self._comparison_relation, frequency=self._frequency + ) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"keyword": self._keyword, "frequency": self._frequency, "relation": self._comparison_relation} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keyword", "frequency", "relation"] + + def check_following(self, value): + """Checks if the response contain the keyword with required frequency.""" + actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return actual_occurrences < self._frequency + elif self._comparison_relation == _COMPARISON_RELATION[1]: + return actual_occurrences >= self._frequency # pytype: disable=bad-return-type + + +class ExcludeWordHarderChecker(Instruction): + """Checks that specified words are not used in response.""" + + def build_description(self, keyword=None, instruction=None): + """Build the instruction description. + + Args: + forbidden_words: A sequences of strings respresenting words that are not + allowed in the response. + + Returns: + A string representing the instruction description. + """ + if not keyword: + self._keyword = random.choice(instruction.split()) + else: + self._keyword = keyword.strip() + + self._description_pattern = "Do not include keyword {keyword} in the response." + + return self._description_pattern.format(keyword=self._keyword) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"keyword": self._keyword} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keyword"] + + def check_following(self, value): + """Check if the response does not contain the expected keywords.""" + if " " + self._keyword + " " in value: + return False + return True + + +class ParagraphBasicChecker(Instruction): + """Checks the paragraphs.""" + + def build_description(self): + """Build the instruction description. + + Args: + num_paragraphs: An integer specifying the number of paragraphs. + + Returns: + A string representing the instruction description. + """ + self._description_pattern = ( + "There should be 2 paragraphs. " + "Paragraphs are separated with the markdown divider: ***" + ) + + return self._description_pattern + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks the response contains required number of paragraphs. + + Args: + value: A string representing the response. The response may contain + paragraphs that are separated by the markdown divider: `***`. + + Returns: + True if the actual number of paragraphs is the same as required; + otherwise, False. + """ + paragraphs = re.split(r"\s?\*\*\*\s?", value) + num_paragraphs = len(paragraphs) + + for index, paragraph in enumerate(paragraphs): + if not paragraph.strip(): + if index == 0 or index == len(paragraphs) - 1: + num_paragraphs -= 1 + else: + return False + + return num_paragraphs == 2 + + +class ParagraphBasicChecker2(Instruction): + """Checks the paragraphs.""" + + def build_description(self): + """Build the instruction description. + + Args: + num_paragraphs: An integer specifying the number of paragraphs. + + Returns: + A string representing the instruction description. + """ + self._description_pattern = "There should be 2 paragraphs. Paragraphs and only paragraphs are separated with each other by two line breaks. " + + return self._description_pattern.format() + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks the response contains required number of paragraphs. + + Args: + value: A string representing the response. The response may contain + paragraphs that are separated by the markdown divider: `***`. + + Returns: + True if the actual number of paragraphs is the same as required; + otherwise, False. + """ + paragraphs = re.split(r"\n\n", value) + num_paragraphs = len(paragraphs) + + for index, paragraph in enumerate(paragraphs): + if not paragraph.strip(): + if index == 0 or index == len(paragraphs) - 1: + num_paragraphs -= 1 + else: + return False + + return num_paragraphs == 2 + + +class FirstWordSentChecker(Instruction): + """The first word of each sentence should be the word {first_word}.""" + + def build_description(self, first_word=None): + """Build the instruction description. + + Args: + first_word: A string representing the first word of each sentence. + + Returns: + A string representing the instruction description. + """ + if not first_word: + self._first_word = instructions_util.generate_keywords(num_keywords=1)[0] + else: + if not isinstance(first_word, str): + self._first_word == first_word[0].strip() + else: + self._first_word = first_word.strip() + + self._description_pattern = "The first word of each sentence should be the word {first_word}." + + return self._description_pattern.format(first_word=self._first_word) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"first_word": self._first_word} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["first_word"] + + def check_following(self, value): + """Checks if the first word of each sentence is the expected word. + + Args: + value: A string representing the response. + + Returns: + True if the first word of each sentence is the expected word; + otherwise, False. + """ + sentences = instructions_util.split_into_sentences(value) + + # Check if the first word of each sentence matches the expected word + for sentence in sentences: + if not sentence.strip(): + return False + first_word = sentence.split()[0].strip() + if first_word.lower() != self._first_word.lower(): + return False + return True + + +class FirstWordAnswerChecker(Instruction): + """The first word of each sentence should be the word {first_word}.""" + + def build_description(self, first_word=None): + """Build the instruction description. + + Args: + first_word: A string representing the first word of each sentence. + + Returns: + A string representing the instruction description. + """ + if not first_word: + self._first_word = instructions_util.generate_keywords(num_keywords=1)[0] + else: + self._first_word = first_word.strip() + + self._description_pattern = "The first word of your response should be the word {first_word}." + + return self._description_pattern.format(first_word=self._first_word) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"first_word": self._first_word} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["first_word"] + + def check_following(self, value): + """Checks if the first word of each sentence is the expected word. + + Args: + value: A string representing the response. + + Returns: + True if the first word of each sentence is the expected word; + otherwise, False. + """ + if not value.strip() or len(value.split()) == 0: + return False + first_word = value.split()[0].strip() + if first_word.lower() != self._first_word.lower(): + return False + return True + + +class LastWordSentChecker(Instruction): + """The last word of each sentence should be the word {last_word}.""" + + def build_description(self, last_word=None): + """Build the instruction description. + + Args: + first_word: A string representing the last word of each sentence. + + Returns: + A string representing the instruction description. + """ + if not last_word: + self._last_word = instructions_util.generate_keywords(num_keywords=1)[0] + else: + if not isinstance(last_word, str): + self._last_word = last_word[0].strip() + else: + self._last_word = last_word.strip() + + self._description_pattern = ( + "The last word of each sentence, before punctuation, should be the word {last_word}." + ) + + return self._description_pattern.format(last_word=self._last_word) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"last_word": self._last_word} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["last_word"] + + def check_following(self, value): + """Checks if the first word of each sentence is the expected word. + + Args: + value: A string representing the response. + + Returns: + True if the first word of each sentence is the expected word; + otherwise, False. + """ + sentences = instructions_util.split_into_sentences(value) + + # Check if the first word of each sentence matches the expected word + for sentence in sentences: + if not sentence.strip(): + return False + last_word = sentence.split()[-1].strip() + # remove any punctuation from last_word + last_word = re.sub(r"[^\w\s]", "", last_word) + if last_word.lower() != self._last_word.lower(): + return False + return True + + +class LastWordAnswerChecker(Instruction): + """The last word of your response should be the word {last_word}.""" + + def build_description(self, last_word=None): + """Build the instruction description. + + Args: + first_word: A string representing the last word of each sentence. + + Returns: + A string representing the instruction description. + """ + if not last_word: + self._last_word = instructions_util.generate_keywords(num_keywords=1)[0] + else: + self._last_word = last_word.strip() + + self._description_pattern = "The last word of your response should be the word {last_word}." + + return self._description_pattern.format(last_word=self._last_word) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"last_word": self._last_word} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["last_word"] + + def check_following(self, value): + """Checks if the first word of each sentence is the expected word. + + Args: + value: A string representing the response. + + Returns: + True if the first word of each sentence is the expected word; + otherwise, False. + """ + last_word = value.split()[-1].strip() + # remove any punctuation from last_word + last_word = re.sub(r"[^\w\s]", "", last_word) + if last_word.lower() != self._last_word.lower(): + return False + return True + + +class BiGramWrappingChecker(Instruction): + "Wrap every word bigram in double angular brackets, such as <> <> <> <>." + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Wrap every word bigram in double angular brackets, such as <> <> <> <>." + ) + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if every word bigram is enclosed within double angular brackets.""" + words = value.split() + for i in range(0, len(words) - 1, 2): + if i + 1 < len(words): + if not (words[i].startswith("<<") and words[i + 1].endswith(">>")): + return False + return True + + +class CopyingSimpleChecker(Instruction): + "Repeat the request without change (do not say anything before repeating the request; the request you need to repeat does not include this sentence) and do not answer the actual request!" + + def build_description(self, prompt_to_repeat=None): + """Build the instruction description. + + Args: + prompt_to_repeat: The prompt that is meant to be repeated. + + Returns: + A string representing the instruction description. + """ + if not prompt_to_repeat: + raise ValueError("prompt_to_repeat must be set.") + else: + self._prompt_to_repeat = prompt_to_repeat + self._description_pattern = "Repeat the request without change (do not say anything before repeating the request; the request you need to repeat does not include this sentence) and do not answer the actual request!" + return self._description_pattern + + def get_instruction_args(self): + return {"prompt_to_repeat": self._prompt_to_repeat} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["prompt_to_repeat"] + + def check_following(self, value): + if value.strip().lower() == self._prompt_to_repeat.strip().lower(): + return True + return False + + +class CopyingMultipleChecker(Instruction): + "Repeat the request without change {N} times, separated by 6 asterisk symbols (do not say anything before repeating the request; the request you need to repeat does not include this sentence) and do not answer the actual request!" + + def build_description(self, prompt_to_repeat=None, N=None): + """Build the instruction description. + + Args: + prompt_to_repeat: The prompt that is meant to be repeated. + N: An integer representing the number of times to repeat the phrase. + + Returns: + A string representing the instruction description. + """ + if not prompt_to_repeat: + raise ValueError("prompt_to_repeat must be set.") + else: + self._prompt_to_repeat = prompt_to_repeat + if not N: + self._N = random.randint(2, 3) + else: + self._N = N + self._description_pattern = "Repeat the request without change {N} times, separated by 6 asterisk symbols (do not say anything before repeating the request; the request you need to repeat does not include this sentence) and do not answer the actual request!" + return self._description_pattern.format(N=self._N) + + def get_instruction_args(self): + return {"prompt_to_repeat": self._prompt_to_repeat, "N": self._N} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["prompt_to_repeat", "N"] + + def check_following(self, value): + prompts = value.split("******") + if len(prompts) != self._N: + return False + for prompt in prompts: + if prompt.strip().lower() != self._prompt_to_repeat.strip().lower(): + return False + return True + + +class PunctuationDotChecker(Instruction): + "In your entire response, refrain from the use of . (i.e. dots) as punctuation and in general." + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "In your entire response, refrain from the use of . (i.e. dots) as punctuation and in general." + ) + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response does not contain dots.""" + return not re.search(r"\.", value) + + +class PunctuationExclamationChecker(Instruction): + "In your entire response, refrain from the use of ! (i.e. exclamation marks) as punctuation and in general." + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "In your entire response, refrain from the use of ! (i.e. exclamation marks) as punctuation and in general." + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response does not contain exclamation marks.""" + return not re.search(r"\!", value) + + +class LowercaseCountingChecker(Instruction): + "In your response, all lowercase words should appear at most {N} times." + + def build_description(self, N=None): + """Build the instruction description. + + Args: + N: An integer representing the maximum number of lowercase words allowed. + + Returns: + A string representing the instruction description. + """ + if not N: + self._N = random.randint(2, 3) + else: + self._N = N + self._description_pattern = "In your response, all lowercase words should appear at most {N} times." + return self._description_pattern.format(N=self._N) + + def get_instruction_args(self): + return {"N": self._N} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["N"] + + def check_following(self, value): + """Checks that the response does not contain lowercase words more than N times.""" + lowercase_words = re.findall(r"\b[a-z]+\b", value) + if len(lowercase_words) <= self._N: + return True + else: + return False + + +class LetterCountingChecker(Instruction): + "Answer with {relation} {N} letters." + + def build_description(self, N=None, relation=None): + """Build the instruction description. + + Args: + N: An integer representing the maximum number of letters allowed. + + Returns: + A string representing the instruction description. + """ + if not N: + self._N = random.randint(2, 3) + else: + self._N = N + if not relation: + self._relation = random.choice(_COMPARISON_RELATION) + else: + self._relation = relation + self._description_pattern = "Answer with {relation} {N} letters." + return self._description_pattern.format(N=self._N, relation=self._relation) + + def get_instruction_args(self): + return {"N": self._N, "relation": self._relation} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["N", "relation"] + + def check_following(self, value): + """Checks that the response does not contain lowercase words more than N times.""" + letters = re.findall(r"[a-zA-Z]", value) + if self._relation == "at least": + if len(letters) >= self._N: + return True + else: + return False + elif self._relation == "less than": + if len(letters) < self._N: + return True + else: + return False + + +class CountingCompositionChecker(Instruction): + "Write 3 paragraphs, delimited by the markdown divider: * * *, with exactly {n_sent} sentences each, with exactly {n_words} words in each sentence." + + def build_description(self, n_sent=None, n_words=None): + """Build the instruction description. + + Args: + n_sent: An integer representing the number of sentences in each paragraph. + n_words: An integer representing the number of words in each sentence. + + Returns: + A string representing the instruction description. + """ + if not n_sent: + self._n_sent = random.randint(2, 3) + else: + self._n_sent = n_sent + if not n_words: + self._n_words = random.randint(2, 3) + else: + self._n_words = n_words + self._description_pattern = "Write 3 paragraphs, delimited by the markdown divider: * * *, with exactly {n_sent} sentences each, with exactly {n_words} words in each sentence." + return self._description_pattern.format(n_sent=self._n_sent, n_words=self._n_words) + + def get_instruction_args(self): + return {"n_sent": self._n_sent, "n_words": self._n_words} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["n_sent", "n_words"] + + def check_following(self, value): + """Checks that the response contains the expected number of paragraphs, sentences, and words. + + Args: + value: A string representing the response. + + Returns: + True if the response meets the requirements; otherwise, False. + """ + paragraphs = re.split(r"\s?\*\*\*\s?", value) + num_paragraphs = len(paragraphs) + + for index, paragraph in enumerate(paragraphs): + if not paragraph.strip(): + if index == 0 or index == len(paragraphs) - 1: + num_paragraphs -= 1 + else: + return False + + sentences = instructions_util.split_into_sentences(paragraph) + num_sentences = len(sentences) + + if num_sentences != self._n_sent: + return False + + for sentence in sentences: + words = instructions_util.nltk.word_tokenize(sentence) + num_words = len(words) + + if num_words != self._n_words: + return False + + return num_paragraphs == 3 + + +class CountUniqueChecker(Instruction): + "Only use unique words in your response, no word should be repeated!" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "Only use unique words in your response, no word should be repeated!" + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response contains unique words.""" + words = instructions_util.nltk.word_tokenize(value) + unique_words = set(words) + return len(words) == len(unique_words) + + +class CountIncrementWordChecker(Instruction): + "Include keyword {keyword1} once in your response, keyword {keyword2} twice in your response." + + def build_description(self, keyword1=None, keyword2=None): + """Build the instruction description. + + Args: + keyword1: A string representing a keyword that is expected in the response. + keyword2: A string representing a keyword that is expected in the response. + + Returns: + A string representing the instruction description. + """ + if not keyword1: + self._keyword1 = instructions_util.generate_keywords(num_keywords=1) + else: + self._keyword1 = keyword1.strip() + if not keyword2: + self._keyword2 = instructions_util.generate_keywords(num_keywords=1) + else: + self._keyword2 = keyword2.strip() + + self._description_pattern = ( + "Include keyword {keyword1} once in your response, keyword {keyword2} twice in your response." + ) + + return self._description_pattern.format(keyword1=self._keyword1, keyword2=self._keyword2) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"keyword1": self._keyword1, "keyword2": self._keyword2} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keyword1", "keyword2"] + + def check_following(self, value): + """Checks if the response contains the expected number of keywords. + + Args: + value: A string representing the response. + + Returns: + True if the response contains the expected number of keywords; + otherwise, False. + """ + actual_occurrences1 = len(re.findall(self._keyword1, value, flags=re.IGNORECASE)) + actual_occurrences2 = len(re.findall(self._keyword2, value, flags=re.IGNORECASE)) + + if actual_occurrences1 == 1 and actual_occurrences2 == 2: + return True + else: + return False + + +class PalindromeBasicChecker(Instruction): + "Include a palindrome in your response." + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "Include a palindrome in your response." + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response contains a palindrome. + + Args: + value: A string representing the response. + + Returns: + True if the response contains a palindrome; otherwise, False. + """ + palindromes = [word for word in value.split() if word == word[::-1]] + return len(palindromes) > 0 + + +class KeywordSpecificPositionChecker(Instruction): + "Include keyword {keyword1} in the {n}-th sentence, as the {m}-th word of that sentence." + + def build_description(self, keyword=None, n=None, m=None): + """Build the instruction description. + + Args: + keyword: A string representing a keyword that is expected in the response. + n: An integer representing the sentence number. + m: An integer representing the word number. + + Returns: + A string representing the instruction description. + """ + if not keyword: + self._keyword = instructions_util.generate_keywords(num_keywords=1)[0] + else: + if not isinstance(keyword, str): + self._keyword = keyword[0].strip() + else: + self._keyword = keyword.strip() + if not n: + self._n = random.randint(1, 20) + else: + self._n = n + if not m: + self._m = random.randint(1, 30) + else: + self._m = m + + self._description_pattern = ( + "Include keyword {keyword} in the {n}-th sentence, as the {m}-th word of that sentence." + ) + + return self._description_pattern.format(keyword=self._keyword, n=self._n, m=self._m) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"keyword": self._keyword, "n": self._n, "m": self._m} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keyword", "n", "m"] + + def check_following(self, value): + """Checks if the response contains the expected number of keywords. + + Args: + value: A string representing the response. + + Returns: + True if the response contains the expected number of keywords; + otherwise, False. + """ + sentences = instructions_util.split_into_sentences(value) + if len(sentences) < self._n: + return False + words = instructions_util.nltk.word_tokenize(sentences[self._n - 1]) + if len(words) < self._m: + return False + if words[self._m - 1] == self._keyword: + return True + else: + return False + + +class StartEndChecker(Instruction): + "Start and end your response with the same word (do not write anything after the last word, not even punctuation)." + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "Start and end your response with the same word (do not write anything after the last word, not even punctuation)." + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response starts and ends with the same word. + + Args: + value: A string representing the response. + + Returns: + True if the response starts and ends with the same word; + otherwise, False. + """ + # 对于连字符格式,使用连字符分割 + if "-" in value: + words = value.split("-") + else: + # 对于普通格式,使用NLTK分词 + words = instructions_util.nltk.word_tokenize(value) + + if len(words) < 2: + return False + if words[0].lower() == words[-1].lower(): + return True + else: + return False \ No newline at end of file diff --git a/verl/utils/reward_score/ifbench/instructions_registry.py b/verl/utils/reward_score/ifbench/instructions_registry.py new file mode 100644 index 000000000..6a5322ebc --- /dev/null +++ b/verl/utils/reward_score/ifbench/instructions_registry.py @@ -0,0 +1,313 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Registry of all instructions.""" + +from verl.utils.reward_score.ifbench import instructions + +_PARAGRAPH = "paragraphs:" + +_KEYWORD = "keywords:" + +_LETTER = "letters:" + +_LANGUAGE = "language:" + +_LENGTH = "length_constraints:" + +_CONTENT = "detectable_content:" + +_FORMAT = "detectable_format:" + +_MULTITURN = "multi-turn:" + +_COMBINATION = "combination:" + +_STARTEND = "startend:" + +_CHANGE_CASES = "change_case:" + +_PUNCTUATION = "punctuation:" + +_NEW = "new:" + +_COPY = "copy:" + +_BASIC = "basic:" + +_FIRSTWORD = "first_word:" + +_LASTWORD = "last_word:" + +_COUNT = "count:" + + +FUNCTION_DICT = { + # IFEval Constraints + _KEYWORD + "existence": instructions.KeywordChecker, + _KEYWORD + "frequency": instructions.KeywordFrequencyChecker, + # TODO(jeffreyzhou): make a proper set of sentences to choose from + # _KEYWORD + "key_sentences": instructions.KeySentenceChecker, + _KEYWORD + "forbidden_words": instructions.ForbiddenWords, + _KEYWORD + "letter_frequency": instructions.LetterFrequencyChecker, + _LANGUAGE + "response_language": instructions.ResponseLanguageChecker, + _LENGTH + "number_sentences": instructions.NumberOfSentences, + _LENGTH + "number_paragraphs": instructions.ParagraphChecker, + _LENGTH + "number_words": instructions.NumberOfWords, + _LENGTH + "nth_paragraph_first_word": instructions.ParagraphFirstWordCheck, + _CONTENT + "number_placeholders": instructions.PlaceholderChecker, + _CONTENT + "postscript": instructions.PostscriptChecker, + _FORMAT + "number_bullet_lists": instructions.BulletListChecker, + # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace + # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph, + _FORMAT + "constrained_response": instructions.ConstrainedResponseChecker, + _FORMAT + "number_highlighted_sections": (instructions.HighlightSectionChecker), + _FORMAT + "multiple_sections": instructions.SectionChecker, + # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message. + # _FORMAT + "rephrase": instructions.RephraseChecker, + _FORMAT + "json_format": instructions.JsonFormat, + _FORMAT + "title": instructions.TitleChecker, + # TODO(tianjianlu): Re-enable with specific prompts. + # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker, + _COMBINATION + "two_responses": instructions.TwoResponsesChecker, + _COMBINATION + "repeat_prompt": instructions.RepeatPromptThenAnswer, + _STARTEND + "end_checker": instructions.EndChecker, + _CHANGE_CASES + "capital_word_frequency": instructions.CapitalWordFrequencyChecker, + _CHANGE_CASES + "english_capital": instructions.CapitalLettersEnglishChecker, + _CHANGE_CASES + "english_lowercase": instructions.LowercaseLettersEnglishChecker, + _PUNCTUATION + "no_comma": instructions.CommaChecker, + _STARTEND + "quotation": instructions.QuotationChecker, + # New Constraints! + _COPY + "repeat_phrase": instructions.RepeatPhraseChecker, + _COPY + "copy": instructions.CopyChecker, + _NEW + "copy_span_idx": instructions.CopySpanIdxChecker, + _FORMAT + "sentence_hyphens": instructions.SentenceHyphenChecker, + _KEYWORD + "no_adjacent_consecutive": instructions.AdjacentLetterChecker, + _FORMAT + "square_brackets": instructions.SquareBracketChecker, + _KEYWORD + "word_once": instructions.KeywordFrequencyOnceChecker, + _KEYWORD + "word_count_different_numbers": instructions.KeywordFrequencyCheckerDifferent, + _KEYWORD + "exclude_word_harder": instructions.ExcludeWordHarderChecker, + _PARAGRAPH + "paragraphs": instructions.ParagraphBasicChecker, + _PARAGRAPH + "paragraphs2": instructions.ParagraphBasicChecker2, + _FIRSTWORD + "first_word_sent": instructions.FirstWordSentChecker, + _FIRSTWORD + "first_word_answer": instructions.FirstWordAnswerChecker, + _LASTWORD + "last_word_sent": instructions.LastWordSentChecker, + _LASTWORD + "last_word_answer": instructions.LastWordAnswerChecker, + _FORMAT + "bigram_wrapping": instructions.BiGramWrappingChecker, + _COPY + "copying_simple": instructions.CopyingSimpleChecker, + _COPY + "copying_multiple": instructions.CopyingMultipleChecker, + _PUNCTUATION + "punctuation_dot": instructions.PunctuationDotChecker, + _PUNCTUATION + "punctuation_exclamation": instructions.PunctuationExclamationChecker, + _COUNT + "lowercase_counting": instructions.LowercaseCountingChecker, + _LETTER + "letter_counting": instructions.LetterCountingChecker, + _LETTER + "letter_counting2": instructions.LetterFrequencyChecker, + _COUNT + "counting_composition": instructions.CountingCompositionChecker, + _COUNT + "count_unique": instructions.CountUniqueChecker, + _COUNT + "count_increment_word": instructions.CountIncrementWordChecker, + _KEYWORD + "palindrome": instructions.PalindromeBasicChecker, + _KEYWORD + "keyword_specific_position": instructions.KeywordSpecificPositionChecker, + _KEYWORD + "start_end": instructions.StartEndChecker, +} + +INSTRUCTION_DICT = { + _KEYWORD + "existence": instructions.KeywordChecker, + _KEYWORD + "frequency": instructions.KeywordFrequencyChecker, + # TODO(jeffreyzhou): make a proper set of sentences to choose from + # _KEYWORD + "key_sentences": instructions.KeySentenceChecker, + _KEYWORD + "forbidden_words": instructions.ForbiddenWords, + _KEYWORD + "letter_frequency": instructions.LetterFrequencyChecker, + _LANGUAGE + "response_language": instructions.ResponseLanguageChecker, + _LENGTH + "number_sentences": instructions.NumberOfSentences, + _LENGTH + "number_paragraphs": instructions.ParagraphChecker, + _LENGTH + "number_words": instructions.NumberOfWords, + _LENGTH + "nth_paragraph_first_word": instructions.ParagraphFirstWordCheck, + _CONTENT + "number_placeholders": instructions.PlaceholderChecker, + _CONTENT + "postscript": instructions.PostscriptChecker, + _FORMAT + "number_bullet_lists": instructions.BulletListChecker, + # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace + # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph, + _FORMAT + "constrained_response": instructions.ConstrainedResponseChecker, + _FORMAT + "number_highlighted_sections": (instructions.HighlightSectionChecker), + _FORMAT + "multiple_sections": instructions.SectionChecker, + # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message. + # _FORMAT + "rephrase": instructions.RephraseChecker, + _FORMAT + "json_format": instructions.JsonFormat, + _FORMAT + "title": instructions.TitleChecker, + # TODO(tianjianlu): Re-enable with specific prompts. + # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker, + _COMBINATION + "two_responses": instructions.TwoResponsesChecker, + _COMBINATION + "repeat_prompt": instructions.RepeatPromptThenAnswer, + _STARTEND + "end_checker": instructions.EndChecker, + _CHANGE_CASES + "capital_word_frequency": instructions.CapitalWordFrequencyChecker, + _CHANGE_CASES + "english_capital": instructions.CapitalLettersEnglishChecker, + _CHANGE_CASES + "english_lowercase": instructions.LowercaseLettersEnglishChecker, + _PUNCTUATION + "no_comma": instructions.CommaChecker, + _STARTEND + "quotation": instructions.QuotationChecker, + # New Constraints! + _COPY + "repeat_phrase": instructions.RepeatPhraseChecker, + _COPY + "copy": instructions.CopyChecker, + _NEW + "copy_span_idx": instructions.CopySpanIdxChecker, + _FORMAT + "sentence_hyphens": instructions.SentenceHyphenChecker, + _KEYWORD + "no_adjacent_consecutive": instructions.AdjacentLetterChecker, + _FORMAT + "square_brackets": instructions.SquareBracketChecker, + _KEYWORD + "word_once": instructions.KeywordFrequencyOnceChecker, + _KEYWORD + "word_count_different_numbers": instructions.KeywordFrequencyCheckerDifferent, + _KEYWORD + "exclude_word_harder": instructions.ExcludeWordHarderChecker, + _PARAGRAPH + "paragraphs": instructions.ParagraphBasicChecker, + _PARAGRAPH + "paragraphs2": instructions.ParagraphBasicChecker2, + _FIRSTWORD + "first_word_sent": instructions.FirstWordSentChecker, + _FIRSTWORD + "first_word_answer": instructions.FirstWordAnswerChecker, + _LASTWORD + "last_word_sent": instructions.LastWordSentChecker, + _LASTWORD + "last_word_answer": instructions.LastWordAnswerChecker, + _FORMAT + "bigram_wrapping": instructions.BiGramWrappingChecker, + _COPY + "copying_simple": instructions.CopyingSimpleChecker, + _COPY + "copying_multiple": instructions.CopyingMultipleChecker, + _PUNCTUATION + "punctuation_dot": instructions.PunctuationDotChecker, + _PUNCTUATION + "punctuation_exclamation": instructions.PunctuationExclamationChecker, + _COUNT + "lowercase_counting": instructions.LowercaseCountingChecker, + _LETTER + "letter_counting": instructions.LetterCountingChecker, + _LETTER + "letter_counting2": instructions.LetterFrequencyChecker, + _COUNT + "counting_composition": instructions.CountingCompositionChecker, + _COUNT + "count_unique": instructions.CountUniqueChecker, + _COUNT + "count_increment_word": instructions.CountIncrementWordChecker, + _KEYWORD + "palindrome": instructions.PalindromeBasicChecker, + _KEYWORD + "keyword_specific_position": instructions.KeywordSpecificPositionChecker, + _KEYWORD + "start_end": instructions.StartEndChecker, +} + +INSTRUCTION_CONFLICTS = { + _KEYWORD + "existence": {_KEYWORD + "existence"}, + _KEYWORD + "frequency": {_KEYWORD + "frequency"}, + # TODO(jeffreyzhou): make a proper set of sentences to choose from + # _KEYWORD + "key_sentences": instructions.KeySentenceChecker, + _KEYWORD + "forbidden_words": {_KEYWORD + "forbidden_words"}, + _KEYWORD + "letter_frequency": {_KEYWORD + "letter_frequency"}, + _LANGUAGE + "response_language": { + _LANGUAGE + "response_language", + _FORMAT + "multiple_sections", + _KEYWORD + "existence", + _KEYWORD + "frequency", + _KEYWORD + "forbidden_words", + _STARTEND + "end_checker", + _CHANGE_CASES + "english_capital", + _CHANGE_CASES + "english_lowercase", + }, + _LENGTH + "number_sentences": {_LENGTH + "number_sentences"}, + _LENGTH + "number_paragraphs": { + _LENGTH + "number_paragraphs", + _LENGTH + "nth_paragraph_first_word", + _LENGTH + "number_sentences", + _LENGTH + "nth_paragraph_first_word", + }, + _LENGTH + "number_words": {_LENGTH + "number_words"}, + _LENGTH + "nth_paragraph_first_word": {_LENGTH + "nth_paragraph_first_word", _LENGTH + "number_paragraphs"}, + _CONTENT + "number_placeholders": {_CONTENT + "number_placeholders"}, + _CONTENT + "postscript": {_CONTENT + "postscript"}, + _FORMAT + "number_bullet_lists": {_FORMAT + "number_bullet_lists"}, + # TODO(jeffreyzhou): Pre-create paragraph or use prompt to replace + # _CONTENT + "rephrase_paragraph": instructions.RephraseParagraph, + _FORMAT + "constrained_response": set(INSTRUCTION_DICT.keys()), + _FORMAT + "number_highlighted_sections": {_FORMAT + "number_highlighted_sections"}, + _FORMAT + "multiple_sections": { + _FORMAT + "multiple_sections", + _LANGUAGE + "response_language", + _FORMAT + "number_highlighted_sections", + }, + # TODO(tianjianlu): Re-enable rephrasing with preprocessing the message. + # _FORMAT + "rephrase": instructions.RephraseChecker, + _FORMAT + "json_format": set(INSTRUCTION_DICT.keys()).difference( + {_KEYWORD + "forbidden_words", _KEYWORD + "existence"} + ), + _FORMAT + "title": {_FORMAT + "title"}, + # TODO(tianjianlu): Re-enable with specific prompts. + # _MULTITURN + "constrained_start": instructions.ConstrainedStartChecker, + _COMBINATION + "two_responses": set(INSTRUCTION_DICT.keys()).difference( + { + _KEYWORD + "forbidden_words", + _KEYWORD + "existence", + _LANGUAGE + "response_language", + _FORMAT + "title", + _PUNCTUATION + "no_comma", + } + ), + _COMBINATION + "repeat_prompt": set(INSTRUCTION_DICT.keys()).difference( + {_KEYWORD + "existence", _FORMAT + "title", _PUNCTUATION + "no_comma"} + ), + _STARTEND + "end_checker": {_STARTEND + "end_checker"}, + _CHANGE_CASES + "capital_word_frequency": { + _CHANGE_CASES + "capital_word_frequency", + _CHANGE_CASES + "english_lowercase", + _CHANGE_CASES + "english_capital", + }, + _CHANGE_CASES + "english_capital": {_CHANGE_CASES + "english_capital"}, + _CHANGE_CASES + "english_lowercase": {_CHANGE_CASES + "english_lowercase", _CHANGE_CASES + "english_capital"}, + _PUNCTUATION + "no_comma": {_PUNCTUATION + "no_comma"}, + _STARTEND + "quotation": {_STARTEND + "quotation", _FORMAT + "title"}, + _COPY + "repeat_phrase": {_COPY + "repeat_phrase"}, + _COPY + "copy": set(INSTRUCTION_DICT.keys()), + _NEW + "copy_span_idx": set(INSTRUCTION_DICT.keys()), + _FORMAT + "sentence_hyphens": {_FORMAT + "sentence_hyphens"}, + _KEYWORD + "no_adjacent_consecutive": {_KEYWORD + "no_adjacent_consecutive"}, + _FORMAT + "square_brackets": {_FORMAT + "square_brackets"}, + _KEYWORD + "word_once": {_KEYWORD + "word_once"}, + _KEYWORD + "word_count_different_numbers": {_KEYWORD + "word_count_different_numbers"}, + _KEYWORD + "exclude_word_harder": {_KEYWORD + "exclude_word_harder"}, + _PARAGRAPH + "paragraphs": {_PARAGRAPH + "paragraphs", _PARAGRAPH + "paragraphs2"}, + _PARAGRAPH + "paragraphs2": {_PARAGRAPH + "paragraphs", _PARAGRAPH + "paragraphs2"}, + _FIRSTWORD + "first_word_sent": {_FIRSTWORD + "first_word_sent", _FIRSTWORD + "first_word_answer"}, + _FIRSTWORD + "first_word_answer": {_FIRSTWORD + "first_word_sent", _FIRSTWORD + "first_word_answer"}, + _LASTWORD + "last_word_sent": {_LASTWORD + "last_word_sent"}, + _LASTWORD + "last_word_answer": {_LASTWORD + "last_word_answer"}, + _FORMAT + "bigram_wrapping": {_FORMAT + "bigram_wrapping"}, + _COPY + "copying_simple": set(INSTRUCTION_DICT.keys()), + _COPY + "copying_multiple": set(INSTRUCTION_DICT.keys()), + _PUNCTUATION + "punctuation_dot": {_PUNCTUATION + "punctuation_dot"}, + _PUNCTUATION + "punctuation_exclamation": {_PUNCTUATION + "punctuation_exclamation"}, + _COUNT + "lowercase_counting": {_COUNT + "lowercase_counting"}, + _LETTER + "letter_counting": {_LETTER + "letter_counting"}, + _LETTER + "letter_counting2": {_LETTER + "letter_counting2"}, + _COUNT + "counting_composition": { + _COUNT + "counting_composition", + _COUNT + "count_unique", + _COUNT + "count_increment_word", + _PARAGRAPH + "paragraphs", + _PARAGRAPH + "paragraphs2", + _KEYWORD + "letter_frequency", + _KEYWORD + "frequency", + }, + _COUNT + "count_unique": {_COUNT + "count_unique"}, + _COUNT + "count_increment_word": {_COUNT + "count_increment_word"}, + _KEYWORD + "palindrome": {_KEYWORD + "palindrome"}, + _KEYWORD + "keyword_specific_position": {_KEYWORD + "keyword_specific_position"}, + _KEYWORD + "start_end": {_KEYWORD + "start_end"}, +} + + +def conflict_make(conflicts): + """Makes sure if A conflicts with B, B will conflict with A. + + Args: + conflicts: Dictionary of potential conflicts where key is instruction id + and value is set of instruction ids that it conflicts with. + + Returns: + Revised version of the dictionary. All instructions conflict with + themselves. If A conflicts with B, B will conflict with A. + """ + for key in conflicts: + for k in conflicts[key]: + conflicts[k].add(key) + conflicts[key].add(key) + return conflicts \ No newline at end of file diff --git a/verl/utils/reward_score/ifbench/instructions_util.py b/verl/utils/reward_score/ifbench/instructions_util.py new file mode 100644 index 000000000..6ae323585 --- /dev/null +++ b/verl/utils/reward_score/ifbench/instructions_util.py @@ -0,0 +1,1672 @@ +# coding=utf-8 +# Copyright 2024 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility library of instructions.""" + + +import functools +import os +import random +import re +from importlib.metadata import version + +import immutabledict +import nltk +from packaging.version import parse as parse_version + + + +WORD_LIST = [ + "western", + "sentence", + "signal", + "dump", + "spot", + "opposite", + "bottom", + "potato", + "administration", + "working", + "welcome", + "morning", + "good", + "agency", + "primary", + "wish", + "responsibility", + "press", + "problem", + "president", + "steal", + "brush", + "read", + "type", + "beat", + "trainer", + "growth", + "lock", + "bone", + "case", + "equal", + "comfortable", + "region", + "replacement", + "performance", + "mate", + "walk", + "medicine", + "film", + "thing", + "rock", + "tap", + "total", + "competition", + "ease", + "south", + "establishment", + "gather", + "parking", + "world", + "plenty", + "breath", + "claim", + "alcohol", + "trade", + "dear", + "highlight", + "street", + "matter", + "decision", + "mess", + "agreement", + "studio", + "coach", + "assist", + "brain", + "wing", + "style", + "private", + "top", + "brown", + "leg", + "buy", + "procedure", + "method", + "speed", + "high", + "company", + "valuable", + "pie", + "analyst", + "session", + "pattern", + "district", + "pleasure", + "dinner", + "swimming", + "joke", + "order", + "plate", + "department", + "motor", + "cell", + "spend", + "cabinet", + "difference", + "power", + "examination", + "engine", + "horse", + "dimension", + "pay", + "toe", + "curve", + "literature", + "bother", + "fire", + "possibility", + "debate", + "activity", + "passage", + "hello", + "cycle", + "background", + "quiet", + "author", + "effect", + "actor", + "page", + "bicycle", + "error", + "throat", + "attack", + "character", + "phone", + "tea", + "increase", + "outcome", + "file", + "specific", + "inspector", + "internal", + "potential", + "staff", + "building", + "employer", + "shoe", + "hand", + "direction", + "garden", + "purchase", + "interview", + "study", + "recognition", + "member", + "spiritual", + "oven", + "sandwich", + "weird", + "passenger", + "particular", + "response", + "reaction", + "size", + "variation", + "a", + "cancel", + "candy", + "exit", + "guest", + "condition", + "fly", + "price", + "weakness", + "convert", + "hotel", + "great", + "mouth", + "mind", + "song", + "sugar", + "suspect", + "telephone", + "ear", + "roof", + "paint", + "refrigerator", + "organization", + "jury", + "reward", + "engineering", + "day", + "possession", + "crew", + "bar", + "road", + "description", + "celebration", + "score", + "mark", + "letter", + "shower", + "suggestion", + "sir", + "luck", + "national", + "progress", + "hall", + "stroke", + "theory", + "offer", + "story", + "tax", + "definition", + "history", + "ride", + "medium", + "opening", + "glass", + "elevator", + "stomach", + "question", + "ability", + "leading", + "village", + "computer", + "city", + "grand", + "confidence", + "candle", + "priest", + "recommendation", + "point", + "necessary", + "body", + "desk", + "secret", + "horror", + "noise", + "culture", + "warning", + "water", + "round", + "diet", + "flower", + "bus", + "tough", + "permission", + "week", + "prompt", + "connection", + "abuse", + "height", + "save", + "corner", + "border", + "stress", + "drive", + "stop", + "rip", + "meal", + "listen", + "confusion", + "girlfriend", + "living", + "relation", + "significance", + "plan", + "creative", + "atmosphere", + "blame", + "invite", + "housing", + "paper", + "drink", + "roll", + "silver", + "drunk", + "age", + "damage", + "smoke", + "environment", + "pack", + "savings", + "influence", + "tourist", + "rain", + "post", + "sign", + "grandmother", + "run", + "profit", + "push", + "clerk", + "final", + "wine", + "swim", + "pause", + "stuff", + "singer", + "funeral", + "average", + "source", + "scene", + "tradition", + "personal", + "snow", + "nobody", + "distance", + "sort", + "sensitive", + "animal", + "major", + "negotiation", + "click", + "mood", + "period", + "arrival", + "expression", + "holiday", + "repeat", + "dust", + "closet", + "gold", + "bad", + "sail", + "combination", + "clothes", + "emphasis", + "duty", + "black", + "step", + "school", + "jump", + "document", + "professional", + "lip", + "chemical", + "front", + "wake", + "while", + "inside", + "watch", + "row", + "subject", + "penalty", + "balance", + "possible", + "adult", + "aside", + "sample", + "appeal", + "wedding", + "depth", + "king", + "award", + "wife", + "blow", + "site", + "camp", + "music", + "safe", + "gift", + "fault", + "guess", + "act", + "shame", + "drama", + "capital", + "exam", + "stupid", + "record", + "sound", + "swing", + "novel", + "minimum", + "ratio", + "machine", + "shape", + "lead", + "operation", + "salary", + "cloud", + "affair", + "hit", + "chapter", + "stage", + "quantity", + "access", + "army", + "chain", + "traffic", + "kick", + "analysis", + "airport", + "time", + "vacation", + "philosophy", + "ball", + "chest", + "thanks", + "place", + "mountain", + "advertising", + "red", + "past", + "rent", + "return", + "tour", + "house", + "construction", + "net", + "native", + "war", + "figure", + "fee", + "spray", + "user", + "dirt", + "shot", + "task", + "stick", + "friend", + "software", + "promotion", + "interaction", + "surround", + "block", + "purpose", + "practice", + "conflict", + "routine", + "requirement", + "bonus", + "hole", + "state", + "junior", + "sweet", + "catch", + "tear", + "fold", + "wall", + "editor", + "life", + "position", + "pound", + "respect", + "bathroom", + "coat", + "script", + "job", + "teach", + "birth", + "view", + "resolve", + "theme", + "employee", + "doubt", + "market", + "education", + "serve", + "recover", + "tone", + "harm", + "miss", + "union", + "understanding", + "cow", + "river", + "association", + "concept", + "training", + "recipe", + "relationship", + "reserve", + "depression", + "proof", + "hair", + "revenue", + "independent", + "lift", + "assignment", + "temporary", + "amount", + "loss", + "edge", + "track", + "check", + "rope", + "estimate", + "pollution", + "stable", + "message", + "delivery", + "perspective", + "mirror", + "assistant", + "representative", + "witness", + "nature", + "judge", + "fruit", + "tip", + "devil", + "town", + "emergency", + "upper", + "drop", + "stay", + "human", + "neck", + "speaker", + "network", + "sing", + "resist", + "league", + "trip", + "signature", + "lawyer", + "importance", + "gas", + "choice", + "engineer", + "success", + "part", + "external", + "worker", + "simple", + "quarter", + "student", + "heart", + "pass", + "spite", + "shift", + "rough", + "lady", + "grass", + "community", + "garage", + "youth", + "standard", + "skirt", + "promise", + "blind", + "television", + "disease", + "commission", + "positive", + "energy", + "calm", + "presence", + "tune", + "basis", + "preference", + "head", + "common", + "cut", + "somewhere", + "presentation", + "current", + "thought", + "revolution", + "effort", + "master", + "implement", + "republic", + "floor", + "principle", + "stranger", + "shoulder", + "grade", + "button", + "tennis", + "police", + "collection", + "account", + "register", + "glove", + "divide", + "professor", + "chair", + "priority", + "combine", + "peace", + "extension", + "maybe", + "evening", + "frame", + "sister", + "wave", + "code", + "application", + "mouse", + "match", + "counter", + "bottle", + "half", + "cheek", + "resolution", + "back", + "knowledge", + "make", + "discussion", + "screw", + "length", + "accident", + "battle", + "dress", + "knee", + "log", + "package", + "it", + "turn", + "hearing", + "newspaper", + "layer", + "wealth", + "profile", + "imagination", + "answer", + "weekend", + "teacher", + "appearance", + "meet", + "bike", + "rise", + "belt", + "crash", + "bowl", + "equivalent", + "support", + "image", + "poem", + "risk", + "excitement", + "remote", + "secretary", + "public", + "produce", + "plane", + "display", + "money", + "sand", + "situation", + "punch", + "customer", + "title", + "shake", + "mortgage", + "option", + "number", + "pop", + "window", + "extent", + "nothing", + "experience", + "opinion", + "departure", + "dance", + "indication", + "boy", + "material", + "band", + "leader", + "sun", + "beautiful", + "muscle", + "farmer", + "variety", + "fat", + "handle", + "director", + "opportunity", + "calendar", + "outside", + "pace", + "bath", + "fish", + "consequence", + "put", + "owner", + "go", + "doctor", + "information", + "share", + "hurt", + "protection", + "career", + "finance", + "force", + "golf", + "garbage", + "aspect", + "kid", + "food", + "boot", + "milk", + "respond", + "objective", + "reality", + "raw", + "ring", + "mall", + "one", + "impact", + "area", + "news", + "international", + "series", + "impress", + "mother", + "shelter", + "strike", + "loan", + "month", + "seat", + "anything", + "entertainment", + "familiar", + "clue", + "year", + "glad", + "supermarket", + "natural", + "god", + "cost", + "conversation", + "tie", + "ruin", + "comfort", + "earth", + "storm", + "percentage", + "assistance", + "budget", + "strength", + "beginning", + "sleep", + "other", + "young", + "unit", + "fill", + "store", + "desire", + "hide", + "value", + "cup", + "maintenance", + "nurse", + "function", + "tower", + "role", + "class", + "camera", + "database", + "panic", + "nation", + "basket", + "ice", + "art", + "spirit", + "chart", + "exchange", + "feedback", + "statement", + "reputation", + "search", + "hunt", + "exercise", + "nasty", + "notice", + "male", + "yard", + "annual", + "collar", + "date", + "platform", + "plant", + "fortune", + "passion", + "friendship", + "spread", + "cancer", + "ticket", + "attitude", + "island", + "active", + "object", + "service", + "buyer", + "bite", + "card", + "face", + "steak", + "proposal", + "patient", + "heat", + "rule", + "resident", + "broad", + "politics", + "west", + "knife", + "expert", + "girl", + "design", + "salt", + "baseball", + "grab", + "inspection", + "cousin", + "couple", + "magazine", + "cook", + "dependent", + "security", + "chicken", + "version", + "currency", + "ladder", + "scheme", + "kitchen", + "employment", + "local", + "attention", + "manager", + "fact", + "cover", + "sad", + "guard", + "relative", + "county", + "rate", + "lunch", + "program", + "initiative", + "gear", + "bridge", + "breast", + "talk", + "dish", + "guarantee", + "beer", + "vehicle", + "reception", + "woman", + "substance", + "copy", + "lecture", + "advantage", + "park", + "cold", + "death", + "mix", + "hold", + "scale", + "tomorrow", + "blood", + "request", + "green", + "cookie", + "church", + "strip", + "forever", + "beyond", + "debt", + "tackle", + "wash", + "following", + "feel", + "maximum", + "sector", + "sea", + "property", + "economics", + "menu", + "bench", + "try", + "language", + "start", + "call", + "solid", + "address", + "income", + "foot", + "senior", + "honey", + "few", + "mixture", + "cash", + "grocery", + "link", + "map", + "form", + "factor", + "pot", + "model", + "writer", + "farm", + "winter", + "skill", + "anywhere", + "birthday", + "policy", + "release", + "husband", + "lab", + "hurry", + "mail", + "equipment", + "sink", + "pair", + "driver", + "consideration", + "leather", + "skin", + "blue", + "boat", + "sale", + "brick", + "two", + "feed", + "square", + "dot", + "rush", + "dream", + "location", + "afternoon", + "manufacturer", + "control", + "occasion", + "trouble", + "introduction", + "advice", + "bet", + "eat", + "kill", + "category", + "manner", + "office", + "estate", + "pride", + "awareness", + "slip", + "crack", + "client", + "nail", + "shoot", + "membership", + "soft", + "anybody", + "web", + "official", + "individual", + "pizza", + "interest", + "bag", + "spell", + "profession", + "queen", + "deal", + "resource", + "ship", + "guy", + "chocolate", + "joint", + "formal", + "upstairs", + "car", + "resort", + "abroad", + "dealer", + "associate", + "finger", + "surgery", + "comment", + "team", + "detail", + "crazy", + "path", + "tale", + "initial", + "arm", + "radio", + "demand", + "single", + "draw", + "yellow", + "contest", + "piece", + "quote", + "pull", + "commercial", + "shirt", + "contribution", + "cream", + "channel", + "suit", + "discipline", + "instruction", + "concert", + "speech", + "low", + "effective", + "hang", + "scratch", + "industry", + "breakfast", + "lay", + "join", + "metal", + "bedroom", + "minute", + "product", + "rest", + "temperature", + "many", + "give", + "argument", + "print", + "purple", + "laugh", + "health", + "credit", + "investment", + "sell", + "setting", + "lesson", + "egg", + "middle", + "marriage", + "level", + "evidence", + "phrase", + "love", + "self", + "benefit", + "guidance", + "affect", + "you", + "dad", + "anxiety", + "special", + "boyfriend", + "test", + "blank", + "payment", + "soup", + "obligation", + "reply", + "smile", + "deep", + "complaint", + "addition", + "review", + "box", + "towel", + "minor", + "fun", + "soil", + "issue", + "cigarette", + "internet", + "gain", + "tell", + "entry", + "spare", + "incident", + "family", + "refuse", + "branch", + "can", + "pen", + "grandfather", + "constant", + "tank", + "uncle", + "climate", + "ground", + "volume", + "communication", + "kind", + "poet", + "child", + "screen", + "mine", + "quit", + "gene", + "lack", + "charity", + "memory", + "tooth", + "fear", + "mention", + "marketing", + "reveal", + "reason", + "court", + "season", + "freedom", + "land", + "sport", + "audience", + "classroom", + "law", + "hook", + "win", + "carry", + "eye", + "smell", + "distribution", + "research", + "country", + "dare", + "hope", + "whereas", + "stretch", + "library", + "if", + "delay", + "college", + "plastic", + "book", + "present", + "use", + "worry", + "champion", + "goal", + "economy", + "march", + "election", + "reflection", + "midnight", + "slide", + "inflation", + "action", + "challenge", + "guitar", + "coast", + "apple", + "campaign", + "field", + "jacket", + "sense", + "way", + "visual", + "remove", + "weather", + "trash", + "cable", + "regret", + "buddy", + "beach", + "historian", + "courage", + "sympathy", + "truck", + "tension", + "permit", + "nose", + "bed", + "son", + "person", + "base", + "meat", + "usual", + "air", + "meeting", + "worth", + "game", + "independence", + "physical", + "brief", + "play", + "raise", + "board", + "she", + "key", + "writing", + "pick", + "command", + "party", + "yesterday", + "spring", + "candidate", + "physics", + "university", + "concern", + "development", + "change", + "string", + "target", + "instance", + "room", + "bitter", + "bird", + "football", + "normal", + "split", + "impression", + "wood", + "long", + "meaning", + "stock", + "cap", + "leadership", + "media", + "ambition", + "fishing", + "essay", + "salad", + "repair", + "today", + "designer", + "night", + "bank", + "drawing", + "inevitable", + "phase", + "vast", + "chip", + "anger", + "switch", + "cry", + "twist", + "personality", + "attempt", + "storage", + "being", + "preparation", + "bat", + "selection", + "white", + "technology", + "contract", + "side", + "section", + "station", + "till", + "structure", + "tongue", + "taste", + "truth", + "difficulty", + "group", + "limit", + "main", + "move", + "feeling", + "light", + "example", + "mission", + "might", + "wait", + "wheel", + "shop", + "host", + "classic", + "alternative", + "cause", + "agent", + "consist", + "table", + "airline", + "text", + "pool", + "craft", + "range", + "fuel", + "tool", + "partner", + "load", + "entrance", + "deposit", + "hate", + "article", + "video", + "summer", + "feature", + "extreme", + "mobile", + "hospital", + "flight", + "fall", + "pension", + "piano", + "fail", + "result", + "rub", + "gap", + "system", + "report", + "suck", + "ordinary", + "wind", + "nerve", + "ask", + "shine", + "note", + "line", + "mom", + "perception", + "brother", + "reference", + "bend", + "charge", + "treat", + "trick", + "term", + "homework", + "bake", + "bid", + "status", + "project", + "strategy", + "orange", + "let", + "enthusiasm", + "parent", + "concentrate", + "device", + "travel", + "poetry", + "business", + "society", + "kiss", + "end", + "vegetable", + "employ", + "schedule", + "hour", + "brave", + "focus", + "process", + "movie", + "illegal", + "general", + "coffee", + "ad", + "highway", + "chemistry", + "psychology", + "hire", + "bell", + "conference", + "relief", + "show", + "neat", + "funny", + "weight", + "quality", + "club", + "daughter", + "zone", + "touch", + "tonight", + "shock", + "burn", + "excuse", + "name", + "survey", + "landscape", + "advance", + "satisfaction", + "bread", + "disaster", + "item", + "hat", + "prior", + "shopping", + "visit", + "east", + "photo", + "home", + "idea", + "father", + "comparison", + "cat", + "pipe", + "winner", + "count", + "lake", + "fight", + "prize", + "foundation", + "dog", + "keep", + "ideal", + "fan", + "struggle", + "peak", + "safety", + "solution", + "hell", + "conclusion", + "population", + "strain", + "alarm", + "measurement", + "second", + "train", + "race", + "due", + "insurance", + "boss", + "tree", + "monitor", + "sick", + "course", + "drag", + "appointment", + "slice", + "still", + "care", + "patience", + "rich", + "escape", + "emotion", + "royal", + "female", + "childhood", + "government", + "picture", + "will", + "sock", + "big", + "gate", + "oil", + "cross", + "pin", + "improvement", + "championship", + "silly", + "help", + "sky", + "pitch", + "man", + "diamond", + "most", + "transition", + "work", + "science", + "committee", + "moment", + "fix", + "teaching", + "dig", + "specialist", + "complex", + "guide", + "people", + "dead", + "voice", + "original", + "break", + "topic", + "data", + "degree", + "reading", + "recording", + "bunch", + "reach", + "judgment", + "lie", + "regular", + "set", + "painting", + "mode", + "list", + "player", + "bear", + "north", + "wonder", + "carpet", + "heavy", + "officer", + "negative", + "clock", + "unique", + "baby", + "pain", + "assumption", + "disk", + "iron", + "bill", + "drawer", + "look", + "double", + "mistake", + "finish", + "future", + "brilliant", + "contact", + "math", + "rice", + "leave", + "restaurant", + "discount", + "sex", + "virus", + "bit", + "trust", + "event", + "wear", + "juice", + "failure", + "bug", + "context", + "mud", + "whole", + "wrap", + "intention", + "draft", + "pressure", + "cake", + "dark", + "explanation", + "space", + "angle", + "word", + "efficiency", + "management", + "habit", + "star", + "chance", + "finding", + "transportation", + "stand", + "criticism", + "flow", + "door", + "injury", + "insect", + "surprise", + "apartment", +] # pylint: disable=line-too-long + +# ISO 639-1 codes to language names. +LANGUAGE_CODES = immutabledict.immutabledict( + { + "en": "English", + "es": "Spanish", + "pt": "Portuguese", + "ar": "Arabic", + "hi": "Hindi", + "fr": "French", + "ru": "Russian", + "de": "German", + "ja": "Japanese", + "it": "Italian", + "bn": "Bengali", + "uk": "Ukrainian", + "th": "Thai", + "ur": "Urdu", + "ta": "Tamil", + "te": "Telugu", + "bg": "Bulgarian", + "ko": "Korean", + "pl": "Polish", + "he": "Hebrew", + "fa": "Persian", + "vi": "Vietnamese", + "ne": "Nepali", + "sw": "Swahili", + "kn": "Kannada", + "mr": "Marathi", + "gu": "Gujarati", + "pa": "Punjabi", + "ml": "Malayalam", + "fi": "Finnish", + } +) + +_ALPHABETS = "([A-Za-z])" +_PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]" +_SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)" +_STARTERS = ( + r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" +) +_ACRONYMS = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" +_WEBSITES = "[.](com|net|org|io|gov|edu|me)" +_DIGITS = "([0-9])" +_MULTIPLE_DOTS = r"\.{2,}" + + +def split_into_sentences(text): + """Split the text into sentences. + + Args: + text: A string that consists of more than or equal to one sentences. + + Returns: + A list of strings where each string is a sentence. + """ + text = " " + text + " " + text = text.replace("\n", " ") + text = re.sub(_PREFIXES, "\\1", text) + text = re.sub(_WEBSITES, "\\1", text) + text = re.sub(_DIGITS + "[.]" + _DIGITS, "\\1\\2", text) + text = re.sub(_MULTIPLE_DOTS, lambda match: "" * len(match.group(0)) + "", text) + if "Ph.D" in text: + text = text.replace("Ph.D.", "PhD") + text = re.sub(r"\s" + _ALPHABETS + "[.] ", " \\1 ", text) + text = re.sub(_ACRONYMS + " " + _STARTERS, "\\1 \\2", text) + text = re.sub(_ALPHABETS + "[.]" + _ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1\\2\\3", text) + text = re.sub(_ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1\\2", text) + text = re.sub(" " + _SUFFIXES + "[.] " + _STARTERS, " \\1 \\2", text) + text = re.sub(" " + _SUFFIXES + "[.]", " \\1", text) + text = re.sub(" " + _ALPHABETS + "[.]", " \\1", text) + if "”" in text: + text = text.replace(".”", "”.") + if '"' in text: + text = text.replace('."', '".') + if "!" in text: + text = text.replace('!"', '"!') + if "?" in text: + text = text.replace('?"', '"?') + text = text.replace(".", ".") + text = text.replace("?", "?") + text = text.replace("!", "!") + text = text.replace("", ".") + sentences = text.split("") + sentences = [s.strip() for s in sentences] + if sentences and not sentences[-1]: + sentences = sentences[:-1] + return sentences + + +def count_words(text): + """Counts the number of words.""" + tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+") + tokens = tokenizer.tokenize(text) + num_words = len(tokens) + return num_words + + +@functools.lru_cache(maxsize=None) +def _get_sentence_tokenizer(): + return nltk.data.load("nltk:tokenizers/punkt/english.pickle") + + +def count_sentences(text): + """Count the number of sentences.""" + tokenizer = _get_sentence_tokenizer() + tokenized_sentences = tokenizer.tokenize(text) + return len(tokenized_sentences) + + +def generate_keywords(num_keywords): + """Randomly generates a few keywords.""" + return random.sample(WORD_LIST, k=num_keywords) \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/__init__.py b/verl/utils/reward_score/synlogic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/verl/utils/reward_score/synlogic/arrow_maze_verifier.py b/verl/utils/reward_score/synlogic/arrow_maze_verifier.py new file mode 100644 index 000000000..841c70c1c --- /dev/null +++ b/verl/utils/reward_score/synlogic/arrow_maze_verifier.py @@ -0,0 +1,306 @@ +import json +from typing import List, Dict, Tuple +from .verifier import Verifier +from .data import Data +import re + +class ArrowMazeVerifier(Verifier): + """ + 箭头迷宫游戏验证器 + + 验证条件: + 1. 判断answer grid的大小是否和question grid一致 + 2. 判断answer grid中数字格子是否和question grid中数字格子一致 + 3. 判断question grid空格("X")在answer grid中是否被箭头填满 + 4. 判断箭头符号是否合法: + 上(↑)、下(↓)、左(←)、右(→)或对角线方向(↖、↗、↘、↙) + 5. 判断answer grid中非空格("X")和非数字的部分,即预填的箭头,是否和question grid一致 + 6. 迷宫有个隐藏的条件是所有箭头都能被射线箭头串覆盖到 + 7. 每个数字起点发出的射线箭头串总长度等于该数字 + """ + + # 定义合法的箭头符号 + VALID_ARROWS = {"↑", "↓", "←", "→", "↖", "↗", "↘", "↙"} + + # 定义箭头符号和其对应的方向 + ARROWS_DIRECTIONS = { + "↑": (-1, 0), # 上 + "↓": (1, 0), # 下 + "←": (0, -1), # 左 + "→": (0, 1), # 右 + "↖": (-1, -1), # 左上 + "↗": (-1, 1), # 右上 + "↘": (1, 1), # 右下 + "↙": (1, -1) # 左下 + } + + def verify(self, data: Data, test_solution_str: str) -> bool: + + """ + 验证箭头迷宫的答案是否正确 + + @param data: 游戏数据 + @param test_solution_str: 测试答案字符串 (JSON格式的二维数组) + @return: 答案是否正确 + """ + test_answer_str = self.extract_answer(test_solution_str) + if not test_answer_str: + # print("答案为空,验证失败") + return False + + try: + # 解析测试答案 + test_answer = json.loads(test_answer_str) + + # 获取原始迷宫 + question_grid = data.metadata["maze"] + + # 检查答案是否符合要求 + if not self._verify_grid_size(test_answer, question_grid): + # print("答案网格大小与题目不匹配") + return False + + if not self._verify_number_positions(test_answer, question_grid): + # print("答案中数字位置或值与题目不匹配") + return False + + if not self._verify_all_blanks_filled(test_answer, question_grid): + # print("答案中有空格未被填满") + return False + + + if not self._verify_arrow_symbols(test_answer): + # print("答案中包含非法箭头符号") + return False + + + if not self._verify_prefilled_arrows(test_answer, question_grid): + # print("答案中预填箭头与题目不一致") + return False + + if not self._verify_arrow_rays(test_answer): + # print("答案中存在未被射线覆盖的箭头") + return False + + if not self._verify_number_rays(test_answer): + # print("答案中数字的射线箭头串总数不符合要求") + return False + + # 所有验证都通过 + return True + + except Exception as e: + return False + + def _verify_grid_size(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证答案网格大小是否与题目一致 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 网格大小是否一致 + """ + if len(test_answer) != len(question_grid): + return False + + for i in range(len(test_answer)): + if len(test_answer[i]) != len(question_grid[i]): + return False + + return True + + def _verify_number_positions(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证答案中数字位置和值是否与题目一致 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 数字位置和值是否一致 + """ + for i in range(len(question_grid)): + for j in range(len(question_grid[i])): + if question_grid[i][j].isdigit(): + if test_answer[i][j] != question_grid[i][j]: + return False + return True + + def _verify_all_blanks_filled(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证所有空格是否都被填满 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 所有空格是否被填满 + """ + for i in range(len(question_grid)): + for j in range(len(question_grid[i])): + if question_grid[i][j] == "X" and test_answer[i][j] == "X": + return False + return True + + def _verify_arrow_symbols(self, test_answer: List[List[str]]) -> bool: + """ + 验证箭头符号是否合法 + + @param test_answer: 测试答案网格 + @return: 箭头符号是否合法 + """ + for i in range(len(test_answer)): + for j in range(len(test_answer[i])): + cell = test_answer[i][j] + if not cell.isdigit() and cell != "X" and cell not in self.VALID_ARROWS: + return False + return True + + def _verify_prefilled_arrows(self, test_answer: List[List[str]], question_grid: List[List[str]]) -> bool: + """ + 验证预填的箭头是否与题目一致 + + @param test_answer: 测试答案网格 + @param question_grid: 题目网格 + @return: 预填箭头是否一致 + """ + for i in range(len(question_grid)): + for j in range(len(question_grid[i])): + cell = question_grid[i][j] + if not cell.isdigit() and cell != "X": + if test_answer[i][j] != cell: + return False + return True + + def _verify_arrow_rays(self, test_answer: List[List[str]]) -> bool: + """ + 验证所有箭头是否都能被射线箭头串覆盖到 + + @param test_answer: 测试答案网格 + @return: 所有箭头是否都能被射线覆盖 + """ + n = len(test_answer) + m = len(test_answer[0]) if n > 0 else 0 + + # 创建覆盖标记数组 + covered = [[False for _ in range(m)] for _ in range(n)] + + # 标记数字位置为已覆盖 + for i in range(n): + for j in range(m): + if test_answer[i][j].isdigit(): + covered[i][j] = True + + # 从每个数字出发,沿各个方向延伸射线,标记覆盖到的箭头 + for i in range(n): + for j in range(m): + if test_answer[i][j].isdigit(): + # 检查所有方向 + for arrow_symbol, (di, dj) in self.ARROWS_DIRECTIONS.items(): + ni, nj = i + di, j + dj + # 沿该方向延伸,直到边界或非匹配箭头 + while 0 <= ni < n and 0 <= nj < m and test_answer[ni][nj] == arrow_symbol: + covered[ni][nj] = True + ni += di + nj += dj + + # 检查所有箭头是否都被覆盖 + for i in range(n): + for j in range(m): + if test_answer[i][j] in self.VALID_ARROWS and not covered[i][j]: + return False + + return True + + def _verify_number_rays(self, test_answer: List[List[str]]) -> bool: + """ + 验证每个数字起点发出的射线箭头串总长度是否等于该数字 + + @param test_answer: 测试答案网格 + @return: 每个数字的射线箭头串是否符合要求 + """ + n = len(test_answer) + m = len(test_answer[0]) if n > 0 else 0 + + for i in range(n): + for j in range(m): + if test_answer[i][j].isdigit(): + number = int(test_answer[i][j]) + arrow_count = self._count_arrow_rays(test_answer, i, j) + if arrow_count != number: + return False + + return True + + def _count_arrow_rays(self, grid: List[List[str]], i: int, j: int) -> int: + """ + 计算从数字出发的所有射线箭头串中箭头总数 + + @param grid: 网格 + @param i: 数字行索引 + @param j: 数字列索引 + @return: 箭头总数 + """ + n = len(grid) + m = len(grid[0]) if n > 0 else 0 + count = 0 + + # 检查所有方向 + for arrow_symbol, (di, dj) in self.ARROWS_DIRECTIONS.items(): + ni, nj = i + di, j + dj + ray_length = 0 + + # 沿该方向计数连续的相同箭头 + while 0 <= ni < n and 0 <= nj < m and grid[ni][nj] == arrow_symbol: + ray_length += 1 + ni += di + nj += dj + + count += ray_length + + return count + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 (JSON格式的二维数组) + """ + if not test_solution: + return "" + # 尝试匹配Python代码块 + import re + code_block_patterns = [ + r'```python\s*\n(.*?\[.*?\].*?)\n```', # 标准Python代码块 + r'```\s*\n(.*?\[.*?\].*?)\n```', # 无语言标记的代码块 + r'```(.*?\[.*?\].*?)```' # 无换行的代码块 + ] + + for pattern in code_block_patterns: + matches = re.findall(pattern, test_solution, re.DOTALL) + if matches: + # 获取最后一个匹配项 + code_block = matches[-1].strip() + try: + # 尝试解析为Python列表 + grid = eval(code_block) + # 验证格式是否为二维数组 + if isinstance(grid, list) and all(isinstance(row, list) for row in grid): + return json.dumps(grid) + except Exception as e: + # print(f"解析代码块失败: {e}") + continue + + # 如果没有找到有效的代码块,尝试直接寻找列表 + list_pattern = r'\[\s*\[.*?\]\s*\]' + matches = re.findall(list_pattern, test_solution, re.DOTALL) + if matches: + try: + # 尝试解析为Python列表 + grid = eval(matches[-1]) + # 验证格式是否为二维数组 + if isinstance(grid, list) and all(isinstance(row, list) for row in grid): + return json.dumps(grid) + except Exception as e: + pass + # print(f"解析列表失败: {e}") + + # 如果上述方法都失败,返回空字符串 + return "" \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py b/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py new file mode 100644 index 000000000..8abdb0742 --- /dev/null +++ b/verl/utils/reward_score/synlogic/boolean_expressions_verifier.py @@ -0,0 +1,53 @@ +import re +from .data import Data +from .verifier import Verifier + +class BooleanExpressionsVerifier(Verifier): + """ + 验证器用于布尔表达式游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + test_answer = self.extract_answer(test_answer) + if test_answer is None: + return False + # 提取所有字母(a-z和A-Z) + test_answer_letters = re.findall(r'[a-zA-Z]', test_answer) + ground_truth_letters = re.findall(r'[a-zA-Z]', data.answer) + test_answer_letters = self.lower(test_answer_letters) + ground_truth_letters = self.lower(ground_truth_letters) + # 转换为集合进行比较 + test_set = set(test_answer_letters) + ground_truth_set = set(ground_truth_letters) + + return test_set == ground_truth_set + except Exception as e: + return False + + def lower(self, answer_list): + return [answer.lower() for answer in answer_list] + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/campsite_verifier.py b/verl/utils/reward_score/synlogic/campsite_verifier.py new file mode 100644 index 000000000..405b5b073 --- /dev/null +++ b/verl/utils/reward_score/synlogic/campsite_verifier.py @@ -0,0 +1,182 @@ +from .data import Data +from .verifier import Verifier +import re +import ast +from typing import List, Set, Tuple, Dict + + +class CampsiteVerifier(Verifier): + """ + Verifier for Campsite game + """ + def verify(self, data: Data, test_solution: str): + try: + test_answer = self.extract_answer(test_solution) + original_grid = data.metadata["grid"] + row_constraints = data.metadata["row_constraints"] + col_constraints = data.metadata["col_constraints"] + n = data.metadata["n"] + m = data.metadata["m"] + + if not test_answer: + return False + + if len(test_answer) != n or any(len(row) != m for row in test_answer): + return False + + if not self._check_trees_unchanged(original_grid, test_answer): + return False + + if not self._check_row_constraints(test_answer, row_constraints): + return False + + if not self._check_col_constraints(test_answer, col_constraints): + return False + + if not self._check_tents_not_adjacent(test_answer): + return False + + if not self._check_tent_tree_matching(test_answer): + return False + + return True + + except Exception as e: + return False + + def _extract_grid(self, test_answer: str) -> List[List[str]]: + """从回答中提取网格""" + grid_pattern = r'\[\s*\[.*?\]\s*\]' + match = re.search(grid_pattern, test_answer, re.DOTALL) + if match: + try: + grid_str = match.group(0) + return ast.literal_eval(grid_str) + except: + pass + + return None + + def _check_trees_unchanged(self, original_grid: List[List[str]], test_answer: List[List[str]]) -> bool: + """检查树木位置是否保持不变""" + for i in range(len(original_grid)): + for j in range(len(original_grid[0])): + if original_grid[i][j] == 'T' and test_answer[i][j] != 'T': + return False + if original_grid[i][j] != 'T' and test_answer[i][j] == 'T': + return False + return True + + def _check_row_constraints(self, grid: List[List[str]], row_constraints: List[int]) -> bool: + """检查行约束条件""" + for i in range(len(grid)): + tent_count = sum(1 for cell in grid[i] if cell == 'C') + if tent_count != row_constraints[i]: + return False + return True + + def _check_col_constraints(self, grid: List[List[str]], col_constraints: List[int]) -> bool: + """检查列约束条件""" + for j in range(len(grid[0])): + tent_count = sum(1 for i in range(len(grid)) if grid[i][j] == 'C') + if tent_count != col_constraints[j]: + return False + return True + + def _check_tents_not_adjacent(self, grid: List[List[str]]) -> bool: + """检查帐篷之间是否相邻(包括对角线)""" + n = len(grid) + m = len(grid[0]) if n > 0 else 0 + + for i in range(n): + for j in range(m): + if grid[i][j] == 'C': + # 检查周围8个方向是否有其他帐篷 + for di in [-1, 0, 1]: + for dj in [-1, 0, 1]: + if di == 0 and dj == 0: + continue + ni, nj = i + di, j + dj + if 0 <= ni < n and 0 <= nj < m and grid[ni][nj] == 'C': + return False + + return True + + def _check_tent_tree_matching(self, grid: List[List[str]]) -> bool: + """ + 检查帐篷与树木的一一匹配关系: + 1. 每个帐篷必须与一棵树正交相邻 + 2. 每棵树只能与一个帐篷匹配 + 3. 每个帐篷只能与一棵树匹配 + 4. 帐篷和树的数量必须相等 + """ + n = len(grid) + m = len(grid[0]) if n > 0 else 0 + + tents = [] + trees = [] + for i in range(n): + for j in range(m): + if grid[i][j] == 'C': + tents.append((i, j)) + elif grid[i][j] == 'T': + trees.append((i, j)) + + if len(tents) != len(trees): + return False + + tent_to_trees = {} + tree_to_tents = {} + + for tent_i, tent_j in tents: + tent_to_trees[(tent_i, tent_j)] = [] + for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + tree_i, tree_j = tent_i + di, tent_j + dj + if 0 <= tree_i < n and 0 <= tree_j < m and grid[tree_i][tree_j] == 'T': + tent_to_trees[(tent_i, tent_j)].append((tree_i, tree_j)) + + for tree_i, tree_j in trees: + tree_to_tents[(tree_i, tree_j)] = [] + for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + tent_i, tent_j = tree_i + di, tree_j + dj + if 0 <= tent_i < n and 0 <= tent_j < m and grid[tent_i][tent_j] == 'C': + tree_to_tents[(tree_i, tree_j)].append((tent_i, tent_j)) + + for tent in tents: + if not tent_to_trees[tent]: + return False + + tent_matched = {} + tree_matched = {} + + def dfs(tent): + for tree in tent_to_trees[tent]: + if tree in visited: + continue + visited.add(tree) + + if tree not in tree_matched or dfs(tree_matched[tree]): + tent_matched[tent] = tree + tree_matched[tree] = tent + return True + return False + + for tent in tents: + visited = set() + if tent not in tent_matched: + if not dfs(tent): + return False + + return len(tent_matched) == len(tents) and len(tree_matched) == len(trees) + + def extract_answer(self, test_solution: str): + """从模型回答中提取解决方案""" + grid_pattern = r'\[\s*\[.*?\]\s*\]' + match = re.search(grid_pattern, test_solution, re.DOTALL) + if match: + try: + grid_str = match.group(0) + return ast.literal_eval(grid_str) + except: + pass + return "" \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/data.py b/verl/utils/reward_score/synlogic/data.py new file mode 100644 index 000000000..8f7d2998d --- /dev/null +++ b/verl/utils/reward_score/synlogic/data.py @@ -0,0 +1,51 @@ +import json + +class Data: + """ + Data class for game/corpus + @param question: question of the game/corpus + @param answer: answer of the game/corpus + @param difficulty: difficulty of the game/corpus, from 1 to 10 + """ + def __init__(self, question: str, answer: str, difficulty: int = 1, metadata: dict = None, **kwargs): + self.question = question + self.answer = answer + self.difficulty = difficulty + self.metadata = metadata + self.gpt_response = "" + + def to_json(self): + return { + "question": self.question, + "answer": self.answer, + "difficulty": self.difficulty, + "metadata": self.metadata, + "gpt_response": self.gpt_response + } + + def to_json_str(self): + return json.dumps(self.to_json(), ensure_ascii=False) + + @classmethod + def from_json_str(cls, json_str): + json_data = json.loads(json_str) + return cls(**json_data) + + @classmethod + def from_json_dict(cls, json_dict): + instance = cls(**json_dict) + if 'gpt_response' in json_dict: + instance.gpt_response = json_dict['gpt_response'] + return instance + + @classmethod + def from_jsonl_file(cls, file_path): + data_list = [] + with open(file_path, "r") as f: + for line in f: + json_data = json.loads(line) + instance = cls(**json_data) + if 'gpt_response' in json_data: + instance.gpt_response = json_data['gpt_response'] + data_list.append(instance) + return data_list \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py new file mode 100644 index 000000000..a498473e8 --- /dev/null +++ b/verl/utils/reward_score/synlogic/dyck_language_errors_verifier.py @@ -0,0 +1,90 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + + +class DyckLanguageErrorsVerifier(Verifier): + """ + 验证器用于检查括号闭合错误识别游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution=test_answer) + # 获取正确答案 + if data.metadata["is_valid"]: + correct_answer = "-1" # 合法序列对应-1 + else: + correct_answer = str(data.metadata["first_error_pos"]) + + # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'") + + # 清理和标准化答案 + test_answer = test_answer.strip() + + # 检查-1答案(合法序列) + if correct_answer == "-1": + # 如果正确答案是-1(合法序列),只接受-1作为回答 + if test_answer == "-1": + is_correct = True + else: + is_correct = False + else: + # 正确答案是位置数字,需要验证模型回答也是相同数字 + try: + is_correct = (int(test_answer) == int(correct_answer)) + except (ValueError, TypeError): + # 如果模型回答不是有效数字,验证失败 + is_correct = False + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + + except Exception as e: + return False + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 + """ + answer_str = test_solution + if answer_str is None: + import re + # 清理回答文本 + solution = test_solution.strip() if test_solution else "" + + # 提取所有数字(包括负数) + numbers = re.findall(r'-?\d+', solution) + if numbers: + # 优先返回"-1"(如果存在) + if "-1" in numbers: + return "-1" + # 否则返回找到的第一个非负整数 + for num in numbers: + if num.isdigit() and int(num) >= 0: + return num + # 如果只有负数,返回第一个 + return numbers[0] + + # 检查是否表示合法 + + + # 默认返回空字符串 + return "" + elif any(keyword in answer_str.lower() for keyword in ["合法", "valid", "correct"]): + return "-1" + else: + return answer_str \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py new file mode 100644 index 000000000..ace952964 --- /dev/null +++ b/verl/utils/reward_score/synlogic/dyck_language_reasoning_errors_verifier.py @@ -0,0 +1,129 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + + +class DyckLanguageReasoningErrorsVerifier(Verifier): + """ + Dyck语言推理错误识别验证器 + """ + def verify(self, data: Data, test_answer: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的答案字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution=test_answer) + # 获取元数据中的正确答案 + correct_indices = data.metadata["error_indices"] + # 格式化为正确的答案字符串格式 + expected_answer = self._format_answer(correct_indices) + + # print(f"验证: 模型答案='{test_answer}', 正确答案='{expected_answer}'") + + # 检查不明确的答案 + if "不确定" in test_answer or "不知道" in test_answer or "unclear" in test_answer.lower(): + # print("验证结果: 错误") + return False + + # 清理模型答案,允许一定的格式变化 + cleaned_test_answer = self._standardize_answer(test_answer) + + if not correct_indices and (cleaned_test_answer == "" or cleaned_test_answer.lower() in ["无问题", "no", "无错误", "no error", "no errors", "no mistakes", "all correct"]): + # 如果没有错误,且模型回答是空字符串或表示无问题,则正确 + is_correct = True + else: + # 将两个答案转换为数字集合进行比较 + test_error_indices = self._extract_error_indices(cleaned_test_answer) + expected_error_indices = set(correct_indices) + + # 检查两个集合是否相同 + is_correct = test_error_indices == expected_error_indices + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + + except Exception as e: + return False + + def _standardize_answer(self, answer: str) -> str: + """ + 标准化答案字符串 + + @param answer: 原始答案字符串 + @return: 标准化后的答案字符串 + """ + # 如果答案为空或仅包含空白字符 + if not answer or answer.strip() == "": + return "" + + # 如果答案表示没有错误 + if answer.lower() in ["无问题", "no", "无错误", "no error", "no errors", "no mistakes", "all correct"]: + return "" + + # 替换中文逗号为英文逗号 + answer = answer.replace(',', ',') + # 移除所有非数字和逗号的字符 + answer = re.sub(r'[^0-9,]', '', answer) + + return answer + + def _extract_error_indices(self, answer: str) -> set: + """ + 从答案字符串中提取错误索引集合 + + @param answer: 答案字符串 + @return: 错误索引集合 + """ + if not answer or answer.strip() == "": + return set() + + # 先替换中文逗号为英文逗号 + answer = answer.replace(',', ',') + + # 分割逗号分隔的数字字符串,并转换为整数集合 + try: + return set(int(idx) for idx in answer.split(',') if idx and idx.strip().isdigit()) + except ValueError: + return set() + + def _format_answer(self, error_indices: list) -> str: + """ + 格式化答案字符串 + + @param error_indices: 错误步骤的索引列表 + @return: 格式化的答案字符串 + """ + if not error_indices: + return "" + + # 按照数字大小排序 + sorted_indices = sorted(error_indices) + + # 转换为字符串,使用英文逗号分隔 + return ",".join(map(str, sorted_indices)) + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取错误索引 + + @param test_solution: 模型的完整回答 + """ + answer_str = test_solution + return answer_str + +if __name__ == "__main__": + test_response = '''\nOkay, so I'm trying to figure out if there are any errors in the reasoning process provided for checking the proper closure of brackets in the sequence: {}[][][]{}{}([]). Let me go through each thought step by step and see where things might have gone wrong.\n\nStarting with Thought 2: the initial stack is empty, which makes sense because we haven't processed any characters yet. So that seems correct.\n\nThought 3: The first character is '{', so it's pushed onto the stack. The stack becomes '{'. Hmm, but in the thought, it's written as '{[', which is incorrect. The stack after the first character should only have '{'. So this thought is wrong because it adds an extra '[' that wasn't there yet. Wait, no, actually, the input sequence starts with '{', so maybe there's a mistake in the thought. Let me recheck. Oh, looking again, no, the input sequence is {}[][][]{}{}([)]}, so the first two characters are '{' and '}'. So after processing '{', the stack should be '{', not '{['. So Thought 3 is wrong because it mistakenly pushes '[' instead of just '{'.\n\nMoving to Thought 4: Next character is '}' which correctly matches the top '{' in the stack, so the stack becomes empty. This part is correct.\n\nThought 5: The next character is '[', so stack becomes '['. Correct.\n\nThought 6: Next is ']', which pops the '[', so stack is empty. Correct.\n\nThought 7: Next is '[', stack becomes '['. Correct.\n\nThought 8: Next is ']', so stack is empty. Correct.\n\nThought 9: Next is '[', stack becomes '['. Correct.\n\nThought 10: Next is ']' which pops the '[', so stack should be empty again. However, in the thought, it says the stack becomes ']', which is incorrect. So Thought 10 is wrong because it added a ']' instead of popping it off. Wait no, the character is ']', which should pop the top '[', so the stack should go from '[' to empty. But the thought says the stack becomes ']' which is wrong. So Thought 10 is incorrect.\n\nThought 11: Next is '{', stack is now '{'. Correct.\n\nThought 12: Next is '}', popping the '{', stack is empty. Correct.\n\nThought 13: Next is '{', stack becomes '{'. Correct.\n\nThought 14: Next is '}', stack is empty. Correct.\n\nThought 15: Next is '(', stack becomes '('. Correct.\n\nThought 16: Next is '[', so stack becomes '(['. Correct.\n\nThought 17: Next is ']', which pops the '[', so stack becomes '('. Correct.\n\nThought 18: Next is ')', which pops the '(', so stack is empty. Correct.\n\nThought 19: After processing all characters, stack is empty. So the sequence is balanced.\n\nWait, but the input sequence actually ends with ']})', right? The sequence given is {}[][][]{}{}([]). So after processing '(' and '[', we have '(', then '['. Then when we process ']', it pops the '[', stack becomes '('. Then next is ')', which pops '(', stack is empty. So everything is correct up to that point.\n\nBut in the earlier steps, the issue was in Thought 10. Because when processing the fifth character, which is '[', followed by ']', the stack should go back to empty. So Thought 10 says the stack becomes ']', which is incorrect. It should go back to empty.\n\nWait, let me recount the steps:\n\nThe input is {}[][][]{}{}([]). So let's break it down:\n\n1. '{' : stack becomes '{'\n2. '}' : stack empty\n3. '[' : stack becomes '['\n4. ']' : stack empty\n5. '[' : '['\n6. ']' : stack empty\n7. '[' : '['\n8. ']' : stack empty\n9. '{' : '{'\n10. '}' : stack empty\n11. '{' : '{'\n12. '}' : stack empty\n13. '(' : '('\n14. '[' : '(['\n15. ']' : '('\n16. ')' : stack empty.\n\nWait, but in the original thought process, after step 10, the character is ']' which pops the '[', making the stack empty, but the thought says stack is ']'. So that's definitely incorrect.\n\nAlso, in step 3, the thought says after '{' the stack is '{[', which is incorrect because the next character is '}', not another '['. So step 3 is wrong because it added an extra '[' instead of just '{' and then processing '}' in step 4 correctly.\n\nSo the errors are in Thought 3 and Thought 10.\n\n\nThe errors are in Thought 3 and Thought 10 because Thought 3 incorrectly adds an extra '[' instead of just '{' and Thought 10 mistakenly leaves the stack as ']' instead of empty after popping.\n\n[3,20]''' + metadata = {"trace_id": "77db72eb-a9db-46cd-96ea-5a49eba78792", "dyck_sequence": "{}[][][]{}{}([])", "thoughts": ["Thought 1: 我们应该逐个处理输入并跟踪栈的配置。", "Thought 2: 栈: 空", "Thought 3: { ; 栈: {[", "Thought 4: } ; 栈: 空", "Thought 5: [ ; 栈: [", "Thought 6: ] ; 栈: 空", "Thought 7: [ ; 栈: [", "Thought 8: ] ; 栈: 空", "Thought 9: [ ; 栈: [", "Thought 10: ] ; 栈: ]", "Thought 11: { ; 栈: {", "Thought 12: } ; 栈: 空", "Thought 13: { ; 栈: {", "Thought 14: } ; 栈: 空", "Thought 15: ( ; 栈: (", "Thought 16: [ ; 栈: ([", "Thought 17: ] ; 栈: (", "Thought 18: ) ; 栈: 空", "Thought 19: 现在,我们已经到达结尾。最终栈是空的。"], "error_indices": [3, 10], "n_types": 3, "total_length": 15, "n_errors": 2} + test_data = Data(question="", answer="", metadata=metadata) + test_verifier = DyckLanguageReasoningErrorsVerifier() + extracted_answer = test_verifier.extract_answer(test_response) + print(extracted_answer) + print(test_verifier.verify(data=test_data, test_answer=test_response)) \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/dyck_language_verifier.py b/verl/utils/reward_score/synlogic/dyck_language_verifier.py new file mode 100644 index 000000000..04f3b2d20 --- /dev/null +++ b/verl/utils/reward_score/synlogic/dyck_language_verifier.py @@ -0,0 +1,81 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + + +class DyckLanguageVerifier(Verifier): + """ + 验证器用于检查Dyck Language游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str) -> bool: + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + # 获取元数据中的完整序列 + full_sequence = data.metadata["full_sequence"] + + # print(f"验证: 模型答案='{test_answer}', 完整序列='{full_sequence}'") + + # 从模型回答中提取答案 + extracted_answer = self.extract_answer(test_answer) + + # 检查答案是否完全匹配 + is_correct = (extracted_answer == full_sequence) + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + + return is_correct + + except Exception as e: + return False + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取括号序列答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 + """ + if not test_solution: + return "" + + # print(f"原始回答:\n{test_solution}") + + def clean_text(text: str) -> str: + """清理文本,处理转义字符和空白字符""" + # 移除所有空白字符(包括换行符、制表符等) + text = ''.join(text.split()) + + # 处理转义序列 + text = text.replace('\\n', '') + text = text.replace('\\t', '') + text = text.replace('\\r', '') + text = text.replace('\\\\', '\\') + + # 如果文本被引号包围,且引号不是括号序列的一部分,则移除外层引号 + if len(text) >= 2: + if text.startswith('"') and text.endswith('"'): + text = text[1:-1] + elif text.startswith("'") and text.endswith("'"): + text = text[1:-1] + + return text + + return clean_text(test_solution) + +if __name__ == "__main__": + test_response = '''填写后的完整序列应为“([])({})([()])”。\n\n检查一下长度是否正确:\n\n原序列长度为11字符,补充3个字符,总长度14。\n\n这样,整个序列是合法的。\n\n\n([])({})([()])''' + metadata = {"trace_id": "38aeede4-d5d7-4863-91d2-df1fd99f491b", "full_sequence": "([])({})([()])", "question_sequence": "([])({})([(", "n_types": 3, "total_length": 14, "fill_length": 3, "nesting_depth": 0} + test_data = Data(question="", answer="", metadata=metadata) + test_verifier = DyckLanguageVerifier() + extracted_answer = test_verifier.extract_answer(test_response) + print(extracted_answer) + print(test_verifier.verify(data=test_data, test_answer=test_response)) \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py b/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py new file mode 100644 index 000000000..b4c8bdb25 --- /dev/null +++ b/verl/utils/reward_score/synlogic/game_of_buggy_tables_verifier.py @@ -0,0 +1,126 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re + +class BuggyTableVerifier(Verifier): + """ + Verifier for the BuggyTable game. + Checks if the submitted answer matches the expected answer. + """ + def extract_answer(self, answer: str) -> str: + """ + Public method to extract and normalize an answer string from LLM output. + Delegates to the private _extract_answer method. + + @param answer: The answer string to normalize + @return: The normalized answer string + """ + return self._extract_answer(answer) + + def verify(self, data: Data, test_answer: str) -> bool: + """ + Verify whether the test answer is consistent with the expected answer + for the buggy table query. + + @param data: Data object containing the expected answer + @param test_answer: The answer provided by the LLM to verify + @return: bool indicating whether the answer is correct + """ + # Extract the expected answer from the Data object + expected_answer = data.answer if data and hasattr(data, 'answer') else "" + + # For empty strings, compare directly + if not expected_answer and not test_answer: + return True + + # Extract and normalize both answers + normalized_expected = self._extract_answer(expected_answer) + normalized_test = self._extract_answer(test_answer) + + # Direct comparison of normalized answers + return normalized_expected == normalized_test + + def _is_raw_numeric_answer(self, value: str) -> bool: + """ + Check if a string represents a plain numeric answer without additional context. + This is used to validate raw input format. + + @param value: The string to check + @return: True if the string is a simple numeric value + """ + # Remove whitespace + value = value.strip() + + # Simple pattern match for a number (optionally with sign and decimal point) + import re + return bool(re.match(r'^-?\d+(\.\d+)?$', value)) + + def _raw_has_exactly_two_decimals(self, value: str) -> bool: + """ + Check if a raw numeric string has exactly 2 decimal places. + This is used to validate the format of the raw answer. + + @param value: The string to check + @return: True if the string has exactly 2 decimal places + """ + # Remove whitespace + value = value.strip() + + # Split on decimal point + parts = value.replace('-', '', 1).split('.') + + # Check if there is exactly one decimal point and two digits after it + return len(parts) == 2 and len(parts[1]) == 2 + + def _is_numeric(self, value: str) -> bool: + """ + Check if a string represents a valid number (including negative numbers and decimals). + + @param value: The string to check + @return: True if the string represents a valid number + """ + # Remove negative sign if present + value = value.strip() + if value.startswith('-'): + value = value[1:] + # Check if remaining string is a valid decimal number + return value.replace('.', '', 1).isdigit() + + def _has_exactly_two_decimals(self, value: str) -> bool: + """ + Check if a number string has exactly 2 decimal places. + + @param value: The number string to check + @return: True if the number has exactly 2 decimal places + """ + # Remove negative sign if present + value = value.strip() + if value.startswith('-'): + value = value[1:] + + # Split into whole and decimal parts + parts = value.split('.') + if len(parts) != 2: + return False + + # Check if decimal part has exactly 2 digits + return len(parts[1]) == 2 + + def _extract_answer(self, answer: str) -> str: + """ + Extract and normalize an answer string from LLM output. + Only finds values with exactly two decimal places. + + @param answer: The answer string to normalize + @return: The normalized answer string + """ + # Convert to string and normalize + normalized = str(answer).strip() if answer is not None else "" + + # Try to find numbers with exactly two decimal places + exact_matches = re.findall(r'-?\d+\.\d{2}\b', normalized) + if exact_matches: + return exact_matches[-1] # Return the last match with exactly two decimals + + # If no exact two-decimal match found, return the original string + return normalized \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/goods_exchange_verifier.py b/verl/utils/reward_score/synlogic/goods_exchange_verifier.py new file mode 100644 index 000000000..1a57d66b8 --- /dev/null +++ b/verl/utils/reward_score/synlogic/goods_exchange_verifier.py @@ -0,0 +1,216 @@ +import re +from .data import Data +from .verifier import Verifier + +class GoodsExchangeVerifier(Verifier): + """ + 验证器用于检查物品交换游戏的答案是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution) + # 获取元数据中的正确答案 + correct_answer = data.metadata["owns_after"] + + # print(f"验证: 模型答案='{test_answer}', 正确答案='{correct_answer}'") + + # 解析模型答案 + model_ownership = self._parse_answer(test_answer) + # 解析正确答案 + correct_ownership = self._parse_answer(correct_answer) + + # 比较两个答案是否完全一致 + is_correct = self._compare_answers(model_ownership, correct_ownership) + + # if is_correct: + # print("验证结果: 正确") + # else: + # print("验证结果: 错误") + # # 打印详细的不匹配信息 + # self._print_difference(model_ownership, correct_ownership) + + return is_correct + + except Exception as e: + return False + + def _parse_answer(self, answer_str): + """ + 解析答案字符串为物品归属字典 + + @param answer_str: 答案字符串,格式为"(('人1','物品1'),('人2','物品2'),...)"或"(人1,物品1),(人2,物品2),..." + @return: 归属关系字典 {人: 物品} + """ + if not answer_str: + return {} + + result = {} + try: + # 预处理:只处理最外层的空格,保留内部结构 + answer_str = answer_str.strip() + + # 尝试使用 eval 解析 Python tuple 格式 + pairs = eval(answer_str) + if isinstance(pairs, tuple): + for pair in pairs: + if isinstance(pair, tuple) and len(pair) == 2: + person, item = pair + # 处理每个值中的空格:移除两端空格 + result[person.strip()] = item.strip() + return result + except Exception as e: + # 如果 eval 失败,记录错误并尝试解析旧格式 + + # 移除最外层的括号(如果有) + if answer_str.startswith('('): + answer_str = answer_str[1:] + if answer_str.endswith(')'): + answer_str = answer_str[:-1] + + # 更健壮的手动解析逻辑 + person_item_pairs = [] + current_pair = "" + bracket_count = 0 + + # 更智能地分割答案字符串 + for char in answer_str: + if char == '(': + bracket_count += 1 + current_pair += char + elif char == ')': + bracket_count -= 1 + current_pair += char + if bracket_count == 0: + person_item_pairs.append(current_pair) + current_pair = "" + elif char == ',' and bracket_count == 0: + # 跳过顶层逗号 + continue + else: + current_pair += char + + # 处理每一对 + for pair in person_item_pairs: + pair = pair.strip() + # 移除括号 + if pair.startswith('('): + pair = pair[1:] + if pair.endswith(')'): + pair = pair[:-1] + + # 拆分人和物品 + try: + # 使用更健壮的分割方法 + parts = [] + quote_count = 0 + current = "" + + for char in pair: + if char in "\"'" and (len(current) == 0 or current[-1] != '\\'): + quote_count = 1 - quote_count + + if char == ',' and quote_count == 0: + parts.append(current.strip()) + current = "" + else: + current += char + + if current: + parts.append(current.strip()) + + if len(parts) >= 2: + person = parts[0].strip().strip("'\"") + item = parts[1].strip().strip("'\"") + result[person] = item + except Exception as e: + print(f"NOTE!!! parse error!!!! (GoodsExchange 2): {e}") + + return result + + def _compare_answers(self, model_ownership, correct_ownership): + """ + 比较两个归属关系字典是否相同 + + @param model_ownership: 模型回答的归属关系 + @param correct_ownership: 正确的归属关系 + @return: 是否完全一致 + """ + # 检查人数是否相同 + if len(model_ownership) != len(correct_ownership): + return False + + # 创建小写人名到原始人名的映射 + model_lower_to_original = {person.lower(): person for person in model_ownership} + + # 检查每个人的物品是否一致 + for person in correct_ownership: + # 如果模型答案中没有这个人(不区分大小写) + if person.lower() not in model_lower_to_original: + return False + + # 获取模型答案中对应的原始人名 + model_person = model_lower_to_original[person.lower()] + + # 如果人的物品不匹配(不区分大小写) + if model_ownership[model_person].lower() != correct_ownership[person].lower(): + return False + + return True + + def _print_difference(self, model_ownership, correct_ownership): + """ + 打印两个归属关系之间的差异 + + @param model_ownership: 模型回答的归属关系 + @param correct_ownership: 正确的归属关系 + """ + print("\n差异详情:") + + # 创建小写人名到原始人名的映射 + model_lower_to_original = {person.lower(): person for person in model_ownership} + correct_lower_to_original = {person.lower(): person for person in correct_ownership} + + # 检查正确答案中的每个人 + for person in correct_ownership: + person_lower = person.lower() + if person_lower not in model_lower_to_original: + # print(f" - 模型答案中缺少: {person}") + pass + else: + model_person = model_lower_to_original[person_lower] + # if model_ownership[model_person].lower() != correct_ownership[person].lower(): + # print(f" - {person}: 模型答案={model_ownership[model_person]}, 正确答案={correct_ownership[person]}") + + # 检查模型答案中的额外人员 + # for person in model_ownership: + # if person.lower() not in correct_lower_to_original: + # print(f" - 模型答案中多余: {person}") + + def extract_answer(self, text): + """从文本中提取答案。 + + Args: + text (str): 输入文本 + + Returns: + str: 提取的答案,格式为 "(('人1','物品1'),('人2','物品2'),...)" + """ + if not text: + return "" + + # 尝试从 Python markdown 代码块中提取 + code_block_pattern = r'```python\s*\n(.*?)\n```' + code_blocks = re.findall(code_block_pattern, text, re.DOTALL) + if code_blocks: + # 使用最后一个代码块 + last_block = code_blocks[-1].strip() + if last_block.startswith("(") and last_block.endswith(")"): + return last_block + return "" \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/math_path_verifier.py b/verl/utils/reward_score/synlogic/math_path_verifier.py new file mode 100755 index 000000000..deffc7a28 --- /dev/null +++ b/verl/utils/reward_score/synlogic/math_path_verifier.py @@ -0,0 +1,98 @@ +import re +import json +import numpy as np +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + + +class MathPathVerifier(Verifier): + """ + 验证器用于检查math_path填充游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的运算表达式 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution=test_answer) + except Exception as e: + return False + + try: + # 解析元数据 + metadata = data.metadata + ref_expr = metadata["ref_expr"] + query_expr = metadata["query_expr"] + + # 验证数字是否被篡改,数字是否在0-9之间。 + test_tmp = test_answer.replace(' ', '').strip() + query_tmp = query_expr.replace(' ', '').strip() + ref_tmp = ref_expr.replace(' ', '').strip() + query_nums = [x for x in query_tmp if '0'<=x<='9' or x=='?'] + test_nums = [x for x in test_tmp if '0'<=x<='9'] + if len(query_nums)!=len(test_nums): + # print(f"所填数字数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + else: + for ind, x in enumerate(query_nums): + if x=='?': + continue + if x!=test_nums[ind]: + # print(f"表达式数字被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + + query_symbols = [x for x in query_tmp if x in ['+', '-', '*', '/', '%']] + test_symbols = [x for x in test_tmp if x in ['+', '-', '*', '/', '%']] + if len(query_symbols)!=len(test_symbols): + # print(f"表达式运算符号数量不匹配!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + else: + for ind, x in enumerate(query_symbols): + if x!=test_symbols[ind]: + # print(f"表达式运算符号被篡改!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + + # 验证回答中的等式是否成立 + try: + tmp = test_tmp.replace('=', '==') + if not eval(tmp): + # print(f"等式不成立!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + except: + # print(f"运算表达式错误!原始:{ref_tmp},query:{query_tmp},模型:{test_tmp}") + return False + + + # 所有检查都通过 + # print("验证结果: 正确") + return True + + except Exception as e: + return False + + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取答案(字符表达式) + + @param test_solution: 模型的完整回答 + @return: 提取的矩阵答案字符串 + """ + if not test_solution: + return "" + # 尝试提取Python代码块中的矩阵 + code_block_pattern = r'\[\[(.*?)\]\]' + code_matches = re.findall(code_block_pattern, test_solution) + + if code_matches: + # 使用最后一个匹配内容 + operation_expression = code_matches[-1].strip() + return operation_expression + + # 如果所有方法都失败,返回空字符串 + return "" + diff --git a/verl/utils/reward_score/synlogic/minesweeper_verifier.py b/verl/utils/reward_score/synlogic/minesweeper_verifier.py new file mode 100644 index 000000000..ca73d4fe3 --- /dev/null +++ b/verl/utils/reward_score/synlogic/minesweeper_verifier.py @@ -0,0 +1,60 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +from typing import List, Tuple + + +class MinesweeperVerifier(Verifier): + """ + Verifier for Minesweeper puzzle + 扫雷游戏验证器 + """ + def verify(self, data: Data, test_solution: str, **kwargs): + try: + # 从解答中提取地雷坐标 + predicted_mines = self.extract_answer(test_solution) + + # 从metadata中获取确定性地雷坐标 + expected_mines = data.metadata["current_mines"] + + # 验证提取的坐标是否正确 + if set(tuple(mine) for mine in predicted_mines) == set(tuple(mine) for mine in expected_mines): + return True + + return False + + except Exception as e: + # 如果验证过程中发生任何错误,返回False + return False + + def extract_answer(self, response: str) -> List[Tuple[int, int]]: + """从模型的响应中提取地雷坐标 + Extract mine coordinates from the model's response""" + patterns = [ + r'\[\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)(?:\s*,\s*\(\s*\d+\s*,\s*\d+\s*\))*\s*\]', # [(0,1),(2,3)] + r'\[\s*\[\s*(\d+)\s*,\s*(\d+)\s*\](?:\s*,\s*\[\s*\d+\s*,\s*\d+\s*\])*\s*\]', # [[0,1],[2,3]] + r'\(\s*(\d+)\s*,\s*(\d+)\s*\)(?:\s*,\s*\(\s*\d+\s*,\s*\d+\s*\))*', # (0,1),(2,3) + ] + + for pattern in patterns: + coords = [] + for match in re.finditer(pattern, response): + try: + # 提取所有坐标对 + coord_pattern = r'(?:\(|\[)\s*(\d+)\s*,\s*(\d+)\s*(?:\)|\])' + for coord_match in re.finditer(coord_pattern, match.group(0)): + i, j = int(coord_match.group(1)), int(coord_match.group(2)) + coords.append((i, j)) + except Exception: + continue + + if coords: + return coords + + # 如果没有找到坐标,尝试查找可能是坐标的任何数字 + number_pairs = re.findall(r'(\d+)[^\d]+(\d+)', response) + if number_pairs: + return [(int(i), int(j)) for i, j in number_pairs] + + return [] \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/norinori_verifier.py b/verl/utils/reward_score/synlogic/norinori_verifier.py new file mode 100644 index 000000000..6d27d21a8 --- /dev/null +++ b/verl/utils/reward_score/synlogic/norinori_verifier.py @@ -0,0 +1,186 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +from collections import defaultdict + +class NorinoriVerifier(Verifier): + """ + Norinori 游戏的验证器 + 检查提交的答案是否符合 Norinori 游戏规则 + """ + + def __init__(self): + super().__init__() + + def verify(self, data: Data, test_solution: str): + """ + 验证 Norinori 游戏的答案 + + 参数: + data -- 游戏数据,包含区域网格等信息 + test_solution -- 用户提交的答案,应为多米诺坐标列表 + + 返回: + bool -- 答案是否正确 + """ + try: + # 从游戏数据中获取区域网格 + region_grid = data.metadata["region_grid"] + n = len(region_grid) + + # 解析答案 + dominoes = self._parse_answer(test_solution) + if dominoes is None: + return False + + # 检查多米诺形状 + if not self._check_domino_shapes(dominoes): + return False + + # 创建覆盖网格 + covered = [[False for _ in range(n)] for _ in range(n)] + for domino in dominoes: + for i, j in domino: + # 转换为0-indexed + i -= 1 + j -= 1 + if i < 0 or i >= n or j < 0 or j >= n: + return False # 坐标超出范围 + if covered[i][j]: + return False # 格子被多次覆盖 + covered[i][j] = True + + # 检查多米诺之间是否相邻 + if not self._check_domino_adjacency(dominoes, n): + return False + + # 检查每个区域是否恰好有两个格子被覆盖 + region_coverage = defaultdict(int) + for i in range(n): + for j in range(n): + if covered[i][j] and region_grid[i][j] != "X": + region_coverage[region_grid[i][j]] += 1 + + for region, count in region_coverage.items(): + if count != 2: + return False + + # 检查所有阴影格子是否被覆盖 + for i in range(n): + for j in range(n): + if region_grid[i][j] == "X" and not covered[i][j]: + return False + + return True + except Exception as e: + return False + + def _parse_answer(self, test_solution: str): + """ + 解析答案字符串,提取多米诺坐标 + + 参数: + test_solution -- 答案字符串 + + 返回: + list -- 多米诺坐标列表,如果格式不正确则返回None + """ + try: + # 使用正则表达式提取坐标对 + pattern = r'\[\((\d+),\s*(\d+)\),\s*\((\d+),\s*(\d+)\)\]' + matches = re.findall(pattern, test_solution) + + if not matches: + # 尝试另一种可能的格式 + pattern = r'\(\s*(\d+)\s*,\s*(\d+)\s*\)\s*,\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)' + matches = re.findall(pattern, test_solution) + + dominoes = [] + for match in matches: + i1, j1, i2, j2 = map(int, match) + dominoes.append([(i1, j1), (i2, j2)]) + + return dominoes + except Exception as e: + return None + + def _check_domino_shapes(self, dominoes): + """ + 检查所有多米诺是否都是1×2或2×1的形状 + + 参数: + dominoes -- 多米诺坐标列表 + + 返回: + bool -- 是否所有多米诺都符合形状要求 + """ + for domino in dominoes: + if len(domino) != 2: + return False + + (i1, j1), (i2, j2) = domino + + # 检查是否为1×2或2×1 + if not ((i1 == i2 and abs(j1 - j2) == 1) or + (j1 == j2 and abs(i1 - i2) == 1)): + return False + + return True + + def _check_domino_adjacency(self, dominoes, n): + """ + 检查多米诺之间是否相邻 + + 参数: + dominoes -- 多米诺坐标列表 + n -- 网格大小 + + 返回: + bool -- 是否所有多米诺都不相邻 + """ + # 创建一个网格来标记每个多米诺的位置 + grid = [[-1 for _ in range(n+2)] for _ in range(n+2)] # 加2是为了处理边界 + + for idx, domino in enumerate(dominoes): + for i, j in domino: + # 转换为0-indexed并考虑边界 + grid[i][j] = idx + + # 检查每个多米诺是否与其他多米诺相邻 + for idx, domino in enumerate(dominoes): + for i, j in domino: + for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + ni, nj = i + di, j + dj + if 1 <= ni <= n and 1 <= nj <= n: # 检查是否在网格内 + if grid[ni][nj] != -1 and grid[ni][nj] != idx: + return False # 发现相邻的多米诺 + + return True + + def extract_answer(self, test_solution: str, strict=False): + """ + 从回答中提取答案 + + 参数: + test_solution -- 用户的回答 + strict -- 是否严格模式 + + 返回: + str -- 提取的答案 + """ + # 尝试找到答案部分 + answer_patterns = [ + r'\[\s*\[\s*\(\s*\d+\s*,\s*\d+\s*\)\s*,\s*\(\s*\d+\s*,\s*\d+\s*\)\s*\]', # 寻找格式如 [[(1,2), (1,3)], ...] 的答案 + r'答案是\s*(.*?)\s*$', # 中文格式 + r'answer is\s*(.*?)\s*$', # 英文格式 + r'solution is\s*(.*?)\s*$' # 另一种英文格式 + ] + + for pattern in answer_patterns: + matches = re.findall(pattern, test_solution, re.IGNORECASE | re.DOTALL) + if matches: + # 返回最后一个匹配项,通常是最终答案 + return matches[-1] + + # 如果没有找到明确的答案格式,返回整个解答 + return test_solution \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/number_wall_verifier.py b/verl/utils/reward_score/synlogic/number_wall_verifier.py new file mode 100644 index 000000000..3336d292a --- /dev/null +++ b/verl/utils/reward_score/synlogic/number_wall_verifier.py @@ -0,0 +1,225 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +from collections import deque + +class NumberWallVerifier(Verifier): + """ + Verifier for Number Wall puzzle + 数字墙拼图验证器 + """ + def verify(self, data: Data, test_solution: str, **kwargs): + try: + # 提取答案网格 + solution_grid = self.extract_answer(test_solution) + if not solution_grid: + # print("Failed to extract solution grid") + return False + + # 提取元数据 + original_grid = data.metadata["grid"] + n = data.metadata["n"] + + # 检查网格尺寸 + if len(solution_grid) != n: + # print(f"Solution grid has incorrect number of rows: {len(solution_grid)} != {n}") + return False + + for row in solution_grid: + if len(row) != n: + # print(f"Solution grid has incorrect number of columns: {len(row)} != {n}") + return False + + # 检查每个单元格只包含数字、"X"或"A" + for cell in row: + if not (isinstance(cell, int) or cell in ["X", "A"]): + # print(f"Invalid cell content: {cell}") + return False + + # 检查原始数字是否保留 + if not self._check_original_numbers(original_grid, solution_grid): + # print("Original numbers not preserved") + return False + + # 检查墙壁布局是否有效(没有2×2或更大的连续墙块) + if not self._check_wall_layout(solution_grid): + # print("Invalid wall layout (2x2 or larger continuous wall blocks found)") + return False + + # 检查岛屿划分是否有效 + if not self._check_islands(solution_grid): + # print("Invalid island division") + return False + + # 检查是否有斜线边 + if not self._check_diagonal_borders(solution_grid): + # print("Invalid solution: islands have diagonal borders") + return False + + return True + + except Exception as e: + # 如果验证过程中发生任何错误,返回False + return False + + def _check_original_numbers(self, original_grid, solution_grid): + """检查原始数字是否在解决方案中保留""" + for i in range(len(original_grid)): + for j in range(len(original_grid[i])): + if isinstance(original_grid[i][j], int): + if original_grid[i][j] != solution_grid[i][j]: + # print(f"Original number at ({i},{j}) changed: {original_grid[i][j]} -> {solution_grid[i][j]}") + return False + return True + + def _check_wall_layout(self, grid): + """检查墙壁布局是否有效(没有2×2或更大的连续墙块)""" + n = len(grid) + for i in range(n - 1): + for j in range(n - 1): + if (grid[i][j] == "A" and grid[i][j+1] == "A" and + grid[i+1][j] == "A" and grid[i+1][j+1] == "A"): + # print(f"Found 2x2 wall block at ({i},{j})") + return False + return True + + def _check_islands(self, grid): + """检查岛屿划分是否有效""" + n = len(grid) + visited = set() + + for i in range(n): + for j in range(n): + if (i, j) not in visited and grid[i][j] != "A": + # 发现一个新岛屿 + island_cells = [] + island_number = None + queue = deque([(i, j)]) + visited.add((i, j)) + + while queue: + r, c = queue.popleft() + island_cells.append((r, c)) + + if isinstance(grid[r][c], int): + if island_number is not None: + # 岛屿有多个数字 + # print(f"Island contains multiple numbers: {island_number} and {grid[r][c]}") + return False + island_number = grid[r][c] + + for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + nr, nc = r + dr, c + dc + if (0 <= nr < n and 0 <= nc < n and + (nr, nc) not in visited and + grid[nr][nc] != "A"): + queue.append((nr, nc)) + visited.add((nr, nc)) + + if island_number is None: + # 岛屿没有数字 + # print(f"Island at ({i},{j}) has no number") + return False + + if len(island_cells) != island_number: + # 岛屿大小与数字不匹配 + # print(f"Island size ({len(island_cells)}) doesn't match number ({island_number})") + return False + + return True + + def _check_diagonal_borders(self, grid): + """检查是否有斜线边(对角相邻的不同岛屿)""" + n = len(grid) + + # 标记所有岛屿 + island_map = {} # 映射格子坐标到岛屿ID + island_id = 0 + visited = set() + + for i in range(n): + for j in range(n): + if grid[i][j] != "A" and (i, j) not in visited: + # 发现一个新岛屿 + queue = deque([(i, j)]) + visited.add((i, j)) + + while queue: + r, c = queue.popleft() + island_map[(r, c)] = island_id + + for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]: + nr, nc = r + dr, c + dc + if (0 <= nr < n and 0 <= nc < n and + grid[nr][nc] != "A" and (nr, nc) not in visited): + queue.append((nr, nc)) + visited.add((nr, nc)) + + island_id += 1 + + # 检查斜线边 + for i in range(n - 1): + for j in range(n - 1): + # 检查2x2方格中的对角格子 + if (grid[i][j] != "A" and grid[i+1][j+1] != "A" and + grid[i][j+1] == "A" and grid[i+1][j] == "A"): + # 对角格子属于不同岛屿,存在斜线边 + if island_map.get((i, j)) != island_map.get((i+1, j+1)): + # print(f"Found diagonal border at ({i},{j}) and ({i+1},{j+1})") + return False + + # 检查另一对对角格子 + if (grid[i][j+1] != "A" and grid[i+1][j] != "A" and + grid[i][j] == "A" and grid[i+1][j+1] == "A"): + # 对角格子属于不同岛屿,存在斜线边 + if island_map.get((i, j+1)) != island_map.get((i+1, j)): + # print(f"Found diagonal border at ({i},{j+1}) and ({i+1},{j})") + return False + + return True + + + def extract_answer(self, response: str): + """从模型的响应中提取答案网格""" + # 在响应中寻找网格表示 + # 修改正则表达式以匹配字符串形式的数字 + grid_pattern = r'\[\s*\[(?:\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*,\s*)*\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*\]\s*(?:,\s*\[(?:\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*,\s*)*\s*(?:"[XA]"|\'[XA]\'|[0-9]+|"[0-9]+"|\'[0-9]+\')\s*\]\s*)*\]' + matches = re.findall(grid_pattern, response) + + if matches: + # 尝试解析最后一个匹配项 + grid_str = matches[-1] + + try: + # 尝试清理字符串,替换可能导致问题的字符 + cleaned_grid_str = grid_str.replace('\n', '').replace('\r', '').strip() + grid = json.loads(cleaned_grid_str) + + # 将字符串数字转换为整数 + for i in range(len(grid)): + for j in range(len(grid[i])): + if isinstance(grid[i][j], str) and grid[i][j].isdigit(): + grid[i][j] = int(grid[i][j]) + + return grid + except json.JSONDecodeError as e: + # 尝试使用 ast.literal_eval 作为备选方案 + try: + import ast + grid = ast.literal_eval(cleaned_grid_str) + + # 将字符串数字转换为整数 + for i in range(len(grid)): + for j in range(len(grid[i])): + if isinstance(grid[i][j], str) and grid[i][j].isdigit(): + grid[i][j] = int(grid[i][j]) + + return grid + except Exception as e2: + pass + else: + # print("No grid pattern found in the response") + pass + + return None diff --git a/verl/utils/reward_score/synlogic/numbrix_verifier.py b/verl/utils/reward_score/synlogic/numbrix_verifier.py new file mode 100644 index 000000000..13a29494a --- /dev/null +++ b/verl/utils/reward_score/synlogic/numbrix_verifier.py @@ -0,0 +1,101 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import ast +import numpy as np + +class NumbrixVerifier(Verifier): + """ + Numbrix 游戏的验证器 + 验证提交的解答是否符合 Numbrix 游戏规则 + """ + def verify(self, data: Data, test_solution: str): + try: + # 提取答案网格 + test_grid = self.extract_answer(test_solution) + if not test_grid: + return False + + # 获取原始谜题和网格大小 + original_grid = data.metadata["grid"] + n = len(original_grid) + n_squared = n * n + + # 检查网格大小是否正确 + if len(test_grid) != n or any(len(row) != n for row in test_grid): + return False + + # 检查是否包含所有数字 1 到 n² + flattened_grid = [cell for row in test_grid for cell in row] + if sorted(flattened_grid) != list(range(1, n_squared + 1)): + return False + + # 检查是否保留了原始提示数字 + for i in range(n): + for j in range(n): + if original_grid[i][j] != "X" and test_grid[i][j] != original_grid[i][j]: + return False + + # 检查连续数字是否正交相邻 + for num in range(1, n_squared): + # 找到当前数字的位置 + current_pos = None + next_pos = None + for i in range(n): + for j in range(n): + if test_grid[i][j] == num: + current_pos = (i, j) + elif test_grid[i][j] == num + 1: + next_pos = (i, j) + + if current_pos is None or next_pos is None: + return False + + # 检查是否正交相邻(曼哈顿距离为1) + i1, j1 = current_pos + i2, j2 = next_pos + manhattan_distance = abs(i1 - i2) + abs(j1 - j2) + if manhattan_distance != 1: + return False + + return True + except Exception as e: + return False + + def extract_answer(self, test_solution: str, strict=False): + """从模型回答中提取网格""" + try: + import ast + import re + # 尝试找到 Python 列表格式的答案 + # 寻找形如 [[1, 2, 3], [4, 5, 6], [7, 8, 9]] 的模式 + pattern = r'\[\s*\[\s*\d+.*?\]\s*\]' + matches = re.finditer(pattern, test_solution, re.DOTALL) + match = None + + # 获取最后一个匹配项 + for m in matches: + match = m + if not match: + return None + + # 提取匹配的文本并尝试解析为 Python 对象 + grid_text = match.group(0) + + # 清理文本,确保它是有效的 Python 列表 + # 移除可能导致解析错误的字符 + grid_text = grid_text.replace("'", "").replace('"', "") + + # 解析为 Python 对象 + grid = ast.literal_eval(grid_text) + + # 确保是二维列表且所有元素都是整数 + if not isinstance(grid, list) or not all(isinstance(row, list) for row in grid): + return None + + if not all(isinstance(cell, int) for row in grid for cell in row): + return None + + return grid + except Exception as e: + return None \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/object_counting_verifier.py b/verl/utils/reward_score/synlogic/object_counting_verifier.py new file mode 100644 index 000000000..389cafce2 --- /dev/null +++ b/verl/utils/reward_score/synlogic/object_counting_verifier.py @@ -0,0 +1,44 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + + +class ObjectCountingVerifier(Verifier): + """ + 验证器用于物品计数游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = int(data.answer) + parsed_answer = self.extract_answer(test_answer) + with open("solution_str_OC.txt", "a") as f: + f.write("data.answer: " + data.answer + '\n') + f.write("test_answer: " + test_answer + '\n') + f.write("parsed_answer" + parsed_answer + '\n') + f.write('-'*32 + '\n') + + if parsed_answer is None: + return False + return int(parsed_answer) == ground_truth + + except Exception as e: + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\Box{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return answer_str + \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/object_properties_verifier.py b/verl/utils/reward_score/synlogic/object_properties_verifier.py new file mode 100644 index 000000000..6640a9baf --- /dev/null +++ b/verl/utils/reward_score/synlogic/object_properties_verifier.py @@ -0,0 +1,39 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + + +class ObjectPropertiesVerifier(Verifier): + """ + 验证器用于物品拥有游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = int(data.answer) + parsed_answer = int(self.extract_answer(test_answer)) + + if parsed_answer is None: + return False + return int(parsed_answer) == ground_truth + + except Exception as e: + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\Box{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\Box{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return None + \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/operation_verifier.py b/verl/utils/reward_score/synlogic/operation_verifier.py new file mode 100644 index 000000000..25a3000b7 --- /dev/null +++ b/verl/utils/reward_score/synlogic/operation_verifier.py @@ -0,0 +1,46 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import math_verify + + +class OperationVerifier(Verifier): + """ + 验证器用于物品计数游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = math_verify.parse(data.answer) + parsed_answer = math_verify.parse(test_answer) + + if parsed_answer is None: + return False + return math_verify.verify(parsed_answer, ground_truth) + except Exception as e: + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py b/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py new file mode 100644 index 000000000..05b010561 --- /dev/null +++ b/verl/utils/reward_score/synlogic/skyscraper_puzzle_verifier.py @@ -0,0 +1,167 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +import ast + + +class SkyscraperPuzzleVerifier(Verifier): + """ + 摩天楼游戏验证器,用于验证模型提供的解答是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否符合摩天楼游戏的规则 + + @param data: 包含游戏信息的Data对象 + @param test_answer: 游戏类提取的网格数据 + @return: 回答是否正确的布尔值 + """ + try: + # 获取游戏元数据 + metadata = data.metadata + n = metadata['n'] + top = metadata['top'] + bottom = metadata['bottom'] + left = metadata['left'] + right = metadata['right'] + + self.n = n + test_answer = self.extract_answer(test_solution) + + # print(f"验证: 游戏规模 {n}×{n}") + # print(f"上方提示: {top}") + # print(f"下方提示: {bottom}") + # print(f"左侧提示: {left}") + # print(f"右侧提示: {right}") + + # 使用提取好的网格数据 + grid = test_answer + + # 检查网格是否是字符串,如果是,说明提取失败 + if isinstance(grid, str): + # print("无法提取有效网格") + return False + + # print("提取的网格:") + # for row in grid: + # print(row) + + # 检查网格规模 + if len(grid) != n or any(len(row) != n for row in grid): + # print(f"网格规模不正确,应为 {n}×{n}") + return False + + # 检查数字范围 (1 到 n) + for i in range(n): + for j in range(n): + if not isinstance(grid[i][j], int) or grid[i][j] < 1 or grid[i][j] > n: + # print(f"位置 ({i+1},{j+1}) 的值 {grid[i][j]} 不在有效范围内 (1-{n})") + return False + + # 检查每行唯一性 + for i in range(n): + if len(set(grid[i])) != n: + # print(f"第 {i+1} 行包含重复数字") + return False + + # 检查每列唯一性 + for j in range(n): + column = [grid[i][j] for i in range(n)] + if len(set(column)) != n: + # print(f"第 {j+1} 列包含重复数字") + return False + + # 检查从上方观察 + for j in range(n): + visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n)]) + if visible_count != top[j]: + # print(f"从上方看第 {j+1} 列可见楼数为 {visible_count},应为 {top[j]}") + return False + + # 检查从下方观察 + for j in range(n): + visible_count = self._count_visible_skyscrapers([grid[i][j] for i in range(n-1, -1, -1)]) + if visible_count != bottom[j]: + # print(f"从下方看第 {j+1} 列可见楼数为 {visible_count},应为 {bottom[j]}") + return False + + # 检查从左侧观察 + for i in range(n): + visible_count = self._count_visible_skyscrapers(grid[i]) + if visible_count != left[i]: + # print(f"从左侧看第 {i+1} 行可见楼数为 {visible_count},应为 {left[i]}") + return False + + # 检查从右侧观察 + for i in range(n): + visible_count = self._count_visible_skyscrapers(grid[i][::-1]) + if visible_count != right[i]: + # print(f"从右侧看第 {i+1} 行可见楼数为 {visible_count},应为 {right[i]}") + return False + + # 所有检查通过 + # print("所有验证规则通过!") + return True + + except Exception as e: + return False + + def _count_visible_skyscrapers(self, heights): + """ + 计算从一个方向看过去能看到的摩天楼数量 + + @param heights: 从观察方向依次排列的摩天楼高度列表 + @return: 可见的摩天楼数量 + """ + visible_count = 0 + max_height = 0 + + for height in heights: + if height > max_height: + visible_count += 1 + max_height = height + + return visible_count + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取网格数据 + + @param test_solution: 模型的完整回答 + @return: 提取的解答网格数据 + """ + try: + n = self.n + + # 从 ```python 代码块中提取 + code_block_pattern = r"```python\s*\n([\s\S]*?)\n\s*```" + code_blocks = re.findall(code_block_pattern, test_solution) + + if code_blocks: + # 取第一个代码块(通常只有一个) + code_block = code_blocks[0].strip() + try: + # 直接解析代码块 + grid = ast.literal_eval(code_block) + # 验证是否为有效的n×n网格 + if (isinstance(grid, list) and + len(grid) == n and + all(isinstance(row, list) and len(row) == n for row in grid)): + return grid + except Exception: + # 如果直接解析失败,尝试移除注释后再解析 + code_without_comments = re.sub(r'#.*$', '', code_block, flags=re.MULTILINE) + try: + grid = ast.literal_eval(code_without_comments.strip()) + if (isinstance(grid, list) and + len(grid) == n and + all(isinstance(row, list) and len(row) == n for row in grid)): + return grid + except Exception: + pass + + # 如果提取失败,返回原始答案 + return test_solution + except Exception as e: + return test_solution \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py b/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py new file mode 100644 index 000000000..abc165d5c --- /dev/null +++ b/verl/utils/reward_score/synlogic/space_reasoning_tree_verifier.py @@ -0,0 +1,44 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import math_verify + +class SpaceReasoningTreeVerifier(Verifier): + """ + 验证器用于空间推理树游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + test_answer = self.extract_answer(test_answer) + if test_answer is None: + return False + test_answer = test_answer.replace(",", ",").replace(" ", "") + ground_truth = data.answer.replace(",", ",").replace(" ", "") + test_set = set(test_answer.split(",")) + ground_truth_set = set(ground_truth.split(",")) + return test_set == ground_truth_set + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/space_reasoning_verifier.py b/verl/utils/reward_score/synlogic/space_reasoning_verifier.py new file mode 100644 index 000000000..249f2dc08 --- /dev/null +++ b/verl/utils/reward_score/synlogic/space_reasoning_verifier.py @@ -0,0 +1,41 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import math_verify + + +class SpaceReasoningVerifier(Verifier): + """ + 验证器用于空间推理游戏的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + test_answer = self.extract_answer(test_answer) + if test_answer is None: + return False + return test_answer.lower() == data.answer.lower() + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从\boxed{开始截取到正确的闭合位置,处理嵌套括号 + start_index = last_box_index + len("\\boxed{") + bracket_stack = 1 # 已经遇到了一个左括号 + end_index = start_index + + while end_index < len(answer_str) and bracket_stack > 0: + if answer_str[end_index] == '{': + bracket_stack += 1 + elif answer_str[end_index] == '}': + bracket_stack -= 1 + end_index += 1 + + if bracket_stack != 0: # 括号不匹配 + return None + + # 提取\boxed{}内的内容 + latex_content = answer_str[start_index:end_index-1].strip() + return latex_content \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py b/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py new file mode 100644 index 000000000..2ea42446f --- /dev/null +++ b/verl/utils/reward_score/synlogic/star_placement_puzzle_verifier.py @@ -0,0 +1,158 @@ +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END +import re +import json +import ast + +import re + +class StarPlacementPuzzleVerifier(Verifier): + """ + 星星放置游戏验证器,用于验证模型提供的解答是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否符合星星放置游戏的规则 + + @param data: 包含游戏信息的Data对象 + @param star_coords: 通过extract_answer提取的星星坐标字典 {区域: [(行,列), ...]} + @return: 回答是否正确的布尔值 + """ + try: + star_coords = self.extract_answer(test_solution) + # 获取游戏元数据 + metadata = data.metadata + n = metadata['n'] + k = metadata['k'] + region_grid = metadata['region_grid'] + + # print(f"验证: 游戏规模 {n}×{n}, 每行/列/区域星星数量: {k}") + + # 检查是否有有效的星星坐标 + if not star_coords: + # print("无法从回答中提取有效的星星坐标") + return False + + # 创建一个表示星星位置的网格 + star_grid = [[0 for _ in range(n)] for _ in range(n)] + for region, coords in star_coords.items(): + for coord in coords: + row, col = coord + if row < 0 or row >= n or col < 0 or col >= n: + # print(f"无效坐标: ({row},{col}) - 超出网格范围") + return False + star_grid[row][col] = 1 + + # 打印星星网格以便调试 + # print("星星网格:") + # for row in star_grid: + # print(''.join(['* ' if cell == 1 else '. ' for cell in row])) + + # 1. 检查每行是否有k颗星星 + for i in range(n): + stars_in_row = sum(star_grid[i]) + if stars_in_row != k: + # print(f"行 {i+1} 有 {stars_in_row} 颗星星,应该有 {k} 颗") + return False + + # 2. 检查每列是否有k颗星星 + for j in range(n): + stars_in_col = sum(star_grid[i][j] for i in range(n)) + if stars_in_col != k: + # print(f"列 {j+1} 有 {stars_in_col} 颗星星,应该有 {k} 颗") + return False + + # 3. 检查每个区域是否有k颗星星 + regions = {} + for i in range(n): + for j in range(n): + region = region_grid[i][j] + if region not in regions: + regions[region] = [] + regions[region].append((i, j)) + + for region, cells in regions.items(): + stars_in_region = sum(star_grid[i][j] for i, j in cells) + if stars_in_region != k: + # print(f"区域 {region} 有 {stars_in_region} 颗星星,应该有 {k} 颗") + return False + + # 4. 检查星星是否互不相邻(水平、垂直、对角线) + for i in range(n): + for j in range(n): + if star_grid[i][j] == 1: + # 检查周围8个方向 + for di in [-1, 0, 1]: + for dj in [-1, 0, 1]: + if di == 0 and dj == 0: + continue # 跳过自身 + ni, nj = i + di, j + dj + if 0 <= ni < n and 0 <= nj < n and star_grid[ni][nj] == 1: + # print(f"星星在 ({i},{j}) 与星星在 ({ni},{nj}) 相邻") + return False + + # 所有检查通过 + # print("所有验证规则通过!") + return True + + except Exception as e: + return False + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取星星坐标 + + @param test_solution: 模型的完整回答 + @return: 提取的星星坐标字典 {区域: [(行,列), ...]} + """ + try: + # 从Python代码块中提取 + python_match = re.search(r'```python\s*\n(.*?)\n\s*```', test_solution, re.DOTALL) + if not python_match: + # print("回答中没有找到```python代码块") + return None + + code_content = python_match.group(1) + + # 尝试从Python代码中提取字典 + try: + # 先尝试直接提取字典内容 + dict_match = re.search(r'\{[^{}]*\}', code_content, re.DOTALL) + if dict_match: + dict_str = dict_match.group(0) + try: + # 将字符串转换为字典 + coords_dict = ast.literal_eval(dict_str) + # 如果成功且是字典类型,继续处理 + if isinstance(coords_dict, dict): + # 将坐标减1(因为用户输入的坐标是1-索引) + result = {} + for region, coords in coords_dict.items(): + result[region] = [(row-1, col-1) for row, col in coords] + return result + except (ValueError, SyntaxError) as e: + pass + + # 如果上面的方法失败,尝试解析变量赋值 + assign_match = re.search(r'(\w+)\s*=\s*(\{[^{}]*\})', code_content, re.DOTALL) + if assign_match: + dict_str = assign_match.group(2) + try: + # 将字符串转换为字典 + coords_dict = ast.literal_eval(dict_str) + # 如果成功且是字典类型,继续处理 + if isinstance(coords_dict, dict): + # 将坐标减1(因为用户输入的坐标是1-索引) + result = {} + for region, coords in coords_dict.items(): + result[region] = [(row-1, col-1) for row, col in coords] + return result + except (ValueError, SyntaxError) as e: + pass + except Exception as e: + pass + + return None + + except Exception as e: + return None \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/synlogic.py b/verl/utils/reward_score/synlogic/synlogic.py new file mode 100644 index 000000000..29b08b122 --- /dev/null +++ b/verl/utils/reward_score/synlogic/synlogic.py @@ -0,0 +1,92 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +# from .game_of_24.scripts.game_of_24_verifier import GameOf24Verifier +# from .cryptarithm.scripts.cryptarithm_verifier import CryptarithmVerifier +# from .survo.scripts.survo_verifier import SurvoVerifier +from .campsite_verifier import CampsiteVerifier +from .skyscraper_puzzle_verifier import SkyscraperPuzzleVerifier +from .web_of_lies_verifier import WebOfLiesVerifier +from .goods_exchange_verifier import GoodsExchangeVerifier +# from .sudoku.scripts.sudoku_verifier import SudokuVerifier +# from corpus.misc.tasks.zebra_puzzle.scripts.zebra_puzzle_verifier import ZebraPuzzleVerifier +# from corpus.misc.tasks.bbeh.scripts.bbeh_verifier import BBEHVerifier +# from corpus.misc.tasks.arc_agi.scripts.arc_agi_verifier import ArcAGIVerifier +from .object_properties_verifier import ObjectPropertiesVerifier +from .object_counting_verifier import ObjectCountingVerifier +from .star_placement_puzzle_verifier import StarPlacementPuzzleVerifier +from .arrow_maze_verifier import ArrowMazeVerifier +# from .kukurasu.scripts.kukurasu_verifier import KukurasuVerifier +from .number_wall_verifier import NumberWallVerifier +from .numbrix_verifier import NumbrixVerifier +from .norinori_verifier import NorinoriVerifier +from .minesweeper_verifier import MinesweeperVerifier +from .operation_verifier import OperationVerifier +from .word_sorting_mistake_verifier import WordSortingMistakeVerifier +from .math_path_verifier import MathPathVerifier +from .boolean_expressions_verifier import BooleanExpressionsVerifier +from .space_reasoning_verifier import SpaceReasoningVerifier +from .space_reasoning_tree_verifier import SpaceReasoningTreeVerifier +from .word_sorting_verifier import WordSortingVerifier +# from corpus.misc.tasks.gpqa.scripts.gpqa_verifier import GPQAVerifier +# from .cipher.scripts.cipher_verifier import CipherVerifier +from .time_sequence_verifier import TimeSequenceVerifier +from .wordscapes_verifier import WordscapesVerifier +# from corpus.misc.tasks.bbh.scripts.boolean_expressions_verifier import BBHBooleanExpressionsVerifier +# from corpus.misc.tasks.bbh.scripts.causal_judgement_verifier import BBHCausalJudgementVerifier # yes no +# from corpus.misc.tasks.bbh.scripts.date_understanding_verifier import BBHDateUnderstandingVerifier # multi-choice +# from corpus.misc.tasks.bbh.scripts.dyck_languages_verifier import BBHDyckLanguagesVerifier +# from corpus.misc.tasks.bbh.scripts.formal_fallacies_verifier import BBHFormalFallaciesVerifier +# from corpus.misc.tasks.bbh.scripts.multistep_arithmetic_two_verifier import BBHMultistepArithmeticVerifier # number +# from corpus.misc.tasks.bbh.scripts.sports_understanding_verifier import BBHSportsUnderstandingVerifier +# from corpus.misc.tasks.bbh.scripts.web_of_lies_verifier import BBHWebOfLiesVerifier +# from corpus.misc.tasks.bbh.scripts.word_sorting_verifier import BBHWordSortingVerifier +from .game_of_buggy_tables_verifier import BuggyTableVerifier +# from .calcudoko.scripts.calcudoko_verifier import CalcudokoVerifier +from .dyck_language_verifier import DyckLanguageVerifier +from .dyck_language_errors_verifier import DyckLanguageErrorsVerifier +from .dyck_language_reasoning_errors_verifier import DyckLanguageReasoningErrorsVerifier +# from .futoshiki.scripts.futoshiki_verifier import FutoshikiVerifier + +# NOTE: Add new tasks in alphabetical order +verifier_classes = { + "arrow_maze": ArrowMazeVerifier, + "boolean_expressions": BooleanExpressionsVerifier, + "buggy_tables": BuggyTableVerifier, + # "calcudoko": CalcudokoVerifier, + "campsite": CampsiteVerifier, + # "cipher": CipherVerifier, + # "cryptarithm": CryptarithmVerifier, + "dyck_language": DyckLanguageVerifier, + "dyck_language_errors": DyckLanguageErrorsVerifier, + "dyck_language_reasoning_errors": DyckLanguageReasoningErrorsVerifier, + # "futoshiki": FutoshikiVerifier, + "goods_exchange": GoodsExchangeVerifier, + # "gpqa_diamond": GPQAVerifier, + # "kukurasu": KukurasuVerifier, + "math_path": MathPathVerifier, + # "arc_agi": ArcAGIVerifier, + # "arc_agi_2": ArcAGIVerifier, + # "mathador": GameOf24Verifier, + "minesweeper": MinesweeperVerifier, + "norinori": NorinoriVerifier, + "number_wall": NumberWallVerifier, + "numbrix": NumbrixVerifier, + "object_counting": ObjectCountingVerifier, + "object_properties": ObjectPropertiesVerifier, + "operation": OperationVerifier, + "skyscraper_puzzle": SkyscraperPuzzleVerifier, + "space_reasoning": SpaceReasoningVerifier, + "space_reasoning_tree": SpaceReasoningTreeVerifier, + "star_placement_puzzle": StarPlacementPuzzleVerifier, + # "sudoku": SudokuVerifier, + # "survo": SurvoVerifier, + "time_sequence": TimeSequenceVerifier, + "web_of_lies": WebOfLiesVerifier, + "word_sorting": WordSortingVerifier, + "word_sorting_mistake": WordSortingMistakeVerifier, + "wordscapes": WordscapesVerifier, + # "zebra_puzzle": ZebraPuzzleVerifier, + # ** bbeh_classes, + # ** bbh_classes, +} \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/time_sequence_verifier.py b/verl/utils/reward_score/synlogic/time_sequence_verifier.py new file mode 100644 index 000000000..8ae6aa161 --- /dev/null +++ b/verl/utils/reward_score/synlogic/time_sequence_verifier.py @@ -0,0 +1,66 @@ +import json +import numpy as np +from .data import Data +from .verifier import Verifier +import re + +class TimeSequenceVerifier(Verifier): + """ + 验证器用于验证 time sequence 的答案是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的答案,格式为数字列表 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution) + # 解析元数据 + metadata = data.metadata + true_answers = metadata['records']['answers'] + + # 解析模型给出的列表 + try: + test_list = json.loads(test_answer.replace(",", ",")) + except: + return False + + try: + if test_list[0]!=true_answers['answer_maxLen']: + # print(f"最长会议时间不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]") + return False + if test_list[1]!=true_answers['answer_nums']: + # print(f"可选会议数量不正确。model:{test_answer} *** true:[{true_answers['answer_maxLen']}, {true_answers['answer_nums']}]") + return False + except: + return False + + # 所有检查都通过 + # print("验证结果: 正确") + return True + except Exception as e: + return False + + def extract_answer(self, test_solution: str): + """ + 从模型的回答中提取答案(矩阵) + + @param test_solution: 模型的完整回答 + @return: 提取答案列表 + """ + if not test_solution: + return "" + + # 尝试提取列表 + matrix_pattern = r'\[.*?\]' + matrix_matches = re.findall(matrix_pattern, test_solution, re.DOTALL) + if matrix_matches: + # 使用最后一个匹配的列表 + # print(matrix_matches) + return matrix_matches[-1].strip() + + # 如果失败,返回空字符串 + return "" \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/verifier.py b/verl/utils/reward_score/synlogic/verifier.py new file mode 100644 index 000000000..498e87a82 --- /dev/null +++ b/verl/utils/reward_score/synlogic/verifier.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from .data import Data + +class Verifier(ABC): + """ + Base class for verifier + """ + def __init__(self): + pass + + @abstractmethod + def verify(self, data: Data, test_answer: str): + """ + Verify whether the test answer is consistent with the gold answer + @param data: Data + @param test_answer: str + @return: bool + """ + raise NotImplementedError("Verifier.verify() is not implemented") + + @abstractmethod + def extract_answer(self, test_solution: str): + """ + Extract the answer from the test solution + @param test_solution: str + @return: str + """ + raise NotImplementedError("Verifier.extract_answer() is not implemented") + +import re + +THOUGHT_DELIMITER_START = "" +THOUGHT_DELIMITER_END = "" + +def _extract_answer(text): + # 定义正则表达式模式,匹配 之间的内容 + pattern = r'(.*?)' + + # 使用 re.search 查找第一个匹配项 + match = re.search(pattern, text, re.DOTALL) + + # 如果找到匹配项,返回匹配的内容 + if match: + return match.group(1).strip() + else: + return None + +def _extract_solution_with_thought(solution_str): + + model_output = solution_str + + if THOUGHT_DELIMITER_END in solution_str: + model_output = solution_str.split(THOUGHT_DELIMITER_END)[1] + + predict_answer = _extract_answer(model_output) + + + if predict_answer is not None: + return predict_answer + else: + return "" + + +class ExactMatchVerifier(Verifier): + """ + Verifier for Exact Match + """ + def verify(self, data: Data, test_solution: str): + try: + test_answer = self.extract_answer(test_solution) + ground_truth = data.answer + correct = test_answer == ground_truth + if correct: + acc_score = 1.0 + else: + acc_score = 0 + + return acc_score + except: + return False + + def extract_answer(self, test_solution: str): + return _extract_solution_with_thought(solution_str=test_solution) \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/web_of_lies_verifier.py b/verl/utils/reward_score/synlogic/web_of_lies_verifier.py new file mode 100644 index 000000000..e301d2cea --- /dev/null +++ b/verl/utils/reward_score/synlogic/web_of_lies_verifier.py @@ -0,0 +1,134 @@ +import re +from .data import Data +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + +class WebOfLiesVerifier(Verifier): + """ + 验证器用于检查谎言之网游戏的答案是否正确 + """ + def verify(self, data: Data, test_solution: str): + """ + 验证模型的回答是否正确 + + @param data: 包含问题、元数据等信息的Data对象 + @param test_answer: 模型给出的回答字符串 + @return: 回答是否正确的布尔值 + """ + try: + test_answer = self.extract_answer(test_solution) + # 获取预期答案和测试答案 + expected_answer = data.answer.lower() + + # 清理测试答案 + test_answer = test_answer.lower() + + # 提取预期答案中的真假值 + expected_truths = self._parse_answer(expected_answer) + + # 提取测试答案中的真假值 + test_truths = self._parse_answer(test_answer) + + # print(f"验证: 预期答案={expected_truths}, 模型答案={test_truths}") + + # 检查答案列表长度是否匹配 + if len(expected_truths) != len(test_truths): + # print(f"验证失败: 答案长度不匹配,预期 {len(expected_truths)},实际 {len(test_truths)}") + return False + + # 检查每个位置的答案是否匹配 + for i, (expected, actual) in enumerate(zip(expected_truths, test_truths)): + if expected != actual: + # print(f"验证失败: 第 {i+1} 个答案不匹配,预期 {expected},实际 {actual}") + return False + + # print("验证成功: 所有答案匹配") + return True + + except Exception as e: + return False + + def _parse_answer(self, answer_str): + """ + 从答案字符串中解析出真假值列表 + + @param answer_str: 答案字符串 + @return: 真假值列表,True表示说真话,False表示说谎话 + """ + # 尝试匹配英文答案格式 (yes/no) + yes_pattern = r'yes|true|truth' + no_pattern = r'no|false|lie' + + # 尝试匹配中文答案格式 (是/否) + cn_yes_pattern = r'是|真话|真' + cn_no_pattern = r'否|假话|假|谎' + + # 组合模式 + yes_patterns = f'({yes_pattern}|{cn_yes_pattern})' + no_patterns = f'({no_pattern}|{cn_no_pattern})' + + # 根据答案字符串中的关键词确定真假值 + truths = [] + + # 寻找所有可能的yes/no或是/否答案 + all_answers = re.findall(rf'{yes_patterns}|{no_patterns}', answer_str) + + for match in all_answers: + # match是一个元组,需要找到非空的元素 + match_str = next((m for m in match if m), '') + + if re.search(yes_pattern, match_str) or re.search(cn_yes_pattern, match_str): + truths.append(True) + elif re.search(no_pattern, match_str) or re.search(cn_no_pattern, match_str): + truths.append(False) + + return truths + + def extract_answer(self, test_solution: str) -> str: + """ + 从模型的回答中提取答案 + + @param test_solution: 模型的完整回答 + @return: 提取的答案 + """ + if not test_solution: + return "" + # 中文模式 + cn_patterns = [ + r'答案是[::]\s*\*\*([^*]+)\*\*[.。]*$', # 匹配"答案是:**是,否,是**"格式 + ] + + # 英文模式 + en_patterns = [ + r'[Tt]he answer is[::=]\s*\*\*([^*]+)\*\*[.。]*$', # 匹配"The answer is: **yes, no, yes**"格式 + ] + + # 尝试匹配所有模式 + patterns = cn_patterns + en_patterns + + for pattern in patterns: + matches = re.findall(pattern, test_solution, re.DOTALL) + if matches: + return matches[-1].strip() + + # 如果上面的模式都没匹配到,尝试更宽松的匹配 + # 查找最后一行中的加粗文本 + lines = test_solution.strip().split('\n') + if lines: + last_line = lines[-1].strip() + bold_match = re.search(r'\*\*([^*]+)\*\*', last_line) + if bold_match: + return bold_match.group(1).strip() + + # 尝试匹配"答案是"或"The answer is"后面的文本 + answer_match = re.search(r'(?:答案是|[Tt]he answer is)[::=]?\s*(.*?)(?:[.。]|$)', last_line) + if answer_match: + return answer_match.group(1).strip() + + # 如果没有找到格式化的答案,尝试直接匹配yes/no或是/否序列 + yes_no_pattern = r'(?:\b(?:yes|no|是|否)\b[,,\s]*)+' + matches = re.findall(yes_no_pattern, test_solution.lower()) + if matches: + return matches[-1].strip() + + # 如果没有匹配到任何模式,返回空字符串 + return "" \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py b/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py new file mode 100644 index 000000000..f2ee2109a --- /dev/null +++ b/verl/utils/reward_score/synlogic/word_sorting_mistake_verifier.py @@ -0,0 +1,44 @@ +import re +from .data import Data +from .verifier import Verifier + +class WordSortingMistakeVerifier(Verifier): + """ + 验证器用于word sorting mistake的答案是否正确 + """ + def verify(self, data: Data, test_answer: str): + try: + ground_truth = data.answer if data.answer is not None else "No" + parsed_answer = self.extract_answer(test_answer) + + if parsed_answer is None: + return False + + if parsed_answer.isdigit(): + try: + return int(parsed_answer) == int(ground_truth) + except Exception as e: + return False + else: + return parsed_answer.lower() == ground_truth.lower() + except Exception as e: + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\boxed{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\boxed{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return None + \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/word_sorting_verifier.py b/verl/utils/reward_score/synlogic/word_sorting_verifier.py new file mode 100644 index 000000000..4032ac21c --- /dev/null +++ b/verl/utils/reward_score/synlogic/word_sorting_verifier.py @@ -0,0 +1,42 @@ +import re +from .data import Data +from .verifier import Verifier + +class WordSortingVerifier(Verifier): + """ + 验证器用于单词排序游戏的答案是否正确 + """ + def str2list(self, answer_str): + # 替换中文逗号为英文逗号,并删除所有空格 + answer_str = answer_str.replace(",", ",").replace(" ", "") + return [w.strip() for w in answer_str.split(",")] + + def verify(self, data: Data, test_answer: str): + try: + ground_truth = self.str2list(data.answer) + parsed_answer = self.str2list(self.extract_answer(test_answer)) + + if parsed_answer is None: + return False + return parsed_answer == ground_truth + + except Exception as e: + return False + + def extract_answer(self, answer_str): + # 先找到最后一个\Box{的位置 + last_box_index = answer_str.rfind("\\boxed{") + + if last_box_index == -1: + return None + + # 从最后一个\Box{开始截取字符串 + last_box_substring = answer_str[last_box_index:] + + # 在截取的子字符串中进行正则匹配 + box_pattern = r'\\boxed\{([^}]*)\}' + match = re.search(box_pattern, last_box_substring) + + if match: + return match.group(1).strip() + return None \ No newline at end of file diff --git a/verl/utils/reward_score/synlogic/wordscapes_verifier.py b/verl/utils/reward_score/synlogic/wordscapes_verifier.py new file mode 100644 index 000000000..1efeb5e21 --- /dev/null +++ b/verl/utils/reward_score/synlogic/wordscapes_verifier.py @@ -0,0 +1,157 @@ +""" +Wordscapes verifier module for the reasonreason framework. +""" + +import json +import re +from .verifier import Verifier, THOUGHT_DELIMITER_START, THOUGHT_DELIMITER_END + +debug_mode = False + +class WordscapesVerifier(Verifier): + """ + Verifier for Wordscapes game + """ + def verify(self, data, test_solution: str): + """ + Verify whether the test answer is consistent with the gold answer + + Args: + data: WordscapesData + test_solution: str containing the solution + + Returns: + float: Score between 0 and 1 + """ + try: + extracted_answer = self.extract_answer(test_solution) + if not extracted_answer: + return False + + if debug_mode: + for row in extracted_answer: + print(" ".join(cell if cell != " " else "_" for cell in row)) + + # Get grid, across_words, and down_words from data + grid = data.metadata["grid"] + across_words = data.metadata["across_words"] + down_words = data.metadata["down_words"] + + # Validate grid dimensions + if len(extracted_answer) != len(grid): + # print(f"Grid height mismatch: expected {len(grid)}, got {len(extracted_answer)}") + return False + + for i in range(len(grid)): + if len(extracted_answer[i]) != len(grid[i]): + # print(f"Grid width mismatch at row {i}: expected {len(grid[i])}, got {len(extracted_answer[i])}") + return False + + # Check if the answer respects the grid layout (X for letters, 0 for empty) + for i in range(len(grid)): + for j in range(len(grid[i])): + if grid[i][j] == "0" and extracted_answer[i][j].strip(): + # print(f"Expected empty space at position ({i},{j}), got '{extracted_answer[i][j]}'") + return False + if grid[i][j] == "X" and not extracted_answer[i][j].strip(): + # print(f"Expected letter at position ({i},{j}), got empty space") + return False + + # Verify across words + for word in across_words: + found = False + for i in range(len(extracted_answer)): + row_str = ''.join(extracted_answer[i]).replace(' ', '').lower() + if word.lower() in row_str: + found = True + break + if not found and word: + # print(f"Across word '{word}' not found in the grid") + return 0 + + # Verify down words + for word in down_words: + found = False + for j in range(len(extracted_answer[0])): + col = [] + for i in range(len(extracted_answer)): + if j < len(extracted_answer[i]): + col.append(extracted_answer[i][j]) + col_str = ''.join(col).replace(' ', '').lower() + if word.lower() in col_str: + found = True + break + if not found and word: # Only check if word is not empty + # print(f"Down word '{word}' not found in the grid") + return False + + # All checks passed + return True + except Exception as e: + return False + + def extract_answer(self, test_solution: str): + """ + Extract the answer from the test solution + + Args: + test_solution: str + + Returns: + list: 2D grid of the answer or None if extraction fails + """ + try: + # Remove thoughts if present + if THOUGHT_DELIMITER_START in test_solution and THOUGHT_DELIMITER_END in test_solution: + # Extract only the part after the thoughts + thought_end_pos = test_solution.rfind(THOUGHT_DELIMITER_END) + if thought_end_pos >= 0: + test_solution = test_solution[thought_end_pos + len(THOUGHT_DELIMITER_END):] + + # Clean up the response and find the grid pattern + # Look for a pattern like [[...]] or [[[...]]] + grid_pattern = re.search(r'\[\s*\[(?:\s*\[)?(.+?)(?:\]\s*)?\]\s*\]', test_solution, re.DOTALL) + if not grid_pattern: + return None + + grid_text = grid_pattern.group(1) + + # Handle various formats + rows = [] + + # Check if rows are separated by commas + split_rows = re.split(r'\],\s*\[', grid_text) + + for row_text in split_rows: + # Clean the row text and extract characters + row_text = row_text.strip().strip('[],') + + # Extract quoted characters: "X" or 'X' or just X + chars = [] + + # Look for quoted strings or standalone characters + char_matches = re.findall(r'\"([^\"]*)\"|\'([^\']*)\'|([^,\s]+)', row_text) + + for match in char_matches: + # Take the first non-empty group from each match + char = next((x for x in match if x), "") + + # Handle numeric or empty values (0, "", '') + if char == "0" or char == "": + char = " " + + chars.append(char) + + if chars: # Only add non-empty rows + rows.append(chars) + + # Make sure we have a valid grid + if not rows or not all(rows): + return None + + return rows + + except Exception as e: + print(f"NOTE!!! parse error!!!! (Wordscapes): {e}") + return None + \ No newline at end of file