Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 187 additions & 26 deletions scripts/eval-runner.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
#!/usr/bin/env python3
import argparse
import asyncio
import fnmatch
import importlib.util
import json
import os
import re
import socket
import sys
import traceback
from dataclasses import dataclass
from typing import Any, Callable
import importlib.util
import fnmatch
from pathlib import PurePosixPath

try:
Expand All @@ -20,6 +21,7 @@
_evals,
_set_lazy_load,
run_evaluator,
set_thread_pool_max_workers,
)
from braintrust.logger import Dataset
from braintrust.util import eprint
Expand All @@ -43,6 +45,21 @@
)


@dataclass(frozen=True)
class EvalFilter:
path: list[str]
pattern: re.Pattern[str]


@dataclass(frozen=True)
class RunnerConfig:
jsonl: bool
list_only: bool
terminate_on_failure: bool
num_workers: int | None
filters: list[EvalFilter]


@dataclass
class SseWriter:
sock: socket.socket
Expand Down Expand Up @@ -92,6 +109,79 @@ def env_flag(name: str) -> bool:
return value.lower() not in {"0", "false", "no", "off", ""}


def parse_serialized_filters(serialized: str | None) -> list[EvalFilter]:
if not serialized:
return []

parsed = json.loads(serialized)
if not isinstance(parsed, list):
raise ValueError("BT_EVAL_FILTER_PARSED must be a JSON array")

filters: list[EvalFilter] = []
for i, entry in enumerate(parsed):
if not isinstance(entry, dict):
raise ValueError("BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}")
key_path = entry.get("path")
pattern = entry.get("pattern")
if not isinstance(key_path, list) or not all(isinstance(part, str) for part in key_path):
raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} path must be an array of strings")
if not isinstance(pattern, str):
raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} pattern must be a string")
filters.append(EvalFilter(path=key_path, pattern=re.compile(pattern)))
return filters


def read_runner_config() -> RunnerConfig:
num_workers_value = os.getenv("BT_EVAL_NUM_WORKERS")
num_workers = int(num_workers_value) if num_workers_value else None
return RunnerConfig(
jsonl=env_flag("BT_EVAL_JSONL"),
list_only=env_flag("BT_EVAL_LIST"),
terminate_on_failure=env_flag("BT_EVAL_TERMINATE_ON_FAILURE"),
num_workers=num_workers,
filters=parse_serialized_filters(os.getenv("BT_EVAL_FILTER_PARSED")),
)


def _to_mapping(value: Any) -> Any:
if isinstance(value, dict):
return {k: _to_mapping(v) for k, v in value.items()}
if isinstance(value, list):
return [_to_mapping(v) for v in value]
if hasattr(value, "__dict__"):
return {
key: _to_mapping(val)
for key, val in vars(value).items()
if not key.startswith("_")
}
return value


def serialize_json_with_plain_string(value: Any) -> str:
if isinstance(value, str):
return value
return json.dumps(value)


def evaluate_filter(value: Any, filt: EvalFilter) -> bool:
current = _to_mapping(value)
for part in filt.path:
if not isinstance(current, dict) or part not in current:
return False
current = current[part]
return bool(filt.pattern.search(serialize_json_with_plain_string(current)))


def filter_evaluators(evaluators: list[EvaluatorInstance], filters: list[EvalFilter]) -> list[EvaluatorInstance]:
if not filters:
return evaluators
return [
evaluator
for evaluator in evaluators
if all(evaluate_filter(evaluator.evaluator, filt) for filt in filters)
]


def snake_to_camel(value: str) -> str:
parts = value.split("_")
if not parts:
Expand Down Expand Up @@ -218,8 +308,12 @@ def resolve_module_info(in_file: str) -> tuple[str, list[str]]:
return module_name, extra_paths


def load_evaluators(files: list[str]) -> list[EvaluatorInstance]:
def load_evaluators(files: list[str]) -> tuple[list[EvaluatorInstance], dict[str, Any]]:
evaluator_instances: list[EvaluatorInstance] = []
reporters: dict[str, Any] = {}
cwd = os.getcwd()
if cwd not in sys.path:
sys.path.insert(0, cwd)
unique_files: set[str] = set()
for file_path in files:
for candidate in collect_files(file_path):
Expand Down Expand Up @@ -250,10 +344,34 @@ def load_evaluators(files: list[str]) -> list[EvaluatorInstance]:
if isinstance(instance, EvaluatorInstance)
]
)
for reporter_name, reporter in _evals.reporters.items():
if reporter_name not in reporters:
reporters[reporter_name] = reporter
finally:
_evals.clear()

return evaluator_instances
return evaluator_instances, reporters


def resolve_reporter(
reporter: Any,
reporters: dict[str, Any],
) -> Any | None:
if isinstance(reporter, str):
if reporter not in reporters:
raise ValueError(f"Reporter {reporter} not found")
return reporters[reporter]
if reporter is not None:
return reporter

if len(reporters) == 0:
return None
if len(reporters) == 1:
return next(iter(reporters.values()))
reporter_names = ", ".join(reporters.keys())
raise ValueError(
f"Multiple reporters found ({reporter_names}). Please specify a reporter explicitly."
)


def _init_experiment_for_eval(evaluator):
Expand Down Expand Up @@ -319,38 +437,74 @@ async def run_evaluator_task(
if experiment:
experiment.flush()

async def run_once(files: list[str], no_send_logs: bool, sse: SseWriter | None) -> bool:
evaluators = load_evaluators(files)
if not evaluators:
async def run_once(
files: list[str],
no_send_logs: bool,
sse: SseWriter | None,
config: RunnerConfig,
) -> bool:
evaluators, reporters = load_evaluators(files)
if not evaluators and not config.list_only:
message = "No evaluators found. Did you call Eval() in the file?"
if sse:
sse.send("error", serialize_error(message))
sse.send("console", {"stream": "stderr", "message": message})
else:
eprint(message)
return False
return True

evaluators = filter_evaluators(evaluators, config.filters)
if config.list_only:
for evaluator_instance in evaluators:
print(evaluator_instance.evaluator.eval_name)
return True

supports_progress = run_evaluator_supports_progress()

tasks = []
progress_callbacks = []
for idx, evaluator_instance in enumerate(evaluators):
progress_cb = create_progress_reporter(sse, evaluator_instance.evaluator.eval_name)
progress_callbacks.append(progress_cb)
tasks.append(
asyncio.create_task(
run_evaluator_task(
evaluator_instance.evaluator, idx, no_send_logs, progress_cb, supports_progress
)
async def run_single_evaluator(
idx: int, evaluator_instance: EvaluatorInstance
) -> tuple[EvaluatorInstance, Any | None, Any | None, dict[str, Any] | None]:
try:
resolved_reporter = resolve_reporter(
getattr(evaluator_instance, "reporter", None),
reporters,
)
)
except Exception as exc:
err = serialize_error(str(exc), traceback.format_exc())
return evaluator_instance, None, None, err

all_success = True
for evaluator_instance, task, progress_cb in zip(evaluators, tasks, progress_callbacks):
progress_cb = create_progress_reporter(sse, evaluator_instance.evaluator.eval_name)
try:
result = await task
result = await run_evaluator_task(
evaluator_instance.evaluator,
idx,
no_send_logs,
progress_cb,
supports_progress,
)
except Exception as exc:
all_success = False
err = serialize_error(str(exc), traceback.format_exc())
return evaluator_instance, resolved_reporter, None, err

return evaluator_instance, resolved_reporter, result, None

execution_results: list[tuple[EvaluatorInstance, Any | None, Any | None, dict[str, Any] | None]] = []
if config.terminate_on_failure:
for idx, evaluator_instance in enumerate(evaluators):
run_result = await run_single_evaluator(idx, evaluator_instance)
execution_results.append(run_result)
if run_result[3] is not None:
break
else:
tasks = [
asyncio.create_task(run_single_evaluator(idx, evaluator_instance))
for idx, evaluator_instance in enumerate(evaluators)
]
execution_results = list(await asyncio.gather(*tasks))

all_success = True
for evaluator_instance, resolved_reporter, result, err in execution_results:
if err is not None:
all_success = False
if sse:
sse.send("error", err)
else:
Expand All @@ -359,11 +513,13 @@ async def run_once(files: list[str], no_send_logs: bool, sse: SseWriter | None)

if sse:
sse.send("summary", format_summary(result.summary.as_dict()))
elif config.jsonl:
print(json.dumps(format_summary(result.summary.as_dict())))
else:
print(result.summary)

failures = [row for row in result.results if row.error]
if failures:
if failures and resolved_reporter is None:
all_success = False
first_error = failures[0]
message = (
Expand All @@ -374,6 +530,8 @@ async def run_once(files: list[str], no_send_logs: bool, sse: SseWriter | None)
sse.send("error", serialize_error(message, stack))
else:
eprint(message)
if config.terminate_on_failure:
break

return all_success

Expand All @@ -392,16 +550,19 @@ def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)

config = read_runner_config()
local = args.local or env_flag("BT_EVAL_LOCAL") or env_flag("BT_EVAL_NO_SEND_LOGS")
files = args.files or ["."]
if config.num_workers is not None:
set_thread_pool_max_workers(config.num_workers)

if not local:
login(api_key=args.api_key, org_name=args.org_name, app_url=args.app_url)

sse = create_sse_writer()
cwd = os.path.abspath(os.getcwd())
try:
success = asyncio.run(run_once(files, local, sse))
success = asyncio.run(run_once(files, local, sse, config))
if sse:
sse.send("dependencies", {"files": collect_dependency_files(cwd, files)})
sse.send("done", {"success": success})
Expand Down
Loading
Loading