diff --git a/metaclaw/data_formatter.py b/metaclaw/data_formatter.py index 13eae0c..908969b 100644 --- a/metaclaw/data_formatter.py +++ b/metaclaw/data_formatter.py @@ -219,6 +219,12 @@ def compute_advantages(batch: list[ConversationSample]) -> list[float]: Centre-and-scale rewards within the batch (GRPO style: (r - mean) / (std + eps)). Returns a list of float advantages, one per sample. + + When the batch contains only one sample (or all rewards are identical), + std is 0 and the normalised advantage collapses to 0 for every sample, + producing no gradient signal at all. In that case we fall back to + returning the raw (centred) rewards so the training step still gets a + meaningful update. """ if not batch: return [] @@ -226,5 +232,6 @@ def compute_advantages(batch: list[ConversationSample]) -> list[float]: mean_r = sum(rewards) / len(rewards) variance = sum((r - mean_r) ** 2 for r in rewards) / len(rewards) std_r = variance ** 0.5 - eps = 1e-8 - return [(r - mean_r) / (std_r + eps) for r in rewards] + if std_r < 1e-8: + return [r - mean_r if len(batch) > 1 else r for r in rewards] + return [(r - mean_r) / std_r for r in rewards] diff --git a/metaclaw/prm_scorer.py b/metaclaw/prm_scorer.py index cab01c6..abb737d 100644 --- a/metaclaw/prm_scorer.py +++ b/metaclaw/prm_scorer.py @@ -167,6 +167,7 @@ def __init__( temperature: float = 0.6, max_new_tokens: int = 1024, llm_client=None, + timeout: float = 120.0, ): if llm_client is not None: self._client = llm_client @@ -180,6 +181,7 @@ def __init__( self.prm_m = prm_m self.temperature = temperature self.max_new_tokens = max_new_tokens + self.timeout = timeout async def evaluate( self, @@ -225,15 +227,21 @@ async def _query_once( self, messages: list[dict], vote_id: int ) -> tuple[Optional[int], str]: try: - completion = await asyncio.to_thread( - self._client.chat.completions.create, - model=self.prm_model, - messages=messages, - temperature=self.temperature, - max_completion_tokens=self.max_new_tokens, + completion = await asyncio.wait_for( + asyncio.to_thread( + self._client.chat.completions.create, + model=self.prm_model, + messages=messages, + temperature=self.temperature, + max_completion_tokens=self.max_new_tokens, + ), + timeout=self.timeout, ) content = completion.choices[0].message.content or "" return _parse_prm_score(content), content + except asyncio.TimeoutError: + logger.warning("[PRMScorer] query timed out after %.0fs (vote %d)", self.timeout, vote_id) + return None, "" except Exception as e: logger.warning("[PRMScorer] query failed (vote %d): %s", vote_id, e) return None, "" diff --git a/metaclaw/skill_evolver.py b/metaclaw/skill_evolver.py index 2c073b0..dd98736 100644 --- a/metaclaw/skill_evolver.py +++ b/metaclaw/skill_evolver.py @@ -82,7 +82,7 @@ def __init__( self._openai_client = None else: self._custom_client = None - api_key = os.environ.get("OPENAI_API_KEY", "aB7cD9eF2gH5iJ8kL1mN4oP6qR3sT0uV") + api_key = os.environ.get("OPENAI_API_KEY", "") base_url = os.environ.get( "OPENAI_BASE_URL", "https://openai-api.shenmishajing.workers.dev/v1",