-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathposterreward_scorer.py
More file actions
61 lines (48 loc) · 2.04 KB
/
posterreward_scorer.py
File metadata and controls
61 lines (48 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
"""
PosterReward 打分模型:读取分析结果,生成标量分数
"""
import json
import os
import argparse
from swift.llm import PtEngine, InferRequest
def main():
parser = argparse.ArgumentParser(description="PosterReward 打分模型推理")
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--input", type=str, required=True, help="Path to analysis_output.jsonl")
parser.add_argument("--output", type=str, default="./score_output.jsonl")
parser.add_argument("--gpu", type=str, default="0")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--eval_prompt", type=str, default="{prompt}")
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with open(args.input, 'r', encoding='utf-8') as f:
data = json.loads(f.readline().strip())
original_prompt = data.get('original_prompt', '')
analysis = data.get('analysis', '')
image_path = data.get('image_path', '')
print(f"Loading scorer model: {args.model}")
engine = PtEngine(args.model, max_batch_size=args.batch_size,
task_type='seq_cls', num_labels=1)
eval_content = args.eval_prompt.format(prompt=original_prompt)
messages = [
{"role": "user", "content": f"<image>{eval_content}"},
{"role": "assistant", "content": analysis},
]
request = InferRequest(messages=messages, images=[image_path])
print("Running reward scoring...")
resp_list = engine.infer([request])
score = resp_list[0].choices[0].message.content
record = {
'image_path': image_path,
'original_prompt': original_prompt,
'analysis': analysis,
'reward_score': score,
}
os.makedirs(os.path.dirname(os.path.abspath(args.output)) or '.', exist_ok=True)
with open(args.output, 'w', encoding='utf-8') as f:
f.write(json.dumps(record, ensure_ascii=False) + '\n')
print(f"Score: {score}")
print(f"Result saved to: {args.output}")
if __name__ == "__main__":
main()