-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclient.py
More file actions
72 lines (63 loc) · 2.6 KB
/
client.py
File metadata and controls
72 lines (63 loc) · 2.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Typed OpenEnv client for the code review benchmark."""
from __future__ import annotations
from typing import Any
from openenv.core.env_client import EnvClient
from openenv.core.client_types import StepResult
try:
from .models import (
ChangedFileSummary,
CodeReviewAction,
CodeReviewObservation,
CodeReviewState,
ReviewScorecard,
SearchHit,
)
except ImportError: # pragma: no cover
from models import (
ChangedFileSummary,
CodeReviewAction,
CodeReviewObservation,
CodeReviewState,
ReviewScorecard,
SearchHit,
)
class CodeReviewEnv(EnvClient[CodeReviewAction, CodeReviewObservation, CodeReviewState]):
"""Persistent WebSocket client for the code review environment."""
def _step_payload(self, action: CodeReviewAction) -> dict[str, Any]:
return action.model_dump(exclude_none=True)
def _parse_result(self, payload: dict[str, Any]) -> StepResult[CodeReviewObservation]:
obs_data = payload.get("observation", {})
scorecard_data = obs_data.get("scorecard")
observation = CodeReviewObservation(
task_id=obs_data.get("task_id", ""),
task_title=obs_data.get("task_title", ""),
difficulty=obs_data.get("difficulty", ""),
phase=obs_data.get("phase", "overview"),
instructions=obs_data.get("instructions", ""),
repo_name=obs_data.get("repo_name", ""),
pr_title=obs_data.get("pr_title", ""),
pr_description=obs_data.get("pr_description", ""),
ci_summary=obs_data.get("ci_summary", ""),
action_result=obs_data.get("action_result", ""),
displayed_content=obs_data.get("displayed_content", ""),
changed_files=[
ChangedFileSummary.model_validate(item)
for item in obs_data.get("changed_files", [])
],
search_results=[
SearchHit.model_validate(item) for item in obs_data.get("search_results", [])
],
attempts_remaining=obs_data.get("attempts_remaining", 0),
scorecard=(
ReviewScorecard.model_validate(scorecard_data) if scorecard_data else None
),
done=payload.get("done", False),
reward=payload.get("reward"),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: dict[str, Any]) -> CodeReviewState:
return CodeReviewState.model_validate(payload)