-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathentailmenttree.py
More file actions
232 lines (175 loc) · 5.73 KB
/
entailmenttree.py
File metadata and controls
232 lines (175 loc) · 5.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import torch
from treenode import TreeNode
from embed import sentence_to_vec
class EntailmentTree:
"""
Represents an entailment tree.
Sentences and tree structure are decoupled.
"""
def __init__(self, tree_json):
# Sentences
self.id_to_sentence = self._parse_sentences(tree_json)
self.id_to_embedding = self._parse_embedding(self.id_to_sentence)
# Tree structure
tree_string = tree_json["meta"]["lisp_proof"]
tokens = tree_string.replace('(', ' ( ').replace(')', ' ) ').split()
self.root = self._parse_root(tokens, 0, len(tokens) - 1)
def __str__(self):
return '\n'.join([f'{k}: {v}' for k, v in self.id_to_sentence.items()
]) + '\n' + self.root.__str__()
def generated_premises(self):
"""
Returns a list of premises generated at each timestep
based on the tree.
:return: list of T ids, for T iterations.
"""
def generated_helper(root):
if len(root.children) == 0:
return []
# Recurse then concatenate
ids = [generated_helper(child) for child in root.children]
ids = [id for l in ids for id in l]
ids.append(root.id)
return ids
return generated_helper(self.root)
def available_premises(self):
"""
Recursively fetches the IDs for the
available premises at each iteration.
:return: return list of T lists
"""
def get_initial_premises(root):
"""
Returns initial premises, a.k.a. leaf nodes.
"""
if len(root.children) == 0:
return [root.id]
# Recurse and flatten
ids = [get_initial_premises(child) for child in root.children]
return [id for l in ids for id in l]
ids_g = self.generated_premises()
ids_r = self.retrieved_premises()
# Construct available premises per iteration
ids = [get_initial_premises(self.root)]
for t in range(len(ids_g) - 1):
ids.append(ids[-1].copy())
# Remove retrieved premises
for id in ids_r[t]:
if id in ids[-1]:
ids[-1].remove(id)
# Add generated premise
ids[-1].append(ids_g[t])
return ids
def retrieved_premises(self):
"""
Recursively fetches the IDs for the
retrieved premises at each iteration.
The retrieved premises are simply a list
of children for each node.
Retrieved premises are ordered from left to right in the tree.
Recursive algorithm:
get list of retrieve premises for left, then right, then
append (child for children of root) to result
:return retrieved: list[T][K_t]
"""
def retrieved_helper(root):
if len(root.children) == 0:
return None
# Recurse
ids = [retrieved_helper(child) for child in root.children]
# Take out NoneTypes
ids = [id for id in ids if id is not None]
# Flatten
ids = [id for l in ids for id in l]
# Add current list of retrieved premises
ids.append([child.id for child in root.children])
return ids
return retrieved_helper(self.root)
def get_indices_of_retrieved_premises(self):
"""
Returns the indices of each retrieved premise
w.r.t. the available premises for each iteration.
:return indices: T-long list,
where T is the number of iterations.
e.g. if ids_r = [[1, 4, 6], [1, 2]]
and ids_a = [[1, 2, 3, 4, 5, 6], [0, 1 ,2]]
then indices = [[0, 3, 5], [1, 2]]
"""
ids_r = self.retrieved_premises()
ids_a = self.available_premises()
indices = []
for r, a in zip(ids_r, ids_a):
indices.append(torch.tensor([a.index(id) for id in r if id in a]))
return indices
def to_embedding(self, ids):
"""
Given a list of m lists of m_i IDs,
returns an list of m_i x d torch tensor of embeddings.
The embedding for an ID is stored in the
dictionary self.id_to_embedding
e.g. self.id_to_embedding[id]
returns a d-dimensional embedding.
e.g. [[id1 id2] [id3 id4 id5]] -> [tensor1 tensor2]
where tensor1 is 2 x d and tensor2 is 3 x d.
"""
embeddings_list = []
for id_list in ids:
tensor_list = []
for id in id_list:
embedding = self.id_to_embedding[id]
tensor_list.append(embedding)
embeddings_list.append(torch.stack(tensor_list))
return embeddings_list
def _parse_sentences(self, tree_json):
"""
Extracts sentences from tree json from dataset.
Each sentences corresponds to an id.
"""
return {
**tree_json["meta"]["triples"],
**tree_json["meta"]["intermediate_conclusions"],
}
def _parse_embedding(self, id_to_sentence):
"""
Returns embeddings for each sentence.
"""
return {
id: sentence_to_vec(sentence) for id, sentence in id_to_sentence.items()
}
def _parse_root(self, tokens, start, end):
"""
Must have a single root node.
"""
id = tokens[end - 1]
children = self._parse_children(tokens, start + 1, end - 3)
root = TreeNode(id, children)
return root
def _parse_children(self, tokens, start, end):
"""
Parses children into an array.
"""
children = []
i = start + 1
while i < end:
# Case 1: base case (leaf node)
if tokens[i] != "(":
id = tokens[i]
children.append(TreeNode(id, []))
# Case 2: recursive case
else:
j = self._find_matching_parenthesis(tokens, i)
child = self._parse_root(tokens, i, j)
children.append(child)
i = j
i += 1
return children
def _find_matching_parenthesis(self, tokens, left_index):
left_count = 0
for i in range(left_index, len(tokens)):
if tokens[i] == '(':
left_count += 1
elif tokens[i] == ')':
left_count -= 1
if left_count == 0:
return i
return -1 # No matching right parenthesis found