diff --git a/scripts/eval-runner.py b/scripts/eval-runner.py index 4d42214..cc8a279 100755 --- a/scripts/eval-runner.py +++ b/scripts/eval-runner.py @@ -1,15 +1,16 @@ #!/usr/bin/env python3 import argparse import asyncio +import fnmatch +import importlib.util import json import os +import re import socket import sys import traceback from dataclasses import dataclass from typing import Any, Callable -import importlib.util -import fnmatch from pathlib import PurePosixPath try: @@ -20,6 +21,7 @@ _evals, _set_lazy_load, run_evaluator, + set_thread_pool_max_workers, ) from braintrust.logger import Dataset from braintrust.util import eprint @@ -43,6 +45,21 @@ ) +@dataclass(frozen=True) +class EvalFilter: + path: list[str] + pattern: re.Pattern[str] + + +@dataclass(frozen=True) +class RunnerConfig: + jsonl: bool + list_only: bool + terminate_on_failure: bool + num_workers: int | None + filters: list[EvalFilter] + + @dataclass class SseWriter: sock: socket.socket @@ -92,6 +109,79 @@ def env_flag(name: str) -> bool: return value.lower() not in {"0", "false", "no", "off", ""} +def parse_serialized_filters(serialized: str | None) -> list[EvalFilter]: + if not serialized: + return [] + + parsed = json.loads(serialized) + if not isinstance(parsed, list): + raise ValueError("BT_EVAL_FILTER_PARSED must be a JSON array") + + filters: list[EvalFilter] = [] + for i, entry in enumerate(parsed): + if not isinstance(entry, dict): + raise ValueError("BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}") + key_path = entry.get("path") + pattern = entry.get("pattern") + if not isinstance(key_path, list) or not all(isinstance(part, str) for part in key_path): + raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} path must be an array of strings") + if not isinstance(pattern, str): + raise ValueError(f"BT_EVAL_FILTER_PARSED entry {i} pattern must be a string") + filters.append(EvalFilter(path=key_path, pattern=re.compile(pattern))) + return filters + + +def read_runner_config() -> RunnerConfig: + num_workers_value = os.getenv("BT_EVAL_NUM_WORKERS") + num_workers = int(num_workers_value) if num_workers_value else None + return RunnerConfig( + jsonl=env_flag("BT_EVAL_JSONL"), + list_only=env_flag("BT_EVAL_LIST"), + terminate_on_failure=env_flag("BT_EVAL_TERMINATE_ON_FAILURE"), + num_workers=num_workers, + filters=parse_serialized_filters(os.getenv("BT_EVAL_FILTER_PARSED")), + ) + + +def _to_mapping(value: Any) -> Any: + if isinstance(value, dict): + return {k: _to_mapping(v) for k, v in value.items()} + if isinstance(value, list): + return [_to_mapping(v) for v in value] + if hasattr(value, "__dict__"): + return { + key: _to_mapping(val) + for key, val in vars(value).items() + if not key.startswith("_") + } + return value + + +def serialize_json_with_plain_string(value: Any) -> str: + if isinstance(value, str): + return value + return json.dumps(value) + + +def evaluate_filter(value: Any, filt: EvalFilter) -> bool: + current = _to_mapping(value) + for part in filt.path: + if not isinstance(current, dict) or part not in current: + return False + current = current[part] + return bool(filt.pattern.search(serialize_json_with_plain_string(current))) + + +def filter_evaluators(evaluators: list[EvaluatorInstance], filters: list[EvalFilter]) -> list[EvaluatorInstance]: + if not filters: + return evaluators + return [ + evaluator + for evaluator in evaluators + if all(evaluate_filter(evaluator.evaluator, filt) for filt in filters) + ] + + def snake_to_camel(value: str) -> str: parts = value.split("_") if not parts: @@ -218,8 +308,12 @@ def resolve_module_info(in_file: str) -> tuple[str, list[str]]: return module_name, extra_paths -def load_evaluators(files: list[str]) -> list[EvaluatorInstance]: +def load_evaluators(files: list[str]) -> tuple[list[EvaluatorInstance], dict[str, Any]]: evaluator_instances: list[EvaluatorInstance] = [] + reporters: dict[str, Any] = {} + cwd = os.getcwd() + if cwd not in sys.path: + sys.path.insert(0, cwd) unique_files: set[str] = set() for file_path in files: for candidate in collect_files(file_path): @@ -250,10 +344,34 @@ def load_evaluators(files: list[str]) -> list[EvaluatorInstance]: if isinstance(instance, EvaluatorInstance) ] ) + for reporter_name, reporter in _evals.reporters.items(): + if reporter_name not in reporters: + reporters[reporter_name] = reporter finally: _evals.clear() - return evaluator_instances + return evaluator_instances, reporters + + +def resolve_reporter( + reporter: Any, + reporters: dict[str, Any], +) -> Any | None: + if isinstance(reporter, str): + if reporter not in reporters: + raise ValueError(f"Reporter {reporter} not found") + return reporters[reporter] + if reporter is not None: + return reporter + + if len(reporters) == 0: + return None + if len(reporters) == 1: + return next(iter(reporters.values())) + reporter_names = ", ".join(reporters.keys()) + raise ValueError( + f"Multiple reporters found ({reporter_names}). Please specify a reporter explicitly." + ) def _init_experiment_for_eval(evaluator): @@ -319,38 +437,74 @@ async def run_evaluator_task( if experiment: experiment.flush() -async def run_once(files: list[str], no_send_logs: bool, sse: SseWriter | None) -> bool: - evaluators = load_evaluators(files) - if not evaluators: +async def run_once( + files: list[str], + no_send_logs: bool, + sse: SseWriter | None, + config: RunnerConfig, +) -> bool: + evaluators, reporters = load_evaluators(files) + if not evaluators and not config.list_only: message = "No evaluators found. Did you call Eval() in the file?" if sse: - sse.send("error", serialize_error(message)) + sse.send("console", {"stream": "stderr", "message": message}) else: eprint(message) - return False + return True + + evaluators = filter_evaluators(evaluators, config.filters) + if config.list_only: + for evaluator_instance in evaluators: + print(evaluator_instance.evaluator.eval_name) + return True supports_progress = run_evaluator_supports_progress() - tasks = [] - progress_callbacks = [] - for idx, evaluator_instance in enumerate(evaluators): - progress_cb = create_progress_reporter(sse, evaluator_instance.evaluator.eval_name) - progress_callbacks.append(progress_cb) - tasks.append( - asyncio.create_task( - run_evaluator_task( - evaluator_instance.evaluator, idx, no_send_logs, progress_cb, supports_progress - ) + async def run_single_evaluator( + idx: int, evaluator_instance: EvaluatorInstance + ) -> tuple[EvaluatorInstance, Any | None, Any | None, dict[str, Any] | None]: + try: + resolved_reporter = resolve_reporter( + getattr(evaluator_instance, "reporter", None), + reporters, ) - ) + except Exception as exc: + err = serialize_error(str(exc), traceback.format_exc()) + return evaluator_instance, None, None, err - all_success = True - for evaluator_instance, task, progress_cb in zip(evaluators, tasks, progress_callbacks): + progress_cb = create_progress_reporter(sse, evaluator_instance.evaluator.eval_name) try: - result = await task + result = await run_evaluator_task( + evaluator_instance.evaluator, + idx, + no_send_logs, + progress_cb, + supports_progress, + ) except Exception as exc: - all_success = False err = serialize_error(str(exc), traceback.format_exc()) + return evaluator_instance, resolved_reporter, None, err + + return evaluator_instance, resolved_reporter, result, None + + execution_results: list[tuple[EvaluatorInstance, Any | None, Any | None, dict[str, Any] | None]] = [] + if config.terminate_on_failure: + for idx, evaluator_instance in enumerate(evaluators): + run_result = await run_single_evaluator(idx, evaluator_instance) + execution_results.append(run_result) + if run_result[3] is not None: + break + else: + tasks = [ + asyncio.create_task(run_single_evaluator(idx, evaluator_instance)) + for idx, evaluator_instance in enumerate(evaluators) + ] + execution_results = list(await asyncio.gather(*tasks)) + + all_success = True + for evaluator_instance, resolved_reporter, result, err in execution_results: + if err is not None: + all_success = False if sse: sse.send("error", err) else: @@ -359,11 +513,13 @@ async def run_once(files: list[str], no_send_logs: bool, sse: SseWriter | None) if sse: sse.send("summary", format_summary(result.summary.as_dict())) + elif config.jsonl: + print(json.dumps(format_summary(result.summary.as_dict()))) else: print(result.summary) failures = [row for row in result.results if row.error] - if failures: + if failures and resolved_reporter is None: all_success = False first_error = failures[0] message = ( @@ -374,6 +530,8 @@ async def run_once(files: list[str], no_send_logs: bool, sse: SseWriter | None) sse.send("error", serialize_error(message, stack)) else: eprint(message) + if config.terminate_on_failure: + break return all_success @@ -392,8 +550,11 @@ def main(argv: list[str] | None = None) -> int: parser = build_parser() args = parser.parse_args(argv) + config = read_runner_config() local = args.local or env_flag("BT_EVAL_LOCAL") or env_flag("BT_EVAL_NO_SEND_LOGS") files = args.files or ["."] + if config.num_workers is not None: + set_thread_pool_max_workers(config.num_workers) if not local: login(api_key=args.api_key, org_name=args.org_name, app_url=args.app_url) @@ -401,7 +562,7 @@ def main(argv: list[str] | None = None) -> int: sse = create_sse_writer() cwd = os.path.abspath(os.getcwd()) try: - success = asyncio.run(run_once(files, local, sse)) + success = asyncio.run(run_once(files, local, sse, config)) if sse: sse.send("dependencies", {"files": collect_dependency_files(cwd, files)}) sse.send("done", {"success": success}) diff --git a/scripts/eval-runner.ts b/scripts/eval-runner.ts index 294e1d8..03e4206 100644 --- a/scripts/eval-runner.ts +++ b/scripts/eval-runner.ts @@ -75,11 +75,30 @@ type SseWriter = { close: () => void; }; +type EvalFilter = { + path: string[]; + pattern: RegExp; +}; + +type SerializedEvalFilter = { + path: string[]; + pattern: string; +}; + +type RunnerConfig = { + jsonl: boolean; + list: boolean; + terminateOnFailure: boolean; + filters: EvalFilter[]; +}; + declare global { // eslint-disable-next-line no-var var _evals: GlobalEvals | undefined; // eslint-disable-next-line no-var var _lazy_load: boolean | undefined; + // eslint-disable-next-line no-var + var __inherited_braintrust_state: unknown; } function isObject(value: unknown): value is Record { @@ -104,6 +123,74 @@ function normalizeFiles(files: string[]): string[] { return files.map((file) => path.resolve(process.cwd(), file)); } +function envFlag(name: string): boolean { + const value = process.env[name]; + if (!value) { + return false; + } + const normalized = value.toLowerCase(); + return !["0", "false", "no", "off", ""].includes(normalized); +} + +function serializeJSONWithPlainString(value: unknown): string { + if (typeof value === "string") { + return value; + } + return JSON.stringify(value); +} + +function parseSerializedFilters(serialized: string | undefined): EvalFilter[] { + if (!serialized) { + return []; + } + + try { + const parsed = JSON.parse(serialized); + if (!Array.isArray(parsed)) { + throw new Error("BT_EVAL_FILTER_PARSED must be a JSON array."); + } + return parsed.map((value) => { + if (!isObject(value)) { + throw new Error( + "BT_EVAL_FILTER_PARSED entries must be objects with {path, pattern}.", + ); + } + const { path: rawPath, pattern: rawPattern } = + value as SerializedEvalFilter; + if ( + !Array.isArray(rawPath) || + !rawPath.every((part) => typeof part === "string") + ) { + throw new Error( + "BT_EVAL_FILTER_PARSED entry path must be an array of strings.", + ); + } + if (typeof rawPattern !== "string") { + throw new Error( + "BT_EVAL_FILTER_PARSED entry pattern must be a string.", + ); + } + return { + path: rawPath, + pattern: new RegExp(rawPattern), + }; + }); + } catch (err) { + throw new Error( + `Invalid BT_EVAL_FILTER_PARSED value: ${err instanceof Error ? err.message : String(err)}`, + ); + } +} + +function readRunnerConfig(): RunnerConfig { + return { + jsonl: envFlag("BT_EVAL_JSONL"), + list: envFlag("BT_EVAL_LIST"), + terminateOnFailure: envFlag("BT_EVAL_TERMINATE_ON_FAILURE"), + filters: parseSerializedFilters(process.env.BT_EVAL_FILTER_PARSED), + }; +} + const runtimeRequire = createRequire( process.argv[1] ?? path.join(process.cwd(), "package.json"), ); @@ -650,27 +737,62 @@ async function loadBraintrust() { return normalizeBraintrustModule(mod); } +function propagateInheritedBraintrustState(braintrust: BraintrustModule) { + const getter = (braintrust as Record) + ._internalGetGlobalState; + if (typeof getter !== "function") { + return; + } + const state = getter(); + if (state !== undefined && state !== null) { + globalThis.__inherited_braintrust_state = state; + } +} + async function loadFiles(files: string[]): Promise { const modules: unknown[] = []; for (const file of files) { const fileUrl = pathToFileURL(file).href; - try { - const mod = await import(fileUrl); - modules.push(mod); - } catch (err) { - if (shouldTryRequire(file, err)) { + const preferRequire = + file.endsWith(".ts") || file.endsWith(".tsx") || file.endsWith(".cjs"); + + if (preferRequire) { + try { + const require = createRequire(fileUrl); + const mod = require(file); + modules.push(mod); + continue; + } catch (requireErr) { try { - const require = createRequire(fileUrl); - const mod = require(file); + const mod = await import(fileUrl); modules.push(mod); continue; - } catch (requireErr) { + } catch (esmErr) { throw new Error( - `Failed to load ${file} as ESM (${formatError(err)}) or CJS (${formatError(requireErr)}).`, + `Failed to load ${file} as CJS (${formatError(requireErr)}) or ESM (${formatError(esmErr)}).`, ); } } - throw err; + } + + try { + const mod = await import(fileUrl); + modules.push(mod); + continue; + } catch (err) { + if (!shouldTryRequire(file, err)) { + throw err; + } + try { + const require = createRequire(fileUrl); + const mod = require(file); + modules.push(mod); + continue; + } catch (requireErr) { + throw new Error( + `Failed to load ${file} as ESM (${formatError(err)}) or CJS (${formatError(requireErr)}).`, + ); + } } } return modules; @@ -791,6 +913,69 @@ function getEvaluators(): EvaluatorEntry[] { return Object.values(evals.evaluators) as EvaluatorEntry[]; } +function getReporters(): Record { + const evals = globalThis._evals; + if (!evals || !evals.reporters) { + return {}; + } + return evals.reporters as Record; +} + +function resolveReporter( + reporter: unknown, + reporters: Record, +): unknown | undefined { + if (typeof reporter === "string") { + if (!(reporter in reporters)) { + throw new Error(`Reporter ${reporter} not found`); + } + return reporters[reporter]; + } + if (reporter !== undefined && reporter !== null) { + return reporter; + } + + const values = Object.values(reporters); + if (values.length === 0) { + return undefined; + } + if (values.length === 1) { + return values[0]; + } + const names = Object.keys(reporters).join(", "); + throw new Error( + `Multiple reporters found (${names}). Please specify a reporter explicitly.`, + ); +} + +function evaluateFilter( + object: Record, + filter: EvalFilter, +): boolean { + const key = filter.path.reduce((acc, part) => { + if (!isObject(acc)) { + return undefined; + } + return acc[part]; + }, object); + if (key === undefined) { + return false; + } + return filter.pattern.test(serializeJSONWithPlainString(key)); +} + +function filterEvaluators( + evaluators: EvaluatorEntry[], + filters: EvalFilter[], +): EvaluatorEntry[] { + if (filters.length === 0) { + return evaluators; + } + return evaluators.filter((entry) => + filters.every((filter) => evaluateFilter(entry.evaluator, filter)), + ); +} + function extractBtEvalMain(mod: unknown): BtEvalMain | null { if (!mod || typeof mod !== "object") { return null; @@ -927,7 +1112,7 @@ function mergeProgress( }; } -async function createEvalRunner() { +async function createEvalRunner(config: RunnerConfig) { const braintrust = await loadBraintrust(); const Eval = braintrust.Eval; if (typeof Eval !== "function") { @@ -990,35 +1175,55 @@ async function createEvalRunner() { } if (sse) { sse.send("summary", result.summary); + } else if (config.jsonl) { + console.log(JSON.stringify(result.summary)); } return result; }; const runRegisteredEvals = async (evaluators: EvaluatorEntry[]) => { - const results = await Promise.all( - evaluators.map(async (entry) => { - try { - const options = entry.reporter - ? { reporter: entry.reporter } + const reporters = getReporters(); + const runEntry = async (entry: EvaluatorEntry): Promise => { + try { + const resolvedReporter = resolveReporter(entry.reporter, reporters); + const options = + resolvedReporter !== undefined + ? { reporter: resolvedReporter } : undefined; - const result = await runEval( - entry.evaluator.projectName, - entry.evaluator, - options, - ); - const failingResults = result.results.filter( - (r: { error?: unknown }) => r.error !== undefined, - ); - return failingResults.length === 0; - } catch (err) { - if (sse) { - sse.send("error", serializeError(err)); - } else { - console.error(err); - } + const result = await runEval( + entry.evaluator.projectName, + entry.evaluator, + options, + ); + const failingResults = result.results.filter( + (r: { error?: unknown }) => r.error !== undefined, + ); + if (failingResults.length > 0 && resolvedReporter === undefined) { return false; } - }), + return true; + } catch (err) { + if (sse) { + sse.send("error", serializeError(err)); + } else { + console.error(err); + } + return false; + } + }; + + if (config.terminateOnFailure) { + for (const entry of evaluators) { + const ok = await runEntry(entry); + if (!ok) { + return false; + } + } + return true; + } + + const results = await Promise.all( + evaluators.map((entry) => runEntry(entry)), ); return results.every(Boolean); }; @@ -1047,6 +1252,7 @@ async function createEvalRunner() { } async function main() { + const config = readRunnerConfig(); const files = process.argv.slice(2); if (files.length === 0) { console.error("No eval files provided."); @@ -1060,12 +1266,13 @@ async function main() { } collectStaticLocalDependencies(normalized); ensureBraintrustAvailable(); - await loadBraintrust(); + const braintrust = await loadBraintrust(); + propagateInheritedBraintrustState(braintrust); initRegistry(); const modules = await loadFiles(normalized); const btEvalMains = collectBtEvalMains(modules); - const runner = await createEvalRunner(); + const runner = await createEvalRunner(config); if (!runner.noSendLogs && typeof runner.login === "function") { try { await runner.login({}); @@ -1082,7 +1289,10 @@ async function main() { const context: BtEvalContext = { Eval: runner.Eval, runEval: runner.runEval, - runRegisteredEvals: () => runner.runRegisteredEvals(getEvaluators()), + runRegisteredEvals: () => + runner.runRegisteredEvals( + filterEvaluators(getEvaluators(), config.filters), + ), makeEvalOptions: runner.makeEvalOptions, sendConsole: (message: string, stream?: "stdout" | "stderr") => { sendConsole(runner.sse, message, stream); @@ -1096,6 +1306,18 @@ async function main() { let ok = true; try { + const discoveredEvaluators = getEvaluators(); + const filteredEvaluators = filterEvaluators( + discoveredEvaluators, + config.filters, + ); + if (config.list) { + for (const entry of filteredEvaluators) { + console.log(entry.evaluator.evalName); + } + return; + } + if (btEvalMains.length > 0) { globalThis._lazy_load = false; for (const main of btEvalMains) { @@ -1111,12 +1333,11 @@ async function main() { } } } else { - const evaluators = getEvaluators(); - if (evaluators.length === 0) { + if (discoveredEvaluators.length === 0) { console.error("No evaluators found. Did you call Eval() in the file?"); process.exit(1); } - ok = await runner.runRegisteredEvals(evaluators); + ok = await runner.runRegisteredEvals(filteredEvaluators); } } finally { collectRequireCacheDependencies(); diff --git a/src/eval.rs b/src/eval.rs index 998cce7..06ba1ce 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -14,7 +14,7 @@ use crossterm::style::{ Stylize, }; use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use strip_ansi_escapes::strip; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::net::UnixListener; @@ -38,6 +38,13 @@ struct EvalRunOutput { status: ExitStatus, dependencies: Vec, } + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +struct RunnerFilter { + path: Vec, + pattern: String, +} + const JS_RUNNER_FILE: &str = "eval-runner.ts"; const PY_RUNNER_FILE: &str = "eval-runner.py"; const JS_RUNNER_SOURCE: &str = include_str!("../scripts/eval-runner.ts"); @@ -96,12 +103,49 @@ pub struct EvalArgs { )] pub no_send_logs: bool, + /// Output one JSON summary per evaluator. + #[arg(long)] + pub jsonl: bool, + + /// Stop after the first failing evaluator. + #[arg(long)] + pub terminate_on_failure: bool, + + /// Number of worker threads for Python eval execution. + #[arg(long, value_name = "COUNT")] + pub num_workers: Option, + + /// List evaluators without executing them. + #[arg(long)] + pub list: bool, + + /// Filter expression(s) used to select which evaluators to run. + #[arg(long, value_name = "FILTER")] + pub filter: Vec, + /// Re-run evals when input files change. #[arg(long, short = 'w')] pub watch: bool, } +#[derive(Debug, Clone)] +struct EvalRunOptions { + jsonl: bool, + terminate_on_failure: bool, + num_workers: Option, + list: bool, + filter: Vec, +} + pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { + let options = EvalRunOptions { + jsonl: args.jsonl, + terminate_on_failure: args.terminate_on_failure, + num_workers: args.num_workers, + list: args.list, + filter: args.filter, + }; + if args.watch { run_eval_files_watch( &base, @@ -109,6 +153,7 @@ pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { args.runner.clone(), args.files.clone(), args.no_send_logs, + options, ) .await } else { @@ -118,6 +163,7 @@ pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { args.runner.clone(), args.files.clone(), args.no_send_logs, + options, ) .await?; if !output.status.success() { @@ -133,6 +179,7 @@ async fn run_eval_files_watch( runner_override: Option, files: Vec, no_send_logs: bool, + options: EvalRunOptions, ) -> Result<()> { let input_watch_paths = resolve_watch_paths(&files)?; let mut active_watch_paths = input_watch_paths.clone(); @@ -150,6 +197,7 @@ async fn run_eval_files_watch( runner_override.clone(), files.clone(), no_send_logs, + options.clone(), ) .await { @@ -192,8 +240,12 @@ async fn run_eval_files_once( runner_override: Option, files: Vec, no_send_logs: bool, + options: EvalRunOptions, ) -> Result { let language = detect_eval_language(&files, language_override)?; + if language != EvalLanguage::Python && options.num_workers.is_some() { + anyhow::bail!("--num-workers is only supported for Python evals."); + } let show_js_runner_hint_on_failure = language == EvalLanguage::JavaScript && runner_override.is_none(); let (js_runner, py_runner) = prepare_eval_runners()?; @@ -237,6 +289,24 @@ async fn run_eval_files_once( cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); cmd.env("BT_EVAL_LOCAL", "1"); } + if options.jsonl { + cmd.env("BT_EVAL_JSONL", "1"); + } + if options.terminate_on_failure { + cmd.env("BT_EVAL_TERMINATE_ON_FAILURE", "1"); + } + if options.list { + cmd.env("BT_EVAL_LIST", "1"); + } + if let Some(num_workers) = options.num_workers { + cmd.env("BT_EVAL_NUM_WORKERS", num_workers.to_string()); + } + if !options.filter.is_empty() { + let parsed = parse_eval_filter_expressions(&options.filter)?; + let serialized = + serde_json::to_string(&parsed).context("failed to serialize eval filters")?; + cmd.env("BT_EVAL_FILTER_PARSED", serialized); + } cmd.env( "BT_EVAL_SSE_SOCK", socket_path.to_string_lossy().to_string(), @@ -267,7 +337,7 @@ async fn run_eval_files_once( }); } - let mut ui = EvalUi::new(); + let mut ui = EvalUi::new(options.jsonl, options.list); let mut status = None; let mut dependency_files: Vec = Vec::new(); @@ -339,6 +409,27 @@ fn resolve_watch_paths(files: &[String]) -> Result> { normalize_watch_paths(files.iter().map(PathBuf::from)) } +fn parse_eval_filter_expression(expression: &str) -> Result { + let (path, pattern) = expression + .split_once('=') + .ok_or_else(|| anyhow::anyhow!("Invalid filter {expression}"))?; + let path = path.trim(); + if path.is_empty() { + anyhow::bail!("Invalid filter {expression}"); + } + Ok(RunnerFilter { + path: path.split('.').map(str::to_string).collect(), + pattern: pattern.to_string(), + }) +} + +fn parse_eval_filter_expressions(filters: &[String]) -> Result> { + filters + .iter() + .map(|filter| parse_eval_filter_expression(filter)) + .collect() +} + fn normalize_watch_paths(paths: impl IntoIterator) -> Result> { let cwd = std::env::current_dir().context("failed to read current directory")?; let mut deduped = BTreeSet::new(); @@ -781,6 +872,12 @@ fn is_ts_node_runner(runner_command: &Path) -> bool { } fn find_python_binary() -> Option { + if let Some(venv_root) = std::env::var_os("VIRTUAL_ENV") { + let candidate = PathBuf::from(venv_root).join("bin").join("python"); + if candidate.is_file() { + return Some(candidate); + } + } find_binary_in_path(&["python3", "python"]) } @@ -882,13 +979,13 @@ enum EvalEvent { stack: Option, }, Console { - _stream: String, + stream: String, message: String, }, } #[allow(dead_code)] -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] struct ExperimentSummary { project_name: String, @@ -902,7 +999,7 @@ struct ExperimentSummary { metrics: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] struct ScoreSummary { name: String, score: f64, @@ -917,7 +1014,7 @@ struct EvalErrorPayload { stack: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] struct MetricSummary { name: String, metric: f64, @@ -970,7 +1067,7 @@ where let mut lines = BufReader::new(stream).lines(); while let Some(line) = lines.next_line().await? { let _ = tx.send(EvalEvent::Console { - _stream: name.to_string(), + stream: name.to_string(), message: line, }); } @@ -1031,7 +1128,7 @@ fn handle_sse_event(event: Option, data: String, tx: &mpsc::UnboundedSen "console" => { if let Ok(console) = serde_json::from_str::(&data) { let _ = tx.send(EvalEvent::Console { - _stream: console.stream, + stream: console.stream, message: console.message, }); } @@ -1068,10 +1165,12 @@ struct EvalUi { bars: HashMap, bar_style: ProgressStyle, spinner_style: ProgressStyle, + jsonl: bool, + list: bool, } impl EvalUi { - fn new() -> Self { + fn new(jsonl: bool, list: bool) -> Self { let progress = MultiProgress::with_draw_target(ProgressDrawTarget::stderr_with_hz(10)); let bar_style = ProgressStyle::with_template("{bar:10.blue} {msg} {percent}% {pos}/{len} {eta}") @@ -1082,6 +1181,8 @@ impl EvalUi { bars: HashMap::new(), bar_style, spinner_style, + jsonl, + list, } } @@ -1098,17 +1199,27 @@ impl EvalUi { let _ = self.progress.println(line); } EvalEvent::Summary(summary) => { - let rendered = format_experiment_summary(&summary); - for line in rendered.lines() { - let _ = self.progress.println(line); + if self.jsonl { + if let Ok(line) = serde_json::to_string(&summary) { + println!("{line}"); + } + } else { + let rendered = format_experiment_summary(&summary); + for line in rendered.lines() { + let _ = self.progress.println(line); + } } } EvalEvent::Progress(progress) => { self.handle_progress(progress); } EvalEvent::Dependencies { .. } => {} - EvalEvent::Console { message, .. } => { - let _ = self.progress.println(message); + EvalEvent::Console { stream, message } => { + if stream == "stdout" && (self.list || self.jsonl) { + println!("{message}"); + } else { + let _ = self.progress.println(message); + } } EvalEvent::Error { message, stack } => { let show_hint = message.contains("Please specify an api key"); @@ -1771,4 +1882,36 @@ mod tests { ] ); } + + #[test] + fn parse_eval_filter_expression_splits_path_and_pattern() { + let parsed = + parse_eval_filter_expression("metadata.case=smoke.*").expect("parse should succeed"); + assert_eq!( + parsed, + RunnerFilter { + path: vec!["metadata".to_string(), "case".to_string()], + pattern: "smoke.*".to_string(), + } + ); + } + + #[test] + fn parse_eval_filter_expression_rejects_missing_equals() { + let err = + parse_eval_filter_expression("metadata.case").expect_err("missing equals should fail"); + assert!( + err.to_string().contains("Invalid filter"), + "unexpected error: {err}" + ); + } + + #[test] + fn parse_eval_filter_expression_rejects_empty_path() { + let err = parse_eval_filter_expression("=foo").expect_err("empty path should fail"); + assert!( + err.to_string().contains("Invalid filter"), + "unexpected error: {err}" + ); + } }