-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNLI.py
More file actions
119 lines (91 loc) · 4.6 KB
/
NLI.py
File metadata and controls
119 lines (91 loc) · 4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import torch
import torch.nn.functional as F
import argparse
model_name = "ruanchaves/mdeberta-v3-base-assin2-entailment"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
# Enable GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def generate_hypothesis(subject, relation, obj):
relation_text = {
"instância de": "{} é um(a) {}.",
"ocupação": "{} trabalha como {}.",
"local de nascimento": "{} nasceu em {}.",
"subclasse de": "{} é uma subclasse de {}.",
"país": "{} está localizado no país {}.",
"sede": "A sede de {} está localizada em {}.",
"tema(s) principal(is)": "{} aborda o tema de {}.",
"categoria taxonómica": "{} pertence à categoria taxonômica {}.",
"género": "{} pertence ao género {}.",
"autor": "{} tem como autor {}.",
"desporto": "{} está relacionado ao desporto {}.",
"desenvolvedor": "{} foi desenvolvido por {}.",
"realizador": "{} foi dirigido ou realizado por {}.",
"banhado por": "{} é localizado em ou perto de {}.",
"compositor": "{} foi composto por {}.",
"empresa de produção": "{} foi produzido pela empresa {}.",
"religião": "{} pertence à religião {}.",
"pai": "{} é filho(a) de {}.",
#add more if needed
}
# Default template for unknown relations
default_template = "{} tem a relação '{}' com {}."
template = relation_text.get(relation, default_template)
return template.format(subject, relation, obj) if template == default_template else template.format(subject, obj)
# Function to process a batch of inputs
def process_batch(batch):
premises = [item['sentence'] for item in batch]
#print(premises)
hypotheses = [generate_hypothesis(item['subject'], item['relation'], item['object']) for item in batch]
#print(hypotheses)
# Tokenize the batch
inputs = tokenizer(premises, hypotheses, return_tensors="pt", truncation=True, padding=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()} # Move inputs to GPU
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Get logits and convert to probabilities
logits = outputs.logits
probs = F.softmax(logits, dim=-1) # Apply softmax to get probabilities
# Get predicted classes and corresponding labels
predicted_classes = torch.argmax(logits, dim=-1).tolist()
predicted_labels = [model.config.id2label[cls] for cls in predicted_classes]
return predicted_labels, probs.tolist()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="NLI Step")
parser.add_argument("input_file", type=str, help="Path to the input jsonl data file.")
parser.add_argument("-o", "--output_file", type=str, default="dataset_NLI.jsonl", help="output dataset file path after the NLI step")
args = parser.parse_args()
# Batch size (adjust based on your GPU memory)
batch_size = 4
# Open the file and process in batches
batch = []
count=0
with open(args.input_file, 'r', encoding='utf-8') as file, open(args.output_file, 'w', encoding='utf-8') as output_file:
for line in file:
# Parse the JSON object from the line
data = json.loads(line.strip())
batch.append(data)
# Process the batch when it reaches the desired size
if len(batch) == batch_size:
predicted_labels, probs = process_batch(batch)
# Add NLI results to each item in the batch
for i, item in enumerate(batch):
item['nli_prediction'] = predicted_labels[i]
item['nli_probabilities'] = probs[i]
# Write the updated item to the output JSONL file
output_file.write(json.dumps(item, ensure_ascii=False) + '\n')
# Clear the batch
batch = []
# Process any remaining items in the last batch
if batch:
predicted_labels, probs = process_batch(batch)
for i, item in enumerate(batch):
item['nli_prediction'] = predicted_labels[i]
item['nli_probabilities'] = probs[i]
# Write the updated item to the output JSONL file
output_file.write(json.dumps(item, ensure_ascii=False) + '\n')