From a640e0c4aec0fa512cced0c78d3b41a9bb538795 Mon Sep 17 00:00:00 2001 From: BlankCheng <913501223@qq.com> Date: Mon, 7 Jul 2025 02:24:53 +0000 Subject: [PATCH 1/2] Implement multiprocess reward manager without async --- .../reward_manager/mp_reward_manager.py | 189 ++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 verl/workers/reward_manager/mp_reward_manager.py diff --git a/verl/workers/reward_manager/mp_reward_manager.py b/verl/workers/reward_manager/mp_reward_manager.py new file mode 100644 index 000000000..c8179a9c3 --- /dev/null +++ b/verl/workers/reward_manager/mp_reward_manager.py @@ -0,0 +1,189 @@ +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, TimeoutError + +import torch + +from verl import DataProto +from verl.utils.reward_score import _default_compute_score + +# --- Helper Function for Parallel Computation (with Per-Job Timeout) --- + + +def parallel_compute_score(compute_score_fn, data_sources, solutions, ground_truths, extra_infos, num_processes=64, timeout_per_job=180): + """ + Computes scores in parallel with a timeout for each individual job. + + Args: + compute_score_fn: The function to compute the score for a single item. + ... (other args) + num_processes: The number of worker processes. + timeout_per_job: A timeout in seconds for each individual call to + compute_score_fn. If a single job takes longer than this, + it will be aborted and its result will be None. + + Returns: + A list of results, with None for timed-out or failed jobs. + """ + results = [None] * len(solutions) + + with ProcessPoolExecutor(max_workers=num_processes) as executor: + # 1. Submit all jobs to the pool. executor.submit returns a Future object. + # A Future represents a computation that may or may not have completed. + futures = {executor.submit(compute_score_fn, data_source=ds, solution_str=sol, ground_truth=gt, extra_info=ei): i for i, (ds, sol, gt, ei) in enumerate(zip(data_sources, solutions, ground_truths, extra_infos))} + + # 2. Retrieve results as they complete, with a per-job timeout. + for future in futures: + original_index = futures[future] + try: + # future.result() waits for the single job to finish. + # The timeout here applies ONLY to this one job. + result = future.result(timeout=timeout_per_job) + results[original_index] = result + except TimeoutError: + print(f"Job {original_index} timed out after {timeout_per_job} seconds. Result is None.") + # The result for this index remains None + except Exception as e: + print(f"Job {original_index} failed with an error: {e}. Result is None.") + # The result for this index remains None + + return results + + +# --- Synchronous Reward Manager (Updated to use Per-Job Timeout) --- + + +class MultiProcessRewardManager: + """ + The reward manager, now robust to individual process timeouts. + """ + + def __init__( + self, + tokenizer, + num_examine, + compute_score=None, + reward_fn_key="data_source", + max_resp_len=None, + overlong_buffer_cfg=None, + shuffle_data=True, + num_processes=64, + timeout_per_job=180, # Changed from timeout_per_batch + **kwargs, + ) -> None: + self.tokenizer = tokenizer + self.num_examine = num_examine + self.compute_score = compute_score or _default_compute_score + self.reward_fn_key = reward_fn_key + self.overlong_buffer_cfg = overlong_buffer_cfg + self.max_resp_len = max_resp_len + self.shuffle_data = shuffle_data + self.num_processes = num_processes + self.timeout_per_job = timeout_per_job # Timeout for each individual job + + if self.overlong_buffer_cfg is not None: + assert self.max_resp_len is not None, f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None" + + def __call__(self, data: DataProto, return_dict: bool = False): + """Computes rewards for a batch of data synchronously.""" + + if "rm_scores" in data.batch.keys(): + return {"reward_tensor": data.batch["rm_scores"]} if return_dict else data.batch["rm_scores"] + + reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) + reward_extra_info = defaultdict(list) + + # --- Data Preparation (largely unchanged) --- + # NOTE: Shuffling is now handled slightly differently because results + # from executor.submit do not arrive in order. We map results + # back to their original positions. + data_sources, solutions, ground_truths, extra_infos = [], [], [], [] + valid_response_lengths, prompt_strs, response_strs = [], [], [] + + for i in range(len(data)): + data_item = data[i] + prompt_ids = data_item.batch["prompts"] + prompt_length = prompt_ids.shape[-1] + valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + response_ids = data_item.batch["responses"] + valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum() + valid_response_lengths.append(valid_response_length) + valid_response_ids = response_ids[:valid_response_length] + prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True) + prompt_strs.append(prompt_str) + response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True) + if response_str.endswith(self.tokenizer.eos_token): + response_str = response_str[: -len(self.tokenizer.eos_token)] + response_strs.append(response_str) + ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"] + data_source = data_item.non_tensor_batch[self.reward_fn_key] + extra_info = data_item.non_tensor_batch.get("extra_info", None) + data_sources.append(data_source) + solutions.append(response_str) + ground_truths.append(ground_truth) + extra_infos.append(extra_info) + + # --- Run Synchronous Parallel Computation --- + # The new function is called here with the per-job timeout. + results = parallel_compute_score( + self.compute_score, + data_sources, + solutions, + ground_truths, + extra_infos, + num_processes=self.num_processes, + timeout_per_job=self.timeout_per_job, # Pass the per-job timeout + ) + + # --- Process Results --- + # No need to handle shuffling here anymore, as results are mapped back to their original index. + already_print_data_sources = {} + for i in range(len(data)): + result = results[i] # Results are now in the correct original order + + # Get other data using the original index + data_source = data_sources[i] + response_str = response_strs[i] + prompt_str = prompt_strs[i] + ground_truth = ground_truths[i] + valid_response_length = valid_response_lengths[i] + + score = 0.0 + if result is None: + # This handles individual failures and timeouts + result = {"score": 0.0, "acc": 0.0, "reason": "timeout_or_error"} + elif not isinstance(result, dict): + result = {"score": result, "acc": result} + + score = result.get("score", 0.0) + for key, value in result.items(): + reward_extra_info[key].append(value) + + reward = score + if self.overlong_buffer_cfg is not None and self.overlong_buffer_cfg.enable: + # ... (overlong penalty logic is unchanged) ... + overlong_buffer_len = self.overlong_buffer_cfg.len + expected_len = self.max_resp_len - overlong_buffer_len + exceed_len = valid_response_length - expected_len + overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor + overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0) + reward += overlong_reward + if self.overlong_buffer_cfg.log: + reward_extra_info["overlong_reward"].append(overlong_reward) + reward_extra_info["overlong"].append(overlong_reward < 0) + + # Assign reward to the correct position in the final tensor + reward_tensor[i, valid_response_length - 1] = reward + + if already_print_data_sources.get(data_source, 0) < self.num_examine: + already_print_data_sources[data_source] = already_print_data_sources.get(data_source, 0) + 1 + print("[prompt]", prompt_str) + print("[response]", response_str) + print("[ground_truth]", ground_truth) + for key, value in result.items(): + print(f"[{key}]", value) + + if return_dict: + return {"reward_tensor": reward_tensor, "reward_extra_info": reward_extra_info} + else: + return reward_tensor From a9c1cf7656089f85b0a9a7a5cc44fa20fb360eb6 Mon Sep 17 00:00:00 2001 From: BlankCheng <913501223@qq.com> Date: Mon, 7 Jul 2025 02:27:13 +0000 Subject: [PATCH 2/2] Remove unused timeout class --- verl/utils/reward_score/naive_dapo.py | 104 ++++++++++---------------- 1 file changed, 40 insertions(+), 64 deletions(-) diff --git a/verl/utils/reward_score/naive_dapo.py b/verl/utils/reward_score/naive_dapo.py index d26a1dd72..8c1acb779 100644 --- a/verl/utils/reward_score/naive_dapo.py +++ b/verl/utils/reward_score/naive_dapo.py @@ -13,35 +13,17 @@ # limitations under the License. # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py +import math +import os import re -import signal -from typing import Optional import sympy from pylatexenc import latex2text from sympy.parsing import sympy_parser -import os + from .prime_math import math_normalize from .prime_math.grader import math_equal - -class timeout: - - def __init__(self, seconds=1, error_message="Timeout"): - self.seconds = seconds - self.error_message = error_message - - def handle_timeout(self, signum, frame): - raise TimeoutError(self.error_message) - - def __enter__(self): - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.alarm(self.seconds) - - def __exit__(self, type, value, traceback): - signal.alarm(0) - - # Constants for normalization SUBSTITUTIONS = [ ("an ", ""), @@ -103,10 +85,10 @@ def __exit__(self, type, value, traceback): def normalize_final_answer(final_answer: str) -> str: """Normalize a final answer to a quantitative reasoning question. - + Args: final_answer: The answer string to normalize - + Returns: Normalized answer string """ @@ -153,7 +135,6 @@ def timeout(timeout_seconds: int = 8): import signal def decorator(func): - def handler(signum, frame): raise TimeoutError("Operation timed out!") @@ -213,7 +194,7 @@ def _is_float(num: str) -> bool: def _is_int(x: float) -> bool: try: return abs(x - int(round(x))) <= 1e-7 - except: + except Exception: return False @@ -226,7 +207,7 @@ def _str_is_int(x: str) -> bool: x = _strip_properly_formatted_commas(x) x = float(x) return abs(x - int(round(x))) <= 1e-7 - except: + except Exception: return False @@ -279,26 +260,26 @@ def _normalize(expr: str) -> str: expr = expr.replace("trillion", "*10^12") for unit in [ - "degree", - "cm", - "centimeter", - "meter", - "mile", - "second", - "minute", - "hour", - "day", - "week", - "month", - "year", - "foot", - "feet", - "inch", - "yard", - "liter", + "degree", + "cm", + "centimeter", + "meter", + "mile", + "second", + "minute", + "hour", + "day", + "week", + "month", + "year", + "foot", + "feet", + "inch", + "yard", + "liter", ]: expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) - expr = re.sub(f"\^ *\\\\circ", "", expr) + expr = re.sub("\^ *\\\\circ", "", expr) if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": expr = expr[1:-1] @@ -309,7 +290,7 @@ def _normalize(expr: str) -> str: if "\\" in expr: try: expr = _parse_latex(expr) - except: + except Exception: pass # edge case with mixed numbers and negative signs @@ -359,7 +340,7 @@ def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): simplified = sympy.simplify(sympy_diff) if simplified == 0: are_equal = True - except: + except Exception: pass return are_equal @@ -371,8 +352,7 @@ def split_tuple(expr: str): expr = _strip_properly_formatted_commas(expr) if len(expr) == 0: return [] - if (len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and - all([ch not in expr[1:-1] for ch in TUPLE_CHARS])): + if len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]): elems = [elem.strip() for elem in expr[1:-1].split(",")] else: elems = [expr] @@ -411,8 +391,7 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: ground_truth_elems = split_tuple(ground_truth_normalized) given_elems = split_tuple(given_normalized) - if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or - ground_truth_normalized[-1] != given_normalized[-1]): + if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]): is_correct = False elif len(ground_truth_elems) != len(given_elems): is_correct = False @@ -432,6 +411,7 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: return is_correct, given_normalized + def _last_boxed_only_string(string): idx = string.rfind("\\boxed") if idx < 0: @@ -459,7 +439,7 @@ def _last_boxed_only_string(string): if left_brace_idx is None or right_brace_idx is None: return None - return string[left_brace_idx + 1:right_brace_idx].strip() + return string[left_brace_idx + 1 : right_brace_idx].strip() def match_answer(response): @@ -471,21 +451,18 @@ def match_answer(response): if ans_boxed: is_matched = True response = ans_boxed - + return is_matched, response -import math -def compute_score(solution_str: str, - ground_truth: str, - extra_info: dict) -> float: +def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> float: """Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions - + Args: solution_str: The solution string ground_truth: The ground truth answer extra_info: dict with additional info for the score computation - + Returns: Reward score (1.0 for correct, -1.0 for incorrect) """ @@ -495,13 +472,13 @@ def compute_score(solution_str: str, # Extract answer from generated output is_matched, extracted_model_output = match_answer(model_output) - + # TWK NOTE: WE REMOVED THE RESPONSE TRUNCATION FROM math_dapo.compute_score # Verify the solution, first check simple comparisons. correct, pred = grade_answer(extracted_model_output, ground_truth) - if not correct: + if not correct: try: if "\\pi" in extracted_model_output or "\\pi" in ground_truth: equivs = [] @@ -510,15 +487,14 @@ def compute_score(solution_str: str, correct = any(equivs) else: correct = math_equal(extracted_model_output, ground_truth, timeout=True) - except: + except Exception: correct = False - # reward = 1.0 if correct else -1.0 - reward = 1.0 if correct else 0. + reward = 1.0 if correct else 0.0 acc = correct return { "score": reward, "acc": acc, - } \ No newline at end of file + }