|
32 | 32 |
|
33 | 33 | args = parser.parse_args() |
34 | 34 | os.makedirs(os.path.dirname(args.save_path), exist_ok=True) |
35 | | -data = json.load(open(args.load_prompt_from)) |
36 | | -data = data.values() if isinstance(data, dict) else data |
37 | | - |
38 | | -print(f"num data: {len(data)}") |
39 | | - |
40 | | -model, tokenizer = load_model_and_tokenizer(args.model_name_or_path) |
41 | | - |
42 | | -results = defaultdict(dict) |
43 | | -results_list = defaultdict(list) |
44 | | -if os.path.exists(args.save_path): |
45 | | - prev_results = json.load(open(args.save_path)) |
46 | | - results.update(prev_results) |
47 | | -if os.path.exists( |
48 | | - os.path.join( |
49 | | - os.path.dirname(args.save_path), |
50 | | - os.path.basename(args.save_path).replace("answer", "answer_list"), |
51 | | - ) |
52 | | -): |
53 | | - results_list = json.load( |
54 | | - open( |
55 | | - os.path.join( |
56 | | - os.path.dirname(args.save_path), |
57 | | - os.path.basename(args.save_path).replace("answer", "answer_list"), |
| 35 | + |
| 36 | + |
| 37 | +def predict(): |
| 38 | + data = json.load(open(args.load_prompt_from)) |
| 39 | + data = data.values() if isinstance(data, dict) else data |
| 40 | + |
| 41 | + print(f"num data: {len(data)}") |
| 42 | + |
| 43 | + model, tokenizer = load_model_and_tokenizer(args.model_name_or_path) |
| 44 | + |
| 45 | + results = defaultdict(dict) |
| 46 | + results_list = defaultdict(list) |
| 47 | + if os.path.exists(args.save_path): |
| 48 | + prev_results = json.load(open(args.save_path)) |
| 49 | + results.update(prev_results) |
| 50 | + if os.path.exists( |
| 51 | + os.path.join( |
| 52 | + os.path.dirname(args.save_path), |
| 53 | + os.path.basename(args.save_path).replace("answer", "answer_list"), |
| 54 | + ) |
| 55 | + ): |
| 56 | + results_list = json.load( |
| 57 | + open( |
| 58 | + os.path.join( |
| 59 | + os.path.dirname(args.save_path), |
| 60 | + os.path.basename(args.save_path).replace("answer", "answer_list"), |
| 61 | + ) |
58 | 62 | ) |
59 | 63 | ) |
60 | | - ) |
61 | 64 |
|
62 | | -prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:" |
63 | | -for sample in tqdm(data): |
64 | | - sample_idx = int(sample["idx"]) |
65 | | - if sample_idx in results or str(sample_idx) in results: |
66 | | - print(f"{sample_idx}-th already processed.") |
67 | | - continue |
68 | | - if args.num_sample > 0 and int(sample_idx) > args.num_sample: |
69 | | - break |
70 | | - transcript = sample[args.load_key] |
71 | | - token_ids = tokenizer.encode(transcript) |
72 | | - if len(token_ids) > args.n_max_token - args.n_max_token_ans: |
73 | | - transcript = tokenizer.decode( |
74 | | - token_ids[: args.n_max_token - args.n_max_token_ans] |
| 65 | + prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:" |
| 66 | + for sample in tqdm(data): |
| 67 | + sample_idx = int(sample["idx"]) |
| 68 | + if sample_idx in results or str(sample_idx) in results: |
| 69 | + print(f"{sample_idx}-th already processed.") |
| 70 | + continue |
| 71 | + if args.num_sample > 0 and int(sample_idx) > args.num_sample: |
| 72 | + break |
| 73 | + transcript = sample[args.load_key] |
| 74 | + token_ids = tokenizer.encode(transcript) |
| 75 | + if len(token_ids) > args.n_max_token - args.n_max_token_ans: |
| 76 | + transcript = tokenizer.decode( |
| 77 | + token_ids[: args.n_max_token - args.n_max_token_ans] |
| 78 | + ) |
| 79 | + qa_list = sample["QA_pairs"] |
| 80 | + q_list = [] |
| 81 | + a_list = [] |
| 82 | + a_list_model = [] |
| 83 | + for qa in qa_list: |
| 84 | + q = qa["question"] |
| 85 | + a = qa["answer"] |
| 86 | + query = prompt.format(transcript=transcript, question=q) |
| 87 | + answer = query_llm( |
| 88 | + query, |
| 89 | + model, |
| 90 | + args.model_name_or_path, |
| 91 | + args.n_max_token_ans, |
| 92 | + tokenizer=tokenizer, |
| 93 | + ) |
| 94 | + q_list.append(q) |
| 95 | + a_list.append(a) |
| 96 | + a_list_model.append(answer) |
| 97 | + |
| 98 | + results[sample_idx]["transcript"] = transcript |
| 99 | + results[sample_idx]["questions"] = q_list[:] |
| 100 | + results[sample_idx]["answers"] = a_list[:] |
| 101 | + results[sample_idx]["model_answers"] = a_list_model[:] |
| 102 | + |
| 103 | + results_list["questions"].extend(q_list[:]) |
| 104 | + results_list["answers"].extend(a_list[:]) |
| 105 | + results_list["model_answers"].extend(a_list_model[:]) |
| 106 | + |
| 107 | + json.dump(results, open(args.save_path, "w"), indent=4) |
| 108 | + json.dump( |
| 109 | + results_list, |
| 110 | + open( |
| 111 | + os.path.join( |
| 112 | + os.path.dirname(args.save_path), |
| 113 | + os.path.basename(args.save_path).replace("answer", "answer_list"), |
| 114 | + ), |
| 115 | + "w", |
| 116 | + ), |
| 117 | + indent=4, |
75 | 118 | ) |
76 | | - qa_list = sample["QA_pairs"] |
77 | | - q_list = [] |
78 | | - a_list = [] |
79 | | - a_list_model = [] |
80 | | - for qa in qa_list: |
81 | | - q = qa["question"] |
82 | | - a = qa["answer"] |
83 | | - query = prompt.format(transcript=transcript, question=q) |
84 | | - answer = query_llm( |
85 | | - query, |
86 | | - model, |
87 | | - args.model_name_or_path, |
88 | | - args.n_max_token_ans, |
89 | | - tokenizer=tokenizer, |
| 119 | + |
| 120 | + |
| 121 | +predict() |
| 122 | +results_list = json.load( |
| 123 | + open( |
| 124 | + os.path.join( |
| 125 | + os.path.dirname(args.save_path), |
| 126 | + os.path.basename(args.save_path).replace("answer", "answer_list"), |
90 | 127 | ) |
91 | | - q_list.append(q) |
92 | | - a_list.append(a) |
93 | | - a_list_model.append(answer) |
94 | | - |
95 | | - results[sample_idx]["transcript"] = transcript |
96 | | - results[sample_idx]["questions"] = q_list[:] |
97 | | - results[sample_idx]["answers"] = a_list[:] |
98 | | - results[sample_idx]["model_answers"] = a_list_model[:] |
99 | | - |
100 | | - results_list["questions"].extend(q_list[:]) |
101 | | - results_list["answers"].extend(a_list[:]) |
102 | | - results_list["model_answers"].extend(a_list_model[:]) |
103 | | - |
104 | | - json.dump(results, open(args.save_path, "w"), indent=4) |
105 | | - json.dump( |
106 | | - results_list, |
107 | | - open( |
108 | | - os.path.join( |
109 | | - os.path.dirname(args.save_path), |
110 | | - os.path.basename(args.save_path).replace("answer", "answer_list"), |
111 | | - ), |
112 | | - "w", |
113 | | - ), |
114 | | - indent=4, |
115 | 128 | ) |
116 | | - |
117 | | -score_dict = evaluate_with_gt(results_list["answers"], results_list["model_answers"]) |
| 129 | +) |
| 130 | +for i, ans in enumerate(results_list["answers"]): |
| 131 | + results_list["answers"][i] = [results_list["answers"][i]] |
| 132 | +score_dict = evaluate_with_gt(results_list["model_answers"], results_list["answers"]) |
118 | 133 | json.dump( |
119 | 134 | score_dict, |
120 | 135 | open( |
|
0 commit comments