-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
532 lines (488 loc) · 28.3 KB
/
agent.py
File metadata and controls
532 lines (488 loc) · 28.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
import os
import json
import math
import re
from typing import Tuple, Set, Dict
from abc import ABC, abstractmethod
import networkx as nx
from matplotlib import pyplot as plt
VERBOSE_RELATION = {"<ent>":"ent", "<con>": "con","<neu>": "neu", "<nneu>": "¬neu", "<nent>": "¬ent", "<ncon>": "¬con", "<nneu>": "¬neu"}
VERBOSE_RELATION_SWAP = {v: k for k, v in VERBOSE_RELATION.items()}
class Agent(ABC):
'''
The agent defines the interface.
An agent has either no, Chat-GPT or a finetuned LLAMA-2-7b as backend LLM.
The agent works on a multi directed (networkx) graph and multiple graph.
The interface allows adding nodes and edges in the common sense.
Graph operations like matching are treated differently tho.
Instead of finding nodes or edges, the job of the LLM is to understand the graph
and generate new nodes or edges, that fit into the given topology.
'''
def __init__(self, log_file = None, feedback_file = None, layout = "planar", font_size = 12, node_size = 300):
'''
The constructor allows general settings, like logs or feedback files and display properties.
Parameters
__________
log_file: Path/str
The path to save the log file at (default None).
feedback_file: Path/str
The path to save the feedback file at (default None).
Feedback files allow the collection of datapoints for training the model.
layout: str
Networkx layout how the graph is to be represented.
Either ("spring", "shell", "spiral", "spectral", "planar", "circular") (default "planar").
font_size: int > 0
Font size of the graph labels (default 12)
node_size int > 0
Node size of the graph nodes in (default 300).
'''
self.G = nx.DiGraph() #inits directional graph.
self.log_file = log_file
self.logs = None
if self.log_file:
self._init_logs() #loads log file into logs
self.feedback_file = feedback_file
if self.feedback_file:
self._init_feedback() #loads feedback file into feedback
#cosmetic settings
self.layout = layout
self.font_size = font_size
self.node_size = node_size
def _init_logs(self):
'''
Load logs into memory if the log file already exists.
'''
if os.path.isfile(self.log_file):
try:
fp = open (self.log_file, 'rb')
self.logs = json.load(fp)
except Exception as e:
print(e)
self.logs = []
else:
self.logs = []
def _init_feedback(self):
'''
Load feedback into memory if the feedback file already exists.
'''
if os.path.isfile(self.feedback_file):
try:
fp = open (self.feedback_file, 'rb')
self.feedback = json.load(fp)
except Exception as e:
print(e)
self.feedback = []
else:
self.feedback = []
def _log_add_node(self, node_name, text, query = None, generated = -1):
'''
Adds the new node to the log file if set in init.
'''
if self.log_file and isinstance(self.logs, list):
self.logs.append({"query": query, "text": text, "node": node_name, "generated": generated})
with open(self.log_file, 'w') as f:
json.dump(self.logs, f, indent=4)
def _log_add_edge(self, edge, relation = "ent", query = None, generated = None):
'''
Adds the new edge to the log file if set in init.
'''
if self.log_file and isinstance(self.logs, list):
self.log.append({"query": query, "relation": relation, "edge": edge, "generated": generated})
with open(self.log_file, 'w') as f:
json.dump(self.logs, f, indent=4)
def display_Graph(self, changed_nodes = None, changed_edges = None):
'''
Displays the graph with the following color scheme:
Nodes and edges that have been generated by the model are colored red.
Nodes and edges that have been inserted by the user are colored green.
Generic nodes and edges that are used for unification of nodes and edges are colored transparent.
Highlights nodes and edges yellow, if they were added since the last display.
Parameters
__________
changed_elems: tuple[list[str], list[tuple[str,str]]]
a tuple with changed nodes (list[str]) with node names and
the changed edges with list of node pairs (source, target).
Can be None.
'''
fig_size = len(self.G.nodes)
fig_size = (fig_size, math.ceil(fig_size/2)) # The width size depends on the number of nodes
plt.figure(3,figsize=fig_size)
if self.layout == "spring": # Switch case depending on the layout
pos = nx.spring_layout(self.G)
elif self.layout == "shell":
pos = nx.shell_layout(self.G)
elif self.layout == "spiral":
pos = nx.spiral_layout(self.G)
elif self.layout == "spectral":
pos = nx.spectral_layout(self.G)
elif self.layout == "planar":
pos = nx.planar_layout(self.G)
elif self.layout == "circular":
pos = nx.circular_layout(self.G)
nodes_not_generated = [x for x,y in self.G.nodes(data=True) if y['generated'] == -1] # Filter all node that have been inserted by the user.
nodes_generated = [x for x,y in self.G.nodes(data=True) if y['generated'] == 1] # Filter all nodes that have been generated by the model.
generic_nodes = [x for x,y in self.G.nodes(data=True) if y['generated'] == 0] # Filter all generic nodes.
labels = {}
for x,y in self.G.nodes(data=True):
labels[x] = y['text'] if y['generated'] == 0 else x # Changes the labels of the generic nodes to their generic counterpart.
if changed_nodes: # Highlight if changed elements are given.
nx.draw_networkx_nodes(self.G, pos, nodelist = changed_nodes, node_color = 'yellow', node_size = self.node_size *2.0, alpha= 0.3)
nx.draw( # Colors user generated nodes green.
self.G,
pos,
nodelist = nodes_not_generated,
labels = labels,
with_labels = True,
node_color = "tab:green",
font_size=self.font_size,
node_size = self.node_size
)
nx.draw( # Colors model generated nodes red.
self.G,
pos,
nodelist = nodes_generated,
labels = labels,
with_labels = True,
node_color = "tab:red",
font_size=self.font_size,
node_size = self.node_size
)
nx.draw( # Colors generic nodes transparent.
self.G,
pos,
nodelist = generic_nodes,
labels = labels,
with_labels = True,
node_color = 'none',
font_size=self.font_size,
node_size = self.node_size
)
edge_relations = nx.get_edge_attributes(self.G,'relation')
edge_attributes_generated = nx.get_edge_attributes(self.G,'generated')
edge_relations_not_generated = {k: v for k, v in edge_relations.items() if edge_attributes_generated[k] == -1} #Filter all edges generated by the model.
edge_relations_generated = {k: v for k, v in edge_relations.items() if edge_attributes_generated[k] == 1} #Filter all edges inserted by the user.
if changed_edges: # Highligts changed nodes if given.
nx.draw_networkx_edges(self.G, pos, edgelist =changed_edges, edge_color= "yellow", width = 3.0, alpha=0.3, arrows = False)
nx.draw_networkx_edge_labels( # Colors user generated edge labels green.
self.G,
pos,
edge_labels = edge_relations_not_generated,
font_color= "tab:green",
font_size=self.font_size
)
nx.draw_networkx_edge_labels( # Colors model generated edge labeld red.
self.G,
pos,
edge_labels = edge_relations_generated,
font_color= "tab:red",
font_size=self.font_size
)
plt.show()
#Produces and prints a list of all nodes and edges in a human readable form (truncated if more than 200 characters long).
prty_print = {}
for x,y in self.G.nodes(data=True):
max_length = min([200, len(y['text'])])
prty_print[x] = y['text'][:max_length]
if max_length < len(y['text']):
prty_print[x] += "..."
print(json.dumps(prty_print, indent=4))
def reset(self):
'''
Resets the graph and clears the logs.
'''
self.G.clear()
self.logs = [] if self.log_file else None
def add_node(self, name, text, query = None, generated = -1):
'''
Adds a node to the graph and logs the process if log_file was set on init.
Parameters
__________
name: str
The name of the node to be added. Overwrites if already exist.
text: str
The text that the node represents.
query: str (optional)
If a query resulted in this node, the query will be logged as well.
generated: int (default -1)
Marks the node as generated by user (-1), generic node (0) or generated by model (1)
'''
self._log_add_node(node_name = name, text = text, query = query, generated = generated)
self.G.add_node(name, text = text, generated = generated, query = query, subset = query)
def add_edge(self, source, target, relation = "ent", query = None, generated = -1):
'''
Adds an edge to the graph and logs the process if log_file was set on init.
Parameters
__________
source: str
The name of the source node to be added.
target: str
The name of the target node to be added.
relation: str (default ent)
The textual entailment relation or unification this edge represents.
ent for entailment
neu for neutral
con for contradiction
nent for not entailment
nneu for not neutral
ncon for not contradiction
& for AND unification
query: str (optional)
If a query resulted in this node, the query will be logged as well.
generated: int (default -1)
Marks the edge as generated by user (-1), generic node (0) or generated by model (1)
'''
self._log_add_edge(edge = (source, target), relation = relation, query = query, generated = generated)
self.G.add_edge(source, target, relation = relation, generated = generated, query = query)
'''
Entailment Operations
_____________________
The entailment operations are the core of this agent. They allow the agent to predict relations
between nodes or generate new nodes, that explain or explore a given context.
While this class defines how queries are formulated, the inheriting agents implement how they are processed.
All those operators return the same Reponse signature:
OperatorResponses
_________________
Dict[str, str]
A dictionary with the fields
query: The query that resulted in the responses.
actual response: The response that was returned by the model without cleaning.
clean_response: The response that was cleaned so it matches expected GBRL syntax (forgiving).
List[str]
A list of nodes that were added or changed in any way.
List[Tuple[str,str]]
A list of Tuples of node names representing edges that were added or changed in any way.
bool A flag if the response fits an expected format.
'''
EntailmentOperatorReturnType = Tuple[Dict[str, str], Set[str], Set[Tuple[str, str]], bool]
def _process_missmatch(self, query, regex, actual_response, clean_response):
'''
A callback for missmatching agent resuls.
'''
print(f"Expected response of form '{regex}' but got '{clean_response}'.")
return {"query": query, "actual_response": actual_response, "clean_response": clean_response} , set(), set(), False
@abstractmethod
def _get_response(self, query, **kwargs) -> Tuple[str,str]:
'''
This method implements the way the agent processes queries.
Parameters
__________
query: str
The query to be processed.
**kwargs key-value
A dict of additional parameters, that some inheriting subclasses need.
Returns
_______
str A cleaned (fogiving) version of the response
str The actual response.
'''
pass
def relation_between(self, premisses, hypothesis, **kwargs) -> EntailmentOperatorReturnType:
'''
Generates an edge between the AND-node of premisses and hypothesis.
If premisses is list of more than one node, this adds an AND-unification of those nodes first.
example query: "<s1><:>some premis text<;><s2><:>some other premis text<;><s3><:>Some hypothesis text.<;><e1><:><rel1><;><s1><&><s2><e1><s3><;>"
Parameters
__________
premisses: str|list[str]
A single node name or list of node names.
If length of list of node names is bigger than 1, the premisses are first unifified with an AND-edge.
hypothesis: str
The hypothesis and target of the textual entailment relationshop.
**kwargs key-value
A dict of additional parameters, that some inheriting subclasses need.
Returns
_______
OperatorResponses (see Entailment Operations Returns)
'''
if isinstance(premisses, list):
premisses.sort() #Sort the premisses, so they yield same results if swapped.
else:
premisses = [premisses]
query = ""
and_node_name = ""
for idx, premis in enumerate(premisses): #Get each premis and unite them with the AND-node
text = self.G.nodes[premis]['text']
query += f"<s{idx+1}><:>{text}<;>"
and_node_name += f"{premis}&"
and_node_name = and_node_name[:-1]
query += f"<s{len(premisses)+1}><:>{self.G.nodes[hypothesis]['text']}<;><e1><:><rel1><;>" #Add the relation <e1> to query that has to be solved.
for idx in range(len(premisses)):
query += f"<s{idx+1}><&>"
query = f"{query[:-3]}<e1><s{len(premisses)+1}><;>" #Add the relations between all nodes including the relation that has to be solved for.
clean_response, actual_response = self._get_response(query, **kwargs)
if not clean_response.startswith(f"<rel1><:>"): #Some operation specific cleaning on the "clean response"
clean_response = f"<rel1><:>{clean_response}"
response_regex = r"<rel1><:>(<ent>|<con>|<neu>|<nent>|<ncon>|<nneu>)<;>" #Expeced response
match = re.match(response_regex, clean_response)
changed_nodes = set()
changed_edges = set()
if not match: #If the model responded in an unforgiving syntax.
return self._process_missmatch(query, response_regex, actual_response, clean_response)
relation = match[1]
if len(premisses) > 1:
self.add_node(and_node_name, "&", generated = 0, query = query) #Adds the new (generic) AND-node to the graph.
changed_nodes.add(and_node_name)
for premis in premisses:
self.add_edge(premis, and_node_name, relation = "&", query = query, generated = 0) #Adds the (generic) AND-relation from all premisses to the new AND-node.
changed_edges.add((premis, and_node_name))
premis = and_node_name
else:
premis = premisses[0]
self.add_edge( #Adds the new edge generated by the agent to the graph.
premis,
hypothesis,
relation = VERBOSE_RELATION[relation],
generated = 1,
query = query
)
changed_edges.add((premis, hypothesis))
return {"query": query, "actual_response": actual_response, "clean_response": clean_response} , changed_nodes, changed_edges, True
def generate_entailment(self, premisses, node_name, relation = "ent", previous_answers = [], prompt = None, **kwargs) -> EntailmentOperatorReturnType:
'''
Generates an edge between the AND-node of premisses and hypothesis.
If premisses is list of more than one node, this adds an AND-unification of those nodes first.
Parameters
__________
premisses: str|list[str]
A single node name or list of node names.
If length of list of node names is bigger than 1, the premisses are first unifified with an AND-edge.
node_name: str
The name of the new node that is to be generated.
relation: str (default ent)
The textual relation that defines the relationship between the premisses and the new node to be generated.
See above (add edge) for possible relation types.
previous_answers: list[str]
A list of nodes that have been generated by the same operator on the same nodes.
This helps the model to generate nodes that are likely as well.
prompt: str (default None)
If given, the model is supposed to answer to the prompt accordingly.
The correct way of answering to a prompt is to begin the generated node with the prompt followed by the answer.
**kwargs key-value
A dict of additional parameters, that some inheriting subclasses need.
Returns
_______
OperatorResponses (see Entailment Operations Returns)
'''
if isinstance(premisses, list):
premisses.sort() #Sort the premisses, so they yield same results if swapped.
else:
premisses = [premisses]
query = ""
and_node_name = ""
for idx, premis in enumerate(premisses): #Get each premis and unite them with the AND-node
text = self.G.nodes[premis]['text']
query += f"<s{idx+1}><:>{text}<;>"
and_node_name += f"{premis}&"
and_node_name = and_node_name[:-1]
for idx, previous_answer in enumerate(previous_answers): #Adds the previous answers to the query
text = self.G.nodes[previous_answer]['text']
query += f"<s{len(premisses)+1+idx}><:>{text}<;>"
query += f"<s{len(previous_answers) + len(premisses) + 1}><:>" #Start the definition of the new to be generated node.
if prompt:
query += f"{prompt} " #Node starts with prompt if passed.
for idx in range(len(previous_answers)): #Then every previous answer is added with the AND-Operator
query += f"<s{len(premisses)+1+idx}><&>"
query += f"<t{len(previous_answers) + len(premisses) + 1}><;>" #Last the actual placeholder is added
for idx in range(len(premisses)):
query += f"<s{idx+1}><&>"
query_relation = VERBOSE_RELATION_SWAP[relation]
query = f"{query[:-3]}{query_relation}<s{len(previous_answers) + len(premisses) + 1}><;>"
clean_response, actual_response = self._get_response(query, **kwargs)
if not clean_response.startswith(f"<t{len(previous_answers) + len(premisses) + 1}><:>"): #Some operation specific cleaning on the "clean response"
clean_response = f"<t{len(previous_answers) + len(premisses) + 1}><:>{clean_response}"
response_regex = f"<t{len(previous_answers) + len(premisses) + 1}><:>(.*|\s)<;>"
match = re.match(response_regex, clean_response)
changed_nodes = set()
changed_edges = set()
if not match: #If the model responded in an unforgiving syntax.
return self._process_missmatch(query, response_regex, actual_response, clean_response)
generated_text = match[1]
if prompt:
generated_text = f"{prompt} {generated_text}"
if len(premisses) > 1:
self.add_node(and_node_name, "&", generated = 0, query = query) #Adds the new (generic) AND-node to the graph.
changed_nodes.add(and_node_name)
for premis in premisses:
self.add_edge(premis, and_node_name, relation = "&", query = query, generated = 0) #Adds the (generic) AND-relation from all premisses to the new AND-node.
changed_edges.add((premis, and_node_name))
premis = and_node_name
else:
premis = premisses[0]
self.add_node(
node_name,
generated_text,
query = query,
generated = 1
)
changed_nodes.add(node_name)
self.add_edge( #Adds the new edge generated by the user to the graph.
premis,
node_name,
relation = relation,
generated = 1,
query = query
)
changed_edges.add((premis, node_name))
return {"query": query, "actual_response": actual_response, "clean_response": clean_response} , changed_nodes, changed_edges, True
def explain_relation(self, source, target, node_name, previous_answers = [], **kwargs) -> EntailmentOperatorReturnType:
'''
Generates a node that explains the relation between source and target.
Parameters
__________
source: str
The source noce name.
target: str
The target node name.
node_name: str
The name of the new node that is to be generated.
previous_answers: list[str]
A list of nodes that have been generated by the same operator on the same nodes.
This helps the model to generate nodes that are likely as well.
**kwargs key-value
A dict of additional parameters, that some inheriting subclasses need.
Returns
_______
OperatorResponses (see Entailment Operations Returns)
'''
source_node = self.G.nodes[source]
source_generated = source_node['generated']
if source_generated == 0: #If the source node is an AND-node, we need to collect all inputs of that node.
source_node_names = [edge[0] for edge in self.G.in_edges([source])]
else:
source_node_names = [source]
target_text = self.G.nodes[target]['text']
edge_attributes = nx.get_edge_attributes(self.G,'relation')
edge_attributes_generated = nx.get_edge_attributes(self.G,'generated')
edge_relation = edge_attributes[(source, target)]
edge_generated = edge_attributes_generated[(source, target)]
query = ""
for idx, source_node_name in enumerate(source_node_names): #Adds source node or all input nodes if AND-Node to query.
source_text = self.G.nodes[source_node_name]['text']
query = f"<s{idx+1}><:>{source_text}<;>"
for idx, previous_answer in enumerate(previous_answers): #Adds the previous answers to the query
text = self.G.nodes[previous_answer]['text']
query += f"<s{len(source_node_names)+1+idx}><:>{text}<;>"
query += f"<s{len(previous_answers) + len(source_node_names) + 1}><:>"
for idx in range(len(previous_answers)): #Then every previous answer is added with the AND-Operator
query += f"<s{len(source_node_names)+1+idx}><&>"
query += f"<t{len(previous_answers) + len(source_node_names) + 1}><;><s{len(previous_answers) + len(source_node_names) + 2}><:>{target_text}<;>"
for idx in range(len(source_node_names)):
query += f"<s{idx+1}><&>"
query = f"{query[:-3]}<ent><s{len(previous_answers) + len(source_node_names) + 1}><;><s{len(previous_answers) + len(source_node_names) + 1}>{VERBOSE_RELATION_SWAP[edge_relation]}<s{len(previous_answers) + len(source_node_names) + 2}><;>"
changed_nodes = set()
changed_edges = set()
clean_response, actual_response = self._get_response(query, **kwargs)
if not clean_response.startswith(f"<t{len(previous_answers) + len(source_node_names) + 1}><:>"): #Some operation specific cleaning on the "clean response"
clean_response = f"<t{len(previous_answers) + len(source_node_names) + 1}><:>{clean_response}"
response_regex = f"<t{len(previous_answers) + len(source_node_names) + 1}><:>(.*|\s)<;>"
match = re.match(response_regex, clean_response)
if not match: #If the model responded in an unforgiving syntax.
return self._process_missmatch(query, response_regex, actual_response, clean_response)
generated_text = match[1]
self.add_node(node_name, generated_text, generated = 1, query = query) #Adds the explanation node to the graph
changed_nodes.add(node_name)
self.add_edge(source, node_name, relation = "ent", generated = -1, query = query) #Adds the edge between source and explanation to the graph, which is entailment by definition.
changed_edges.add((source, node_name))
self.add_edge(node_name, target, relation = edge_relation, generated = edge_generated, query = query) #Adds the edge between explanation and target node to the graph, which is the same relation as between source and target.
changed_edges.add((node_name, target))
return {"query": query, "actual_response": actual_response, "clean_response": clean_response} , changed_nodes, changed_edges, True