-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathinference_code.py
More file actions
55 lines (48 loc) · 1.76 KB
/
inference_code.py
File metadata and controls
55 lines (48 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from matformer.tokenizers import MatformerTokenizer
from matformer.model_config import ModelConfig
def setup_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_and_prepare_model(model_path, model_class, device, tokenizer=None, **kwargs):
if tokenizer:
model, config = model_class.load_from_checkpoint(
model_path,
inference_fix=True,
map_location=device,
tokenizer=tokenizer,
**kwargs
)
else:
model, config = model_class.load_from_checkpoint(
model_path,
inference_fix=True,
map_location=device,
**kwargs
)
model = model.to(device, torch.bfloat16)
model.eval()
return model, config
def interactive_loop(inference_function):
while True:
try:
inp = input(">>> ")
if not inp:
continue
inference_function(inp)
except KeyboardInterrupt:
print("\nFine.")
break
def batch_process_file(input_file, output_file, model, inference_method, **kwargs):
with open(input_file, "r") as f:
lines = f.readlines()
with open(output_file, "w") as f:
for line in lines:
line = line.strip()
if hasattr(model, inference_method):
if inference_method == 'inference_testing':
acc, predicted = getattr(model, inference_method)(line, **kwargs)
out = " ".join(f"[{tok}]" for tok in predicted)
f.write(f"{line}\n{out}\n\n")
else:
result = getattr(model, inference_method)(line, **kwargs)
f.write(f"{line}\n{result}\n\n")