Skip to content

Commit a411a3f

Browse files
iofu728QianhuiWupzs19XufangLuomydmdm
authored
Prereleased(LLMLinguia): fix the chunck issue and prepare for v0.2.2 (#130)
Co-authored-by: Qianhui Wu <wuqh_thu@foxmail.com> Co-authored-by: panzs <915933979@qq.com> Co-authored-by: Xufang Luo <34053802+XufangLuo@users.noreply.github.com> Co-authored-by: Yuqing Yang <justin.yqyang@gmail.com>
1 parent 309392a commit a411a3f

File tree

4 files changed

+109
-84
lines changed

4 files changed

+109
-84
lines changed

experiments/llmlingua2/evaluation/eval_meetingbank_qa.py

Lines changed: 92 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -32,89 +32,104 @@
3232

3333
args = parser.parse_args()
3434
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
35-
data = json.load(open(args.load_prompt_from))
36-
data = data.values() if isinstance(data, dict) else data
37-
38-
print(f"num data: {len(data)}")
39-
40-
model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)
41-
42-
results = defaultdict(dict)
43-
results_list = defaultdict(list)
44-
if os.path.exists(args.save_path):
45-
prev_results = json.load(open(args.save_path))
46-
results.update(prev_results)
47-
if os.path.exists(
48-
os.path.join(
49-
os.path.dirname(args.save_path),
50-
os.path.basename(args.save_path).replace("answer", "answer_list"),
51-
)
52-
):
53-
results_list = json.load(
54-
open(
55-
os.path.join(
56-
os.path.dirname(args.save_path),
57-
os.path.basename(args.save_path).replace("answer", "answer_list"),
35+
36+
37+
def predict():
38+
data = json.load(open(args.load_prompt_from))
39+
data = data.values() if isinstance(data, dict) else data
40+
41+
print(f"num data: {len(data)}")
42+
43+
model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)
44+
45+
results = defaultdict(dict)
46+
results_list = defaultdict(list)
47+
if os.path.exists(args.save_path):
48+
prev_results = json.load(open(args.save_path))
49+
results.update(prev_results)
50+
if os.path.exists(
51+
os.path.join(
52+
os.path.dirname(args.save_path),
53+
os.path.basename(args.save_path).replace("answer", "answer_list"),
54+
)
55+
):
56+
results_list = json.load(
57+
open(
58+
os.path.join(
59+
os.path.dirname(args.save_path),
60+
os.path.basename(args.save_path).replace("answer", "answer_list"),
61+
)
5862
)
5963
)
60-
)
6164

62-
prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:"
63-
for sample in tqdm(data):
64-
sample_idx = int(sample["idx"])
65-
if sample_idx in results or str(sample_idx) in results:
66-
print(f"{sample_idx}-th already processed.")
67-
continue
68-
if args.num_sample > 0 and int(sample_idx) > args.num_sample:
69-
break
70-
transcript = sample[args.load_key]
71-
token_ids = tokenizer.encode(transcript)
72-
if len(token_ids) > args.n_max_token - args.n_max_token_ans:
73-
transcript = tokenizer.decode(
74-
token_ids[: args.n_max_token - args.n_max_token_ans]
65+
prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:"
66+
for sample in tqdm(data):
67+
sample_idx = int(sample["idx"])
68+
if sample_idx in results or str(sample_idx) in results:
69+
print(f"{sample_idx}-th already processed.")
70+
continue
71+
if args.num_sample > 0 and int(sample_idx) > args.num_sample:
72+
break
73+
transcript = sample[args.load_key]
74+
token_ids = tokenizer.encode(transcript)
75+
if len(token_ids) > args.n_max_token - args.n_max_token_ans:
76+
transcript = tokenizer.decode(
77+
token_ids[: args.n_max_token - args.n_max_token_ans]
78+
)
79+
qa_list = sample["QA_pairs"]
80+
q_list = []
81+
a_list = []
82+
a_list_model = []
83+
for qa in qa_list:
84+
q = qa["question"]
85+
a = qa["answer"]
86+
query = prompt.format(transcript=transcript, question=q)
87+
answer = query_llm(
88+
query,
89+
model,
90+
args.model_name_or_path,
91+
args.n_max_token_ans,
92+
tokenizer=tokenizer,
93+
)
94+
q_list.append(q)
95+
a_list.append(a)
96+
a_list_model.append(answer)
97+
98+
results[sample_idx]["transcript"] = transcript
99+
results[sample_idx]["questions"] = q_list[:]
100+
results[sample_idx]["answers"] = a_list[:]
101+
results[sample_idx]["model_answers"] = a_list_model[:]
102+
103+
results_list["questions"].extend(q_list[:])
104+
results_list["answers"].extend(a_list[:])
105+
results_list["model_answers"].extend(a_list_model[:])
106+
107+
json.dump(results, open(args.save_path, "w"), indent=4)
108+
json.dump(
109+
results_list,
110+
open(
111+
os.path.join(
112+
os.path.dirname(args.save_path),
113+
os.path.basename(args.save_path).replace("answer", "answer_list"),
114+
),
115+
"w",
116+
),
117+
indent=4,
75118
)
76-
qa_list = sample["QA_pairs"]
77-
q_list = []
78-
a_list = []
79-
a_list_model = []
80-
for qa in qa_list:
81-
q = qa["question"]
82-
a = qa["answer"]
83-
query = prompt.format(transcript=transcript, question=q)
84-
answer = query_llm(
85-
query,
86-
model,
87-
args.model_name_or_path,
88-
args.n_max_token_ans,
89-
tokenizer=tokenizer,
119+
120+
121+
predict()
122+
results_list = json.load(
123+
open(
124+
os.path.join(
125+
os.path.dirname(args.save_path),
126+
os.path.basename(args.save_path).replace("answer", "answer_list"),
90127
)
91-
q_list.append(q)
92-
a_list.append(a)
93-
a_list_model.append(answer)
94-
95-
results[sample_idx]["transcript"] = transcript
96-
results[sample_idx]["questions"] = q_list[:]
97-
results[sample_idx]["answers"] = a_list[:]
98-
results[sample_idx]["model_answers"] = a_list_model[:]
99-
100-
results_list["questions"].extend(q_list[:])
101-
results_list["answers"].extend(a_list[:])
102-
results_list["model_answers"].extend(a_list_model[:])
103-
104-
json.dump(results, open(args.save_path, "w"), indent=4)
105-
json.dump(
106-
results_list,
107-
open(
108-
os.path.join(
109-
os.path.dirname(args.save_path),
110-
os.path.basename(args.save_path).replace("answer", "answer_list"),
111-
),
112-
"w",
113-
),
114-
indent=4,
115128
)
116-
117-
score_dict = evaluate_with_gt(results_list["answers"], results_list["model_answers"])
129+
)
130+
for i, ans in enumerate(results_list["answers"]):
131+
results_list["answers"][i] = [results_list["answers"][i]]
132+
score_dict = evaluate_with_gt(results_list["model_answers"], results_list["answers"])
118133
json.dump(
119134
score_dict,
120135
open(

experiments/llmlingua2/evaluation/metrics.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,16 @@ def qa_f1_zh_score(prediction, ground_truth, **kwargs):
156156
return f1_score(prediction_tokens, ground_truth_tokens)
157157

158158

159+
def qa_score(prediction, ground_truths):
160+
normalized_prediction = normalize_answer2(prediction)
161+
162+
for ground_truth in ground_truths:
163+
normalized_ground_truth = normalize_answer2(ground_truth)
164+
if normalized_ground_truth.lower() in normalized_prediction.lower():
165+
return 1.0
166+
return 0.0
167+
168+
159169
import regex
160170

161171

@@ -207,12 +217,10 @@ def eval_qa_f1_score(pred, ground_truths):
207217
pred_list = pred_list_truncated
208218

209219
metrics = {
210-
"qa_f1_score": 0.0,
211-
"best_subspan_em": 0.0,
220+
"qa_score": 0.0,
212221
}
213222
for pred, gts in zip(pred_list, gt_list):
214-
metrics["qa_f1_score"] += eval_qa_f1_score(pred, gts)
215-
metrics["best_subspan_em"] += best_subspan_em(pred, gts)
223+
metrics["qa_score"] += qa_score(pred, gts)
216224
# average
217225
for metric_name, score in metrics.items():
218226
metrics[metric_name] = score * 100 / len(pred_list)

llmlingua/prompt_compressor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2407,8 +2407,10 @@ def split_string_to_words(input_string):
24072407
keep_words = []
24082408
word_labels = []
24092409
assert len(words) == len(word_probs)
2410-
for word, word_porb in zip(words, word_probs):
2411-
if word_porb > threshold:
2410+
for word, word_prob in zip(words, word_probs):
2411+
if word_prob > threshold or (
2412+
threshold == 1.0 and word_prob == threshold
2413+
):
24122414
if (
24132415
drop_consecutive
24142416
and word in force_tokens

llmlingua/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
_MINOR = "2"
66
# On master and in a nightly release the patch should be one ahead of the last
77
# released build.
8-
_PATCH = "1"
8+
_PATCH = "2"
99
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
1010
# https://semver.org/#is-v123-a-semantic-version for the semantics.
1111
_SUFFIX = ""

0 commit comments

Comments
 (0)