-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrephrasingModule.py
More file actions
28 lines (20 loc) · 885 Bytes
/
rephrasingModule.py
File metadata and controls
28 lines (20 loc) · 885 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
def rephrase(sentence,tokenizer,model):
text = "paraphrase: " + sentence + " </s>"
encoding = tokenizer.encode_plus(text, padding=True, return_tensors="pt")
#input_ids, attention_masks = encoding["input_ids"].to("cuda"), encoding["attention_mask"].to("cuda")
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"]
outputs = model.generate( #model.to("cuda").generate!!!
input_ids=input_ids, attention_mask=attention_masks,
max_length=128,
do_sample=True,
top_k=120,
top_p=0.95,
early_stopping=True,
num_return_sequences=2
)
lines=[]
for output in outputs:
line = tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
lines.append(line)
return lines