Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions metaclaw/data_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,19 @@ def compute_advantages(batch: list[ConversationSample]) -> list[float]:
Centre-and-scale rewards within the batch (GRPO style: (r - mean) / (std + eps)).

Returns a list of float advantages, one per sample.

When the batch contains only one sample (or all rewards are identical),
std is 0 and the normalised advantage collapses to 0 for every sample,
producing no gradient signal at all. In that case we fall back to
returning the raw (centred) rewards so the training step still gets a
meaningful update.
"""
if not batch:
return []
rewards = [s.reward for s in batch]
mean_r = sum(rewards) / len(rewards)
variance = sum((r - mean_r) ** 2 for r in rewards) / len(rewards)
std_r = variance ** 0.5
eps = 1e-8
return [(r - mean_r) / (std_r + eps) for r in rewards]
if std_r < 1e-8:
return [r - mean_r if len(batch) > 1 else r for r in rewards]
return [(r - mean_r) / std_r for r in rewards]
20 changes: 14 additions & 6 deletions metaclaw/prm_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
temperature: float = 0.6,
max_new_tokens: int = 1024,
llm_client=None,
timeout: float = 120.0,
):
if llm_client is not None:
self._client = llm_client
Expand All @@ -180,6 +181,7 @@ def __init__(
self.prm_m = prm_m
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self.timeout = timeout

async def evaluate(
self,
Expand Down Expand Up @@ -225,15 +227,21 @@ async def _query_once(
self, messages: list[dict], vote_id: int
) -> tuple[Optional[int], str]:
try:
completion = await asyncio.to_thread(
self._client.chat.completions.create,
model=self.prm_model,
messages=messages,
temperature=self.temperature,
max_completion_tokens=self.max_new_tokens,
completion = await asyncio.wait_for(
asyncio.to_thread(
self._client.chat.completions.create,
model=self.prm_model,
messages=messages,
temperature=self.temperature,
max_completion_tokens=self.max_new_tokens,
),
timeout=self.timeout,
)
content = completion.choices[0].message.content or ""
return _parse_prm_score(content), content
except asyncio.TimeoutError:
logger.warning("[PRMScorer] query timed out after %.0fs (vote %d)", self.timeout, vote_id)
return None, ""
except Exception as e:
logger.warning("[PRMScorer] query failed (vote %d): %s", vote_id, e)
return None, ""
2 changes: 1 addition & 1 deletion metaclaw/skill_evolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
self._openai_client = None
else:
self._custom_client = None
api_key = os.environ.get("OPENAI_API_KEY", "aB7cD9eF2gH5iJ8kL1mN4oP6qR3sT0uV")
api_key = os.environ.get("OPENAI_API_KEY", "")
base_url = os.environ.get(
"OPENAI_BASE_URL",
"https://openai-api.shenmishajing.workers.dev/v1",
Expand Down