Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion llmsql/_cli/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,29 @@
from llmsql.evaluation.evaluate import evaluate


def parse_limit(value: str | None) -> int | float | None:
if value is None:
return None

try:
if "." in value:
limit = float(value)
if not (0.0 < limit <= 1.0):
raise ValueError
return limit
else:
limit = int(value)
if limit <= 0:
raise ValueError
return limit
except ValueError as e:
raise argparse.ArgumentTypeError(
"--limit must be a positive integer or a float between 0.0 and 1.0"
) from e


class EvaluationCommand:
"""CLI wrapper for the `evaluate()` function."""

@staticmethod
def register(subparsers: argparse._SubParsersAction) -> None:
eval_parser = subparsers.add_parser(
Expand All @@ -28,6 +48,18 @@ def register(subparsers: argparse._SubParsersAction) -> None:
),
)

eval_parser.add_argument(
"--limit",
required=False,
type=parse_limit,
default=None,
help=(
"Optional. Limit the number of evaluated samples.\n"
"Accepts an integer (e.g. 100) or a float between 0.0 and 1.0 (e.g. 0.1 for 10%).\n"
"Useful for debugging."
)
)

eval_parser.add_argument(
"--workdir-path",
type=str,
Expand Down Expand Up @@ -93,6 +125,8 @@ def execute(args: argparse.Namespace) -> None:
except Exception:
outputs = args.outputs

limit = args.limit

result = evaluate(
outputs=outputs,
workdir_path=args.workdir_path,
Expand All @@ -101,6 +135,7 @@ def execute(args: argparse.Namespace) -> None:
save_report=args.save_report,
show_mismatches=args.show_mismatches,
max_mismatches=args.max_mismatches,
limit=limit
)

print(json.dumps(result, indent=2))
Loading