-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_selector.py
More file actions
91 lines (79 loc) · 2.98 KB
/
test_selector.py
File metadata and controls
91 lines (79 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Manual verification tests for the ToolSelector.
Run with: python3 test_selector.py
"""
from tool_selector import ToolRegistry, ToolSelector, ToolEntry
CASES = [
# (task description, expected top tool name)
("convert 100 USD to Japanese yen", "currency_converter"),
("calculate 17 percent of 340", "calculator"),
("what is the latest news about AI?", "web_search"),
("run this python snippet and show output", "code_executor"),
("summarize this long article", "summarizer"),
("convert 5 miles to kilometers", "unit_converter"),
("how many days between Jan 1 and March 15?", "calendar_tool"),
("read the CSV file and show column names", "file_reader"),
("send an email to john@example.com", "email_sender"),
("count users who signed up last month from the database", "database_query"),
]
PASS = 0
FAIL = 0
registry = ToolRegistry.default()
selector = ToolSelector(registry, top_n=3)
print("=" * 60)
print("ToolSelector — test suite")
print("=" * 60)
for task, expected_top in CASES:
recs = selector.recommend(task)
top_name = recs[0].tool.name if recs else "(none)"
passed = top_name == expected_top
status = "PASS" if passed else "FAIL"
if passed:
PASS += 1
else:
FAIL += 1
score = recs[0].score if recs else 0.0
print(f"[{status}] task='{task[:45]}...' | expected={expected_top} | got={top_name} (score={score:.3f})")
print()
print(f"Results: {PASS}/{PASS+FAIL} passed")
# Custom registry test
print()
print("--- Custom registry test ---")
custom = ToolRegistry()
custom.register(ToolEntry(
name="my_custom_tool",
description="Handles invoice parsing and extraction.",
capabilities=["invoice", "parsing", "extraction", "ocr"],
use_cases=["extract invoice", "parse invoice", "invoice data"],
tags=["finance", "document"],
))
sel2 = ToolSelector(custom, top_n=1)
recs2 = sel2.recommend("extract data from this invoice PDF")
assert recs2 and recs2[0].tool.name == "my_custom_tool", "Custom tool test FAILED"
print("[PASS] Custom tool registry test")
# JSON round-trip test
import tempfile, json, pathlib
print()
print("--- Persistence round-trip test ---")
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
tmp_path = f.name
registry.save(tmp_path)
reloaded = ToolRegistry.load(tmp_path)
assert len(reloaded) == len(registry), "Persistence test FAILED: length mismatch"
print(f"[PASS] Saved and reloaded {len(reloaded)} tools from JSON")
# CLI JSON output test
print()
print("--- CLI JSON output test ---")
from tool_selector.cli import main as cli_main
import io, contextlib
buf = io.StringIO()
with contextlib.redirect_stdout(buf):
cli_main(["--json", "convert 50 EUR to dollars"])
output = json.loads(buf.getvalue())
assert output["recommendations"][0]["name"] == "currency_converter", "CLI test FAILED"
print(f"[PASS] CLI JSON output correct: top={output['recommendations'][0]['name']}")
print()
if FAIL == 0:
print("All tests passed.")
else:
print(f"WARNING: {FAIL} test(s) failed.")