-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathdata_loading_utils.py
More file actions
129 lines (109 loc) · 4.24 KB
/
data_loading_utils.py
File metadata and controls
129 lines (109 loc) · 4.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
import random
import torch as t
import torch.nn.functional as F
from nnsight import LanguageModel
def load_examples(
dataset: str,
n_examples: int,
model: LanguageModel,
use_min_length_only: bool = False,
max_length: int = None,
seed: int = 12
):
with open(dataset, "r") as f:
dataset_items = f.readlines()
random.Random(seed).shuffle(dataset_items)
examples = []
if use_min_length_only:
min_length = float("inf")
for line in dataset_items:
data = json.loads(line)
clean_prefix = model.tokenizer(data["clean_prefix"]).input_ids
patch_prefix = model.tokenizer(data["patch_prefix"]).input_ids
clean_answer = model.tokenizer(data["clean_answer"]).input_ids
patch_answer = model.tokenizer(data["patch_answer"]).input_ids
clean_full = model.tokenizer(data["clean_prefix"] + data["clean_answer"]).input_ids
patch_full = model.tokenizer(data["patch_prefix"] + data["patch_answer"]).input_ids
# strip BOS token from response if necessary
if clean_answer[0] == model.tokenizer.bos_token_id:
clean_answer = clean_answer[1:]
if patch_answer[0] == model.tokenizer.bos_token_id:
patch_answer = patch_answer[1:]
# check that answer is one token
if len(clean_answer) != 1 or len(patch_answer) != 1:
continue
# check that prefixes are the same length
if len(clean_prefix) != len(patch_prefix):
continue
# check for tokenization mismatches
if clean_prefix + clean_answer != clean_full:
continue
if patch_prefix + patch_answer != patch_full:
continue
if max_length is not None and len(clean_prefix) > max_length:
continue
if use_min_length_only:
# restart collection if we've found a new shortest example
if (length := len(clean_prefix)) < min_length:
examples = [] # restart collection
min_length = length
# skip if too long
elif length > min_length:
continue
examples.append(data)
if len(examples) >= n_examples:
return examples
def load_examples_nopair(dataset, num_examples, model):
examples = []
if isinstance(dataset, str):
with open(dataset, "r") as f:
dataset = json.load(f)
elif isinstance(dataset, dict):
pass
else:
raise ValueError(f"Invalid dataset type: {type(dataset)}")
for context_id in dataset:
context = dataset[context_id]["context"]
context = "".join(context)
answer = dataset[context_id]["answer"]
if model.tokenizer(answer, return_tensors="pt").input_ids.shape[1] != 1:
continue
examples.append(
{
"clean_prefix": context,
"clean_answer": answer,
}
)
if len(examples) >= num_examples:
break
return examples
def get_annotation(dataset, model, data):
# First, understand which dataset we're working with
structure = None
if "within_rc" in dataset:
structure = "within_rc"
template = "the_subj subj_main that the_dist subj_dist"
elif "rc.json" in dataset or "rc_" in dataset:
structure = "rc"
template = "the_subj subj_main that the_dist subj_dist verb_dist"
elif "simple.json" in dataset or "simple_" in dataset:
structure = "simple"
template = "the_subj subj_main"
elif "nounpp.json" in dataset or "nounpp_" in dataset:
structure = "nounpp"
template = "the_subj subj_main prep the_dist subj_dist"
if structure is None:
return {}
annotations = {}
# Iterate through words in the template and input. Get token spans
curr_token = 0
for template_word, word in zip(template.split(), data["clean_prefix"].split()):
if word != "The":
word = " " + word
word_tok = model.tokenizer(word, return_tensors="pt", padding=False).input_ids
num_tokens = word_tok.shape[1]
span = (curr_token, curr_token + num_tokens-1)
curr_token += num_tokens
annotations[template_word] = span
return annotations