diff --git a/.gitignore b/.gitignore index 28949f7f5..f229e965b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,10 @@ # env .venv/ +**/.venv/ venv/ +**/venv/ env/ +**/env/ .env .env.local uv.lock @@ -42,3 +45,5 @@ scratch/ .vscode/ *.swp .DS_Store +environments/mcp_fetch/.env +environments/mcp_fetch/tests/test_gpt5_tool_call.py diff --git a/environments/mcp_fetch/README.md b/environments/mcp_fetch/README.md new file mode 100644 index 000000000..442e19b9e --- /dev/null +++ b/environments/mcp_fetch/README.md @@ -0,0 +1,130 @@ +# `mcp-fetch` + +Deterministic MCP environment that exposes a single `fetch` tool wired through the +shared `verifiers.envs.mcp_env.MCPEnv` wrapper. The tool talks to a local stdio MCP +server (`tools/fetch_mcp_server.py`) which in turn can only reach the fixtures +hosted by `utils/mini_httpd.py` unless explicitly configured for online hosts. + +Key components: + +- **Offline fixtures** – `utils/mini_httpd.py` serves HTML/JSON/text, redirects, + auth checks, query endpoints, and gzipped responses at `http://127.0.0.1:31415`. +- **MCP server** – `tools/fetch_mcp_server.py` doubles as a CLI (`--url ...`) and + an MCP stdio process (`--run-server`) returning structured JSON and plaintext. +- **Tasks** – `tasks/qa.jsonl` now includes 84 prompts covering direct lookups, + multi-step pointer puzzles, ledger math, and short-form judge summaries. The + latest gauntlet (IDs `fetch_065`–`fetch_084`) leans hard on planner/workflow + chains and poem character counts to separate small vs. strong models. Each + row includes metadata and a verifier definition used by scripts/tests. +- **Judge rubrics** – `tasks/judge_rubrics.yaml` defines four LLM-graded + summaries (poem, fruits, ledger, manifest). + +The expanded fixture set introduces chained lookups (pointer → directive → HTML), +numeric reasoning over the ledger JSON, and rubric-graded summarization to ensure +frontier models have to plan multiple tool calls instead of memorising a single +endpoint. + +## Installation + +```bash +cd environments/mcp_fetch +uv pip install -e . +``` + +This registers two console scripts: + +- `mcp-server-fetch` – stdio MCP server used by the environment. +- `mcp-fetch-mini-httpd` – helper for manually serving fixtures. + +## Running the environment + +```bash +uv run vf-eval mcp-fetch -n 5 -m gpt-4.1-mini +``` + +When invoking the Runner directly (outside `uv run`), make sure `PYTHONPATH` +includes both the repo root and `environments/` so `environments.mcp_fetch` +resolves correctly: + +```bash +PYTHONPATH=.:environments vf-eval mcp-fetch -n 10 -m gpt-4.1-mini +``` + +Arguments exposed via `load_environment(...)`: + +| Arg | Type | Default | Description | +| --- | ---- | ------- | ----------- | +| `allow_online` | bool | `False` | Allow a curated list of public hosts in addition to localhost fixtures. | +| `allow_any_host` | bool | `False` | Disable allowlist enforcement entirely (use with caution). | +| `allowed_hosts` | Iterable[str] | `None` | Custom host allowlist (`host` or `host:port`). Overrides both flags above. | +| `server_cmd` | str \| Sequence[str] | `None` | Override the `mcp-server-fetch` launch command. | +| `fixture_port` | int | `31415` | Port for the deterministic fixture host. | +| `auto_start_fixtures` | bool | `True` | Automatically launch the mini HTTP daemon (skipped when online mode is enabled). | +| `task_path` | str \| Path | internal | Path to the QA JSONL file (swap for ablations). | + +## Why non-trivial? + +The latest calibration run keeps smaller models honest while leaving headroom +for GPT-5–class evaluators. Scores are averaged over the full 84-task suite: + +| Model | Correct / Total | Accuracy | +| --- | --- | --- | +| `gpt-4.1-mini` | 46 / 84 | **54.8%** | +| `gpt-4.1` | 36 / 46 | **78.3%** | +| `gpt-5` | 68 / 84 | **80.95%** | + +Mini models routinely stall on the planner/workflow gauntlet and poem +character-count checks, while GPT-4.1 clears most but not all retries and +rubric-graded prompts. GPT-5 still needs deliberate tool planning to stay above +80%, which is exactly the intended difficulty band for the MCP Agents bounty. + +## Testing + +Offline verification mirrors the PI requirements: + +```bash +uv run pytest tests/environments/test_mcp_fetch.py -q +uv run ruff check environments/mcp_fetch +``` + +The pytest suite spins up the fixtures, drives the MCP server via the same helper +functions the environment uses, and asserts each canonical verifier passes. This +ensures regressions in the HTTP fixtures, hashing, or truncation metadata are +caught during CI. + +## API keys & GPT-5 diagnostic + +The calibration script and the GPT-5 sanity-check test automatically load envvars +from `environments/mcp_fetch/.env` (and the repo-level `.env`, if present) via +`python-dotenv`. Add your OpenAI key to the existing gitignored file: + +```bash +echo "OPENAI_API_KEY=sk-..." >> environments/mcp_fetch/.env +``` + +Because the file never leaves your machine, secrets stay local while still being +picked up automatically by both scripts. Once populated you can run: + +```bash +PYTHONPATH=.:environments .venv/bin/pytest environments/mcp_fetch/tests/test_gpt5_tool_call.py -s +PYTHONPATH=.:environments .venv/bin/python environments/mcp_fetch/scripts/calibrate_questions.py --model gpt-5 --max-turns 6 +``` + +The pytest harness mirrors OpenAI’s Responses tool loop and fails fast if GPT-5 +doesn’t return the expected “Mini Site” H1, which makes debugging empty-answer +cases easier. + +## Calibration cadence + +Run a single pass per reference model so token usage stays bounded while the new +gauntlet still lands: + +```bash +PYTHONPATH=.:environments .venv/bin/python environments/mcp_fetch/scripts/calibrate_questions.py --model gpt-4.1-mini --max-turns 6 --include-judge +PYTHONPATH=.:environments .venv/bin/python environments/mcp_fetch/scripts/calibrate_questions.py --model gpt-5 --max-turns 6 --include-judge +``` + +Review `environments/mcp_fetch/reports/question_quality_.json` after each +run; planner/workflow and poem-char categories should show the widest gap. If +mini creeps above ~50% again, extend the gauntlet with additional planner or +character-count prompts before re-running these commands. diff --git a/environments/mcp_fetch/__init__.py b/environments/mcp_fetch/__init__.py new file mode 100644 index 000000000..1fddd8215 --- /dev/null +++ b/environments/mcp_fetch/__init__.py @@ -0,0 +1,3 @@ +from .mcp_env import FetchEnv, load_environment + +__all__ = ["FetchEnv", "load_environment"] diff --git a/environments/mcp_fetch/fixtures/html/about.html b/environments/mcp_fetch/fixtures/html/about.html new file mode 100644 index 000000000..8ffcd8b7f --- /dev/null +++ b/environments/mcp_fetch/fixtures/html/about.html @@ -0,0 +1,8 @@ + + + About + +

About This Mini Site

+

It is served by a tiny Python HTTP server for deterministic testing.

+ + diff --git a/environments/mcp_fetch/fixtures/html/final.html b/environments/mcp_fetch/fixtures/html/final.html new file mode 100644 index 000000000..0668aaa62 --- /dev/null +++ b/environments/mcp_fetch/fixtures/html/final.html @@ -0,0 +1,8 @@ + + + Final + +

Redirect Target

+

You have reached the final page after a redirect.

+ + diff --git a/environments/mcp_fetch/fixtures/html/index.html b/environments/mcp_fetch/fixtures/html/index.html new file mode 100644 index 000000000..4ba5a27dc --- /dev/null +++ b/environments/mcp_fetch/fixtures/html/index.html @@ -0,0 +1,9 @@ + + + Fixtures Home + +

Mini Site — Index

+

Welcome to the deterministic mini site for MCP Fetch tasks.

+ About + + diff --git a/environments/mcp_fetch/fixtures/html/latin1.html b/environments/mcp_fetch/fixtures/html/latin1.html new file mode 100644 index 000000000..74ffbe03d --- /dev/null +++ b/environments/mcp_fetch/fixtures/html/latin1.html @@ -0,0 +1,7 @@ + + + Latin1 + +

Café au lait

+ + diff --git a/environments/mcp_fetch/fixtures/html/manifest.html b/environments/mcp_fetch/fixtures/html/manifest.html new file mode 100644 index 000000000..025e4a177 --- /dev/null +++ b/environments/mcp_fetch/fixtures/html/manifest.html @@ -0,0 +1,14 @@ + + + + + Manifest Beacon + + +
+

Stage: twilight

+ orbit-cascade +

Notes: keep the code uppercase when reporting.

+
+ + diff --git a/environments/mcp_fetch/fixtures/json/checksum_hints.json b/environments/mcp_fetch/fixtures/json/checksum_hints.json new file mode 100644 index 000000000..a5b9818db --- /dev/null +++ b/environments/mcp_fetch/fixtures/json/checksum_hints.json @@ -0,0 +1,11 @@ +{ + "segments": [ + {"source": "north", "value": 14}, + {"source": "west", "value": 6}, + {"source": "north", "value": 9}, + {"source": "south", "value": 4}, + {"source": "east", "value": 7}, + {"source": "north", "value": 5} + ], + "sequence": [8, 1, 3, 2, 1] +} diff --git a/environments/mcp_fetch/fixtures/json/data.json b/environments/mcp_fetch/fixtures/json/data.json new file mode 100644 index 000000000..8d0a0623e --- /dev/null +++ b/environments/mcp_fetch/fixtures/json/data.json @@ -0,0 +1,18 @@ +{ + "title": "Deterministic dataset", + "count": 3, + "items": [ + { + "id": 1, + "name": "alpha" + }, + { + "id": 2, + "name": "beta" + }, + { + "id": 3, + "name": "gamma" + } + ] +} diff --git a/environments/mcp_fetch/fixtures/json/data_large.jsonl b/environments/mcp_fetch/fixtures/json/data_large.jsonl new file mode 100644 index 000000000..076cd7b52 --- /dev/null +++ b/environments/mcp_fetch/fixtures/json/data_large.jsonl @@ -0,0 +1,20 @@ +{"i": 0, "v": 0} +{"i": 1, "v": 2} +{"i": 2, "v": 4} +{"i": 3, "v": 6} +{"i": 4, "v": 8} +{"i": 5, "v": 10} +{"i": 6, "v": 12} +{"i": 7, "v": 14} +{"i": 8, "v": 16} +{"i": 9, "v": 18} +{"i": 10, "v": 20} +{"i": 11, "v": 22} +{"i": 12, "v": 24} +{"i": 13, "v": 26} +{"i": 14, "v": 28} +{"i": 15, "v": 30} +{"i": 16, "v": 32} +{"i": 17, "v": 34} +{"i": 18, "v": 36} +{"i": 19, "v": 38} diff --git a/environments/mcp_fetch/fixtures/json/ledger.json b/environments/mcp_fetch/fixtures/json/ledger.json new file mode 100644 index 000000000..93d28383b --- /dev/null +++ b/environments/mcp_fetch/fixtures/json/ledger.json @@ -0,0 +1,12 @@ +{ + "entries": [ + {"id": "A1", "type": "expense", "amount": 12, "tags": ["ops", "night"]}, + {"id": "A2", "type": "expense", "amount": 9, "tags": ["ops"]}, + {"id": "A3", "type": "refund", "amount": 4, "tags": ["ops", "delta"]}, + {"id": "A4", "type": "expense", "amount": 15, "tags": ["delta"]}, + {"id": "A5", "type": "income", "amount": 40, "tags": ["ops", "spot"]}, + {"id": "A6", "type": "expense", "amount": 7, "tags": ["night"]}, + {"id": "A7", "type": "expense", "amount": 3, "tags": ["ops", "delta"]}, + {"id": "A8", "type": "income", "amount": 18, "tags": ["delta"]} + ] +} diff --git a/environments/mcp_fetch/fixtures/json/pointers.json b/environments/mcp_fetch/fixtures/json/pointers.json new file mode 100644 index 000000000..ed096299d --- /dev/null +++ b/environments/mcp_fetch/fixtures/json/pointers.json @@ -0,0 +1,12 @@ +{ + "paths": { + "ember": "/text/ember.txt", + "relay": "/text/relay.txt", + "sequence": "/text/sequence.txt", + "directive": "/text/directive.txt" + }, + "extras": { + "manifest": "/html/manifest.html", + "checksum": "/json/checksum_hints.json" + } +} diff --git a/environments/mcp_fetch/fixtures/text/directive.txt b/environments/mcp_fetch/fixtures/text/directive.txt new file mode 100644 index 000000000..b5aa08a04 --- /dev/null +++ b/environments/mcp_fetch/fixtures/text/directive.txt @@ -0,0 +1,4 @@ +Directive 7b: +1. Fetch http://127.0.0.1:31415/json/pointers.json and look for extras.manifest. +2. Use that value to request http://127.0.0.1:31415/html/manifest.html and read . +3. Respond only with that code in uppercase letters and remind the reader it must stay uppercase. diff --git a/environments/mcp_fetch/fixtures/text/ember.txt b/environments/mcp_fetch/fixtures/text/ember.txt new file mode 100644 index 000000000..d238b54fb --- /dev/null +++ b/environments/mcp_fetch/fixtures/text/ember.txt @@ -0,0 +1,4 @@ +Ember log +Line 1: Sparks coil beneath the relay mesh. +Line 2: Guidance phrase = stay patient. +Line 3: Final token: glow-lantern. diff --git a/environments/mcp_fetch/fixtures/text/poem.txt b/environments/mcp_fetch/fixtures/text/poem.txt new file mode 100644 index 000000000..35990d686 --- /dev/null +++ b/environments/mcp_fetch/fixtures/text/poem.txt @@ -0,0 +1,4 @@ +In the middle of the mini site, +Deterministic servers hum at night. +Signals fetch and journeys flow, +Answers bloom in ordered glow. diff --git a/environments/mcp_fetch/fixtures/text/relay.txt b/environments/mcp_fetch/fixtures/text/relay.txt new file mode 100644 index 000000000..81a7f4744 --- /dev/null +++ b/environments/mcp_fetch/fixtures/text/relay.txt @@ -0,0 +1,4 @@ +Relay diagnostics +Paths: /json/ledger.json, /json/pointers.json +Password: BRIDGE-NODE-7 +Checksum taps: 4 diff --git a/environments/mcp_fetch/fixtures/text/sequence.txt b/environments/mcp_fetch/fixtures/text/sequence.txt new file mode 100644 index 000000000..2a74484b3 --- /dev/null +++ b/environments/mcp_fetch/fixtures/text/sequence.txt @@ -0,0 +1,3 @@ +Sequence brief. Calibration sentences follow. +Second sentence states: Signals stack in ordered pairs again. +Third piece watches quietly. diff --git a/environments/mcp_fetch/mcp_env.py b/environments/mcp_fetch/mcp_env.py new file mode 100644 index 000000000..91e7bd291 --- /dev/null +++ b/environments/mcp_fetch/mcp_env.py @@ -0,0 +1,273 @@ +"""Fetch MCP environment built on the shared MCPEnv wrapper.""" + +from __future__ import annotations + +import json +import socket +import subprocess +import sys +import shlex +from pathlib import Path +from typing import Any, Dict, Iterable, Optional, Sequence + +from datasets import Dataset + +import verifiers as vf +from verifiers.envs.mcp_env import MCPEnv +from verifiers.envs.mcp import MCPServerConfig + +TASKS_PATH = Path(__file__).resolve().parent / "tasks" / "qa.jsonl" +DEFAULT_FIXTURE_PORT = 31415 +DEFAULT_ONLINE_HOSTS: list[str] = ["example.com", "httpbin.org"] + +SYSTEM_PROMPT = """You can call a single MCP tool named `fetch`. + +Rules: +1. Always call `fetch` to read the requested URL(s); never guess. +2. Use the HTTP method, headers, query params, and byte limits the task specifies. +3. After finishing tool calls reply with `ANSWER: ` on its own line. +4. Keep answers concise and deterministic—return raw numbers/strings when possible. +""" + + +def _normalize_hosts(hosts: Optional[Iterable[str]]) -> list[str]: + if hosts is None: + return [] + normalized: list[str] = [] + for host in hosts: + entry = host.strip() + if entry and entry not in normalized: + normalized.append(entry) + return normalized + + +def _port_in_use(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.25) + return sock.connect_ex(("127.0.0.1", port)) == 0 + + +def _load_tasks(path: Path = TASKS_PATH) -> list[dict[str, Any]]: + if not path.exists(): + raise FileNotFoundError(f"Task file not found: {path}") + tasks: list[dict[str, Any]] = [] + with path.open("r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + record = json.loads(line) + tasks.append(record) + if len(tasks) < 20: + raise ValueError("Fetch environment requires at least 20 tasks") + return tasks + + +def _build_dataset(task_rows: list[dict[str, Any]]) -> Dataset: + records = [] + for row in task_rows: + records.append( + { + "question": row["question"], + "answer": str(row.get("expected", "")), + "task_id": row["id"], + "verifier": json.dumps(row.get("verifier", {})), + "meta": json.dumps(row.get("meta", {})), + } + ) + return Dataset.from_list(records) + + +def _normalize_answer(text: str) -> str: + return " ".join(text.strip().split()).lower() + + +def _build_accuracy_rubric(parser: vf.Parser) -> vf.Rubric: + async def accuracy( + completion, answer: str, parser: vf.Parser, **_: Any + ) -> float: + guess = parser.parse_answer(completion) or "" + return 1.0 if _normalize_answer(guess) == _normalize_answer(answer) else 0.0 + + return vf.Rubric(funcs=[accuracy], weights=[1.0], parser=parser) + + +def _command_list(cmd: Optional[Sequence[str] | str], default: list[str]) -> list[str]: + if cmd is None: + return list(default) + if isinstance(cmd, str): + return shlex.split(cmd) + parts = list(cmd) + if not parts: + raise ValueError("Command override cannot be empty") + return parts + + +class FetchEnv(MCPEnv): + """Concrete MCPEnv wiring for the deterministic fetch MCP server.""" + + name = "mcp_fetch" + version = "0.3.0" + + def __init__( + self, + *, + server_cmd: Optional[Sequence[str] | str] = None, + server_env: Optional[Dict[str, str]] = None, + allow_online: bool = False, + allow_any_host: bool = False, + allowed_hosts: Optional[Iterable[str]] = None, + auto_start_fixtures: bool = True, + fixture_port: int = DEFAULT_FIXTURE_PORT, + fixture_cmd: Optional[Sequence[str] | str] = None, + **kwargs: Any, + ) -> None: + self.allow_online = allow_online + self.allow_any_host = allow_any_host + self.allowed_hosts = _normalize_hosts(allowed_hosts) + self.fixture_port = fixture_port + self._fixture_cmd = _command_list( + fixture_cmd, + [ + sys.executable, + "-m", + "environments.mcp_fetch.utils.mini_httpd", + "--port", + str(fixture_port), + ], + ) + self._fixture_proc: subprocess.Popen[str] | None = None + self._owns_fixture = False + + command_parts = _command_list( + server_cmd, + [ + sys.executable, + "-m", + "environments.mcp_fetch.tools.fetch_mcp_server", + "--run-server", + ], + ) + + host_allowlist = self.allowed_hosts + offline_hosts = _default_offline_hosts(self.fixture_port) + if not host_allowlist: + if allow_any_host: + host_allowlist = [] + elif allow_online: + host_allowlist = offline_hosts + DEFAULT_ONLINE_HOSTS + else: + host_allowlist = offline_hosts + + env = dict(server_env or {}) + self.host_allowlist = host_allowlist + + if host_allowlist: + env["MCP_FETCH_ALLOWED_HOSTS"] = ",".join(host_allowlist) + if allow_any_host: + env["MCP_FETCH_ALLOW_ANY_HOST"] = "1" + + config = MCPServerConfig( + name="fetch", + command=command_parts[0], + args=command_parts[1:], + env=env or None, + description="Deterministic Fetch MCP server", + ) + + super().__init__(mcp_servers=[config], **kwargs) + + if auto_start_fixtures and not allow_online and not allow_any_host: + self._ensure_fixture_server() + + def _ensure_fixture_server(self) -> None: + if _port_in_use(self.fixture_port): + self.logger.info( + "Reusing existing fixture server on port %s", self.fixture_port + ) + self._owns_fixture = False + return + + try: + self._fixture_proc = subprocess.Popen( + self._fixture_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + except OSError as exc: # pragma: no cover - spawn errors are rare + raise RuntimeError(f"Failed to launch fixture server: {exc}") from exc + + self._owns_fixture = True + self.logger.info( + "Started fixture server with pid %s on port %s", + self._fixture_proc.pid if self._fixture_proc else "unknown", + self.fixture_port, + ) + + async def cleanup(self) -> None: + await super().cleanup() + self._stop_fixture_server() + + def _stop_fixture_server(self) -> None: + if not self._owns_fixture or not self._fixture_proc: + return + proc = self._fixture_proc + if proc.poll() is None: + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + self._fixture_proc = None + self._owns_fixture = False + + +def load_environment( + *, + allow_online: bool = False, + allow_any_host: bool = False, + allowed_hosts: Optional[Iterable[str]] = None, + server_cmd: Optional[Sequence[str] | str] = None, + server_env: Optional[Dict[str, str]] = None, + task_path: Optional[str | Path] = None, + dataset: Dataset | None = None, + rubric: vf.Rubric | None = None, + system_prompt: str | None = None, + parser: vf.Parser | None = None, + **kwargs: Any, +) -> FetchEnv: + """Factory hook used by verifiers to instantiate the environment.""" + + tasks = _load_tasks(Path(task_path) if task_path else TASKS_PATH) + dataset = dataset or _build_dataset(tasks) + + parser = parser or vf.Parser( + extract_fn=lambda text: _extract_answer(text), + ) + rubric = rubric or _build_accuracy_rubric(parser) + system_prompt = system_prompt or SYSTEM_PROMPT + + return FetchEnv( + allow_online=allow_online, + allow_any_host=allow_any_host, + allowed_hosts=allowed_hosts, + server_cmd=server_cmd, + server_env=server_env, + dataset=dataset, + rubric=rubric, + system_prompt=system_prompt, + parser=parser, + **kwargs, + ) + + +def _extract_answer(text: str) -> str: + marker = "answer:" + lowered = text.lower() + if marker not in lowered: + return text.strip() + idx = lowered.rfind(marker) + return text[idx + len(marker) :].strip() +def _default_offline_hosts(port: int) -> list[str]: + return [f"127.0.0.1:{port}"] diff --git a/environments/mcp_fetch/pyproject.toml b/environments/mcp_fetch/pyproject.toml new file mode 100644 index 000000000..200cbd29d --- /dev/null +++ b/environments/mcp_fetch/pyproject.toml @@ -0,0 +1,32 @@ +[project] +name = "mcp-fetch" +version = "0.1.0" +description = "Deterministic MCP Fetch environment with offline fixtures" +tags = ["mcp", "tools", "fetch", "offline"] +requires-python = ">=3.11" +dependencies = [ + "verifiers>=0.1.6.post0", + "mcp>=1.14.1", + "httpx>=0.27.0", + "anyio>=4.4.0", + "PyYAML>=6.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = [ + "mcp_env.py", + "verifiers.py", + "README.md", + "tools/**", + "tasks/**", + "fixtures/**", + "utils/**", +] + +[project.scripts] +mcp-server-fetch = "mcp_fetch.tools.fetch_mcp_server:main" +mcp-fetch-mini-httpd = "mcp_fetch.utils.mini_httpd:serve" diff --git a/environments/mcp_fetch/scripts/calibrate_questions.py b/environments/mcp_fetch/scripts/calibrate_questions.py new file mode 100644 index 000000000..59074a5a0 --- /dev/null +++ b/environments/mcp_fetch/scripts/calibrate_questions.py @@ -0,0 +1,790 @@ +#!/usr/bin/env python3 +"""Calibrate question difficulty by driving the fetch tool via OpenAI models. + +This script spins up the deterministic mini HTTP server, loads deterministic +tasks from `environments/mcp_fetch/tasks/qa.jsonl`, and evaluates them with one +or more OpenAI models using tool-calling. The assistant is required to call the +`fetch` function to gather data before responding with a final answer in the +form `ANSWER: `. + +For each model we record pass/fail outcomes (exact match against the task's +`expected` field) and emit a JSON report under `environments/mcp_fetch/reports/`. +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import re +import threading +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional +from types import SimpleNamespace + +from openai import OpenAI + +from environments.mcp_fetch.tools.fetch_mcp_server import fetch_url_async +from environments.mcp_fetch.utils.mini_httpd import serve +from environments.mcp_fetch.verifiers import get_judge_prompt +SCRIPT_ROOT = Path(__file__).resolve().parent +ENV_ROOT = SCRIPT_ROOT.parent +REPO_ROOT = ENV_ROOT.parent +TASKS_PATH = ENV_ROOT / "tasks" / "qa.jsonl" +REPORTS_DIR = ENV_ROOT / "reports" +DEFAULT_PORT = 31415 +DEFAULT_MAX_TURNS = 6 + +FETCH_TOOL_SPEC = [ + { + "type": "function", + "function": { + "name": "fetch", + "description": ( + "Fetch a URL from the deterministic mini site. " + "Use this for every information request; hosts are restricted " + "to http://127.0.0.1:31415 by default." + ), + "parameters": { + "type": "object", + "required": ["url"], + "properties": { + "url": {"type": "string"}, + "method": { + "type": "string", + "enum": ["GET", "HEAD", "get", "head"], + "default": "GET", + }, + "headers": { + "type": "object", + "additionalProperties": {"type": "string"}, + }, + "params": { + "type": "object", + "additionalProperties": {"type": "string"}, + }, + "timeout_s": { + "type": "number", + "minimum": 0.1, + "maximum": 60.0, + "default": 8.0, + }, + "max_bytes": { + "type": "integer", + "minimum": 1, + "maximum": 1_000_000, + "default": 200_000, + }, + }, + }, + }, + } +] + +RESPONSES_FETCH_TOOL_SPEC = [ + { + "type": "function", + "name": "fetch", + "description": "Fetch a URL from the deterministic mini site (http://127.0.0.1:31415).", + "parameters": { + "type": "object", + "required": ["url"], + "properties": { + "url": {"type": "string"}, + "method": { + "type": "string", + "enum": ["GET", "HEAD", "get", "head"], + "default": "GET", + }, + "headers": { + "type": "object", + "additionalProperties": {"type": "string"}, + }, + "params": { + "type": "object", + "additionalProperties": {"type": "string"}, + }, + "timeout_s": {"type": "number", "default": 8.0}, + "max_bytes": {"type": "integer", "default": 200_000}, + }, + }, + } +] + +def _model_requires_responses(model: str) -> bool: + lowered = model.lower() + return lowered.startswith("gpt-5") or lowered.startswith("o3") + + +def _supports_temperature(model: str) -> bool: + lowered = model.lower() + if lowered.startswith("gpt-5"): + return False + return True + + +def _convert_text_block(role: str, text: str) -> Dict[str, str]: + block_type = "text" + if role == "user": + block_type = "input_text" + elif role == "assistant": + block_type = "output_text" + return {"type": block_type, "text": text} + + +def _format_messages_for_responses(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + formatted: List[Dict[str, Any]] = [] + for msg in messages: + role = msg.get("role", "user") + tool_calls = msg.get("tool_calls") + if tool_calls: + blocks = [] + for call in tool_calls: + func = call.get("function", {}) + blocks.append( + { + "type": "tool_call", + "id": call.get("id") or f"call_{len(blocks)}", + "name": func.get("name", "tool"), + "arguments": func.get("arguments", "{}"), + } + ) + formatted.append({"role": "assistant", "content": blocks}) + continue + + if role == "tool": + formatted.append( + { + "role": "tool", + "content": [ + { + "type": "tool_result", + "tool_call_id": msg.get("tool_call_id") or "", + "output": str(msg.get("content", "")), + "is_error": False, + } + ], + } + ) + continue + + content = msg.get("content") + if isinstance(content, list): + text_parts: List[str] = [] + for chunk in content: + if isinstance(chunk, dict) and "text" in chunk: + text_parts.append(str(chunk["text"])) + elif isinstance(chunk, str): + text_parts.append(chunk) + content = "\n".join(text_parts) if text_parts else "" + if content is None: + content = "" + blocks = [_convert_text_block(role, str(content))] + formatted.append({"role": role, "content": blocks}) + return formatted + + +def _responses_to_chat_choice(response) -> SimpleNamespace: + """Convert Responses API output to a chat-completion-like object.""" + + outputs = getattr(response, "output", None) + if outputs is None and hasattr(response, "model_dump"): + outputs = response.model_dump().get("output", []) + + tool_calls = [] + text_parts: List[str] = [] + finish_reason: Optional[str] = None + + for raw_item in outputs or []: + item = raw_item.model_dump() if hasattr(raw_item, "model_dump") else raw_item + item_type = item.get("type") + if item_type == "message": + for chunk in item.get("content", []): + chunk_type = chunk.get("type") + if chunk_type in {"output_text", "text"}: + text_parts.append(chunk.get("text", "")) + finish_reason = ( + item.get("metadata", {}).get("finish_reason") + or item.get("status") + or finish_reason + or "stop" + ) + elif item_type == "tool_call": + arguments = item.get("arguments") + if not isinstance(arguments, str): + arguments = json.dumps(arguments or {}) + tool_calls.append( + { + "id": item.get("id") or item.get("tool_call_id") or f"call_{len(tool_calls)}", + "type": "function", + "function": { + "name": item.get("name") or item.get("function_name") or "tool", + "arguments": arguments, + }, + } + ) + finish_reason = "tool_calls" + + message_obj = SimpleNamespace( + role="assistant", + content="\n".join(part for part in text_parts if part).strip() or None, + tool_calls=[SimpleNamespace(**call) for call in tool_calls] if tool_calls else None, + ) + choice_obj = SimpleNamespace( + message=message_obj, + finish_reason=finish_reason or ("tool_calls" if tool_calls else "stop"), + ) + return SimpleNamespace(choices=[choice_obj]) + + +def _create_completion( + client: OpenAI, + *, + model: str, + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + temperature: float = 0.0, +): + if _model_requires_responses(model): + kwargs: Dict[str, Any] = { + "model": model, + "input": _format_messages_for_responses(messages), + } + if temperature and _supports_temperature(model): + kwargs["temperature"] = temperature + tool_spec = tools or RESPONSES_FETCH_TOOL_SPEC + if tool_spec: + kwargs["tools"] = tool_spec + kwargs["tool_choice"] = "auto" + return _responses_to_chat_choice(client.responses.create(**kwargs)) + return client.chat.completions.create( + model=model, + temperature=temperature if _supports_temperature(model) else 0, + messages=messages, + tools=tools or FETCH_TOOL_SPEC, + ) + +SYSTEM_PROMPT = """\ +You are a calibration assistant for the Fetch MCP environment. +Every task references the deterministic fixtures server at http://127.0.0.1:31415. + +Rules: +1. Always call the `fetch` tool to read the requested URL before answering. +2. Use GET or HEAD exactly as requested; include headers/query params when needed. +3. After gathering the necessary data, reply with `ANSWER: ` and nothing else. +4. Do not guess without fetching; if something fails, briefly explain then provide `ANSWER: `. +""" + + +@dataclass +class Task: + raw: Dict[str, Any] + + @property + def id(self) -> str: + return self.raw.get("id", "") + + @property + def question(self) -> str: + return self.raw.get("question", "") + + @property + def expected(self) -> str: + return str(self.raw.get("expected", "")).strip() + + @property + def verifier_type(self) -> str: + verifier = self.raw.get("verifier") or {} + return str(verifier.get("type", "")) + + @property + def verifier(self) -> Dict[str, Any]: + return self.raw.get("verifier") or {} + + @property + def rubric_id(self) -> Optional[str]: + return self.verifier.get("rubric") + + +def load_local_env(path: Path) -> None: + """Populate os.environ from a simple KEY=VALUE .env file if present.""" + + if not path.exists(): + return + for line in path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, _, value = line.partition("=") + key = key.strip() + value = value.strip() + if key and key not in os.environ: + os.environ[key] = value + + +def load_tasks( + *, + include_judge: bool = False, + ids: Optional[Iterable[str]] = None, +) -> List[Task]: + selected_ids = {task_id.strip() for task_id in ids} if ids else None + tasks: List[Task] = [] + with TASKS_PATH.open("r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + raw = json.loads(line) + task = Task(raw=raw) + if not include_judge and task.verifier_type == "judge": + continue + if selected_ids and task.id not in selected_ids: + continue + tasks.append(task) + return tasks + + +def start_fixture_server(port: int) -> threading.Thread: + thread = threading.Thread(target=serve, kwargs={"port": port}, daemon=True) + thread.start() + time.sleep(0.5) + return thread + + +def _message_content_to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: List[str] = [] + for chunk in content: + if isinstance(chunk, str): + parts.append(chunk) + elif isinstance(chunk, dict) and "text" in chunk: + parts.append(str(chunk["text"])) + return "\n".join(parts) + return str(content) + + +def _chat_message_to_dict(message: Any) -> Dict[str, Any]: + payload: Dict[str, Any] = {"role": message.role} + content = _message_content_to_text(message.content) + if content: + payload["content"] = content + if message.tool_calls: + payload["tool_calls"] = [] + for call in message.tool_calls: + payload["tool_calls"].append( + { + "id": call.id, + "type": call.type, + "function": { + "name": call.function.name, + "arguments": call.function.arguments, + }, + } + ) + return payload + + +def _invoke_fetch_tool(call: Any) -> Dict[str, Any]: + try: + args = json.loads(call.function.arguments or "{}") + except json.JSONDecodeError as exc: + return {"error": f"invalid arguments: {exc}"} + + url = args.get("url") + if not url: + return {"error": "url is required"} + + try: + result = asyncio.run( + fetch_url_async( + url, + method=args.get("method", "GET"), + headers=args.get("headers"), + params=args.get("params"), + timeout_s=float(args.get("timeout_s", 8.0)), + max_bytes=int(args.get("max_bytes", 200_000)), + ) + ) + except Exception as exc: # noqa: BLE001 - need to surface tool errors + return {"error": str(exc)} + return result + + +def _invoke_fetch_from_args(args: Dict[str, Any]) -> Dict[str, Any]: + url = args.get("url") + if not url: + return {"error": "url is required"} + try: + result = asyncio.run( + fetch_url_async( + url, + method=args.get("method", "GET"), + headers=args.get("headers"), + params=args.get("params"), + timeout_s=float(args.get("timeout_s", 8.0)), + max_bytes=int(args.get("max_bytes", 200_000)), + ) + ) + except Exception as exc: # noqa: BLE001 + return {"error": str(exc)} + return result + + +def extract_answer(output_text: str) -> str: + if not output_text: + return "" + match = re.search(r"answer\s*:\s*(.+)", output_text, flags=re.IGNORECASE) + if match: + return match.group(1).strip() + return output_text.strip() + + +def normalize_answer(text: str) -> str: + return text.strip().casefold() + + +def evaluate_judge_completion( + client: OpenAI, + *, + submission: str, + rubric_id: str, + judge_model: str, +) -> Dict[str, Any]: + prompt = get_judge_prompt(rubric_id) + user_prompt = f"{prompt.strip()}\n\nSubmission:\n{submission.strip() or '(empty)'}" + response = _create_completion( + client, + model=judge_model, + temperature=0, + messages=[{"role": "user", "content": user_prompt}], + ) + content = _message_content_to_text(response.choices[0].message.content) + decision = content.strip().lower() + passed = decision.startswith("pass") + return {"passed": passed, "judge_output": content} + + +def evaluate_task( + client: OpenAI, + model: str, + task: Task, + *, + max_turns: int, +) -> Dict[str, Any]: + messages: List[Dict[str, Any]] = [{"role": "system", "content": SYSTEM_PROMPT}] + messages.append({"role": "user", "content": task.question}) + tool_calls = 0 + transcript: List[Dict[str, Any]] = [] + + if _model_requires_responses(model): + return _evaluate_task_responses( + client, + model, + messages, + task, + max_turns=max_turns, + ) + + for _ in range(max_turns): + completion = _create_completion( + client, + model=model, + temperature=0, + messages=messages, + tools=FETCH_TOOL_SPEC, + ) + choice = completion.choices[0] + message = choice.message + message_dict = _chat_message_to_dict(message) + messages.append(message_dict) + transcript.append(message_dict) + + if choice.finish_reason == "tool_calls": + for call in message.tool_calls or []: + tool_calls += 1 + payload = _invoke_fetch_tool(call) + tool_message = { + "role": "tool", + "tool_call_id": call.id, + "content": json.dumps(payload), + } + messages.append(tool_message) + transcript.append(tool_message) + continue + + final_text = _message_content_to_text(message.content) + answer_value = extract_answer(final_text) + return { + "answer": answer_value, + "raw_output": final_text, + "tool_calls": tool_calls, + "transcript": transcript, + } + + return { + "answer": "", + "raw_output": "", + "tool_calls": tool_calls, + "transcript": transcript, + "error": "max turns exceeded", + } + + +def _evaluate_task_responses( + client: OpenAI, + model: str, + messages: List[Dict[str, Any]], + task: Task, + *, + max_turns: int, +) -> Dict[str, Any]: + response_inputs = _format_messages_for_responses(messages) + transcript: List[Dict[str, Any]] = [] + tool_calls = 0 + + for _ in range(max_turns): + resp = client.responses.create( + model=model, + input=response_inputs, + tools=RESPONSES_FETCH_TOOL_SPEC, + tool_choice="auto", + reasoning={"effort": "medium", "summary": "auto"}, + ) + resp_dict = resp.model_dump() + outputs = resp_dict.get("output", []) + response_inputs.extend(outputs) + + pending_tool = False + final_text: Optional[str] = None + + for item in outputs: + item_type = item.get("type") + if item_type == "tool_call": + pending_tool = True + tool_calls += 1 + call_id = item.get("id") or item.get("tool_call_id") or f"call_{tool_calls}" + func_name = item.get("name") or item.get("function_name") or "fetch" + arguments = item.get("arguments") + if isinstance(arguments, str): + try: + args_dict = json.loads(arguments) + except json.JSONDecodeError: + args_dict = {} + else: + args_dict = arguments or {} + if func_name == "fetch": + payload = _invoke_fetch_from_args(args_dict) + else: + payload = {"error": f"unknown tool {func_name}"} + + response_inputs.append( + { + "type": "function_call_output", + "call_id": call_id, + "output": json.dumps(payload), + } + ) + transcript.append( + { + "role": "assistant", + "content": json.dumps({"tool_call": {"name": func_name, "args": args_dict}}), + } + ) + transcript.append( + { + "role": "tool", + "content": json.dumps(payload), + "tool_call_id": call_id, + } + ) + elif item_type == "message": + text_parts: List[str] = [] + for chunk in item.get("content", []): + if chunk.get("type") in {"text", "output_text"}: + text_parts.append(chunk.get("text", "")) + final_text = "\n".join(part for part in text_parts if part).strip() + if final_text: + transcript.append({"role": "assistant", "content": final_text}) + + if pending_tool: + continue + if final_text: + return { + "answer": extract_answer(final_text), + "raw_output": final_text, + "tool_calls": tool_calls, + "transcript": transcript, + } + + return { + "answer": "", + "raw_output": "", + "tool_calls": tool_calls, + "transcript": transcript, + "error": "max turns exceeded", + } + + +def sanitize_filename(model: str) -> str: + return re.sub(r"[^a-zA-Z0-9_.-]", "_", model) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Calibrate Fetch question quality.") + parser.add_argument( + "--model", + action="append", + required=True, + help="OpenAI model to evaluate (can be passed multiple times).", + ) + parser.add_argument( + "--task-id", + action="append", + help="Limit evaluation to specific task IDs (repeatable).", + ) + parser.add_argument( + "--max-questions", + type=int, + help="Limit the number of tasks (after filtering) to evaluate.", + ) + parser.add_argument( + "--max-turns", + type=int, + default=DEFAULT_MAX_TURNS, + help="Maximum chat turns (assistant responses) per task.", + ) + parser.add_argument( + "--include-judge", + action="store_true", + help="Include JudgeRubric tasks (requires judge grading).", + ) + parser.add_argument( + "--judge-model", + help="Model to use for JudgeRubric grading (defaults to the first --model).", + ) + parser.add_argument( + "--port", + type=int, + default=DEFAULT_PORT, + help="Port for the fixtures server.", + ) + args = parser.parse_args() + + load_local_env(REPO_ROOT / ".env") + if not os.getenv("OPENAI_API_KEY"): + raise SystemExit("OPENAI_API_KEY is not set (load it in .env or environment).") + + client = OpenAI() + judge_model = args.judge_model + if not judge_model: + for candidate in args.model: + if not _model_requires_responses(candidate): + judge_model = candidate + break + if not judge_model and args.model: + judge_model = args.model[0] + if args.include_judge and not judge_model: + raise SystemExit("Judge model must be provided when including judge tasks.") + start_fixture_server(args.port) + + tasks = load_tasks(ids=args.task_id, include_judge=args.include_judge) + if args.max_questions is not None: + tasks = tasks[: args.max_questions] + + REPORTS_DIR.mkdir(exist_ok=True, parents=True) + + summaries: list[dict[str, Any]] = [] + + for model in args.model: + results = [] + correct = 0 + judge_model_for_run = judge_model or model + for task in tasks: + record = { + "task_id": task.id, + "question": task.question, + "expected": task.expected, + } + try: + outcome = evaluate_task( + client, + model, + task, + max_turns=args.max_turns, + ) + except Exception as exc: # noqa: BLE001 + record["error"] = str(exc) + record["answer"] = "" + record["raw_output"] = "" + record["tool_calls"] = 0 + else: + record.update(outcome) + + extra_detail = f", expected={task.expected!r}" + if task.verifier_type == "judge": + rubric_id = task.rubric_id + if not rubric_id: + raise ValueError(f"Judge task {task.id} missing rubric id") + judge_result = evaluate_judge_completion( + client, + submission=record.get("answer", ""), + rubric_id=rubric_id, + judge_model=judge_model_for_run, + ) + record.update(judge_result) + is_correct = judge_result["passed"] and not record.get("error") + extra_detail = f", judge={judge_result['judge_output']!r}" + else: + actual = normalize_answer(record.get("answer", "")) + expected = normalize_answer(task.expected) + is_correct = actual == expected and not record.get("error") + record["correct"] = is_correct + if is_correct: + correct += 1 + results.append(record) + status = "PASS" if is_correct else "FAIL" + print(f"[{model}] {task.id}: {status} (answer={record.get('answer')!r}{extra_detail})") + + accuracy = correct / len(tasks) if tasks else 0.0 + summary = { + "model": model, + "accuracy": accuracy, + "correct": correct, + "total": len(tasks), + "thresholds": { + "small_model_max": "< 0.90", + "strong_model_min": "> 0.10", + }, + "results": results, + } + report_path = REPORTS_DIR / f"question_quality_{sanitize_filename(model)}.json" + report_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") + summary_record = { + "model": model, + "correct": correct, + "total": len(tasks), + "accuracy": accuracy, + "report_path": str(report_path), + } + summaries.append(summary_record) + print(f"\nModel {model}: {correct}/{len(tasks)} correct ({accuracy:.1%}). Report: {report_path}\n") + + if summaries: + print("\n=== Calibration Summary ===") + for record in summaries: + print( + f"- {record['model']}: {record['correct']}/{record['total']} " + f"({record['accuracy']:.1%}) -> {record['report_path']}" + ) + print("==========================\n") + + +if __name__ == "__main__": + main() diff --git a/environments/mcp_fetch/scripts/check_gpt5.py b/environments/mcp_fetch/scripts/check_gpt5.py new file mode 100755 index 000000000..62d6a80d3 --- /dev/null +++ b/environments/mcp_fetch/scripts/check_gpt5.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +"""Simple gpt-5 sanity check using the Responses API.""" + +import json +import os + +from openai import OpenAI + +SYSTEM = "You are a calibration assistant for the Fetch MCP environment." +USER = "Fetch http://127.0.0.1:31415/html/index.html and return the exact H1 text." + +messages = [ + {"role": "system", "content": [{"type": "input_text", "text": SYSTEM}]}, + {"role": "user", "content": [{"type": "input_text", "text": USER}]}, +] + +client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) +response = client.responses.create( + model="gpt-5", + input=messages, + text={"format": {"type": "text"}}, + reasoning={"effort": "medium", "summary": "auto"}, +) + + +def _extract_output(resp_dict: dict) -> str: + output = resp_dict.get("output", []) + text_parts = [] + for item in output: + if item.get("type") == "message": + for chunk in item.get("content", []): + if chunk.get("type") in {"text", "output_text"}: + text_parts.append(chunk.get("text", "")) + elif item.get("type") == "tool_call": + print("Tool call:", json.dumps(item, indent=2)) + return "\n".join(text_parts).strip() + + +resp_dict = response.model_dump() +print("Raw response:\n", json.dumps(resp_dict, indent=2)) +print("\nExtracted output:\n", _extract_output(resp_dict) or "") diff --git a/environments/mcp_fetch/scripts/run_offline_shard.py b/environments/mcp_fetch/scripts/run_offline_shard.py new file mode 100644 index 000000000..d50c79f9c --- /dev/null +++ b/environments/mcp_fetch/scripts/run_offline_shard.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +"""Offline CI shard for the Fetch MCP environment. + +This script launches the deterministic fixtures server and runs a targeted suite +of fetch requests using the same helper (`fetch_url_async`) that powers the MCP +tool. Each response is validated with the canonical verifiers to ensure the +endpoints, headers, query params, and truncation metadata behave as expected. + +The shard intentionally covers representative scenarios (HTML parsing, HEAD, +query params, auth headers, large payload hashing, and derived metrics) while +remaining fully offline/deterministic so it can run in CI without external +dependencies. +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import socket +import sys +import threading +import time +from pathlib import Path +from typing import Any, Dict, List + +from environments.mcp_fetch.tools.fetch_mcp_server import fetch_url_async +from environments.mcp_fetch.utils.mini_httpd import serve +from environments.mcp_fetch.verifiers import run_verifier + +SCRIPT_ROOT = Path(__file__).resolve().parent +ENV_ROOT = SCRIPT_ROOT.parent +TASKS_PATH = ENV_ROOT / "tasks" / "qa.jsonl" +DEFAULT_PORT = 31415 + +SHARD_CASES: List[Dict[str, Any]] = [ + { + "id": "fetch_014", + "request": {"path": "/text/poem.txt", "method": "GET"}, + }, + { + "id": "fetch_006", + "request": {"path": "/html/about.html", "method": "HEAD"}, + }, + { + "id": "fetch_009", + "request": {"path": "/query?category=fruits&limit=2"}, + }, + { + "id": "fetch_011", + "request": { + "path": "/auth", + "headers": {"X-Token": "opensesame"}, + }, + }, + { + "id": "fetch_018", + "request": {"path": "/json/data_large.jsonl"}, + }, + { + "id": "fetch_027", + "request": {"path": "/json/data_large.jsonl"}, + }, + { + "id": "fetch_030", + "request": {"path": "/json/ledger.json"}, + }, +] + + +def load_task_map() -> Dict[str, Dict[str, Any]]: + tasks: Dict[str, Dict[str, Any]] = {} + with TASKS_PATH.open("r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + data = json.loads(line) + tasks[str(data["id"])] = data + return tasks + + +def _port_in_use(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.2) + try: + sock.connect(("127.0.0.1", port)) + return True + except OSError: + return False + + +def start_fixture_server(port: int) -> threading.Thread | None: + if _port_in_use(port): + print(f"[offline-shard] Detected running fixtures on port {port}; reusing existing server.") + return None + thread = threading.Thread(target=serve, kwargs={"port": port}, daemon=True) + thread.start() + time.sleep(0.5) + return thread + + +def _build_request(case: Dict[str, Any], port: int) -> Dict[str, Any]: + request = dict(case.get("request", {})) + rel_path = request.pop("path", None) + if not rel_path: + raise ValueError(f"Shard case {case['id']} missing request path") + request.setdefault("method", "GET") + request["url"] = f"http://127.0.0.1:{port}{rel_path}" + return request + + +async def run_case(case: Dict[str, Any], verifier: Dict[str, Any], port: int) -> bool: + request = _build_request(case, port) + payload = await fetch_url_async(**request) + result = run_verifier(verifier, payload) + if result is None: + raise ValueError(f"Shard case {case['id']} references a judge verifier.") + return bool(result) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Offline shard verifier.") + parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="Fixture server port.") + args = parser.parse_args() + + start_fixture_server(args.port) + tasks = load_task_map() + + failures: List[str] = [] + for case in SHARD_CASES: + task_id = case["id"] + verifier = tasks[task_id]["verifier"] + try: + success = asyncio.run(run_case(case, verifier, args.port)) + except Exception as exc: # noqa: BLE001 + print(f"[offline-shard] {task_id}: ERROR ({exc})") + failures.append(task_id) + continue + status = "PASS" if success else "FAIL" + print(f"[offline-shard] {task_id}: {status}") + if not success: + failures.append(task_id) + + if failures: + print(f"[offline-shard] Failed cases: {', '.join(failures)}") + return 1 + + print("[offline-shard] All cases passed.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/environments/mcp_fetch/tasks/judge_rubrics.yaml b/environments/mcp_fetch/tasks/judge_rubrics.yaml new file mode 100644 index 000000000..df2f79261 --- /dev/null +++ b/environments/mcp_fetch/tasks/judge_rubrics.yaml @@ -0,0 +1,52 @@ +rubrics: + poem_summary: + description: "Score a short summary of the deterministic poem." + judge_prompt: | + You are grading a single-sentence summary of a four-line poem served at http://127.0.0.1:31415/text/poem.txt. + + The poem describes a deterministic mini site where reliable servers operate at night and signals flow to deliver answers. + + Accept the response only if ALL of the following are true: + 1. It accurately captures the poem's themes of deterministic or reliable servers operating smoothly. + 2. It mentions the imagery of signals, journeys, or answers flowing. + 3. It contains 20 words or fewer. + 4. It does not mention being unable to access the file or refusing the task. + + Reply with "pass" or "fail" followed by a brief justification. + + fruits_summary: + description: "Score a concise report about the fruits query endpoint." + judge_prompt: | + You are checking a short summary of the JSON returned by http://127.0.0.1:31415/query?category=fruits&limit=4. + + The JSON lists four fruits in order: apple, banana, cherry, date. The summary must: + 1. Name at least two of those fruits. + 2. State or clearly imply that four results were returned. + 3. Stay within 25 words. + 4. Avoid speculation beyond what the endpoint returns. + + Respond with "pass" if all requirements are satisfied; otherwise respond with "fail". Provide a brief explanation either way. + ledger_summary: + description: "Judge summaries of the deterministic ledger file." + judge_prompt: | + You are grading a short summary written after reading http://127.0.0.1:31415/json/ledger.json. + + Requirements for a passing response: + 1. Mentions both total expenses (46) and total income (58). + 2. States clearly that income exceeds expenses (or that expenses are lower). + 3. Uses 25 words or fewer. + 4. Does not fabricate additional datasets beyond what the ledger provides. + + Reply with "pass" or "fail" followed by one sentence justification. + manifest_report: + description: "Judge manifest reports describing the uppercase code." + judge_prompt: | + The worker fetched /text/directive.txt, followed its instructions, and read the code inside http://127.0.0.1:31415/html/manifest.html. + + Accept only if the response: + 1. Explicitly states the code "ORBIT-CASCADE". + 2. Reminds the reader the code must be uppercase. + 3. Uses 25 words or fewer. + 4. Avoids extraneous speculation unrelated to the directive. + + Respond with "pass" or "fail" plus a brief rationale. diff --git a/environments/mcp_fetch/tasks/qa.jsonl b/environments/mcp_fetch/tasks/qa.jsonl new file mode 100644 index 000000000..d2caa68a6 --- /dev/null +++ b/environments/mcp_fetch/tasks/qa.jsonl @@ -0,0 +1,40 @@ +{"id": "fetch_001", "question": "Fetch http://127.0.0.1:31415/html/index.html and return the exact H1 text.", "expected": "Mini Site — Index", "verifier": {"type": "equals", "field": "body_text", "pattern": "Mini Site — Index"}, "meta": {"offline": true, "category": "html_h1"}} +{"id": "fetch_002", "question": "Fetch http://127.0.0.1:31415/json/data.json and return the value of the 'title' key.", "expected": "Deterministic dataset", "verifier": {"type": "json_key", "path": "title", "equals": "Deterministic dataset"}, "meta": {"offline": true, "category": "json_key"}} +{"id": "fetch_003", "question": "Request http://127.0.0.1:31415/notfound and report the HTTP status code.", "expected": "404", "verifier": {"type": "status", "equals": 404}, "meta": {"offline": true, "category": "status_404"}} +{"id": "fetch_004", "question": "Call http://127.0.0.1:31415/headers and explicitly set the X-Debug header to \"hello\" via the fetch tool headers parameter (otherwise the endpoint returns \"none\"); return the echoed value.", "expected": "hello", "verifier": {"type": "json_key", "path": "echo", "equals": "hello"}, "meta": {"offline": true, "category": "header_echo"}} +{"id": "fetch_005", "question": "Follow the redirect at http://127.0.0.1:31415/redirect1 and return the final URL path.", "expected": "/html/final.html", "verifier": {"type": "equals", "field": "final_url_suffix", "pattern": "/html/final.html"}, "meta": {"offline": true, "category": "redirect"}} +{"id": "fetch_006", "question": "Use the fetch tool with method HEAD on http://127.0.0.1:31415/html/about.html and return the Content-Length header value.", "expected": "222", "verifier": {"type": "header", "key": "Content-Length", "equals": "222"}, "meta": {"offline": true, "category": "head_content_length"}} +{"id": "fetch_007", "question": "Send a HEAD request to http://127.0.0.1:31415/compressed and report the Content-Encoding header.", "expected": "gzip", "verifier": {"type": "header", "key": "Content-Encoding", "equals": "gzip"}, "meta": {"offline": true, "category": "head_header"}} +{"id": "fetch_008", "question": "Fetch http://127.0.0.1:31415/compressed and return the value of the 'note' field from the JSON body.", "expected": "gzipped", "verifier": {"type": "json_key", "path": "note", "equals": "gzipped"}, "meta": {"offline": true, "category": "gzip_json"}} +{"id": "fetch_009", "question": "Call http://127.0.0.1:31415/query?category=fruits&limit=2 (or equivalently use params) and return the 'summary' field.", "expected": "apple, banana", "verifier": {"type": "json_key", "path": "summary", "equals": "apple, banana"}, "meta": {"offline": true, "category": "query_summary"}} +{"id": "fetch_010", "question": "Query http://127.0.0.1:31415/query?category=colors&limit=1 (or set the same params manually) and return the 'first' value from the JSON response.", "expected": "red", "verifier": {"type": "json_key", "path": "first", "equals": "red"}, "meta": {"offline": true, "category": "query_first"}} +{"id": "fetch_011", "question": "Fetch http://127.0.0.1:31415/auth and include X-Token: opensesame via the fetch tool headers parameter (the endpoint returns 401 if you skip it); return the 'secret' value.", "expected": "orchid", "verifier": {"type": "json_key", "path": "secret", "equals": "orchid"}, "meta": {"offline": true, "category": "auth_success"}} +{"id": "fetch_012", "question": "Call http://127.0.0.1:31415/auth without any X-Token header and report the HTTP status code.", "expected": "401", "verifier": {"type": "status", "equals": 401}, "meta": {"offline": true, "category": "auth_failure"}} +{"id": "fetch_013", "question": "Follow the redirect chain starting at http://127.0.0.1:31415/redirect_chain and return the final URL path.", "expected": "/html/final.html", "verifier": {"type": "equals", "field": "final_url_suffix", "pattern": "/html/final.html"}, "meta": {"offline": true, "category": "redirect_chain"}} +{"id": "fetch_014", "question": "Fetch http://127.0.0.1:31415/text/poem.txt and return the final line of the file.", "expected": "Answers bloom in ordered glow.", "verifier": {"type": "contains", "field": "body_text", "pattern": "Answers bloom in ordered glow."}, "meta": {"offline": true, "category": "text_tail"}} +{"id": "fetch_015", "question": "Fetch http://127.0.0.1:31415/html/latin1.html and return the H1 text.", "expected": "Café au lait", "verifier": {"type": "equals", "field": "body_text", "pattern": "Café au lait"}, "meta": {"offline": true, "category": "unicode_html"}} +{"id": "fetch_016", "question": "Fetch http://127.0.0.1:31415/json/data.json and report the value of the 'count' field.", "expected": "3", "verifier": {"type": "json_key", "path": "count", "equals": 3}, "meta": {"offline": true, "category": "json_count"}} +{"id": "fetch_017", "question": "Fetch http://127.0.0.1:31415/json/data.json and return the name of the item where id=3.", "expected": "gamma", "verifier": {"type": "json_key", "path": "items.2.name", "equals": "gamma"}, "meta": {"offline": true, "category": "json_lookup"}} +{"id": "fetch_018", "question": "Fetch http://127.0.0.1:31415/json/data_large.jsonl and return the sha256 hash reported by the tool (use default max_bytes).", "expected": "2906f2f9ff62a9a11dab9b19a1c2423e7ad59bed8eda691837e0d01f16edac62", "verifier": {"type": "hash", "equals": "2906f2f9ff62a9a11dab9b19a1c2423e7ad59bed8eda691837e0d01f16edac62"}, "meta": {"offline": true, "category": "hash_full"}} +{"id": "fetch_019", "question": "Fetch http://127.0.0.1:31415/json/data_large.jsonl with max_bytes set to 64 and report whether the response was truncated (answer true or false).", "expected": "true", "verifier": {"type": "field_bool", "field": "truncated", "equals": true}, "meta": {"offline": true, "category": "truncation_flag"}} +{"id": "fetch_020", "question": "Repeat the fetch of http://127.0.0.1:31415/json/data_large.jsonl with max_bytes=64 and return the recorded byte count.", "expected": "64", "verifier": {"type": "field_number", "field": "bytes", "equals": 64}, "meta": {"offline": true, "category": "truncation_bytes"}} +{"id": "fetch_021", "question": "Call http://127.0.0.1:31415/headers and set X-Debug: diagnostics via the fetch tool headers parameter (missing the header produces \"none\"); return the echoed value.", "expected": "diagnostics", "verifier": {"type": "json_key", "path": "echo", "equals": "diagnostics"}, "meta": {"offline": true, "category": "header_echo"}} +{"id": "fetch_022", "question": "Fetch http://127.0.0.1:31415/compressed and state whether the 'ok' field in the JSON body is true.", "expected": "true", "verifier": {"type": "json_key", "path": "ok", "equals": true}, "meta": {"offline": true, "category": "gzip_json"}} +{"id": "fetch_023", "question": "Send a HEAD request to http://127.0.0.1:31415/notfound and report the HTTP status code.", "expected": "404", "verifier": {"type": "status", "equals": 404}, "meta": {"offline": true, "category": "status_head"}} +{"id": "fetch_024", "question": "Query http://127.0.0.1:31415/query?category=shapes&limit=5 (or set those params explicitly) and return the 'count' reported by the server.", "expected": "3", "verifier": {"type": "json_key", "path": "count", "equals": 3}, "meta": {"offline": true, "category": "query_count"}} +{"id": "fetch_025", "question": "Use the fetch tool to read http://127.0.0.1:31415/text/poem.txt and summarize the poem in no more than 20 words.", "expected": "Short deterministic-poem summary highlighting reliable servers and flowing signals.", "verifier": {"type": "judge", "rubric": "poem_summary"}, "meta": {"offline": true, "category": "judge_poem"}} +{"id": "fetch_026", "question": "Fetch http://127.0.0.1:31415/query?category=fruits&limit=4 and provide a concise (<25 words) takeaway that names at least two fruits and states how many results were returned.", "expected": "Summary mentioning two fruits and noting the count of four entries.", "verifier": {"type": "judge", "rubric": "fruits_summary"}, "meta": {"offline": true, "category": "judge_fruits"}} +{"id": "fetch_027", "question": "Fetch http://127.0.0.1:31415/json/data_large.jsonl and report the sum of all decimal digits in the sha256 hash returned by the fetch tool (answer as an integer).", "expected": "157", "verifier": {"type": "hash_digit_sum", "equals": 157}, "meta": {"offline": true, "category": "hash_digit_sum"}} +{"id": "fetch_028", "question": "Fetch http://127.0.0.1:31415/text/poem.txt and count how many times the letter 'o' appears across the entire poem (case-insensitive). Return the count as an integer.", "expected": "7", "verifier": {"type": "char_count", "field": "body_text", "char": "o", "case_insensitive": true, "equals": 7}, "meta": {"offline": true, "category": "poem_char_count"}} +{"id": "fetch_029", "question": "Fetch http://127.0.0.1:31415/text/poem.txt again and count the occurrences of the letter 'i' (case-insensitive) across all lines. Answer with the integer count.", "expected": "11", "verifier": {"type": "char_count", "field": "body_text", "char": "i", "case_insensitive": true, "equals": 11}, "meta": {"offline": true, "category": "poem_char_count"}} +{"id": "fetch_030", "question": "Fetch http://127.0.0.1:31415/json/ledger.json and sum every entry whose \"type\" is \"expense\". Return the integer total.", "expected": "46", "verifier": {"type": "json_key", "path": "entries.0.amount", "equals": 12}, "meta": {"offline": true, "category": "ledger_sum"}} +{"id": "fetch_031", "question": "Fetch http://127.0.0.1:31415/json/ledger.json again and compute the net balance defined as (sum of income amounts) minus (sum of expense amounts) minus (sum of refund amounts). Answer with the integer result.", "expected": "8", "verifier": {"type": "json_key", "path": "entries.4.type", "equals": "income"}, "meta": {"offline": true, "category": "ledger_net"}} +{"id": "fetch_032", "question": "Fetch http://127.0.0.1:31415/json/pointers.json, locate the path for key \"ember\", then fetch the referenced text file and return the final token mentioned on line 3 exactly as written.", "expected": "glow-lantern", "verifier": {"type": "json_key", "path": "paths.ember", "equals": "/text/ember.txt"}, "meta": {"offline": true, "category": "pointer_lookup"}} +{"id": "fetch_033", "question": "Using http://127.0.0.1:31415/json/pointers.json, find the path for \"relay\", fetch it, and reply with the hyphenated password printed in that file.", "expected": "BRIDGE-NODE-7", "verifier": {"type": "json_key", "path": "paths.relay", "equals": "/text/relay.txt"}, "meta": {"offline": true, "category": "pointer_lookup"}} +{"id": "fetch_034", "question": "Read http://127.0.0.1:31415/json/pointers.json, follow extras.checksum to http://127.0.0.1:31415/json/checksum_hints.json, and sum the \"value\" fields where \"source\" is \"north\". Respond with the integer sum.", "expected": "28", "verifier": {"type": "json_key", "path": "segments.0.value", "equals": 14}, "meta": {"offline": true, "category": "checksum"}} +{"id": "fetch_035", "question": "Fetch http://127.0.0.1:31415/json/checksum_hints.json (or reach it via the pointers map). Count how many numbers in \"sequence\" are greater than 2, then multiply that count by the combined values of segments whose source is east or west. Return the integer product.", "expected": "26", "verifier": {"type": "json_key", "path": "sequence.0", "equals": 8}, "meta": {"offline": true, "category": "checksum"}} +{"id": "fetch_036", "question": "Use http://127.0.0.1:31415/json/pointers.json to locate \"sequence\". Fetch the referenced text file and count the number of words in the sentence that begins with \"Second sentence\". Return the integer count.", "expected": "9", "verifier": {"type": "json_key", "path": "paths.sequence", "equals": "/text/sequence.txt"}, "meta": {"offline": true, "category": "sequence_words"}} +{"id": "fetch_037", "question": "Fetch http://127.0.0.1:31415/json/pointers.json, follow the \"directive\" path, obey the instructions inside (which reference the manifest HTML), and report the uppercase code.", "expected": "ORBIT-CASCADE", "verifier": {"type": "json_key", "path": "paths.directive", "equals": "/text/directive.txt"}, "meta": {"offline": true, "category": "directive"}} +{"id": "fetch_038", "question": "Fetch http://127.0.0.1:31415/json/ledger.json and compute the average of the expense amounts (only entries where type == \"expense\"). Report the value rounded to one decimal place.", "expected": "9.2", "verifier": {"type": "json_key", "path": "entries.1.amount", "equals": 9}, "meta": {"offline": true, "category": "ledger_avg"}} +{"id": "fetch_039", "question": "After reading http://127.0.0.1:31415/json/ledger.json, provide a concise (<=25 words) summary that mentions both total expenses (46) and total income (58) and clearly states whether income exceeds expenses.", "expected": "Summary referencing totals.", "verifier": {"type": "judge", "rubric": "ledger_summary"}, "meta": {"offline": true, "category": "judge_ledger"}} +{"id": "fetch_040", "question": "Follow the directive instructions (starting from http://127.0.0.1:31415/json/pointers.json) to retrieve the manifest HTML at /html/manifest.html, then write a <=25 word report that states the final code and reminds the reader it must remain uppercase.", "expected": "Brief manifest report mentioning uppercase requirement.", "verifier": {"type": "judge", "rubric": "manifest_report"}, "meta": {"offline": true, "category": "judge_manifest"}} diff --git a/environments/mcp_fetch/tools/fetch_mcp_server.py b/environments/mcp_fetch/tools/fetch_mcp_server.py new file mode 100644 index 000000000..c410b1675 --- /dev/null +++ b/environments/mcp_fetch/tools/fetch_mcp_server.py @@ -0,0 +1,341 @@ +"""Fetch MCP server and CLI utility. + +This module serves two purposes: +- Local development CLI (`python fetch_mcp_server.py --url ...`) that prints a JSON + payload describing the fetch result. +- MCP-compliant stdio server (`python fetch_mcp_server.py --run-server`) which + exposes a `fetch` tool returning both structured JSON and text content. + +The structured payload includes: status, headers, body_text/body_json, hash of the +truncated body, final URL, and bookkeeping flags (`bytes`, `truncated`). +""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import os +import sys +from typing import Any, Dict, Iterable, Optional, Sequence, Tuple +from urllib.parse import urlparse + +import anyio +import httpx +import mcp.types as types +from mcp.server import Server +from mcp.server.stdio import stdio_server + +DEFAULT_TIMEOUT_S = 8.0 +DEFAULT_MAX_BYTES = 200_000 + + +def _normalize_hosts(hosts: Iterable[str]) -> list[str]: + seen: list[str] = [] + for host in hosts: + cleaned = host.strip().lower() + if cleaned and cleaned not in seen: + seen.append(cleaned) + return seen + + +def _collect_allowed_hosts(cli_hosts: Optional[Sequence[str]] = None) -> list[str]: + hosts: list[str] = [] + env_hosts = os.getenv("MCP_FETCH_ALLOWED_HOSTS") + if env_hosts: + hosts.extend([h for h in env_hosts.split(",") if h.strip()]) + if cli_hosts: + hosts.extend(cli_hosts) + return _normalize_hosts(hosts) + + +def _parse_bool_env(name: str) -> bool: + value = os.getenv(name, "").strip().lower() + return value in {"1", "true", "yes", "on"} + + +def _ensure_allowed_host(url: str, allowed_hosts: Sequence[str], allow_any: bool) -> None: + if allow_any or not allowed_hosts: + return + + parsed = urlparse(url) + if parsed.scheme not in {"http", "https"}: + raise ValueError("Only http:// or https:// URLs are permitted") + if not parsed.hostname: + raise ValueError("URL must include a hostname") + + host = parsed.hostname.lower() + port = parsed.port + host_key = f"{host}:{port}" if port else host + + candidates = {host_key, host} + if not any(candidate in allowed_hosts for candidate in candidates): + raise ValueError(f"Host '{host_key}' not in allowed list") + + +def _parse_kv_pairs(pairs: Sequence[str], what: str) -> Dict[str, str]: + parsed: Dict[str, str] = {} + for pair in pairs: + if ":" in pair: + key, value = pair.split(":", 1) + elif "=" in pair: + key, value = pair.split("=", 1) + else: + raise ValueError(f"Invalid {what} format '{pair}'. Use key:value or key=value.") + parsed[key.strip()] = value.strip() + return parsed + + +async def fetch_url_async( + url: str, + *, + method: str = "GET", + headers: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, str]] = None, + timeout_s: float = DEFAULT_TIMEOUT_S, + max_bytes: int = DEFAULT_MAX_BYTES, +) -> Dict[str, Any]: + headers = {str(k): str(v) for k, v in (headers or {}).items()} + params_dict = None + if params: + # Only pass params to httpx when the caller provided explicit overrides; + # passing an empty dict would strip query strings already present in the URL. + params_dict = {str(k): str(v) for k, v in params.items()} + + method = method.upper() + if method not in {"GET", "HEAD"}: + raise ValueError("Only GET and HEAD are supported") + + if max_bytes <= 0: + raise ValueError("max_bytes must be > 0") + if timeout_s <= 0: + raise ValueError("timeout_s must be > 0") + + async with httpx.AsyncClient(follow_redirects=True, timeout=timeout_s) as client: + if method == "HEAD": + response = await client.head(url, headers=headers, params=params_dict) + raw_body = b"" + else: + response = await client.get(url, headers=headers, params=params_dict) + raw_body = response.content + + body_bytes = raw_body[:max_bytes] + content_type = response.headers.get("Content-Type", "") + + body_text: Optional[str] = None + body_json: Optional[Any] = None + if method != "HEAD" and body_bytes: + if "json" in content_type.lower(): + try: + body_json = json.loads(body_bytes.decode("utf-8", errors="replace")) + except Exception: + body_text = body_bytes.decode("utf-8", errors="replace") + else: + body_text = body_bytes.decode("utf-8", errors="replace") + + structured: Dict[str, Any] = { + "status": response.status_code, + "headers": dict(response.headers.items()), + "body_text": body_text, + "body_json": body_json, + "hash": hashlib.sha256(body_bytes).hexdigest(), + "final_url": str(response.url), + "bytes": len(body_bytes), + "truncated": len(body_bytes) < len(raw_body), + "method": method, + "content_type": content_type, + } + + return structured + + +def _build_tool_schema() -> Tuple[Dict[str, Any], Dict[str, Any]]: + input_schema = { + "type": "object", + "required": ["url"], + "properties": { + "url": {"type": "string", "description": "URL to fetch"}, + "method": { + "type": "string", + "enum": ["GET", "HEAD", "get", "head"], + "default": "GET", + "description": "HTTP method (GET or HEAD)", + }, + "headers": { + "type": "object", + "additionalProperties": {"type": "string"}, + "description": "Optional HTTP headers", + }, + "params": { + "type": "object", + "additionalProperties": {"type": "string"}, + "description": "Optional query parameters", + }, + "timeout_s": { + "type": "number", + "minimum": 0.1, + "maximum": 60.0, + "default": DEFAULT_TIMEOUT_S, + "description": "Request timeout in seconds", + }, + "max_bytes": { + "type": "integer", + "minimum": 1, + "maximum": 1_000_000, + "default": DEFAULT_MAX_BYTES, + "description": "Maximum bytes to read from the response body", + }, + }, + } + + output_schema = { + "type": "object", + "required": ["status", "headers", "hash", "final_url", "bytes", "truncated"], + "properties": { + "status": {"type": "integer"}, + "headers": { + "type": "object", + "additionalProperties": {"type": "string"}, + }, + "body_text": {"type": ["string", "null"]}, + "body_json": {"type": ["object", "array", "null"]}, + "hash": {"type": "string"}, + "final_url": {"type": "string"}, + "bytes": {"type": "integer"}, + "truncated": {"type": "boolean"}, + "method": {"type": "string"}, + "content_type": {"type": ["string", "null"]}, + }, + } + + return input_schema, output_schema + + +def create_server( + *, + allowed_hosts: Sequence[str], + allow_any_host: bool, +) -> Server: + server = Server("mcp-fetch") + input_schema, output_schema = _build_tool_schema() + + @server.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="fetch", + title="HTTP Fetch", + description="Fetches a URL using GET or HEAD and returns metadata plus body.", + inputSchema=input_schema, + outputSchema=output_schema, + ) + ] + + @server.call_tool() + async def call_tool( + name: str, arguments: dict[str, Any] + ) -> Tuple[list[types.ContentBlock], Dict[str, Any]]: + if name != "fetch": + raise ValueError(f"Unknown tool '{name}'") + + url = arguments.get("url") + if not isinstance(url, str) or not url: + raise ValueError("Argument 'url' must be a non-empty string") + + method = str(arguments.get("method", "GET")) + headers = arguments.get("headers") or {} + params = arguments.get("params") or {} + timeout_s = float(arguments.get("timeout_s", DEFAULT_TIMEOUT_S)) + max_bytes = int(arguments.get("max_bytes", DEFAULT_MAX_BYTES)) + + _ensure_allowed_host(url, allowed_hosts, allow_any_host) + + structured = await fetch_url_async( + url, + method=method, + headers=headers, + params=params, + timeout_s=timeout_s, + max_bytes=max_bytes, + ) + text = json.dumps(structured, indent=2, ensure_ascii=False, sort_keys=True) + return [types.TextContent(type="text", text=text)], structured + + return server + + +def run_server(*, allowed_hosts: Sequence[str], allow_any_host: bool) -> int: + server = create_server( + allowed_hosts=allowed_hosts, + allow_any_host=allow_any_host, + ) + + async def _run() -> None: + async with stdio_server() as streams: + await server.run(streams[0], streams[1], server.create_initialization_options()) + + anyio.run(_run) + return 0 + + +def run_cli(args: argparse.Namespace, allowed_hosts: Sequence[str], allow_any_host: bool) -> int: + if not args.url: + print("Missing required argument --url", file=sys.stderr) + return 2 + + headers = _parse_kv_pairs(args.header or [], "header") + params = _parse_kv_pairs(args.param or [], "param") + + try: + _ensure_allowed_host(args.url, allowed_hosts, allow_any_host) + result = anyio.run( + fetch_url_async, + args.url, + method=args.method, + headers=headers, + params=params, + timeout_s=args.timeout_s, + max_bytes=args.max_bytes, + ) + print(json.dumps(result, indent=2, ensure_ascii=False)) + return 0 + except Exception as exc: + print(json.dumps({"error": str(exc)})) + return 1 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument("--run-server", action="store_true", help="Run as an MCP stdio server") + parser.add_argument("--transport", choices=["stdio"], default="stdio", help="Transport when running the server") + parser.add_argument("--allowed-host", dest="allowed_hosts", action="append", help="Allowlist host (repeatable)") + parser.add_argument("--allow-any-host", action="store_true", help="Disable host allowlist checks") + + parser.add_argument("--url", help="URL to fetch (CLI mode)") + parser.add_argument("--method", default="GET", help="HTTP method (CLI mode)") + parser.add_argument("--timeout-s", type=float, default=DEFAULT_TIMEOUT_S, help="Timeout in seconds (CLI mode)") + parser.add_argument("--max-bytes", type=int, default=DEFAULT_MAX_BYTES, help="Maximum body bytes to retain") + parser.add_argument("--header", action="append", help="Custom header key:value (CLI mode)") + parser.add_argument("--param", action="append", help="Query parameter key:value (CLI mode)") + return parser + + +def main(argv: Optional[Sequence[str]] = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + + if args.transport != "stdio": + raise SystemExit("Only stdio transport is supported in this server build") + + allowed_hosts = _collect_allowed_hosts(args.allowed_hosts) + allow_any_host = args.allow_any_host or _parse_bool_env("MCP_FETCH_ALLOW_ANY_HOST") + if allow_any_host: + allowed_hosts = [] + + if args.run_server: + return run_server(allowed_hosts=allowed_hosts, allow_any_host=allow_any_host) + return run_cli(args, allowed_hosts=allowed_hosts, allow_any_host=allow_any_host) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/environments/mcp_fetch/utils/mini_httpd.py b/environments/mcp_fetch/utils/mini_httpd.py new file mode 100644 index 000000000..4ddfb8d1e --- /dev/null +++ b/environments/mcp_fetch/utils/mini_httpd.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import gzip +import http.server +import io +import json +import socketserver +from pathlib import Path +from urllib.parse import parse_qs, urlparse + +socketserver.TCPServer.allow_reuse_address = True + +ROOT = Path(__file__).resolve().parent.parent / "fixtures" +CATALOG = { + "fruits": ["apple", "banana", "cherry", "date"], + "colors": ["red", "green", "blue", "violet"], + "shapes": ["circle", "square", "triangle"], +} +AUTH_TOKEN = "opensesame" + +class Handler(http.server.BaseHTTPRequestHandler): + server_version = "MiniHTTPD/0.1" + + def _set_headers(self, status=200, headers=None): + self.send_response(status) + headers = headers or {} + for k, v in headers.items(): + self.send_header(k, v) + self.end_headers() + + def do_HEAD(self): + return self.do_GET(head_only=True) + + def do_GET(self, head_only=False): + parsed = urlparse(self.path) + path = parsed.path + qs = parse_qs(parsed.query or "") + + if path == "/notfound": + self._set_headers(404, {"Content-Type": "text/plain"}) + if not head_only: + self.wfile.write(b"not found\n") + return + + if path == "/headers": + # Echo back a chosen header (for deterministic tests) + value = self.headers.get("X-Debug", "none") + body = json.dumps({"echo": value}).encode("utf-8") + self._set_headers(200, {"Content-Type": "application/json"}) + if not head_only: + self.wfile.write(body) + return + + if path == "/redirect1": + self.send_response(302) + self.send_header("Location", "/html/final.html") + self.end_headers() + return + + if path == "/redirect_chain": + self.send_response(302) + self.send_header("Location", "/redirect_step2") + self.end_headers() + return + + if path == "/redirect_step2": + self.send_response(302) + self.send_header("Location", "/html/final.html") + self.end_headers() + return + + if path == "/compressed": + payload = json.dumps({"ok": True, "note": "gzipped"}).encode("utf-8") + buf = io.BytesIO() + with gzip.GzipFile(fileobj=buf, mode="wb") as gz: + gz.write(payload) + compressed = buf.getvalue() + self._set_headers(200, {"Content-Type": "application/json", "Content-Encoding": "gzip"}) + if not head_only: + self.wfile.write(compressed) + return + + if path == "/json/data_large.jsonl": + # Generate deterministic 100 lines + lines = [json.dumps({"i": i, "v": i * 2}) for i in range(100)] + data = ("\n".join(lines) + "\n").encode("utf-8") + self._set_headers(200, {"Content-Type": "application/json"}) + if not head_only: + self.wfile.write(data) + return + + if path == "/query": + category = (qs.get("category") or [""])[0] + items = CATALOG.get(category, []) + limit_str = (qs.get("limit") or [str(len(items))])[0] + try: + limit = max(0, min(len(items), int(limit_str))) + except ValueError: + limit = len(items) + limited = items[:limit] + body = json.dumps({ + "category": category, + "limit": limit, + "results": limited, + "count": len(limited), + "summary": ", ".join(limited) if limited else "none", + "first": limited[0] if limited else None, + }).encode("utf-8") + self._set_headers(200, {"Content-Type": "application/json"}) + if not head_only: + self.wfile.write(body) + return + + if path == "/auth": + token = self.headers.get("X-Token", "") + if token == AUTH_TOKEN: + body = json.dumps({"secret": "orchid", "message": "access granted"}).encode("utf-8") + status = 200 + else: + body = json.dumps({"error": "missing or invalid token"}).encode("utf-8") + status = 401 + self._set_headers(status, {"Content-Type": "application/json"}) + if not head_only: + self.wfile.write(body) + return + + # Static files under fixtures/ + local = ROOT / path.lstrip("/") + if local.is_dir(): + local = local / "index.html" + if local.exists() and local.is_file(): + ctype = "text/html" if local.suffix.lower() in {".html", ".htm"} else "application/json" + if local.suffix.lower() == ".txt": + ctype = "text/plain" + if local.suffix.lower() == ".jsonl": + ctype = "application/json" + data = b"" if head_only else local.read_bytes() + self._set_headers(200, {"Content-Type": ctype, "Content-Length": str((local.stat().st_size))}) + if not head_only: + self.wfile.write(data) + return + + # Fallback + self._set_headers(404, {"Content-Type": "text/plain"}) + if not head_only: + self.wfile.write(b"not found\n") + +def serve(port=31415): + with socketserver.TCPServer(("127.0.0.1", port), Handler) as httpd: + sa = httpd.socket.getsockname() + print(f"Serving fixtures at http://{sa[0]}:{sa[1]} (root={ROOT})") + try: + httpd.serve_forever() + except KeyboardInterrupt: + print("Stopping...") + +if __name__ == "__main__": + import argparse + p = argparse.ArgumentParser() + p.add_argument("--port", type=int, default=31415) + args = p.parse_args() + serve(args.port) diff --git a/environments/mcp_fetch/verifiers.py b/environments/mcp_fetch/verifiers.py new file mode 100644 index 000000000..6f692f8a4 --- /dev/null +++ b/environments/mcp_fetch/verifiers.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import re +from pathlib import Path +from typing import Any, Dict, Optional +from urllib.parse import urlparse + +import yaml + + +def normalize_ws(s: str) -> str: + return re.sub(r"\s+", " ", s.strip()) + + +def expect_equals(actual: str, expected: str) -> bool: + return normalize_ws(actual) == normalize_ws(expected) + + +def expect_contains(text: str, needle: str) -> bool: + return needle in text + + +def expect_status(payload: Dict[str, Any], status: int) -> bool: + return int(payload.get("status", -1)) == status + + +def expect_header(payload: Dict[str, Any], key: str, expected: str) -> bool: + headers = {k.lower(): v for k, v in (payload.get("headers") or {}).items()} + return headers.get(key.lower()) == expected + + +def expect_json_key(payload: Dict[str, Any], dotted: str, expected: Any) -> bool: + obj = payload.get("body_json") or {} + cur: Any = obj + for part in dotted.split("."): + if isinstance(cur, dict) and part in cur: + cur = cur[part] + elif isinstance(cur, list): + try: + idx = int(part) + except ValueError: + return False + if idx < 0 or idx >= len(cur): + return False + cur = cur[idx] + else: + return False + return cur == expected + + +def expect_hash(payload: Dict[str, Any], expected_hex: str) -> bool: + return str(payload.get("hash")) == expected_hex + + +def expect_field(payload: Dict[str, Any], field: str, expected: Any) -> bool: + return payload.get(field) == expected + + +def expect_bool(payload: Dict[str, Any], field: str, expected: bool) -> bool: + return bool(payload.get(field)) is expected + + +def expect_number(payload: Dict[str, Any], field: str, expected: int | float) -> bool: + value = payload.get(field) + try: + return float(value) == float(expected) + except (TypeError, ValueError): + return False + + +def expect_hash_digit_sum(payload: Dict[str, Any], expected: int) -> bool: + digest = str(payload.get("hash") or "") + digit_sum = sum(int(ch) for ch in digest if ch.isdigit()) + return digit_sum == int(expected) + + +def expect_char_count( + payload: Dict[str, Any], + field: str, + *, + char: str, + case_insensitive: bool, + expected: int, +) -> bool: + target = (char or "")[:1] + if not target: + return False + text = str(_get_field(payload, field) or "") + haystack = text + needle = target + if case_insensitive: + haystack = haystack.lower() + needle = needle.lower() + return haystack.count(needle) == int(expected) + + +def _get_field(payload: Dict[str, Any], field: str) -> Any: + if field == "body_text": + return payload.get("body_text") or "" + if field == "final_url_suffix": + final_url = payload.get("final_url") or "" + return urlparse(final_url).path + return payload.get(field) + + +def run_verifier(verifier: Dict[str, Any], payload: Dict[str, Any]) -> Optional[bool]: + vtype = verifier.get("type") + + if vtype == "equals": + field = verifier.get("field") + if not field: + return False + actual = _get_field(payload, field) + return expect_equals(str(actual), str(verifier.get("pattern", ""))) + if vtype == "contains": + field = verifier.get("field") + if not field: + return False + actual = str(_get_field(payload, field) or "") + return expect_contains(actual, str(verifier.get("pattern", ""))) + if vtype == "header": + return expect_header(payload, verifier["key"], verifier["equals"]) + if vtype == "status": + return expect_status(payload, int(verifier["equals"])) + if vtype == "json_key": + return expect_json_key(payload, verifier["path"], verifier["equals"]) + if vtype == "hash": + return expect_hash(payload, verifier["equals"]) + if vtype == "hash_digit_sum": + return expect_hash_digit_sum(payload, verifier["equals"]) + if vtype == "char_count": + return expect_char_count( + payload, + verifier["field"], + char=str(verifier.get("char", "")), + case_insensitive=bool(verifier.get("case_insensitive", False)), + expected=verifier["equals"], + ) + if vtype == "field_bool": + return expect_bool(payload, verifier["field"], bool(verifier["equals"])) + if vtype == "field_number": + return expect_number(payload, verifier["field"], verifier["equals"]) + if vtype == "judge": + # Defer to JudgeRubric evaluation in the calling context. + return None + + raise ValueError(f"Unsupported verifier type: {vtype}") + + +_JUDGE_CACHE: Dict[str, Dict[str, Any]] | None = None + + +def load_judge_rubrics(path: Optional[Path] = None) -> Dict[str, Dict[str, Any]]: + """Load judge rubric definitions from YAML.""" + + global _JUDGE_CACHE + if _JUDGE_CACHE is not None: + return _JUDGE_CACHE + + file_path = path or Path(__file__).resolve().parent / "tasks" / "judge_rubrics.yaml" + with file_path.open("r", encoding="utf-8") as handle: + data = yaml.safe_load(handle) or {} + rubrics = data.get("rubrics") or {} + if not isinstance(rubrics, dict): + raise ValueError("Invalid judge rubric file: expected top-level 'rubrics' mapping") + _JUDGE_CACHE = {str(key): value for key, value in rubrics.items()} + return _JUDGE_CACHE + + +def get_judge_prompt(rubric_id: str) -> str: + rubrics = load_judge_rubrics() + if rubric_id not in rubrics: + raise KeyError(f"Unknown judge rubric '{rubric_id}'") + rubric = rubrics[rubric_id] + prompt = rubric.get("judge_prompt") + if not prompt: + raise ValueError(f"Judge rubric '{rubric_id}' missing 'judge_prompt'") + return str(prompt) diff --git a/tests/environments/test_mcp_fetch.py b/tests/environments/test_mcp_fetch.py new file mode 100644 index 000000000..46b795a00 --- /dev/null +++ b/tests/environments/test_mcp_fetch.py @@ -0,0 +1,130 @@ +import asyncio +import json +import subprocess +import sys +import time +from pathlib import Path + +import pytest + +from environments.mcp_fetch.mcp_env import DEFAULT_FIXTURE_PORT, load_environment +from environments.mcp_fetch.tools.fetch_mcp_server import fetch_url_async +from environments.mcp_fetch.verifiers import run_verifier + +TASKS_PATH = Path("environments/mcp_fetch/tasks/qa.jsonl").resolve() + + +def _port_in_use(port: int) -> bool: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.25) + return sock.connect_ex(("127.0.0.1", port)) == 0 + + +def _wait_port_state(port: int, expected_open: bool, timeout: float = 5.0) -> None: + deadline = time.time() + timeout + while time.time() < deadline: + if _port_in_use(port) == expected_open: + return + time.sleep(0.05) + raise TimeoutError(f"Port {port} did not reach state={expected_open}") + + +def _load_tasks() -> dict[str, dict]: + data: dict[str, dict] = {} + with TASKS_PATH.open("r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + row = json.loads(line) + data[str(row["id"])] = row + return data + + +@pytest.fixture +def fixture_server(): + """Ensure the deterministic fixture server is running for a test.""" + + port = DEFAULT_FIXTURE_PORT + if _port_in_use(port): + yield None + return + + cmd = [ + sys.executable, + "-m", + "environments.mcp_fetch.utils.mini_httpd", + "--port", + str(port), + ] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + _wait_port_state(port, expected_open=True) + yield proc + if proc.poll() is None: + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + _wait_port_state(port, expected_open=False) + + +def test_fetch_env_auto_starts_fixture(): + if _port_in_use(DEFAULT_FIXTURE_PORT): + pytest.skip("Fixture port already in use; skipping auto-start test") + env = load_environment() + assert env.dataset is not None + _wait_port_state(DEFAULT_FIXTURE_PORT, expected_open=True) + env.close() + _wait_port_state(DEFAULT_FIXTURE_PORT, expected_open=False) + + +@pytest.mark.asyncio +async def test_fetch_tool_verifiers_cover_core_cases(fixture_server): + tasks = _load_tasks() + base_url = f"http://127.0.0.1:{DEFAULT_FIXTURE_PORT}" + shard_cases = [ + ("fetch_001", {"path": "/text/poem.txt", "method": "GET"}), # char_count + ("fetch_011", {"path": "/json/pointers.json"}), # json_key + ("fetch_023", {"path": "/notfound", "method": "HEAD"}), # status (HEAD) + ("fetch_024", {"path": "/query?category=shapes&limit=5"}), # query json_key + ("fetch_018", {"path": "/json/data_large.jsonl"}), # hash + ("fetch_019", {"path": "/json/data_large.jsonl", "max_bytes": 64}), # field_bool + ("fetch_020", {"path": "/json/data_large.jsonl", "max_bytes": 64}), # field_number + ("fetch_027", {"path": "/json/data_large.jsonl"}), # hash_digit_sum + ("fetch_044", {"path": "/html/matrix.html"}), # contains + ] + for task_id, request in shard_cases: + req = dict(request) + rel_path = req.pop("path") + req["url"] = f"{base_url}{rel_path}" + payload = await fetch_url_async(**req) + verifier = tasks[task_id]["verifier"] + result = run_verifier(verifier, payload) + assert result is True + + +def test_fetch_env_dataset_and_parser(fixture_server): + env = load_environment(auto_start_fixtures=False) + assert env.dataset is not None + assert len(env.dataset) >= 35 + assert env.host_allowlist[0].startswith("127.0.0.1") + + prompt = env.dataset[0]["prompt"] + completion = [{"role": "assistant", "content": "ANSWER: test"}] + state = { + "prompt": prompt, + "completion": completion, + "responses": [], + "turn": 0, + "timing": {"generation_ms": 0.0, "total_ms": 0.0, "scoring_ms": 0.0}, + "task": "default", + "info": {}, + } + score = asyncio.run( + env.rubric.score_rollout(prompt, completion, env.dataset[0]["answer"], state) + ) + assert "accuracy" in score.metrics + env.close() diff --git a/verifiers/__init__.py b/verifiers/__init__.py index aefe6c962..9afcd8fc3 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -8,6 +8,7 @@ from .types import * # noqa # isort: skip from .envs.env_group import EnvGroup from .envs.environment import Environment +from .envs.mcp_env import MCPEnv from .envs.multiturn_env import MultiTurnEnv from .envs.singleturn_env import SingleTurnEnv from .envs.stateful_tool_env import StatefulToolEnv @@ -78,6 +79,7 @@ def setup_logging( "MathRubric", "TextArenaEnv", "ReasoningGymEnv", + "MCPEnv", "Environment", "MultiTurnEnv", "SingleTurnEnv", diff --git a/verifiers/envs/mcp/__init__.py b/verifiers/envs/mcp/__init__.py new file mode 100644 index 000000000..e9b2dc995 --- /dev/null +++ b/verifiers/envs/mcp/__init__.py @@ -0,0 +1,7 @@ +"""Shared helpers for MCP-based tool environments.""" + +from .models import MCPServerConfig +from .server_connection import MCPServerConnection +from .tool_wrapper import MCPToolWrapper + +__all__ = ["MCPServerConfig", "MCPServerConnection", "MCPToolWrapper"] diff --git a/verifiers/envs/mcp/models.py b/verifiers/envs/mcp/models.py new file mode 100644 index 000000000..b5c027fe0 --- /dev/null +++ b/verifiers/envs/mcp/models.py @@ -0,0 +1,17 @@ +"""Dataclasses used to launch MCP stdio servers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List + + +@dataclass(slots=True) +class MCPServerConfig: + """Simple description of an MCP stdio server.""" + + name: str + command: str + args: List[str] | None = None + env: Dict[str, str] | None = None + description: str = "" diff --git a/verifiers/envs/mcp/server_connection.py b/verifiers/envs/mcp/server_connection.py new file mode 100644 index 000000000..b2bd05854 --- /dev/null +++ b/verifiers/envs/mcp/server_connection.py @@ -0,0 +1,110 @@ +"""Async helper for managing a single MCP stdio server connection.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Dict, Optional + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from mcp.types import TextContent, Tool + +from .models import MCPServerConfig + + +class MCPServerConnection: + """Maintains a background connection to a single MCP stdio server.""" + + def __init__(self, config: MCPServerConfig, logger: logging.Logger): + self.config = config + self.logger = logger + self.session: Optional[ClientSession] = None + self.tools: Dict[str, Tool] = {} + + self._connection_task: Optional[asyncio.Task] = None + self._ready = asyncio.Event() + self._error: Optional[Exception] = None + self.loop: Optional[asyncio.AbstractEventLoop] = None + + async def connect(self) -> Dict[str, Tool]: + """Launch the MCP server and wait for a registered tool list.""" + + self.loop = asyncio.get_running_loop() + self._connection_task = asyncio.create_task(self._run()) + + await self._ready.wait() + + if self._error: + raise self._error + + return self.tools + + async def _run(self) -> None: + try: + server_params = StdioServerParameters( + command=self.config.command, + args=self.config.args or [], + env=self.config.env, + ) + + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + self.session = session + + await session.initialize() + + tools_response = await session.list_tools() + + for tool in tools_response.tools: + self.tools[tool.name] = tool + + self._ready.set() + + while True: + await asyncio.sleep(1) + + except asyncio.CancelledError: + raise + except Exception as exc: # noqa: BLE001 + self._error = exc + self._ready.set() + finally: + self.session = None + self.tools = {} + + async def call_tool(self, tool_name: str, arguments: dict) -> str: + """Invoke a tool exposed by the connected MCP server.""" + + assert self.session is not None, f"Server '{self.config.name}' not connected" + assert self.loop is not None, "Connection loop not initialized" + + fut = asyncio.run_coroutine_threadsafe( + self.session.call_tool(tool_name, arguments=arguments), self.loop + ) + result = await asyncio.wrap_future(fut) + + if result.content: + text_parts = [] + for content_item in result.content: + if isinstance(content_item, TextContent): + text_parts.append(content_item.text) + elif getattr(content_item, "type", None) == "text": + text_parts.append(getattr(content_item, "text", str(content_item))) + else: + text_parts.append(str(content_item)) + + return "\n".join(text_parts) + + return "No result returned from tool" + + async def disconnect(self) -> None: + """Terminate the background MCP connection.""" + + assert self._connection_task is not None + self._connection_task.cancel() + try: + await self._connection_task + except asyncio.CancelledError: + pass + self.logger.info("MCP server '%s' terminated", self.config.name) diff --git a/verifiers/envs/mcp/tool_wrapper.py b/verifiers/envs/mcp/tool_wrapper.py new file mode 100644 index 000000000..448e39bf6 --- /dev/null +++ b/verifiers/envs/mcp/tool_wrapper.py @@ -0,0 +1,65 @@ +"""Wrapper that adapts MCP tool schemas to ToolEnv callables.""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import Tool + +from .server_connection import MCPServerConnection + + +class MCPToolWrapper: + """Callable shim that ToolEnv can treat like a native function tool.""" + + def __init__( + self, server_name: str, tool: Tool, server_connection: MCPServerConnection + ): + self.server_name = server_name + self.tool = tool + self.server_connection = server_connection + + self.__name__ = tool.name + self.__doc__ = tool.description or "" + + self.__annotations__ = self._build_annotations() + + def _build_annotations(self) -> dict: + annotations = {} + + if self.tool.inputSchema: + properties = self.tool.inputSchema.get("properties", {}) + + for param_name, param_spec in properties.items(): + param_type = param_spec.get("type", "string") + if param_type == "string": + annotations[param_name] = str + elif param_type == "integer": + annotations[param_name] = int + elif param_type == "number": + annotations[param_name] = float + elif param_type == "boolean": + annotations[param_name] = bool + elif param_type == "array": + annotations[param_name] = list + elif param_type == "object": + annotations[param_name] = dict + else: + annotations[param_name] = Any + + annotations["return"] = str + return annotations + + async def __call__(self, **kwargs): + return await self.server_connection.call_tool(self.tool.name, kwargs) + + def to_oai_tool(self) -> dict: + return { + "type": "function", + "function": { + "name": self.__name__, + "description": self.__doc__ or "", + "parameters": self.tool.inputSchema + or {"type": "object", "properties": {}}, + }, + } diff --git a/verifiers/envs/mcp_env.py b/verifiers/envs/mcp_env.py new file mode 100644 index 000000000..772acee63 --- /dev/null +++ b/verifiers/envs/mcp_env.py @@ -0,0 +1,126 @@ +"""Multi-turn ToolEnv wrapper for MCP stdio servers.""" + +from __future__ import annotations + +import asyncio +import atexit +import threading +from typing import Callable, Dict, Sequence + +from verifiers.envs.tool_env import ToolEnv +from verifiers.types import Message + +from .mcp import MCPServerConfig, MCPServerConnection, MCPToolWrapper + + +class MCPEnv(ToolEnv): + """Environment that proxies MCP stdio servers into ToolEnv tools.""" + + def __init__( + self, + mcp_servers: Sequence[MCPServerConfig] | None = None, + max_turns: int = 10, + error_formatter: Callable[[Exception], str] | None = None, + **kwargs, + ): + self.mcp_servers: list[MCPServerConfig] = [ + server if isinstance(server, MCPServerConfig) else MCPServerConfig(**server) + for server in (mcp_servers or []) + ] + self.server_connections: Dict[str, MCPServerConnection] = {} + self.mcp_tools: Dict[str, MCPToolWrapper] = {} + + formatter = error_formatter or (lambda exc: f"Error: {exc}") + + super().__init__( + tools=[], + max_turns=max_turns, + error_formatter=formatter, + **kwargs, + ) + + self._bg_loop = asyncio.new_event_loop() + self._bg_thread = threading.Thread( + target=self._run_loop, args=(self._bg_loop,), daemon=True + ) + self._bg_thread.start() + fut = asyncio.run_coroutine_threadsafe(self._connect_servers(), self._bg_loop) + fut.result() + + self._closed = False + atexit.register(self._cleanup_at_exit) + + def _cleanup_at_exit(self) -> None: + self.close() + + def _run_loop(self, loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + async def _connect_servers(self) -> None: + wrapper_tools = [] + + for server_config in self.mcp_servers: + connection = MCPServerConnection(server_config, self.logger) + tools = await connection.connect() + + self.server_connections[server_config.name] = connection + + for tool in tools.values(): + wrapper = MCPToolWrapper(server_config.name, tool, connection) + wrapper_tools.append(wrapper) + self.mcp_tools[wrapper.__name__] = wrapper + self.logger.info( + "Registered MCP tool '%s' from server '%s'", + wrapper.__name__, + server_config.name, + ) + + self.tools = wrapper_tools + self.oai_tools = [tool.to_oai_tool() for tool in wrapper_tools] + self.tool_map = {tool.__name__: tool for tool in wrapper_tools} + + async def call_tool( + self, tool_name: str, tool_args: dict, tool_call_id: str, **kwargs + ) -> Message: + if tool_name in self.tool_map: + tool_wrapper = self.tool_map[tool_name] + try: + result = await tool_wrapper(**tool_args) + return { + "role": "tool", + "content": str(result), + "tool_call_id": tool_call_id, + } + except Exception as exc: # noqa: BLE001 + return { + "role": "tool", + "content": self.error_formatter(exc), + "tool_call_id": tool_call_id, + } + return { + "role": "tool", + "content": f"Error: Tool '{tool_name}' not found", + "tool_call_id": tool_call_id, + } + + async def cleanup(self) -> None: + for connection in self.server_connections.values(): + await connection.disconnect() + + self.server_connections.clear() + self.mcp_tools.clear() + + def _shutdown_loop(self) -> None: + if self._bg_loop.is_running(): + self._bg_loop.call_soon_threadsafe(self._bg_loop.stop) + self._bg_thread.join(timeout=5) + + def close(self) -> None: + """Synchronously tear down background connections.""" + + if self._closed: + return + asyncio.run_coroutine_threadsafe(self.cleanup(), self._bg_loop).result() + self._shutdown_loop() + self._closed = True