-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_dictation.py
More file actions
138 lines (115 loc) · 3.78 KB
/
eval_dictation.py
File metadata and controls
138 lines (115 loc) · 3.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
"""
Small evaluation harness for Bolo transcript quality.
Usage:
python3 eval_dictation.py init
python3 eval_dictation.py prompts
python3 eval_dictation.py score eval_results.json
`init` prints a JSON template you can save and fill with observed transcripts.
`prompts` prints the phrase list in a speakable format.
`score` compares results against expected phrases and reports exact/normalized matches.
"""
import argparse
import json
import re
import sys
from pathlib import Path
BASE_DIR = Path(__file__).resolve().parent
PHRASES_FILE = BASE_DIR / "eval_phrases.json"
def load_phrases():
return json.loads(PHRASES_FILE.read_text(encoding="utf-8"))
def normalize(text: str) -> str:
text = (text or "").strip().casefold()
contractions = {
"i'm": "i am",
"it's": "it is",
"that's": "that is",
"what's": "what is",
"can't": "cannot",
"won't": "will not",
"don't": "do not",
"didn't": "did not",
"i've": "i have",
"you're": "you are",
}
for src, dst in contractions.items():
text = text.replace(src, dst)
text = re.sub(r"[^\w\s]", " ", text)
text = re.sub(r"\s+", " ", text)
return text.strip()
def init_template():
template = []
for item in load_phrases():
template.append(
{
"id": item["id"],
"expected": item["expected"],
"actual": "",
"notes": "",
}
)
print(json.dumps(template, indent=2))
def print_prompts():
for idx, item in enumerate(load_phrases(), start=1):
print(f"{idx}. [{item['category']}] {item['expected']}")
def score_results(results_path: Path):
phrases = {item["id"]: item for item in load_phrases()}
results = json.loads(results_path.read_text(encoding="utf-8"))
rows = []
exact = 0
normalized = 0
for result in results:
phrase = phrases.get(result["id"])
if not phrase:
continue
expected = phrase["expected"]
actual = result.get("actual", "")
exact_match = actual.strip() == expected.strip()
normalized_match = normalize(actual) == normalize(expected)
exact += int(exact_match)
normalized += int(normalized_match)
rows.append(
{
"id": phrase["id"],
"category": phrase["category"],
"exact": exact_match,
"normalized": normalized_match,
"expected": expected,
"actual": actual,
}
)
print(f"phrases: {len(rows)}")
print(f"exact_match: {exact}/{len(rows)}")
print(f"normalized_match: {normalized}/{len(rows)}")
long_form_rows = [row for row in rows if row["category"] == "long_form"]
if long_form_rows:
long_form_ok = sum(int(row["normalized"]) for row in long_form_rows)
print(f"long_form_match: {long_form_ok}/{len(long_form_rows)}")
print("")
for row in rows:
status = "OK" if row["normalized"] else "MISS"
print(f"{status} {row['id']} [{row['category']}]")
if not row["normalized"]:
print(f" expected: {row['expected']}")
print(f" actual: {row['actual']}")
def main():
parser = argparse.ArgumentParser()
sub = parser.add_subparsers(dest="cmd", required=True)
sub.add_parser("init")
sub.add_parser("prompts")
score = sub.add_parser("score")
score.add_argument("results")
args = parser.parse_args()
if args.cmd == "init":
init_template()
return
if args.cmd == "prompts":
print_prompts()
return
if args.cmd == "score":
score_results(Path(args.results))
return
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()