-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathblob.py
More file actions
26 lines (24 loc) · 1.01 KB
/
blob.py
File metadata and controls
26 lines (24 loc) · 1.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
class Blob:
def __init__(self):
"""The __init__ function is needed for initial preparation. It is started once during deployment."""
self.tokenizer = AutoTokenizer.from_pretrained("Elron/bleurt-large-512")
self.model = AutoModelForSequenceClassification.from_pretrained("Elron/bleurt-large-512")
self.model.eval()
self.model.to("cuda")
def predict(self, model_inputs: dict):
"""The predict function is called for every prediction."""
references = model_inputs["references"]
candidates = model_inputs["candidates"]
with torch.no_grad():
scores = (
self.model(
**self.tokenizer(
references, candidates, return_tensors="pt", padding=True
).to("cuda")
)[0]
.squeeze()
.tolist()
)
return {"scores": scores}