forked from HJYao00/Mulberry
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer.py
More file actions
111 lines (87 loc) · 3.59 KB
/
infer.py
File metadata and controls
111 lines (87 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import transformers
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import argparse
PROMPT = """Generate an image description based on the question.
Then, provide a rationale to analyze the question.
Next, generate a step-by-step reasoning process to solve the problem. Ensure the steps are logical and concise.
Finally, provide a concise summary of the final answer in the following format: 'The final answer is: xxx'. If the question is multiple-choice, provide the options along with their content. If it is free-form, directly present the final result. Do not provide any explanation.
Format your response with the following sections, separated by ###:
### Image Description:
### Rationales:
### Let's think step by step.
### Step 1:
### Step 2:
...
### The final answer is:
{question}"""
def llava_infer(model_path, question, img_path, only_output_final_answer=False):
def output_process(answer):
if "<s>" in answer:
answer = answer.replace("<s>", "").strip()
if "[/INST]" in answer:
answer = answer.split("[/INST]")[1].strip()
elif "ASSISTANT:" in answer:
answer = answer.split("ASSISTANT:")[1].strip()
elif "assistant\n" in answer:
answer = answer.split("assistant\n")[1].strip()
elif "<|end_header_id|>\n\n" in answer:
answer = answer.split("<|end_header_id|>\n\n")[2].strip()
if "</s>" in answer:
answer = answer.split("</s>")[0].strip()
elif "<|im_end|>" in answer:
answer = answer.split("<|im_end|>")[0].strip()
elif "<|eot_id|>" in answer:
answer = answer.split("<|eot_id|>")[0].strip()
return answer
processor = LlavaNextProcessor.from_pretrained(model_path)
model = LlavaNextForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
use_flash_attention_2=True,
)
model = model.eval().cuda()
kwargs = dict(
do_sample=False, temperature=0, max_new_tokens=1024, top_p=None, num_beams=1
)
images = [Image.open(img_path).convert("RGB")]
content = [
{"type": 'text', "text":PROMPT.format(question=question)},
{"type": "image"}
]
conversation = [
{
"role": "user",
"content": content,
}
]
prompt = processor.apply_chat_template(
conversation, add_generation_prompt=True
)
inputs = processor(prompt, images, return_tensors="pt").to(
"cuda", torch.float16
)
output = model.generate(**inputs, **kwargs)
answer = processor.decode(output[0], skip_special_token=True)
answer = output_process(answer)
if only_output_final_answer:
if len(answer.split('### The final answer is:')) == 2:
answer = answer.split('### The final answer is:')[-1].strip()
return answer
else:
return answer
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default='Mulberry_llava_8b')
parser.add_argument("--question", type=str, default=None)
parser.add_argument("--img_path", type=str, default=None)
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--only_output_final_answer", action='store_true')
args = parser.parse_args()
if args.model == 'Mulberry_llava_8b':
answer = llava_infer(args.model_path, args.question, args.img_path, args.only_output_final_answer)
else:
raise NotImplementedError()
print(answer)