From 90e7ab61e827ce826fa0d2d2b8fbce7d34d3c926 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sun, 5 Oct 2025 13:40:41 -0400 Subject: [PATCH 01/32] Linting updates --- debug.py | 38 + .../annotation_functions.py | 6 +- pyreason/scripts/components/world.py | 11 +- .../scripts/interpretation/interpretation.py | 286 +-- .../interpretation/interpretation.py.bak | 1967 +++++++++++++++++ .../interpretation/interpretation_fp.py | 306 +-- .../interpretation/interpretation_parallel.py | 286 +-- pyreason/scripts/interval/interval.py | 10 +- .../numba_types/fact_edge_type.py | 10 +- .../numba_types/fact_node_type.py | 10 +- .../numba_types/interval_type.py | 8 +- .../numba_wrapper/numba_types/world_type.py | 20 +- pyreason/scripts/utils/filter.py | 36 +- pyreason/scripts/utils/graphml_parser.py | 56 +- pyreason/scripts/utils/output.py | 1 - pyreason/scripts/utils/plotter.py | 5 +- pyreason/scripts/utils/query_parser.py | 10 +- pyreason/scripts/utils/rule_parser.py | 22 +- pyreason/scripts/utils/visuals.py | 3 - pyreason/scripts/utils/yaml_parser.py | 28 +- 20 files changed, 2557 insertions(+), 562 deletions(-) create mode 100644 debug.py create mode 100755 pyreason/scripts/interpretation/interpretation.py.bak diff --git a/debug.py b/debug.py new file mode 100644 index 00000000..52e29773 --- /dev/null +++ b/debug.py @@ -0,0 +1,38 @@ +import pyreason as pr + +def test_anyBurl_rule_1_fp(): + graph_path = './tests/functional/knowledge_graph_test_subset.graphml' + pr.reset() + pr.reset_rules() + pr.reset_settings() + # Modify pyreason settings to make verbose and to save the rule trace to a file + pr.settings.verbose = True + pr.settings.fp_version = True # Use the FP version of the reasoner + pr.settings.atom_trace = True + pr.settings.memory_profile = False + pr.settings.canonical = True + pr.settings.inconsistency_check = False + pr.settings.static_graph_facts = False + pr.settings.output_to_file = False + pr.settings.store_interpretation_changes = True + pr.settings.save_graph_attributes_to_trace = True + # Load all the files into pyreason + pr.load_graphml(graph_path) + pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True)) + + # Run the program for two timesteps to see the diffusion take place + interpretation = pr.reason(timesteps=1) + # pr.save_rule_trace(interpretation) + + # Display the changes in the interpretation for each timestep + dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) + for t, df in enumerate(dataframes): + print(f'TIMESTEP - {t}') + print(df) + print() + assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' + assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' + assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' + +if __name__ == "__main__": + test_anyBurl_rule_1_fp() \ No newline at end of file diff --git a/pyreason/scripts/annotation_functions/annotation_functions.py b/pyreason/scripts/annotation_functions/annotation_functions.py index d04cf781..9ef577de 100755 --- a/pyreason/scripts/annotation_functions/annotation_functions.py +++ b/pyreason/scripts/annotation_functions/annotation_functions.py @@ -30,9 +30,9 @@ def _check_bound(lower, upper): if lower > upper: return (0, 1) else: - l = min(lower, 1) - u = min(upper, 1) - return (l, u) + lower_bound = min(lower, 1) + upper_bound = min(upper, 1) + return (lower_bound, upper_bound) @numba.njit diff --git a/pyreason/scripts/components/world.py b/pyreason/scripts/components/world.py index a632a680..090b502b 100755 --- a/pyreason/scripts/components/world.py +++ b/pyreason/scripts/components/world.py @@ -8,8 +8,8 @@ class World: def __init__(self, labels): self._labels = labels self._world = numba.typed.Dict.empty(key_type=label.label_type, value_type=interval.interval_type) - for l in labels: - self._world[l] = interval.closed(0.0, 1.0) + for lbl in labels: + self._world[lbl] = interval.closed(0.0, 1.0) @property def world(self): @@ -29,9 +29,6 @@ def is_satisfied(self, label, interval): return result def update(self, label, interval): - lwanted = None - bwanted = None - current_bnd = self._world[label] new_bnd = current_bnd.intersection(interval) self._world[label] = new_bnd @@ -48,7 +45,7 @@ def get_world(self): def __str__(self): result = '' - for label in self._world.keys(): - result = result + label.get_value() + ',' + self._world[label].to_str() + '\n' + for lbl in self._world.keys(): + result = result + lbl.get_value() + ',' + self._world[lbl].to_str() + '\n' return result diff --git a/pyreason/scripts/interpretation/interpretation.py b/pyreason/scripts/interpretation/interpretation.py index 81bf6bb0..2693a219 100755 --- a/pyreason/scripts/interpretation/interpretation.py +++ b/pyreason/scripts/interpretation/interpretation.py @@ -106,9 +106,9 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, # Setup graph neighbors and reverse neighbors self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type)) for n in self.graph.nodes(): - l = numba.typed.List.empty_list(node_type) - [l.append(neigh) for neigh in self.graph.neighbors(n)] - self.neighbors[n] = l + neighbor_list = numba.typed.List.empty_list(node_type) + [neighbor_list.append(neigh) for neigh in self.graph.neighbors(n)] + self.neighbors[n] = neighbor_list self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors) @@ -139,10 +139,10 @@ def _init_interpretations_node(nodes, specific_labels, num_ga): interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for l, ns in specific_labels.items(): - predicate_map[l] = numba.typed.List(ns) + for lbl, ns in specific_labels.items(): + predicate_map[lbl] = numba.typed.List(ns) for n in ns: - interpretations[n].world[l] = interval.closed(0.0, 1.0) + interpretations[n].world[lbl] = interval.closed(0.0, 1.0) num_ga[0] += 1 return interpretations, predicate_map @@ -158,10 +158,10 @@ def _init_interpretations_edge(edges, specific_labels, num_ga): interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for l, es in specific_labels.items(): - predicate_map[l] = numba.typed.List(es) + for lbl, es in specific_labels.items(): + predicate_map[lbl] = numba.typed.List(es) for e in es: - interpretations[e].world[l] = interval.closed(0.0, 1.0) + interpretations[e].world[lbl] = interval.closed(0.0, 1.0) num_ga[0] += 1 return interpretations, predicate_map @@ -246,16 +246,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Reset nodes (only if not static) for n in nodes: w = interpretations_node[n].world - for l in w: - if not w[l].is_static(): - w[l].reset() + for label in w: + if not w[label].is_static(): + w[label].reset() # Reset edges (only if not static) for e in edges: w = interpretations_edge[e].world - for l in w: - if not w[l].is_static(): - w[l].reset() + for label in w: + if not w[label].is_static(): + w[label].reset() # Convergence parameters changes_cnt = 0 @@ -269,36 +269,36 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi nodes_set = set(nodes) for i in range(len(facts_to_be_applied_node)): if facts_to_be_applied_node[i][0] == t: - comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] + comp, label, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] # If the component is not in the graph, add it if comp not in nodes_set: _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node) nodes_set.add(comp) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well - if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static(): + if label in interpretations_node[comp].world and interpretations_node[comp].world[label].is_static(): # Check if we should even store any of the changes to the rule trace etc. # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, bnd)) + rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, bnd)) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_node_trace[i]) for p1, p2 in ipl: - if p1==l: + if p1==label: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_node[comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p2], facts_to_be_applied_node_trace[i]) - elif p2==l: + elif p2==label: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i]) else: # Check for inconsistencies (multiple facts) - if check_consistent_node(interpretations_node, comp, (l, bnd)): + if check_consistent_node(interpretations_node, comp, (label, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) update = u or update # Update convergence params @@ -310,9 +310,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_node(interpretations_node, comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) update = u or update # Update convergence params @@ -322,7 +322,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute)) + facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, label, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i]) @@ -345,34 +345,34 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi edges_set = set(edges) for i in range(len(facts_to_be_applied_edge)): if facts_to_be_applied_edge[i][0]==t: - comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] + comp, label, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] # If the component is not in the graph, add it if comp not in edges_set: _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) edges_set.add(comp) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well - if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static(): + if label in interpretations_edge[comp].world and interpretations_edge[comp].world[label].is_static(): # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, interpretations_edge[comp].world[l])) + rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, interpretations_edge[comp].world[label])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_edge_trace[i]) for p1, p2 in ipl: - if p1==l: + if p1==label: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_edge[comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p2], facts_to_be_applied_edge_trace[i]) - elif p2==l: + elif p2==label: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_edge[comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p1], facts_to_be_applied_edge_trace[i]) else: # Check for inconsistencies - if check_consistent_edge(interpretations_edge, comp, (l, bnd)): + if check_consistent_edge(interpretations_edge, comp, (label, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) update = u or update # Update convergence params @@ -384,9 +384,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_edge(interpretations_edge, comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) update = u or update # Update convergence params @@ -396,7 +396,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute)) + facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, label, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i]) @@ -423,11 +423,11 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_node): if i[0] == t: - comp, l, bnd, set_static = i[1], i[2], i[3], i[4] + comp, label, bnd, set_static = i[1], i[2], i[3], i[4] # Check for inconsistencies - if check_consistent_node(interpretations_node, comp, (l, bnd)): + if check_consistent_node(interpretations_node, comp, (label, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) update = u or update # Update convergence params @@ -438,9 +438,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_node(interpretations_node, comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) update = u or update # Update convergence params @@ -462,7 +462,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_edge): if i[0] == t: - comp, l, bnd, set_static = i[1], i[2], i[3], i[4] + comp, label, bnd, set_static = i[1], i[2], i[3], i[4] sources, targets, edge_l = edges_to_be_added_edge_rule[idx] edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) changes_cnt += changes @@ -500,9 +500,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: # Check for inconsistencies - if check_consistent_edge(interpretations_edge, comp, (l, bnd)): + if check_consistent_edge(interpretations_edge, comp, (label, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) update = u or update # Update convergence params @@ -513,9 +513,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_edge(interpretations_edge, comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) update = u or update # Update convergence params @@ -644,16 +644,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi return fp_cnt, t - def add_edge(self, edge, l): + def add_edge(self, edge, label): # This function is useful for pyreason gym, called externally - _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1) + _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, label, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1) def add_node(self, node, labels): # This function is useful for pyreason gym, called externally if node not in self.nodes: _add_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node) - for l in labels: - self.interpretations_node[node].world[label.Label(l)] = interval.closed(0, 1) + for label in labels: + self.interpretations_node[node].world[label.Label(label)] = interval.closed(0, 1) def delete_edge(self, edge): # This function is useful for pyreason gym, called externally @@ -678,23 +678,23 @@ def get_dict(self): # Update interpretation nodes for change in self.rule_trace_node: - time, _, node, l, bnd = change - interpretations[time][node][l._value] = (bnd.lower, bnd.upper) + time, _, node, label, bnd = change + interpretations[time][node][label._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][node][l._value] = (bnd.lower, bnd.upper) + interpretations[t][node][label._value] = (bnd.lower, bnd.upper) # Update interpretation edges for change in self.rule_trace_edge: - time, _, edge, l, bnd, = change - interpretations[time][edge][l._value] = (bnd.lower, bnd.upper) + time, _, edge, label, bnd, = change + interpretations[time][edge][label._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][edge][l._value] = (bnd.lower, bnd.upper) + interpretations[t][edge][label._value] = (bnd.lower, bnd.upper) return interpretations @@ -706,10 +706,10 @@ def get_final_num_ground_atoms(self): ga_cnt = 0 for node in self.nodes: - for l in self.interpretations_node[node].world: + for lbl in self.interpretations_node[node].world: ga_cnt += 1 for edge in self.edges: - for l in self.interpretations_edge[edge].world: + for lbl in self.interpretations_edge[edge].world: ga_cnt += 1 return ga_cnt @@ -807,7 +807,7 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map clause_label = clause[1] clause_variables = clause[2] clause_bnd = clause[3] - clause_operator = clause[4] + # clause_operator = clause[4] # Currently unused # This is a node clause if clause_type == 'node': @@ -1303,17 +1303,17 @@ def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, @numba.njit(cache=True) -def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes): +def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, label, nodes): # The groundings for a node clause can be either a previous grounding or all possible nodes - if l in predicate_map: - grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1] + if label in predicate_map: + grounding = predicate_map[label] if clause_var_1 not in groundings else groundings[clause_var_1] else: grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1] return grounding @numba.njit(cache=True) -def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges): +def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, label, edges): # There are 4 cases for predicate(Y,Z): # 1. Both predicate variables Y and Z have not been encountered before # 2. The source variable Y has not been encountered before but the target variable Z has @@ -1324,8 +1324,8 @@ def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groun # Case 1: # We replace Y by all nodes and Z by the neighbors of each of these nodes if clause_var_1 not in groundings and clause_var_2 not in groundings: - if l in predicate_map: - edge_groundings = predicate_map[l] + if label in predicate_map: + edge_groundings = predicate_map[label] else: edge_groundings = edges @@ -1419,34 +1419,34 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] - l, bnd = na + label, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if l not in world.world: - world.world[l] = interval.closed(0, 1) + if label not in world.world: + world.world[label] = interval.closed(0, 1) num_ga[t_cnt] += 1 - if l in predicate_map: - predicate_map[l].append(comp) + if label in predicate_map: + predicate_map[label].append(comp) else: - predicate_map[l] = numba.typed.List([comp]) + predicate_map[label] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[l].copy() + prev_bnd = world.world[label].copy() # override will not check for inconsistencies if override: - world.world[l].set_lower_upper(bnd.lower, bnd.upper) + world.world[label].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(l, bnd) - world.world[l].set_static(static) - if world.world[l]!=prev_bnd: + world.update(label, bnd) + world.world[label].set_static(static) + if world.world[label]!=prev_bnd: updated = True - updated_bnds.append(world.world[l]) + updated_bnds.append(world.world[label]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1462,7 +1462,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == l: + if p1 == label: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1470,7 +1470,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1479,7 +1479,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == l: + if p2 == label: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1487,7 +1487,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1501,8 +1501,8 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[l] - prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) + current_bnd = world.world[label] + prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1515,7 +1515,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c return (updated, change) - except: + except Exception: return (False, 0) @@ -1525,34 +1525,34 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] - l, bnd = na + label, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if l not in world.world: - world.world[l] = interval.closed(0, 1) + if label not in world.world: + world.world[label] = interval.closed(0, 1) num_ga[t_cnt] += 1 - if l in predicate_map: - predicate_map[l].append(comp) + if label in predicate_map: + predicate_map[label].append(comp) else: - predicate_map[l] = numba.typed.List([comp]) + predicate_map[label] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[l].copy() + prev_bnd = world.world[label].copy() # override will not check for inconsistencies if override: - world.world[l].set_lower_upper(bnd.lower, bnd.upper) + world.world[label].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(l, bnd) - world.world[l].set_static(static) - if world.world[l]!=prev_bnd: + world.update(label, bnd) + world.world[label].set_static(static) + if world.world[label]!=prev_bnd: updated = True - updated_bnds.append(world.world[l]) + updated_bnds.append(world.world[label]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1568,7 +1568,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == l: + if p1 == label: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1576,7 +1576,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1585,7 +1585,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == l: + if p2 == label: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1593,7 +1593,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1607,8 +1607,8 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[l] - prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) + current_bnd = world.world[label] + prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1620,7 +1620,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 1 + ip_update_cnt return (updated, change) - except: + except Exception: return (False, 0) @@ -1632,8 +1632,8 @@ def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name): @numba.njit(cache=True) def are_satisfied_node(interpretations, comp, nas): result = True - for (l, bnd) in nas: - result = result and is_satisfied_node(interpretations, comp, (l, bnd)) + for (lbl, bnd) in nas: + result = result and is_satisfied_node(interpretations, comp, (lbl, bnd)) return result @@ -1645,7 +1645,7 @@ def is_satisfied_node(interpretations, comp, na): try: world = interpretations[comp] result = world.is_satisfied(na[0], na[1]) - except: + except Exception: result = False else: result = True @@ -1656,23 +1656,23 @@ def is_satisfied_node(interpretations, comp, na): def is_satisfied_node_comparison(interpretations, comp, na): result = False number = 0 - l, bnd = na - l_str = l.value + label, bnd = na + label_str = label.value - if not (l is None or bnd is None): + if not (label is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): world_l_str = world_l.value - if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): + if label_str in world_l_str and world_l_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world_l_str[len(l_str)+1:]) + number = str_to_float(world_l_str[len(label_str)+1:]) break - except: + except Exception: result = False else: result = True @@ -1682,8 +1682,8 @@ def is_satisfied_node_comparison(interpretations, comp, na): @numba.njit(cache=True) def are_satisfied_edge(interpretations, comp, nas): result = True - for (l, bnd) in nas: - result = result and is_satisfied_edge(interpretations, comp, (l, bnd)) + for (lbl, bnd) in nas: + result = result and is_satisfied_edge(interpretations, comp, (lbl, bnd)) return result @@ -1695,7 +1695,7 @@ def is_satisfied_edge(interpretations, comp, na): try: world = interpretations[comp] result = world.is_satisfied(na[0], na[1]) - except: + except Exception: result = False else: result = True @@ -1706,23 +1706,23 @@ def is_satisfied_edge(interpretations, comp, na): def is_satisfied_edge_comparison(interpretations, comp, na): result = False number = 0 - l, bnd = na - l_str = l.value + label, bnd = na + label_str = label.value - if not (l is None or bnd is None): + if not (label is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): world_l_str = world_l.value - if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): + if label_str in world_l_str and world_l_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world_l_str[len(l_str)+1:]) + number = str_to_float(world_l_str[len(label_str)+1:]) break - except: + except Exception: result = False else: result = True @@ -1846,7 +1846,7 @@ def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): @numba.njit(cache=True) -def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): +def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t): # If not a node, add to list of nodes and initialize neighbors if source not in nodes: _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node) @@ -1855,8 +1855,8 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int _add_node(target, neighbors, reverse_neighbors, nodes, interpretations_node) # Make sure edge doesn't already exist - # Make sure, if l=='', not to add the label - # Make sure, if edge exists, that we don't override the l label if it exists + # Make sure, if label=='', not to add the label + # Make sure, if edge exists, that we don't override the label label if it exists edge = (source, target) new_edge = False if edge not in edges: @@ -1864,36 +1864,36 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int edges.append(edge) neighbors[source].append(target) reverse_neighbors[target].append(source) - if l.value!='': - interpretations_edge[edge] = world.World(numba.typed.List([l])) + if label.value!='': + interpretations_edge[edge] = world.World(numba.typed.List([label])) num_ga[t] += 1 - if l in predicate_map: - predicate_map[l].append(edge) + if label in predicate_map: + predicate_map[label].append(edge) else: - predicate_map[l] = numba.typed.List([edge]) + predicate_map[label] = numba.typed.List([edge]) else: interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type)) else: - if l not in interpretations_edge[edge].world and l.value!='': + if label not in interpretations_edge[edge].world and label.value!='': new_edge = True - interpretations_edge[edge].world[l] = interval.closed(0, 1) + interpretations_edge[edge].world[label] = interval.closed(0, 1) num_ga[t] += 1 - if l in predicate_map: - predicate_map[l].append(edge) + if label in predicate_map: + predicate_map[label].append(edge) else: - predicate_map[l] = numba.typed.List([edge]) + predicate_map[label] = numba.typed.List([edge]) return edge, new_edge @numba.njit(cache=True) -def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): +def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t): changes = 0 edges_added = numba.typed.List.empty_list(edge_type) for source in sources: for target in targets: - edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t) + edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t) edges_added.append(edge) changes = changes+1 if new_edge else changes return edges_added, changes @@ -1905,9 +1905,9 @@ def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge edges.remove(edge) num_ga[-1] -= len(interpretations_edge[edge].world) del interpretations_edge[edge] - for l in predicate_map: - if edge in predicate_map[l]: - predicate_map[l].remove(edge) + for lbl in predicate_map: + if edge in predicate_map[lbl]: + predicate_map[lbl].remove(edge) neighbors[source].remove(target) reverse_neighbors[target].remove(source) @@ -1919,9 +1919,9 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node del interpretations_node[node] del neighbors[node] del reverse_neighbors[node] - for l in predicate_map: - if node in predicate_map[l]: - predicate_map[l].remove(node) + for lbl in predicate_map: + if node in predicate_map[lbl]: + predicate_map[lbl].remove(node) # Remove all occurrences of node in neighbors for n in neighbors.keys(): diff --git a/pyreason/scripts/interpretation/interpretation.py.bak b/pyreason/scripts/interpretation/interpretation.py.bak new file mode 100755 index 00000000..81bf6bb0 --- /dev/null +++ b/pyreason/scripts/interpretation/interpretation.py.bak @@ -0,0 +1,1967 @@ +from typing import Union, Tuple + +import pyreason.scripts.numba_wrapper.numba_types.world_type as world +import pyreason.scripts.numba_wrapper.numba_types.label_type as label +import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval +from pyreason.scripts.interpretation.interpretation_dict import InterpretationDict + +import numba +from numba import objmode, prange + + +# Types for the dictionaries +node_type = numba.types.string +edge_type = numba.types.UniTuple(numba.types.string, 2) + +# Type for storing list of qualified nodes/edges +list_of_nodes = numba.types.ListType(node_type) +list_of_edges = numba.types.ListType(edge_type) + +# Type for storing clause data +clause_data = numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string))) + +# Type for storing refine clause data +refine_data = numba.types.Tuple((numba.types.string, numba.types.string, numba.types.int8)) + +# Type for facts to be applied +facts_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) +facts_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) + +# Type for returning list of applicable rules for a certain rule +# node/edge, annotations, qualified nodes, qualified edges, edges to be added +node_applicable_rule_type = numba.types.Tuple(( + node_type, + numba.types.ListType(numba.types.ListType(interval.interval_type)), + numba.types.ListType(numba.types.ListType(node_type)), + numba.types.ListType(numba.types.ListType(edge_type)), + numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)) +)) + +edge_applicable_rule_type = numba.types.Tuple(( + edge_type, + numba.types.ListType(numba.types.ListType(interval.interval_type)), + numba.types.ListType(numba.types.ListType(node_type)), + numba.types.ListType(numba.types.ListType(edge_type)), + numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)) +)) + +rules_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean)) +rules_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean)) +rules_to_be_applied_trace_type = numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string)) +edges_to_be_added_type = numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)) + + +class Interpretation: + specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(node_type)) + specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(edge_type)) + + def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules): + self.graph = graph + self.ipl = ipl + self.annotation_functions = annotation_functions + self.reverse_graph = reverse_graph + self.atom_trace = atom_trace + self.save_graph_attributes_to_rule_trace = save_graph_attributes_to_rule_trace + self.persistent = persistent + self.inconsistency_check = inconsistency_check + self.store_interpretation_changes = store_interpretation_changes + self.update_mode = update_mode + self.allow_ground_rules = allow_ground_rules + + # Counter for number of ground atoms for each timestep, start with zero for the zeroth timestep + self.num_ga = numba.typed.List.empty_list(numba.types.int64) + self.num_ga.append(0) + + # For reasoning and reasoning again (contains previous time and previous fp operation cnt) + self.time = 0 + self.prev_reasoning_data = numba.typed.List([0, 0]) + + # Initialize list of tuples for rules/facts to be applied, along with all the ground atoms that fired the rule. One to One correspondence between rules_to_be_applied_node and rules_to_be_applied_node_trace if atom_trace is true + self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type) + self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type) + self.facts_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.string) + self.facts_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.string) + self.rules_to_be_applied_node = numba.typed.List.empty_list(rules_to_be_applied_node_type) + self.rules_to_be_applied_edge = numba.typed.List.empty_list(rules_to_be_applied_edge_type) + self.facts_to_be_applied_node = numba.typed.List.empty_list(facts_to_be_applied_node_type) + self.facts_to_be_applied_edge = numba.typed.List.empty_list(facts_to_be_applied_edge_type) + self.edges_to_be_added_node_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))) + self.edges_to_be_added_edge_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))) + + # Keep track of all the rules that have affected each node/edge at each timestep/fp operation, and all ground atoms that have affected the rules as well. Keep track of previous bounds and name of the rule/fact here + self.rule_trace_node_atoms = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), interval.interval_type, numba.types.string))) + self.rule_trace_edge_atoms = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), interval.interval_type, numba.types.string))) + self.rule_trace_node = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, numba.types.uint16, node_type, label.label_type, interval.interval_type))) + self.rule_trace_edge = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, numba.types.uint16, edge_type, label.label_type, interval.interval_type))) + + # Nodes and edges of the graph + self.nodes = numba.typed.List.empty_list(node_type) + self.edges = numba.typed.List.empty_list(edge_type) + self.nodes.extend(numba.typed.List(self.graph.nodes())) + self.edges.extend(numba.typed.List(self.graph.edges())) + + self.interpretations_node, self.predicate_map_node = self._init_interpretations_node(self.nodes, self.specific_node_labels, self.num_ga) + self.interpretations_edge, self.predicate_map_edge = self._init_interpretations_edge(self.edges, self.specific_edge_labels, self.num_ga) + + # Setup graph neighbors and reverse neighbors + self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type)) + for n in self.graph.nodes(): + l = numba.typed.List.empty_list(node_type) + [l.append(neigh) for neigh in self.graph.neighbors(n)] + self.neighbors[n] = l + + self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors) + + @staticmethod + @numba.njit(cache=True) + def _init_reverse_neighbors(neighbors): + reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) + for n, neighbor_nodes in neighbors.items(): + for neighbor_node in neighbor_nodes: + if neighbor_node in reverse_neighbors and n not in reverse_neighbors[neighbor_node]: + reverse_neighbors[neighbor_node].append(n) + else: + reverse_neighbors[neighbor_node] = numba.typed.List([n]) + # This makes sure each node has a value + if n not in reverse_neighbors: + reverse_neighbors[n] = numba.typed.List.empty_list(node_type) + + return reverse_neighbors + + @staticmethod + @numba.njit(cache=True) + def _init_interpretations_node(nodes, specific_labels, num_ga): + interpretations = numba.typed.Dict.empty(key_type=node_type, value_type=world.world_type) + predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_nodes) + + # Initialize nodes + for n in nodes: + interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) + + # Specific labels + for l, ns in specific_labels.items(): + predicate_map[l] = numba.typed.List(ns) + for n in ns: + interpretations[n].world[l] = interval.closed(0.0, 1.0) + num_ga[0] += 1 + + return interpretations, predicate_map + + @staticmethod + @numba.njit(cache=True) + def _init_interpretations_edge(edges, specific_labels, num_ga): + interpretations = numba.typed.Dict.empty(key_type=edge_type, value_type=world.world_type) + predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_edges) + + # Initialize edges + for n in edges: + interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) + + # Specific labels + for l, es in specific_labels.items(): + predicate_map[l] = numba.typed.List(es) + for e in es: + interpretations[e].world[l] = interval.closed(0.0, 1.0) + num_ga[0] += 1 + + return interpretations, predicate_map + + @staticmethod + @numba.njit(cache=True) + def _init_convergence(convergence_bound_threshold, convergence_threshold): + if convergence_bound_threshold==-1 and convergence_threshold==-1: + convergence_mode = 'perfect_convergence' + convergence_delta = 0 + elif convergence_bound_threshold==-1: + convergence_mode = 'delta_interpretation' + convergence_delta = convergence_threshold + else: + convergence_mode = 'delta_bound' + convergence_delta = convergence_bound_threshold + return convergence_mode, convergence_delta + + def start_fp(self, tmax, facts_node, facts_edge, rules, verbose, convergence_threshold, convergence_bound_threshold, again=False, restart=True): + self.tmax = tmax + self._convergence_mode, self._convergence_delta = self._init_convergence(convergence_bound_threshold, convergence_threshold) + max_facts_time = self._init_facts(facts_node, facts_edge, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.atom_trace) + self._start_fp(rules, max_facts_time, verbose, again, restart) + + @staticmethod + @numba.njit(cache=True) + def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, atom_trace): + max_time = 0 + for fact in facts_node: + for t in range(fact.get_time_lower(), fact.get_time_upper() + 1): + max_time = max(max_time, t) + name = fact.get_name() + graph_attribute = True if name=='graph-attribute-fact' else False + facts_to_be_applied_node.append((numba.types.uint16(t), fact.get_component(), fact.get_label(), fact.get_bound(), fact.static, graph_attribute)) + if atom_trace: + facts_to_be_applied_node_trace.append(fact.get_name()) + for fact in facts_edge: + for t in range(fact.get_time_lower(), fact.get_time_upper() + 1): + max_time = max(max_time, t) + name = fact.get_name() + graph_attribute = True if name=='graph-attribute-fact' else False + facts_to_be_applied_edge.append((numba.types.uint16(t), fact.get_component(), fact.get_label(), fact.get_bound(), fact.static, graph_attribute)) + if atom_trace: + facts_to_be_applied_edge_trace.append(fact.get_name()) + return max_time + + def _start_fp(self, rules, max_facts_time, verbose, again, restart): + if again: + self.num_ga.append(self.num_ga[-1]) + if restart: + self.time = 0 + self.prev_reasoning_data[0] = 0 + fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.persistent, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, self.num_ga, verbose, again) + self.time = t - 1 + # If we need to reason again, store the next timestep to start from + self.prev_reasoning_data[0] = t + self.prev_reasoning_data[1] = fp_cnt + if verbose: + print('Fixed Point iterations:', fp_cnt) + + @staticmethod + @numba.njit(cache=True, parallel=False) + def reason(interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, convergence_mode, convergence_delta, num_ga, verbose, again): + t = prev_reasoning_data[0] + fp_cnt = prev_reasoning_data[1] + max_rules_time = 0 + timestep_loop = True + facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type) + facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type) + facts_to_be_applied_node_trace_new = numba.typed.List.empty_list(numba.types.string) + facts_to_be_applied_edge_trace_new = numba.typed.List.empty_list(numba.types.string) + rules_to_remove_idx = set() + rules_to_remove_idx.add(-1) + while timestep_loop: + if t==tmax: + timestep_loop = False + if verbose: + with objmode(): + print('Timestep:', t, flush=True) + # Reset Interpretation at beginning of timestep if non-persistent + if t>0 and not persistent: + # Reset nodes (only if not static) + for n in nodes: + w = interpretations_node[n].world + for l in w: + if not w[l].is_static(): + w[l].reset() + + # Reset edges (only if not static) + for e in edges: + w = interpretations_edge[e].world + for l in w: + if not w[l].is_static(): + w[l].reset() + + # Convergence parameters + changes_cnt = 0 + bound_delta = 0 + update = False + + # Start by applying facts + # Nodes + facts_to_be_applied_node_new.clear() + facts_to_be_applied_node_trace_new.clear() + nodes_set = set(nodes) + for i in range(len(facts_to_be_applied_node)): + if facts_to_be_applied_node[i][0] == t: + comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] + # If the component is not in the graph, add it + if comp not in nodes_set: + _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node) + nodes_set.add(comp) + + # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well + if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static(): + # Check if we should even store any of the changes to the rule trace etc. + # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute + if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: + rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, bnd)) + if atom_trace: + _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_node_trace[i]) + for p1, p2 in ipl: + if p1==l: + rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_node[comp].world[p2])) + if atom_trace: + _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p2], facts_to_be_applied_node_trace[i]) + elif p2==l: + rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1])) + if atom_trace: + _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i]) + + else: + # Check for inconsistencies (multiple facts) + if check_consistent_node(interpretations_node, comp, (l, bnd)): + mode = 'graph-attribute-fact' if graph_attribute else 'fact' + override = True if update_mode == 'override' else False + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) + + update = u or update + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + # Resolve inconsistency if necessary otherwise override bounds + else: + mode = 'graph-attribute-fact' if graph_attribute else 'fact' + if inconsistency_check: + resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) + else: + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) + + update = u or update + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + + if static: + facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute)) + if atom_trace: + facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i]) + + # If time doesn't match, fact to be applied later + else: + facts_to_be_applied_node_new.append(facts_to_be_applied_node[i]) + if atom_trace: + facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i]) + + # Update list of facts with ones that have not been applied yet (delete applied facts) + facts_to_be_applied_node[:] = facts_to_be_applied_node_new.copy() + if atom_trace: + facts_to_be_applied_node_trace[:] = facts_to_be_applied_node_trace_new.copy() + facts_to_be_applied_node_new.clear() + facts_to_be_applied_node_trace_new.clear() + + # Edges + facts_to_be_applied_edge_new.clear() + facts_to_be_applied_edge_trace_new.clear() + edges_set = set(edges) + for i in range(len(facts_to_be_applied_edge)): + if facts_to_be_applied_edge[i][0]==t: + comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] + # If the component is not in the graph, add it + if comp not in edges_set: + _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) + edges_set.add(comp) + + # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well + if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static(): + # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute + if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: + rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, interpretations_edge[comp].world[l])) + if atom_trace: + _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_edge_trace[i]) + for p1, p2 in ipl: + if p1==l: + rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_edge[comp].world[p2])) + if atom_trace: + _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p2], facts_to_be_applied_edge_trace[i]) + elif p2==l: + rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_edge[comp].world[p1])) + if atom_trace: + _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p1], facts_to_be_applied_edge_trace[i]) + else: + # Check for inconsistencies + if check_consistent_edge(interpretations_edge, comp, (l, bnd)): + mode = 'graph-attribute-fact' if graph_attribute else 'fact' + override = True if update_mode == 'override' else False + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) + + update = u or update + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + # Resolve inconsistency + else: + mode = 'graph-attribute-fact' if graph_attribute else 'fact' + if inconsistency_check: + resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) + else: + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) + + update = u or update + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + + if static: + facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute)) + if atom_trace: + facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i]) + + # Time doesn't match, fact to be applied later + else: + facts_to_be_applied_edge_new.append(facts_to_be_applied_edge[i]) + if atom_trace: + facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i]) + + # Update list of facts with ones that have not been applied yet (delete applied facts) + facts_to_be_applied_edge[:] = facts_to_be_applied_edge_new.copy() + if atom_trace: + facts_to_be_applied_edge_trace[:] = facts_to_be_applied_edge_trace_new.copy() + facts_to_be_applied_edge_new.clear() + facts_to_be_applied_edge_trace_new.clear() + + in_loop = True + while in_loop: + # This will become true only if delta_t = 0 for some rule, otherwise we go to the next timestep + in_loop = False + + # Apply the rules that need to be applied at this timestep + # Nodes + rules_to_remove_idx.clear() + for idx, i in enumerate(rules_to_be_applied_node): + if i[0] == t: + comp, l, bnd, set_static = i[1], i[2], i[3], i[4] + # Check for inconsistencies + if check_consistent_node(interpretations_node, comp, (l, bnd)): + override = True if update_mode == 'override' else False + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + + update = u or update + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + # Resolve inconsistency + else: + if inconsistency_check: + resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') + else: + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + + update = u or update + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + + # Delete rules that have been applied from list by adding index to list + rules_to_remove_idx.add(idx) + + # Remove from rules to be applied and edges to be applied lists after coming out from loop + rules_to_be_applied_node[:] = numba.typed.List([rules_to_be_applied_node[i] for i in range(len(rules_to_be_applied_node)) if i not in rules_to_remove_idx]) + edges_to_be_added_node_rule[:] = numba.typed.List([edges_to_be_added_node_rule[i] for i in range(len(edges_to_be_added_node_rule)) if i not in rules_to_remove_idx]) + if atom_trace: + rules_to_be_applied_node_trace[:] = numba.typed.List([rules_to_be_applied_node_trace[i] for i in range(len(rules_to_be_applied_node_trace)) if i not in rules_to_remove_idx]) + + # Edges + rules_to_remove_idx.clear() + for idx, i in enumerate(rules_to_be_applied_edge): + if i[0] == t: + comp, l, bnd, set_static = i[1], i[2], i[3], i[4] + sources, targets, edge_l = edges_to_be_added_edge_rule[idx] + edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) + changes_cnt += changes + + # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally + if edge_l.value != '': + for e in edges_added: + if interpretations_edge[e].world[edge_l].is_static(): + continue + if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)): + override = True if update_mode == 'override' else False + u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + + update = u or update + + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + # Resolve inconsistency + else: + if inconsistency_check: + resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') + else: + u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + + update = u or update + + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + + else: + # Check for inconsistencies + if check_consistent_edge(interpretations_edge, comp, (l, bnd)): + override = True if update_mode == 'override' else False + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + + update = u or update + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + # Resolve inconsistency + else: + if inconsistency_check: + resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') + else: + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + + update = u or update + # Update convergence params + if convergence_mode=='delta_bound': + bound_delta = max(bound_delta, changes) + else: + changes_cnt += changes + + # Delete rules that have been applied from list by adding the index to list + rules_to_remove_idx.add(idx) + + # Remove from rules to be applied and edges to be applied lists after coming out from loop + rules_to_be_applied_edge[:] = numba.typed.List([rules_to_be_applied_edge[i] for i in range(len(rules_to_be_applied_edge)) if i not in rules_to_remove_idx]) + edges_to_be_added_edge_rule[:] = numba.typed.List([edges_to_be_added_edge_rule[i] for i in range(len(edges_to_be_added_edge_rule)) if i not in rules_to_remove_idx]) + if atom_trace: + rules_to_be_applied_edge_trace[:] = numba.typed.List([rules_to_be_applied_edge_trace[i] for i in range(len(rules_to_be_applied_edge_trace)) if i not in rules_to_remove_idx]) + + # Fixed point + if update: + # Increase fp operator count + fp_cnt += 1 + + # Lists or threadsafe operations (when parallel is on) + rules_to_be_applied_node_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_node_type) for _ in range(len(rules))]) + rules_to_be_applied_edge_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_edge_type) for _ in range(len(rules))]) + if atom_trace: + rules_to_be_applied_node_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))]) + rules_to_be_applied_edge_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))]) + edges_to_be_added_edge_rule_threadsafe = numba.typed.List([numba.typed.List.empty_list(edges_to_be_added_type) for _ in range(len(rules))]) + + for i in prange(len(rules)): + rule = rules[i] + + # Only go through if the rule can be applied within the given timesteps, or we're running until convergence + delta_t = rule.get_delta() + if t + delta_t <= tmax or tmax == -1 or again: + applicable_node_rules, applicable_edge_rules = _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules, num_ga, t) + + # Loop through applicable rules and add them to the rules to be applied for later or next fp operation + for applicable_rule in applicable_node_rules: + n, annotations, qualified_nodes, qualified_edges, _ = applicable_rule + # If there is an edge to add or the predicate doesn't exist or the interpretation is not static + if rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static(): + bnd = annotate(annotation_functions, rule, annotations, rule.get_weights()) + # Bound annotations in between 0 and 1 + bnd_l = min(max(bnd[0], 0), 1) + bnd_u = min(max(bnd[1], 0), 1) + bnd = interval.closed(bnd_l, bnd_u) + max_rules_time = max(max_rules_time, t + delta_t) + rules_to_be_applied_node_threadsafe[i].append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, rule.is_static_rule())) + if atom_trace: + rules_to_be_applied_node_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name())) + + # If delta_t is zero we apply the rules and check if more are applicable + if delta_t == 0: + in_loop = True + update = False + + for applicable_rule in applicable_edge_rules: + e, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule + # If there is an edge to add or the predicate doesn't exist or the interpretation is not static + if len(edges_to_add[0]) > 0 or rule.get_target() not in interpretations_edge[e].world or not interpretations_edge[e].world[rule.get_target()].is_static(): + bnd = annotate(annotation_functions, rule, annotations, rule.get_weights()) + # Bound annotations in between 0 and 1 + bnd_l = min(max(bnd[0], 0), 1) + bnd_u = min(max(bnd[1], 0), 1) + bnd = interval.closed(bnd_l, bnd_u) + max_rules_time = max(max_rules_time, t+delta_t) + # edges_to_be_added_edge_rule.append(edges_to_add) + edges_to_be_added_edge_rule_threadsafe[i].append(edges_to_add) + rules_to_be_applied_edge_threadsafe[i].append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, rule.is_static_rule())) + if atom_trace: + # rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name())) + rules_to_be_applied_edge_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name())) + + # If delta_t is zero we apply the rules and check if more are applicable + if delta_t == 0: + in_loop = True + update = False + + # Update lists after parallel run + for i in range(len(rules)): + if len(rules_to_be_applied_node_threadsafe[i]) > 0: + rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i]) + if len(rules_to_be_applied_edge_threadsafe[i]) > 0: + rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i]) + if atom_trace: + if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0: + rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i]) + if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0: + rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i]) + if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0: + edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i]) + + # Check for convergence after each timestep (perfect convergence or convergence specified by user) + # Check number of changed interpretations or max bound change + # User specified convergence + if convergence_mode == 'delta_interpretation': + if changes_cnt <= convergence_delta: + if verbose: + print(f'\nConverged at time: {t} with {int(changes_cnt)} changes from the previous interpretation') + # Be consistent with time returned when we don't converge + t += 1 + break + elif convergence_mode == 'delta_bound': + if bound_delta <= convergence_delta: + if verbose: + print(f'\nConverged at time: {t} with {float_to_str(bound_delta)} as the maximum bound change from the previous interpretation') + # Be consistent with time returned when we don't converge + t += 1 + break + # Perfect convergence + # Make sure there are no rules to be applied, and no facts that will be applied in the future. We do this by checking the max time any rule/fact is applicable + # If no more rules/facts to be applied + elif convergence_mode == 'perfect_convergence': + if t>=max_facts_time and t >= max_rules_time: + if verbose: + print(f'\nConverged at time: {t}') + # Be consistent with time returned when we don't converge + t += 1 + break + + # Increment t, update number of ground atoms + t += 1 + num_ga.append(num_ga[-1]) + + return fp_cnt, t + + def add_edge(self, edge, l): + # This function is useful for pyreason gym, called externally + _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1) + + def add_node(self, node, labels): + # This function is useful for pyreason gym, called externally + if node not in self.nodes: + _add_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node) + for l in labels: + self.interpretations_node[node].world[label.Label(l)] = interval.closed(0, 1) + + def delete_edge(self, edge): + # This function is useful for pyreason gym, called externally + _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge, self.predicate_map_edge, self.num_ga) + + def delete_node(self, node): + # This function is useful for pyreason gym, called externally + _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node, self.predicate_map_node, self.num_ga) + + def get_dict(self): + # This function can be called externally to retrieve a dict of the interpretation values + # Only values in the rule trace will be added + + # Initialize interpretations for each time and node and edge + interpretations = {} + for t in range(self.time+1): + interpretations[t] = {} + for node in self.nodes: + interpretations[t][node] = InterpretationDict() + for edge in self.edges: + interpretations[t][edge] = InterpretationDict() + + # Update interpretation nodes + for change in self.rule_trace_node: + time, _, node, l, bnd = change + interpretations[time][node][l._value] = (bnd.lower, bnd.upper) + + # If persistent, update all following timesteps as well + if self. persistent: + for t in range(time+1, self.time+1): + interpretations[t][node][l._value] = (bnd.lower, bnd.upper) + + # Update interpretation edges + for change in self.rule_trace_edge: + time, _, edge, l, bnd, = change + interpretations[time][edge][l._value] = (bnd.lower, bnd.upper) + + # If persistent, update all following timesteps as well + if self. persistent: + for t in range(time+1, self.time+1): + interpretations[t][edge][l._value] = (bnd.lower, bnd.upper) + + return interpretations + + def get_final_num_ground_atoms(self): + """ + This function returns the number of ground atoms after the reasoning process, for the final timestep + :return: int: Number of ground atoms in the interpretation after reasoning + """ + ga_cnt = 0 + + for node in self.nodes: + for l in self.interpretations_node[node].world: + ga_cnt += 1 + for edge in self.edges: + for l in self.interpretations_edge[edge].world: + ga_cnt += 1 + + return ga_cnt + + def get_num_ground_atoms(self): + """ + This function returns the number of ground atoms after the reasoning process, for each timestep + :return: list: Number of ground atoms in the interpretation after reasoning for each timestep + """ + if self.num_ga[-1] == 0: + self.num_ga.pop() + return self.num_ga + + def query(self, query, return_bool=True) -> Union[bool, Tuple[float, float]]: + """ + This function is used to query the graph after reasoning + :param query: A PyReason query object + :param return_bool: If True, returns boolean of query, else the bounds associated with it + :return: bool, or bounds + """ + + comp_type = query.get_component_type() + component = query.get_component() + pred = query.get_predicate() + bnd = query.get_bounds() + + # Check if the component exists + if comp_type == 'node': + if component not in self.nodes: + return False if return_bool else (0, 0) + else: + if component not in self.edges: + return False if return_bool else (0, 0) + + # Check if the predicate exists + if comp_type == 'node': + if pred not in self.interpretations_node[component].world: + return False if return_bool else (0, 0) + else: + if pred not in self.interpretations_edge[component].world: + return False if return_bool else (0, 0) + + # Check if the bounds are satisfied + if comp_type == 'node': + if self.interpretations_node[component].world[pred] in bnd: + return True if return_bool else (self.interpretations_node[component].world[pred].lower, self.interpretations_node[component].world[pred].upper) + else: + return False if return_bool else (0, 0) + else: + if self.interpretations_edge[component].world[pred] in bnd: + return True if return_bool else (self.interpretations_edge[component].world[pred].lower, self.interpretations_edge[component].world[pred].upper) + else: + return False if return_bool else (0, 0) + + +@numba.njit(cache=True) +def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules, num_ga, t): + # Extract rule params + rule_type = rule.get_type() + head_variables = rule.get_head_variables() + clauses = rule.get_clauses() + thresholds = rule.get_thresholds() + ann_fn = rule.get_annotation_function() + rule_edges = rule.get_edges() + + if rule_type == 'node': + head_var_1 = head_variables[0] + else: + head_var_1, head_var_2 = head_variables[0], head_variables[1] + + # We return a list of tuples which specify the target nodes/edges that have made the rule body true + applicable_rules_node = numba.typed.List.empty_list(node_applicable_rule_type) + applicable_rules_edge = numba.typed.List.empty_list(edge_applicable_rule_type) + + # Grounding procedure + # 1. Go through each clause and check which variables have not been initialized in groundings + # 2. Check satisfaction of variables based on the predicate in the clause + + # Grounding variable that maps variables in the body to a list of grounded nodes + # Grounding edges that maps edge variables to a list of edges + groundings = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes) + groundings_edges = numba.typed.Dict.empty(key_type=edge_type, value_type=list_of_edges) + + # Dependency graph that keeps track of the connections between the variables in the body + dependency_graph_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) + dependency_graph_reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) + + nodes_set = set(nodes) + edges_set = set(edges) + + satisfaction = True + for i, clause in enumerate(clauses): + # Unpack clause variables + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + clause_bnd = clause[3] + clause_operator = clause[4] + + # This is a node clause + if clause_type == 'node': + clause_var_1 = clause_variables[0] + + # Get subset of nodes that can be used to ground the variable + # If we allow ground atoms, we can use the nodes directly + if allow_ground_rules and clause_var_1 in nodes_set: + grounding = numba.typed.List([clause_var_1]) + else: + grounding = get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map_node, clause_label, nodes) + + # Narrow subset based on predicate + qualified_groundings = get_qualified_node_groundings(interpretations_node, grounding, clause_label, clause_bnd) + groundings[clause_var_1] = qualified_groundings + qualified_groundings_set = set(qualified_groundings) + for c1, c2 in groundings_edges: + if c1 == clause_var_1: + groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[0] in qualified_groundings_set]) + if c2 == clause_var_1: + groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[1] in qualified_groundings_set]) + + # Check satisfaction of those nodes wrt the threshold + # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule + # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges + # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0: + satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction + + # This is an edge clause + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + + # Get subset of edges that can be used to ground the variables + # If we allow ground atoms, we can use the nodes directly + if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set: + grounding = numba.typed.List([(clause_var_1, clause_var_2)]) + else: + grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label, edges) + + # Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster) + qualified_groundings = get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, clause_bnd) + + # Check satisfaction of those edges wrt the threshold + # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule + # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges + # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0: + satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction + + # Update the groundings + groundings[clause_var_1] = numba.typed.List.empty_list(node_type) + groundings[clause_var_2] = numba.typed.List.empty_list(node_type) + groundings_clause_1_set = set(groundings[clause_var_1]) + groundings_clause_2_set = set(groundings[clause_var_2]) + for e in qualified_groundings: + if e[0] not in groundings_clause_1_set: + groundings[clause_var_1].append(e[0]) + groundings_clause_1_set.add(e[0]) + if e[1] not in groundings_clause_2_set: + groundings[clause_var_2].append(e[1]) + groundings_clause_2_set.add(e[1]) + + # Update the edge groundings (to use later for grounding other clauses with the same variables) + groundings_edges[(clause_var_1, clause_var_2)] = qualified_groundings + + # Update dependency graph + # Add a connection between clause_var_1 -> clause_var_2 and vice versa + if clause_var_1 not in dependency_graph_neighbors: + dependency_graph_neighbors[clause_var_1] = numba.typed.List([clause_var_2]) + elif clause_var_2 not in dependency_graph_neighbors[clause_var_1]: + dependency_graph_neighbors[clause_var_1].append(clause_var_2) + if clause_var_2 not in dependency_graph_reverse_neighbors: + dependency_graph_reverse_neighbors[clause_var_2] = numba.typed.List([clause_var_1]) + elif clause_var_1 not in dependency_graph_reverse_neighbors[clause_var_2]: + dependency_graph_reverse_neighbors[clause_var_2].append(clause_var_1) + + # This is a comparison clause + else: + pass + + # Refine the subsets based on any updates + if satisfaction: + refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors) + + # If satisfaction is false, break + if not satisfaction: + break + + # If satisfaction is still true, one final refinement to check if each edge pair is valid in edge rules + # Then continue to setup any edges to be added and annotations + # Fill out the rules to be applied lists + if satisfaction: + # Create temp grounding containers to verify if the head groundings are valid (only for edge rules) + # Setup edges to be added and fill rules to be applied + # Setup traces and inputs for annotation function + # Loop through the clause data and setup final annotations and trace variables + # Three cases: 1.node rule, 2. edge rule with infer edges, 3. edge rule + if rule_type == 'node': + # Loop through all the head variable groundings and add it to the rules to be applied + # Loop through the clauses and add appropriate trace data and annotations + + # If there is no grounding for head_var_1, we treat it as a ground atom and add it to the graph + head_var_1_in_nodes = head_var_1 in nodes + add_head_var_node_to_graph = False + if allow_ground_rules and head_var_1_in_nodes: + groundings[head_var_1] = numba.typed.List([head_var_1]) + elif head_var_1 not in groundings: + if not head_var_1_in_nodes: + add_head_var_node_to_graph = True + groundings[head_var_1] = numba.typed.List([head_var_1]) + + for head_grounding in groundings[head_var_1]: + qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) + annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) + edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) + + # Check for satisfaction one more time in case the refining process has changed the groundings + satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges) + if not satisfaction: + continue + + for i, clause in enumerate(clauses): + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + + # 1. + if atom_trace: + if clause_var_1 == head_var_1: + qualified_nodes.append(numba.typed.List([head_grounding])) + else: + qualified_nodes.append(numba.typed.List(groundings[clause_var_1])) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1: + a.append(interpretations_node[head_grounding].world[clause_label]) + else: + for qn in groundings[clause_var_1]: + a.append(interpretations_node[qn].world[clause_label]) + annotations.append(a) + + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + # 1. + if atom_trace: + # Cases: Both equal, one equal, none equal + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + if clause_var_1 == head_var_1: + es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_grounding]) + qualified_edges.append(es) + elif clause_var_2 == head_var_1: + es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_grounding]) + qualified_edges.append(es) + else: + qualified_edges.append(numba.typed.List(groundings_edges[(clause_var_1, clause_var_2)])) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1: + for e in groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_2 == head_var_1: + for e in groundings_edges[(clause_var_1, clause_var_2)]: + if e[1] == head_grounding: + a.append(interpretations_edge[e].world[clause_label]) + else: + for qe in groundings_edges[(clause_var_1, clause_var_2)]: + a.append(interpretations_edge[qe].world[clause_label]) + annotations.append(a) + else: + # Comparison clause (we do not handle for now) + pass + + # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules) + if add_head_var_node_to_graph: + _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node) + + # For each grounding add a rule to be applied + applicable_rules_node.append((head_grounding, annotations, qualified_nodes, qualified_edges, edges_to_be_added)) + + elif rule_type == 'edge': + head_var_1 = head_variables[0] + head_var_2 = head_variables[1] + + # If there is no grounding for head_var_1 or head_var_2, we treat it as a ground atom and add it to the graph + head_var_1_in_nodes = head_var_1 in nodes + head_var_2_in_nodes = head_var_2 in nodes + add_head_var_1_node_to_graph = False + add_head_var_2_node_to_graph = False + add_head_edge_to_graph = False + if allow_ground_rules and head_var_1_in_nodes: + groundings[head_var_1] = numba.typed.List([head_var_1]) + if allow_ground_rules and head_var_2_in_nodes: + groundings[head_var_2] = numba.typed.List([head_var_2]) + + if head_var_1 not in groundings: + if not head_var_1_in_nodes: + add_head_var_1_node_to_graph = True + groundings[head_var_1] = numba.typed.List([head_var_1]) + if head_var_2 not in groundings: + if not head_var_2_in_nodes: + add_head_var_2_node_to_graph = True + groundings[head_var_2] = numba.typed.List([head_var_2]) + + # Artificially connect the head variables with an edge if both of them were not in the graph + if not head_var_1_in_nodes and not head_var_2_in_nodes: + add_head_edge_to_graph = True + + head_var_1_groundings = groundings[head_var_1] + head_var_2_groundings = groundings[head_var_2] + + source, target, _ = rule_edges + infer_edges = True if source != '' and target != '' else False + + # Prepare the edges that we will loop over. + # For infer edges we loop over each combination pair + # Else we loop over the valid edges in the graph + valid_edge_groundings = numba.typed.List.empty_list(edge_type) + for g1 in head_var_1_groundings: + for g2 in head_var_2_groundings: + if infer_edges: + valid_edge_groundings.append((g1, g2)) + else: + if (g1, g2) in edges_set: + valid_edge_groundings.append((g1, g2)) + + # Loop through the head variable groundings + for valid_e in valid_edge_groundings: + head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1] + qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) + annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) + edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) + + # Containers to keep track of groundings to make sure that the edge pair is valid + # We do this because we cannot know beforehand the edge matches from source groundings to target groundings + temp_groundings = groundings.copy() + temp_groundings_edges = groundings_edges.copy() + + # Refine the temp groundings for the specific edge head grounding + # We update the edge collection as well depending on if there's a match between the clause variables and head variables + temp_groundings[head_var_1] = numba.typed.List([head_var_1_grounding]) + temp_groundings[head_var_2] = numba.typed.List([head_var_2_grounding]) + for c1, c2 in temp_groundings_edges.keys(): + if c1 == head_var_1 and c2 == head_var_2: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_1_grounding, head_var_2_grounding)]) + elif c1 == head_var_2 and c2 == head_var_1: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_2_grounding, head_var_1_grounding)]) + elif c1 == head_var_1: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_1_grounding]) + elif c2 == head_var_1: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_1_grounding]) + elif c1 == head_var_2: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_2_grounding]) + elif c2 == head_var_2: + temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_2_grounding]) + + refine_groundings(head_variables, temp_groundings, temp_groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors) + + # Check if the thresholds are still satisfied + # Check if all clauses are satisfied again in case the refining process changed anything + satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, temp_groundings, temp_groundings_edges) + + if not satisfaction: + continue + + if infer_edges: + # Prevent self loops while inferring edges if the clause variables are not the same + if source != target and head_var_1_grounding == head_var_2_grounding: + continue + edges_to_be_added[0].append(head_var_1_grounding) + edges_to_be_added[1].append(head_var_2_grounding) + + for i, clause in enumerate(clauses): + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + # 1. + if atom_trace: + if clause_var_1 == head_var_1: + qualified_nodes.append(numba.typed.List([head_var_1_grounding])) + elif clause_var_1 == head_var_2: + qualified_nodes.append(numba.typed.List([head_var_2_grounding])) + else: + qualified_nodes.append(numba.typed.List(temp_groundings[clause_var_1])) + qualified_edges.append(numba.typed.List.empty_list(edge_type)) + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1: + a.append(interpretations_node[head_var_1_grounding].world[clause_label]) + elif clause_var_1 == head_var_2: + a.append(interpretations_node[head_var_2_grounding].world[clause_label]) + else: + for qn in temp_groundings[clause_var_1]: + a.append(interpretations_node[qn].world[clause_label]) + annotations.append(a) + + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + # 1. + if atom_trace: + # Cases: + # 1. Both equal (cv1 = hv1 and cv2 = hv2 or cv1 = hv2 and cv2 = hv1) + # 2. One equal (cv1 = hv1 or cv2 = hv1 or cv1 = hv2 or cv2 = hv2) + # 3. None equal + qualified_nodes.append(numba.typed.List.empty_list(node_type)) + if clause_var_1 == head_var_1 and clause_var_2 == head_var_2: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding]) + qualified_edges.append(es) + elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding]) + qualified_edges.append(es) + elif clause_var_1 == head_var_1: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding]) + qualified_edges.append(es) + elif clause_var_1 == head_var_2: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding]) + qualified_edges.append(es) + elif clause_var_2 == head_var_1: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_1_grounding]) + qualified_edges.append(es) + elif clause_var_2 == head_var_2: + es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_2_grounding]) + qualified_edges.append(es) + else: + qualified_edges.append(numba.typed.List(temp_groundings_edges[(clause_var_1, clause_var_2)])) + + # 2. + if ann_fn != '': + a = numba.typed.List.empty_list(interval.interval_type) + if clause_var_1 == head_var_1 and clause_var_2 == head_var_2: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_1 == head_var_1: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_1_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_1 == head_var_2: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[0] == head_var_2_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_2 == head_var_1: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[1] == head_var_1_grounding: + a.append(interpretations_edge[e].world[clause_label]) + elif clause_var_2 == head_var_2: + for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: + if e[1] == head_var_2_grounding: + a.append(interpretations_edge[e].world[clause_label]) + else: + for qe in temp_groundings_edges[(clause_var_1, clause_var_2)]: + a.append(interpretations_edge[qe].world[clause_label]) + annotations.append(a) + + # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules) + if add_head_var_1_node_to_graph and head_var_1_grounding == head_var_1: + _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node) + if add_head_var_2_node_to_graph and head_var_2_grounding == head_var_2: + _add_node(head_var_2, neighbors, reverse_neighbors, nodes, interpretations_node) + if add_head_edge_to_graph and (head_var_1, head_var_2) == (head_var_1_grounding, head_var_2_grounding): + _add_edge(head_var_1, head_var_2, neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) + + # For each grounding combination add a rule to be applied + # Only if all the clauses have valid groundings + # if satisfaction: + e = (head_var_1_grounding, head_var_2_grounding) + applicable_rules_edge.append((e, annotations, qualified_nodes, qualified_edges, edges_to_be_added)) + + # Return the applicable rules + return applicable_rules_node, applicable_rules_edge + + +@numba.njit(cache=True) +def check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges): + # Check if the thresholds are satisfied for each clause + satisfaction = True + for i, clause in enumerate(clauses): + # Unpack clause variables + clause_type = clause[0] + clause_label = clause[1] + clause_variables = clause[2] + + if clause_type == 'node': + clause_var_1 = clause_variables[0] + satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, groundings[clause_var_1], groundings[clause_var_1], clause_label, thresholds[i]) and satisfaction + elif clause_type == 'edge': + clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] + satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, groundings_edges[(clause_var_1, clause_var_2)], groundings_edges[(clause_var_1, clause_var_2)], clause_label, thresholds[i]) and satisfaction + return satisfaction + + +@numba.njit(cache=True) +def refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors): + # Loop through the dependency graph and refine the groundings that have connections + all_variables_refined = numba.typed.List(clause_variables) + variables_just_refined = numba.typed.List(clause_variables) + new_variables_refined = numba.typed.List.empty_list(numba.types.string) + while len(variables_just_refined) > 0: + for refined_variable in variables_just_refined: + # Refine all the neighbors of the refined variable + if refined_variable in dependency_graph_neighbors: + for neighbor in dependency_graph_neighbors[refined_variable]: + old_edge_groundings = groundings_edges[(refined_variable, neighbor)] + new_node_groundings = groundings[refined_variable] + + # Delete old groundings for the variable being refined + del groundings[neighbor] + groundings[neighbor] = numba.typed.List.empty_list(node_type) + + # Update the edge groundings and node groundings + qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[0] in new_node_groundings]) + groundings_neighbor_set = set(groundings[neighbor]) + for e in qualified_groundings: + if e[1] not in groundings_neighbor_set: + groundings[neighbor].append(e[1]) + groundings_neighbor_set.add(e[1]) + groundings_edges[(refined_variable, neighbor)] = qualified_groundings + + # Add the neighbor to the list of refined variables so that we can refine for all its neighbors + if neighbor not in all_variables_refined: + new_variables_refined.append(neighbor) + + if refined_variable in dependency_graph_reverse_neighbors: + for reverse_neighbor in dependency_graph_reverse_neighbors[refined_variable]: + old_edge_groundings = groundings_edges[(reverse_neighbor, refined_variable)] + new_node_groundings = groundings[refined_variable] + + # Delete old groundings for the variable being refined + del groundings[reverse_neighbor] + groundings[reverse_neighbor] = numba.typed.List.empty_list(node_type) + + # Update the edge groundings and node groundings + qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[1] in new_node_groundings]) + groundings_reverse_neighbor_set = set(groundings[reverse_neighbor]) + for e in qualified_groundings: + if e[0] not in groundings_reverse_neighbor_set: + groundings[reverse_neighbor].append(e[0]) + groundings_reverse_neighbor_set.add(e[0]) + groundings_edges[(reverse_neighbor, refined_variable)] = qualified_groundings + + # Add the neighbor to the list of refined variables so that we can refine for all its neighbors + if reverse_neighbor not in all_variables_refined: + new_variables_refined.append(reverse_neighbor) + + variables_just_refined = numba.typed.List(new_variables_refined) + all_variables_refined.extend(new_variables_refined) + new_variables_refined.clear() + + +@numba.njit(cache=True) +def check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_grounding, clause_label, threshold): + threshold_quantifier_type = threshold[1][1] + if threshold_quantifier_type == 'total': + neigh_len = len(grounding) + + # Available is all neighbors that have a particular label with bound inside [0,1] + elif threshold_quantifier_type == 'available': + neigh_len = len(get_qualified_node_groundings(interpretations_node, grounding, clause_label, interval.closed(0, 1))) + + qualified_neigh_len = len(qualified_grounding) + satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) + return satisfaction + + +@numba.njit(cache=True) +def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_grounding, clause_label, threshold): + threshold_quantifier_type = threshold[1][1] + if threshold_quantifier_type == 'total': + neigh_len = len(grounding) + + # Available is all neighbors that have a particular label with bound inside [0,1] + elif threshold_quantifier_type == 'available': + neigh_len = len(get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, interval.closed(0, 1))) + + qualified_neigh_len = len(qualified_grounding) + satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) + return satisfaction + + +@numba.njit(cache=True) +def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes): + # The groundings for a node clause can be either a previous grounding or all possible nodes + if l in predicate_map: + grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1] + else: + grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1] + return grounding + + +@numba.njit(cache=True) +def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges): + # There are 4 cases for predicate(Y,Z): + # 1. Both predicate variables Y and Z have not been encountered before + # 2. The source variable Y has not been encountered before but the target variable Z has + # 3. The target variable Z has not been encountered before but the source variable Y has + # 4. Both predicate variables Y and Z have been encountered before + edge_groundings = numba.typed.List.empty_list(edge_type) + + # Case 1: + # We replace Y by all nodes and Z by the neighbors of each of these nodes + if clause_var_1 not in groundings and clause_var_2 not in groundings: + if l in predicate_map: + edge_groundings = predicate_map[l] + else: + edge_groundings = edges + + # Case 2: + # We replace Y by the sources of Z + elif clause_var_1 not in groundings and clause_var_2 in groundings: + for n in groundings[clause_var_2]: + es = numba.typed.List([(nn, n) for nn in reverse_neighbors[n]]) + edge_groundings.extend(es) + + # Case 3: + # We replace Z by the neighbors of Y + elif clause_var_1 in groundings and clause_var_2 not in groundings: + for n in groundings[clause_var_1]: + es = numba.typed.List([(n, nn) for nn in neighbors[n]]) + edge_groundings.extend(es) + + # Case 4: + # We have seen both variables before + else: + # We have already seen these two variables in an edge clause + if (clause_var_1, clause_var_2) in groundings_edges: + edge_groundings = groundings_edges[(clause_var_1, clause_var_2)] + # We have seen both these variables but not in an edge clause together + else: + groundings_clause_var_2_set = set(groundings[clause_var_2]) + for n in groundings[clause_var_1]: + es = numba.typed.List([(n, nn) for nn in neighbors[n] if nn in groundings_clause_var_2_set]) + edge_groundings.extend(es) + + return edge_groundings + + +@numba.njit(cache=True) +def get_qualified_node_groundings(interpretations_node, grounding, clause_l, clause_bnd): + # Filter the grounding by the predicate and bound of the clause + qualified_groundings = numba.typed.List.empty_list(node_type) + for n in grounding: + if is_satisfied_node(interpretations_node, n, (clause_l, clause_bnd)): + qualified_groundings.append(n) + + return qualified_groundings + + +@numba.njit(cache=True) +def get_qualified_edge_groundings(interpretations_edge, grounding, clause_l, clause_bnd): + # Filter the grounding by the predicate and bound of the clause + qualified_groundings = numba.typed.List.empty_list(edge_type) + for e in grounding: + if is_satisfied_edge(interpretations_edge, e, (clause_l, clause_bnd)): + qualified_groundings.append(e) + + return qualified_groundings + + +@numba.njit(cache=True) +def _satisfies_threshold(num_neigh, num_qualified_component, threshold): + # Checks if qualified neighbors satisfy threshold. This is for one clause + if threshold[1][0]=='number': + if threshold[0]=='greater_equal': + result = True if num_qualified_component >= threshold[2] else False + elif threshold[0]=='greater': + result = True if num_qualified_component > threshold[2] else False + elif threshold[0]=='less_equal': + result = True if num_qualified_component <= threshold[2] else False + elif threshold[0]=='less': + result = True if num_qualified_component < threshold[2] else False + elif threshold[0]=='equal': + result = True if num_qualified_component == threshold[2] else False + + elif threshold[1][0]=='percent': + if num_neigh==0: + result = False + elif threshold[0]=='greater_equal': + result = True if num_qualified_component/num_neigh >= threshold[2]*0.01 else False + elif threshold[0]=='greater': + result = True if num_qualified_component/num_neigh > threshold[2]*0.01 else False + elif threshold[0]=='less_equal': + result = True if num_qualified_component/num_neigh <= threshold[2]*0.01 else False + elif threshold[0]=='less': + result = True if num_qualified_component/num_neigh < threshold[2]*0.01 else False + elif threshold[0]=='equal': + result = True if num_qualified_component/num_neigh == threshold[2]*0.01 else False + + return result + + +@numba.njit(cache=True) +def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, num_ga, mode, override=False): + updated = False + # This is to prevent a key error in case the label is a specific label + try: + world = interpretations[comp] + l, bnd = na + updated_bnds = numba.typed.List.empty_list(interval.interval_type) + + # Add label to world if it is not there + if l not in world.world: + world.world[l] = interval.closed(0, 1) + num_ga[t_cnt] += 1 + if l in predicate_map: + predicate_map[l].append(comp) + else: + predicate_map[l] = numba.typed.List([comp]) + + # Check if update is necessary with previous bnd + prev_bnd = world.world[l].copy() + + # override will not check for inconsistencies + if override: + world.world[l].set_lower_upper(bnd.lower, bnd.upper) + else: + world.update(l, bnd) + world.world[l].set_static(static) + if world.world[l]!=prev_bnd: + updated = True + updated_bnds.append(world.world[l]) + + # Add to rule trace if update happened and add to atom trace if necessary + if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) + if atom_trace: + # Mode can be fact or rule, updation of trace will happen accordingly + if mode=='fact' or mode=='graph-attribute-fact': + qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) + name = facts_to_be_applied_trace[idx] + _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name) + elif mode=='rule': + qn, qe, name = rules_to_be_applied_trace[idx] + _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name) + + # Update complement of predicate (if exists) based on new knowledge of predicate + if updated: + ip_update_cnt = 0 + for p1, p2 in ipl: + if p1 == l: + if p2 not in world.world: + world.world[p2] = interval.closed(0, 1) + if p2 in predicate_map: + predicate_map[p2].append(comp) + else: + predicate_map[p2] = numba.typed.List([comp]) + if atom_trace: + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') + lower = max(world.world[p2].lower, 1 - world.world[p1].upper) + upper = min(world.world[p2].upper, 1 - world.world[p1].lower) + world.world[p2].set_lower_upper(lower, upper) + world.world[p2].set_static(static) + ip_update_cnt += 1 + updated_bnds.append(world.world[p2]) + if store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) + if p2 == l: + if p1 not in world.world: + world.world[p1] = interval.closed(0, 1) + if p1 in predicate_map: + predicate_map[p1].append(comp) + else: + predicate_map[p1] = numba.typed.List([comp]) + if atom_trace: + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') + lower = max(world.world[p1].lower, 1 - world.world[p2].upper) + upper = min(world.world[p1].upper, 1 - world.world[p2].lower) + world.world[p1].set_lower_upper(lower, upper) + world.world[p1].set_static(static) + ip_update_cnt += 1 + updated_bnds.append(world.world[p1]) + if store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper))) + + # Gather convergence data + change = 0 + if updated: + # Find out if it has changed from previous interp + current_bnd = world.world[l] + prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) + if current_bnd != prev_t_bnd: + if convergence_mode=='delta_bound': + for i in updated_bnds: + lower_delta = abs(i.lower-prev_t_bnd.lower) + upper_delta = abs(i.upper-prev_t_bnd.upper) + max_delta = max(lower_delta, upper_delta) + change = max(change, max_delta) + else: + change = 1 + ip_update_cnt + + return (updated, change) + + except: + return (False, 0) + + +@numba.njit(cache=True) +def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, num_ga, mode, override=False): + updated = False + # This is to prevent a key error in case the label is a specific label + try: + world = interpretations[comp] + l, bnd = na + updated_bnds = numba.typed.List.empty_list(interval.interval_type) + + # Add label to world if it is not there + if l not in world.world: + world.world[l] = interval.closed(0, 1) + num_ga[t_cnt] += 1 + if l in predicate_map: + predicate_map[l].append(comp) + else: + predicate_map[l] = numba.typed.List([comp]) + + # Check if update is necessary with previous bnd + prev_bnd = world.world[l].copy() + + # override will not check for inconsistencies + if override: + world.world[l].set_lower_upper(bnd.lower, bnd.upper) + else: + world.update(l, bnd) + world.world[l].set_static(static) + if world.world[l]!=prev_bnd: + updated = True + updated_bnds.append(world.world[l]) + + # Add to rule trace if update happened and add to atom trace if necessary + if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) + if atom_trace: + # Mode can be fact or rule, updation of trace will happen accordingly + if mode=='fact' or mode=='graph-attribute-fact': + qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) + qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) + name = facts_to_be_applied_trace[idx] + _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name) + elif mode=='rule': + qn, qe, name = rules_to_be_applied_trace[idx] + _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name) + + # Update complement of predicate (if exists) based on new knowledge of predicate + if updated: + ip_update_cnt = 0 + for p1, p2 in ipl: + if p1 == l: + if p2 not in world.world: + world.world[p2] = interval.closed(0, 1) + if p2 in predicate_map: + predicate_map[p2].append(comp) + else: + predicate_map[p2] = numba.typed.List([comp]) + if atom_trace: + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') + lower = max(world.world[p2].lower, 1 - world.world[p1].upper) + upper = min(world.world[p2].upper, 1 - world.world[p1].lower) + world.world[p2].set_lower_upper(lower, upper) + world.world[p2].set_static(static) + ip_update_cnt += 1 + updated_bnds.append(world.world[p2]) + if store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) + if p2 == l: + if p1 not in world.world: + world.world[p1] = interval.closed(0, 1) + if p1 in predicate_map: + predicate_map[p1].append(comp) + else: + predicate_map[p1] = numba.typed.List([comp]) + if atom_trace: + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') + lower = max(world.world[p1].lower, 1 - world.world[p2].upper) + upper = min(world.world[p1].upper, 1 - world.world[p2].lower) + world.world[p1].set_lower_upper(lower, upper) + world.world[p1].set_static(static) + ip_update_cnt += 1 + updated_bnds.append(world.world[p2]) + if store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper))) + + # Gather convergence data + change = 0 + if updated: + # Find out if it has changed from previous interp + current_bnd = world.world[l] + prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) + if current_bnd != prev_t_bnd: + if convergence_mode=='delta_bound': + for i in updated_bnds: + lower_delta = abs(i.lower-prev_t_bnd.lower) + upper_delta = abs(i.upper-prev_t_bnd.upper) + max_delta = max(lower_delta, upper_delta) + change = max(change, max_delta) + else: + change = 1 + ip_update_cnt + + return (updated, change) + except: + return (False, 0) + + +@numba.njit(cache=True) +def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name): + rule_trace.append((qn, qe, prev_bnd.copy(), name)) + + +@numba.njit(cache=True) +def are_satisfied_node(interpretations, comp, nas): + result = True + for (l, bnd) in nas: + result = result and is_satisfied_node(interpretations, comp, (l, bnd)) + return result + + +@numba.njit(cache=True) +def is_satisfied_node(interpretations, comp, na): + result = False + if not (na[0] is None or na[1] is None): + # This is to prevent a key error in case the label is a specific label + try: + world = interpretations[comp] + result = world.is_satisfied(na[0], na[1]) + except: + result = False + else: + result = True + return result + + +@numba.njit(cache=True) +def is_satisfied_node_comparison(interpretations, comp, na): + result = False + number = 0 + l, bnd = na + l_str = l.value + + if not (l is None or bnd is None): + # This is to prevent a key error in case the label is a specific label + try: + world = interpretations[comp] + for world_l in world.world.keys(): + world_l_str = world_l.value + if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): + # The label is contained in the world + result = world.is_satisfied(world_l, na[1]) + # Find the suffix number + number = str_to_float(world_l_str[len(l_str)+1:]) + break + + except: + result = False + else: + result = True + return result, number + + +@numba.njit(cache=True) +def are_satisfied_edge(interpretations, comp, nas): + result = True + for (l, bnd) in nas: + result = result and is_satisfied_edge(interpretations, comp, (l, bnd)) + return result + + +@numba.njit(cache=True) +def is_satisfied_edge(interpretations, comp, na): + result = False + if not (na[0] is None or na[1] is None): + # This is to prevent a key error in case the label is a specific label + try: + world = interpretations[comp] + result = world.is_satisfied(na[0], na[1]) + except: + result = False + else: + result = True + return result + + +@numba.njit(cache=True) +def is_satisfied_edge_comparison(interpretations, comp, na): + result = False + number = 0 + l, bnd = na + l_str = l.value + + if not (l is None or bnd is None): + # This is to prevent a key error in case the label is a specific label + try: + world = interpretations[comp] + for world_l in world.world.keys(): + world_l_str = world_l.value + if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): + # The label is contained in the world + result = world.is_satisfied(world_l, na[1]) + # Find the suffix number + number = str_to_float(world_l_str[len(l_str)+1:]) + break + + except: + result = False + else: + result = True + return result, number + + +@numba.njit(cache=True) +def annotate(annotation_functions, rule, annotations, weights): + func_name = rule.get_annotation_function() + if func_name == '': + return rule.get_bnd().lower, rule.get_bnd().upper + else: + with numba.objmode(annotation='Tuple((float64, float64))'): + for func in annotation_functions: + if func.__name__ == func_name: + annotation = func(annotations, weights) + return annotation + + +@numba.njit(cache=True) +def check_consistent_node(interpretations, comp, na): + world = interpretations[comp] + if na[0] in world.world: + bnd = world.world[na[0]] + else: + bnd = interval.closed(0, 1) + if (na[1].lower > bnd.upper) or (bnd.lower > na[1].upper): + return False + else: + return True + + +@numba.njit(cache=True) +def check_consistent_edge(interpretations, comp, na): + world = interpretations[comp] + if na[0] in world.world: + bnd = world.world[na[0]] + else: + bnd = interval.closed(0, 1) + if (na[1].lower > bnd.upper) or (bnd.lower > na[1].upper): + return False + else: + return True + + +@numba.njit(cache=True) +def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode): + world = interpretations[comp] + if store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1))) + if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace: + name = facts_to_be_applied_trace[idx] + elif mode == 'rule' and atom_trace: + name = rules_to_be_applied_trace[idx][2] + else: + name = '-' + if atom_trace: + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], f'Inconsistency due to {name}') + # Resolve inconsistency and set static + world.world[na[0]].set_lower_upper(0, 1) + world.world[na[0]].set_static(True) + for p1, p2 in ipl: + if p1==na[0]: + if atom_trace: + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'Inconsistency due to {name}') + world.world[p2].set_lower_upper(0, 1) + world.world[p2].set_static(True) + if store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(0,1))) + + if p2==na[0]: + if atom_trace: + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'Inconsistency due to {name}') + world.world[p1].set_lower_upper(0, 1) + world.world[p1].set_static(True) + if store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1))) + # Add inconsistent predicates to a list + + +@numba.njit(cache=True) +def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode): + w = interpretations[comp] + if store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1))) + if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace: + name = facts_to_be_applied_trace[idx] + elif mode == 'rule' and atom_trace: + name = rules_to_be_applied_trace[idx][2] + else: + name = '-' + if atom_trace: + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], f'Inconsistency due to {name}') + # Resolve inconsistency and set static + w.world[na[0]].set_lower_upper(0, 1) + w.world[na[0]].set_static(True) + for p1, p2 in ipl: + if p1==na[0]: + if atom_trace: + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], f'Inconsistency due to {name}') + w.world[p2].set_lower_upper(0, 1) + w.world[p2].set_static(True) + if store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(0,1))) + + if p2==na[0]: + if atom_trace: + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], f'Inconsistency due to {name}') + w.world[p1].set_lower_upper(0, 1) + w.world[p1].set_static(True) + if store_interpretation_changes: + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1))) + + +@numba.njit(cache=True) +def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): + nodes.append(node) + neighbors[node] = numba.typed.List.empty_list(node_type) + reverse_neighbors[node] = numba.typed.List.empty_list(node_type) + interpretations_node[node] = world.World(numba.typed.List.empty_list(label.label_type)) + + +@numba.njit(cache=True) +def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): + # If not a node, add to list of nodes and initialize neighbors + if source not in nodes: + _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node) + + if target not in nodes: + _add_node(target, neighbors, reverse_neighbors, nodes, interpretations_node) + + # Make sure edge doesn't already exist + # Make sure, if l=='', not to add the label + # Make sure, if edge exists, that we don't override the l label if it exists + edge = (source, target) + new_edge = False + if edge not in edges: + new_edge = True + edges.append(edge) + neighbors[source].append(target) + reverse_neighbors[target].append(source) + if l.value!='': + interpretations_edge[edge] = world.World(numba.typed.List([l])) + num_ga[t] += 1 + if l in predicate_map: + predicate_map[l].append(edge) + else: + predicate_map[l] = numba.typed.List([edge]) + else: + interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type)) + else: + if l not in interpretations_edge[edge].world and l.value!='': + new_edge = True + interpretations_edge[edge].world[l] = interval.closed(0, 1) + num_ga[t] += 1 + + if l in predicate_map: + predicate_map[l].append(edge) + else: + predicate_map[l] = numba.typed.List([edge]) + + return edge, new_edge + + +@numba.njit(cache=True) +def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): + changes = 0 + edges_added = numba.typed.List.empty_list(edge_type) + for source in sources: + for target in targets: + edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t) + edges_added.append(edge) + changes = changes+1 if new_edge else changes + return edges_added, changes + + +@numba.njit(cache=True) +def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge, predicate_map, num_ga): + source, target = edge + edges.remove(edge) + num_ga[-1] -= len(interpretations_edge[edge].world) + del interpretations_edge[edge] + for l in predicate_map: + if edge in predicate_map[l]: + predicate_map[l].remove(edge) + neighbors[source].remove(target) + reverse_neighbors[target].remove(source) + + +@numba.njit(cache=True) +def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node, predicate_map, num_ga): + nodes.remove(node) + num_ga[-1] -= len(interpretations_node[node].world) + del interpretations_node[node] + del neighbors[node] + del reverse_neighbors[node] + for l in predicate_map: + if node in predicate_map[l]: + predicate_map[l].remove(node) + + # Remove all occurrences of node in neighbors + for n in neighbors.keys(): + if node in neighbors[n]: + neighbors[n].remove(node) + for n in reverse_neighbors.keys(): + if node in reverse_neighbors[n]: + reverse_neighbors[n].remove(node) + + +@numba.njit(cache=True) +def float_to_str(value): + number = int(value) + decimal = int(value % 1 * 1000) + float_str = f'{number}.{decimal}' + return float_str + + +@numba.njit(cache=True) +def str_to_float(value): + decimal_pos = value.find('.') + if decimal_pos != -1: + after_decimal_len = len(value[decimal_pos+1:]) + else: + after_decimal_len = 0 + value = value.replace('.', '') + value = str_to_int(value) + value = value / 10**after_decimal_len + return value + + +@numba.njit(cache=True) +def str_to_int(value): + if value[0] == '-': + negative = True + value = value.replace('-','') + else: + negative = False + final_index, result = len(value) - 1, 0 + for i, v in enumerate(value): + result += (ord(v) - 48) * (10 ** (final_index - i)) + result = -result if negative else result + return result \ No newline at end of file diff --git a/pyreason/scripts/interpretation/interpretation_fp.py b/pyreason/scripts/interpretation/interpretation_fp.py index 20aa33c3..65a55113 100755 --- a/pyreason/scripts/interpretation/interpretation_fp.py +++ b/pyreason/scripts/interpretation/interpretation_fp.py @@ -110,9 +110,9 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, # Setup graph neighbors and reverse neighbors self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type)) for n in self.graph.nodes(): - l = numba.typed.List.empty_list(node_type) - [l.append(neigh) for neigh in self.graph.neighbors(n)] - self.neighbors[n] = l + neighbor_list = numba.typed.List.empty_list(node_type) + [neighbor_list.append(neigh) for neigh in self.graph.neighbors(n)] + self.neighbors[n] = neighbor_list self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors) @@ -146,10 +146,10 @@ def _init_interpretations_node(nodes, specific_labels): # interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for l, ns in specific_labels.items(): - predicate_map[l] = numba.typed.List(ns) + for lbl, ns in specific_labels.items(): + predicate_map[lbl] = numba.typed.List(ns) # for n in ns: - # interpretations[n].world[l] = interval.closed(0.0, 1.0) + # interpretations[n].world[lbl] = interval.closed(0.0, 1.0) # num_ga[0] += 1 return interpretations, predicate_map @@ -168,10 +168,10 @@ def _init_interpretations_edge(edges, specific_labels): # interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for l, es in specific_labels.items(): - predicate_map[l] = numba.typed.List(es) + for lbl, es in specific_labels.items(): + predicate_map[lbl] = numba.typed.List(es) # for e in es: - # interpretations[e].world[l] = interval.closed(0.0, 1.0) + # interpretations[e].world[lbl] = interval.closed(0.0, 1.0) # num_ga[0] += 1 return interpretations, predicate_map @@ -281,10 +281,10 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi w = last_t_interp[n].world new_w = interpretations_node[t][n].world - for l in w: + for label in w: # Only copy if this is the first fp operation (fp_cnt == 0) or if the label doesn't exist - if fp_cnt == 0 or l not in new_w: - new_w[l] = w[l].copy() + if fp_cnt == 0 or label not in new_w: + new_w[label] = w[label].copy() # If not persistent then copy only what is static elif t > 0 and not persistent: @@ -297,12 +297,12 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi w = last_t_interp[n].world new_w = interpretations_node[t][n].world - for l in w: - if w[l].is_static(): + for label in w: + if w[label].is_static(): # Only copy if this is the first fp operation (fp_cnt == 0) or if the label doesn't exist - if fp_cnt == 0 or l not in new_w: - print("Overwriting static label", l, "for node", n, "at time", t) - new_w[l] = w[l].copy() + if fp_cnt == 0 or label not in new_w: + print("Overwriting static label", label, "for node", n, "at time", t) + new_w[label] = w[label].copy() # Edges # Only create new interpretation if it doesn't exist or if this is the first fp operation @@ -318,10 +318,10 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi w = last_t_interp[e].world new_w = interpretations_edge[t][e].world - for l in w: + for label in w: # Only copy if this is the first fp operation (fp_cnt == 0) or if the label doesn't exist - if fp_cnt == 0 or l not in new_w: - new_w[l] = w[l].copy() + if fp_cnt == 0 or label not in new_w: + new_w[label] = w[label].copy() # If not persistent then copy only what is static elif t > 0 and not persistent: @@ -333,11 +333,11 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi w = last_t_interp[e].world new_w = interpretations_edge[t][e].world - for l in w: - if w[l].is_static(): + for label in w: + if w[label].is_static(): # Only copy if this is the first fp operation (fp_cnt == 0) or if the label doesn't exist - if fp_cnt == 0 or l not in new_w: - new_w[l] = w[l].copy() + if fp_cnt == 0 or label not in new_w: + new_w[label] = w[label].copy() # Convergence parameters changes_cnt = 0 @@ -351,7 +351,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi nodes_set = set(nodes) for i in range(len(facts_to_be_applied_node)): if facts_to_be_applied_node[i][0] == t: - comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] + comp, label, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] # If the component is not in the graph, add it if comp not in nodes_set: nodes_set.add(comp) @@ -359,34 +359,34 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi elif comp not in interpretations_node[t]: _add_node_to_interpretation(comp, interpretations_node[t]) - print("Applying fact for node:", comp, l, bnd, static, graph_attribute, "at", t, "fp", fp_cnt) + print("Applying fact for node:", comp, label, bnd, static, graph_attribute, "at", t, "fp", fp_cnt) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well - if l in interpretations_node[t][comp].world and interpretations_node[t][comp].world[l].is_static(): + if label in interpretations_node[t][comp].world and interpretations_node[t][comp].world[label].is_static(): print("should not be here") # Check if we should even store any of the changes to the rule trace etc. # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, bnd)) + rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, bnd)) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_node_trace[i]) for p1, p2 in ipl: - if p1==l: + if p1==label: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_node[t][comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[t][comp].world[p2], facts_to_be_applied_node_trace[i]) - elif p2==l: + elif p2==label: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[t][comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[t][comp].world[p1], facts_to_be_applied_node_trace[i]) else: # Check for inconsistencies (multiple facts) - if check_consistent_node(interpretations_node[t], comp, (l, bnd)): + if check_consistent_node(interpretations_node[t], comp, (label, bnd)): print("should be here") mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override) + u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override) update = u or update # Update convergence params @@ -398,9 +398,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_node(interpretations_node[t], comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_node(interpretations_node[t], comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True) + u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True) update = u or update # Update convergence params @@ -410,7 +410,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute)) + facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, label, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i]) @@ -433,7 +433,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi edges_set = set(edges) for i in range(len(facts_to_be_applied_edge)): if facts_to_be_applied_edge[i][0] == t: - comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] + comp, label, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] # If the component is not in the graph, add it if comp not in edges_set: _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node[t], interpretations_edge[t], predicate_map_edge, t) @@ -442,27 +442,27 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi _add_edge_to_interpretation(comp, interpretations_edge[t]) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well - if l in interpretations_edge[t][comp].world and interpretations_edge[t][comp].world[l].is_static(): + if label in interpretations_edge[t][comp].world and interpretations_edge[t][comp].world[label].is_static(): # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, interpretations_edge[t][comp].world[l])) + rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, interpretations_edge[t][comp].world[label])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_edge_trace[i]) for p1, p2 in ipl: - if p1 == l: + if p1 == label: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_edge[t][comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[t][comp].world[p2], facts_to_be_applied_edge_trace[i]) - elif p2 == l: + elif p2 == label: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_edge[t][comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[t][comp].world[p1], facts_to_be_applied_edge_trace[i]) else: # Check for inconsistencies - if check_consistent_edge(interpretations_edge[t], comp, (l, bnd)): + if check_consistent_edge(interpretations_edge[t], comp, (label, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override) + u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override) update = u or update # Update convergence params @@ -474,9 +474,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge[t], comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_edge(interpretations_edge[t], comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True) + u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True) update = u or update # Update convergence params @@ -486,7 +486,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute)) + facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, label, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i]) @@ -587,16 +587,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi rules_to_remove_idx.clear() print("there are ", len(rules_to_be_applied_node), "rules to be applied for nodes") for idx, i in enumerate(rules_to_be_applied_node): - t, comp, l, bnd, set_static = i[0], i[1], i[2], i[3], i[4] + t, comp, label, bnd, set_static = i[0], i[1], i[2], i[3], i[4] # if node doesn't exist in interpretation, add it if comp not in interpretations_node[t]: _add_node_to_interpretation(comp, interpretations_node[t]) # Check for inconsistencies - if check_consistent_node(interpretations_node[t], comp, (l, bnd)): + if check_consistent_node(interpretations_node[t], comp, (label, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override) + u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override) update = u or update # Update convergence params @@ -607,9 +607,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_node(interpretations_node[t], comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_node(interpretations_node[t], comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True) + u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True) update = u or update # Update convergence params @@ -621,8 +621,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Delete rules that have been applied from list by adding index to list rules_to_remove_idx.add(idx) print("node rule to be applied") - print(t, comp, l, bnd, update) - print("interp change", interpretations_node[t][comp].world[l]) + print(t, comp, label, bnd, update) + print("interp change", interpretations_node[t][comp].world[label]) # Remove from rules to be applied and edges to be applied lists after coming out from loop rules_to_be_applied_node[:] = numba.typed.List([rules_to_be_applied_node[i] for i in range(len(rules_to_be_applied_node)) if i not in rules_to_remove_idx]) @@ -633,7 +633,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Edges rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_edge): - t, comp, l, bnd, set_static = i[0], i[1], i[2], i[3], i[4] + t, comp, label, bnd, set_static = i[0], i[1], i[2], i[3], i[4] sources, targets, edge_l = edges_to_be_added_edge_rule[idx] edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node[t], interpretations_edge[t], predicate_map_edge, t) changes_cnt += changes @@ -675,9 +675,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi _add_edge_to_interpretation(comp, interpretations_edge[t]) # Check for inconsistencies - if check_consistent_edge(interpretations_edge[t], comp, (l, bnd)): + if check_consistent_edge(interpretations_edge[t], comp, (label, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) + u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) update = u or update # Update convergence params @@ -688,9 +688,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge[t], comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_edge(interpretations_edge[t], comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) + u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) update = u or update # Update convergence params @@ -742,16 +742,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi return fp_cnt, max_t - def add_edge(self, edge, l): + def add_edge(self, edge, label): # This function is useful for pyreason gym, called externally - _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, -1) + _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, label, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, -1) def add_node(self, node, labels): # This function is useful for pyreason gym, called externally if node not in self.nodes: _add_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node) - for l in labels: - self.interpretations_node[node].world[label.Label(l)] = interval.closed(0, 1) + for lbl in labels: + self.interpretations_node[node].world[label.Label(lbl)] = interval.closed(0, 1) def delete_edge(self, edge): # This function is useful for pyreason gym, called externally @@ -776,23 +776,23 @@ def get_dict(self): # Update interpretation nodes for change in self.rule_trace_node: - time, _, node, l, bnd = change - interpretations[time][node][l._value] = (bnd.lower, bnd.upper) + time, _, node, label, bnd = change + interpretations[time][node][label._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][node][l._value] = (bnd.lower, bnd.upper) + interpretations[t][node][label._value] = (bnd.lower, bnd.upper) # Update interpretation edges for change in self.rule_trace_edge: - time, _, edge, l, bnd, = change - interpretations[time][edge][l._value] = (bnd.lower, bnd.upper) + time, _, edge, label, bnd, = change + interpretations[time][edge][label._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][edge][l._value] = (bnd.lower, bnd.upper) + interpretations[t][edge][label._value] = (bnd.lower, bnd.upper) return interpretations @@ -804,10 +804,10 @@ def get_final_num_ground_atoms(self): ga_cnt = 0 for node in self.nodes: - for l in self.interpretations_node[node].world: + for lbl in self.interpretations_node[node].world: ga_cnt += 1 for edge in self.edges: - for l in self.interpretations_edge[edge].world: + for lbl in self.interpretations_edge[edge].world: ga_cnt += 1 return ga_cnt @@ -1401,17 +1401,17 @@ def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, @numba.njit(cache=True) -def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes): +def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, label, nodes): # The groundings for a node clause can be either a previous grounding or all possible nodes - if l in predicate_map: - grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1] + if label in predicate_map: + grounding = predicate_map[label] if clause_var_1 not in groundings else groundings[clause_var_1] else: grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1] return grounding @numba.njit(cache=True) -def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges): +def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, label, edges): # There are 4 cases for predicate(Y,Z): # 1. Both predicate variables Y and Z have not been encountered before # 2. The source variable Y has not been encountered before but the target variable Z has @@ -1422,8 +1422,8 @@ def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groun # Case 1: # We replace Y by all nodes and Z by the neighbors of each of these nodes if clause_var_1 not in groundings and clause_var_2 not in groundings: - if l in predicate_map: - edge_groundings = predicate_map[l] + if label in predicate_map: + edge_groundings = predicate_map[label] else: edge_groundings = edges @@ -1516,33 +1516,33 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated = False # This is to prevent a key error in case the label is a specific label world = interpretations[comp] - l, bnd = na + label, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if l not in world.world: - world.world[l] = interval.closed(0, 1) - if l in predicate_map: - predicate_map[l].append(comp) + if label not in world.world: + world.world[label] = interval.closed(0, 1) + if label in predicate_map: + predicate_map[label].append(comp) else: - predicate_map[l] = numba.typed.List([comp]) + predicate_map[label] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[l].copy() + prev_bnd = world.world[label].copy() # override will not check for inconsistencies if override: - world.world[l].set_lower_upper(bnd.lower, bnd.upper) + world.world[label].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(l, bnd) - world.world[l].set_static(static) - if world.world[l]!=prev_bnd: + world.update(label, bnd) + world.world[label].set_static(static) + if world.world[label]!=prev_bnd: updated = True - updated_bnds.append(world.world[l]) + updated_bnds.append(world.world[label]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1558,7 +1558,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == l: + if p1 == label: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1566,7 +1566,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1575,7 +1575,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == l: + if p2 == label: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1583,7 +1583,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1597,8 +1597,8 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[l] - prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) + current_bnd = world.world[label] + prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1616,33 +1616,33 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False): updated = False world = interpretations[comp] - l, bnd = na + label, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if l not in world.world: - world.world[l] = interval.closed(0, 1) - if l in predicate_map: - predicate_map[l].append(comp) + if label not in world.world: + world.world[label] = interval.closed(0, 1) + if label in predicate_map: + predicate_map[label].append(comp) else: - predicate_map[l] = numba.typed.List([comp]) + predicate_map[label] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[l].copy() + prev_bnd = world.world[label].copy() # override will not check for inconsistencies if override: - world.world[l].set_lower_upper(bnd.lower, bnd.upper) + world.world[label].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(l, bnd) - world.world[l].set_static(static) - if world.world[l]!=prev_bnd: + world.update(label, bnd) + world.world[label].set_static(static) + if world.world[label]!=prev_bnd: updated = True - updated_bnds.append(world.world[l]) + updated_bnds.append(world.world[label]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1658,7 +1658,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == l: + if p1 == label: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1666,7 +1666,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1675,7 +1675,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == l: + if p2 == label: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1683,7 +1683,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1697,8 +1697,8 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[l] - prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) + current_bnd = world.world[label] + prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1720,8 +1720,8 @@ def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name): @numba.njit(cache=True) def are_satisfied_node(interpretations, comp, nas): result = True - for (l, bnd) in nas: - result = result and is_satisfied_node(interpretations, comp, (l, bnd)) + for (lbl, bnd) in nas: + result = result and is_satisfied_node(interpretations, comp, (lbl, bnd)) return result @@ -1733,7 +1733,7 @@ def is_satisfied_node(interpretations, comp, na): try: world = interpretations[comp] result = world.is_satisfied(na[0], na[1]) - except: + except Exception: result = False else: result = True @@ -1744,23 +1744,23 @@ def is_satisfied_node(interpretations, comp, na): def is_satisfied_node_comparison(interpretations, comp, na): result = False number = 0 - l, bnd = na - l_str = l.value + label, bnd = na + label_str = label.value - if not (l is None or bnd is None): + if not (label is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): - world_l_str = world_l.value - if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): + world.label_str = world.label.value + if label_str in world.label_str and world.label_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world_l_str[len(l_str)+1:]) + number = str_to_float(world.label_str[len(label_str)+1:]) break - except: + except Exception: result = False else: result = True @@ -1770,8 +1770,8 @@ def is_satisfied_node_comparison(interpretations, comp, na): @numba.njit(cache=True) def are_satisfied_edge(interpretations, comp, nas): result = True - for (l, bnd) in nas: - result = result and is_satisfied_edge(interpretations, comp, (l, bnd)) + for (lbl, bnd) in nas: + result = result and is_satisfied_edge(interpretations, comp, (lbl, bnd)) return result @@ -1783,7 +1783,7 @@ def is_satisfied_edge(interpretations, comp, na): try: world = interpretations[comp] result = world.is_satisfied(na[0], na[1]) - except: + except Exception: result = False else: result = True @@ -1794,23 +1794,23 @@ def is_satisfied_edge(interpretations, comp, na): def is_satisfied_edge_comparison(interpretations, comp, na): result = False number = 0 - l, bnd = na - l_str = l.value + label, bnd = na + label_str = label.value - if not (l is None or bnd is None): + if not (label is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): - world_l_str = world_l.value - if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): + world.label_str = world.label.value + if label_str in world.label_str and world.label_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world_l_str[len(l_str)+1:]) + number = str_to_float(world.label_str[len(label_str)+1:]) break - except: + except Exception: result = False else: result = True @@ -1945,7 +1945,7 @@ def _add_edge_to_interpretation(edge, interpretations_edge): @numba.njit(cache=True) -def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, t): +def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, t): # If not a node, add to list of nodes and initialize neighbors if source not in nodes: _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node) @@ -1955,7 +1955,7 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int # Make sure edge doesn't already exist # Make sure, if l=='', not to add the label - # Make sure, if edge exists, that we don't override the l label if it exists + # Make sure, if edge exists, that we don't override the label label if it exists edge = (source, target) new_edge = False if edge not in edges: @@ -1963,35 +1963,35 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int edges.append(edge) neighbors[source].append(target) reverse_neighbors[target].append(source) - if l.value!='': + if label.value!='': if edge not in interpretations_edge: - interpretations_edge[edge] = world.World(numba.typed.List([l])) - if l in predicate_map: - predicate_map[l].append(edge) + interpretations_edge[edge] = world.World(numba.typed.List([label])) + if label in predicate_map: + predicate_map[label].append(edge) else: - predicate_map[l] = numba.typed.List([edge]) + predicate_map[label] = numba.typed.List([edge]) else: interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type)) else: - if l not in interpretations_edge[edge].world and l.value!='': + if label not in interpretations_edge[edge].world and label.value!='': new_edge = True - interpretations_edge[edge].world[l] = interval.closed(0, 1) + interpretations_edge[edge].world[label] = interval.closed(0, 1) - if l in predicate_map: - predicate_map[l].append(edge) + if label in predicate_map: + predicate_map[label].append(edge) else: - predicate_map[l] = numba.typed.List([edge]) + predicate_map[label] = numba.typed.List([edge]) return edge, new_edge @numba.njit(cache=True) -def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, t): +def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, t): changes = 0 edges_added = numba.typed.List.empty_list(edge_type) for source in sources: for target in targets: - edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, t) + edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, t) edges_added.append(edge) changes = changes+1 if new_edge else changes return edges_added, changes @@ -2002,9 +2002,9 @@ def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge source, target = edge edges.remove(edge) del interpretations_edge[edge] - for l in predicate_map: - if edge in predicate_map[l]: - predicate_map[l].remove(edge) + for lbl in predicate_map: + if edge in predicate_map[lbl]: + predicate_map[lbl].remove(edge) neighbors[source].remove(target) reverse_neighbors[target].remove(source) @@ -2015,9 +2015,9 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node del interpretations_node[node] del neighbors[node] del reverse_neighbors[node] - for l in predicate_map: - if node in predicate_map[l]: - predicate_map[l].remove(node) + for lbl in predicate_map: + if node in predicate_map[lbl]: + predicate_map[lbl].remove(node) # Remove all occurrences of node in neighbors for n in neighbors.keys(): diff --git a/pyreason/scripts/interpretation/interpretation_parallel.py b/pyreason/scripts/interpretation/interpretation_parallel.py index bbf541de..3c59026e 100644 --- a/pyreason/scripts/interpretation/interpretation_parallel.py +++ b/pyreason/scripts/interpretation/interpretation_parallel.py @@ -106,9 +106,9 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, # Setup graph neighbors and reverse neighbors self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type)) for n in self.graph.nodes(): - l = numba.typed.List.empty_list(node_type) - [l.append(neigh) for neigh in self.graph.neighbors(n)] - self.neighbors[n] = l + neighbor_list = numba.typed.List.empty_list(node_type) + [neighbor_list.append(neigh) for neigh in self.graph.neighbors(n)] + self.neighbors[n] = neighbor_list self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors) @@ -139,10 +139,10 @@ def _init_interpretations_node(nodes, specific_labels, num_ga): interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for l, ns in specific_labels.items(): - predicate_map[l] = numba.typed.List(ns) + for lbl, ns in specific_labels.items(): + predicate_map[lbl] = numba.typed.List(ns) for n in ns: - interpretations[n].world[l] = interval.closed(0.0, 1.0) + interpretations[n].world[lbl] = interval.closed(0.0, 1.0) num_ga[0] += 1 return interpretations, predicate_map @@ -158,10 +158,10 @@ def _init_interpretations_edge(edges, specific_labels, num_ga): interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for l, es in specific_labels.items(): - predicate_map[l] = numba.typed.List(es) + for lbl, es in specific_labels.items(): + predicate_map[lbl] = numba.typed.List(es) for e in es: - interpretations[e].world[l] = interval.closed(0.0, 1.0) + interpretations[e].world[lbl] = interval.closed(0.0, 1.0) num_ga[0] += 1 return interpretations, predicate_map @@ -246,16 +246,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Reset nodes (only if not static) for n in nodes: w = interpretations_node[n].world - for l in w: - if not w[l].is_static(): - w[l].reset() + for label in w: + if not w[label].is_static(): + w[label].reset() # Reset edges (only if not static) for e in edges: w = interpretations_edge[e].world - for l in w: - if not w[l].is_static(): - w[l].reset() + for label in w: + if not w[label].is_static(): + w[label].reset() # Convergence parameters changes_cnt = 0 @@ -269,36 +269,36 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi nodes_set = set(nodes) for i in range(len(facts_to_be_applied_node)): if facts_to_be_applied_node[i][0] == t: - comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] + comp, label, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] # If the component is not in the graph, add it if comp not in nodes_set: _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node) nodes_set.add(comp) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well - if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static(): + if label in interpretations_node[comp].world and interpretations_node[comp].world[label].is_static(): # Check if we should even store any of the changes to the rule trace etc. # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, bnd)) + rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, bnd)) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_node_trace[i]) for p1, p2 in ipl: - if p1==l: + if p1==label: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_node[comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p2], facts_to_be_applied_node_trace[i]) - elif p2==l: + elif p2==label: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i]) else: # Check for inconsistencies (multiple facts) - if check_consistent_node(interpretations_node, comp, (l, bnd)): + if check_consistent_node(interpretations_node, comp, (label, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) update = u or update # Update convergence params @@ -310,9 +310,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_node(interpretations_node, comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) update = u or update # Update convergence params @@ -322,7 +322,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute)) + facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, label, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i]) @@ -345,34 +345,34 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi edges_set = set(edges) for i in range(len(facts_to_be_applied_edge)): if facts_to_be_applied_edge[i][0]==t: - comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] + comp, label, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] # If the component is not in the graph, add it if comp not in edges_set: _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) edges_set.add(comp) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well - if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static(): + if label in interpretations_edge[comp].world and interpretations_edge[comp].world[label].is_static(): # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, interpretations_edge[comp].world[l])) + rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, interpretations_edge[comp].world[label])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_edge_trace[i]) for p1, p2 in ipl: - if p1==l: + if p1==label: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_edge[comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p2], facts_to_be_applied_edge_trace[i]) - elif p2==l: + elif p2==label: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_edge[comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p1], facts_to_be_applied_edge_trace[i]) else: # Check for inconsistencies - if check_consistent_edge(interpretations_edge, comp, (l, bnd)): + if check_consistent_edge(interpretations_edge, comp, (label, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) update = u or update # Update convergence params @@ -384,9 +384,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_edge(interpretations_edge, comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) update = u or update # Update convergence params @@ -396,7 +396,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute)) + facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, label, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i]) @@ -423,11 +423,11 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_node): if i[0] == t: - comp, l, bnd, set_static = i[1], i[2], i[3], i[4] + comp, label, bnd, set_static = i[1], i[2], i[3], i[4] # Check for inconsistencies - if check_consistent_node(interpretations_node, comp, (l, bnd)): + if check_consistent_node(interpretations_node, comp, (label, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) update = u or update # Update convergence params @@ -438,9 +438,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_node(interpretations_node, comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) update = u or update # Update convergence params @@ -462,7 +462,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_edge): if i[0] == t: - comp, l, bnd, set_static = i[1], i[2], i[3], i[4] + comp, label, bnd, set_static = i[1], i[2], i[3], i[4] sources, targets, edge_l = edges_to_be_added_edge_rule[idx] edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) changes_cnt += changes @@ -500,9 +500,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: # Check for inconsistencies - if check_consistent_edge(interpretations_edge, comp, (l, bnd)): + if check_consistent_edge(interpretations_edge, comp, (label, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) update = u or update # Update convergence params @@ -513,9 +513,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_edge(interpretations_edge, comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) update = u or update # Update convergence params @@ -644,16 +644,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi return fp_cnt, t - def add_edge(self, edge, l): + def add_edge(self, edge, label): # This function is useful for pyreason gym, called externally - _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1) + _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, label, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1) def add_node(self, node, labels): # This function is useful for pyreason gym, called externally if node not in self.nodes: _add_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node) - for l in labels: - self.interpretations_node[node].world[label.Label(l)] = interval.closed(0, 1) + for lbl in labels: + self.interpretations_node[node].world[label.Label(lbl)] = interval.closed(0, 1) def delete_edge(self, edge): # This function is useful for pyreason gym, called externally @@ -678,23 +678,23 @@ def get_dict(self): # Update interpretation nodes for change in self.rule_trace_node: - time, _, node, l, bnd = change - interpretations[time][node][l._value] = (bnd.lower, bnd.upper) + time, _, node, label, bnd = change + interpretations[time][node][label._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][node][l._value] = (bnd.lower, bnd.upper) + interpretations[t][node][label._value] = (bnd.lower, bnd.upper) # Update interpretation edges for change in self.rule_trace_edge: - time, _, edge, l, bnd, = change - interpretations[time][edge][l._value] = (bnd.lower, bnd.upper) + time, _, edge, label, bnd, = change + interpretations[time][edge][label._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][edge][l._value] = (bnd.lower, bnd.upper) + interpretations[t][edge][label._value] = (bnd.lower, bnd.upper) return interpretations @@ -706,10 +706,10 @@ def get_final_num_ground_atoms(self): ga_cnt = 0 for node in self.nodes: - for l in self.interpretations_node[node].world: + for lbl in self.interpretations_node[node].world: ga_cnt += 1 for edge in self.edges: - for l in self.interpretations_edge[edge].world: + for lbl in self.interpretations_edge[edge].world: ga_cnt += 1 return ga_cnt @@ -1303,17 +1303,17 @@ def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, @numba.njit(cache=True) -def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes): +def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, label, nodes): # The groundings for a node clause can be either a previous grounding or all possible nodes - if l in predicate_map: - grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1] + if label in predicate_map: + grounding = predicate_map[label] if clause_var_1 not in groundings else groundings[clause_var_1] else: grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1] return grounding @numba.njit(cache=True) -def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges): +def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, label, edges): # There are 4 cases for predicate(Y,Z): # 1. Both predicate variables Y and Z have not been encountered before # 2. The source variable Y has not been encountered before but the target variable Z has @@ -1324,8 +1324,8 @@ def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groun # Case 1: # We replace Y by all nodes and Z by the neighbors of each of these nodes if clause_var_1 not in groundings and clause_var_2 not in groundings: - if l in predicate_map: - edge_groundings = predicate_map[l] + if label in predicate_map: + edge_groundings = predicate_map[label] else: edge_groundings = edges @@ -1419,34 +1419,34 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] - l, bnd = na + label, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if l not in world.world: - world.world[l] = interval.closed(0, 1) + if label not in world.world: + world.world[label] = interval.closed(0, 1) num_ga[t_cnt] += 1 - if l in predicate_map: - predicate_map[l].append(comp) + if label in predicate_map: + predicate_map[label].append(comp) else: - predicate_map[l] = numba.typed.List([comp]) + predicate_map[label] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[l].copy() + prev_bnd = world.world[label].copy() # override will not check for inconsistencies if override: - world.world[l].set_lower_upper(bnd.lower, bnd.upper) + world.world[label].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(l, bnd) - world.world[l].set_static(static) - if world.world[l]!=prev_bnd: + world.update(label, bnd) + world.world[label].set_static(static) + if world.world[label]!=prev_bnd: updated = True - updated_bnds.append(world.world[l]) + updated_bnds.append(world.world[label]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1462,7 +1462,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == l: + if p1 == label: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1470,7 +1470,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1479,7 +1479,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == l: + if p2 == label: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1487,7 +1487,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1501,8 +1501,8 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[l] - prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) + current_bnd = world.world[label] + prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1515,7 +1515,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c return (updated, change) - except: + except Exception: return (False, 0) @@ -1525,34 +1525,34 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] - l, bnd = na + label, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if l not in world.world: - world.world[l] = interval.closed(0, 1) + if label not in world.world: + world.world[label] = interval.closed(0, 1) num_ga[t_cnt] += 1 - if l in predicate_map: - predicate_map[l].append(comp) + if label in predicate_map: + predicate_map[label].append(comp) else: - predicate_map[l] = numba.typed.List([comp]) + predicate_map[label] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[l].copy() + prev_bnd = world.world[label].copy() # override will not check for inconsistencies if override: - world.world[l].set_lower_upper(bnd.lower, bnd.upper) + world.world[label].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(l, bnd) - world.world[l].set_static(static) - if world.world[l]!=prev_bnd: + world.update(label, bnd) + world.world[label].set_static(static) + if world.world[label]!=prev_bnd: updated = True - updated_bnds.append(world.world[l]) + updated_bnds.append(world.world[label]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1568,7 +1568,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == l: + if p1 == label: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1576,7 +1576,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1585,7 +1585,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == l: + if p2 == label: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1593,7 +1593,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1607,8 +1607,8 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[l] - prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) + current_bnd = world.world[label] + prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1620,7 +1620,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 1 + ip_update_cnt return (updated, change) - except: + except Exception: return (False, 0) @@ -1632,8 +1632,8 @@ def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name): @numba.njit(cache=True) def are_satisfied_node(interpretations, comp, nas): result = True - for (l, bnd) in nas: - result = result and is_satisfied_node(interpretations, comp, (l, bnd)) + for (lbl, bnd) in nas: + result = result and is_satisfied_node(interpretations, comp, (lbl, bnd)) return result @@ -1645,7 +1645,7 @@ def is_satisfied_node(interpretations, comp, na): try: world = interpretations[comp] result = world.is_satisfied(na[0], na[1]) - except: + except Exception: result = False else: result = True @@ -1656,23 +1656,23 @@ def is_satisfied_node(interpretations, comp, na): def is_satisfied_node_comparison(interpretations, comp, na): result = False number = 0 - l, bnd = na - l_str = l.value + label, bnd = na + label_str = label.value - if not (l is None or bnd is None): + if not (label is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): - world_l_str = world_l.value - if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): + world.label_str = world.label.value + if label_str in world.label_str and world.label_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world_l_str[len(l_str)+1:]) + number = str_to_float(world.label_str[len(label_str)+1:]) break - except: + except Exception: result = False else: result = True @@ -1682,8 +1682,8 @@ def is_satisfied_node_comparison(interpretations, comp, na): @numba.njit(cache=True) def are_satisfied_edge(interpretations, comp, nas): result = True - for (l, bnd) in nas: - result = result and is_satisfied_edge(interpretations, comp, (l, bnd)) + for (lbl, bnd) in nas: + result = result and is_satisfied_edge(interpretations, comp, (lbl, bnd)) return result @@ -1695,7 +1695,7 @@ def is_satisfied_edge(interpretations, comp, na): try: world = interpretations[comp] result = world.is_satisfied(na[0], na[1]) - except: + except Exception: result = False else: result = True @@ -1706,23 +1706,23 @@ def is_satisfied_edge(interpretations, comp, na): def is_satisfied_edge_comparison(interpretations, comp, na): result = False number = 0 - l, bnd = na - l_str = l.value + label, bnd = na + label_str = label.value - if not (l is None or bnd is None): + if not (label is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): - world_l_str = world_l.value - if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): + world.label_str = world.label.value + if label_str in world.label_str and world.label_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world_l_str[len(l_str)+1:]) + number = str_to_float(world.label_str[len(label_str)+1:]) break - except: + except Exception: result = False else: result = True @@ -1846,7 +1846,7 @@ def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): @numba.njit(cache=True) -def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): +def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t): # If not a node, add to list of nodes and initialize neighbors if source not in nodes: _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node) @@ -1856,7 +1856,7 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int # Make sure edge doesn't already exist # Make sure, if l=='', not to add the label - # Make sure, if edge exists, that we don't override the l label if it exists + # Make sure, if edge exists, that we don't override the label label if it exists edge = (source, target) new_edge = False if edge not in edges: @@ -1864,36 +1864,36 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, int edges.append(edge) neighbors[source].append(target) reverse_neighbors[target].append(source) - if l.value!='': - interpretations_edge[edge] = world.World(numba.typed.List([l])) + if label.value!='': + interpretations_edge[edge] = world.World(numba.typed.List([label])) num_ga[t] += 1 - if l in predicate_map: - predicate_map[l].append(edge) + if label in predicate_map: + predicate_map[label].append(edge) else: - predicate_map[l] = numba.typed.List([edge]) + predicate_map[label] = numba.typed.List([edge]) else: interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type)) else: - if l not in interpretations_edge[edge].world and l.value!='': + if label not in interpretations_edge[edge].world and label.value!='': new_edge = True - interpretations_edge[edge].world[l] = interval.closed(0, 1) + interpretations_edge[edge].world[label] = interval.closed(0, 1) num_ga[t] += 1 - if l in predicate_map: - predicate_map[l].append(edge) + if label in predicate_map: + predicate_map[label].append(edge) else: - predicate_map[l] = numba.typed.List([edge]) + predicate_map[label] = numba.typed.List([edge]) return edge, new_edge @numba.njit(cache=True) -def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): +def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t): changes = 0 edges_added = numba.typed.List.empty_list(edge_type) for source in sources: for target in targets: - edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t) + edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t) edges_added.append(edge) changes = changes+1 if new_edge else changes return edges_added, changes @@ -1905,9 +1905,9 @@ def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge edges.remove(edge) num_ga[-1] -= len(interpretations_edge[edge].world) del interpretations_edge[edge] - for l in predicate_map: - if edge in predicate_map[l]: - predicate_map[l].remove(edge) + for lbl in predicate_map: + if edge in predicate_map[lbl]: + predicate_map[lbl].remove(edge) neighbors[source].remove(target) reverse_neighbors[target].remove(source) @@ -1919,9 +1919,9 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node del interpretations_node[node] del neighbors[node] del reverse_neighbors[node] - for l in predicate_map: - if node in predicate_map[l]: - predicate_map[l].remove(node) + for lbl in predicate_map: + if node in predicate_map[lbl]: + predicate_map[lbl].remove(node) # Remove all occurrences of node in neighbors for n in neighbors.keys(): diff --git a/pyreason/scripts/interval/interval.py b/pyreason/scripts/interval/interval.py index 56252274..1750b274 100755 --- a/pyreason/scripts/interval/interval.py +++ b/pyreason/scripts/interval/interval.py @@ -4,8 +4,8 @@ class Interval(structref.StructRefProxy): - def __new__(cls, l, u, s=False): - return structref.StructRefProxy.__new__(cls, l, u, s, l, u) + def __new__(cls, lower, upper, s=False): + return structref.StructRefProxy.__new__(cls, lower, upper, s, lower, upper) @property @njit @@ -33,9 +33,9 @@ def prev_upper(self): return self.prev_u @njit - def set_lower_upper(self, l, u): - self.l = l - self.u = u + def set_lower_upper(self, lower, upper): + self.l = lower + self.u = upper @njit def reset(self): diff --git a/pyreason/scripts/numba_wrapper/numba_types/fact_edge_type.py b/pyreason/scripts/numba_wrapper/numba_types/fact_edge_type.py index 88a51f01..0b9d952f 100755 --- a/pyreason/scripts/numba_wrapper/numba_types/fact_edge_type.py +++ b/pyreason/scripts/numba_wrapper/numba_types/fact_edge_type.py @@ -32,8 +32,8 @@ def typeof_fact(val, c): # Construct object from Numba functions @type_callable(Fact) def type_fact(context): - def typer(name, component, l, bnd, t_lower, t_upper, static): - if isinstance(name, types.UnicodeType) and isinstance(component, types.Tuple) and isinstance(l, label.LabelType) and isinstance(bnd, interval.IntervalType) and isinstance(t_lower, numba.types.Integer) and isinstance(t_upper, numba.types.Integer) and isinstance(static, numba.types.Boolean): + def typer(name, component, label_param, bnd, t_lower, t_upper, static): + if isinstance(name, types.UnicodeType) and isinstance(component, types.Tuple) and isinstance(label_param, label.LabelType) and isinstance(bnd, interval.IntervalType) and isinstance(t_lower, numba.types.Integer) and isinstance(t_upper, numba.types.Integer) and isinstance(static, numba.types.Boolean): return fact_type return typer @@ -68,11 +68,11 @@ def __init__(self, dmm, fe_type): @lower_builtin(Fact, numba.types.string, numba.types.Tuple((numba.types.string, numba.types.string)), label.label_type, interval.interval_type, numba.types.uint16, numba.types.uint16, numba.types.boolean) def impl_fact(context, builder, sig, args): typ = sig.return_type - name, component, l, bnd, t_lower, t_upper, static = args + name, component, label_param, bnd, t_lower, t_upper, static = args fact = cgutils.create_struct_proxy(typ)(context, builder) fact.name = name fact.component = component - fact.l = l + fact.l = label_param fact.bnd = bnd fact.t_lower = t_lower fact.t_upper = t_upper @@ -117,7 +117,7 @@ def getter(fact): @overload_method(FactType, "get_time_upper") -def get_time_lower(fact): +def get_time_upper(fact): def getter(fact): return fact.t_upper return getter diff --git a/pyreason/scripts/numba_wrapper/numba_types/fact_node_type.py b/pyreason/scripts/numba_wrapper/numba_types/fact_node_type.py index 7f8b847d..2892a93e 100755 --- a/pyreason/scripts/numba_wrapper/numba_types/fact_node_type.py +++ b/pyreason/scripts/numba_wrapper/numba_types/fact_node_type.py @@ -32,8 +32,8 @@ def typeof_fact(val, c): # Construct object from Numba functions @type_callable(Fact) def type_fact(context): - def typer(name, component, l, bnd, t_lower, t_upper, static): - if isinstance(name, types.UnicodeType) and isinstance(component, types.UnicodeType) and isinstance(l, label.LabelType) and isinstance(bnd, interval.IntervalType) and isinstance(t_lower, numba.types.Integer) and isinstance(t_upper, numba.types.Integer) and isinstance(static, numba.types.Boolean): + def typer(name, component, label_param, bnd, t_lower, t_upper, static): + if isinstance(name, types.UnicodeType) and isinstance(component, types.UnicodeType) and isinstance(label_param, label.LabelType) and isinstance(bnd, interval.IntervalType) and isinstance(t_lower, numba.types.Integer) and isinstance(t_upper, numba.types.Integer) and isinstance(static, numba.types.Boolean): return fact_type return typer @@ -68,11 +68,11 @@ def __init__(self, dmm, fe_type): @lower_builtin(Fact, numba.types.string, numba.types.string, label.label_type, interval.interval_type, numba.types.uint16, numba.types.uint16, numba.types.boolean) def impl_fact(context, builder, sig, args): typ = sig.return_type - name, component, l, bnd, t_lower, t_upper, static = args + name, component, label_param, bnd, t_lower, t_upper, static = args fact = cgutils.create_struct_proxy(typ)(context, builder) fact.name = name fact.component = component - fact.l = l + fact.l = label_param fact.bnd = bnd fact.t_lower = t_lower fact.t_upper = t_upper @@ -117,7 +117,7 @@ def getter(fact): @overload_method(FactType, "get_time_upper") -def get_time_lower(fact): +def get_time_upper(fact): def getter(fact): return fact.t_upper return getter diff --git a/pyreason/scripts/numba_wrapper/numba_types/interval_type.py b/pyreason/scripts/numba_wrapper/numba_types/interval_type.py index 70439510..b6d2aa20 100755 --- a/pyreason/scripts/numba_wrapper/numba_types/interval_type.py +++ b/pyreason/scripts/numba_wrapper/numba_types/interval_type.py @@ -65,10 +65,10 @@ def impl(self, interval): return impl @overload_method(IntervalType, 'set_lower_upper') -def set_lower_upper(interval, l, u): - def impl(interval, l, u): - interval.l = np.float64(l) - interval.u = np.float64(u) +def set_lower_upper(interval, lower, upper): + def impl(interval, lower, upper): + interval.l = np.float64(lower) + interval.u = np.float64(upper) return impl @overload_method(IntervalType, 'reset') diff --git a/pyreason/scripts/numba_wrapper/numba_types/world_type.py b/pyreason/scripts/numba_wrapper/numba_types/world_type.py index 6b370a2b..e79748da 100755 --- a/pyreason/scripts/numba_wrapper/numba_types/world_type.py +++ b/pyreason/scripts/numba_wrapper/numba_types/world_type.py @@ -3,17 +3,15 @@ import pyreason.scripts.numba_wrapper.numba_types.label_type as label from pyreason.scripts.components.world import World -import operator from numba import types from numba.extending import typeof_impl from numba.extending import type_callable from numba.extending import models, register_model from numba.extending import make_attribute_wrapper -from numba.extending import overload_method, overload +from numba.extending import overload_method from numba.extending import lower_builtin from numba.core import cgutils from numba.extending import unbox, NativeValue, box -from numba.core.typing import signature # Create new numba type @@ -40,7 +38,7 @@ def typer(labels, world): return typer @type_callable(World) -def type_world(context): +def type_world_labels_only(context): def typer(labels): if isinstance(labels, types.ListType): return world_type @@ -68,21 +66,21 @@ def __init__(self, dmm, fe_type): def impl_world(context, builder, sig, args): # context.build_map(builder, ) typ = sig.return_type - l, wo = args + labels_arg, wo = args context.nrt.incref(builder, types.DictType(label.label_type, interval.interval_type), wo) - context.nrt.incref(builder, types.ListType(label.label_type), l) + context.nrt.incref(builder, types.ListType(label.label_type), labels_arg) w = cgutils.create_struct_proxy(typ)(context, builder) - w.labels = l + w.labels = labels_arg w.world = wo return w._getvalue() @lower_builtin(World, types.ListType(label.label_type)) -def impl_world(context, builder, sig, args): - def make_world(l): +def impl_world_labels_only(context, builder, sig, args): + def make_world(labels_arg): d = numba.typed.Dict.empty(key_type=label.label_type, value_type=interval.interval_type) - for lab in l: + for lab in labels_arg: d[lab] = interval.closed(0.0, 1.0) - w = World(l, d) + w = World(labels_arg, d) return w w = context.compile_internal(builder, make_world, sig, args) diff --git a/pyreason/scripts/utils/filter.py b/pyreason/scripts/utils/filter.py index 8f7498a0..c39a6b13 100755 --- a/pyreason/scripts/utils/filter.py +++ b/pyreason/scripts/utils/filter.py @@ -20,14 +20,14 @@ def filter_and_sort_nodes(self, interpretation, labels, bound, sort_by='lower', # change contains the timestep, fp operation, component, label and interval # Keep only the latest/most recent changes. Since list is sequencial, whatever was earlier will be overwritten for change in interpretation.rule_trace_node: - t, fp, comp, l, bnd = change - latest_changes[t][(comp, l)] = bnd + t, fp, comp, label, bnd = change + latest_changes[t][(comp, label)] = bnd # Create a list that needs to be sorted. This contains only the latest changes list_to_be_sorted = [] for t, d in latest_changes.items(): - for (comp, l), bnd in d.items(): - list_to_be_sorted.append((bnd, t, comp, l)) + for (comp, label), bnd in d.items(): + list_to_be_sorted.append((bnd, t, comp, label)) # Sort the list reverse = True if descending else False @@ -38,15 +38,15 @@ def filter_and_sort_nodes(self, interpretation, labels, bound, sort_by='lower', # Add sorted elements to df for i in list_to_be_sorted: - bnd, t, comp, l = i - df[t][(comp, l)] = bnd + bnd, t, comp, label = i + df[t][(comp, label)] = bnd for t, d in df.items(): - for (comp, l), bnd in d.items(): - if l.get_value() in labels and bnd in bound: + for (comp, label), bnd in d.items(): + if label.get_value() in labels and bnd in bound: if comp not in nodes[t]: nodes[t][comp] = {lab:[0,1] for lab in labels} - nodes[t][comp][l.get_value()] = [bnd.lower, bnd.upper] + nodes[t][comp][label.get_value()] = [bnd.lower, bnd.upper] dataframes = [] for t in range(self.tmax+1): @@ -74,14 +74,14 @@ def filter_and_sort_edges(self, interpretation, labels, bound, sort_by='lower', # change contains the timestep, fp operation, component, label and interval # Keep only the latest/most recent changes. Since list is sequential, whatever was earlier will be overwritten for change in interpretation.rule_trace_edge: - t, fp, comp, l, bnd = change - latest_changes[t][(comp, l)] = bnd + t, fp, comp, label, bnd = change + latest_changes[t][(comp, label)] = bnd # Create a list that needs to be sorted. This contains only the latest changes list_to_be_sorted = [] for t, d in latest_changes.items(): - for (comp, l), bnd in d.items(): - list_to_be_sorted.append((bnd, t, comp, l)) + for (comp, label), bnd in d.items(): + list_to_be_sorted.append((bnd, t, comp, label)) # Sort the list reverse = True if descending else False @@ -92,15 +92,15 @@ def filter_and_sort_edges(self, interpretation, labels, bound, sort_by='lower', # Add sorted elements to df for i in list_to_be_sorted: - bnd, t, comp, l = i - df[t][(comp, l)] = bnd + bnd, t, comp, label = i + df[t][(comp, label)] = bnd for t, d in df.items(): - for (comp, l), bnd in d.items(): - if l.get_value() in labels and bnd in bound: + for (comp, label), bnd in d.items(): + if label.get_value() in labels and bnd in bound: if comp not in edges[t]: edges[t][comp] = {lab: [0, 1] for lab in labels} - edges[t][comp][l.get_value()] = [bnd.lower, bnd.upper] + edges[t][comp][label.get_value()] = [bnd.lower, bnd.upper] dataframes = [] for t in range(self.tmax+1): diff --git a/pyreason/scripts/utils/graphml_parser.py b/pyreason/scripts/utils/graphml_parser.py index 967bd4dd..7b6ec87e 100755 --- a/pyreason/scripts/utils/graphml_parser.py +++ b/pyreason/scripts/utils/graphml_parser.py @@ -34,13 +34,13 @@ def parse_graph_attributes(self, static_facts): # IF attribute is a float or int and it is less than 1, then make it a bound, else make it a label if (isinstance(value, (float, int)) and 1 >= value >= 0) or ( isinstance(value, str) and value.replace('.', '').isdigit() and 1 >= float(value) >= 0): - l = str(key) - l_bnd = float(value) - u_bnd = 1 + label_str = str(key) + lower_bnd = float(value) + upper_bnd = 1 else: - l = f'{key}-{value}' - l_bnd = 1 - u_bnd = 1 + label_str = f'{key}-{value}' + lower_bnd = 1 + upper_bnd = 1 if isinstance(value, str): bnd_str = value.split(',') if len(bnd_str) == 2: @@ -48,29 +48,29 @@ def parse_graph_attributes(self, static_facts): low = int(bnd_str[0]) up = int(bnd_str[1]) if 1 >= low >= 0 and 1 >= up >= 0: - l_bnd = low - u_bnd = up - l = str(key) - except: + lower_bnd = low + upper_bnd = up + label_str = str(key) + except (ValueError, TypeError): pass - if label.Label(l) not in specific_node_labels.keys(): - specific_node_labels[label.Label(l)] = numba.typed.List.empty_list(numba.types.string) - specific_node_labels[label.Label(l)].append(n) - f = fact_node.Fact('graph-attribute-fact', n, label.Label(l), interval.closed(l_bnd, u_bnd), 0, 0, static=static_facts) + if label.Label(label_str) not in specific_node_labels.keys(): + specific_node_labels[label.Label(label_str)] = numba.typed.List.empty_list(numba.types.string) + specific_node_labels[label.Label(label_str)].append(n) + f = fact_node.Fact('graph-attribute-fact', n, label.Label(label_str), interval.closed(lower_bnd, upper_bnd), 0, 0, static=static_facts) facts_node.append(f) for e in self.graph.edges: for key, value in self.graph.edges[e].items(): # IF attribute is a float or int and it is less than 1, then make it a bound, else make it a label if (isinstance(value, (float, int)) and 1 >= value >= 0) or ( isinstance(value, str) and value.replace('.', '').isdigit() and 1 >= float(value) >= 0): - l = str(key) - l_bnd = float(value) - u_bnd = 1 + label_str = str(key) + lower_bnd = float(value) + upper_bnd = 1 else: - l = f'{key}-{value}' - l_bnd = 1 - u_bnd = 1 + label_str = f'{key}-{value}' + lower_bnd = 1 + upper_bnd = 1 if isinstance(value, str): bnd_str = value.split(',') if len(bnd_str) == 2: @@ -78,16 +78,16 @@ def parse_graph_attributes(self, static_facts): low = int(bnd_str[0]) up = int(bnd_str[1]) if 1 >= low >= 0 and 1 >= up >= 0: - l_bnd = low - u_bnd = up - l = str(key) - except: + lower_bnd = low + upper_bnd = up + label_str = str(key) + except (ValueError, TypeError): pass - if label.Label(l) not in specific_edge_labels.keys(): - specific_edge_labels[label.Label(l)] = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, numba.types.string))) - specific_edge_labels[label.Label(l)].append((e[0], e[1])) - f = fact_edge.Fact('graph-attribute-fact', (e[0], e[1]), label.Label(l), interval.closed(l_bnd, u_bnd), 0, 0, static=static_facts) + if label.Label(label_str) not in specific_edge_labels.keys(): + specific_edge_labels[label.Label(label_str)] = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, numba.types.string))) + specific_edge_labels[label.Label(label_str)].append((e[0], e[1])) + f = fact_edge.Fact('graph-attribute-fact', (e[0], e[1]), label.Label(label_str), interval.closed(lower_bnd, upper_bnd), 0, 0, static=static_facts) facts_edge.append(f) return facts_node, facts_edge, specific_node_labels, specific_edge_labels \ No newline at end of file diff --git a/pyreason/scripts/utils/output.py b/pyreason/scripts/utils/output.py index e680083d..614ade1d 100755 --- a/pyreason/scripts/utils/output.py +++ b/pyreason/scripts/utils/output.py @@ -1,4 +1,3 @@ -import csv import os import pandas as pd diff --git a/pyreason/scripts/utils/plotter.py b/pyreason/scripts/utils/plotter.py index 2ec182a6..ddc1e828 100755 --- a/pyreason/scripts/utils/plotter.py +++ b/pyreason/scripts/utils/plotter.py @@ -1,4 +1,3 @@ -import pandas as pd import matplotlib.pyplot as plt import seaborn as sns @@ -64,9 +63,9 @@ def main(): # ax.set_ylabel(y_axis_title, fontsize=13) # plt.show() if smooth: - plt.savefig(f'timesteps_vs_time_smooth.png') + plt.savefig('timesteps_vs_time_smooth.png') else: - plt.savefig(f'timesteps_vs_memory.png') + plt.savefig('timesteps_vs_memory.png') if __name__ == '__main__': diff --git a/pyreason/scripts/utils/query_parser.py b/pyreason/scripts/utils/query_parser.py index 7c5bfdb8..7dcc0759 100644 --- a/pyreason/scripts/utils/query_parser.py +++ b/pyreason/scripts/utils/query_parser.py @@ -8,17 +8,17 @@ def parse_query(query: str): if ':' in query: pred_comp, bounds = query.split(':') bounds = bounds.replace('[', '').replace(']', '') - l, u = bounds.split(',') - l, u = float(l), float(u) + lower, upper = bounds.split(',') + lower, upper = float(lower), float(upper) else: if query[0] == '~': pred_comp = query[1:] - l, u = 0, 0 + lower, upper = 0, 0 else: pred_comp = query - l, u = 1, 1 + lower, upper = 1, 1 - bnd = interval.closed(l, u) + bnd = interval.closed(lower, upper) # Split predicate and component idx = pred_comp.find('(') diff --git a/pyreason/scripts/utils/rule_parser.py b/pyreason/scripts/utils/rule_parser.py index dc1e7728..358976c7 100644 --- a/pyreason/scripts/utils/rule_parser.py +++ b/pyreason/scripts/utils/rule_parser.py @@ -81,8 +81,8 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, d # 7 for i in range(len(body_bounds)): bound = body_bounds[i] - l, u = _str_bound_to_bound(bound) - body_bounds[i] = [l, u] + lower, upper = _str_bound_to_bound(bound) + body_bounds[i] = [lower, upper] # Find the target predicate and bounds and annotation function if any. # Possible heads: @@ -181,9 +181,9 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, d clause_type = 'comparison' subset = numba.typed.List(variables) - l = label.Label(predicate) + label_obj = label.Label(predicate) bnd = interval.closed(bounds[0], bounds[1]) - clauses.append((clause_type, l, subset, bnd, op)) + clauses.append((clause_type, label_obj, subset, bnd, op)) # Assert that there are two variables in the head of the rule if we infer edges # Add edges between head variables if necessary @@ -208,22 +208,22 @@ def parse_rule(rule_text: str, name: str, custom_thresholds: Union[None, list, d def _str_bound_to_bound(str_bound): str_bound = str_bound.replace('[', '') str_bound = str_bound.replace(']', '') - l, u = str_bound.split(',') - return float(l), float(u) + lower, upper = str_bound.split(',') + return float(lower), float(upper) def _is_bound(str_bound): str_bound = str_bound.replace('[', '') str_bound = str_bound.replace(']', '') try: - l, u = str_bound.split(',') - l = l.replace('.', '') - u = u.replace('.', '') - if l.isdigit() and u.isdigit(): + lower, upper = str_bound.split(',') + lower = lower.replace('.', '') + upper = upper.replace('.', '') + if lower.isdigit() and upper.isdigit(): result = True else: result = False - except: + except (ValueError, AttributeError): result = False return result diff --git a/pyreason/scripts/utils/visuals.py b/pyreason/scripts/utils/visuals.py index c0133f7b..6ec4ac4d 100755 --- a/pyreason/scripts/utils/visuals.py +++ b/pyreason/scripts/utils/visuals.py @@ -6,10 +6,7 @@ ''' import networkx as nx -import matplotlib import matplotlib.pyplot as plt -import pandas as pd -from textwrap import wrap def get_subgraph(whole_graph, node_list): return nx.subgraph(whole_graph, node_list) diff --git a/pyreason/scripts/utils/yaml_parser.py b/pyreason/scripts/utils/yaml_parser.py index d93df9fc..da480942 100755 --- a/pyreason/scripts/utils/yaml_parser.py +++ b/pyreason/scripts/utils/yaml_parser.py @@ -39,9 +39,9 @@ def parse_rules(path): # Append clause clause_type = clause[0] subset = (clause[1][0], clause[1][0]) if clause_type=='node' else (clause[1][0], clause[1][1]) - l = label.Label(clause[2]) + label_obj = label.Label(clause[2]) bnd = interval.closed(clause[3][0], clause[3][1]) - neigh_criteria.append((clause_type, l, subset, bnd)) + neigh_criteria.append((clause_type, label_obj, subset, bnd)) # Append threshold corresponding to clause if specified in rule, else use default of greater equal to 1 if len(clause)>4: @@ -115,7 +115,7 @@ def parse_facts(path, reverse): if facts_yaml['nodes'] is not None: for fact_name, values in facts_yaml['nodes'].items(): n = str(values['node']) - l = label.Label(values['label']) + label_obj = label.Label(values['label']) bound = interval.closed(values['bound'][0], values['bound'][1]) if values['static']: static = True @@ -125,14 +125,14 @@ def parse_facts(path, reverse): static = False t_lower = values['t_lower'] t_upper = values['t_upper'] - f = fact_node.Fact(fact_name, n, l, bound, t_lower, t_upper, static) + f = fact_node.Fact(fact_name, n, label_obj, bound, t_lower, t_upper, static) facts_node.append(f) facts_edge = numba.typed.List.empty_list(fact_edge.fact_type) if facts_yaml['edges'] is not None: for fact_name, values in facts_yaml['edges'].items(): e = (str(values['source']), str(values['target'])) if not reverse else (str(values['target']), str(values['source'])) - l = label.Label(values['label']) + label_obj = label.Label(values['label']) bound = interval.closed(values['bound'][0], values['bound'][1]) if values['static']: static = True @@ -142,7 +142,7 @@ def parse_facts(path, reverse): static = False t_lower = values['t_lower'] t_upper = values['t_upper'] - f = fact_edge.Fact(fact_name, e, l, bound, t_lower, t_upper, static) + f = fact_edge.Fact(fact_name, e, label_obj, bound, t_lower, t_upper, static) facts_edge.append(f) return facts_node, facts_edge @@ -156,13 +156,13 @@ def parse_labels(path): edge_labels = numba.typed.List.empty_list(label.label_type) if labels_yaml['node_labels'] is not None: for label_name in labels_yaml['node_labels']: - l = label.Label(label_name) - node_labels.append(l) + label_obj = label.Label(label_name) + node_labels.append(label_obj) if labels_yaml['edge_labels'] is not None: for label_name in labels_yaml['edge_labels']: - l = label.Label(label_name) - edge_labels.append(l) + label_obj = label.Label(label_name) + edge_labels.append(label_obj) # Add an edge label for each edge edge_labels.append(label.Label('edge')) @@ -171,15 +171,15 @@ def parse_labels(path): if labels_yaml['node_specific_labels'] is not None: for entry in labels_yaml['node_specific_labels']: for label_name, nodes in entry.items(): - l = label.Label(str(label_name)) - specific_node_labels[l] = numba.typed.List([str(n) for n in nodes]) + label_obj = label.Label(str(label_name)) + specific_node_labels[label_obj] = numba.typed.List([str(n) for n in nodes]) specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(numba.types.Tuple((numba.types.string, numba.types.string)))) if labels_yaml['edge_specific_labels'] is not None: for entry in labels_yaml['edge_specific_labels']: for label_name, edges in entry.items(): - l = label.Label(str(label_name)) - specific_edge_labels[l] = numba.typed.List([(str(e[0]), str(e[1])) for e in edges]) + label_obj = label.Label(str(label_name)) + specific_edge_labels[label_obj] = numba.typed.List([(str(e[0]), str(e[1])) for e in edges]) return node_labels, edge_labels, specific_node_labels, specific_edge_labels From 0db987868ff3579c2df67a4125eec740bf8502eb Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sun, 5 Oct 2025 14:37:33 -0400 Subject: [PATCH 02/32] Final updates and CI/CD --- .../workflows/python-package-version-test.yml | 5 +- pyproject.toml | 5 + .../scripts/interpretation/interpretation.py | 274 ++++++++-------- .../interpretation/interpretation_fp.py | 302 +++++++++--------- .../interpretation/interpretation_parallel.py | 278 ++++++++-------- 5 files changed, 435 insertions(+), 429 deletions(-) diff --git a/.github/workflows/python-package-version-test.yml b/.github/workflows/python-package-version-test.yml index 680029b3..edc07f83 100644 --- a/.github/workflows/python-package-version-test.yml +++ b/.github/workflows/python-package-version-test.yml @@ -27,9 +27,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest + python -m pip install ruff pytest pip install torch==2.6.0 if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with ruff + run: | + python -m ruff check pyreason/scripts/ - name: Pytest Unit Tests with JIT Disabled run: | pytest tests/unit/disable_jit diff --git a/pyproject.toml b/pyproject.toml index f1bfca4d..bf5379f8 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,8 @@ [build-system] requires = ['setuptools>=42'] build-backend = 'setuptools.build_meta' + +[tool.ruff.lint] +# Ignore ambiguous variable name errors (E741) in interpretation files +[tool.ruff.lint.per-file-ignores] +"pyreason/scripts/interpretation/*.py" = ["E741"] diff --git a/pyreason/scripts/interpretation/interpretation.py b/pyreason/scripts/interpretation/interpretation.py index 2693a219..5cadb8bd 100755 --- a/pyreason/scripts/interpretation/interpretation.py +++ b/pyreason/scripts/interpretation/interpretation.py @@ -106,9 +106,9 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, # Setup graph neighbors and reverse neighbors self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type)) for n in self.graph.nodes(): - neighbor_list = numba.typed.List.empty_list(node_type) - [neighbor_list.append(neigh) for neigh in self.graph.neighbors(n)] - self.neighbors[n] = neighbor_list + l = numba.typed.List.empty_list(node_type) + [l.append(neigh) for neigh in self.graph.neighbors(n)] + self.neighbors[n] = l self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors) @@ -139,10 +139,10 @@ def _init_interpretations_node(nodes, specific_labels, num_ga): interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for lbl, ns in specific_labels.items(): - predicate_map[lbl] = numba.typed.List(ns) + for l, ns in specific_labels.items(): + predicate_map[l] = numba.typed.List(ns) for n in ns: - interpretations[n].world[lbl] = interval.closed(0.0, 1.0) + interpretations[n].world[l] = interval.closed(0.0, 1.0) num_ga[0] += 1 return interpretations, predicate_map @@ -158,10 +158,10 @@ def _init_interpretations_edge(edges, specific_labels, num_ga): interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for lbl, es in specific_labels.items(): - predicate_map[lbl] = numba.typed.List(es) + for l, es in specific_labels.items(): + predicate_map[l] = numba.typed.List(es) for e in es: - interpretations[e].world[lbl] = interval.closed(0.0, 1.0) + interpretations[e].world[l] = interval.closed(0.0, 1.0) num_ga[0] += 1 return interpretations, predicate_map @@ -246,16 +246,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Reset nodes (only if not static) for n in nodes: w = interpretations_node[n].world - for label in w: - if not w[label].is_static(): - w[label].reset() + for l in w: + if not w[l].is_static(): + w[l].reset() # Reset edges (only if not static) for e in edges: w = interpretations_edge[e].world - for label in w: - if not w[label].is_static(): - w[label].reset() + for l in w: + if not w[l].is_static(): + w[l].reset() # Convergence parameters changes_cnt = 0 @@ -269,36 +269,36 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi nodes_set = set(nodes) for i in range(len(facts_to_be_applied_node)): if facts_to_be_applied_node[i][0] == t: - comp, label, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] + comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] # If the component is not in the graph, add it if comp not in nodes_set: _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node) nodes_set.add(comp) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well - if label in interpretations_node[comp].world and interpretations_node[comp].world[label].is_static(): + if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static(): # Check if we should even store any of the changes to the rule trace etc. # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, bnd)) + rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, bnd)) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_node_trace[i]) for p1, p2 in ipl: - if p1==label: + if p1==l: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_node[comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p2], facts_to_be_applied_node_trace[i]) - elif p2==label: + elif p2==l: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i]) else: # Check for inconsistencies (multiple facts) - if check_consistent_node(interpretations_node, comp, (label, bnd)): + if check_consistent_node(interpretations_node, comp, (l, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) update = u or update # Update convergence params @@ -310,9 +310,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) update = u or update # Update convergence params @@ -322,7 +322,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, label, bnd, static, graph_attribute)) + facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i]) @@ -345,34 +345,34 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi edges_set = set(edges) for i in range(len(facts_to_be_applied_edge)): if facts_to_be_applied_edge[i][0]==t: - comp, label, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] + comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] # If the component is not in the graph, add it if comp not in edges_set: _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) edges_set.add(comp) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well - if label in interpretations_edge[comp].world and interpretations_edge[comp].world[label].is_static(): + if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static(): # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, interpretations_edge[comp].world[label])) + rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, interpretations_edge[comp].world[l])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_edge_trace[i]) for p1, p2 in ipl: - if p1==label: + if p1==l: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_edge[comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p2], facts_to_be_applied_edge_trace[i]) - elif p2==label: + elif p2==l: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_edge[comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p1], facts_to_be_applied_edge_trace[i]) else: # Check for inconsistencies - if check_consistent_edge(interpretations_edge, comp, (label, bnd)): + if check_consistent_edge(interpretations_edge, comp, (l, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) update = u or update # Update convergence params @@ -384,9 +384,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) update = u or update # Update convergence params @@ -396,7 +396,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, label, bnd, static, graph_attribute)) + facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i]) @@ -423,11 +423,11 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_node): if i[0] == t: - comp, label, bnd, set_static = i[1], i[2], i[3], i[4] + comp, l, bnd, set_static = i[1], i[2], i[3], i[4] # Check for inconsistencies - if check_consistent_node(interpretations_node, comp, (label, bnd)): + if check_consistent_node(interpretations_node, comp, (l, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) update = u or update # Update convergence params @@ -438,9 +438,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) update = u or update # Update convergence params @@ -462,7 +462,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_edge): if i[0] == t: - comp, label, bnd, set_static = i[1], i[2], i[3], i[4] + comp, l, bnd, set_static = i[1], i[2], i[3], i[4] sources, targets, edge_l = edges_to_be_added_edge_rule[idx] edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) changes_cnt += changes @@ -500,9 +500,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: # Check for inconsistencies - if check_consistent_edge(interpretations_edge, comp, (label, bnd)): + if check_consistent_edge(interpretations_edge, comp, (l, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) update = u or update # Update convergence params @@ -513,9 +513,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) update = u or update # Update convergence params @@ -644,16 +644,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi return fp_cnt, t - def add_edge(self, edge, label): + def add_edge(self, edge, l): # This function is useful for pyreason gym, called externally - _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, label, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1) + _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1) def add_node(self, node, labels): # This function is useful for pyreason gym, called externally if node not in self.nodes: _add_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node) - for label in labels: - self.interpretations_node[node].world[label.Label(label)] = interval.closed(0, 1) + for l in labels: + self.interpretations_node[node].world[label.Label(l)] = interval.closed(0, 1) def delete_edge(self, edge): # This function is useful for pyreason gym, called externally @@ -678,23 +678,23 @@ def get_dict(self): # Update interpretation nodes for change in self.rule_trace_node: - time, _, node, label, bnd = change - interpretations[time][node][label._value] = (bnd.lower, bnd.upper) + time, _, node, l, bnd = change + interpretations[time][node][l._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][node][label._value] = (bnd.lower, bnd.upper) + interpretations[t][node][l._value] = (bnd.lower, bnd.upper) # Update interpretation edges for change in self.rule_trace_edge: - time, _, edge, label, bnd, = change - interpretations[time][edge][label._value] = (bnd.lower, bnd.upper) + time, _, edge, l, bnd, = change + interpretations[time][edge][l._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][edge][label._value] = (bnd.lower, bnd.upper) + interpretations[t][edge][l._value] = (bnd.lower, bnd.upper) return interpretations @@ -706,10 +706,10 @@ def get_final_num_ground_atoms(self): ga_cnt = 0 for node in self.nodes: - for lbl in self.interpretations_node[node].world: + for l in self.interpretations_node[node].world: ga_cnt += 1 for edge in self.edges: - for lbl in self.interpretations_edge[edge].world: + for l in self.interpretations_edge[edge].world: ga_cnt += 1 return ga_cnt @@ -807,7 +807,7 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map clause_label = clause[1] clause_variables = clause[2] clause_bnd = clause[3] - # clause_operator = clause[4] # Currently unused + _clause_operator = clause[4] # This is a node clause if clause_type == 'node': @@ -1303,17 +1303,17 @@ def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, @numba.njit(cache=True) -def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, label, nodes): +def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes): # The groundings for a node clause can be either a previous grounding or all possible nodes - if label in predicate_map: - grounding = predicate_map[label] if clause_var_1 not in groundings else groundings[clause_var_1] + if l in predicate_map: + grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1] else: grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1] return grounding @numba.njit(cache=True) -def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, label, edges): +def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges): # There are 4 cases for predicate(Y,Z): # 1. Both predicate variables Y and Z have not been encountered before # 2. The source variable Y has not been encountered before but the target variable Z has @@ -1324,8 +1324,8 @@ def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groun # Case 1: # We replace Y by all nodes and Z by the neighbors of each of these nodes if clause_var_1 not in groundings and clause_var_2 not in groundings: - if label in predicate_map: - edge_groundings = predicate_map[label] + if l in predicate_map: + edge_groundings = predicate_map[l] else: edge_groundings = edges @@ -1419,34 +1419,34 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] - label, bnd = na + l, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if label not in world.world: - world.world[label] = interval.closed(0, 1) + if l not in world.world: + world.world[l] = interval.closed(0, 1) num_ga[t_cnt] += 1 - if label in predicate_map: - predicate_map[label].append(comp) + if l in predicate_map: + predicate_map[l].append(comp) else: - predicate_map[label] = numba.typed.List([comp]) + predicate_map[l] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[label].copy() + prev_bnd = world.world[l].copy() # override will not check for inconsistencies if override: - world.world[label].set_lower_upper(bnd.lower, bnd.upper) + world.world[l].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(label, bnd) - world.world[label].set_static(static) - if world.world[label]!=prev_bnd: + world.update(l, bnd) + world.world[l].set_static(static) + if world.world[l]!=prev_bnd: updated = True - updated_bnds.append(world.world[label]) + updated_bnds.append(world.world[l]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1462,7 +1462,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == label: + if p1 == l: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1470,7 +1470,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1479,7 +1479,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == label: + if p2 == l: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1487,7 +1487,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1501,8 +1501,8 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[label] - prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) + current_bnd = world.world[l] + prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1525,34 +1525,34 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] - label, bnd = na + l, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if label not in world.world: - world.world[label] = interval.closed(0, 1) + if l not in world.world: + world.world[l] = interval.closed(0, 1) num_ga[t_cnt] += 1 - if label in predicate_map: - predicate_map[label].append(comp) + if l in predicate_map: + predicate_map[l].append(comp) else: - predicate_map[label] = numba.typed.List([comp]) + predicate_map[l] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[label].copy() + prev_bnd = world.world[l].copy() # override will not check for inconsistencies if override: - world.world[label].set_lower_upper(bnd.lower, bnd.upper) + world.world[l].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(label, bnd) - world.world[label].set_static(static) - if world.world[label]!=prev_bnd: + world.update(l, bnd) + world.world[l].set_static(static) + if world.world[l]!=prev_bnd: updated = True - updated_bnds.append(world.world[label]) + updated_bnds.append(world.world[l]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1568,7 +1568,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == label: + if p1 == l: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1576,7 +1576,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1585,7 +1585,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == label: + if p2 == l: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1593,7 +1593,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1607,8 +1607,8 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[label] - prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) + current_bnd = world.world[l] + prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1632,8 +1632,8 @@ def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name): @numba.njit(cache=True) def are_satisfied_node(interpretations, comp, nas): result = True - for (lbl, bnd) in nas: - result = result and is_satisfied_node(interpretations, comp, (lbl, bnd)) + for (l, bnd) in nas: + result = result and is_satisfied_node(interpretations, comp, (l, bnd)) return result @@ -1656,20 +1656,20 @@ def is_satisfied_node(interpretations, comp, na): def is_satisfied_node_comparison(interpretations, comp, na): result = False number = 0 - label, bnd = na - label_str = label.value + l, bnd = na + l_str = l.value - if not (label is None or bnd is None): + if not (l is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): world_l_str = world_l.value - if label_str in world_l_str and world_l_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): + if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world_l_str[len(label_str)+1:]) + number = str_to_float(world_l_str[len(l_str)+1:]) break except Exception: @@ -1682,8 +1682,8 @@ def is_satisfied_node_comparison(interpretations, comp, na): @numba.njit(cache=True) def are_satisfied_edge(interpretations, comp, nas): result = True - for (lbl, bnd) in nas: - result = result and is_satisfied_edge(interpretations, comp, (lbl, bnd)) + for (l, bnd) in nas: + result = result and is_satisfied_edge(interpretations, comp, (l, bnd)) return result @@ -1706,20 +1706,20 @@ def is_satisfied_edge(interpretations, comp, na): def is_satisfied_edge_comparison(interpretations, comp, na): result = False number = 0 - label, bnd = na - label_str = label.value + l, bnd = na + l_str = l.value - if not (label is None or bnd is None): + if not (l is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): world_l_str = world_l.value - if label_str in world_l_str and world_l_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): + if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world_l_str[len(label_str)+1:]) + number = str_to_float(world_l_str[len(l_str)+1:]) break except Exception: @@ -1846,7 +1846,7 @@ def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): @numba.njit(cache=True) -def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t): +def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): # If not a node, add to list of nodes and initialize neighbors if source not in nodes: _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node) @@ -1855,8 +1855,8 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, _add_node(target, neighbors, reverse_neighbors, nodes, interpretations_node) # Make sure edge doesn't already exist - # Make sure, if label=='', not to add the label - # Make sure, if edge exists, that we don't override the label label if it exists + # Make sure, if l=='', not to add the label + # Make sure, if edge exists, that we don't override the l label if it exists edge = (source, target) new_edge = False if edge not in edges: @@ -1864,36 +1864,36 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, edges.append(edge) neighbors[source].append(target) reverse_neighbors[target].append(source) - if label.value!='': - interpretations_edge[edge] = world.World(numba.typed.List([label])) + if l.value!='': + interpretations_edge[edge] = world.World(numba.typed.List([l])) num_ga[t] += 1 - if label in predicate_map: - predicate_map[label].append(edge) + if l in predicate_map: + predicate_map[l].append(edge) else: - predicate_map[label] = numba.typed.List([edge]) + predicate_map[l] = numba.typed.List([edge]) else: interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type)) else: - if label not in interpretations_edge[edge].world and label.value!='': + if l not in interpretations_edge[edge].world and l.value!='': new_edge = True - interpretations_edge[edge].world[label] = interval.closed(0, 1) + interpretations_edge[edge].world[l] = interval.closed(0, 1) num_ga[t] += 1 - if label in predicate_map: - predicate_map[label].append(edge) + if l in predicate_map: + predicate_map[l].append(edge) else: - predicate_map[label] = numba.typed.List([edge]) + predicate_map[l] = numba.typed.List([edge]) return edge, new_edge @numba.njit(cache=True) -def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t): +def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): changes = 0 edges_added = numba.typed.List.empty_list(edge_type) for source in sources: for target in targets: - edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t) + edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t) edges_added.append(edge) changes = changes+1 if new_edge else changes return edges_added, changes @@ -1905,9 +1905,9 @@ def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge edges.remove(edge) num_ga[-1] -= len(interpretations_edge[edge].world) del interpretations_edge[edge] - for lbl in predicate_map: - if edge in predicate_map[lbl]: - predicate_map[lbl].remove(edge) + for l in predicate_map: + if edge in predicate_map[l]: + predicate_map[l].remove(edge) neighbors[source].remove(target) reverse_neighbors[target].remove(source) @@ -1919,9 +1919,9 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node del interpretations_node[node] del neighbors[node] del reverse_neighbors[node] - for lbl in predicate_map: - if node in predicate_map[lbl]: - predicate_map[lbl].remove(node) + for l in predicate_map: + if node in predicate_map[l]: + predicate_map[l].remove(node) # Remove all occurrences of node in neighbors for n in neighbors.keys(): diff --git a/pyreason/scripts/interpretation/interpretation_fp.py b/pyreason/scripts/interpretation/interpretation_fp.py index 65a55113..cfd7e4e4 100755 --- a/pyreason/scripts/interpretation/interpretation_fp.py +++ b/pyreason/scripts/interpretation/interpretation_fp.py @@ -110,9 +110,9 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, # Setup graph neighbors and reverse neighbors self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type)) for n in self.graph.nodes(): - neighbor_list = numba.typed.List.empty_list(node_type) - [neighbor_list.append(neigh) for neigh in self.graph.neighbors(n)] - self.neighbors[n] = neighbor_list + l = numba.typed.List.empty_list(node_type) + [l.append(neigh) for neigh in self.graph.neighbors(n)] + self.neighbors[n] = l self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors) @@ -146,10 +146,10 @@ def _init_interpretations_node(nodes, specific_labels): # interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for lbl, ns in specific_labels.items(): - predicate_map[lbl] = numba.typed.List(ns) + for l, ns in specific_labels.items(): + predicate_map[l] = numba.typed.List(ns) # for n in ns: - # interpretations[n].world[lbl] = interval.closed(0.0, 1.0) + # interpretations[n].world[l] = interval.closed(0.0, 1.0) # num_ga[0] += 1 return interpretations, predicate_map @@ -168,10 +168,10 @@ def _init_interpretations_edge(edges, specific_labels): # interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for lbl, es in specific_labels.items(): - predicate_map[lbl] = numba.typed.List(es) + for l, es in specific_labels.items(): + predicate_map[l] = numba.typed.List(es) # for e in es: - # interpretations[e].world[lbl] = interval.closed(0.0, 1.0) + # interpretations[e].world[l] = interval.closed(0.0, 1.0) # num_ga[0] += 1 return interpretations, predicate_map @@ -281,10 +281,10 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi w = last_t_interp[n].world new_w = interpretations_node[t][n].world - for label in w: + for l in w: # Only copy if this is the first fp operation (fp_cnt == 0) or if the label doesn't exist - if fp_cnt == 0 or label not in new_w: - new_w[label] = w[label].copy() + if fp_cnt == 0 or l not in new_w: + new_w[l] = w[l].copy() # If not persistent then copy only what is static elif t > 0 and not persistent: @@ -297,12 +297,12 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi w = last_t_interp[n].world new_w = interpretations_node[t][n].world - for label in w: - if w[label].is_static(): + for l in w: + if w[l].is_static(): # Only copy if this is the first fp operation (fp_cnt == 0) or if the label doesn't exist - if fp_cnt == 0 or label not in new_w: - print("Overwriting static label", label, "for node", n, "at time", t) - new_w[label] = w[label].copy() + if fp_cnt == 0 or l not in new_w: + print("Overwriting static label", l, "for node", n, "at time", t) + new_w[l] = w[l].copy() # Edges # Only create new interpretation if it doesn't exist or if this is the first fp operation @@ -318,10 +318,10 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi w = last_t_interp[e].world new_w = interpretations_edge[t][e].world - for label in w: + for l in w: # Only copy if this is the first fp operation (fp_cnt == 0) or if the label doesn't exist - if fp_cnt == 0 or label not in new_w: - new_w[label] = w[label].copy() + if fp_cnt == 0 or l not in new_w: + new_w[l] = w[l].copy() # If not persistent then copy only what is static elif t > 0 and not persistent: @@ -333,11 +333,11 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi w = last_t_interp[e].world new_w = interpretations_edge[t][e].world - for label in w: - if w[label].is_static(): + for l in w: + if w[l].is_static(): # Only copy if this is the first fp operation (fp_cnt == 0) or if the label doesn't exist - if fp_cnt == 0 or label not in new_w: - new_w[label] = w[label].copy() + if fp_cnt == 0 or l not in new_w: + new_w[l] = w[l].copy() # Convergence parameters changes_cnt = 0 @@ -351,7 +351,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi nodes_set = set(nodes) for i in range(len(facts_to_be_applied_node)): if facts_to_be_applied_node[i][0] == t: - comp, label, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] + comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] # If the component is not in the graph, add it if comp not in nodes_set: nodes_set.add(comp) @@ -359,34 +359,34 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi elif comp not in interpretations_node[t]: _add_node_to_interpretation(comp, interpretations_node[t]) - print("Applying fact for node:", comp, label, bnd, static, graph_attribute, "at", t, "fp", fp_cnt) + print("Applying fact for node:", comp, l, bnd, static, graph_attribute, "at", t, "fp", fp_cnt) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well - if label in interpretations_node[t][comp].world and interpretations_node[t][comp].world[label].is_static(): + if l in interpretations_node[t][comp].world and interpretations_node[t][comp].world[l].is_static(): print("should not be here") # Check if we should even store any of the changes to the rule trace etc. # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, bnd)) + rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, bnd)) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_node_trace[i]) for p1, p2 in ipl: - if p1==label: + if p1==l: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_node[t][comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[t][comp].world[p2], facts_to_be_applied_node_trace[i]) - elif p2==label: + elif p2==l: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[t][comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[t][comp].world[p1], facts_to_be_applied_node_trace[i]) else: # Check for inconsistencies (multiple facts) - if check_consistent_node(interpretations_node[t], comp, (label, bnd)): + if check_consistent_node(interpretations_node[t], comp, (l, bnd)): print("should be here") mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override) + u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=override) update = u or update # Update convergence params @@ -398,9 +398,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_node(interpretations_node[t], comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_node(interpretations_node[t], comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True) + u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode=mode, override=True) update = u or update # Update convergence params @@ -410,7 +410,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, label, bnd, static, graph_attribute)) + facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i]) @@ -433,7 +433,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi edges_set = set(edges) for i in range(len(facts_to_be_applied_edge)): if facts_to_be_applied_edge[i][0] == t: - comp, label, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] + comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] # If the component is not in the graph, add it if comp not in edges_set: _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node[t], interpretations_edge[t], predicate_map_edge, t) @@ -442,27 +442,27 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi _add_edge_to_interpretation(comp, interpretations_edge[t]) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well - if label in interpretations_edge[t][comp].world and interpretations_edge[t][comp].world[label].is_static(): + if l in interpretations_edge[t][comp].world and interpretations_edge[t][comp].world[l].is_static(): # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, interpretations_edge[t][comp].world[label])) + rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, interpretations_edge[t][comp].world[l])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_edge_trace[i]) for p1, p2 in ipl: - if p1 == label: + if p1 == l: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_edge[t][comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[t][comp].world[p2], facts_to_be_applied_edge_trace[i]) - elif p2 == label: + elif p2 == l: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_edge[t][comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[t][comp].world[p1], facts_to_be_applied_edge_trace[i]) else: # Check for inconsistencies - if check_consistent_edge(interpretations_edge[t], comp, (label, bnd)): + if check_consistent_edge(interpretations_edge[t], comp, (l, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override) + u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=override) update = u or update # Update convergence params @@ -474,9 +474,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge[t], comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_edge(interpretations_edge[t], comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True) + u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode=mode, override=True) update = u or update # Update convergence params @@ -486,7 +486,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, label, bnd, static, graph_attribute)) + facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i]) @@ -538,7 +538,6 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # If delta_t is zero we apply the rules and check if more are applicable if delta_t == 0: - in_loop = True update = False for applicable_rule in applicable_edge_rules: @@ -560,7 +559,6 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # If delta_t is zero we apply the rules and check if more are applicable if delta_t == 0: - in_loop = True update = False # Update lists after parallel run @@ -587,16 +585,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi rules_to_remove_idx.clear() print("there are ", len(rules_to_be_applied_node), "rules to be applied for nodes") for idx, i in enumerate(rules_to_be_applied_node): - t, comp, label, bnd, set_static = i[0], i[1], i[2], i[3], i[4] + t, comp, l, bnd, set_static = i[0], i[1], i[2], i[3], i[4] # if node doesn't exist in interpretation, add it if comp not in interpretations_node[t]: _add_node_to_interpretation(comp, interpretations_node[t]) # Check for inconsistencies - if check_consistent_node(interpretations_node[t], comp, (label, bnd)): + if check_consistent_node(interpretations_node[t], comp, (l, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override) + u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=override) update = u or update # Update convergence params @@ -607,9 +605,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_node(interpretations_node[t], comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_node(interpretations_node[t], comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True) + u, changes = _update_node(interpretations_node[t], predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, mode='rule', override=True) update = u or update # Update convergence params @@ -621,8 +619,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Delete rules that have been applied from list by adding index to list rules_to_remove_idx.add(idx) print("node rule to be applied") - print(t, comp, label, bnd, update) - print("interp change", interpretations_node[t][comp].world[label]) + print(t, comp, l, bnd, update) + print("interp change", interpretations_node[t][comp].world[l]) # Remove from rules to be applied and edges to be applied lists after coming out from loop rules_to_be_applied_node[:] = numba.typed.List([rules_to_be_applied_node[i] for i in range(len(rules_to_be_applied_node)) if i not in rules_to_remove_idx]) @@ -633,7 +631,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Edges rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_edge): - t, comp, label, bnd, set_static = i[0], i[1], i[2], i[3], i[4] + t, comp, l, bnd, set_static = i[0], i[1], i[2], i[3], i[4] sources, targets, edge_l = edges_to_be_added_edge_rule[idx] edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node[t], interpretations_edge[t], predicate_map_edge, t) changes_cnt += changes @@ -675,9 +673,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi _add_edge_to_interpretation(comp, interpretations_edge[t]) # Check for inconsistencies - if check_consistent_edge(interpretations_edge[t], comp, (label, bnd)): + if check_consistent_edge(interpretations_edge[t], comp, (l, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) + u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=override) update = u or update # Update convergence params @@ -688,9 +686,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge[t], comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_edge(interpretations_edge[t], comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) + u, changes = _update_edge(interpretations_edge[t], predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, mode='rule', override=True) update = u or update # Update convergence params @@ -742,16 +740,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi return fp_cnt, max_t - def add_edge(self, edge, label): + def add_edge(self, edge, l): # This function is useful for pyreason gym, called externally - _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, label, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, -1) + _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, -1) def add_node(self, node, labels): # This function is useful for pyreason gym, called externally if node not in self.nodes: _add_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node) - for lbl in labels: - self.interpretations_node[node].world[label.Label(lbl)] = interval.closed(0, 1) + for l in labels: + self.interpretations_node[node].world[label.Label(l)] = interval.closed(0, 1) def delete_edge(self, edge): # This function is useful for pyreason gym, called externally @@ -776,23 +774,23 @@ def get_dict(self): # Update interpretation nodes for change in self.rule_trace_node: - time, _, node, label, bnd = change - interpretations[time][node][label._value] = (bnd.lower, bnd.upper) + time, _, node, l, bnd = change + interpretations[time][node][l._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][node][label._value] = (bnd.lower, bnd.upper) + interpretations[t][node][l._value] = (bnd.lower, bnd.upper) # Update interpretation edges for change in self.rule_trace_edge: - time, _, edge, label, bnd, = change - interpretations[time][edge][label._value] = (bnd.lower, bnd.upper) + time, _, edge, l, bnd, = change + interpretations[time][edge][l._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][edge][label._value] = (bnd.lower, bnd.upper) + interpretations[t][edge][l._value] = (bnd.lower, bnd.upper) return interpretations @@ -804,10 +802,10 @@ def get_final_num_ground_atoms(self): ga_cnt = 0 for node in self.nodes: - for lbl in self.interpretations_node[node].world: + for l in self.interpretations_node[node].world: ga_cnt += 1 for edge in self.edges: - for lbl in self.interpretations_edge[edge].world: + for l in self.interpretations_edge[edge].world: ga_cnt += 1 return ga_cnt @@ -896,7 +894,7 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map clause_label = clause[1] clause_variables = clause[2] clause_bnd = clause[3] - clause_operator = clause[4] + _clause_operator = clause[4] # This is a node clause if clause_type == 'node': @@ -1401,17 +1399,17 @@ def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, @numba.njit(cache=True) -def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, label, nodes): +def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes): # The groundings for a node clause can be either a previous grounding or all possible nodes - if label in predicate_map: - grounding = predicate_map[label] if clause_var_1 not in groundings else groundings[clause_var_1] + if l in predicate_map: + grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1] else: grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1] return grounding @numba.njit(cache=True) -def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, label, edges): +def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges): # There are 4 cases for predicate(Y,Z): # 1. Both predicate variables Y and Z have not been encountered before # 2. The source variable Y has not been encountered before but the target variable Z has @@ -1422,8 +1420,8 @@ def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groun # Case 1: # We replace Y by all nodes and Z by the neighbors of each of these nodes if clause_var_1 not in groundings and clause_var_2 not in groundings: - if label in predicate_map: - edge_groundings = predicate_map[label] + if l in predicate_map: + edge_groundings = predicate_map[l] else: edge_groundings = edges @@ -1516,33 +1514,33 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated = False # This is to prevent a key error in case the label is a specific label world = interpretations[comp] - label, bnd = na + l, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if label not in world.world: - world.world[label] = interval.closed(0, 1) - if label in predicate_map: - predicate_map[label].append(comp) + if l not in world.world: + world.world[l] = interval.closed(0, 1) + if l in predicate_map: + predicate_map[l].append(comp) else: - predicate_map[label] = numba.typed.List([comp]) + predicate_map[l] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[label].copy() + prev_bnd = world.world[l].copy() # override will not check for inconsistencies if override: - world.world[label].set_lower_upper(bnd.lower, bnd.upper) + world.world[l].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(label, bnd) - world.world[label].set_static(static) - if world.world[label]!=prev_bnd: + world.update(l, bnd) + world.world[l].set_static(static) + if world.world[l]!=prev_bnd: updated = True - updated_bnds.append(world.world[label]) + updated_bnds.append(world.world[l]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1558,7 +1556,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == label: + if p1 == l: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1566,7 +1564,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1575,7 +1573,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == label: + if p2 == l: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1583,7 +1581,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1597,8 +1595,8 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[label] - prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) + current_bnd = world.world[l] + prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1616,33 +1614,33 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, mode, override=False): updated = False world = interpretations[comp] - label, bnd = na + l, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if label not in world.world: - world.world[label] = interval.closed(0, 1) - if label in predicate_map: - predicate_map[label].append(comp) + if l not in world.world: + world.world[l] = interval.closed(0, 1) + if l in predicate_map: + predicate_map[l].append(comp) else: - predicate_map[label] = numba.typed.List([comp]) + predicate_map[l] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[label].copy() + prev_bnd = world.world[l].copy() # override will not check for inconsistencies if override: - world.world[label].set_lower_upper(bnd.lower, bnd.upper) + world.world[l].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(label, bnd) - world.world[label].set_static(static) - if world.world[label]!=prev_bnd: + world.update(l, bnd) + world.world[l].set_static(static) + if world.world[l]!=prev_bnd: updated = True - updated_bnds.append(world.world[label]) + updated_bnds.append(world.world[l]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1658,7 +1656,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == label: + if p1 == l: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1666,7 +1664,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1675,7 +1673,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == label: + if p2 == l: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1683,7 +1681,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1697,8 +1695,8 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[label] - prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) + current_bnd = world.world[l] + prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1720,8 +1718,8 @@ def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name): @numba.njit(cache=True) def are_satisfied_node(interpretations, comp, nas): result = True - for (lbl, bnd) in nas: - result = result and is_satisfied_node(interpretations, comp, (lbl, bnd)) + for (l, bnd) in nas: + result = result and is_satisfied_node(interpretations, comp, (l, bnd)) return result @@ -1744,20 +1742,20 @@ def is_satisfied_node(interpretations, comp, na): def is_satisfied_node_comparison(interpretations, comp, na): result = False number = 0 - label, bnd = na - label_str = label.value + l, bnd = na + l_str = l.value - if not (label is None or bnd is None): + if not (l is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): - world.label_str = world.label.value - if label_str in world.label_str and world.label_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): + world_l_str = world_l.value + if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world.label_str[len(label_str)+1:]) + number = str_to_float(world_l_str[len(l_str)+1:]) break except Exception: @@ -1770,8 +1768,8 @@ def is_satisfied_node_comparison(interpretations, comp, na): @numba.njit(cache=True) def are_satisfied_edge(interpretations, comp, nas): result = True - for (lbl, bnd) in nas: - result = result and is_satisfied_edge(interpretations, comp, (lbl, bnd)) + for (l, bnd) in nas: + result = result and is_satisfied_edge(interpretations, comp, (l, bnd)) return result @@ -1794,20 +1792,20 @@ def is_satisfied_edge(interpretations, comp, na): def is_satisfied_edge_comparison(interpretations, comp, na): result = False number = 0 - label, bnd = na - label_str = label.value + l, bnd = na + l_str = l.value - if not (label is None or bnd is None): + if not (l is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): - world.label_str = world.label.value - if label_str in world.label_str and world.label_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): + world_l_str = world_l.value + if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world.label_str[len(label_str)+1:]) + number = str_to_float(world_l_str[len(l_str)+1:]) break except Exception: @@ -1945,7 +1943,7 @@ def _add_edge_to_interpretation(edge, interpretations_edge): @numba.njit(cache=True) -def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, t): +def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, t): # If not a node, add to list of nodes and initialize neighbors if source not in nodes: _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node) @@ -1955,7 +1953,7 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, # Make sure edge doesn't already exist # Make sure, if l=='', not to add the label - # Make sure, if edge exists, that we don't override the label label if it exists + # Make sure, if edge exists, that we don't override the l label if it exists edge = (source, target) new_edge = False if edge not in edges: @@ -1963,35 +1961,35 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, edges.append(edge) neighbors[source].append(target) reverse_neighbors[target].append(source) - if label.value!='': + if l.value!='': if edge not in interpretations_edge: - interpretations_edge[edge] = world.World(numba.typed.List([label])) - if label in predicate_map: - predicate_map[label].append(edge) + interpretations_edge[edge] = world.World(numba.typed.List([l])) + if l in predicate_map: + predicate_map[l].append(edge) else: - predicate_map[label] = numba.typed.List([edge]) + predicate_map[l] = numba.typed.List([edge]) else: interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type)) else: - if label not in interpretations_edge[edge].world and label.value!='': + if l not in interpretations_edge[edge].world and l.value!='': new_edge = True - interpretations_edge[edge].world[label] = interval.closed(0, 1) + interpretations_edge[edge].world[l] = interval.closed(0, 1) - if label in predicate_map: - predicate_map[label].append(edge) + if l in predicate_map: + predicate_map[l].append(edge) else: - predicate_map[label] = numba.typed.List([edge]) + predicate_map[l] = numba.typed.List([edge]) return edge, new_edge @numba.njit(cache=True) -def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, t): +def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, t): changes = 0 edges_added = numba.typed.List.empty_list(edge_type) for source in sources: for target in targets: - edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, t) + edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, t) edges_added.append(edge) changes = changes+1 if new_edge else changes return edges_added, changes @@ -2002,9 +2000,9 @@ def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge source, target = edge edges.remove(edge) del interpretations_edge[edge] - for lbl in predicate_map: - if edge in predicate_map[lbl]: - predicate_map[lbl].remove(edge) + for l in predicate_map: + if edge in predicate_map[l]: + predicate_map[l].remove(edge) neighbors[source].remove(target) reverse_neighbors[target].remove(source) @@ -2015,9 +2013,9 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node del interpretations_node[node] del neighbors[node] del reverse_neighbors[node] - for lbl in predicate_map: - if node in predicate_map[lbl]: - predicate_map[lbl].remove(node) + for l in predicate_map: + if node in predicate_map[l]: + predicate_map[l].remove(node) # Remove all occurrences of node in neighbors for n in neighbors.keys(): diff --git a/pyreason/scripts/interpretation/interpretation_parallel.py b/pyreason/scripts/interpretation/interpretation_parallel.py index 3c59026e..3f2dcdcb 100644 --- a/pyreason/scripts/interpretation/interpretation_parallel.py +++ b/pyreason/scripts/interpretation/interpretation_parallel.py @@ -106,9 +106,9 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, # Setup graph neighbors and reverse neighbors self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type)) for n in self.graph.nodes(): - neighbor_list = numba.typed.List.empty_list(node_type) - [neighbor_list.append(neigh) for neigh in self.graph.neighbors(n)] - self.neighbors[n] = neighbor_list + l = numba.typed.List.empty_list(node_type) + [l.append(neigh) for neigh in self.graph.neighbors(n)] + self.neighbors[n] = l self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors) @@ -139,10 +139,10 @@ def _init_interpretations_node(nodes, specific_labels, num_ga): interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for lbl, ns in specific_labels.items(): - predicate_map[lbl] = numba.typed.List(ns) + for l, ns in specific_labels.items(): + predicate_map[l] = numba.typed.List(ns) for n in ns: - interpretations[n].world[lbl] = interval.closed(0.0, 1.0) + interpretations[n].world[l] = interval.closed(0.0, 1.0) num_ga[0] += 1 return interpretations, predicate_map @@ -158,10 +158,10 @@ def _init_interpretations_edge(edges, specific_labels, num_ga): interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) # Specific labels - for lbl, es in specific_labels.items(): - predicate_map[lbl] = numba.typed.List(es) + for l, es in specific_labels.items(): + predicate_map[l] = numba.typed.List(es) for e in es: - interpretations[e].world[lbl] = interval.closed(0.0, 1.0) + interpretations[e].world[l] = interval.closed(0.0, 1.0) num_ga[0] += 1 return interpretations, predicate_map @@ -246,16 +246,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Reset nodes (only if not static) for n in nodes: w = interpretations_node[n].world - for label in w: - if not w[label].is_static(): - w[label].reset() + for l in w: + if not w[l].is_static(): + w[l].reset() # Reset edges (only if not static) for e in edges: w = interpretations_edge[e].world - for label in w: - if not w[label].is_static(): - w[label].reset() + for l in w: + if not w[l].is_static(): + w[l].reset() # Convergence parameters changes_cnt = 0 @@ -269,36 +269,36 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi nodes_set = set(nodes) for i in range(len(facts_to_be_applied_node)): if facts_to_be_applied_node[i][0] == t: - comp, label, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] + comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] # If the component is not in the graph, add it if comp not in nodes_set: _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node) nodes_set.add(comp) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well - if label in interpretations_node[comp].world and interpretations_node[comp].world[label].is_static(): + if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static(): # Check if we should even store any of the changes to the rule trace etc. # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, bnd)) + rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, bnd)) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_node_trace[i]) for p1, p2 in ipl: - if p1==label: + if p1==l: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_node[comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p2], facts_to_be_applied_node_trace[i]) - elif p2==label: + elif p2==l: rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i]) else: # Check for inconsistencies (multiple facts) - if check_consistent_node(interpretations_node, comp, (label, bnd)): + if check_consistent_node(interpretations_node, comp, (l, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) update = u or update # Update convergence params @@ -310,9 +310,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) update = u or update # Update convergence params @@ -322,7 +322,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, label, bnd, static, graph_attribute)) + facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i]) @@ -345,34 +345,34 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi edges_set = set(edges) for i in range(len(facts_to_be_applied_edge)): if facts_to_be_applied_edge[i][0]==t: - comp, label, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] + comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] # If the component is not in the graph, add it if comp not in edges_set: _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) edges_set.add(comp) # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well - if label in interpretations_edge[comp].world and interpretations_edge[comp].world[label].is_static(): + if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static(): # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, label, interpretations_edge[comp].world[label])) + rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, interpretations_edge[comp].world[l])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_edge_trace[i]) for p1, p2 in ipl: - if p1==label: + if p1==l: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_edge[comp].world[p2])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p2], facts_to_be_applied_edge_trace[i]) - elif p2==label: + elif p2==l: rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_edge[comp].world[p1])) if atom_trace: _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p1], facts_to_be_applied_edge_trace[i]) else: # Check for inconsistencies - if check_consistent_edge(interpretations_edge, comp, (label, bnd)): + if check_consistent_edge(interpretations_edge, comp, (l, bnd)): mode = 'graph-attribute-fact' if graph_attribute else 'fact' override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) update = u or update # Update convergence params @@ -384,9 +384,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: mode = 'graph-attribute-fact' if graph_attribute else 'fact' if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (label, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) + resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) update = u or update # Update convergence params @@ -396,7 +396,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi changes_cnt += changes if static: - facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, label, bnd, static, graph_attribute)) + facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute)) if atom_trace: facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i]) @@ -423,11 +423,11 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_node): if i[0] == t: - comp, label, bnd, set_static = i[1], i[2], i[3], i[4] + comp, l, bnd, set_static = i[1], i[2], i[3], i[4] # Check for inconsistencies - if check_consistent_node(interpretations_node, comp, (label, bnd)): + if check_consistent_node(interpretations_node, comp, (l, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) update = u or update # Update convergence params @@ -438,9 +438,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (label, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) update = u or update # Update convergence params @@ -462,7 +462,7 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi rules_to_remove_idx.clear() for idx, i in enumerate(rules_to_be_applied_edge): if i[0] == t: - comp, label, bnd, set_static = i[1], i[2], i[3], i[4] + comp, l, bnd, set_static = i[1], i[2], i[3], i[4] sources, targets, edge_l = edges_to_be_added_edge_rule[idx] edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) changes_cnt += changes @@ -500,9 +500,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi else: # Check for inconsistencies - if check_consistent_edge(interpretations_edge, comp, (label, bnd)): + if check_consistent_edge(interpretations_edge, comp, (l, bnd)): override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) update = u or update # Update convergence params @@ -513,9 +513,9 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi # Resolve inconsistency else: if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (label, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') + resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (label, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) + u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) update = u or update # Update convergence params @@ -644,16 +644,16 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi return fp_cnt, t - def add_edge(self, edge, label): + def add_edge(self, edge, l): # This function is useful for pyreason gym, called externally - _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, label, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1) + _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1) def add_node(self, node, labels): # This function is useful for pyreason gym, called externally if node not in self.nodes: _add_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node) - for lbl in labels: - self.interpretations_node[node].world[label.Label(lbl)] = interval.closed(0, 1) + for l in labels: + self.interpretations_node[node].world[label.Label(l)] = interval.closed(0, 1) def delete_edge(self, edge): # This function is useful for pyreason gym, called externally @@ -678,23 +678,23 @@ def get_dict(self): # Update interpretation nodes for change in self.rule_trace_node: - time, _, node, label, bnd = change - interpretations[time][node][label._value] = (bnd.lower, bnd.upper) + time, _, node, l, bnd = change + interpretations[time][node][l._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][node][label._value] = (bnd.lower, bnd.upper) + interpretations[t][node][l._value] = (bnd.lower, bnd.upper) # Update interpretation edges for change in self.rule_trace_edge: - time, _, edge, label, bnd, = change - interpretations[time][edge][label._value] = (bnd.lower, bnd.upper) + time, _, edge, l, bnd, = change + interpretations[time][edge][l._value] = (bnd.lower, bnd.upper) # If persistent, update all following timesteps as well if self. persistent: for t in range(time+1, self.time+1): - interpretations[t][edge][label._value] = (bnd.lower, bnd.upper) + interpretations[t][edge][l._value] = (bnd.lower, bnd.upper) return interpretations @@ -706,10 +706,10 @@ def get_final_num_ground_atoms(self): ga_cnt = 0 for node in self.nodes: - for lbl in self.interpretations_node[node].world: + for l in self.interpretations_node[node].world: ga_cnt += 1 for edge in self.edges: - for lbl in self.interpretations_edge[edge].world: + for l in self.interpretations_edge[edge].world: ga_cnt += 1 return ga_cnt @@ -807,7 +807,7 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map clause_label = clause[1] clause_variables = clause[2] clause_bnd = clause[3] - clause_operator = clause[4] + _clause_operator = clause[4] # This is a node clause if clause_type == 'node': @@ -1303,17 +1303,17 @@ def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, @numba.njit(cache=True) -def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, label, nodes): +def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes): # The groundings for a node clause can be either a previous grounding or all possible nodes - if label in predicate_map: - grounding = predicate_map[label] if clause_var_1 not in groundings else groundings[clause_var_1] + if l in predicate_map: + grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1] else: grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1] return grounding @numba.njit(cache=True) -def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, label, edges): +def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges): # There are 4 cases for predicate(Y,Z): # 1. Both predicate variables Y and Z have not been encountered before # 2. The source variable Y has not been encountered before but the target variable Z has @@ -1324,8 +1324,8 @@ def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groun # Case 1: # We replace Y by all nodes and Z by the neighbors of each of these nodes if clause_var_1 not in groundings and clause_var_2 not in groundings: - if label in predicate_map: - edge_groundings = predicate_map[label] + if l in predicate_map: + edge_groundings = predicate_map[l] else: edge_groundings = edges @@ -1419,34 +1419,34 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] - label, bnd = na + l, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if label not in world.world: - world.world[label] = interval.closed(0, 1) + if l not in world.world: + world.world[l] = interval.closed(0, 1) num_ga[t_cnt] += 1 - if label in predicate_map: - predicate_map[label].append(comp) + if l in predicate_map: + predicate_map[l].append(comp) else: - predicate_map[label] = numba.typed.List([comp]) + predicate_map[l] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[label].copy() + prev_bnd = world.world[l].copy() # override will not check for inconsistencies if override: - world.world[label].set_lower_upper(bnd.lower, bnd.upper) + world.world[l].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(label, bnd) - world.world[label].set_static(static) - if world.world[label]!=prev_bnd: + world.update(l, bnd) + world.world[l].set_static(static) + if world.world[l]!=prev_bnd: updated = True - updated_bnds.append(world.world[label]) + updated_bnds.append(world.world[l]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1462,7 +1462,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == label: + if p1 == l: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1470,7 +1470,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1479,7 +1479,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == label: + if p2 == l: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1487,7 +1487,7 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1501,8 +1501,8 @@ def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[label] - prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) + current_bnd = world.world[l] + prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1525,34 +1525,34 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] - label, bnd = na + l, bnd = na updated_bnds = numba.typed.List.empty_list(interval.interval_type) # Add label to world if it is not there - if label not in world.world: - world.world[label] = interval.closed(0, 1) + if l not in world.world: + world.world[l] = interval.closed(0, 1) num_ga[t_cnt] += 1 - if label in predicate_map: - predicate_map[label].append(comp) + if l in predicate_map: + predicate_map[l].append(comp) else: - predicate_map[label] = numba.typed.List([comp]) + predicate_map[l] = numba.typed.List([comp]) # Check if update is necessary with previous bnd - prev_bnd = world.world[label].copy() + prev_bnd = world.world[l].copy() # override will not check for inconsistencies if override: - world.world[label].set_lower_upper(bnd.lower, bnd.upper) + world.world[l].set_lower_upper(bnd.lower, bnd.upper) else: - world.update(label, bnd) - world.world[label].set_static(static) - if world.world[label]!=prev_bnd: + world.update(l, bnd) + world.world[l].set_static(static) + if world.world[l]!=prev_bnd: updated = True - updated_bnds.append(world.world[label]) + updated_bnds.append(world.world[l]) # Add to rule trace if update happened and add to atom trace if necessary if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, label, world.world[label].copy())) + rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) if atom_trace: # Mode can be fact or rule, updation of trace will happen accordingly if mode=='fact' or mode=='graph-attribute-fact': @@ -1568,7 +1568,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c if updated: ip_update_cnt = 0 for p1, p2 in ipl: - if p1 == label: + if p1 == l: if p2 not in world.world: world.world[p2] = interval.closed(0, 1) if p2 in predicate_map: @@ -1576,7 +1576,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p2] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') lower = max(world.world[p2].lower, 1 - world.world[p1].upper) upper = min(world.world[p2].upper, 1 - world.world[p1].lower) world.world[p2].set_lower_upper(lower, upper) @@ -1585,7 +1585,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c updated_bnds.append(world.world[p2]) if store_interpretation_changes: rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == label: + if p2 == l: if p1 not in world.world: world.world[p1] = interval.closed(0, 1) if p1 in predicate_map: @@ -1593,7 +1593,7 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c else: predicate_map[p1] = numba.typed.List([comp]) if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {label.get_value()}') + _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') lower = max(world.world[p1].lower, 1 - world.world[p2].upper) upper = min(world.world[p1].upper, 1 - world.world[p2].lower) world.world[p1].set_lower_upper(lower, upper) @@ -1607,8 +1607,8 @@ def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_c change = 0 if updated: # Find out if it has changed from previous interp - current_bnd = world.world[label] - prev_t_bnd = interval.closed(world.world[label].prev_lower, world.world[label].prev_upper) + current_bnd = world.world[l] + prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) if current_bnd != prev_t_bnd: if convergence_mode=='delta_bound': for i in updated_bnds: @@ -1632,8 +1632,8 @@ def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name): @numba.njit(cache=True) def are_satisfied_node(interpretations, comp, nas): result = True - for (lbl, bnd) in nas: - result = result and is_satisfied_node(interpretations, comp, (lbl, bnd)) + for (l, bnd) in nas: + result = result and is_satisfied_node(interpretations, comp, (l, bnd)) return result @@ -1656,20 +1656,20 @@ def is_satisfied_node(interpretations, comp, na): def is_satisfied_node_comparison(interpretations, comp, na): result = False number = 0 - label, bnd = na - label_str = label.value + l, bnd = na + l_str = l.value - if not (label is None or bnd is None): + if not (l is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): - world.label_str = world.label.value - if label_str in world.label_str and world.label_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): + world_l_str = world_l.value + if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world.label_str[len(label_str)+1:]) + number = str_to_float(world_l_str[len(l_str)+1:]) break except Exception: @@ -1682,8 +1682,8 @@ def is_satisfied_node_comparison(interpretations, comp, na): @numba.njit(cache=True) def are_satisfied_edge(interpretations, comp, nas): result = True - for (lbl, bnd) in nas: - result = result and is_satisfied_edge(interpretations, comp, (lbl, bnd)) + for (l, bnd) in nas: + result = result and is_satisfied_edge(interpretations, comp, (l, bnd)) return result @@ -1706,20 +1706,20 @@ def is_satisfied_edge(interpretations, comp, na): def is_satisfied_edge_comparison(interpretations, comp, na): result = False number = 0 - label, bnd = na - label_str = label.value + l, bnd = na + l_str = l.value - if not (label is None or bnd is None): + if not (l is None or bnd is None): # This is to prevent a key error in case the label is a specific label try: world = interpretations[comp] for world_l in world.world.keys(): - world.label_str = world.label.value - if label_str in world.label_str and world.label_str[len(label_str)+1:].replace('.', '').replace('-', '').isdigit(): + world_l_str = world_l.value + if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): # The label is contained in the world result = world.is_satisfied(world_l, na[1]) # Find the suffix number - number = str_to_float(world.label_str[len(label_str)+1:]) + number = str_to_float(world_l_str[len(l_str)+1:]) break except Exception: @@ -1846,7 +1846,7 @@ def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): @numba.njit(cache=True) -def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t): +def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): # If not a node, add to list of nodes and initialize neighbors if source not in nodes: _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node) @@ -1856,7 +1856,7 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, # Make sure edge doesn't already exist # Make sure, if l=='', not to add the label - # Make sure, if edge exists, that we don't override the label label if it exists + # Make sure, if edge exists, that we don't override the l label if it exists edge = (source, target) new_edge = False if edge not in edges: @@ -1864,36 +1864,36 @@ def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, edges.append(edge) neighbors[source].append(target) reverse_neighbors[target].append(source) - if label.value!='': - interpretations_edge[edge] = world.World(numba.typed.List([label])) + if l.value!='': + interpretations_edge[edge] = world.World(numba.typed.List([l])) num_ga[t] += 1 - if label in predicate_map: - predicate_map[label].append(edge) + if l in predicate_map: + predicate_map[l].append(edge) else: - predicate_map[label] = numba.typed.List([edge]) + predicate_map[l] = numba.typed.List([edge]) else: interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type)) else: - if label not in interpretations_edge[edge].world and label.value!='': + if l not in interpretations_edge[edge].world and l.value!='': new_edge = True - interpretations_edge[edge].world[label] = interval.closed(0, 1) + interpretations_edge[edge].world[l] = interval.closed(0, 1) num_ga[t] += 1 - if label in predicate_map: - predicate_map[label].append(edge) + if l in predicate_map: + predicate_map[l].append(edge) else: - predicate_map[label] = numba.typed.List([edge]) + predicate_map[l] = numba.typed.List([edge]) return edge, new_edge @numba.njit(cache=True) -def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t): +def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): changes = 0 edges_added = numba.typed.List.empty_list(edge_type) for source in sources: for target in targets: - edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, label, interpretations_node, interpretations_edge, predicate_map, num_ga, t) + edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t) edges_added.append(edge) changes = changes+1 if new_edge else changes return edges_added, changes @@ -1905,9 +1905,9 @@ def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge edges.remove(edge) num_ga[-1] -= len(interpretations_edge[edge].world) del interpretations_edge[edge] - for lbl in predicate_map: - if edge in predicate_map[lbl]: - predicate_map[lbl].remove(edge) + for l in predicate_map: + if edge in predicate_map[l]: + predicate_map[l].remove(edge) neighbors[source].remove(target) reverse_neighbors[target].remove(source) @@ -1919,9 +1919,9 @@ def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node del interpretations_node[node] del neighbors[node] del reverse_neighbors[node] - for lbl in predicate_map: - if node in predicate_map[lbl]: - predicate_map[lbl].remove(node) + for l in predicate_map: + if node in predicate_map[l]: + predicate_map[l].remove(node) # Remove all occurrences of node in neighbors for n in neighbors.keys(): @@ -1964,4 +1964,4 @@ def str_to_int(value): for i, v in enumerate(value): result += (ord(v) - 48) * (10 ** (final_index - i)) result = -result if negative else result - return result + return result \ No newline at end of file From e15a4bcbb90eb452969e3c59dfe179b766f68648 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sun, 5 Oct 2025 14:56:45 -0400 Subject: [PATCH 03/32] Rm debug --- debug.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) delete mode 100644 debug.py diff --git a/debug.py b/debug.py deleted file mode 100644 index 52e29773..00000000 --- a/debug.py +++ /dev/null @@ -1,38 +0,0 @@ -import pyreason as pr - -def test_anyBurl_rule_1_fp(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.fp_version = True # Use the FP version of the reasoner - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - -if __name__ == "__main__": - test_anyBurl_rule_1_fp() \ No newline at end of file From 2db5fb6f61326e0d0b0065031eab50571add257a Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 9 Oct 2025 11:01:49 -0400 Subject: [PATCH 04/32] Add linter to pre-commit hooks --- .pre-commit-config.yaml | 8 + .../interpretation/interpretation.py.bak | 1967 ----------------- 2 files changed, 8 insertions(+), 1967 deletions(-) delete mode 100755 pyreason/scripts/interpretation/interpretation.py.bak diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 28a4906f..0b94a67e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,14 @@ repos: - repo: local hooks: # --- COMMIT STAGE: Fast unit tests only --- + - id: ruff-check + name: Run ruff linter + entry: .venv/bin/python -m ruff check + language: system + types: [python] + pass_filenames: false + stages: [pre-commit] + - id: pytest-unit-no-jit name: Run JIT-disabled unit tests entry: .venv/bin/python -m pytest tests/unit/disable_jit -m "not slow" --tb=short -q diff --git a/pyreason/scripts/interpretation/interpretation.py.bak b/pyreason/scripts/interpretation/interpretation.py.bak deleted file mode 100755 index 81bf6bb0..00000000 --- a/pyreason/scripts/interpretation/interpretation.py.bak +++ /dev/null @@ -1,1967 +0,0 @@ -from typing import Union, Tuple - -import pyreason.scripts.numba_wrapper.numba_types.world_type as world -import pyreason.scripts.numba_wrapper.numba_types.label_type as label -import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval -from pyreason.scripts.interpretation.interpretation_dict import InterpretationDict - -import numba -from numba import objmode, prange - - -# Types for the dictionaries -node_type = numba.types.string -edge_type = numba.types.UniTuple(numba.types.string, 2) - -# Type for storing list of qualified nodes/edges -list_of_nodes = numba.types.ListType(node_type) -list_of_edges = numba.types.ListType(edge_type) - -# Type for storing clause data -clause_data = numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string))) - -# Type for storing refine clause data -refine_data = numba.types.Tuple((numba.types.string, numba.types.string, numba.types.int8)) - -# Type for facts to be applied -facts_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) -facts_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean, numba.types.boolean)) - -# Type for returning list of applicable rules for a certain rule -# node/edge, annotations, qualified nodes, qualified edges, edges to be added -node_applicable_rule_type = numba.types.Tuple(( - node_type, - numba.types.ListType(numba.types.ListType(interval.interval_type)), - numba.types.ListType(numba.types.ListType(node_type)), - numba.types.ListType(numba.types.ListType(edge_type)), - numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)) -)) - -edge_applicable_rule_type = numba.types.Tuple(( - edge_type, - numba.types.ListType(numba.types.ListType(interval.interval_type)), - numba.types.ListType(numba.types.ListType(node_type)), - numba.types.ListType(numba.types.ListType(edge_type)), - numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)) -)) - -rules_to_be_applied_node_type = numba.types.Tuple((numba.types.uint16, node_type, label.label_type, interval.interval_type, numba.types.boolean)) -rules_to_be_applied_edge_type = numba.types.Tuple((numba.types.uint16, edge_type, label.label_type, interval.interval_type, numba.types.boolean)) -rules_to_be_applied_trace_type = numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), numba.types.string)) -edges_to_be_added_type = numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type)) - - -class Interpretation: - specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(node_type)) - specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(edge_type)) - - def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules): - self.graph = graph - self.ipl = ipl - self.annotation_functions = annotation_functions - self.reverse_graph = reverse_graph - self.atom_trace = atom_trace - self.save_graph_attributes_to_rule_trace = save_graph_attributes_to_rule_trace - self.persistent = persistent - self.inconsistency_check = inconsistency_check - self.store_interpretation_changes = store_interpretation_changes - self.update_mode = update_mode - self.allow_ground_rules = allow_ground_rules - - # Counter for number of ground atoms for each timestep, start with zero for the zeroth timestep - self.num_ga = numba.typed.List.empty_list(numba.types.int64) - self.num_ga.append(0) - - # For reasoning and reasoning again (contains previous time and previous fp operation cnt) - self.time = 0 - self.prev_reasoning_data = numba.typed.List([0, 0]) - - # Initialize list of tuples for rules/facts to be applied, along with all the ground atoms that fired the rule. One to One correspondence between rules_to_be_applied_node and rules_to_be_applied_node_trace if atom_trace is true - self.rules_to_be_applied_node_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type) - self.rules_to_be_applied_edge_trace = numba.typed.List.empty_list(rules_to_be_applied_trace_type) - self.facts_to_be_applied_node_trace = numba.typed.List.empty_list(numba.types.string) - self.facts_to_be_applied_edge_trace = numba.typed.List.empty_list(numba.types.string) - self.rules_to_be_applied_node = numba.typed.List.empty_list(rules_to_be_applied_node_type) - self.rules_to_be_applied_edge = numba.typed.List.empty_list(rules_to_be_applied_edge_type) - self.facts_to_be_applied_node = numba.typed.List.empty_list(facts_to_be_applied_node_type) - self.facts_to_be_applied_edge = numba.typed.List.empty_list(facts_to_be_applied_edge_type) - self.edges_to_be_added_node_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))) - self.edges_to_be_added_edge_rule = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(node_type), numba.types.ListType(node_type), label.label_type))) - - # Keep track of all the rules that have affected each node/edge at each timestep/fp operation, and all ground atoms that have affected the rules as well. Keep track of previous bounds and name of the rule/fact here - self.rule_trace_node_atoms = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), interval.interval_type, numba.types.string))) - self.rule_trace_edge_atoms = numba.typed.List.empty_list(numba.types.Tuple((numba.types.ListType(numba.types.ListType(node_type)), numba.types.ListType(numba.types.ListType(edge_type)), interval.interval_type, numba.types.string))) - self.rule_trace_node = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, numba.types.uint16, node_type, label.label_type, interval.interval_type))) - self.rule_trace_edge = numba.typed.List.empty_list(numba.types.Tuple((numba.types.uint16, numba.types.uint16, edge_type, label.label_type, interval.interval_type))) - - # Nodes and edges of the graph - self.nodes = numba.typed.List.empty_list(node_type) - self.edges = numba.typed.List.empty_list(edge_type) - self.nodes.extend(numba.typed.List(self.graph.nodes())) - self.edges.extend(numba.typed.List(self.graph.edges())) - - self.interpretations_node, self.predicate_map_node = self._init_interpretations_node(self.nodes, self.specific_node_labels, self.num_ga) - self.interpretations_edge, self.predicate_map_edge = self._init_interpretations_edge(self.edges, self.specific_edge_labels, self.num_ga) - - # Setup graph neighbors and reverse neighbors - self.neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=numba.types.ListType(node_type)) - for n in self.graph.nodes(): - l = numba.typed.List.empty_list(node_type) - [l.append(neigh) for neigh in self.graph.neighbors(n)] - self.neighbors[n] = l - - self.reverse_neighbors = self._init_reverse_neighbors(self.neighbors) - - @staticmethod - @numba.njit(cache=True) - def _init_reverse_neighbors(neighbors): - reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) - for n, neighbor_nodes in neighbors.items(): - for neighbor_node in neighbor_nodes: - if neighbor_node in reverse_neighbors and n not in reverse_neighbors[neighbor_node]: - reverse_neighbors[neighbor_node].append(n) - else: - reverse_neighbors[neighbor_node] = numba.typed.List([n]) - # This makes sure each node has a value - if n not in reverse_neighbors: - reverse_neighbors[n] = numba.typed.List.empty_list(node_type) - - return reverse_neighbors - - @staticmethod - @numba.njit(cache=True) - def _init_interpretations_node(nodes, specific_labels, num_ga): - interpretations = numba.typed.Dict.empty(key_type=node_type, value_type=world.world_type) - predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_nodes) - - # Initialize nodes - for n in nodes: - interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) - - # Specific labels - for l, ns in specific_labels.items(): - predicate_map[l] = numba.typed.List(ns) - for n in ns: - interpretations[n].world[l] = interval.closed(0.0, 1.0) - num_ga[0] += 1 - - return interpretations, predicate_map - - @staticmethod - @numba.njit(cache=True) - def _init_interpretations_edge(edges, specific_labels, num_ga): - interpretations = numba.typed.Dict.empty(key_type=edge_type, value_type=world.world_type) - predicate_map = numba.typed.Dict.empty(key_type=label.label_type, value_type=list_of_edges) - - # Initialize edges - for n in edges: - interpretations[n] = world.World(numba.typed.List.empty_list(label.label_type)) - - # Specific labels - for l, es in specific_labels.items(): - predicate_map[l] = numba.typed.List(es) - for e in es: - interpretations[e].world[l] = interval.closed(0.0, 1.0) - num_ga[0] += 1 - - return interpretations, predicate_map - - @staticmethod - @numba.njit(cache=True) - def _init_convergence(convergence_bound_threshold, convergence_threshold): - if convergence_bound_threshold==-1 and convergence_threshold==-1: - convergence_mode = 'perfect_convergence' - convergence_delta = 0 - elif convergence_bound_threshold==-1: - convergence_mode = 'delta_interpretation' - convergence_delta = convergence_threshold - else: - convergence_mode = 'delta_bound' - convergence_delta = convergence_bound_threshold - return convergence_mode, convergence_delta - - def start_fp(self, tmax, facts_node, facts_edge, rules, verbose, convergence_threshold, convergence_bound_threshold, again=False, restart=True): - self.tmax = tmax - self._convergence_mode, self._convergence_delta = self._init_convergence(convergence_bound_threshold, convergence_threshold) - max_facts_time = self._init_facts(facts_node, facts_edge, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.atom_trace) - self._start_fp(rules, max_facts_time, verbose, again, restart) - - @staticmethod - @numba.njit(cache=True) - def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, atom_trace): - max_time = 0 - for fact in facts_node: - for t in range(fact.get_time_lower(), fact.get_time_upper() + 1): - max_time = max(max_time, t) - name = fact.get_name() - graph_attribute = True if name=='graph-attribute-fact' else False - facts_to_be_applied_node.append((numba.types.uint16(t), fact.get_component(), fact.get_label(), fact.get_bound(), fact.static, graph_attribute)) - if atom_trace: - facts_to_be_applied_node_trace.append(fact.get_name()) - for fact in facts_edge: - for t in range(fact.get_time_lower(), fact.get_time_upper() + 1): - max_time = max(max_time, t) - name = fact.get_name() - graph_attribute = True if name=='graph-attribute-fact' else False - facts_to_be_applied_edge.append((numba.types.uint16(t), fact.get_component(), fact.get_label(), fact.get_bound(), fact.static, graph_attribute)) - if atom_trace: - facts_to_be_applied_edge_trace.append(fact.get_name()) - return max_time - - def _start_fp(self, rules, max_facts_time, verbose, again, restart): - if again: - self.num_ga.append(self.num_ga[-1]) - if restart: - self.time = 0 - self.prev_reasoning_data[0] = 0 - fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.predicate_map_node, self.predicate_map_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.persistent, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, self.num_ga, verbose, again) - self.time = t - 1 - # If we need to reason again, store the next timestep to start from - self.prev_reasoning_data[0] = t - self.prev_reasoning_data[1] = fp_cnt - if verbose: - print('Fixed Point iterations:', fp_cnt) - - @staticmethod - @numba.njit(cache=True, parallel=False) - def reason(interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, persistent, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, convergence_mode, convergence_delta, num_ga, verbose, again): - t = prev_reasoning_data[0] - fp_cnt = prev_reasoning_data[1] - max_rules_time = 0 - timestep_loop = True - facts_to_be_applied_node_new = numba.typed.List.empty_list(facts_to_be_applied_node_type) - facts_to_be_applied_edge_new = numba.typed.List.empty_list(facts_to_be_applied_edge_type) - facts_to_be_applied_node_trace_new = numba.typed.List.empty_list(numba.types.string) - facts_to_be_applied_edge_trace_new = numba.typed.List.empty_list(numba.types.string) - rules_to_remove_idx = set() - rules_to_remove_idx.add(-1) - while timestep_loop: - if t==tmax: - timestep_loop = False - if verbose: - with objmode(): - print('Timestep:', t, flush=True) - # Reset Interpretation at beginning of timestep if non-persistent - if t>0 and not persistent: - # Reset nodes (only if not static) - for n in nodes: - w = interpretations_node[n].world - for l in w: - if not w[l].is_static(): - w[l].reset() - - # Reset edges (only if not static) - for e in edges: - w = interpretations_edge[e].world - for l in w: - if not w[l].is_static(): - w[l].reset() - - # Convergence parameters - changes_cnt = 0 - bound_delta = 0 - update = False - - # Start by applying facts - # Nodes - facts_to_be_applied_node_new.clear() - facts_to_be_applied_node_trace_new.clear() - nodes_set = set(nodes) - for i in range(len(facts_to_be_applied_node)): - if facts_to_be_applied_node[i][0] == t: - comp, l, bnd, static, graph_attribute = facts_to_be_applied_node[i][1], facts_to_be_applied_node[i][2], facts_to_be_applied_node[i][3], facts_to_be_applied_node[i][4], facts_to_be_applied_node[i][5] - # If the component is not in the graph, add it - if comp not in nodes_set: - _add_node(comp, neighbors, reverse_neighbors, nodes, interpretations_node) - nodes_set.add(comp) - - # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute and add ipl complement to rule trace as well - if l in interpretations_node[comp].world and interpretations_node[comp].world[l].is_static(): - # Check if we should even store any of the changes to the rule trace etc. - # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute - if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, bnd)) - if atom_trace: - _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_node_trace[i]) - for p1, p2 in ipl: - if p1==l: - rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_node[comp].world[p2])) - if atom_trace: - _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p2], facts_to_be_applied_node_trace[i]) - elif p2==l: - rule_trace_node.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_node[comp].world[p1])) - if atom_trace: - _update_rule_trace(rule_trace_node_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_node[comp].world[p1], facts_to_be_applied_node_trace[i]) - - else: - # Check for inconsistencies (multiple facts) - if check_consistent_node(interpretations_node, comp, (l, bnd)): - mode = 'graph-attribute-fact' if graph_attribute else 'fact' - override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) - - update = u or update - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - # Resolve inconsistency if necessary otherwise override bounds - else: - mode = 'graph-attribute-fact' if graph_attribute else 'fact' - if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode=mode) - else: - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, i, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) - - update = u or update - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - - if static: - facts_to_be_applied_node_new.append((numba.types.uint16(facts_to_be_applied_node[i][0]+1), comp, l, bnd, static, graph_attribute)) - if atom_trace: - facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i]) - - # If time doesn't match, fact to be applied later - else: - facts_to_be_applied_node_new.append(facts_to_be_applied_node[i]) - if atom_trace: - facts_to_be_applied_node_trace_new.append(facts_to_be_applied_node_trace[i]) - - # Update list of facts with ones that have not been applied yet (delete applied facts) - facts_to_be_applied_node[:] = facts_to_be_applied_node_new.copy() - if atom_trace: - facts_to_be_applied_node_trace[:] = facts_to_be_applied_node_trace_new.copy() - facts_to_be_applied_node_new.clear() - facts_to_be_applied_node_trace_new.clear() - - # Edges - facts_to_be_applied_edge_new.clear() - facts_to_be_applied_edge_trace_new.clear() - edges_set = set(edges) - for i in range(len(facts_to_be_applied_edge)): - if facts_to_be_applied_edge[i][0]==t: - comp, l, bnd, static, graph_attribute = facts_to_be_applied_edge[i][1], facts_to_be_applied_edge[i][2], facts_to_be_applied_edge[i][3], facts_to_be_applied_edge[i][4], facts_to_be_applied_edge[i][5] - # If the component is not in the graph, add it - if comp not in edges_set: - _add_edge(comp[0], comp[1], neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) - edges_set.add(comp) - - # Check if bnd is static. Then no need to update, just add to rule trace, check if graph attribute, and add ipl complement to rule trace as well - if l in interpretations_edge[comp].world and interpretations_edge[comp].world[l].is_static(): - # Inverse of this is: if not save_graph_attributes_to_rule_trace and graph_attribute - if (save_graph_attributes_to_rule_trace or not graph_attribute) and store_interpretation_changes: - rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, l, interpretations_edge[comp].world[l])) - if atom_trace: - _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), bnd, facts_to_be_applied_edge_trace[i]) - for p1, p2 in ipl: - if p1==l: - rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p2, interpretations_edge[comp].world[p2])) - if atom_trace: - _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p2], facts_to_be_applied_edge_trace[i]) - elif p2==l: - rule_trace_edge.append((numba.types.uint16(t), numba.types.uint16(fp_cnt), comp, p1, interpretations_edge[comp].world[p1])) - if atom_trace: - _update_rule_trace(rule_trace_edge_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), interpretations_edge[comp].world[p1], facts_to_be_applied_edge_trace[i]) - else: - # Check for inconsistencies - if check_consistent_edge(interpretations_edge, comp, (l, bnd)): - mode = 'graph-attribute-fact' if graph_attribute else 'fact' - override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=override) - - update = u or update - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - # Resolve inconsistency - else: - mode = 'graph-attribute-fact' if graph_attribute else 'fact' - if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, i, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode=mode) - else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, i, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode=mode, override=True) - - update = u or update - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - - if static: - facts_to_be_applied_edge_new.append((numba.types.uint16(facts_to_be_applied_edge[i][0]+1), comp, l, bnd, static, graph_attribute)) - if atom_trace: - facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i]) - - # Time doesn't match, fact to be applied later - else: - facts_to_be_applied_edge_new.append(facts_to_be_applied_edge[i]) - if atom_trace: - facts_to_be_applied_edge_trace_new.append(facts_to_be_applied_edge_trace[i]) - - # Update list of facts with ones that have not been applied yet (delete applied facts) - facts_to_be_applied_edge[:] = facts_to_be_applied_edge_new.copy() - if atom_trace: - facts_to_be_applied_edge_trace[:] = facts_to_be_applied_edge_trace_new.copy() - facts_to_be_applied_edge_new.clear() - facts_to_be_applied_edge_trace_new.clear() - - in_loop = True - while in_loop: - # This will become true only if delta_t = 0 for some rule, otherwise we go to the next timestep - in_loop = False - - # Apply the rules that need to be applied at this timestep - # Nodes - rules_to_remove_idx.clear() - for idx, i in enumerate(rules_to_be_applied_node): - if i[0] == t: - comp, l, bnd, set_static = i[1], i[2], i[3], i[4] - # Check for inconsistencies - if check_consistent_node(interpretations_node, comp, (l, bnd)): - override = True if update_mode == 'override' else False - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) - - update = u or update - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - # Resolve inconsistency - else: - if inconsistency_check: - resolve_inconsistency_node(interpretations_node, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_node, rule_trace_node_atoms, rules_to_be_applied_node_trace, facts_to_be_applied_node_trace, store_interpretation_changes, mode='rule') - else: - u, changes = _update_node(interpretations_node, predicate_map_node, comp, (l, bnd), ipl, rule_trace_node, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_node_trace, idx, facts_to_be_applied_node_trace, rule_trace_node_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) - - update = u or update - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - - # Delete rules that have been applied from list by adding index to list - rules_to_remove_idx.add(idx) - - # Remove from rules to be applied and edges to be applied lists after coming out from loop - rules_to_be_applied_node[:] = numba.typed.List([rules_to_be_applied_node[i] for i in range(len(rules_to_be_applied_node)) if i not in rules_to_remove_idx]) - edges_to_be_added_node_rule[:] = numba.typed.List([edges_to_be_added_node_rule[i] for i in range(len(edges_to_be_added_node_rule)) if i not in rules_to_remove_idx]) - if atom_trace: - rules_to_be_applied_node_trace[:] = numba.typed.List([rules_to_be_applied_node_trace[i] for i in range(len(rules_to_be_applied_node_trace)) if i not in rules_to_remove_idx]) - - # Edges - rules_to_remove_idx.clear() - for idx, i in enumerate(rules_to_be_applied_edge): - if i[0] == t: - comp, l, bnd, set_static = i[1], i[2], i[3], i[4] - sources, targets, edge_l = edges_to_be_added_edge_rule[idx] - edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) - changes_cnt += changes - - # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally - if edge_l.value != '': - for e in edges_added: - if interpretations_edge[e].world[edge_l].is_static(): - continue - if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)): - override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) - - update = u or update - - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - # Resolve inconsistency - else: - if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, e, (edge_l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') - else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) - - update = u or update - - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - - else: - # Check for inconsistencies - if check_consistent_edge(interpretations_edge, comp, (l, bnd)): - override = True if update_mode == 'override' else False - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) - - update = u or update - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - # Resolve inconsistency - else: - if inconsistency_check: - resolve_inconsistency_edge(interpretations_edge, comp, (l, bnd), ipl, t, fp_cnt, idx, atom_trace, rule_trace_edge, rule_trace_edge_atoms, rules_to_be_applied_edge_trace, facts_to_be_applied_edge_trace, store_interpretation_changes, mode='rule') - else: - u, changes = _update_edge(interpretations_edge, predicate_map_edge, comp, (l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=True) - - update = u or update - # Update convergence params - if convergence_mode=='delta_bound': - bound_delta = max(bound_delta, changes) - else: - changes_cnt += changes - - # Delete rules that have been applied from list by adding the index to list - rules_to_remove_idx.add(idx) - - # Remove from rules to be applied and edges to be applied lists after coming out from loop - rules_to_be_applied_edge[:] = numba.typed.List([rules_to_be_applied_edge[i] for i in range(len(rules_to_be_applied_edge)) if i not in rules_to_remove_idx]) - edges_to_be_added_edge_rule[:] = numba.typed.List([edges_to_be_added_edge_rule[i] for i in range(len(edges_to_be_added_edge_rule)) if i not in rules_to_remove_idx]) - if atom_trace: - rules_to_be_applied_edge_trace[:] = numba.typed.List([rules_to_be_applied_edge_trace[i] for i in range(len(rules_to_be_applied_edge_trace)) if i not in rules_to_remove_idx]) - - # Fixed point - if update: - # Increase fp operator count - fp_cnt += 1 - - # Lists or threadsafe operations (when parallel is on) - rules_to_be_applied_node_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_node_type) for _ in range(len(rules))]) - rules_to_be_applied_edge_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_edge_type) for _ in range(len(rules))]) - if atom_trace: - rules_to_be_applied_node_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))]) - rules_to_be_applied_edge_trace_threadsafe = numba.typed.List([numba.typed.List.empty_list(rules_to_be_applied_trace_type) for _ in range(len(rules))]) - edges_to_be_added_edge_rule_threadsafe = numba.typed.List([numba.typed.List.empty_list(edges_to_be_added_type) for _ in range(len(rules))]) - - for i in prange(len(rules)): - rule = rules[i] - - # Only go through if the rule can be applied within the given timesteps, or we're running until convergence - delta_t = rule.get_delta() - if t + delta_t <= tmax or tmax == -1 or again: - applicable_node_rules, applicable_edge_rules = _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules, num_ga, t) - - # Loop through applicable rules and add them to the rules to be applied for later or next fp operation - for applicable_rule in applicable_node_rules: - n, annotations, qualified_nodes, qualified_edges, _ = applicable_rule - # If there is an edge to add or the predicate doesn't exist or the interpretation is not static - if rule.get_target() not in interpretations_node[n].world or not interpretations_node[n].world[rule.get_target()].is_static(): - bnd = annotate(annotation_functions, rule, annotations, rule.get_weights()) - # Bound annotations in between 0 and 1 - bnd_l = min(max(bnd[0], 0), 1) - bnd_u = min(max(bnd[1], 0), 1) - bnd = interval.closed(bnd_l, bnd_u) - max_rules_time = max(max_rules_time, t + delta_t) - rules_to_be_applied_node_threadsafe[i].append((numba.types.uint16(t + delta_t), n, rule.get_target(), bnd, rule.is_static_rule())) - if atom_trace: - rules_to_be_applied_node_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name())) - - # If delta_t is zero we apply the rules and check if more are applicable - if delta_t == 0: - in_loop = True - update = False - - for applicable_rule in applicable_edge_rules: - e, annotations, qualified_nodes, qualified_edges, edges_to_add = applicable_rule - # If there is an edge to add or the predicate doesn't exist or the interpretation is not static - if len(edges_to_add[0]) > 0 or rule.get_target() not in interpretations_edge[e].world or not interpretations_edge[e].world[rule.get_target()].is_static(): - bnd = annotate(annotation_functions, rule, annotations, rule.get_weights()) - # Bound annotations in between 0 and 1 - bnd_l = min(max(bnd[0], 0), 1) - bnd_u = min(max(bnd[1], 0), 1) - bnd = interval.closed(bnd_l, bnd_u) - max_rules_time = max(max_rules_time, t+delta_t) - # edges_to_be_added_edge_rule.append(edges_to_add) - edges_to_be_added_edge_rule_threadsafe[i].append(edges_to_add) - rules_to_be_applied_edge_threadsafe[i].append((numba.types.uint16(t+delta_t), e, rule.get_target(), bnd, rule.is_static_rule())) - if atom_trace: - # rules_to_be_applied_edge_trace.append((qualified_nodes, qualified_edges, rule.get_name())) - rules_to_be_applied_edge_trace_threadsafe[i].append((qualified_nodes, qualified_edges, rule.get_name())) - - # If delta_t is zero we apply the rules and check if more are applicable - if delta_t == 0: - in_loop = True - update = False - - # Update lists after parallel run - for i in range(len(rules)): - if len(rules_to_be_applied_node_threadsafe[i]) > 0: - rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i]) - if len(rules_to_be_applied_edge_threadsafe[i]) > 0: - rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i]) - if atom_trace: - if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0: - rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i]) - if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0: - rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i]) - if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0: - edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i]) - - # Check for convergence after each timestep (perfect convergence or convergence specified by user) - # Check number of changed interpretations or max bound change - # User specified convergence - if convergence_mode == 'delta_interpretation': - if changes_cnt <= convergence_delta: - if verbose: - print(f'\nConverged at time: {t} with {int(changes_cnt)} changes from the previous interpretation') - # Be consistent with time returned when we don't converge - t += 1 - break - elif convergence_mode == 'delta_bound': - if bound_delta <= convergence_delta: - if verbose: - print(f'\nConverged at time: {t} with {float_to_str(bound_delta)} as the maximum bound change from the previous interpretation') - # Be consistent with time returned when we don't converge - t += 1 - break - # Perfect convergence - # Make sure there are no rules to be applied, and no facts that will be applied in the future. We do this by checking the max time any rule/fact is applicable - # If no more rules/facts to be applied - elif convergence_mode == 'perfect_convergence': - if t>=max_facts_time and t >= max_rules_time: - if verbose: - print(f'\nConverged at time: {t}') - # Be consistent with time returned when we don't converge - t += 1 - break - - # Increment t, update number of ground atoms - t += 1 - num_ga.append(num_ga[-1]) - - return fp_cnt, t - - def add_edge(self, edge, l): - # This function is useful for pyreason gym, called externally - _add_edge(edge[0], edge[1], self.neighbors, self.reverse_neighbors, self.nodes, self.edges, l, self.interpretations_node, self.interpretations_edge, self.predicate_map_edge, self.num_ga, -1) - - def add_node(self, node, labels): - # This function is useful for pyreason gym, called externally - if node not in self.nodes: - _add_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node) - for l in labels: - self.interpretations_node[node].world[label.Label(l)] = interval.closed(0, 1) - - def delete_edge(self, edge): - # This function is useful for pyreason gym, called externally - _delete_edge(edge, self.neighbors, self.reverse_neighbors, self.edges, self.interpretations_edge, self.predicate_map_edge, self.num_ga) - - def delete_node(self, node): - # This function is useful for pyreason gym, called externally - _delete_node(node, self.neighbors, self.reverse_neighbors, self.nodes, self.interpretations_node, self.predicate_map_node, self.num_ga) - - def get_dict(self): - # This function can be called externally to retrieve a dict of the interpretation values - # Only values in the rule trace will be added - - # Initialize interpretations for each time and node and edge - interpretations = {} - for t in range(self.time+1): - interpretations[t] = {} - for node in self.nodes: - interpretations[t][node] = InterpretationDict() - for edge in self.edges: - interpretations[t][edge] = InterpretationDict() - - # Update interpretation nodes - for change in self.rule_trace_node: - time, _, node, l, bnd = change - interpretations[time][node][l._value] = (bnd.lower, bnd.upper) - - # If persistent, update all following timesteps as well - if self. persistent: - for t in range(time+1, self.time+1): - interpretations[t][node][l._value] = (bnd.lower, bnd.upper) - - # Update interpretation edges - for change in self.rule_trace_edge: - time, _, edge, l, bnd, = change - interpretations[time][edge][l._value] = (bnd.lower, bnd.upper) - - # If persistent, update all following timesteps as well - if self. persistent: - for t in range(time+1, self.time+1): - interpretations[t][edge][l._value] = (bnd.lower, bnd.upper) - - return interpretations - - def get_final_num_ground_atoms(self): - """ - This function returns the number of ground atoms after the reasoning process, for the final timestep - :return: int: Number of ground atoms in the interpretation after reasoning - """ - ga_cnt = 0 - - for node in self.nodes: - for l in self.interpretations_node[node].world: - ga_cnt += 1 - for edge in self.edges: - for l in self.interpretations_edge[edge].world: - ga_cnt += 1 - - return ga_cnt - - def get_num_ground_atoms(self): - """ - This function returns the number of ground atoms after the reasoning process, for each timestep - :return: list: Number of ground atoms in the interpretation after reasoning for each timestep - """ - if self.num_ga[-1] == 0: - self.num_ga.pop() - return self.num_ga - - def query(self, query, return_bool=True) -> Union[bool, Tuple[float, float]]: - """ - This function is used to query the graph after reasoning - :param query: A PyReason query object - :param return_bool: If True, returns boolean of query, else the bounds associated with it - :return: bool, or bounds - """ - - comp_type = query.get_component_type() - component = query.get_component() - pred = query.get_predicate() - bnd = query.get_bounds() - - # Check if the component exists - if comp_type == 'node': - if component not in self.nodes: - return False if return_bool else (0, 0) - else: - if component not in self.edges: - return False if return_bool else (0, 0) - - # Check if the predicate exists - if comp_type == 'node': - if pred not in self.interpretations_node[component].world: - return False if return_bool else (0, 0) - else: - if pred not in self.interpretations_edge[component].world: - return False if return_bool else (0, 0) - - # Check if the bounds are satisfied - if comp_type == 'node': - if self.interpretations_node[component].world[pred] in bnd: - return True if return_bool else (self.interpretations_node[component].world[pred].lower, self.interpretations_node[component].world[pred].upper) - else: - return False if return_bool else (0, 0) - else: - if self.interpretations_edge[component].world[pred] in bnd: - return True if return_bool else (self.interpretations_edge[component].world[pred].lower, self.interpretations_edge[component].world[pred].upper) - else: - return False if return_bool else (0, 0) - - -@numba.njit(cache=True) -def _ground_rule(rule, interpretations_node, interpretations_edge, predicate_map_node, predicate_map_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules, num_ga, t): - # Extract rule params - rule_type = rule.get_type() - head_variables = rule.get_head_variables() - clauses = rule.get_clauses() - thresholds = rule.get_thresholds() - ann_fn = rule.get_annotation_function() - rule_edges = rule.get_edges() - - if rule_type == 'node': - head_var_1 = head_variables[0] - else: - head_var_1, head_var_2 = head_variables[0], head_variables[1] - - # We return a list of tuples which specify the target nodes/edges that have made the rule body true - applicable_rules_node = numba.typed.List.empty_list(node_applicable_rule_type) - applicable_rules_edge = numba.typed.List.empty_list(edge_applicable_rule_type) - - # Grounding procedure - # 1. Go through each clause and check which variables have not been initialized in groundings - # 2. Check satisfaction of variables based on the predicate in the clause - - # Grounding variable that maps variables in the body to a list of grounded nodes - # Grounding edges that maps edge variables to a list of edges - groundings = numba.typed.Dict.empty(key_type=numba.types.string, value_type=list_of_nodes) - groundings_edges = numba.typed.Dict.empty(key_type=edge_type, value_type=list_of_edges) - - # Dependency graph that keeps track of the connections between the variables in the body - dependency_graph_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) - dependency_graph_reverse_neighbors = numba.typed.Dict.empty(key_type=node_type, value_type=list_of_nodes) - - nodes_set = set(nodes) - edges_set = set(edges) - - satisfaction = True - for i, clause in enumerate(clauses): - # Unpack clause variables - clause_type = clause[0] - clause_label = clause[1] - clause_variables = clause[2] - clause_bnd = clause[3] - clause_operator = clause[4] - - # This is a node clause - if clause_type == 'node': - clause_var_1 = clause_variables[0] - - # Get subset of nodes that can be used to ground the variable - # If we allow ground atoms, we can use the nodes directly - if allow_ground_rules and clause_var_1 in nodes_set: - grounding = numba.typed.List([clause_var_1]) - else: - grounding = get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map_node, clause_label, nodes) - - # Narrow subset based on predicate - qualified_groundings = get_qualified_node_groundings(interpretations_node, grounding, clause_label, clause_bnd) - groundings[clause_var_1] = qualified_groundings - qualified_groundings_set = set(qualified_groundings) - for c1, c2 in groundings_edges: - if c1 == clause_var_1: - groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[0] in qualified_groundings_set]) - if c2 == clause_var_1: - groundings_edges[(c1, c2)] = numba.typed.List([e for e in groundings_edges[(c1, c2)] if e[1] in qualified_groundings_set]) - - # Check satisfaction of those nodes wrt the threshold - # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule - # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges - # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0: - satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction - - # This is an edge clause - elif clause_type == 'edge': - clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] - - # Get subset of edges that can be used to ground the variables - # If we allow ground atoms, we can use the nodes directly - if allow_ground_rules and (clause_var_1, clause_var_2) in edges_set: - grounding = numba.typed.List([(clause_var_1, clause_var_2)]) - else: - grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map_edge, clause_label, edges) - - # Narrow subset based on predicate (save the edges that are qualified to use for finding future groundings faster) - qualified_groundings = get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, clause_bnd) - - # Check satisfaction of those edges wrt the threshold - # Only check satisfaction if the default threshold is used. This saves us from grounding the rest of the rule - # It doesn't make sense to check any other thresholds because the head could be grounded with multiple nodes/edges - # if thresholds[i][1][0] == 'number' and thresholds[i][1][1] == 'total' and thresholds[i][2] == 1.0: - satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_groundings, clause_label, thresholds[i]) and satisfaction - - # Update the groundings - groundings[clause_var_1] = numba.typed.List.empty_list(node_type) - groundings[clause_var_2] = numba.typed.List.empty_list(node_type) - groundings_clause_1_set = set(groundings[clause_var_1]) - groundings_clause_2_set = set(groundings[clause_var_2]) - for e in qualified_groundings: - if e[0] not in groundings_clause_1_set: - groundings[clause_var_1].append(e[0]) - groundings_clause_1_set.add(e[0]) - if e[1] not in groundings_clause_2_set: - groundings[clause_var_2].append(e[1]) - groundings_clause_2_set.add(e[1]) - - # Update the edge groundings (to use later for grounding other clauses with the same variables) - groundings_edges[(clause_var_1, clause_var_2)] = qualified_groundings - - # Update dependency graph - # Add a connection between clause_var_1 -> clause_var_2 and vice versa - if clause_var_1 not in dependency_graph_neighbors: - dependency_graph_neighbors[clause_var_1] = numba.typed.List([clause_var_2]) - elif clause_var_2 not in dependency_graph_neighbors[clause_var_1]: - dependency_graph_neighbors[clause_var_1].append(clause_var_2) - if clause_var_2 not in dependency_graph_reverse_neighbors: - dependency_graph_reverse_neighbors[clause_var_2] = numba.typed.List([clause_var_1]) - elif clause_var_1 not in dependency_graph_reverse_neighbors[clause_var_2]: - dependency_graph_reverse_neighbors[clause_var_2].append(clause_var_1) - - # This is a comparison clause - else: - pass - - # Refine the subsets based on any updates - if satisfaction: - refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors) - - # If satisfaction is false, break - if not satisfaction: - break - - # If satisfaction is still true, one final refinement to check if each edge pair is valid in edge rules - # Then continue to setup any edges to be added and annotations - # Fill out the rules to be applied lists - if satisfaction: - # Create temp grounding containers to verify if the head groundings are valid (only for edge rules) - # Setup edges to be added and fill rules to be applied - # Setup traces and inputs for annotation function - # Loop through the clause data and setup final annotations and trace variables - # Three cases: 1.node rule, 2. edge rule with infer edges, 3. edge rule - if rule_type == 'node': - # Loop through all the head variable groundings and add it to the rules to be applied - # Loop through the clauses and add appropriate trace data and annotations - - # If there is no grounding for head_var_1, we treat it as a ground atom and add it to the graph - head_var_1_in_nodes = head_var_1 in nodes - add_head_var_node_to_graph = False - if allow_ground_rules and head_var_1_in_nodes: - groundings[head_var_1] = numba.typed.List([head_var_1]) - elif head_var_1 not in groundings: - if not head_var_1_in_nodes: - add_head_var_node_to_graph = True - groundings[head_var_1] = numba.typed.List([head_var_1]) - - for head_grounding in groundings[head_var_1]: - qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) - qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) - annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) - edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) - - # Check for satisfaction one more time in case the refining process has changed the groundings - satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges) - if not satisfaction: - continue - - for i, clause in enumerate(clauses): - clause_type = clause[0] - clause_label = clause[1] - clause_variables = clause[2] - - if clause_type == 'node': - clause_var_1 = clause_variables[0] - - # 1. - if atom_trace: - if clause_var_1 == head_var_1: - qualified_nodes.append(numba.typed.List([head_grounding])) - else: - qualified_nodes.append(numba.typed.List(groundings[clause_var_1])) - qualified_edges.append(numba.typed.List.empty_list(edge_type)) - # 2. - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - if clause_var_1 == head_var_1: - a.append(interpretations_node[head_grounding].world[clause_label]) - else: - for qn in groundings[clause_var_1]: - a.append(interpretations_node[qn].world[clause_label]) - annotations.append(a) - - elif clause_type == 'edge': - clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] - # 1. - if atom_trace: - # Cases: Both equal, one equal, none equal - qualified_nodes.append(numba.typed.List.empty_list(node_type)) - if clause_var_1 == head_var_1: - es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_grounding]) - qualified_edges.append(es) - elif clause_var_2 == head_var_1: - es = numba.typed.List([e for e in groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_grounding]) - qualified_edges.append(es) - else: - qualified_edges.append(numba.typed.List(groundings_edges[(clause_var_1, clause_var_2)])) - # 2. - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - if clause_var_1 == head_var_1: - for e in groundings_edges[(clause_var_1, clause_var_2)]: - if e[0] == head_grounding: - a.append(interpretations_edge[e].world[clause_label]) - elif clause_var_2 == head_var_1: - for e in groundings_edges[(clause_var_1, clause_var_2)]: - if e[1] == head_grounding: - a.append(interpretations_edge[e].world[clause_label]) - else: - for qe in groundings_edges[(clause_var_1, clause_var_2)]: - a.append(interpretations_edge[qe].world[clause_label]) - annotations.append(a) - else: - # Comparison clause (we do not handle for now) - pass - - # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules) - if add_head_var_node_to_graph: - _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node) - - # For each grounding add a rule to be applied - applicable_rules_node.append((head_grounding, annotations, qualified_nodes, qualified_edges, edges_to_be_added)) - - elif rule_type == 'edge': - head_var_1 = head_variables[0] - head_var_2 = head_variables[1] - - # If there is no grounding for head_var_1 or head_var_2, we treat it as a ground atom and add it to the graph - head_var_1_in_nodes = head_var_1 in nodes - head_var_2_in_nodes = head_var_2 in nodes - add_head_var_1_node_to_graph = False - add_head_var_2_node_to_graph = False - add_head_edge_to_graph = False - if allow_ground_rules and head_var_1_in_nodes: - groundings[head_var_1] = numba.typed.List([head_var_1]) - if allow_ground_rules and head_var_2_in_nodes: - groundings[head_var_2] = numba.typed.List([head_var_2]) - - if head_var_1 not in groundings: - if not head_var_1_in_nodes: - add_head_var_1_node_to_graph = True - groundings[head_var_1] = numba.typed.List([head_var_1]) - if head_var_2 not in groundings: - if not head_var_2_in_nodes: - add_head_var_2_node_to_graph = True - groundings[head_var_2] = numba.typed.List([head_var_2]) - - # Artificially connect the head variables with an edge if both of them were not in the graph - if not head_var_1_in_nodes and not head_var_2_in_nodes: - add_head_edge_to_graph = True - - head_var_1_groundings = groundings[head_var_1] - head_var_2_groundings = groundings[head_var_2] - - source, target, _ = rule_edges - infer_edges = True if source != '' and target != '' else False - - # Prepare the edges that we will loop over. - # For infer edges we loop over each combination pair - # Else we loop over the valid edges in the graph - valid_edge_groundings = numba.typed.List.empty_list(edge_type) - for g1 in head_var_1_groundings: - for g2 in head_var_2_groundings: - if infer_edges: - valid_edge_groundings.append((g1, g2)) - else: - if (g1, g2) in edges_set: - valid_edge_groundings.append((g1, g2)) - - # Loop through the head variable groundings - for valid_e in valid_edge_groundings: - head_var_1_grounding, head_var_2_grounding = valid_e[0], valid_e[1] - qualified_nodes = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) - qualified_edges = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) - annotations = numba.typed.List.empty_list(numba.typed.List.empty_list(interval.interval_type)) - edges_to_be_added = (numba.typed.List.empty_list(node_type), numba.typed.List.empty_list(node_type), rule_edges[-1]) - - # Containers to keep track of groundings to make sure that the edge pair is valid - # We do this because we cannot know beforehand the edge matches from source groundings to target groundings - temp_groundings = groundings.copy() - temp_groundings_edges = groundings_edges.copy() - - # Refine the temp groundings for the specific edge head grounding - # We update the edge collection as well depending on if there's a match between the clause variables and head variables - temp_groundings[head_var_1] = numba.typed.List([head_var_1_grounding]) - temp_groundings[head_var_2] = numba.typed.List([head_var_2_grounding]) - for c1, c2 in temp_groundings_edges.keys(): - if c1 == head_var_1 and c2 == head_var_2: - temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_1_grounding, head_var_2_grounding)]) - elif c1 == head_var_2 and c2 == head_var_1: - temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e == (head_var_2_grounding, head_var_1_grounding)]) - elif c1 == head_var_1: - temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_1_grounding]) - elif c2 == head_var_1: - temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_1_grounding]) - elif c1 == head_var_2: - temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[0] == head_var_2_grounding]) - elif c2 == head_var_2: - temp_groundings_edges[(c1, c2)] = numba.typed.List([e for e in temp_groundings_edges[(c1, c2)] if e[1] == head_var_2_grounding]) - - refine_groundings(head_variables, temp_groundings, temp_groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors) - - # Check if the thresholds are still satisfied - # Check if all clauses are satisfied again in case the refining process changed anything - satisfaction = check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, temp_groundings, temp_groundings_edges) - - if not satisfaction: - continue - - if infer_edges: - # Prevent self loops while inferring edges if the clause variables are not the same - if source != target and head_var_1_grounding == head_var_2_grounding: - continue - edges_to_be_added[0].append(head_var_1_grounding) - edges_to_be_added[1].append(head_var_2_grounding) - - for i, clause in enumerate(clauses): - clause_type = clause[0] - clause_label = clause[1] - clause_variables = clause[2] - - if clause_type == 'node': - clause_var_1 = clause_variables[0] - # 1. - if atom_trace: - if clause_var_1 == head_var_1: - qualified_nodes.append(numba.typed.List([head_var_1_grounding])) - elif clause_var_1 == head_var_2: - qualified_nodes.append(numba.typed.List([head_var_2_grounding])) - else: - qualified_nodes.append(numba.typed.List(temp_groundings[clause_var_1])) - qualified_edges.append(numba.typed.List.empty_list(edge_type)) - # 2. - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - if clause_var_1 == head_var_1: - a.append(interpretations_node[head_var_1_grounding].world[clause_label]) - elif clause_var_1 == head_var_2: - a.append(interpretations_node[head_var_2_grounding].world[clause_label]) - else: - for qn in temp_groundings[clause_var_1]: - a.append(interpretations_node[qn].world[clause_label]) - annotations.append(a) - - elif clause_type == 'edge': - clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] - # 1. - if atom_trace: - # Cases: - # 1. Both equal (cv1 = hv1 and cv2 = hv2 or cv1 = hv2 and cv2 = hv1) - # 2. One equal (cv1 = hv1 or cv2 = hv1 or cv1 = hv2 or cv2 = hv2) - # 3. None equal - qualified_nodes.append(numba.typed.List.empty_list(node_type)) - if clause_var_1 == head_var_1 and clause_var_2 == head_var_2: - es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding]) - qualified_edges.append(es) - elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1: - es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding]) - qualified_edges.append(es) - elif clause_var_1 == head_var_1: - es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_1_grounding]) - qualified_edges.append(es) - elif clause_var_1 == head_var_2: - es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[0] == head_var_2_grounding]) - qualified_edges.append(es) - elif clause_var_2 == head_var_1: - es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_1_grounding]) - qualified_edges.append(es) - elif clause_var_2 == head_var_2: - es = numba.typed.List([e for e in temp_groundings_edges[(clause_var_1, clause_var_2)] if e[1] == head_var_2_grounding]) - qualified_edges.append(es) - else: - qualified_edges.append(numba.typed.List(temp_groundings_edges[(clause_var_1, clause_var_2)])) - - # 2. - if ann_fn != '': - a = numba.typed.List.empty_list(interval.interval_type) - if clause_var_1 == head_var_1 and clause_var_2 == head_var_2: - for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: - if e[0] == head_var_1_grounding and e[1] == head_var_2_grounding: - a.append(interpretations_edge[e].world[clause_label]) - elif clause_var_1 == head_var_2 and clause_var_2 == head_var_1: - for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: - if e[0] == head_var_2_grounding and e[1] == head_var_1_grounding: - a.append(interpretations_edge[e].world[clause_label]) - elif clause_var_1 == head_var_1: - for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: - if e[0] == head_var_1_grounding: - a.append(interpretations_edge[e].world[clause_label]) - elif clause_var_1 == head_var_2: - for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: - if e[0] == head_var_2_grounding: - a.append(interpretations_edge[e].world[clause_label]) - elif clause_var_2 == head_var_1: - for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: - if e[1] == head_var_1_grounding: - a.append(interpretations_edge[e].world[clause_label]) - elif clause_var_2 == head_var_2: - for e in temp_groundings_edges[(clause_var_1, clause_var_2)]: - if e[1] == head_var_2_grounding: - a.append(interpretations_edge[e].world[clause_label]) - else: - for qe in temp_groundings_edges[(clause_var_1, clause_var_2)]: - a.append(interpretations_edge[qe].world[clause_label]) - annotations.append(a) - - # Now that we're sure that the rule is satisfied, we add the head to the graph if needed (only for ground rules) - if add_head_var_1_node_to_graph and head_var_1_grounding == head_var_1: - _add_node(head_var_1, neighbors, reverse_neighbors, nodes, interpretations_node) - if add_head_var_2_node_to_graph and head_var_2_grounding == head_var_2: - _add_node(head_var_2, neighbors, reverse_neighbors, nodes, interpretations_node) - if add_head_edge_to_graph and (head_var_1, head_var_2) == (head_var_1_grounding, head_var_2_grounding): - _add_edge(head_var_1, head_var_2, neighbors, reverse_neighbors, nodes, edges, label.Label(''), interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) - - # For each grounding combination add a rule to be applied - # Only if all the clauses have valid groundings - # if satisfaction: - e = (head_var_1_grounding, head_var_2_grounding) - applicable_rules_edge.append((e, annotations, qualified_nodes, qualified_edges, edges_to_be_added)) - - # Return the applicable rules - return applicable_rules_node, applicable_rules_edge - - -@numba.njit(cache=True) -def check_all_clause_satisfaction(interpretations_node, interpretations_edge, clauses, thresholds, groundings, groundings_edges): - # Check if the thresholds are satisfied for each clause - satisfaction = True - for i, clause in enumerate(clauses): - # Unpack clause variables - clause_type = clause[0] - clause_label = clause[1] - clause_variables = clause[2] - - if clause_type == 'node': - clause_var_1 = clause_variables[0] - satisfaction = check_node_grounding_threshold_satisfaction(interpretations_node, groundings[clause_var_1], groundings[clause_var_1], clause_label, thresholds[i]) and satisfaction - elif clause_type == 'edge': - clause_var_1, clause_var_2 = clause_variables[0], clause_variables[1] - satisfaction = check_edge_grounding_threshold_satisfaction(interpretations_edge, groundings_edges[(clause_var_1, clause_var_2)], groundings_edges[(clause_var_1, clause_var_2)], clause_label, thresholds[i]) and satisfaction - return satisfaction - - -@numba.njit(cache=True) -def refine_groundings(clause_variables, groundings, groundings_edges, dependency_graph_neighbors, dependency_graph_reverse_neighbors): - # Loop through the dependency graph and refine the groundings that have connections - all_variables_refined = numba.typed.List(clause_variables) - variables_just_refined = numba.typed.List(clause_variables) - new_variables_refined = numba.typed.List.empty_list(numba.types.string) - while len(variables_just_refined) > 0: - for refined_variable in variables_just_refined: - # Refine all the neighbors of the refined variable - if refined_variable in dependency_graph_neighbors: - for neighbor in dependency_graph_neighbors[refined_variable]: - old_edge_groundings = groundings_edges[(refined_variable, neighbor)] - new_node_groundings = groundings[refined_variable] - - # Delete old groundings for the variable being refined - del groundings[neighbor] - groundings[neighbor] = numba.typed.List.empty_list(node_type) - - # Update the edge groundings and node groundings - qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[0] in new_node_groundings]) - groundings_neighbor_set = set(groundings[neighbor]) - for e in qualified_groundings: - if e[1] not in groundings_neighbor_set: - groundings[neighbor].append(e[1]) - groundings_neighbor_set.add(e[1]) - groundings_edges[(refined_variable, neighbor)] = qualified_groundings - - # Add the neighbor to the list of refined variables so that we can refine for all its neighbors - if neighbor not in all_variables_refined: - new_variables_refined.append(neighbor) - - if refined_variable in dependency_graph_reverse_neighbors: - for reverse_neighbor in dependency_graph_reverse_neighbors[refined_variable]: - old_edge_groundings = groundings_edges[(reverse_neighbor, refined_variable)] - new_node_groundings = groundings[refined_variable] - - # Delete old groundings for the variable being refined - del groundings[reverse_neighbor] - groundings[reverse_neighbor] = numba.typed.List.empty_list(node_type) - - # Update the edge groundings and node groundings - qualified_groundings = numba.typed.List([edge for edge in old_edge_groundings if edge[1] in new_node_groundings]) - groundings_reverse_neighbor_set = set(groundings[reverse_neighbor]) - for e in qualified_groundings: - if e[0] not in groundings_reverse_neighbor_set: - groundings[reverse_neighbor].append(e[0]) - groundings_reverse_neighbor_set.add(e[0]) - groundings_edges[(reverse_neighbor, refined_variable)] = qualified_groundings - - # Add the neighbor to the list of refined variables so that we can refine for all its neighbors - if reverse_neighbor not in all_variables_refined: - new_variables_refined.append(reverse_neighbor) - - variables_just_refined = numba.typed.List(new_variables_refined) - all_variables_refined.extend(new_variables_refined) - new_variables_refined.clear() - - -@numba.njit(cache=True) -def check_node_grounding_threshold_satisfaction(interpretations_node, grounding, qualified_grounding, clause_label, threshold): - threshold_quantifier_type = threshold[1][1] - if threshold_quantifier_type == 'total': - neigh_len = len(grounding) - - # Available is all neighbors that have a particular label with bound inside [0,1] - elif threshold_quantifier_type == 'available': - neigh_len = len(get_qualified_node_groundings(interpretations_node, grounding, clause_label, interval.closed(0, 1))) - - qualified_neigh_len = len(qualified_grounding) - satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) - return satisfaction - - -@numba.njit(cache=True) -def check_edge_grounding_threshold_satisfaction(interpretations_edge, grounding, qualified_grounding, clause_label, threshold): - threshold_quantifier_type = threshold[1][1] - if threshold_quantifier_type == 'total': - neigh_len = len(grounding) - - # Available is all neighbors that have a particular label with bound inside [0,1] - elif threshold_quantifier_type == 'available': - neigh_len = len(get_qualified_edge_groundings(interpretations_edge, grounding, clause_label, interval.closed(0, 1))) - - qualified_neigh_len = len(qualified_grounding) - satisfaction = _satisfies_threshold(neigh_len, qualified_neigh_len, threshold) - return satisfaction - - -@numba.njit(cache=True) -def get_rule_node_clause_grounding(clause_var_1, groundings, predicate_map, l, nodes): - # The groundings for a node clause can be either a previous grounding or all possible nodes - if l in predicate_map: - grounding = predicate_map[l] if clause_var_1 not in groundings else groundings[clause_var_1] - else: - grounding = nodes if clause_var_1 not in groundings else groundings[clause_var_1] - return grounding - - -@numba.njit(cache=True) -def get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, predicate_map, l, edges): - # There are 4 cases for predicate(Y,Z): - # 1. Both predicate variables Y and Z have not been encountered before - # 2. The source variable Y has not been encountered before but the target variable Z has - # 3. The target variable Z has not been encountered before but the source variable Y has - # 4. Both predicate variables Y and Z have been encountered before - edge_groundings = numba.typed.List.empty_list(edge_type) - - # Case 1: - # We replace Y by all nodes and Z by the neighbors of each of these nodes - if clause_var_1 not in groundings and clause_var_2 not in groundings: - if l in predicate_map: - edge_groundings = predicate_map[l] - else: - edge_groundings = edges - - # Case 2: - # We replace Y by the sources of Z - elif clause_var_1 not in groundings and clause_var_2 in groundings: - for n in groundings[clause_var_2]: - es = numba.typed.List([(nn, n) for nn in reverse_neighbors[n]]) - edge_groundings.extend(es) - - # Case 3: - # We replace Z by the neighbors of Y - elif clause_var_1 in groundings and clause_var_2 not in groundings: - for n in groundings[clause_var_1]: - es = numba.typed.List([(n, nn) for nn in neighbors[n]]) - edge_groundings.extend(es) - - # Case 4: - # We have seen both variables before - else: - # We have already seen these two variables in an edge clause - if (clause_var_1, clause_var_2) in groundings_edges: - edge_groundings = groundings_edges[(clause_var_1, clause_var_2)] - # We have seen both these variables but not in an edge clause together - else: - groundings_clause_var_2_set = set(groundings[clause_var_2]) - for n in groundings[clause_var_1]: - es = numba.typed.List([(n, nn) for nn in neighbors[n] if nn in groundings_clause_var_2_set]) - edge_groundings.extend(es) - - return edge_groundings - - -@numba.njit(cache=True) -def get_qualified_node_groundings(interpretations_node, grounding, clause_l, clause_bnd): - # Filter the grounding by the predicate and bound of the clause - qualified_groundings = numba.typed.List.empty_list(node_type) - for n in grounding: - if is_satisfied_node(interpretations_node, n, (clause_l, clause_bnd)): - qualified_groundings.append(n) - - return qualified_groundings - - -@numba.njit(cache=True) -def get_qualified_edge_groundings(interpretations_edge, grounding, clause_l, clause_bnd): - # Filter the grounding by the predicate and bound of the clause - qualified_groundings = numba.typed.List.empty_list(edge_type) - for e in grounding: - if is_satisfied_edge(interpretations_edge, e, (clause_l, clause_bnd)): - qualified_groundings.append(e) - - return qualified_groundings - - -@numba.njit(cache=True) -def _satisfies_threshold(num_neigh, num_qualified_component, threshold): - # Checks if qualified neighbors satisfy threshold. This is for one clause - if threshold[1][0]=='number': - if threshold[0]=='greater_equal': - result = True if num_qualified_component >= threshold[2] else False - elif threshold[0]=='greater': - result = True if num_qualified_component > threshold[2] else False - elif threshold[0]=='less_equal': - result = True if num_qualified_component <= threshold[2] else False - elif threshold[0]=='less': - result = True if num_qualified_component < threshold[2] else False - elif threshold[0]=='equal': - result = True if num_qualified_component == threshold[2] else False - - elif threshold[1][0]=='percent': - if num_neigh==0: - result = False - elif threshold[0]=='greater_equal': - result = True if num_qualified_component/num_neigh >= threshold[2]*0.01 else False - elif threshold[0]=='greater': - result = True if num_qualified_component/num_neigh > threshold[2]*0.01 else False - elif threshold[0]=='less_equal': - result = True if num_qualified_component/num_neigh <= threshold[2]*0.01 else False - elif threshold[0]=='less': - result = True if num_qualified_component/num_neigh < threshold[2]*0.01 else False - elif threshold[0]=='equal': - result = True if num_qualified_component/num_neigh == threshold[2]*0.01 else False - - return result - - -@numba.njit(cache=True) -def _update_node(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, num_ga, mode, override=False): - updated = False - # This is to prevent a key error in case the label is a specific label - try: - world = interpretations[comp] - l, bnd = na - updated_bnds = numba.typed.List.empty_list(interval.interval_type) - - # Add label to world if it is not there - if l not in world.world: - world.world[l] = interval.closed(0, 1) - num_ga[t_cnt] += 1 - if l in predicate_map: - predicate_map[l].append(comp) - else: - predicate_map[l] = numba.typed.List([comp]) - - # Check if update is necessary with previous bnd - prev_bnd = world.world[l].copy() - - # override will not check for inconsistencies - if override: - world.world[l].set_lower_upper(bnd.lower, bnd.upper) - else: - world.update(l, bnd) - world.world[l].set_static(static) - if world.world[l]!=prev_bnd: - updated = True - updated_bnds.append(world.world[l]) - - # Add to rule trace if update happened and add to atom trace if necessary - if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) - if atom_trace: - # Mode can be fact or rule, updation of trace will happen accordingly - if mode=='fact' or mode=='graph-attribute-fact': - qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) - qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) - name = facts_to_be_applied_trace[idx] - _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name) - elif mode=='rule': - qn, qe, name = rules_to_be_applied_trace[idx] - _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name) - - # Update complement of predicate (if exists) based on new knowledge of predicate - if updated: - ip_update_cnt = 0 - for p1, p2 in ipl: - if p1 == l: - if p2 not in world.world: - world.world[p2] = interval.closed(0, 1) - if p2 in predicate_map: - predicate_map[p2].append(comp) - else: - predicate_map[p2] = numba.typed.List([comp]) - if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') - lower = max(world.world[p2].lower, 1 - world.world[p1].upper) - upper = min(world.world[p2].upper, 1 - world.world[p1].lower) - world.world[p2].set_lower_upper(lower, upper) - world.world[p2].set_static(static) - ip_update_cnt += 1 - updated_bnds.append(world.world[p2]) - if store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == l: - if p1 not in world.world: - world.world[p1] = interval.closed(0, 1) - if p1 in predicate_map: - predicate_map[p1].append(comp) - else: - predicate_map[p1] = numba.typed.List([comp]) - if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') - lower = max(world.world[p1].lower, 1 - world.world[p2].upper) - upper = min(world.world[p1].upper, 1 - world.world[p2].lower) - world.world[p1].set_lower_upper(lower, upper) - world.world[p1].set_static(static) - ip_update_cnt += 1 - updated_bnds.append(world.world[p1]) - if store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper))) - - # Gather convergence data - change = 0 - if updated: - # Find out if it has changed from previous interp - current_bnd = world.world[l] - prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) - if current_bnd != prev_t_bnd: - if convergence_mode=='delta_bound': - for i in updated_bnds: - lower_delta = abs(i.lower-prev_t_bnd.lower) - upper_delta = abs(i.upper-prev_t_bnd.upper) - max_delta = max(lower_delta, upper_delta) - change = max(change, max_delta) - else: - change = 1 + ip_update_cnt - - return (updated, change) - - except: - return (False, 0) - - -@numba.njit(cache=True) -def _update_edge(interpretations, predicate_map, comp, na, ipl, rule_trace, fp_cnt, t_cnt, static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_trace, idx, facts_to_be_applied_trace, rule_trace_atoms, store_interpretation_changes, num_ga, mode, override=False): - updated = False - # This is to prevent a key error in case the label is a specific label - try: - world = interpretations[comp] - l, bnd = na - updated_bnds = numba.typed.List.empty_list(interval.interval_type) - - # Add label to world if it is not there - if l not in world.world: - world.world[l] = interval.closed(0, 1) - num_ga[t_cnt] += 1 - if l in predicate_map: - predicate_map[l].append(comp) - else: - predicate_map[l] = numba.typed.List([comp]) - - # Check if update is necessary with previous bnd - prev_bnd = world.world[l].copy() - - # override will not check for inconsistencies - if override: - world.world[l].set_lower_upper(bnd.lower, bnd.upper) - else: - world.update(l, bnd) - world.world[l].set_static(static) - if world.world[l]!=prev_bnd: - updated = True - updated_bnds.append(world.world[l]) - - # Add to rule trace if update happened and add to atom trace if necessary - if (save_graph_attributes_to_rule_trace or not mode=='graph-attribute-fact') and store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, l, world.world[l].copy())) - if atom_trace: - # Mode can be fact or rule, updation of trace will happen accordingly - if mode=='fact' or mode=='graph-attribute-fact': - qn = numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)) - qe = numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)) - name = facts_to_be_applied_trace[idx] - _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name) - elif mode=='rule': - qn, qe, name = rules_to_be_applied_trace[idx] - _update_rule_trace(rule_trace_atoms, qn, qe, prev_bnd, name) - - # Update complement of predicate (if exists) based on new knowledge of predicate - if updated: - ip_update_cnt = 0 - for p1, p2 in ipl: - if p1 == l: - if p2 not in world.world: - world.world[p2] = interval.closed(0, 1) - if p2 in predicate_map: - predicate_map[p2].append(comp) - else: - predicate_map[p2] = numba.typed.List([comp]) - if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'IPL: {l.get_value()}') - lower = max(world.world[p2].lower, 1 - world.world[p1].upper) - upper = min(world.world[p2].upper, 1 - world.world[p1].lower) - world.world[p2].set_lower_upper(lower, upper) - world.world[p2].set_static(static) - ip_update_cnt += 1 - updated_bnds.append(world.world[p2]) - if store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(lower, upper))) - if p2 == l: - if p1 not in world.world: - world.world[p1] = interval.closed(0, 1) - if p1 in predicate_map: - predicate_map[p1].append(comp) - else: - predicate_map[p1] = numba.typed.List([comp]) - if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'IPL: {l.get_value()}') - lower = max(world.world[p1].lower, 1 - world.world[p2].upper) - upper = min(world.world[p1].upper, 1 - world.world[p2].lower) - world.world[p1].set_lower_upper(lower, upper) - world.world[p1].set_static(static) - ip_update_cnt += 1 - updated_bnds.append(world.world[p2]) - if store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(lower, upper))) - - # Gather convergence data - change = 0 - if updated: - # Find out if it has changed from previous interp - current_bnd = world.world[l] - prev_t_bnd = interval.closed(world.world[l].prev_lower, world.world[l].prev_upper) - if current_bnd != prev_t_bnd: - if convergence_mode=='delta_bound': - for i in updated_bnds: - lower_delta = abs(i.lower-prev_t_bnd.lower) - upper_delta = abs(i.upper-prev_t_bnd.upper) - max_delta = max(lower_delta, upper_delta) - change = max(change, max_delta) - else: - change = 1 + ip_update_cnt - - return (updated, change) - except: - return (False, 0) - - -@numba.njit(cache=True) -def _update_rule_trace(rule_trace, qn, qe, prev_bnd, name): - rule_trace.append((qn, qe, prev_bnd.copy(), name)) - - -@numba.njit(cache=True) -def are_satisfied_node(interpretations, comp, nas): - result = True - for (l, bnd) in nas: - result = result and is_satisfied_node(interpretations, comp, (l, bnd)) - return result - - -@numba.njit(cache=True) -def is_satisfied_node(interpretations, comp, na): - result = False - if not (na[0] is None or na[1] is None): - # This is to prevent a key error in case the label is a specific label - try: - world = interpretations[comp] - result = world.is_satisfied(na[0], na[1]) - except: - result = False - else: - result = True - return result - - -@numba.njit(cache=True) -def is_satisfied_node_comparison(interpretations, comp, na): - result = False - number = 0 - l, bnd = na - l_str = l.value - - if not (l is None or bnd is None): - # This is to prevent a key error in case the label is a specific label - try: - world = interpretations[comp] - for world_l in world.world.keys(): - world_l_str = world_l.value - if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): - # The label is contained in the world - result = world.is_satisfied(world_l, na[1]) - # Find the suffix number - number = str_to_float(world_l_str[len(l_str)+1:]) - break - - except: - result = False - else: - result = True - return result, number - - -@numba.njit(cache=True) -def are_satisfied_edge(interpretations, comp, nas): - result = True - for (l, bnd) in nas: - result = result and is_satisfied_edge(interpretations, comp, (l, bnd)) - return result - - -@numba.njit(cache=True) -def is_satisfied_edge(interpretations, comp, na): - result = False - if not (na[0] is None or na[1] is None): - # This is to prevent a key error in case the label is a specific label - try: - world = interpretations[comp] - result = world.is_satisfied(na[0], na[1]) - except: - result = False - else: - result = True - return result - - -@numba.njit(cache=True) -def is_satisfied_edge_comparison(interpretations, comp, na): - result = False - number = 0 - l, bnd = na - l_str = l.value - - if not (l is None or bnd is None): - # This is to prevent a key error in case the label is a specific label - try: - world = interpretations[comp] - for world_l in world.world.keys(): - world_l_str = world_l.value - if l_str in world_l_str and world_l_str[len(l_str)+1:].replace('.', '').replace('-', '').isdigit(): - # The label is contained in the world - result = world.is_satisfied(world_l, na[1]) - # Find the suffix number - number = str_to_float(world_l_str[len(l_str)+1:]) - break - - except: - result = False - else: - result = True - return result, number - - -@numba.njit(cache=True) -def annotate(annotation_functions, rule, annotations, weights): - func_name = rule.get_annotation_function() - if func_name == '': - return rule.get_bnd().lower, rule.get_bnd().upper - else: - with numba.objmode(annotation='Tuple((float64, float64))'): - for func in annotation_functions: - if func.__name__ == func_name: - annotation = func(annotations, weights) - return annotation - - -@numba.njit(cache=True) -def check_consistent_node(interpretations, comp, na): - world = interpretations[comp] - if na[0] in world.world: - bnd = world.world[na[0]] - else: - bnd = interval.closed(0, 1) - if (na[1].lower > bnd.upper) or (bnd.lower > na[1].upper): - return False - else: - return True - - -@numba.njit(cache=True) -def check_consistent_edge(interpretations, comp, na): - world = interpretations[comp] - if na[0] in world.world: - bnd = world.world[na[0]] - else: - bnd = interval.closed(0, 1) - if (na[1].lower > bnd.upper) or (bnd.lower > na[1].upper): - return False - else: - return True - - -@numba.njit(cache=True) -def resolve_inconsistency_node(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode): - world = interpretations[comp] - if store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1))) - if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace: - name = facts_to_be_applied_trace[idx] - elif mode == 'rule' and atom_trace: - name = rules_to_be_applied_trace[idx][2] - else: - name = '-' - if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[na[0]], f'Inconsistency due to {name}') - # Resolve inconsistency and set static - world.world[na[0]].set_lower_upper(0, 1) - world.world[na[0]].set_static(True) - for p1, p2 in ipl: - if p1==na[0]: - if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p2], f'Inconsistency due to {name}') - world.world[p2].set_lower_upper(0, 1) - world.world[p2].set_static(True) - if store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(0,1))) - - if p2==na[0]: - if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), world.world[p1], f'Inconsistency due to {name}') - world.world[p1].set_lower_upper(0, 1) - world.world[p1].set_static(True) - if store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1))) - # Add inconsistent predicates to a list - - -@numba.njit(cache=True) -def resolve_inconsistency_edge(interpretations, comp, na, ipl, t_cnt, fp_cnt, idx, atom_trace, rule_trace, rule_trace_atoms, rules_to_be_applied_trace, facts_to_be_applied_trace, store_interpretation_changes, mode): - w = interpretations[comp] - if store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, na[0], interval.closed(0,1))) - if mode == 'fact' or mode == 'graph-attribute-fact' and atom_trace: - name = facts_to_be_applied_trace[idx] - elif mode == 'rule' and atom_trace: - name = rules_to_be_applied_trace[idx][2] - else: - name = '-' - if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[na[0]], f'Inconsistency due to {name}') - # Resolve inconsistency and set static - w.world[na[0]].set_lower_upper(0, 1) - w.world[na[0]].set_static(True) - for p1, p2 in ipl: - if p1==na[0]: - if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p2], f'Inconsistency due to {name}') - w.world[p2].set_lower_upper(0, 1) - w.world[p2].set_static(True) - if store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p2, interval.closed(0,1))) - - if p2==na[0]: - if atom_trace: - _update_rule_trace(rule_trace_atoms, numba.typed.List.empty_list(numba.typed.List.empty_list(node_type)), numba.typed.List.empty_list(numba.typed.List.empty_list(edge_type)), w.world[p1], f'Inconsistency due to {name}') - w.world[p1].set_lower_upper(0, 1) - w.world[p1].set_static(True) - if store_interpretation_changes: - rule_trace.append((numba.types.uint16(t_cnt), numba.types.uint16(fp_cnt), comp, p1, interval.closed(0,1))) - - -@numba.njit(cache=True) -def _add_node(node, neighbors, reverse_neighbors, nodes, interpretations_node): - nodes.append(node) - neighbors[node] = numba.typed.List.empty_list(node_type) - reverse_neighbors[node] = numba.typed.List.empty_list(node_type) - interpretations_node[node] = world.World(numba.typed.List.empty_list(label.label_type)) - - -@numba.njit(cache=True) -def _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): - # If not a node, add to list of nodes and initialize neighbors - if source not in nodes: - _add_node(source, neighbors, reverse_neighbors, nodes, interpretations_node) - - if target not in nodes: - _add_node(target, neighbors, reverse_neighbors, nodes, interpretations_node) - - # Make sure edge doesn't already exist - # Make sure, if l=='', not to add the label - # Make sure, if edge exists, that we don't override the l label if it exists - edge = (source, target) - new_edge = False - if edge not in edges: - new_edge = True - edges.append(edge) - neighbors[source].append(target) - reverse_neighbors[target].append(source) - if l.value!='': - interpretations_edge[edge] = world.World(numba.typed.List([l])) - num_ga[t] += 1 - if l in predicate_map: - predicate_map[l].append(edge) - else: - predicate_map[l] = numba.typed.List([edge]) - else: - interpretations_edge[edge] = world.World(numba.typed.List.empty_list(label.label_type)) - else: - if l not in interpretations_edge[edge].world and l.value!='': - new_edge = True - interpretations_edge[edge].world[l] = interval.closed(0, 1) - num_ga[t] += 1 - - if l in predicate_map: - predicate_map[l].append(edge) - else: - predicate_map[l] = numba.typed.List([edge]) - - return edge, new_edge - - -@numba.njit(cache=True) -def _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t): - changes = 0 - edges_added = numba.typed.List.empty_list(edge_type) - for source in sources: - for target in targets: - edge, new_edge = _add_edge(source, target, neighbors, reverse_neighbors, nodes, edges, l, interpretations_node, interpretations_edge, predicate_map, num_ga, t) - edges_added.append(edge) - changes = changes+1 if new_edge else changes - return edges_added, changes - - -@numba.njit(cache=True) -def _delete_edge(edge, neighbors, reverse_neighbors, edges, interpretations_edge, predicate_map, num_ga): - source, target = edge - edges.remove(edge) - num_ga[-1] -= len(interpretations_edge[edge].world) - del interpretations_edge[edge] - for l in predicate_map: - if edge in predicate_map[l]: - predicate_map[l].remove(edge) - neighbors[source].remove(target) - reverse_neighbors[target].remove(source) - - -@numba.njit(cache=True) -def _delete_node(node, neighbors, reverse_neighbors, nodes, interpretations_node, predicate_map, num_ga): - nodes.remove(node) - num_ga[-1] -= len(interpretations_node[node].world) - del interpretations_node[node] - del neighbors[node] - del reverse_neighbors[node] - for l in predicate_map: - if node in predicate_map[l]: - predicate_map[l].remove(node) - - # Remove all occurrences of node in neighbors - for n in neighbors.keys(): - if node in neighbors[n]: - neighbors[n].remove(node) - for n in reverse_neighbors.keys(): - if node in reverse_neighbors[n]: - reverse_neighbors[n].remove(node) - - -@numba.njit(cache=True) -def float_to_str(value): - number = int(value) - decimal = int(value % 1 * 1000) - float_str = f'{number}.{decimal}' - return float_str - - -@numba.njit(cache=True) -def str_to_float(value): - decimal_pos = value.find('.') - if decimal_pos != -1: - after_decimal_len = len(value[decimal_pos+1:]) - else: - after_decimal_len = 0 - value = value.replace('.', '') - value = str_to_int(value) - value = value / 10**after_decimal_len - return value - - -@numba.njit(cache=True) -def str_to_int(value): - if value[0] == '-': - negative = True - value = value.replace('-','') - else: - negative = False - final_index, result = len(value) - 1, 0 - for i, v in enumerate(value): - result += (ord(v) - 48) * (10 ** (final_index - i)) - result = -result if negative else result - return result \ No newline at end of file From 5a6f250abbb6c4e58484622725dd7e349ed4f40f Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 9 Oct 2025 11:21:07 -0400 Subject: [PATCH 05/32] More lint fixes --- .pre-commit-config.yaml | 2 +- pyreason/__init__.py | 10 ++++------ pyreason/pyreason.py | 12 ++++++------ 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b94a67e..3002f2a1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: # --- COMMIT STAGE: Fast unit tests only --- - id: ruff-check name: Run ruff linter - entry: .venv/bin/python -m ruff check + entry: .venv/bin/python -m ruff check pyreason/scripts language: system types: [python] pass_filenames: false diff --git a/pyreason/__init__.py b/pyreason/__init__.py index 15fec585..366d0595 100755 --- a/pyreason/__init__.py +++ b/pyreason/__init__.py @@ -1,16 +1,14 @@ # Set numba environment variable import os +import yaml +from pyreason.pyreason import settings, load_graphml, add_rule, Rule, add_fact, Fact, reason, reset, reset_rules +from pkg_resources import get_distribution, DistributionNotFound + package_path = os.path.abspath(os.path.dirname(__file__)) cache_path = os.path.join(package_path, 'cache') cache_status_path = os.path.join(package_path, '.cache_status.yaml') os.environ['NUMBA_CACHE_DIR'] = cache_path - -from pyreason.pyreason import * -import yaml -from importlib.metadata import version -from pkg_resources import get_distribution, DistributionNotFound - try: __version__ = get_distribution(__name__).version except DistributionNotFound: diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py index 10ab683e..6e2de24a 100755 --- a/pyreason/pyreason.py +++ b/pyreason/pyreason.py @@ -1,4 +1,5 @@ # This is the file that will be imported when "import pyreason" is called. All content will be run automatically +import importlib import networkx as nx import numba import time @@ -25,15 +26,14 @@ import pyreason.scripts.numba_wrapper.numba_types.fact_edge_type as fact_edge import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval from pyreason.scripts.utils.reorder_clauses import reorder_clauses -try: - import torch -except ImportError: +if importlib.util.find_spec("torch") is not None: + from pyreason.scripts.learning.classification.classifier import LogicIntegratedClassifier + from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions +else: LogicIntegratedClassifier = None ModelInterfaceOptions = None print('torch is not installed, model integration is disabled') -else: - from pyreason.scripts.learning.classification.classifier import LogicIntegratedClassifier - from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions + # USER VARIABLES From cd124e1981baa6415cb6853dbf9d55c29dda7a7a Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 9 Oct 2025 12:14:59 -0400 Subject: [PATCH 06/32] EOF fixer --- .github/workflows/python-package-version-test.yml | 3 +-- .gitignore | 1 - .numba_config.yaml | 2 +- .pre-commit-config.yaml | 14 ++++++++++---- Makefile | 10 +++++----- docs/hello-world.md | 2 +- docs/hello-world.py | 1 - docs/source/api_reference/index.rst | 2 +- docs/source/examples_rst/advanced_example.rst | 1 - .../examples_rst/advanced_output_example.rst | 2 +- docs/source/examples_rst/annF_average_example.rst | 1 - .../annF_linear_combination_example.rst | 2 +- docs/source/examples_rst/index.rst | 2 +- docs/source/examples_rst/infer_edges_example.rst | 2 +- docs/source/installation.rst | 2 +- docs/source/license.rst | 1 - docs/source/tutorials/advanced_tutorial.rst | 2 +- docs/source/tutorials/index.rst | 2 +- docs/source/tutorials/infer_edges.rst | 3 --- docs/source/user_guide/3_pyreason_rules.rst | 2 +- docs/source/user_guide/4_pyreason_settings.rst | 2 +- .../user_guide/5_inconsistent_predicate_list.rst | 2 +- .../source/user_guide/7_jupyter_notebook_usage.rst | 1 - docs/source/user_guide/index.rst | 3 --- examples/advanced_output.txt | 2 +- examples/temporal_classifier_integration_ex.py | 2 +- initialize.py | 2 +- jobs/.gitignore | 2 +- output/.gitignore | 2 +- profiling/.gitignore | 2 +- pyreason/__init__.py | 2 +- .../annotation_functions/annotation_functions.py | 2 -- pyreason/scripts/components/label.py | 2 +- pyreason/scripts/interpretation/interpretation.py | 2 +- .../scripts/interpretation/interpretation_fp.py | 2 +- .../interpretation/interpretation_parallel.py | 2 +- .../numba_wrapper/numba_types/label_type.py | 1 - pyreason/scripts/utils/graphml_parser.py | 2 +- pyreason/scripts/utils/plotter.py | 2 +- pyreason/scripts/utils/visuals.py | 2 +- pytest.ini | 2 +- run_on_agave.sh | 2 -- run_tests.py | 2 +- test_config.json | 2 +- .../functional/knowledge_graph_test_subset.graphml | 2 +- tests/functional/test_annotation_function.py | 2 +- tests/functional/test_anyBurl_infer_edges_rules.py | 2 +- tests/functional/test_custom_thresholds.py | 2 +- tests/functional/test_hello_world.py | 2 +- tests/functional/test_hello_world_parallel.py | 1 - tests/functional/test_pyreason_comprehensive.py | 2 +- tests/functional/test_reason_again.py | 2 +- tests/unit/api_tests/conftest.py | 2 +- .../unit/api_tests/test_pyreason_add_operations.py | 2 +- tests/unit/api_tests/test_pyreason_file_loading.py | 2 +- tests/unit/api_tests/test_pyreason_reasoning.py | 2 +- tests/unit/api_tests/test_pyreason_settings.py | 2 +- tests/unit/api_tests/test_pyreason_validation.py | 2 +- .../interpretations/test_reason_core.py | 2 -- .../interpretations/test_reason_update.py | 2 -- 60 files changed, 59 insertions(+), 76 deletions(-) diff --git a/.github/workflows/python-package-version-test.yml b/.github/workflows/python-package-version-test.yml index edc07f83..e8f52d67 100644 --- a/.github/workflows/python-package-version-test.yml +++ b/.github/workflows/python-package-version-test.yml @@ -32,7 +32,7 @@ jobs: if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with ruff run: | - python -m ruff check pyreason/scripts/ + python -m ruff check pyreason - name: Pytest Unit Tests with JIT Disabled run: | pytest tests/unit/disable_jit @@ -42,4 +42,3 @@ jobs: - name: Pytest Functional Tests run: | pytest tests/functional - diff --git a/.gitignore b/.gitignore index 869f45d1..a8b74e70 100755 --- a/.gitignore +++ b/.gitignore @@ -173,4 +173,3 @@ cython_debug/ # Sphinx Documentation /docs/source/_static/css/fonts/ - diff --git a/.numba_config.yaml b/.numba_config.yaml index b9caca04..25013e34 100644 --- a/.numba_config.yaml +++ b/.numba_config.yaml @@ -1,2 +1,2 @@ disable_jit: 0 -cache_dir: ./pyreason/cache/ \ No newline at end of file +cache_dir: ./pyreason/cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3002f2a1..d9d8c66d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,11 @@ repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: end-of-file-fixer + name: Fix end of files + stages: [pre-commit] + - repo: local hooks: # --- COMMIT STAGE: Fast unit tests only --- @@ -11,23 +18,22 @@ repos: stages: [pre-commit] - id: pytest-unit-no-jit - name: Run JIT-disabled unit tests + name: Run JIT-disabled unit tests entry: .venv/bin/python -m pytest tests/unit/disable_jit -m "not slow" --tb=short -q language: system pass_filenames: false stages: [pre-commit] - id: pytest-unit-jit - name: Run JIT-enabled unit tests + name: Run JIT-enabled unit tests entry: .venv/bin/python -m pytest tests/unit/dont_disable_jit -m "not slow" --tb=short -q language: system pass_filenames: false stages: [pre-commit] # --- PUSH STAGE: Complete test suite --- - - id: pytest-unit-api - name: Run pyreason api unit tests + name: Run pyreason api unit tests entry: .venv/bin/python -m pytest tests/unit/api_tests --tb=short -q language: system pass_filenames: false diff --git a/Makefile b/Makefile index 4235deee..05a049c1 100644 --- a/Makefile +++ b/Makefile @@ -131,10 +131,10 @@ coverage-xml: ## Show path to XML coverage report # Development targets lint: ## Run linting checks @echo "$(BOLD)$(BLUE)Running linting checks...$(RESET)" - @echo "$(YELLOW)Note: Add your preferred linter commands here$(RESET)" - # Example: flake8 pyreason tests - # Example: black --check pyreason tests - # Example: mypy pyreason + @echo "Fixing end of files..." + @pre-commit run end-of-file-fixer --all-files || true + @echo "Running ruff..." + ./.venv/bin/python -m ruff check pyreason/scripts check-deps: ## Check if required dependencies are installed @echo "$(BOLD)$(BLUE)Checking dependencies...$(RESET)" @@ -176,4 +176,4 @@ info: ## Show project and tool versions @echo "Pytest: $$($(PYTHON) -c 'import pytest; print(pytest.__version__)' 2>/dev/null || echo 'Not installed')" @echo "Coverage: $$($(PYTHON) -c 'import coverage; print(coverage.__version__)' 2>/dev/null || echo 'Not installed')" @echo "Working Directory: $$(pwd)" - @echo "Test Runner: $$(ls -la run_tests.py 2>/dev/null || echo 'Not found')" \ No newline at end of file + @echo "Test Runner: $$(ls -la run_tests.py 2>/dev/null || echo 'Not found')" diff --git a/docs/hello-world.md b/docs/hello-world.md index cc4a75cf..401670e4 100755 --- a/docs/hello-world.md +++ b/docs/hello-world.md @@ -131,4 +131,4 @@ After running the python file, the expected output is: 3. For timestep 2, since Justin has just become popular, John now has one popular friend (Justin) and the same pet as Justin (dog). Therefore `Justin -> popular: [1,1]` -We also output two CSV files detailing all the events that took place during reasoning (one for nodes, one for edges) \ No newline at end of file +We also output two CSV files detailing all the events that took place during reasoning (one for nodes, one for edges) diff --git a/docs/hello-world.py b/docs/hello-world.py index e2c4b379..77a6b2f9 100644 --- a/docs/hello-world.py +++ b/docs/hello-world.py @@ -51,4 +51,3 @@ # Get all interpretations in a dictionary interpretations_dict = interpretation.get_interpretation_dict() - diff --git a/docs/source/api_reference/index.rst b/docs/source/api_reference/index.rst index 3c1b4a08..b0fdadea 100644 --- a/docs/source/api_reference/index.rst +++ b/docs/source/api_reference/index.rst @@ -5,4 +5,4 @@ API Documentation .. automodule:: pyreason :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: diff --git a/docs/source/examples_rst/advanced_example.rst b/docs/source/examples_rst/advanced_example.rst index b7762bad..7f096dc0 100644 --- a/docs/source/examples_rst/advanced_example.rst +++ b/docs/source/examples_rst/advanced_example.rst @@ -164,4 +164,3 @@ Advanced Example # Display filtered node and edge data print(df1) print(df2) - diff --git a/docs/source/examples_rst/advanced_output_example.rst b/docs/source/examples_rst/advanced_output_example.rst index 9c4c416c..edd67965 100644 --- a/docs/source/examples_rst/advanced_output_example.rst +++ b/docs/source/examples_rst/advanced_output_example.rst @@ -695,4 +695,4 @@ Advanced Example Full Output ('customer_6', 'Car_6'): {}, ('customer_6', 'Pet_4'): {}, ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}}} - \ No newline at end of file + diff --git a/docs/source/examples_rst/annF_average_example.rst b/docs/source/examples_rst/annF_average_example.rst index c27541e0..afcfdbb2 100644 --- a/docs/source/examples_rst/annF_average_example.rst +++ b/docs/source/examples_rst/annF_average_example.rst @@ -58,4 +58,3 @@ Average Annotation Function Example assert interpretation.query('average_function(A, B) : [0.105, 1]'), 'Average function should be [0.105, 1]' average_annotation_function() - diff --git a/docs/source/examples_rst/annF_linear_combination_example.rst b/docs/source/examples_rst/annF_linear_combination_example.rst index 18a897c5..199a3cae 100644 --- a/docs/source/examples_rst/annF_linear_combination_example.rst +++ b/docs/source/examples_rst/annF_linear_combination_example.rst @@ -85,4 +85,4 @@ Linear Combination Annotation Function Example assert interpretation.query('linear_combination_function(A, B) : [0.1, 0.4]'), 'Linear combination function should be [0.105, 1]' # Run the test function - linear_combination_annotation_function() \ No newline at end of file + linear_combination_annotation_function() diff --git a/docs/source/examples_rst/index.rst b/docs/source/examples_rst/index.rst index c5927584..a120ef53 100644 --- a/docs/source/examples_rst/index.rst +++ b/docs/source/examples_rst/index.rst @@ -13,4 +13,4 @@ Examples ./* - \ No newline at end of file + diff --git a/docs/source/examples_rst/infer_edges_example.rst b/docs/source/examples_rst/infer_edges_example.rst index b295a589..75ad1a1f 100644 --- a/docs/source/examples_rst/infer_edges_example.rst +++ b/docs/source/examples_rst/infer_edges_example.rst @@ -83,4 +83,4 @@ Infer Edges Example assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' nx.draw(G, with_labels=True, node_color='lightblue', font_weight='bold', node_size=3000) - plt.show() \ No newline at end of file + plt.show() diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 20060dc0..d3586355 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -20,4 +20,4 @@ You will see a message like this when you import PyReason for the first time: .. code:: text - Imported PyReason for the first time. Initializing caches for faster runtimes ... this will take a minute \ No newline at end of file + Imported PyReason for the first time. Initializing caches for faster runtimes ... this will take a minute diff --git a/docs/source/license.rst b/docs/source/license.rst index c9569d7d..2b02fe4e 100644 --- a/docs/source/license.rst +++ b/docs/source/license.rst @@ -9,4 +9,3 @@ Trademark Permission :width: 50 PyReason™ and PyReason Design Logo |logo| ™ are trademarks of the Arizona Board of Regents/Arizona State University. Users of the software are permitted to use PyReason™ in association with the software for any purpose, provided such use is related to the software (e.g., Powered by PyReason™). Additionally, educational institutions are permitted to use the PyReason Design Logo |logo| ™ for non-commercial purposes. - diff --git a/docs/source/tutorials/advanced_tutorial.rst b/docs/source/tutorials/advanced_tutorial.rst index 22248abf..54c8297d 100644 --- a/docs/source/tutorials/advanced_tutorial.rst +++ b/docs/source/tutorials/advanced_tutorial.rst @@ -161,4 +161,4 @@ Below is the expected output at timestep ``0`` ('customer_6', 'Car_6'): {}, ('customer_6', 'Pet_4'): {}, ('customer_6', 'customer_0'): {'same_color_car': (1.0, 1.0)}}, - \ No newline at end of file + diff --git a/docs/source/tutorials/index.rst b/docs/source/tutorials/index.rst index e0c2f755..947b3295 100644 --- a/docs/source/tutorials/index.rst +++ b/docs/source/tutorials/index.rst @@ -16,4 +16,4 @@ Contents ./custom_thresholds.rst ./infer_edges.rst ./annotation_function.rst - \ No newline at end of file + diff --git a/docs/source/tutorials/infer_edges.rst b/docs/source/tutorials/infer_edges.rst index 6442322b..67655b62 100644 --- a/docs/source/tutorials/infer_edges.rst +++ b/docs/source/tutorials/infer_edges.rst @@ -146,6 +146,3 @@ The graph after running shows a new connection from ``Vnukovo_International_Airp .. image:: ../../../media/infer_edges2.png :align: center - - - diff --git a/docs/source/user_guide/3_pyreason_rules.rst b/docs/source/user_guide/3_pyreason_rules.rst index 51e9be53..96f87be2 100644 --- a/docs/source/user_guide/3_pyreason_rules.rst +++ b/docs/source/user_guide/3_pyreason_rules.rst @@ -208,4 +208,4 @@ PyReason's ``Threshold`` class is used to define custom thresholds for a rule. T and the second is either ``"total"`` or ``"available"``. ``"total"`` refers to all groundings of the clause, while ``"available"`` refers to the groundings that have the predicate of the clause. #. ``thresh`` **(int)**: The value of the threshold -An example usage can be found :ref:`here `. \ No newline at end of file +An example usage can be found :ref:`here `. diff --git a/docs/source/user_guide/4_pyreason_settings.rst b/docs/source/user_guide/4_pyreason_settings.rst index bd7d3305..8f037d97 100644 --- a/docs/source/user_guide/4_pyreason_settings.rst +++ b/docs/source/user_guide/4_pyreason_settings.rst @@ -91,4 +91,4 @@ Notes on Parallelism ~~~~~~~~~~~~~~~~~~~~ PyReason is parallelized over rules, so for large rulesets it is recommended that this setting is used. However, for small rulesets, the overhead might be more than the speedup and it is worth checking the performance on your specific use case. -When possible we recommend using the same number of cores (or a multiple) as the number of rules in the program. \ No newline at end of file +When possible we recommend using the same number of cores (or a multiple) as the number of rules in the program. diff --git a/docs/source/user_guide/5_inconsistent_predicate_list.rst b/docs/source/user_guide/5_inconsistent_predicate_list.rst index 2be00a50..136fc65a 100644 --- a/docs/source/user_guide/5_inconsistent_predicate_list.rst +++ b/docs/source/user_guide/5_inconsistent_predicate_list.rst @@ -18,4 +18,4 @@ This can be done by using the following code: pr.add_inconsistent_predicate('sick', 'healthy') This allows PyReason to automatically update the bounds of the predicates in the inconsistent predicate list to the -negation of a predicate that is updated. \ No newline at end of file +negation of a predicate that is updated. diff --git a/docs/source/user_guide/7_jupyter_notebook_usage.rst b/docs/source/user_guide/7_jupyter_notebook_usage.rst index 4eb2bea0..80c59643 100644 --- a/docs/source/user_guide/7_jupyter_notebook_usage.rst +++ b/docs/source/user_guide/7_jupyter_notebook_usage.rst @@ -8,4 +8,3 @@ However, if you want to use PyReason in a Jupyter Notebook, make sure you unders 1. When using functions like ``add_rule`` or ``add_fact`` in a Jupyter Notebook, make sure to run the cell only once. Running the cell multiple times will add the same rule/fact multiple times. It is recommended to store all the rules and facts in an array and then add them all at once in one cell towards the end 2. Functions like ``load_graph`` and ``load_graphml`` which are run multiple times can also have the same issue. Make sure to run them only once. - diff --git a/docs/source/user_guide/index.rst b/docs/source/user_guide/index.rst index 584d3f01..2587e974 100644 --- a/docs/source/user_guide/index.rst +++ b/docs/source/user_guide/index.rst @@ -10,6 +10,3 @@ In this section we demonstrate the functionality of the `pyreason` library and h :glob: ./* - - - diff --git a/examples/advanced_output.txt b/examples/advanced_output.txt index b24449d2..5c785988 100644 --- a/examples/advanced_output.txt +++ b/examples/advanced_output.txt @@ -836,4 +836,4 @@ Filtered Edges: 6 (customer_4, customer_5) [0, 1] [1.0, 1.0] 7 (customer_5, customer_3) [0, 1] [1.0, 1.0] 8 (customer_5, customer_6) [0, 1] [1.0, 1.0] -9 (customer_6, customer_0) [0, 1] [1.0, 1.0]] \ No newline at end of file +9 (customer_6, customer_0) [0, 1] [1.0, 1.0]] diff --git a/examples/temporal_classifier_integration_ex.py b/examples/temporal_classifier_integration_ex.py index f2f97d57..7113356e 100644 --- a/examples/temporal_classifier_integration_ex.py +++ b/examples/temporal_classifier_integration_ex.py @@ -77,4 +77,4 @@ trace = pr.get_rule_trace(interpretation) print("\n=== Reasoning Rule Trace ===") -print(trace[0]) \ No newline at end of file +print(trace[0]) diff --git a/initialize.py b/initialize.py index ef1637f2..fcdb6d8b 100644 --- a/initialize.py +++ b/initialize.py @@ -1,4 +1,4 @@ # Run this script after cloning repository to generate the numba caches. This script runs the hello-world program internally print('Initializing PyReason caches') import pyreason as pr - \ No newline at end of file + diff --git a/jobs/.gitignore b/jobs/.gitignore index e7a210ec..94548af5 100755 --- a/jobs/.gitignore +++ b/jobs/.gitignore @@ -1,3 +1,3 @@ * */ -!.gitignore \ No newline at end of file +!.gitignore diff --git a/output/.gitignore b/output/.gitignore index e7a210ec..94548af5 100755 --- a/output/.gitignore +++ b/output/.gitignore @@ -1,3 +1,3 @@ * */ -!.gitignore \ No newline at end of file +!.gitignore diff --git a/profiling/.gitignore b/profiling/.gitignore index e7a210ec..94548af5 100755 --- a/profiling/.gitignore +++ b/profiling/.gitignore @@ -1,3 +1,3 @@ * */ -!.gitignore \ No newline at end of file +!.gitignore diff --git a/pyreason/__init__.py b/pyreason/__init__.py index 366d0595..5fc960e8 100755 --- a/pyreason/__init__.py +++ b/pyreason/__init__.py @@ -1,7 +1,7 @@ # Set numba environment variable import os import yaml -from pyreason.pyreason import settings, load_graphml, add_rule, Rule, add_fact, Fact, reason, reset, reset_rules +from pyreason.pyreason import * from pkg_resources import get_distribution, DistributionNotFound package_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/pyreason/scripts/annotation_functions/annotation_functions.py b/pyreason/scripts/annotation_functions/annotation_functions.py index 9ef577de..75eb9b6f 100755 --- a/pyreason/scripts/annotation_functions/annotation_functions.py +++ b/pyreason/scripts/annotation_functions/annotation_functions.py @@ -99,5 +99,3 @@ def minimum(annotations, weights): lower, upper = _check_bound(min_lower, min_upper) return interval.closed(lower, upper) - - diff --git a/pyreason/scripts/components/label.py b/pyreason/scripts/components/label.py index 443d3a42..6bb5dce5 100755 --- a/pyreason/scripts/components/label.py +++ b/pyreason/scripts/components/label.py @@ -17,4 +17,4 @@ def __hash__(self): return hash(str(self)) def __repr__(self): - return self.get_value() \ No newline at end of file + return self.get_value() diff --git a/pyreason/scripts/interpretation/interpretation.py b/pyreason/scripts/interpretation/interpretation.py index 5cadb8bd..4f7ac3c1 100755 --- a/pyreason/scripts/interpretation/interpretation.py +++ b/pyreason/scripts/interpretation/interpretation.py @@ -1964,4 +1964,4 @@ def str_to_int(value): for i, v in enumerate(value): result += (ord(v) - 48) * (10 ** (final_index - i)) result = -result if negative else result - return result \ No newline at end of file + return result diff --git a/pyreason/scripts/interpretation/interpretation_fp.py b/pyreason/scripts/interpretation/interpretation_fp.py index cfd7e4e4..f8e0bef9 100755 --- a/pyreason/scripts/interpretation/interpretation_fp.py +++ b/pyreason/scripts/interpretation/interpretation_fp.py @@ -2058,4 +2058,4 @@ def str_to_int(value): for i, v in enumerate(value): result += (ord(v) - 48) * (10 ** (final_index - i)) result = -result if negative else result - return result \ No newline at end of file + return result diff --git a/pyreason/scripts/interpretation/interpretation_parallel.py b/pyreason/scripts/interpretation/interpretation_parallel.py index 3f2dcdcb..3d98bd7e 100644 --- a/pyreason/scripts/interpretation/interpretation_parallel.py +++ b/pyreason/scripts/interpretation/interpretation_parallel.py @@ -1964,4 +1964,4 @@ def str_to_int(value): for i, v in enumerate(value): result += (ord(v) - 48) * (10 ** (final_index - i)) result = -result if negative else result - return result \ No newline at end of file + return result diff --git a/pyreason/scripts/numba_wrapper/numba_types/label_type.py b/pyreason/scripts/numba_wrapper/numba_types/label_type.py index 97e3f7fa..bf007dcc 100755 --- a/pyreason/scripts/numba_wrapper/numba_types/label_type.py +++ b/pyreason/scripts/numba_wrapper/numba_types/label_type.py @@ -103,4 +103,3 @@ def box_label(typ, val, c): c.pyapi.decref(value_obj) c.pyapi.decref(class_obj) return res - diff --git a/pyreason/scripts/utils/graphml_parser.py b/pyreason/scripts/utils/graphml_parser.py index 7b6ec87e..2eb2c7ae 100755 --- a/pyreason/scripts/utils/graphml_parser.py +++ b/pyreason/scripts/utils/graphml_parser.py @@ -90,4 +90,4 @@ def parse_graph_attributes(self, static_facts): f = fact_edge.Fact('graph-attribute-fact', (e[0], e[1]), label.Label(label_str), interval.closed(lower_bnd, upper_bnd), 0, 0, static=static_facts) facts_edge.append(f) - return facts_node, facts_edge, specific_node_labels, specific_edge_labels \ No newline at end of file + return facts_node, facts_edge, specific_node_labels, specific_edge_labels diff --git a/pyreason/scripts/utils/plotter.py b/pyreason/scripts/utils/plotter.py index ddc1e828..06251fb3 100755 --- a/pyreason/scripts/utils/plotter.py +++ b/pyreason/scripts/utils/plotter.py @@ -69,4 +69,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/pyreason/scripts/utils/visuals.py b/pyreason/scripts/utils/visuals.py index 6ec4ac4d..0d920e8e 100755 --- a/pyreason/scripts/utils/visuals.py +++ b/pyreason/scripts/utils/visuals.py @@ -21,4 +21,4 @@ def make_visuals(graph_data, nodelist): else: color_map.append('green') labels_g=nx.get_node_attributes(graph_data, "name") - nx.draw(graph_data, pos=pos_g, node_color=color_map, node_size=100, font_size=10, font_color='DarkBlue', with_labels=True, labels=labels_g) \ No newline at end of file + nx.draw(graph_data, pos=pos_g, node_color=color_map, node_size=100, font_size=10, font_color='DarkBlue', with_labels=True, labels=labels_g) diff --git a/pytest.ini b/pytest.ini index 9f4a72e8..ae7474ce 100644 --- a/pytest.ini +++ b/pytest.ini @@ -64,4 +64,4 @@ console_output_style = progress # timeout = 60 # Disabled for functional tests that may take longer # JUnit XML output (useful for CI/CD) -junit_family = xunit2 \ No newline at end of file +junit_family = xunit2 diff --git a/run_on_agave.sh b/run_on_agave.sh index 44d77bad..cd4fab19 100755 --- a/run_on_agave.sh +++ b/run_on_agave.sh @@ -45,5 +45,3 @@ fi # Run pyreason python3 -u -m pyreason.scripts.diffuse --graph_path $graph_path --timesteps $timesteps --rules $rules_yaml_path --facts $facts_yaml_path --labels $labels_yaml_path --ipl $ipl_yaml_path --output_to_file --output_file $output_file_name #------------------------------------------------------------------------- - - diff --git a/run_tests.py b/run_tests.py index 47cf533f..9f6a5ca2 100755 --- a/run_tests.py +++ b/run_tests.py @@ -525,4 +525,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/test_config.json b/test_config.json index f0aa588f..086502e8 100644 --- a/test_config.json +++ b/test_config.json @@ -73,4 +73,4 @@ "junit_xml": true, "console_output": "detailed" } -} \ No newline at end of file +} diff --git a/tests/functional/knowledge_graph_test_subset.graphml b/tests/functional/knowledge_graph_test_subset.graphml index 72e5c23b..bdd7dcea 100644 --- a/tests/functional/knowledge_graph_test_subset.graphml +++ b/tests/functional/knowledge_graph_test_subset.graphml @@ -68,4 +68,4 @@ - \ No newline at end of file + diff --git a/tests/functional/test_annotation_function.py b/tests/functional/test_annotation_function.py index 8193559c..635cc425 100644 --- a/tests/functional/test_annotation_function.py +++ b/tests/functional/test_annotation_function.py @@ -79,4 +79,4 @@ def test_annotation_function_fp(): print(df) print() - assert interpretation.query(pr.Query('union_probability(A, B) : [0.21, 1]')), 'Union probability should be 0.21' \ No newline at end of file + assert interpretation.query(pr.Query('union_probability(A, B) : [0.21, 1]')), 'Union probability should be 0.21' diff --git a/tests/functional/test_anyBurl_infer_edges_rules.py b/tests/functional/test_anyBurl_infer_edges_rules.py index ab3df7a8..c58c6ce3 100644 --- a/tests/functional/test_anyBurl_infer_edges_rules.py +++ b/tests/functional/test_anyBurl_infer_edges_rules.py @@ -299,4 +299,4 @@ def test_anyBurl_rule_4_fp(): print() assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Yali', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Yali, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' \ No newline at end of file + assert ('Yali', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Yali, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' diff --git a/tests/functional/test_custom_thresholds.py b/tests/functional/test_custom_thresholds.py index f412a47c..4edb804d 100644 --- a/tests/functional/test_custom_thresholds.py +++ b/tests/functional/test_custom_thresholds.py @@ -125,4 +125,4 @@ def test_custom_thresholds_fp(): ].ViewedByAll == [ 1, 1, - ], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps" \ No newline at end of file + ], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps" diff --git a/tests/functional/test_hello_world.py b/tests/functional/test_hello_world.py index 016e8006..5b47c589 100644 --- a/tests/functional/test_hello_world.py +++ b/tests/functional/test_hello_world.py @@ -170,4 +170,4 @@ def py_closed(lower, upper, static=False): finally: interval_type.closed = original_closed - assert jit_res == py_res \ No newline at end of file + assert jit_res == py_res diff --git a/tests/functional/test_hello_world_parallel.py b/tests/functional/test_hello_world_parallel.py index f809a6c4..18f4a182 100644 --- a/tests/functional/test_hello_world_parallel.py +++ b/tests/functional/test_hello_world_parallel.py @@ -45,4 +45,3 @@ def test_hello_world_parallel(): # John should be popular in timestep 3 assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' - diff --git a/tests/functional/test_pyreason_comprehensive.py b/tests/functional/test_pyreason_comprehensive.py index d1113de8..ea4b06ac 100644 --- a/tests/functional/test_pyreason_comprehensive.py +++ b/tests/functional/test_pyreason_comprehensive.py @@ -433,4 +433,4 @@ def test_settings_validation_fp(self): if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/tests/functional/test_reason_again.py b/tests/functional/test_reason_again.py index 2ad8c338..ec853438 100644 --- a/tests/functional/test_reason_again.py +++ b/tests/functional/test_reason_again.py @@ -108,4 +108,4 @@ def test_reason_again_fp(): assert 'Justin' in dataframes[4]['component'].values and dataframes[4].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps' # John should be popular in timestep 3 - assert 'John' in dataframes[4]['component'].values and dataframes[4].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' \ No newline at end of file + assert 'John' in dataframes[4]['component'].values and dataframes[4].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' diff --git a/tests/unit/api_tests/conftest.py b/tests/unit/api_tests/conftest.py index f8cf07ee..5aea3fb2 100644 --- a/tests/unit/api_tests/conftest.py +++ b/tests/unit/api_tests/conftest.py @@ -4,4 +4,4 @@ """ # No special setup needed - let pyreason import normally with JIT enabled -# This allows us to test the actual API behavior \ No newline at end of file +# This allows us to test the actual API behavior diff --git a/tests/unit/api_tests/test_pyreason_add_operations.py b/tests/unit/api_tests/test_pyreason_add_operations.py index 426ea2b7..fb84f6f9 100644 --- a/tests/unit/api_tests/test_pyreason_add_operations.py +++ b/tests/unit/api_tests/test_pyreason_add_operations.py @@ -421,4 +421,4 @@ def test_counters_after_reset(self): pr.add_fact(fact2) # Fact should get auto-generated name - assert fact2.name.startswith('fact_') \ No newline at end of file + assert fact2.name.startswith('fact_') diff --git a/tests/unit/api_tests/test_pyreason_file_loading.py b/tests/unit/api_tests/test_pyreason_file_loading.py index 3c9219dd..2767d915 100644 --- a/tests/unit/api_tests/test_pyreason_file_loading.py +++ b/tests/unit/api_tests/test_pyreason_file_loading.py @@ -963,4 +963,4 @@ def test_get_rule_trace_returns_dataframes(self): # DataFrames should have some basic expected structure # (exact columns depend on implementation, but they should be valid DataFrames) assert hasattr(node_trace, 'columns') - assert hasattr(edge_trace, 'columns') \ No newline at end of file + assert hasattr(edge_trace, 'columns') diff --git a/tests/unit/api_tests/test_pyreason_reasoning.py b/tests/unit/api_tests/test_pyreason_reasoning.py index 9049e24a..bcbbf4f7 100644 --- a/tests/unit/api_tests/test_pyreason_reasoning.py +++ b/tests/unit/api_tests/test_pyreason_reasoning.py @@ -941,4 +941,4 @@ def test_reason_again_assert_coverage(self): # Now the assert in _reason_again should pass interpretation2 = pr.reason(timesteps=1, again=True) - assert interpretation2 is not None \ No newline at end of file + assert interpretation2 is not None diff --git a/tests/unit/api_tests/test_pyreason_settings.py b/tests/unit/api_tests/test_pyreason_settings.py index 3614b0a1..bf907841 100644 --- a/tests/unit/api_tests/test_pyreason_settings.py +++ b/tests/unit/api_tests/test_pyreason_settings.py @@ -522,4 +522,4 @@ def test_multiple_settings_modifications(self): # All should be set correctly assert pr.settings.verbose is True assert pr.settings.memory_profile is True - assert pr.settings.output_file_name == "test" \ No newline at end of file + assert pr.settings.output_file_name == "test" diff --git a/tests/unit/api_tests/test_pyreason_validation.py b/tests/unit/api_tests/test_pyreason_validation.py index 0d91eff1..7c1c4a93 100644 --- a/tests/unit/api_tests/test_pyreason_validation.py +++ b/tests/unit/api_tests/test_pyreason_validation.py @@ -215,4 +215,4 @@ def test_reset_after_errors(self): pr.add_rule(rule) # Should succeed without exception # Verify the new rule was created correctly - assert hasattr(rule.rule, 'get_rule_name') # Rule object exists \ No newline at end of file + assert hasattr(rule.rule, 'get_rule_name') # Rule object exists diff --git a/tests/unit/disable_jit/interpretations/test_reason_core.py b/tests/unit/disable_jit/interpretations/test_reason_core.py index d9cefd4b..872c6171 100644 --- a/tests/unit/disable_jit/interpretations/test_reason_core.py +++ b/tests/unit/disable_jit/interpretations/test_reason_core.py @@ -1611,5 +1611,3 @@ def test_reason_breaks_on_perfect_convergence(monkeypatch, reason_env): assert fp == 0 and max_t == 1 assert any("Converged at fp" in line for line in printed) - - diff --git a/tests/unit/disable_jit/interpretations/test_reason_update.py b/tests/unit/disable_jit/interpretations/test_reason_update.py index ccbb0287..4a31bae5 100644 --- a/tests/unit/disable_jit/interpretations/test_reason_update.py +++ b/tests/unit/disable_jit/interpretations/test_reason_update.py @@ -1603,5 +1603,3 @@ def __init__(self): assert new_edge is True assert l in interpretations_edge[edge].world assert predicate_map[l] == [edge] - - From 3e1a6f247b6fb4236adccde17dbffa016cfd58b3 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 9 Oct 2025 14:36:51 -0400 Subject: [PATCH 07/32] Exuse main pyreason from broad import lint fail --- pyreason/__init__.py | 1 + pyreason/pyreason.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pyreason/__init__.py b/pyreason/__init__.py index 5fc960e8..3324cbbd 100755 --- a/pyreason/__init__.py +++ b/pyreason/__init__.py @@ -1,4 +1,5 @@ # Set numba environment variable +# ruff: noqa: F403 F405 (Ignore Pyreason import * for public api) import os import yaml from pyreason.pyreason import * diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py index 6e2de24a..61a4d4b7 100755 --- a/pyreason/pyreason.py +++ b/pyreason/pyreason.py @@ -1,4 +1,5 @@ # This is the file that will be imported when "import pyreason" is called. All content will be run automatically +# ruff: noqa: F401 (Ignore Pyreason import * for public api) import importlib import networkx as nx import numba From b2e04ea12d4d3f558f49a1711ebdd82db88673fc Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 9 Oct 2025 14:41:00 -0400 Subject: [PATCH 08/32] add back import --- pyreason/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyreason/__init__.py b/pyreason/__init__.py index 3324cbbd..544ddf94 100755 --- a/pyreason/__init__.py +++ b/pyreason/__init__.py @@ -4,6 +4,7 @@ import yaml from pyreason.pyreason import * from pkg_resources import get_distribution, DistributionNotFound +from importlib.metadata import version package_path = os.path.abspath(os.path.dirname(__file__)) cache_path = os.path.join(package_path, 'cache') @@ -16,7 +17,6 @@ # package is not installed pass - with open(cache_status_path) as file: cache_status = yaml.safe_load(file) From 5cc0ca118340f06967dda6b412b122a87298a488 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sat, 18 Oct 2025 14:16:20 -0400 Subject: [PATCH 09/32] Upd --- .pre-commit-config.yaml | 21 +- check_variable_names.py | 201 ++++++++++++++++++ .../scripts/interpretation/interpretation.py | 12 +- .../interpretation/interpretation_parallel.py | 12 +- run_tests.py | 1 + sync_interpretation_parallel.py | 81 +++++++ tests/functional/test_hello_world_parallel.py | 48 ----- .../test_pyreason_file_consistency.py | 78 +++++++ 8 files changed, 381 insertions(+), 73 deletions(-) create mode 100644 check_variable_names.py create mode 100644 sync_interpretation_parallel.py delete mode 100644 tests/functional/test_hello_world_parallel.py create mode 100644 tests/unit/api_tests/test_pyreason_file_consistency.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 08a5ee59..44d9580b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,15 +3,23 @@ repos: - repo: local hooks: # --- COMMIT STAGE: Fast unit tests only --- + - id: sync-interpretation-parallel + name: Sync interpretation_parallel.py from interpretation.py + entry: .venv/bin/python sync_interpretation_parallel.py + language: system + pass_filenames: false + files: 'pyreason/scripts/interpretation/interpretation\.py' + stages: [pre-commit] + - id: pytest-unit-no-jit - name: Run JIT-disabled unit tests + name: Run JIT-disabled unit tests entry: .venv/bin/python -m pytest tests/unit/disable_jit -m "not slow" --tb=short -q language: system pass_filenames: false stages: [pre-commit] - id: pytest-unit-jit - name: Run JIT-enabled unit tests + name: Run JIT-enabled unit tests entry: .venv/bin/python -m pytest tests/unit/dont_disable_jit -m "not slow" --tb=short -q language: system pass_filenames: false @@ -20,12 +28,19 @@ repos: # --- PUSH STAGE: Complete test suite --- - id: pytest-unit-api - name: Run pyreason api unit tests + name: Run pyreason api unit tests entry: .venv/bin/python -m pytest tests/unit/api_tests --tb=short -q language: system pass_filenames: false stages: [pre-push] + - id: pytest-file-consistency + name: Verify interpretation files are in sync + entry: .venv/bin/python -m pytest tests/unit/api_tests/test_pyreason_file_consistency.py -v + language: system + pass_filenames: false + stages: [pre-push] + - id: pytest-functional-complete name: Run functional test suite entry: .venv/bin/python -m pytest tests/functional/ --tb=short diff --git a/check_variable_names.py b/check_variable_names.py new file mode 100644 index 00000000..c228b2e0 --- /dev/null +++ b/check_variable_names.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +""" +Script to check for occurrences of specific variable names in Python codebase. +Uses AST parsing to avoid false positives from comments and strings. +""" + +import argparse +import ast +import sys +from pathlib import Path +from typing import List, Tuple + + +class VariableNameVisitor(ast.NodeVisitor): + """AST visitor that finds all occurrences of a specific variable name.""" + + def __init__(self, target_name: str, filepath: str): + self.target_name = target_name + self.filepath = filepath + self.occurrences: List[Tuple[int, int, str, str]] = [] + + def visit_Name(self, node: ast.Name) -> None: + """Visit Name nodes (variable references).""" + if node.id == self.target_name: + context = "reference" + if isinstance(node.ctx, ast.Store): + context = "assignment" + elif isinstance(node.ctx, ast.Del): + context = "deletion" + self.occurrences.append((node.lineno, node.col_offset, context, node.id)) + self.generic_visit(node) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + """Visit function definitions to check parameter names.""" + for arg in node.args.args + node.args.posonlyargs + node.args.kwonlyargs: + if arg.arg == self.target_name: + self.occurrences.append((arg.lineno, arg.col_offset, "function_parameter", arg.arg)) + if node.args.vararg and node.args.vararg.arg == self.target_name: + self.occurrences.append((node.args.vararg.lineno, node.args.vararg.col_offset, "vararg", node.args.vararg.arg)) + if node.args.kwarg and node.args.kwarg.arg == self.target_name: + self.occurrences.append((node.args.kwarg.lineno, node.args.kwarg.col_offset, "kwarg", node.args.kwarg.arg)) + self.generic_visit(node) + + def visit_arg(self, node: ast.arg) -> None: + """Visit argument nodes.""" + if node.arg == self.target_name: + self.occurrences.append((node.lineno, node.col_offset, "argument", node.arg)) + self.generic_visit(node) + + def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None: + """Visit exception handler to check exception variable names.""" + if node.name == self.target_name: + self.occurrences.append((node.lineno, node.col_offset, "exception_var", node.name)) + self.generic_visit(node) + + def visit_comprehension(self, node: ast.comprehension) -> None: + """Visit comprehension to check target variables.""" + if isinstance(node.target, ast.Name) and node.target.id == self.target_name: + self.occurrences.append((node.target.lineno, node.target.col_offset, "comprehension_var", node.target.id)) + self.generic_visit(node) + + +def find_python_files(root_dir: Path, exclude_patterns: List[str] = None) -> List[Path]: + """Find all Python files in the directory tree.""" + if exclude_patterns is None: + exclude_patterns = [] + + python_files = [] + for py_file in root_dir.rglob("*.py"): + # Check if file should be excluded + should_exclude = False + for pattern in exclude_patterns: + if pattern in str(py_file): + should_exclude = True + break + + if not should_exclude: + python_files.append(py_file) + + return python_files + + +def check_file_for_variable(filepath: Path, variable_name: str) -> List[Tuple[int, int, str, str]]: + """Check a single file for occurrences of the variable name.""" + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + + tree = ast.parse(content, filename=str(filepath)) + visitor = VariableNameVisitor(variable_name, str(filepath)) + visitor.visit(tree) + + return visitor.occurrences + except SyntaxError: + print(f"Warning: Could not parse {filepath} (syntax error)", file=sys.stderr) + return [] + except Exception as e: + print(f"Warning: Error processing {filepath}: {e}", file=sys.stderr) + return [] + + +def get_line_content(filepath: Path, lineno: int) -> str: + """Get the content of a specific line from a file.""" + try: + with open(filepath, 'r', encoding='utf-8') as f: + lines = f.readlines() + if 0 < lineno <= len(lines): + return lines[lineno - 1].rstrip() + except Exception: + pass + return "" + + +def main(): + parser = argparse.ArgumentParser( + description="Check for occurrences of specific variable names in Python codebase" + ) + parser.add_argument( + "variable_name", + help="The variable name to search for" + ) + parser.add_argument( + "--path", + type=Path, + default=Path("."), + help="Root directory to search (default: current directory)" + ) + parser.add_argument( + "--exclude", + nargs="*", + default=[".venv", "__pycache__", ".git", "build", "dist", ".eggs"], + help="Patterns to exclude from search" + ) + parser.add_argument( + "--include-tests", + action="store_true", + help="Include test files (excluded by default)" + ) + parser.add_argument( + "--include-docs", + action="store_true", + help="Include documentation files (excluded by default)" + ) + + args = parser.parse_args() + + # Add default exclusions + exclude_patterns = args.exclude.copy() + if not args.include_tests: + exclude_patterns.extend(["test_", "tests/"]) + if not args.include_docs: + exclude_patterns.extend(["docs/", "examples/"]) + + # Find all Python files + print(f"Searching for variable '{args.variable_name}' in {args.path}") + print(f"Excluding patterns: {exclude_patterns}\n") + + python_files = find_python_files(args.path, exclude_patterns) + print(f"Found {len(python_files)} Python files to check\n") + + # Check each file + total_occurrences = 0 + files_with_occurrences = 0 + + for filepath in sorted(python_files): + occurrences = check_file_for_variable(filepath, args.variable_name) + + if occurrences: + files_with_occurrences += 1 + total_occurrences += len(occurrences) + + # Display relative path + try: + rel_path = filepath.relative_to(args.path) + except ValueError: + rel_path = filepath + + print(f"\n{rel_path}:") + for lineno, col_offset, context, var_name in sorted(occurrences): + line_content = get_line_content(filepath, lineno) + print(f" Line {lineno}, Col {col_offset} ({context}): {line_content.strip()}") + + # Summary + print(f"\n{'=' * 60}") + print(f"Summary:") + print(f" Variable name: '{args.variable_name}'") + print(f" Files checked: {len(python_files)}") + print(f" Files with occurrences: {files_with_occurrences}") + print(f" Total occurrences: {total_occurrences}") + print(f"{'=' * 60}") + + # Exit with non-zero if occurrences found + if total_occurrences > 0: + sys.exit(1) + else: + print("\nNo occurrences found!") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/pyreason/scripts/interpretation/interpretation.py b/pyreason/scripts/interpretation/interpretation.py index 03e46290..8bdc3051 100755 --- a/pyreason/scripts/interpretation/interpretation.py +++ b/pyreason/scripts/interpretation/interpretation.py @@ -463,11 +463,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi for idx, i in enumerate(rules_to_be_applied_edge): if i[0] == t: comp, l, bnd, set_static = i[1], i[2], i[3], i[4] - print('applying edge rule at time', t, 'for component', comp, 'label', l, 'bound', bnd, 'set_static', set_static) sources, targets, edge_l = edges_to_be_added_edge_rule[idx] - print("adding edges:", sources, targets, edge_l) edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) - print('after adding, edges are:', edges) changes_cnt += changes # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally @@ -605,26 +602,19 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi in_loop_threadsafe[i] = True update_threadsafe[i] = False - # Update lists after parallel run - print("len", len(rules_to_be_applied_edge_threadsafe)) - for i in rules_to_be_applied_edge_threadsafe: - print(i) + # Update lists after parallel run for i in range(len(rules)): if len(rules_to_be_applied_node_threadsafe[i]) > 0: rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i]) if len(rules_to_be_applied_edge_threadsafe[i]) > 0: - print('here, edge rules') rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i]) - print("rules_to_be_applied_edge", rules_to_be_applied_edge) if atom_trace: if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0: rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i]) if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0: rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i]) if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0: - print('here, edge add') edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i]) - print("edges_to_be_added_edge_rule", edges_to_be_added_edge_rule) # Merge threadsafe flags for in_loop and update in_loop = in_loop diff --git a/pyreason/scripts/interpretation/interpretation_parallel.py b/pyreason/scripts/interpretation/interpretation_parallel.py index 1972c404..45c41054 100644 --- a/pyreason/scripts/interpretation/interpretation_parallel.py +++ b/pyreason/scripts/interpretation/interpretation_parallel.py @@ -463,11 +463,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi for idx, i in enumerate(rules_to_be_applied_edge): if i[0] == t: comp, l, bnd, set_static = i[1], i[2], i[3], i[4] - print('applying edge rule at time', t, 'for component', comp, 'label', l, 'bound', bnd, 'set_static', set_static) sources, targets, edge_l = edges_to_be_added_edge_rule[idx] - print("adding edges:", sources, targets, edge_l) edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) - print('after adding, edges are:', edges) changes_cnt += changes # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally @@ -605,26 +602,19 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi in_loop_threadsafe[i] = True update_threadsafe[i] = False - # Update lists after parallel run - print("len", len(rules_to_be_applied_edge_threadsafe)) - for i in rules_to_be_applied_edge_threadsafe: - print(i) + # Update lists after parallel run for i in range(len(rules)): if len(rules_to_be_applied_node_threadsafe[i]) > 0: rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i]) if len(rules_to_be_applied_edge_threadsafe[i]) > 0: - print('here, edge rules') rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i]) - print("rules_to_be_applied_edge", rules_to_be_applied_edge) if atom_trace: if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0: rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i]) if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0: rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i]) if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0: - print('here, edge add') edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i]) - print("edges_to_be_added_edge_rule", edges_to_be_added_edge_rule) # Merge threadsafe flags for in_loop and update in_loop = in_loop diff --git a/run_tests.py b/run_tests.py index 47cf533f..8cbe0117 100755 --- a/run_tests.py +++ b/run_tests.py @@ -167,6 +167,7 @@ def run_suite(self, suite: TestSuite, coverage: bool = True) -> Tuple[bool, Opti # Prepare pytest command - try to find python/pytest python_cmd = self._find_python_command() cmd = [python_cmd, '-m', 'pytest'] + print("Running command with Python:", python_cmd) coverage_file_path = None if coverage: diff --git a/sync_interpretation_parallel.py b/sync_interpretation_parallel.py new file mode 100644 index 00000000..2ad25724 --- /dev/null +++ b/sync_interpretation_parallel.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +""" +Pre-commit hook script to synchronize interpretation_parallel.py from interpretation.py. + +This script ensures that interpretation_parallel.py is always an exact copy of +interpretation.py, except for line 226 which should have parallel=True instead of +parallel=False in the @numba.njit decorator. +""" + +import sys +from pathlib import Path + + +def sync_interpretation_files(): + """ + Synchronize interpretation_parallel.py from interpretation.py. + + Returns: + int: 0 on success, 1 on failure + """ + # Get the path to the interpretation files + script_dir = Path(__file__).resolve().parent + project_root = script_dir + interpretation_dir = project_root / "pyreason" / "scripts" / "interpretation" + + interpretation_file = interpretation_dir / "interpretation.py" + interpretation_parallel_file = interpretation_dir / "interpretation_parallel.py" + + # Verify source file exists + if not interpretation_file.exists(): + print(f"Error: Source file not found: {interpretation_file}", file=sys.stderr) + return 1 + + # Read the source file + try: + with open(interpretation_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + except Exception as e: + print(f"Error reading {interpretation_file}: {e}", file=sys.stderr) + return 1 + + # Verify we have at least 226 lines + if len(lines) < 226: + print(f"Error: {interpretation_file} has fewer than 226 lines", file=sys.stderr) + return 1 + + # Expected line 226 (index 225) in source file + expected_line = "\t@numba.njit(cache=True, parallel=False)\n" + + if lines[225] != expected_line: + print(f"Warning: Line 226 in {interpretation_file} is not as expected.", file=sys.stderr) + print(f" Expected: {expected_line.strip()}", file=sys.stderr) + print(f" Got: {lines[225].strip()}", file=sys.stderr) + print(f" Proceeding with replacement anyway...", file=sys.stderr) + + # Replace line 226 for parallel version + modified_lines = lines.copy() + modified_lines[225] = "\t@numba.njit(cache=True, parallel=True)\n" + + # Write to parallel file + try: + with open(interpretation_parallel_file, 'w', encoding='utf-8') as f: + f.writelines(modified_lines) + except Exception as e: + print(f"Error writing {interpretation_parallel_file}: {e}", file=sys.stderr) + return 1 + + print(f"✓ Successfully synced {interpretation_parallel_file.name} from {interpretation_file.name}") + print(f" Modified line 226: parallel=False → parallel=True") + + return 0 + + +def main(): + """Main entry point.""" + result = sync_interpretation_files() + sys.exit(result) + + +if __name__ == "__main__": + main() diff --git a/tests/functional/test_hello_world_parallel.py b/tests/functional/test_hello_world_parallel.py deleted file mode 100644 index 888dbed1..00000000 --- a/tests/functional/test_hello_world_parallel.py +++ /dev/null @@ -1,48 +0,0 @@ -# Test if the simple hello world program works. -import pyreason as pr - - -def test_hello_world_parallel(): - # Reset PyReason - pr.reset() - pr.reset_rules() - - # Modify the paths based on where you've stored the files we made above - graph_path = './tests/functional/friends_graph.graphml' - - # Modify pyreason settings to make verbose - pr.reset_settings() - pr.settings.verbose = True # Print info to screen - pr.settings.parallel_computing = True # Use parallel computing - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) - pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=2) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_nodes(interpretation, ['popular']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - - assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person' - assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people' - assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people' - - # Mary should be popular in all three timesteps - assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' - assert 'Mary' in dataframes[1]['component'].values and dataframes[1].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps' - assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps' - - # Justin should be popular in timesteps 1, 2 - assert 'Justin' in dataframes[1]['component'].values and dataframes[1].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps' - assert 'Justin' in dataframes[2]['component'].values and dataframes[2].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps' - - # John should be popular in timestep 3 - assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' - diff --git a/tests/unit/api_tests/test_pyreason_file_consistency.py b/tests/unit/api_tests/test_pyreason_file_consistency.py new file mode 100644 index 00000000..f7f30e7e --- /dev/null +++ b/tests/unit/api_tests/test_pyreason_file_consistency.py @@ -0,0 +1,78 @@ +""" +Tests to validate consistency between related files in the PyReason codebase. +""" + +import os +from pathlib import Path + + +def test_interpretation_parallel_consistency(): + """ + Test that interpretation_parallel.py is identical to interpretation.py + except for the parallel flag in the @numba.njit decorator on line 226. + + interpretation.py line 226 should have: @numba.njit(cache=True, parallel=False) + interpretation_parallel.py line 226 should have: @numba.njit(cache=True, parallel=True) + """ + # Get the path to the interpretation files + scripts_dir = Path(__file__).parent.parent.parent.parent / "pyreason" / "scripts" / "interpretation" + interpretation_file = scripts_dir / "interpretation.py" + interpretation_parallel_file = scripts_dir / "interpretation_parallel.py" + + # Verify both files exist + assert interpretation_file.exists(), f"File not found: {interpretation_file}" + assert interpretation_parallel_file.exists(), f"File not found: {interpretation_parallel_file}" + + # Read both files + with open(interpretation_file, 'r', encoding='utf-8') as f: + interpretation_lines = f.readlines() + + with open(interpretation_parallel_file, 'r', encoding='utf-8') as f: + interpretation_parallel_lines = f.readlines() + + # Check that both files have the same number of lines + assert len(interpretation_lines) == len(interpretation_parallel_lines), \ + f"Files have different number of lines: {len(interpretation_lines)} vs {len(interpretation_parallel_lines)}" + + # Expected difference on line 226 (index 225) + expected_line_226_interpretation = "\t@numba.njit(cache=True, parallel=False)\n" + expected_line_226_interpretation_parallel = "\t@numba.njit(cache=True, parallel=True)\n" + + # Track differences + differences = [] + + # Compare line by line + for line_num, (line1, line2) in enumerate(zip(interpretation_lines, interpretation_parallel_lines), start=1): + if line1 != line2: + # Line 226 should be the only difference + if line_num == 226: + # Verify the expected difference + if line1 != expected_line_226_interpretation: + differences.append( + f"Line {line_num} in interpretation.py is not as expected.\n" + f" Expected: {expected_line_226_interpretation.strip()}\n" + f" Got: {line1.strip()}" + ) + if line2 != expected_line_226_interpretation_parallel: + differences.append( + f"Line {line_num} in interpretation_parallel.py is not as expected.\n" + f" Expected: {expected_line_226_interpretation_parallel.strip()}\n" + f" Got: {line2.strip()}" + ) + else: + # Any other difference is unexpected + differences.append( + f"Unexpected difference at line {line_num}:\n" + f" interpretation.py: {line1.strip()}\n" + f" interpretation_parallel.py: {line2.strip()}" + ) + + # Assert no unexpected differences + assert len(differences) == 0, \ + f"Found {len(differences)} unexpected difference(s) between the files:\n" + "\n\n".join(differences) + + # Verify the expected difference exists on line 226 + assert interpretation_lines[225] == expected_line_226_interpretation, \ + f"Line 226 in interpretation.py is not as expected" + assert interpretation_parallel_lines[225] == expected_line_226_interpretation_parallel, \ + f"Line 226 in interpretation_parallel.py is not as expected" From dd0904c569b6ca28617aa871dd7a8f7a8fe8a3ba Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sat, 18 Oct 2025 14:16:20 -0400 Subject: [PATCH 10/32] Add consistency check for interpretation and interpreation_parallel --- .pre-commit-config.yaml | 14 +++- .../scripts/interpretation/interpretation.py | 12 +-- .../interpretation/interpretation_parallel.py | 12 +-- run_tests.py | 1 + sync_interpretation_parallel.py | 81 +++++++++++++++++++ tests/functional/test_hello_world_parallel.py | 48 ----------- .../test_pyreason_file_consistency.py | 78 ++++++++++++++++++ 7 files changed, 173 insertions(+), 73 deletions(-) create mode 100644 sync_interpretation_parallel.py delete mode 100644 tests/functional/test_hello_world_parallel.py create mode 100644 tests/unit/api_tests/test_pyreason_file_consistency.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 08a5ee59..b23c025a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,15 +3,23 @@ repos: - repo: local hooks: # --- COMMIT STAGE: Fast unit tests only --- + - id: sync-interpretation-parallel + name: Sync interpretation_parallel.py from interpretation.py + entry: .venv/bin/python sync_interpretation_parallel.py + language: system + pass_filenames: false + files: 'pyreason/scripts/interpretation/interpretation\.py' + stages: [pre-commit] + - id: pytest-unit-no-jit - name: Run JIT-disabled unit tests + name: Run JIT-disabled unit tests entry: .venv/bin/python -m pytest tests/unit/disable_jit -m "not slow" --tb=short -q language: system pass_filenames: false stages: [pre-commit] - id: pytest-unit-jit - name: Run JIT-enabled unit tests + name: Run JIT-enabled unit tests entry: .venv/bin/python -m pytest tests/unit/dont_disable_jit -m "not slow" --tb=short -q language: system pass_filenames: false @@ -20,7 +28,7 @@ repos: # --- PUSH STAGE: Complete test suite --- - id: pytest-unit-api - name: Run pyreason api unit tests + name: Run pyreason api unit tests entry: .venv/bin/python -m pytest tests/unit/api_tests --tb=short -q language: system pass_filenames: false diff --git a/pyreason/scripts/interpretation/interpretation.py b/pyreason/scripts/interpretation/interpretation.py index 03e46290..8bdc3051 100755 --- a/pyreason/scripts/interpretation/interpretation.py +++ b/pyreason/scripts/interpretation/interpretation.py @@ -463,11 +463,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi for idx, i in enumerate(rules_to_be_applied_edge): if i[0] == t: comp, l, bnd, set_static = i[1], i[2], i[3], i[4] - print('applying edge rule at time', t, 'for component', comp, 'label', l, 'bound', bnd, 'set_static', set_static) sources, targets, edge_l = edges_to_be_added_edge_rule[idx] - print("adding edges:", sources, targets, edge_l) edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) - print('after adding, edges are:', edges) changes_cnt += changes # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally @@ -605,26 +602,19 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi in_loop_threadsafe[i] = True update_threadsafe[i] = False - # Update lists after parallel run - print("len", len(rules_to_be_applied_edge_threadsafe)) - for i in rules_to_be_applied_edge_threadsafe: - print(i) + # Update lists after parallel run for i in range(len(rules)): if len(rules_to_be_applied_node_threadsafe[i]) > 0: rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i]) if len(rules_to_be_applied_edge_threadsafe[i]) > 0: - print('here, edge rules') rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i]) - print("rules_to_be_applied_edge", rules_to_be_applied_edge) if atom_trace: if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0: rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i]) if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0: rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i]) if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0: - print('here, edge add') edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i]) - print("edges_to_be_added_edge_rule", edges_to_be_added_edge_rule) # Merge threadsafe flags for in_loop and update in_loop = in_loop diff --git a/pyreason/scripts/interpretation/interpretation_parallel.py b/pyreason/scripts/interpretation/interpretation_parallel.py index 1972c404..45c41054 100644 --- a/pyreason/scripts/interpretation/interpretation_parallel.py +++ b/pyreason/scripts/interpretation/interpretation_parallel.py @@ -463,11 +463,8 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi for idx, i in enumerate(rules_to_be_applied_edge): if i[0] == t: comp, l, bnd, set_static = i[1], i[2], i[3], i[4] - print('applying edge rule at time', t, 'for component', comp, 'label', l, 'bound', bnd, 'set_static', set_static) sources, targets, edge_l = edges_to_be_added_edge_rule[idx] - print("adding edges:", sources, targets, edge_l) edges_added, changes = _add_edges(sources, targets, neighbors, reverse_neighbors, nodes, edges, edge_l, interpretations_node, interpretations_edge, predicate_map_edge, num_ga, t) - print('after adding, edges are:', edges) changes_cnt += changes # Update bound for newly added edges. Use bnd to update all edges if label is specified, else use bnd to update normally @@ -605,26 +602,19 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi in_loop_threadsafe[i] = True update_threadsafe[i] = False - # Update lists after parallel run - print("len", len(rules_to_be_applied_edge_threadsafe)) - for i in rules_to_be_applied_edge_threadsafe: - print(i) + # Update lists after parallel run for i in range(len(rules)): if len(rules_to_be_applied_node_threadsafe[i]) > 0: rules_to_be_applied_node.extend(rules_to_be_applied_node_threadsafe[i]) if len(rules_to_be_applied_edge_threadsafe[i]) > 0: - print('here, edge rules') rules_to_be_applied_edge.extend(rules_to_be_applied_edge_threadsafe[i]) - print("rules_to_be_applied_edge", rules_to_be_applied_edge) if atom_trace: if len(rules_to_be_applied_node_trace_threadsafe[i]) > 0: rules_to_be_applied_node_trace.extend(rules_to_be_applied_node_trace_threadsafe[i]) if len(rules_to_be_applied_edge_trace_threadsafe[i]) > 0: rules_to_be_applied_edge_trace.extend(rules_to_be_applied_edge_trace_threadsafe[i]) if len(edges_to_be_added_edge_rule_threadsafe[i]) > 0: - print('here, edge add') edges_to_be_added_edge_rule.extend(edges_to_be_added_edge_rule_threadsafe[i]) - print("edges_to_be_added_edge_rule", edges_to_be_added_edge_rule) # Merge threadsafe flags for in_loop and update in_loop = in_loop diff --git a/run_tests.py b/run_tests.py index 47cf533f..8cbe0117 100755 --- a/run_tests.py +++ b/run_tests.py @@ -167,6 +167,7 @@ def run_suite(self, suite: TestSuite, coverage: bool = True) -> Tuple[bool, Opti # Prepare pytest command - try to find python/pytest python_cmd = self._find_python_command() cmd = [python_cmd, '-m', 'pytest'] + print("Running command with Python:", python_cmd) coverage_file_path = None if coverage: diff --git a/sync_interpretation_parallel.py b/sync_interpretation_parallel.py new file mode 100644 index 00000000..2ad25724 --- /dev/null +++ b/sync_interpretation_parallel.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +""" +Pre-commit hook script to synchronize interpretation_parallel.py from interpretation.py. + +This script ensures that interpretation_parallel.py is always an exact copy of +interpretation.py, except for line 226 which should have parallel=True instead of +parallel=False in the @numba.njit decorator. +""" + +import sys +from pathlib import Path + + +def sync_interpretation_files(): + """ + Synchronize interpretation_parallel.py from interpretation.py. + + Returns: + int: 0 on success, 1 on failure + """ + # Get the path to the interpretation files + script_dir = Path(__file__).resolve().parent + project_root = script_dir + interpretation_dir = project_root / "pyreason" / "scripts" / "interpretation" + + interpretation_file = interpretation_dir / "interpretation.py" + interpretation_parallel_file = interpretation_dir / "interpretation_parallel.py" + + # Verify source file exists + if not interpretation_file.exists(): + print(f"Error: Source file not found: {interpretation_file}", file=sys.stderr) + return 1 + + # Read the source file + try: + with open(interpretation_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + except Exception as e: + print(f"Error reading {interpretation_file}: {e}", file=sys.stderr) + return 1 + + # Verify we have at least 226 lines + if len(lines) < 226: + print(f"Error: {interpretation_file} has fewer than 226 lines", file=sys.stderr) + return 1 + + # Expected line 226 (index 225) in source file + expected_line = "\t@numba.njit(cache=True, parallel=False)\n" + + if lines[225] != expected_line: + print(f"Warning: Line 226 in {interpretation_file} is not as expected.", file=sys.stderr) + print(f" Expected: {expected_line.strip()}", file=sys.stderr) + print(f" Got: {lines[225].strip()}", file=sys.stderr) + print(f" Proceeding with replacement anyway...", file=sys.stderr) + + # Replace line 226 for parallel version + modified_lines = lines.copy() + modified_lines[225] = "\t@numba.njit(cache=True, parallel=True)\n" + + # Write to parallel file + try: + with open(interpretation_parallel_file, 'w', encoding='utf-8') as f: + f.writelines(modified_lines) + except Exception as e: + print(f"Error writing {interpretation_parallel_file}: {e}", file=sys.stderr) + return 1 + + print(f"✓ Successfully synced {interpretation_parallel_file.name} from {interpretation_file.name}") + print(f" Modified line 226: parallel=False → parallel=True") + + return 0 + + +def main(): + """Main entry point.""" + result = sync_interpretation_files() + sys.exit(result) + + +if __name__ == "__main__": + main() diff --git a/tests/functional/test_hello_world_parallel.py b/tests/functional/test_hello_world_parallel.py deleted file mode 100644 index 888dbed1..00000000 --- a/tests/functional/test_hello_world_parallel.py +++ /dev/null @@ -1,48 +0,0 @@ -# Test if the simple hello world program works. -import pyreason as pr - - -def test_hello_world_parallel(): - # Reset PyReason - pr.reset() - pr.reset_rules() - - # Modify the paths based on where you've stored the files we made above - graph_path = './tests/functional/friends_graph.graphml' - - # Modify pyreason settings to make verbose - pr.reset_settings() - pr.settings.verbose = True # Print info to screen - pr.settings.parallel_computing = True # Use parallel computing - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) - pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=2) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_nodes(interpretation, ['popular']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - - assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person' - assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people' - assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people' - - # Mary should be popular in all three timesteps - assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' - assert 'Mary' in dataframes[1]['component'].values and dataframes[1].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps' - assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps' - - # Justin should be popular in timesteps 1, 2 - assert 'Justin' in dataframes[1]['component'].values and dataframes[1].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps' - assert 'Justin' in dataframes[2]['component'].values and dataframes[2].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps' - - # John should be popular in timestep 3 - assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' - diff --git a/tests/unit/api_tests/test_pyreason_file_consistency.py b/tests/unit/api_tests/test_pyreason_file_consistency.py new file mode 100644 index 00000000..f7f30e7e --- /dev/null +++ b/tests/unit/api_tests/test_pyreason_file_consistency.py @@ -0,0 +1,78 @@ +""" +Tests to validate consistency between related files in the PyReason codebase. +""" + +import os +from pathlib import Path + + +def test_interpretation_parallel_consistency(): + """ + Test that interpretation_parallel.py is identical to interpretation.py + except for the parallel flag in the @numba.njit decorator on line 226. + + interpretation.py line 226 should have: @numba.njit(cache=True, parallel=False) + interpretation_parallel.py line 226 should have: @numba.njit(cache=True, parallel=True) + """ + # Get the path to the interpretation files + scripts_dir = Path(__file__).parent.parent.parent.parent / "pyreason" / "scripts" / "interpretation" + interpretation_file = scripts_dir / "interpretation.py" + interpretation_parallel_file = scripts_dir / "interpretation_parallel.py" + + # Verify both files exist + assert interpretation_file.exists(), f"File not found: {interpretation_file}" + assert interpretation_parallel_file.exists(), f"File not found: {interpretation_parallel_file}" + + # Read both files + with open(interpretation_file, 'r', encoding='utf-8') as f: + interpretation_lines = f.readlines() + + with open(interpretation_parallel_file, 'r', encoding='utf-8') as f: + interpretation_parallel_lines = f.readlines() + + # Check that both files have the same number of lines + assert len(interpretation_lines) == len(interpretation_parallel_lines), \ + f"Files have different number of lines: {len(interpretation_lines)} vs {len(interpretation_parallel_lines)}" + + # Expected difference on line 226 (index 225) + expected_line_226_interpretation = "\t@numba.njit(cache=True, parallel=False)\n" + expected_line_226_interpretation_parallel = "\t@numba.njit(cache=True, parallel=True)\n" + + # Track differences + differences = [] + + # Compare line by line + for line_num, (line1, line2) in enumerate(zip(interpretation_lines, interpretation_parallel_lines), start=1): + if line1 != line2: + # Line 226 should be the only difference + if line_num == 226: + # Verify the expected difference + if line1 != expected_line_226_interpretation: + differences.append( + f"Line {line_num} in interpretation.py is not as expected.\n" + f" Expected: {expected_line_226_interpretation.strip()}\n" + f" Got: {line1.strip()}" + ) + if line2 != expected_line_226_interpretation_parallel: + differences.append( + f"Line {line_num} in interpretation_parallel.py is not as expected.\n" + f" Expected: {expected_line_226_interpretation_parallel.strip()}\n" + f" Got: {line2.strip()}" + ) + else: + # Any other difference is unexpected + differences.append( + f"Unexpected difference at line {line_num}:\n" + f" interpretation.py: {line1.strip()}\n" + f" interpretation_parallel.py: {line2.strip()}" + ) + + # Assert no unexpected differences + assert len(differences) == 0, \ + f"Found {len(differences)} unexpected difference(s) between the files:\n" + "\n\n".join(differences) + + # Verify the expected difference exists on line 226 + assert interpretation_lines[225] == expected_line_226_interpretation, \ + f"Line 226 in interpretation.py is not as expected" + assert interpretation_parallel_lines[225] == expected_line_226_interpretation_parallel, \ + f"Line 226 in interpretation_parallel.py is not as expected" From f71820469f72615abe5843d44252ce2cbc7a1469 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sat, 18 Oct 2025 14:32:29 -0400 Subject: [PATCH 11/32] Del check var names --- check_variable_names.py | 201 ---------------------------------------- 1 file changed, 201 deletions(-) delete mode 100644 check_variable_names.py diff --git a/check_variable_names.py b/check_variable_names.py deleted file mode 100644 index c228b2e0..00000000 --- a/check_variable_names.py +++ /dev/null @@ -1,201 +0,0 @@ -#!/usr/bin/env python3 -""" -Script to check for occurrences of specific variable names in Python codebase. -Uses AST parsing to avoid false positives from comments and strings. -""" - -import argparse -import ast -import sys -from pathlib import Path -from typing import List, Tuple - - -class VariableNameVisitor(ast.NodeVisitor): - """AST visitor that finds all occurrences of a specific variable name.""" - - def __init__(self, target_name: str, filepath: str): - self.target_name = target_name - self.filepath = filepath - self.occurrences: List[Tuple[int, int, str, str]] = [] - - def visit_Name(self, node: ast.Name) -> None: - """Visit Name nodes (variable references).""" - if node.id == self.target_name: - context = "reference" - if isinstance(node.ctx, ast.Store): - context = "assignment" - elif isinstance(node.ctx, ast.Del): - context = "deletion" - self.occurrences.append((node.lineno, node.col_offset, context, node.id)) - self.generic_visit(node) - - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - """Visit function definitions to check parameter names.""" - for arg in node.args.args + node.args.posonlyargs + node.args.kwonlyargs: - if arg.arg == self.target_name: - self.occurrences.append((arg.lineno, arg.col_offset, "function_parameter", arg.arg)) - if node.args.vararg and node.args.vararg.arg == self.target_name: - self.occurrences.append((node.args.vararg.lineno, node.args.vararg.col_offset, "vararg", node.args.vararg.arg)) - if node.args.kwarg and node.args.kwarg.arg == self.target_name: - self.occurrences.append((node.args.kwarg.lineno, node.args.kwarg.col_offset, "kwarg", node.args.kwarg.arg)) - self.generic_visit(node) - - def visit_arg(self, node: ast.arg) -> None: - """Visit argument nodes.""" - if node.arg == self.target_name: - self.occurrences.append((node.lineno, node.col_offset, "argument", node.arg)) - self.generic_visit(node) - - def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None: - """Visit exception handler to check exception variable names.""" - if node.name == self.target_name: - self.occurrences.append((node.lineno, node.col_offset, "exception_var", node.name)) - self.generic_visit(node) - - def visit_comprehension(self, node: ast.comprehension) -> None: - """Visit comprehension to check target variables.""" - if isinstance(node.target, ast.Name) and node.target.id == self.target_name: - self.occurrences.append((node.target.lineno, node.target.col_offset, "comprehension_var", node.target.id)) - self.generic_visit(node) - - -def find_python_files(root_dir: Path, exclude_patterns: List[str] = None) -> List[Path]: - """Find all Python files in the directory tree.""" - if exclude_patterns is None: - exclude_patterns = [] - - python_files = [] - for py_file in root_dir.rglob("*.py"): - # Check if file should be excluded - should_exclude = False - for pattern in exclude_patterns: - if pattern in str(py_file): - should_exclude = True - break - - if not should_exclude: - python_files.append(py_file) - - return python_files - - -def check_file_for_variable(filepath: Path, variable_name: str) -> List[Tuple[int, int, str, str]]: - """Check a single file for occurrences of the variable name.""" - try: - with open(filepath, 'r', encoding='utf-8') as f: - content = f.read() - - tree = ast.parse(content, filename=str(filepath)) - visitor = VariableNameVisitor(variable_name, str(filepath)) - visitor.visit(tree) - - return visitor.occurrences - except SyntaxError: - print(f"Warning: Could not parse {filepath} (syntax error)", file=sys.stderr) - return [] - except Exception as e: - print(f"Warning: Error processing {filepath}: {e}", file=sys.stderr) - return [] - - -def get_line_content(filepath: Path, lineno: int) -> str: - """Get the content of a specific line from a file.""" - try: - with open(filepath, 'r', encoding='utf-8') as f: - lines = f.readlines() - if 0 < lineno <= len(lines): - return lines[lineno - 1].rstrip() - except Exception: - pass - return "" - - -def main(): - parser = argparse.ArgumentParser( - description="Check for occurrences of specific variable names in Python codebase" - ) - parser.add_argument( - "variable_name", - help="The variable name to search for" - ) - parser.add_argument( - "--path", - type=Path, - default=Path("."), - help="Root directory to search (default: current directory)" - ) - parser.add_argument( - "--exclude", - nargs="*", - default=[".venv", "__pycache__", ".git", "build", "dist", ".eggs"], - help="Patterns to exclude from search" - ) - parser.add_argument( - "--include-tests", - action="store_true", - help="Include test files (excluded by default)" - ) - parser.add_argument( - "--include-docs", - action="store_true", - help="Include documentation files (excluded by default)" - ) - - args = parser.parse_args() - - # Add default exclusions - exclude_patterns = args.exclude.copy() - if not args.include_tests: - exclude_patterns.extend(["test_", "tests/"]) - if not args.include_docs: - exclude_patterns.extend(["docs/", "examples/"]) - - # Find all Python files - print(f"Searching for variable '{args.variable_name}' in {args.path}") - print(f"Excluding patterns: {exclude_patterns}\n") - - python_files = find_python_files(args.path, exclude_patterns) - print(f"Found {len(python_files)} Python files to check\n") - - # Check each file - total_occurrences = 0 - files_with_occurrences = 0 - - for filepath in sorted(python_files): - occurrences = check_file_for_variable(filepath, args.variable_name) - - if occurrences: - files_with_occurrences += 1 - total_occurrences += len(occurrences) - - # Display relative path - try: - rel_path = filepath.relative_to(args.path) - except ValueError: - rel_path = filepath - - print(f"\n{rel_path}:") - for lineno, col_offset, context, var_name in sorted(occurrences): - line_content = get_line_content(filepath, lineno) - print(f" Line {lineno}, Col {col_offset} ({context}): {line_content.strip()}") - - # Summary - print(f"\n{'=' * 60}") - print(f"Summary:") - print(f" Variable name: '{args.variable_name}'") - print(f" Files checked: {len(python_files)}") - print(f" Files with occurrences: {files_with_occurrences}") - print(f" Total occurrences: {total_occurrences}") - print(f"{'=' * 60}") - - # Exit with non-zero if occurrences found - if total_occurrences > 0: - sys.exit(1) - else: - print("\nNo occurrences found!") - sys.exit(0) - - -if __name__ == "__main__": - main() From 05e566d565586b6a9b830014ab3012096952da7d Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sat, 18 Oct 2025 14:33:03 -0400 Subject: [PATCH 12/32] Remove consistency check from push stage --- .pre-commit-config.yaml | 7 ------- 1 file changed, 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 44d9580b..b23c025a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,13 +34,6 @@ repos: pass_filenames: false stages: [pre-push] - - id: pytest-file-consistency - name: Verify interpretation files are in sync - entry: .venv/bin/python -m pytest tests/unit/api_tests/test_pyreason_file_consistency.py -v - language: system - pass_filenames: false - stages: [pre-push] - - id: pytest-functional-complete name: Run functional test suite entry: .venv/bin/python -m pytest tests/functional/ --tb=short From 98ac3732037be7b43b28cfac6cc739e05b02fc92 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sat, 18 Oct 2025 14:33:43 -0400 Subject: [PATCH 13/32] Remove debug scripts --- debug.py | 100 ----------------------------------- debug_thresholds.py | 126 -------------------------------------------- 2 files changed, 226 deletions(-) delete mode 100644 debug.py delete mode 100644 debug_thresholds.py diff --git a/debug.py b/debug.py deleted file mode 100644 index 65230fce..00000000 --- a/debug.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Debug script for test_annotation_function parallel mode issue.""" -import pyreason as pr -from pyreason import Threshold -import numba -import numpy as np -from pyreason.scripts.numba_wrapper.numba_types.interval_type import closed - - -@numba.njit -def probability_func(annotations, weights): - prob_A = annotations[0][0].lower - prob_B = annotations[1][0].lower - union_prob = prob_A + prob_B - union_prob = np.round(union_prob, 3) - return union_prob, 1 - - -def main(): - # Setup parallel mode - pr.reset() - pr.reset_rules() - pr.reset_settings() - pr.settings.verbose = False # Disable verbose to speed up - pr.settings.parallel_computing = True - pr.settings.allow_ground_rules = True - - print("Settings configured:") - print(f" parallel_computing: {pr.settings.parallel_computing}") - print(f" allow_ground_rules: {pr.settings.allow_ground_rules}") - - print("=" * 80) - print("PARALLEL MODE DEBUG") - print("=" * 80) - - # Add facts - pr.add_fact(pr.Fact('P(A) : [0.01, 1]')) - pr.add_fact(pr.Fact('P(B) : [0.2, 1]')) - - # Add annotation function - pr.add_annotation_function(probability_func) - - # Add rule - pr.add_rule(pr.Rule('union_probability(A, B):probability_func <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True)) - - # Run reasoning - print("\nRunning reasoning for 1 timestep...") - interpretation = pr.reason(timesteps=1) - - # Display results - print("\n" + "=" * 80) - print("RESULTS") - print("=" * 80) - - dataframes = pr.filter_and_sort_edges(interpretation, ['union_probability']) - for t, df in enumerate(dataframes): - print(f'\nTIMESTEP - {t}') - print(df) - print() - - # Check what we actually got - print("\n" + "=" * 80) - print("QUERY RESULTS") - print("=" * 80) - - # Try to query the actual value - query_result = interpretation.query(pr.Query('union_probability(A, B) : [0.21, 1]')) - print(f"\nQuery for [0.21, 1]: {query_result}") - - # Let's also try to see what value we actually got - # Query with a wider range to see if it exists at all - wider_query = interpretation.query(pr.Query('union_probability(A, B) : [0, 1]')) - print(f"Query for [0, 1] (wider range): {wider_query}") - - # Get the actual edge data - print("\n" + "=" * 80) - print("DETAILED EDGE INSPECTION") - print("=" * 80) - - # Access the interpretation's internal data to see actual values - if hasattr(interpretation, 'get_dict'): - edge_dict = interpretation.get_dict() - print(f"\nEdge dictionary keys: {edge_dict.keys()}") - if ('A', 'B') in edge_dict: - print(f"\nEdge ('A', 'B') data:") - for key, value in edge_dict[('A', 'B')].items(): - print(f" {key}: {value}") - - # Alternative: inspect atoms directly - if hasattr(interpretation, 'atoms'): - print(f"\nAtoms available: {interpretation.atoms}") - - print("\n" + "=" * 80) - print("EXPECTED vs ACTUAL") - print("=" * 80) - print(f"Expected: union_probability(A, B) with bounds [0.21, 1]") - print(f"Actual: See dataframe above") - - -if __name__ == "__main__": - main() diff --git a/debug_thresholds.py b/debug_thresholds.py deleted file mode 100644 index ca79f154..00000000 --- a/debug_thresholds.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Debug script for test_custom_thresholds parallel mode issue.""" -import pyreason as pr -from pyreason import Threshold - - -def main(): - # Setup parallel mode - pr.reset() - pr.reset_rules() - pr.reset_settings() - pr.settings.verbose = False # Disable verbose to speed up - pr.settings.parallel_computing = True - pr.settings.atom_trace = True - - print("=" * 80) - print("CUSTOM THRESHOLDS PARALLEL MODE DEBUG") - print("=" * 80) - print(f"Settings:") - print(f" parallel_computing: {pr.settings.parallel_computing}") - print(f" atom_trace: {pr.settings.atom_trace}") - - # Load graph - graph_path = "./tests/functional/group_chat_graph.graphml" - print(f"\nLoading graph from: {graph_path}") - pr.load_graphml(graph_path) - - # Add custom thresholds - user_defined_thresholds = [ - Threshold("greater_equal", ("number", "total"), 1), - Threshold("greater_equal", ("percent", "total"), 100), - ] - print(f"\nCustom thresholds: {user_defined_thresholds}") - - # Add rule - pr.add_rule( - pr.Rule( - "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)", - "viewed_by_all_rule", - custom_thresholds=user_defined_thresholds, - ) - ) - print("Rule added: ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)") - - # Add facts - pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3)) - pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3)) - pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3)) - pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3)) - print("\nFacts added:") - print(" Viewed(Zach) at t=0") - print(" Viewed(Justin) at t=0") - print(" Viewed(Michelle) at t=1") - print(" Viewed(Amy) at t=2") - - # Run reasoning - print("\n" + "=" * 80) - print("Running reasoning for 3 timesteps...") - print("=" * 80) - interpretation = pr.reason(timesteps=3) - print("Reasoning completed!") - - # Display results - print("\n" + "=" * 80) - print("RESULTS - ViewedByAll at each timestep") - print("=" * 80) - - dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"]) - for t, df in enumerate(dataframes): - print(f"\nTIMESTEP {t}:") - print(f" Number of nodes with ViewedByAll: {len(df)}") - if len(df) > 0: - print(df) - else: - print(" (no nodes with ViewedByAll)") - - # Check specific assertions - print("\n" + "=" * 80) - print("ASSERTION CHECKS") - print("=" * 80) - - t0_check = len(dataframes[0]) == 0 - print(f"✓ t=0: ViewedByAll count = {len(dataframes[0])} (expected: 0) - {'PASS' if t0_check else 'FAIL'}") - - t2_check = len(dataframes[2]) == 1 - print(f"✓ t=2: ViewedByAll count = {len(dataframes[2])} (expected: 1) - {'PASS' if t2_check else 'FAIL'}") - - if len(dataframes[2]) > 0: - has_textmsg = "TextMessage" in dataframes[2]["component"].values - if has_textmsg: - bounds = dataframes[2].iloc[0].ViewedByAll - bounds_check = bounds == [1, 1] - print(f"✓ t=2: TextMessage bounds = {bounds} (expected: [1, 1]) - {'PASS' if bounds_check else 'FAIL'}") - else: - print(f"✗ t=2: TextMessage not found in ViewedByAll nodes") - print(f" Available nodes: {dataframes[2]['component'].values}") - else: - print("✗ t=2: No ViewedByAll nodes found (expected TextMessage)") - - # Additional debugging: show all Viewed facts at each timestep - print("\n" + "=" * 80) - print("DEBUG - Viewed nodes at each timestep") - print("=" * 80) - viewed_dataframes = pr.filter_and_sort_nodes(interpretation, ["Viewed"]) - for t, df in enumerate(viewed_dataframes): - print(f"\nTIMESTEP {t}:") - if len(df) > 0: - print(df) - else: - print(" (no Viewed nodes)") - - # Show HaveAccess edges if possible - print("\n" + "=" * 80) - print("DEBUG - HaveAccess edges") - print("=" * 80) - try: - access_dataframes = pr.filter_and_sort_edges(interpretation, ["HaveAccess"]) - print(f"Number of HaveAccess edges at t=0: {len(access_dataframes[0]) if access_dataframes else 'N/A'}") - if access_dataframes and len(access_dataframes[0]) > 0: - print("\nSample HaveAccess edges:") - print(access_dataframes[0].head(10)) - except Exception as e: - print(f"Could not retrieve HaveAccess edges: {e}") - - -if __name__ == "__main__": - main() \ No newline at end of file From 90ec93033c39265d6e376f5de30998f322b605d4 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sat, 18 Oct 2025 14:39:04 -0400 Subject: [PATCH 14/32] Remove print --- pyreason/scripts/interpretation/interpretation.py | 1 - pyreason/scripts/interpretation/interpretation_parallel.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pyreason/scripts/interpretation/interpretation.py b/pyreason/scripts/interpretation/interpretation.py index 8bdc3051..befbdd95 100755 --- a/pyreason/scripts/interpretation/interpretation.py +++ b/pyreason/scripts/interpretation/interpretation.py @@ -475,7 +475,6 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)): override = True if update_mode == 'override' else False u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) - print('updating edge', e, 'label', edge_l, 'to bound', bnd) update = u or update # Update convergence params diff --git a/pyreason/scripts/interpretation/interpretation_parallel.py b/pyreason/scripts/interpretation/interpretation_parallel.py index 45c41054..23afb6cc 100644 --- a/pyreason/scripts/interpretation/interpretation_parallel.py +++ b/pyreason/scripts/interpretation/interpretation_parallel.py @@ -475,7 +475,6 @@ def reason(interpretations_node, interpretations_edge, predicate_map_node, predi if check_consistent_edge(interpretations_edge, e, (edge_l, bnd)): override = True if update_mode == 'override' else False u, changes = _update_edge(interpretations_edge, predicate_map_edge, e, (edge_l, bnd), ipl, rule_trace_edge, fp_cnt, t, set_static, convergence_mode, atom_trace, save_graph_attributes_to_rule_trace, rules_to_be_applied_edge_trace, idx, facts_to_be_applied_edge_trace, rule_trace_edge_atoms, store_interpretation_changes, num_ga, mode='rule', override=override) - print('updating edge', e, 'label', edge_l, 'to bound', bnd) update = u or update # Update convergence params From b775c82619a56f6caeebf39ca2625a1116e1f860 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sat, 18 Oct 2025 14:51:02 -0400 Subject: [PATCH 15/32] Run unit tests in parrallel on make test --- Makefile | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index 4235deee..0b7a98b2 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ # Provides convenient shortcuts for running different test configurations .PHONY: help test test-all test-fast test-api test-jit test-no-jit test-consistency \ - test-parallel test-no-coverage coverage-report coverage-html coverage-xml \ + test-parallel test-sequential test-no-coverage coverage-report coverage-html coverage-xml \ clean clean-coverage clean-reports install-deps lint check-deps # Default target @@ -37,16 +37,16 @@ help: ## Show this help message @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " $(GREEN)%-20s$(RESET) %s\n", $$1, $$2}' $(MAKEFILE_LIST) @echo "" @echo "$(BLUE)Examples:$(RESET)" - @echo " make test # Run all test suites with coverage and open report" + @echo " make test # Run all test suites in parallel with coverage" + @echo " make test-sequential # Run all test suites sequentially (slower)" @echo " make test-fast # Run only fast test suites" @echo " make test-api # Run only API tests" - @echo " make test-parallel # Run tests in parallel where possible" @echo " make coverage-html # Generate HTML coverage report" # Main test targets -test: ## Run all test suites with coverage and open report - @echo "$(BOLD)$(BLUE)Running all test suites...$(RESET)" - $(RUN_TESTS) +test: ## Run all test suites with coverage and open report (in parallel) + @echo "$(BOLD)$(BLUE)Running all test suites in parallel...$(RESET)" + $(RUN_TESTS) --parallel @echo "$(BOLD)$(GREEN)Opening coverage report in browser...$(RESET)" @if [ -f test_reports/htmlcov/index.html ]; then \ open test_reports/htmlcov/index.html 2>/dev/null || \ @@ -58,8 +58,12 @@ test: ## Run all test suites with coverage and open report test-all: test ## Alias for 'test' target -test-only: ## Run all test suites with coverage (no browser) - @echo "$(BOLD)$(BLUE)Running all test suites...$(RESET)" +test-only: ## Run all test suites with coverage (no browser, in parallel) + @echo "$(BOLD)$(BLUE)Running all test suites in parallel...$(RESET)" + $(RUN_TESTS) --parallel + +test-sequential: ## Run all test suites sequentially (no parallelization) + @echo "$(BOLD)$(BLUE)Running all test suites sequentially...$(RESET)" $(RUN_TESTS) test-fast: ## Run only fast test suites (api_tests, dont_disable_jit) @@ -96,9 +100,9 @@ test-functional: ## Run functional/end-to-end tests $(RUN_TESTS) --suite functional -test-all-suites: ## Run all test suites including functional tests - @echo "$(BOLD)$(BLUE)Running all test suites including functional...$(RESET)" - $(RUN_TESTS) +test-all-suites: ## Run all test suites including functional tests (in parallel) + @echo "$(BOLD)$(BLUE)Running all test suites including functional in parallel...$(RESET)" + $(RUN_TESTS) --parallel # Coverage targets coverage-report: ## Show coverage report in terminal From 07fc33cc0a39e7c2a0ccddd2123fef06c245029f Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sat, 18 Oct 2025 14:55:40 -0400 Subject: [PATCH 16/32] Don't skip slow tests in pre-commit hooks --- .pre-commit-config.yaml | 6 +- .../test_anyBurl_infer_edges_rules.py | 302 ------------------ 2 files changed, 3 insertions(+), 305 deletions(-) delete mode 100644 tests/functional/test_anyBurl_infer_edges_rules.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b23c025a..7445ca1a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,14 +13,14 @@ repos: - id: pytest-unit-no-jit name: Run JIT-disabled unit tests - entry: .venv/bin/python -m pytest tests/unit/disable_jit -m "not slow" --tb=short -q + entry: .venv/bin/python -m pytest tests/unit/disable_jit --tb=short -q language: system pass_filenames: false stages: [pre-commit] - id: pytest-unit-jit name: Run JIT-enabled unit tests - entry: .venv/bin/python -m pytest tests/unit/dont_disable_jit -m "not slow" --tb=short -q + entry: .venv/bin/python -m pytest tests/unit/dont_disable_jit --tb=short -q language: system pass_filenames: false stages: [pre-commit] @@ -36,7 +36,7 @@ repos: - id: pytest-functional-complete name: Run functional test suite - entry: .venv/bin/python -m pytest tests/functional/ --tb=short + entry: .venv/bin/python -m pytest tests/functional/ --tb=short -q language: system pass_filenames: false stages: [pre-push] diff --git a/tests/functional/test_anyBurl_infer_edges_rules.py b/tests/functional/test_anyBurl_infer_edges_rules.py deleted file mode 100644 index ab3df7a8..00000000 --- a/tests/functional/test_anyBurl_infer_edges_rules.py +++ /dev/null @@ -1,302 +0,0 @@ -import pyreason as pr -import pytest - - -@pytest.mark.slow -def test_anyBurl_rule_1(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.slow -def test_anyBurl_rule_2(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - pr.settings.parallel_computing = False - # Load all the files into pyreason - pr.load_graphml(graph_path) - - pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_2', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Riga_International_Airport', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Riga_International_Airport, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.slow -def test_anyBurl_rule_3(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - pr.settings.parallel_computing = False - # Load all the files into pyreason - pr.load_graphml(graph_path) - - pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_3', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Vnukovo_International_Airport', 'Yali') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Yali) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.slow -def test_anyBurl_rule_4(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - pr.settings.parallel_computing = False - # Load all the files into pyreason - pr.load_graphml(graph_path) - - pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_4', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Yali', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Yali, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.fp -@pytest.mark.slow -def test_anyBurl_rule_1_fp(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.fp_version = True # Use the FP version of the reasoner - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.fp -@pytest.mark.slow -def test_anyBurl_rule_2_fp(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.fp_version = True # Use the FP version of the reasoner - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - pr.settings.parallel_computing = False - # Load all the files into pyreason - pr.load_graphml(graph_path) - - pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_2', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Riga_International_Airport', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Riga_International_Airport, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.fp -@pytest.mark.slow -def test_anyBurl_rule_3_fp(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.fp_version = True # Use the FP version of the reasoner - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - pr.settings.parallel_computing = False - # Load all the files into pyreason - pr.load_graphml(graph_path) - - pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_3', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Vnukovo_International_Airport', 'Yali') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Yali) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.fp -@pytest.mark.slow -def test_anyBurl_rule_4_fp(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.fp_version = True # Use the FP version of the reasoner - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - pr.settings.parallel_computing = False - # Load all the files into pyreason - pr.load_graphml(graph_path) - - pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_4', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Yali', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Yali, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' \ No newline at end of file From e088709b12075f092e4abd6a4fca78487e51362f Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Mon, 20 Oct 2025 09:59:42 -0400 Subject: [PATCH 17/32] Revert import order and function redefinition --- pyreason/__init__.py | 14 ++++++++------ .../numba_wrapper/numba_types/world_type.py | 5 +++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pyreason/__init__.py b/pyreason/__init__.py index 544ddf94..05fa8237 100755 --- a/pyreason/__init__.py +++ b/pyreason/__init__.py @@ -1,22 +1,24 @@ -# Set numba environment variable # ruff: noqa: F403 F405 (Ignore Pyreason import * for public api) +# Set numba environment variable import os -import yaml -from pyreason.pyreason import * -from pkg_resources import get_distribution, DistributionNotFound -from importlib.metadata import version - package_path = os.path.abspath(os.path.dirname(__file__)) cache_path = os.path.join(package_path, 'cache') cache_status_path = os.path.join(package_path, '.cache_status.yaml') os.environ['NUMBA_CACHE_DIR'] = cache_path + +from pyreason.pyreason import * +import yaml +from importlib.metadata import version +from pkg_resources import get_distribution, DistributionNotFound + try: __version__ = get_distribution(__name__).version except DistributionNotFound: # package is not installed pass + with open(cache_status_path) as file: cache_status = yaml.safe_load(file) diff --git a/pyreason/scripts/numba_wrapper/numba_types/world_type.py b/pyreason/scripts/numba_wrapper/numba_types/world_type.py index e79748da..1518dae0 100755 --- a/pyreason/scripts/numba_wrapper/numba_types/world_type.py +++ b/pyreason/scripts/numba_wrapper/numba_types/world_type.py @@ -38,7 +38,8 @@ def typer(labels, world): return typer @type_callable(World) -def type_world_labels_only(context): +# ruff: noqa: F811 +def type_world(context): def typer(labels): if isinstance(labels, types.ListType): return world_type @@ -75,7 +76,7 @@ def impl_world(context, builder, sig, args): return w._getvalue() @lower_builtin(World, types.ListType(label.label_type)) -def impl_world_labels_only(context, builder, sig, args): +def impl_world(context, builder, sig, args): # ruff: noqa: F811 def make_world(labels_arg): d = numba.typed.Dict.empty(key_type=label.label_type, value_type=interval.interval_type) for lab in labels_arg: From db51d5bdded41f9b2a787209396d5549206caa54 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:59:14 -0400 Subject: [PATCH 18/32] Update tests --- tests/functional/test_advanced_features.py | 74 ++- tests/functional/test_annotation_function.py | 82 ---- .../test_anyBurl_infer_edges_rules.py | 181 +------- tests/functional/test_basic_reasoning.py | 49 -- tests/functional/test_continuation.py | 65 --- tests/functional/test_custom_thresholds.py | 128 ----- tests/functional/test_edge_inference.py | 127 ----- tests/functional/test_hello_world_parallel.py | 47 -- .../functional/test_pyreason_comprehensive.py | 436 ------------------ tests/functional/test_reason_again.py | 111 ----- .../api_tests/test_pyreason_file_loading.py | 34 +- .../unit/api_tests/test_pyreason_reasoning.py | 45 +- .../test_pyreason_state_management.py | 11 + 13 files changed, 147 insertions(+), 1243 deletions(-) delete mode 100644 tests/functional/test_annotation_function.py delete mode 100644 tests/functional/test_continuation.py delete mode 100644 tests/functional/test_custom_thresholds.py delete mode 100644 tests/functional/test_edge_inference.py delete mode 100644 tests/functional/test_hello_world_parallel.py delete mode 100644 tests/functional/test_pyreason_comprehensive.py delete mode 100644 tests/functional/test_reason_again.py diff --git a/tests/functional/test_advanced_features.py b/tests/functional/test_advanced_features.py index 5ecb5676..23e7c2f6 100644 --- a/tests/functional/test_advanced_features.py +++ b/tests/functional/test_advanced_features.py @@ -1,8 +1,10 @@ # Advanced feature tests for PyReason (annotation functions, custom thresholds, classifier integration) +import faulthandler import pyreason as pr from pyreason import Threshold import torch import torch.nn as nn +import networkx as nx import numba import numpy as np import pytest @@ -31,8 +33,10 @@ def probability_func(annotations, weights): return union_prob, 1 -def test_probability_func_consistency(): +@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) +def test_probability_func_consistency(mode): """Ensure annotation function behaves the same with and without JIT.""" + setup_mode(mode) annotations = numba.typed.List() annotations.append(numba.typed.List([closed(0.01, 1.0)])) annotations.append(numba.typed.List([closed(0.2, 1.0)])) @@ -172,3 +176,71 @@ def test_classifier_integration(mode): print("\nGenerated PyReason Facts:") for fact in facts: print(fact) + + +@pytest.mark.skipif(True, reason="Reason again functionality not implemented for FP version") +@pytest.mark.parametrize("mode", ["regular"]) +def test_reason_again(mode): + """Test reasoning continuation functionality.""" + setup_mode(mode) + + # Modify the paths based on where you've stored the files we made above + graph_path = './tests/functional/friends_graph.graphml' + + # Load all the files into pyreason + pr.load_graphml(graph_path) + pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) + pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 1)) + + # Run the program for two timesteps to see the diffusion take place + faulthandler.enable() + interpretation = pr.reason(timesteps=1) + + # Now reason again + new_fact = pr.Fact('popular(Mary)', 'popular_fact2', 2, 4) + pr.add_fact(new_fact) + interpretation = pr.reason(timesteps=3, again=True, restart=False) + + # Display the changes in the interpretation for each timestep + dataframes = pr.filter_and_sort_nodes(interpretation, ['popular']) + for t, df in enumerate(dataframes): + print(f'TIMESTEP - {t}') + print(df) + print() + + assert len(dataframes[2]) == 1, 'At t=0 there should be one popular person' + assert len(dataframes[3]) == 2, 'At t=1 there should be two popular people' + assert len(dataframes[4]) == 3, 'At t=2 there should be three popular people' + + # Mary should be popular in all three timesteps + assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' + assert 'Mary' in dataframes[3]['component'].values and dataframes[3].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps' + assert 'Mary' in dataframes[4]['component'].values and dataframes[4].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps' + + # Justin should be popular in timesteps 1, 2 + assert 'Justin' in dataframes[3]['component'].values and dataframes[3].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps' + assert 'Justin' in dataframes[4]['component'].values and dataframes[4].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps' + + # John should be popular in timestep 3 + assert 'John' in dataframes[4]['component'].values and dataframes[4].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' + + +@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) +def test_reason_with_queries(mode): + """Test reasoning with query-based rule filtering""" + setup_mode(mode) + # Set up test scenario + graph = nx.DiGraph() + graph.add_edges_from([("A", "B"), ("B", "C")]) + pr.load_graph(graph) + + pr.add_rule(pr.Rule('popular(x) <-1 friend(x, y)', 'rule1')) + pr.add_rule(pr.Rule('friend(x, y) <-1 knows(x, y)', 'rule2')) + pr.add_fact(pr.Fact('knows(A, B)', 'fact1')) + + # Create query to filter rules + query = pr.Query('popular(A)') + pr.settings.verbose = False # Reduce output noise + + interpretation = pr.reason(timesteps=1, queries=[query]) + # Should complete and apply rule filtering logic diff --git a/tests/functional/test_annotation_function.py b/tests/functional/test_annotation_function.py deleted file mode 100644 index 635cc425..00000000 --- a/tests/functional/test_annotation_function.py +++ /dev/null @@ -1,82 +0,0 @@ -# Test if annotation functions work -import pyreason as pr -import numba -import numpy as np -import pytest -from pyreason.scripts.numba_wrapper.numba_types.interval_type import closed - - -@numba.njit -def probability_func(annotations, weights): - prob_A = annotations[0][0].lower - prob_B = annotations[1][0].lower - union_prob = prob_A + prob_B - union_prob = np.round(union_prob, 3) - return union_prob, 1 - - -def test_probability_func_consistency(): - """Ensure annotation function behaves the same with and without JIT.""" - annotations = numba.typed.List() - annotations.append(numba.typed.List([closed(0.01, 1.0)])) - annotations.append(numba.typed.List([closed(0.2, 1.0)])) - weights = numba.typed.List([1.0, 1.0]) - jit_res = probability_func(annotations, weights) - py_res = probability_func.py_func(annotations, weights) - assert jit_res == py_res - - -@pytest.mark.slow -def test_annotation_function(): - # Reset PyReason - pr.reset() - pr.reset_rules() - pr.reset_settings() - print("fp version", pr.settings.fp_version) - - pr.settings.allow_ground_rules = True - - pr.add_fact(pr.Fact('P(A) : [0.01, 1]')) - pr.add_fact(pr.Fact('P(B) : [0.2, 1]')) - pr.add_annotation_function(probability_func) - pr.add_rule(pr.Rule('union_probability(A, B):probability_func <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True)) - - interpretation = pr.reason(timesteps=1) - - dataframes = pr.filter_and_sort_edges(interpretation, ['union_probability']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - - assert interpretation.query(pr.Query('union_probability(A, B) : [0.21, 1]')), 'Union probability should be 0.21' - - -@pytest.mark.fp -@pytest.mark.slow -def test_annotation_function_fp(): - # Reset PyReason - pr.reset() - pr.reset_rules() - pr.reset_settings() - - # Set FP version - pr.settings.fp_version = True - print("fp version", pr.settings.fp_version) - - pr.settings.allow_ground_rules = True - - pr.add_fact(pr.Fact('P(A) : [0.01, 1]')) - pr.add_fact(pr.Fact('P(B) : [0.2, 1]')) - pr.add_annotation_function(probability_func) - pr.add_rule(pr.Rule('union_probability(A, B):probability_func <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True)) - - interpretation = pr.reason(timesteps=1) - - dataframes = pr.filter_and_sort_edges(interpretation, ['union_probability']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - - assert interpretation.query(pr.Query('union_probability(A, B) : [0.21, 1]')), 'Union probability should be 0.21' diff --git a/tests/functional/test_anyBurl_infer_edges_rules.py b/tests/functional/test_anyBurl_infer_edges_rules.py index c58c6ce3..28d5774b 100644 --- a/tests/functional/test_anyBurl_infer_edges_rules.py +++ b/tests/functional/test_anyBurl_infer_edges_rules.py @@ -1,163 +1,24 @@ import pyreason as pr import pytest - -@pytest.mark.slow -def test_anyBurl_rule_1(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.slow -def test_anyBurl_rule_2(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - pr.settings.parallel_computing = False - # Load all the files into pyreason - pr.load_graphml(graph_path) - - pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_2', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Riga_International_Airport', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Riga_International_Airport, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.slow -def test_anyBurl_rule_3(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - pr.settings.parallel_computing = False - # Load all the files into pyreason - pr.load_graphml(graph_path) - - pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_3', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Vnukovo_International_Airport', 'Yali') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Yali) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.slow -def test_anyBurl_rule_4(): - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' +def setup_mode(mode): + """Configure PyReason settings for the specified mode.""" pr.reset() pr.reset_rules() pr.reset_settings() - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - pr.settings.parallel_computing = False - # Load all the files into pyreason - pr.load_graphml(graph_path) + pr.settings.verbose = Tru - pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_4', infer_edges=True)) + if mode == "fp": + pr.settings.fp_version = True + elif mode == "parallel": + pr.settings.parallel_computing = True - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - # pr.save_rule_trace(interpretation) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Yali', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Yali, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.fp @pytest.mark.slow -def test_anyBurl_rule_1_fp(): +def test_anyBurl_rule_1(mode): graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() + setup_mode(mode) # Modify pyreason settings to make verbose and to save the rule trace to a file pr.settings.verbose = True - pr.settings.fp_version = True # Use the FP version of the reasoner pr.settings.atom_trace = True pr.settings.memory_profile = False pr.settings.canonical = True @@ -185,16 +46,12 @@ def test_anyBurl_rule_1_fp(): assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' -@pytest.mark.fp @pytest.mark.slow -def test_anyBurl_rule_2_fp(): +def test_anyBurl_rule_2(mode): graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() + setup_mode(mode) # Modify pyreason settings to make verbose and to save the rule trace to a file pr.settings.verbose = True - pr.settings.fp_version = True # Use the FP version of the reasoner pr.settings.atom_trace = True pr.settings.memory_profile = False pr.settings.canonical = True @@ -224,16 +81,12 @@ def test_anyBurl_rule_2_fp(): assert ('Riga_International_Airport', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Riga_International_Airport, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' -@pytest.mark.fp @pytest.mark.slow -def test_anyBurl_rule_3_fp(): +def test_anyBurl_rule_3(mode): graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() + setup_mode(mode) # Modify pyreason settings to make verbose and to save the rule trace to a file pr.settings.verbose = True - pr.settings.fp_version = True # Use the FP version of the reasoner pr.settings.atom_trace = True pr.settings.memory_profile = False pr.settings.canonical = True @@ -263,16 +116,12 @@ def test_anyBurl_rule_3_fp(): assert ('Vnukovo_International_Airport', 'Yali') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Yali) should have isConnectedTo bounds [1,1] for t=1 timesteps' -@pytest.mark.fp @pytest.mark.slow -def test_anyBurl_rule_4_fp(): +def test_anyBurl_rule_4(mode): graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - pr.reset() - pr.reset_rules() - pr.reset_settings() + setup_mode(mode) # Modify pyreason settings to make verbose and to save the rule trace to a file pr.settings.verbose = True - pr.settings.fp_version = True # Use the FP version of the reasoner pr.settings.atom_trace = True pr.settings.memory_profile = False pr.settings.canonical = True diff --git a/tests/functional/test_basic_reasoning.py b/tests/functional/test_basic_reasoning.py index 72b835d5..b0b39d30 100644 --- a/tests/functional/test_basic_reasoning.py +++ b/tests/functional/test_basic_reasoning.py @@ -59,55 +59,6 @@ def test_hello_world(mode): # John should be popular in timestep 3 assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' - -@pytest.mark.slow -def test_hello_world_consistency(): - """Test consistency between JIT and pure Python implementations.""" - # Reset PyReason - pr.reset() - pr.reset_rules() - pr.reset_settings() - - # Modify the paths based on where you've stored the files we made above - graph_path = './tests/functional/friends_graph.graphml' - - # Modify pyreason settings to make verbose - pr.settings.verbose = True # Print info to screen - pr.settings.atom_trace = True # Print atom trace - pr.settings.test_inconsistency = True - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) - pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=2) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_nodes(interpretation, ['popular']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - - assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person' - assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people' - assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people' - - # Mary should be popular in all three timesteps - assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' - assert 'Mary' in dataframes[1]['component'].values and dataframes[1].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps' - assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps' - - # Justin should be popular in timesteps 1, 2 - assert 'Justin' in dataframes[1]['component'].values and dataframes[1].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps' - assert 'Justin' in dataframes[2]['component'].values and dataframes[2].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps' - - # John should be popular in timestep 3 - assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' - - @pytest.mark.parametrize("mode", ["regular", "fp"]) def test_reorder_clauses(mode): """Test clause reordering functionality.""" diff --git a/tests/functional/test_continuation.py b/tests/functional/test_continuation.py deleted file mode 100644 index 1b572ca1..00000000 --- a/tests/functional/test_continuation.py +++ /dev/null @@ -1,65 +0,0 @@ -# Tests for reasoning continuation functionality (reason again) -import pyreason as pr -import faulthandler -import pytest - - -def setup_mode(mode): - """Configure PyReason settings for the specified mode.""" - pr.reset() - pr.reset_rules() - pr.reset_settings() - pr.settings.verbose = True - pr.settings.atom_trace = True - - if mode == "fp": - pr.settings.fp_version = True - elif mode == "parallel": - pr.settings.parallel_computing = True - - -@pytest.mark.skipif(True, reason="Reason again functionality not implemented for FP version") -@pytest.mark.parametrize("mode", ["regular"]) -def test_reason_again(mode): - """Test reasoning continuation functionality.""" - setup_mode(mode) - - # Modify the paths based on where you've stored the files we made above - graph_path = './tests/functional/friends_graph.graphml' - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) - pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 1)) - - # Run the program for two timesteps to see the diffusion take place - faulthandler.enable() - interpretation = pr.reason(timesteps=1) - - # Now reason again - new_fact = pr.Fact('popular(Mary)', 'popular_fact2', 2, 4) - pr.add_fact(new_fact) - interpretation = pr.reason(timesteps=3, again=True, restart=False) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_nodes(interpretation, ['popular']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - - assert len(dataframes[2]) == 1, 'At t=0 there should be one popular person' - assert len(dataframes[3]) == 2, 'At t=1 there should be two popular people' - assert len(dataframes[4]) == 3, 'At t=2 there should be three popular people' - - # Mary should be popular in all three timesteps - assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' - assert 'Mary' in dataframes[3]['component'].values and dataframes[3].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps' - assert 'Mary' in dataframes[4]['component'].values and dataframes[4].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps' - - # Justin should be popular in timesteps 1, 2 - assert 'Justin' in dataframes[3]['component'].values and dataframes[3].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps' - assert 'Justin' in dataframes[4]['component'].values and dataframes[4].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps' - - # John should be popular in timestep 3 - assert 'John' in dataframes[4]['component'].values and dataframes[4].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' diff --git a/tests/functional/test_custom_thresholds.py b/tests/functional/test_custom_thresholds.py deleted file mode 100644 index 4edb804d..00000000 --- a/tests/functional/test_custom_thresholds.py +++ /dev/null @@ -1,128 +0,0 @@ -# Test if the simple program works with thresholds defined -import pyreason as pr -from pyreason import Threshold -import pytest - - -@pytest.mark.slow -def test_custom_thresholds(): - # Reset PyReason - pr.reset() - pr.reset_rules() - - # Modify the paths based on where you've stored the files we made above - graph_path = "./tests/functional/group_chat_graph.graphml" - - # Modify pyreason settings to make verbose - pr.reset_settings() - pr.settings.verbose = True # Print info to screen - - # Load all the files into pyreason - pr.load_graphml(graph_path) - - # add custom thresholds - user_defined_thresholds = [ - Threshold("greater_equal", ("number", "total"), 1), - Threshold("greater_equal", ("percent", "total"), 100), - ] - - pr.add_rule( - pr.Rule( - "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)", - "viewed_by_all_rule", - custom_thresholds=user_defined_thresholds, - ) - ) - - pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3)) - pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3)) - pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3)) - pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3)) - - # Run the program for three timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=3) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"]) - for t, df in enumerate(dataframes): - print(f"TIMESTEP - {t}") - print(df) - print() - - assert ( - len(dataframes[0]) == 0 - ), "At t=0 the TextMessage should not have been ViewedByAll" - assert ( - len(dataframes[2]) == 1 - ), "At t=2 the TextMessage should have been ViewedByAll" - - # TextMessage should be ViewedByAll in t=2 - assert "TextMessage" in dataframes[2]["component"].values and dataframes[2].iloc[ - 0 - ].ViewedByAll == [ - 1, - 1, - ], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps" - - -@pytest.mark.fp -@pytest.mark.slow -def test_custom_thresholds_fp(): - # Reset PyReason - pr.reset() - pr.reset_rules() - - # Modify the paths based on where you've stored the files we made above - graph_path = "./tests/functional/group_chat_graph.graphml" - - # Modify pyreason settings to make verbose - pr.reset_settings() - pr.settings.verbose = True # Print info to screen - pr.settings.fp_version = True # Use the FP version of the reasoner - - # Load all the files into pyreason - pr.load_graphml(graph_path) - - # add custom thresholds - user_defined_thresholds = [ - Threshold("greater_equal", ("number", "total"), 1), - Threshold("greater_equal", ("percent", "total"), 100), - ] - - pr.add_rule( - pr.Rule( - "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)", - "viewed_by_all_rule", - custom_thresholds=user_defined_thresholds, - ) - ) - - pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3)) - pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3)) - pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3)) - pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3)) - - # Run the program for three timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=3) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"]) - for t, df in enumerate(dataframes): - print(f"TIMESTEP - {t}") - print(df) - print() - - assert ( - len(dataframes[0]) == 0 - ), "At t=0 the TextMessage should not have been ViewedByAll" - assert ( - len(dataframes[2]) == 1 - ), "At t=2 the TextMessage should have been ViewedByAll" - - # TextMessage should be ViewedByAll in t=2 - assert "TextMessage" in dataframes[2]["component"].values and dataframes[2].iloc[ - 0 - ].ViewedByAll == [ - 1, - 1, - ], "TextMessage should have ViewedByAll bounds [1,1] for t=2 timesteps" diff --git a/tests/functional/test_edge_inference.py b/tests/functional/test_edge_inference.py deleted file mode 100644 index ad9eb969..00000000 --- a/tests/functional/test_edge_inference.py +++ /dev/null @@ -1,127 +0,0 @@ -# Edge inference rule tests for PyReason -import pyreason as pr -import pytest - - -def setup_mode(mode): - """Configure PyReason settings for the specified mode.""" - pr.reset() - pr.reset_rules() - pr.reset_settings() - - # Modify pyreason settings to make verbose and to save the rule trace to a file - pr.settings.verbose = True - pr.settings.atom_trace = True - pr.settings.memory_profile = False - pr.settings.canonical = True - pr.settings.inconsistency_check = False - pr.settings.static_graph_facts = False - pr.settings.output_to_file = False - pr.settings.store_interpretation_changes = True - pr.settings.save_graph_attributes_to_trace = True - pr.settings.parallel_computing = False - - if mode == "fp": - pr.settings.fp_version = True - elif mode == "parallel": - pr.settings.parallel_computing = True - - -@pytest.mark.slow -@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) -def test_anyBurl_rule_1(mode): - """Test anyBurl rule 1: isConnectedTo(A, Y) <- isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)""" - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - setup_mode(mode) - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_1', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Vnukovo_International_Airport', 'Riga_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Riga_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.slow -@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) -def test_anyBurl_rule_2(mode): - """Test anyBurl rule 2: isConnectedTo(Y, A) <- isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)""" - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - setup_mode(mode) - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(Y, B), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_2', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 2 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Riga_International_Airport', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Riga_International_Airport, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.slow -@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) -def test_anyBurl_rule_3(mode): - """Test anyBurl rule 3: isConnectedTo(A, Y) <- isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)""" - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - setup_mode(mode) - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('isConnectedTo(A, Y) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_3', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Vnukovo_International_Airport', 'Yali') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Vnukovo_International_Airport, Yali) should have isConnectedTo bounds [1,1] for t=1 timesteps' - - -@pytest.mark.slow -@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) -def test_anyBurl_rule_4(mode): - """Test anyBurl rule 4: isConnectedTo(Y, A) <- isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)""" - graph_path = './tests/functional/knowledge_graph_test_subset.graphml' - setup_mode(mode) - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('isConnectedTo(Y, A) <-1 isConnectedTo(B, Y), Amsterdam_Airport_Schiphol(B), Vnukovo_International_Airport(A)', 'connected_rule_4', infer_edges=True)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=1) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_edges(interpretation, ['isConnectedTo']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - assert len(dataframes) == 2, 'Pyreason should run exactly 1 fixpoint operations' - assert len(dataframes[1]) == 1, 'At t=1 there should be only 1 new isConnectedTo atom' - assert ('Yali', 'Vnukovo_International_Airport') in dataframes[1]['component'].values.tolist() and dataframes[1]['isConnectedTo'].iloc[0] == [1, 1], '(Yali, Vnukovo_International_Airport) should have isConnectedTo bounds [1,1] for t=1 timesteps' diff --git a/tests/functional/test_hello_world_parallel.py b/tests/functional/test_hello_world_parallel.py deleted file mode 100644 index f9c04325..00000000 --- a/tests/functional/test_hello_world_parallel.py +++ /dev/null @@ -1,47 +0,0 @@ -# Test if the simple hello world program works. -import pyreason as pr - - -def test_hello_world_parallel(): - # Reset PyReason - pr.reset() - pr.reset_rules() - - # Modify the paths based on where you've stored the files we made above - graph_path = './tests/functional/friends_graph.graphml' - - # Modify pyreason settings to make verbose - pr.reset_settings() - pr.settings.verbose = True # Print info to screen - pr.settings.parallel_computing = True # Use parallel computing - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) - pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 2)) - - # Run the program for two timesteps to see the diffusion take place - interpretation = pr.reason(timesteps=2) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_nodes(interpretation, ['popular']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - - assert len(dataframes[0]) == 1, 'At t=0 there should be one popular person' - assert len(dataframes[1]) == 2, 'At t=1 there should be two popular people' - assert len(dataframes[2]) == 3, 'At t=2 there should be three popular people' - - # Mary should be popular in all three timesteps - assert 'Mary' in dataframes[0]['component'].values and dataframes[0].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' - assert 'Mary' in dataframes[1]['component'].values and dataframes[1].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps' - assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps' - - # Justin should be popular in timesteps 1, 2 - assert 'Justin' in dataframes[1]['component'].values and dataframes[1].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps' - assert 'Justin' in dataframes[2]['component'].values and dataframes[2].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps' - - # John should be popular in timestep 3 - assert 'John' in dataframes[2]['component'].values and dataframes[2].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' diff --git a/tests/functional/test_pyreason_comprehensive.py b/tests/functional/test_pyreason_comprehensive.py deleted file mode 100644 index ea4b06ac..00000000 --- a/tests/functional/test_pyreason_comprehensive.py +++ /dev/null @@ -1,436 +0,0 @@ -""" -Comprehensive functional tests for pyreason.py to cover missing branches. -These tests focus on error conditions, settings validation, and edge cases. -""" - -import pytest -import tempfile -import os -import networkx as nx -from unittest import mock - -import pyreason as pr - - -class TestSettingsValidation: - """Test settings validation - covers all TypeError branches""" - - def setup_method(self): - """Reset settings before each test""" - pr.reset() - pr.reset_settings() - - def test_verbose_type_error(self): - """Test verbose setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.verbose = "not_bool" - - def test_output_to_file_type_error(self): - """Test output_to_file setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.output_to_file = 123 - - def test_output_file_name_type_error(self): - """Test output_file_name setter with invalid type""" - with pytest.raises(TypeError, match='file_name has to be a string'): - pr.settings.output_file_name = 123 - - def test_graph_attribute_parsing_type_error(self): - """Test graph_attribute_parsing setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.graph_attribute_parsing = "not_bool" - - def test_abort_on_inconsistency_type_error(self): - """Test abort_on_inconsistency setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.abort_on_inconsistency = 1.5 - - def test_memory_profile_type_error(self): - """Test memory_profile setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.memory_profile = [] - - def test_reverse_digraph_type_error(self): - """Test reverse_digraph setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.reverse_digraph = {} - - def test_atom_trace_type_error(self): - """Test atom_trace setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.atom_trace = "false" - - def test_save_graph_attributes_to_trace_type_error(self): - """Test save_graph_attributes_to_trace setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.save_graph_attributes_to_trace = 0 - - def test_canonical_type_error(self): - """Test canonical setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.canonical = "canonical" - - def test_persistent_type_error(self): - """Test persistent setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.persistent = None - - def test_inconsistency_check_type_error(self): - """Test inconsistency_check setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.inconsistency_check = 42 - - def test_static_graph_facts_type_error(self): - """Test static_graph_facts setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.static_graph_facts = [True] - - def test_store_interpretation_changes_type_error(self): - """Test store_interpretation_changes setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.store_interpretation_changes = 1 - - def test_parallel_computing_type_error(self): - """Test parallel_computing setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.parallel_computing = "True" - - def test_update_mode_type_error(self): - """Test update_mode setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a str'): - pr.settings.update_mode = True - - def test_allow_ground_rules_type_error(self): - """Test allow_ground_rules setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.allow_ground_rules = 3.14 - - def test_fp_version_type_error(self): - """Test fp_version setter with invalid type""" - with pytest.raises(TypeError, match='value has to be a bool'): - pr.settings.fp_version = "optimized" - - -class TestFileOperations: - """Test file loading operations and error conditions""" - - def setup_method(self): - """Reset state before each test""" - pr.reset() - pr.reset_settings() - - def test_load_graphml_nonexistent_file(self): - """Test loading non-existent GraphML file""" - with pytest.raises((FileNotFoundError, OSError)): - pr.load_graphml("non_existent_file.graphml") - - def test_load_ipl_nonexistent_file(self): - """Test loading non-existent IPL file""" - with pytest.raises((FileNotFoundError, OSError)): - pr.load_inconsistent_predicate_list("non_existent_ipl.yaml") - - def test_add_rules_from_nonexistent_file(self): - """Test adding rules from non-existent file""" - with pytest.raises((FileNotFoundError, OSError)): - pr.add_rules_from_file("non_existent_rules.txt") - - def test_add_rules_from_file_with_comments_and_empty_lines(self): - """Test rule file parsing handles comments and empty lines""" - with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: - f.write("# This is a comment\n") - f.write("\n") # Empty line - f.write(" \n") # Whitespace-only line - f.write("test_rule(x) <-1 other_rule(x)\n") - f.write("# Another comment\n") - f.write("another_rule(y) <-1 test_rule(y)\n") - temp_path = f.name - - try: - pr.add_rules_from_file(temp_path) - rules = pr.get_rules() - assert len(rules) == 2 # Should only include the 2 actual rules - finally: - os.unlink(temp_path) - - def test_add_inconsistent_predicates(self): - """Test adding inconsistent predicate pairs""" - pr.add_inconsistent_predicate("pred1", "pred2") - pr.add_inconsistent_predicate("pred3", "pred4") - # Should not raise exceptions - - -class TestReasoningErrorConditions: - """Test reasoning function error conditions and edge cases""" - - def setup_method(self): - """Reset state before each test""" - pr.reset() - pr.reset_settings() - - def test_reason_without_rules_exception(self): - """Test reasoning without rules raises exception""" - # Load a graph but no rules - graph = nx.DiGraph() - graph.add_edge("A", "B") - pr.load_graph(graph) - - with pytest.raises(Exception, match='There are no rules'): - pr.reason() - - def test_reason_without_graph_uses_empty_graph(self): - """Test reasoning without graph uses empty graph and warns""" - pr.add_rule(pr.Rule('test(x) <- test2(x)', 'test_rule')) - - with pytest.warns(UserWarning, match='Graph not loaded'): - interpretation = pr.reason() - # Should complete without crashing - - def test_reason_auto_names_rules(self): - """Test that rules get auto-named when no name provided""" - pr.add_rule(pr.Rule('test1(x) <- test2(x)')) # No name - pr.add_rule(pr.Rule('test3(x) <- test4(x)')) # No name - - rules = pr.get_rules() - assert rules[0].get_rule_name() == 'rule_0' - assert rules[1].get_rule_name() == 'rule_1' - - def test_reason_auto_names_facts(self): - """Test that facts get auto-named when no name provided""" - fact1 = pr.Fact('test(node1)') # No name - fact2 = pr.Fact('test(node1, node2)') # No name - - pr.add_fact(fact1) - pr.add_fact(fact2) - - # Names should be auto-generated - assert fact1.name.startswith('fact_') - assert fact2.name.startswith('fact_') - - -class TestGraphAttributeParsing: - """Test graph attribute parsing branches""" - - def setup_method(self): - """Reset state before each test""" - pr.reset() - pr.reset_settings() - - def test_load_graph_with_attribute_parsing_enabled(self): - """Test loading graph with attribute parsing enabled""" - graph = nx.DiGraph() - graph.add_node("A", label="person", age=25) - graph.add_node("B", label="person", age=30) - graph.add_edge("A", "B", relation="knows", weight=0.8) - - pr.settings.graph_attribute_parsing = True - pr.load_graph(graph) - # Should complete without errors - - def test_load_graph_with_attribute_parsing_disabled(self): - """Test loading graph with attribute parsing disabled (lines 540-543, 562-565)""" - graph = nx.DiGraph() - graph.add_node("A", label="person") - graph.add_edge("A", "B", relation="knows") - - pr.settings.graph_attribute_parsing = False - pr.load_graph(graph) - # Should complete without errors and use empty collections - - -class TestOutputFunctionAssertions: - """Test output functions when store_interpretation_changes=False""" - - def setup_method(self): - """Reset state before each test""" - pr.reset() - pr.reset_settings() - - def test_save_rule_trace_assertion(self): - """Test save_rule_trace assertion when store_interpretation_changes=False""" - pr.settings.store_interpretation_changes = False - - with pytest.raises(AssertionError, match='store interpretation changes setting is off'): - pr.save_rule_trace(mock.MagicMock(), './test/') - - def test_get_rule_trace_assertion(self): - """Test get_rule_trace assertion when store_interpretation_changes=False""" - pr.settings.store_interpretation_changes = False - - with pytest.raises(AssertionError, match='store interpretation changes setting is off'): - pr.get_rule_trace(mock.MagicMock()) - - def test_filter_and_sort_nodes_assertion(self): - """Test filter_and_sort_nodes assertion when store_interpretation_changes=False""" - pr.settings.store_interpretation_changes = False - - with pytest.raises(AssertionError, match='store interpretation changes setting is off'): - pr.filter_and_sort_nodes(mock.MagicMock(), ['test']) - - def test_filter_and_sort_edges_assertion(self): - """Test filter_and_sort_edges assertion when store_interpretation_changes=False""" - pr.settings.store_interpretation_changes = False - - with pytest.raises(AssertionError, match='store interpretation changes setting is off'): - pr.filter_and_sort_edges(mock.MagicMock(), ['test']) - - -class TestReasoningModes: - """Test different reasoning modes and settings""" - - def setup_method(self): - """Reset state before each test""" - pr.reset() - pr.reset_settings() - - def test_reason_with_memory_profiling(self): - """Test reasoning with memory profiling enabled""" - # Set up minimal working example - graph = nx.DiGraph() - graph.add_edge("A", "B") - pr.load_graph(graph) - pr.add_rule(pr.Rule('test(x) <- test(y)', 'test_rule')) - pr.add_fact(pr.Fact('test(A)', 'test_fact')) - - pr.settings.memory_profile = True - pr.settings.verbose = False # Reduce output noise - - # Should complete without errors - interpretation = pr.reason(timesteps=1) - - def test_reason_with_output_to_file(self): - """Test reasoning with output_to_file enabled""" - # Set up minimal working example - graph = nx.DiGraph() - graph.add_edge("A", "B") - pr.load_graph(graph) - pr.add_rule(pr.Rule('test(x) <- test(y)', 'test_rule')) - pr.add_fact(pr.Fact('test(A)', 'test_fact')) - - pr.settings.output_to_file = True - pr.settings.output_file_name = "test_output" - - interpretation = pr.reason(timesteps=1) - - # Check if output file was created (and clean up) - import glob - output_files = glob.glob("test_output_*.txt") - for f in output_files: - os.unlink(f) - - def test_reason_again_functionality(self): - """Test reason again functionality (lines 688-693, 788-799)""" - # Set up initial reasoning - graph = nx.DiGraph() - graph.add_edge("A", "B") - pr.load_graph(graph) - pr.add_rule(pr.Rule('test(x) <- test(y)', 'test_rule')) - pr.add_fact(pr.Fact('test(A)', 'test_fact', 0, 1)) - - # First reasoning - interpretation1 = pr.reason(timesteps=1) - - # Add new fact and reason again - pr.add_fact(pr.Fact('test(B)', 'test_fact2', 2, 3)) - interpretation2 = pr.reason(timesteps=2, again=True, restart=False) - - # Should complete without errors - - -class TestAnnotationFunctions: - """Test annotation function management""" - - def test_add_annotation_function(self): - """Test adding annotation function""" - def test_func(annotations, weights): - return sum(w * a[0].lower for w, a in zip(weights, annotations)), 1.0 - - pr.add_annotation_function(test_func) - # Should complete without errors - - -class TestTorchIntegrationHandling: - """Test torch integration state consistency""" - - def test_torch_integration_consistency(self): - """Test that torch integration variables are consistent""" - # Just verify the current state is consistent - if hasattr(pr, 'LogicIntegratedClassifier'): - if pr.LogicIntegratedClassifier is None: - # If LogicIntegratedClassifier is None, ModelInterfaceOptions should also be None - assert pr.ModelInterfaceOptions is None - else: - # If LogicIntegratedClassifier exists, ModelInterfaceOptions should also exist - assert pr.ModelInterfaceOptions is not None - - -class TestQueryFiltering: - """Test query-based rule filtering""" - - def setup_method(self): - """Reset state before each test""" - pr.reset() - pr.reset_settings() - - def test_reason_with_queries(self): - """Test reasoning with query-based rule filtering""" - # Set up test scenario - graph = nx.DiGraph() - graph.add_edges_from([("A", "B"), ("B", "C")]) - pr.load_graph(graph) - - pr.add_rule(pr.Rule('popular(x) <-1 friend(x, y)', 'rule1')) - pr.add_rule(pr.Rule('friend(x, y) <-1 knows(x, y)', 'rule2')) - pr.add_fact(pr.Fact('knows(A, B)', 'fact1')) - - # Create query to filter rules - query = pr.Query('popular(A)') - pr.settings.verbose = False # Reduce output noise - - interpretation = pr.reason(timesteps=1, queries=[query]) - # Should complete and apply rule filtering logic - - -@pytest.mark.fp -class TestFixedPointVersions: - """Test key functionality with FP version enabled""" - - def setup_method(self): - """Reset settings before each test""" - pr.reset() - pr.reset_settings() - pr.settings.fp_version = True # Enable FP version for all tests in this class - - @pytest.mark.fp - @pytest.mark.skip(reason="Pure edge-to-edge transitive reasoning not supported in PyReason architecture. Both regular and FP versions fail this test. PyReason requires node clause anchors for multi-clause rule grounding.") - def test_basic_reasoning_fp(self): - """Test basic reasoning functionality with FP version""" - # Create simple graph - graph = nx.DiGraph() - graph.add_edge("A", "B") - graph.add_edge("B", "C") - pr.load_graph(graph) - - # Add rule and fact - pr.add_rule(pr.Rule('connected(x, z) <-1 connected(x, y), connected(y, z)', 'transitive_rule')) - pr.add_fact(pr.Fact('connected(A, B)', 'fact1')) - pr.add_fact(pr.Fact('connected(B, C)', 'fact2')) - - # Reason - interpretation = pr.reason(timesteps=2) - - # Verify transitivity worked - assert interpretation.query(pr.Query('connected(A, C)')), 'Should infer connected(A, C) via transitivity' - - @pytest.mark.fp - def test_settings_validation_fp(self): - """Test that settings validation works with FP version""" - # Test that fp_version setting is properly set - assert pr.settings.fp_version == True, 'FP version should be enabled' - - # Test that other settings still work - pr.settings.verbose = True - assert pr.settings.verbose == True, 'Verbose setting should work' - - -if __name__ == '__main__': - pytest.main([__file__]) diff --git a/tests/functional/test_reason_again.py b/tests/functional/test_reason_again.py deleted file mode 100644 index ec853438..00000000 --- a/tests/functional/test_reason_again.py +++ /dev/null @@ -1,111 +0,0 @@ -# Test if the simple hello world program works -import pyreason as pr -import faulthandler -import pytest - - -def test_reason_again(): - # Reset PyReason - pr.reset() - pr.reset_rules() - pr.reset_settings() - - # Modify the paths based on where you've stored the files we made above - graph_path = './tests/functional/friends_graph.graphml' - - # Modify pyreason settings to make verbose - pr.settings.verbose = True # Print info to screen - pr.settings.atom_trace = True # Save atom trace - # pr.settings.optimize_rules = False # Disable rule optimization for debugging - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) - pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 1)) - - # Run the program for two timesteps to see the diffusion take place - faulthandler.enable() - interpretation = pr.reason(timesteps=1) - - # Now reason again - new_fact = pr.Fact('popular(Mary)', 'popular_fact2', 2, 4) - pr.add_fact(new_fact) - interpretation = pr.reason(timesteps=3, again=True, restart=False) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_nodes(interpretation, ['popular']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - - assert len(dataframes[2]) == 1, 'At t=0 there should be one popular person' - assert len(dataframes[3]) == 2, 'At t=1 there should be two popular people' - assert len(dataframes[4]) == 3, 'At t=2 there should be three popular people' - - # Mary should be popular in all three timesteps - assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' - assert 'Mary' in dataframes[3]['component'].values and dataframes[3].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps' - assert 'Mary' in dataframes[4]['component'].values and dataframes[4].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps' - - # Justin should be popular in timesteps 1, 2 - assert 'Justin' in dataframes[3]['component'].values and dataframes[3].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps' - assert 'Justin' in dataframes[4]['component'].values and dataframes[4].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps' - - # John should be popular in timestep 3 - assert 'John' in dataframes[4]['component'].values and dataframes[4].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' - - -@pytest.mark.fp -@pytest.mark.skip(reason="Reason again functionality not implemented for FP version") -def test_reason_again_fp(): - # Reset PyReason - pr.reset() - pr.reset_rules() - pr.reset_settings() - - # Modify the paths based on where you've stored the files we made above - graph_path = './tests/functional/friends_graph.graphml' - - # Modify pyreason settings to make verbose - pr.settings.verbose = True # Print info to screen - pr.settings.fp_version = True # Use the FP version of the reasoner - pr.settings.atom_trace = True # Save atom trace - # pr.settings.optimize_rules = False # Disable rule optimization for debugging - - # Load all the files into pyreason - pr.load_graphml(graph_path) - pr.add_rule(pr.Rule('popular(x) <-1 popular(y), Friends(x,y), owns(y,z), owns(x,z)', 'popular_rule')) - pr.add_fact(pr.Fact('popular(Mary)', 'popular_fact', 0, 1)) - - # Run the program for two timesteps to see the diffusion take place - faulthandler.enable() - interpretation = pr.reason(timesteps=1) - - # Now reason again - new_fact = pr.Fact('popular(Mary)', 'popular_fact2', 2, 4) - pr.add_fact(new_fact) - interpretation = pr.reason(timesteps=3, again=True, restart=False) - - # Display the changes in the interpretation for each timestep - dataframes = pr.filter_and_sort_nodes(interpretation, ['popular']) - for t, df in enumerate(dataframes): - print(f'TIMESTEP - {t}') - print(df) - print() - - assert len(dataframes[2]) == 1, 'At t=0 there should be one popular person' - assert len(dataframes[3]) == 2, 'At t=1 there should be two popular people' - assert len(dataframes[4]) == 3, 'At t=2 there should be three popular people' - - # Mary should be popular in all three timesteps - assert 'Mary' in dataframes[2]['component'].values and dataframes[2].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=0 timesteps' - assert 'Mary' in dataframes[3]['component'].values and dataframes[3].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=1 timesteps' - assert 'Mary' in dataframes[4]['component'].values and dataframes[4].iloc[0].popular == [1, 1], 'Mary should have popular bounds [1,1] for t=2 timesteps' - - # Justin should be popular in timesteps 1, 2 - assert 'Justin' in dataframes[3]['component'].values and dataframes[3].iloc[1].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=1 timesteps' - assert 'Justin' in dataframes[4]['component'].values and dataframes[4].iloc[2].popular == [1, 1], 'Justin should have popular bounds [1,1] for t=2 timesteps' - - # John should be popular in timestep 3 - assert 'John' in dataframes[4]['component'].values and dataframes[4].iloc[1].popular == [1, 1], 'John should have popular bounds [1,1] for t=2 timesteps' diff --git a/tests/unit/api_tests/test_pyreason_file_loading.py b/tests/unit/api_tests/test_pyreason_file_loading.py index 2767d915..09ef3f63 100644 --- a/tests/unit/api_tests/test_pyreason_file_loading.py +++ b/tests/unit/api_tests/test_pyreason_file_loading.py @@ -703,22 +703,23 @@ def test_add_rules_from_file_simple_rules(self): finally: os.unlink(tmp_path) - def test_add_rules_from_file_with_comments(self): - """Test loading rules from file with comments.""" - rules_content = """# This is a comment -friend(A, B) <- knows(A, B) -# Another comment -enemy(A, B) <- ~friend(A, B) -# Final comment""" - - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp: - tmp.write(rules_content) - tmp_path = tmp.name + def test_add_rules_from_file_with_comments_and_empty_lines(self): + """Test rule file parsing handles comments and empty lines""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("# This is a comment\n") + f.write("\n") # Empty line + f.write(" \n") # Whitespace-only line + f.write("test_rule(x) <-1 other_rule(x)\n") + f.write("# Another comment\n") + f.write("another_rule(y) <-1 test_rule(y)\n") + temp_path = f.name try: - pr.add_rules_from_file(tmp_path) + pr.add_rules_from_file(temp_path) + rules = pr.get_rules() + assert len(rules) == 2 # Should only include the 2 actual rules finally: - os.unlink(tmp_path) + os.unlink(temp_path) def test_add_rules_from_file_with_empty_lines(self): """Test loading rules from file with empty lines.""" @@ -842,6 +843,13 @@ def test_add_rules_from_file_after_existing_rules(self): os.unlink(tmp_path) + def test_add_inconsistent_predicates(self): + """Test adding inconsistent predicate pairs""" + pr.add_inconsistent_predicate("pred1", "pred2") + pr.add_inconsistent_predicate("pred3", "pred4") + # Should not raise exceptions + + class TestRuleTrace: """Test save_rule_trace() and get_rule_trace() functions.""" diff --git a/tests/unit/api_tests/test_pyreason_reasoning.py b/tests/unit/api_tests/test_pyreason_reasoning.py index ba371aa4..20a2ec49 100644 --- a/tests/unit/api_tests/test_pyreason_reasoning.py +++ b/tests/unit/api_tests/test_pyreason_reasoning.py @@ -22,25 +22,13 @@ def setup_method(self): pr.reset() pr.reset_settings() - def test_reason_with_no_graph_loads_empty_graph(self): - """Test reasoning without loading a graph (should load empty graph with warning).""" - # Don't load any graph - pr.add_rule(Rule("test(A) <- test(A)", "test_rule", False)) + def test_reason_without_graph_uses_empty_graph(self): + """Test reasoning without graph uses empty graph and warns""" + pr.add_rule(pr.Rule('test(x) <- test2(x)', 'test_rule')) - # Capture stdout to check for warning - captured_output = StringIO() - original_stdout = sys.stdout - pr.settings.verbose = True - - try: - sys.stdout = captured_output - interpretation = pr.reason(timesteps=1) - output = captured_output.getvalue() - - # Should contain warning about no graph - assert "Graph not loaded" in output or interpretation is not None - finally: - sys.stdout = original_stdout + with pytest.warns(UserWarning, match='Graph not loaded'): + interpretation = pr.reason() + # Should complete without crashing def test_reason_with_no_rules_raises_exception(self): """Test reasoning without any rules raises exception.""" @@ -51,6 +39,27 @@ def test_reason_with_no_rules_raises_exception(self): with pytest.raises(Exception, match="There are no rules"): pr.reason(timesteps=1) + def test_reason_auto_names_rules(self): + """Test that rules get auto-named when no name provided""" + pr.add_rule(pr.Rule('test1(x) <- test2(x)')) # No name + pr.add_rule(pr.Rule('test3(x) <- test4(x)')) # No name + + rules = pr.get_rules() + assert rules[0].get_rule_name() == 'rule_0' + assert rules[1].get_rule_name() == 'rule_1' + + def test_reason_auto_names_facts(self): + """Test that facts get auto-named when no name provided""" + fact1 = pr.Fact('test(node1)') # No name + fact2 = pr.Fact('test(node1, node2)') # No name + + pr.add_fact(fact1) + pr.add_fact(fact2) + + # Names should be auto-generated + assert fact1.name.startswith('fact_') + assert fact2.name.startswith('fact_') + def test_reason_with_output_to_file(self): """Test reasoning with output_to_file setting.""" graph = nx.DiGraph() diff --git a/tests/unit/api_tests/test_pyreason_state_management.py b/tests/unit/api_tests/test_pyreason_state_management.py index c500be72..91c709f8 100644 --- a/tests/unit/api_tests/test_pyreason_state_management.py +++ b/tests/unit/api_tests/test_pyreason_state_management.py @@ -173,6 +173,17 @@ def setup_method(self): pr.reset() pr.reset_settings() + def test_torch_integration_consistency(self): + """Test that torch integration variables are consistent""" + # Just verify the current state is consistent + if hasattr(pr, 'LogicIntegratedClassifier'): + if pr.LogicIntegratedClassifier is None: + # If LogicIntegratedClassifier is None, ModelInterfaceOptions should also be None + assert pr.ModelInterfaceOptions is None + else: + # If LogicIntegratedClassifier exists, ModelInterfaceOptions should also exist + assert pr.ModelInterfaceOptions is not None + def test_state_isolation_between_operations(self): """Test that state is properly isolated between operations.""" # This test verifies that subsequent operations don't interfere From 23d02098aa2ef99fd1269be349efe2198628f0c6 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 15:09:30 -0400 Subject: [PATCH 19/32] Parameterize functional test --- tests/functional/test_advanced_features.py | 1 - tests/functional/test_anyBurl_infer_edges_rules.py | 6 +++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/functional/test_advanced_features.py b/tests/functional/test_advanced_features.py index 23e7c2f6..e44f9fa2 100644 --- a/tests/functional/test_advanced_features.py +++ b/tests/functional/test_advanced_features.py @@ -1,5 +1,4 @@ # Advanced feature tests for PyReason (annotation functions, custom thresholds, classifier integration) -import faulthandler import pyreason as pr from pyreason import Threshold import torch diff --git a/tests/functional/test_anyBurl_infer_edges_rules.py b/tests/functional/test_anyBurl_infer_edges_rules.py index 28d5774b..fe9e9501 100644 --- a/tests/functional/test_anyBurl_infer_edges_rules.py +++ b/tests/functional/test_anyBurl_infer_edges_rules.py @@ -6,7 +6,7 @@ def setup_mode(mode): pr.reset() pr.reset_rules() pr.reset_settings() - pr.settings.verbose = Tru + pr.settings.verbose = True if mode == "fp": pr.settings.fp_version = True @@ -14,6 +14,7 @@ def setup_mode(mode): pr.settings.parallel_computing = True @pytest.mark.slow +@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) def test_anyBurl_rule_1(mode): graph_path = './tests/functional/knowledge_graph_test_subset.graphml' setup_mode(mode) @@ -47,6 +48,7 @@ def test_anyBurl_rule_1(mode): @pytest.mark.slow +@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) def test_anyBurl_rule_2(mode): graph_path = './tests/functional/knowledge_graph_test_subset.graphml' setup_mode(mode) @@ -82,6 +84,7 @@ def test_anyBurl_rule_2(mode): @pytest.mark.slow +@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) def test_anyBurl_rule_3(mode): graph_path = './tests/functional/knowledge_graph_test_subset.graphml' setup_mode(mode) @@ -117,6 +120,7 @@ def test_anyBurl_rule_3(mode): @pytest.mark.slow +@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) def test_anyBurl_rule_4(mode): graph_path = './tests/functional/knowledge_graph_test_subset.graphml' setup_mode(mode) From 5c0e4359dbe4e344b0a16fffd472bd463ee40f7a Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 15:13:32 -0400 Subject: [PATCH 20/32] Remove debug scripts --- debug.py | 100 ----------------------------------- debug_thresholds.py | 126 -------------------------------------------- 2 files changed, 226 deletions(-) delete mode 100644 debug.py delete mode 100644 debug_thresholds.py diff --git a/debug.py b/debug.py deleted file mode 100644 index 65230fce..00000000 --- a/debug.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Debug script for test_annotation_function parallel mode issue.""" -import pyreason as pr -from pyreason import Threshold -import numba -import numpy as np -from pyreason.scripts.numba_wrapper.numba_types.interval_type import closed - - -@numba.njit -def probability_func(annotations, weights): - prob_A = annotations[0][0].lower - prob_B = annotations[1][0].lower - union_prob = prob_A + prob_B - union_prob = np.round(union_prob, 3) - return union_prob, 1 - - -def main(): - # Setup parallel mode - pr.reset() - pr.reset_rules() - pr.reset_settings() - pr.settings.verbose = False # Disable verbose to speed up - pr.settings.parallel_computing = True - pr.settings.allow_ground_rules = True - - print("Settings configured:") - print(f" parallel_computing: {pr.settings.parallel_computing}") - print(f" allow_ground_rules: {pr.settings.allow_ground_rules}") - - print("=" * 80) - print("PARALLEL MODE DEBUG") - print("=" * 80) - - # Add facts - pr.add_fact(pr.Fact('P(A) : [0.01, 1]')) - pr.add_fact(pr.Fact('P(B) : [0.2, 1]')) - - # Add annotation function - pr.add_annotation_function(probability_func) - - # Add rule - pr.add_rule(pr.Rule('union_probability(A, B):probability_func <- P(A):[0, 1], P(B):[0, 1]', infer_edges=True)) - - # Run reasoning - print("\nRunning reasoning for 1 timestep...") - interpretation = pr.reason(timesteps=1) - - # Display results - print("\n" + "=" * 80) - print("RESULTS") - print("=" * 80) - - dataframes = pr.filter_and_sort_edges(interpretation, ['union_probability']) - for t, df in enumerate(dataframes): - print(f'\nTIMESTEP - {t}') - print(df) - print() - - # Check what we actually got - print("\n" + "=" * 80) - print("QUERY RESULTS") - print("=" * 80) - - # Try to query the actual value - query_result = interpretation.query(pr.Query('union_probability(A, B) : [0.21, 1]')) - print(f"\nQuery for [0.21, 1]: {query_result}") - - # Let's also try to see what value we actually got - # Query with a wider range to see if it exists at all - wider_query = interpretation.query(pr.Query('union_probability(A, B) : [0, 1]')) - print(f"Query for [0, 1] (wider range): {wider_query}") - - # Get the actual edge data - print("\n" + "=" * 80) - print("DETAILED EDGE INSPECTION") - print("=" * 80) - - # Access the interpretation's internal data to see actual values - if hasattr(interpretation, 'get_dict'): - edge_dict = interpretation.get_dict() - print(f"\nEdge dictionary keys: {edge_dict.keys()}") - if ('A', 'B') in edge_dict: - print(f"\nEdge ('A', 'B') data:") - for key, value in edge_dict[('A', 'B')].items(): - print(f" {key}: {value}") - - # Alternative: inspect atoms directly - if hasattr(interpretation, 'atoms'): - print(f"\nAtoms available: {interpretation.atoms}") - - print("\n" + "=" * 80) - print("EXPECTED vs ACTUAL") - print("=" * 80) - print(f"Expected: union_probability(A, B) with bounds [0.21, 1]") - print(f"Actual: See dataframe above") - - -if __name__ == "__main__": - main() diff --git a/debug_thresholds.py b/debug_thresholds.py deleted file mode 100644 index ca79f154..00000000 --- a/debug_thresholds.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Debug script for test_custom_thresholds parallel mode issue.""" -import pyreason as pr -from pyreason import Threshold - - -def main(): - # Setup parallel mode - pr.reset() - pr.reset_rules() - pr.reset_settings() - pr.settings.verbose = False # Disable verbose to speed up - pr.settings.parallel_computing = True - pr.settings.atom_trace = True - - print("=" * 80) - print("CUSTOM THRESHOLDS PARALLEL MODE DEBUG") - print("=" * 80) - print(f"Settings:") - print(f" parallel_computing: {pr.settings.parallel_computing}") - print(f" atom_trace: {pr.settings.atom_trace}") - - # Load graph - graph_path = "./tests/functional/group_chat_graph.graphml" - print(f"\nLoading graph from: {graph_path}") - pr.load_graphml(graph_path) - - # Add custom thresholds - user_defined_thresholds = [ - Threshold("greater_equal", ("number", "total"), 1), - Threshold("greater_equal", ("percent", "total"), 100), - ] - print(f"\nCustom thresholds: {user_defined_thresholds}") - - # Add rule - pr.add_rule( - pr.Rule( - "ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)", - "viewed_by_all_rule", - custom_thresholds=user_defined_thresholds, - ) - ) - print("Rule added: ViewedByAll(y) <- HaveAccess(x,y), Viewed(x)") - - # Add facts - pr.add_fact(pr.Fact("Viewed(Zach)", "seen-fact-zach", 0, 3)) - pr.add_fact(pr.Fact("Viewed(Justin)", "seen-fact-justin", 0, 3)) - pr.add_fact(pr.Fact("Viewed(Michelle)", "seen-fact-michelle", 1, 3)) - pr.add_fact(pr.Fact("Viewed(Amy)", "seen-fact-amy", 2, 3)) - print("\nFacts added:") - print(" Viewed(Zach) at t=0") - print(" Viewed(Justin) at t=0") - print(" Viewed(Michelle) at t=1") - print(" Viewed(Amy) at t=2") - - # Run reasoning - print("\n" + "=" * 80) - print("Running reasoning for 3 timesteps...") - print("=" * 80) - interpretation = pr.reason(timesteps=3) - print("Reasoning completed!") - - # Display results - print("\n" + "=" * 80) - print("RESULTS - ViewedByAll at each timestep") - print("=" * 80) - - dataframes = pr.filter_and_sort_nodes(interpretation, ["ViewedByAll"]) - for t, df in enumerate(dataframes): - print(f"\nTIMESTEP {t}:") - print(f" Number of nodes with ViewedByAll: {len(df)}") - if len(df) > 0: - print(df) - else: - print(" (no nodes with ViewedByAll)") - - # Check specific assertions - print("\n" + "=" * 80) - print("ASSERTION CHECKS") - print("=" * 80) - - t0_check = len(dataframes[0]) == 0 - print(f"✓ t=0: ViewedByAll count = {len(dataframes[0])} (expected: 0) - {'PASS' if t0_check else 'FAIL'}") - - t2_check = len(dataframes[2]) == 1 - print(f"✓ t=2: ViewedByAll count = {len(dataframes[2])} (expected: 1) - {'PASS' if t2_check else 'FAIL'}") - - if len(dataframes[2]) > 0: - has_textmsg = "TextMessage" in dataframes[2]["component"].values - if has_textmsg: - bounds = dataframes[2].iloc[0].ViewedByAll - bounds_check = bounds == [1, 1] - print(f"✓ t=2: TextMessage bounds = {bounds} (expected: [1, 1]) - {'PASS' if bounds_check else 'FAIL'}") - else: - print(f"✗ t=2: TextMessage not found in ViewedByAll nodes") - print(f" Available nodes: {dataframes[2]['component'].values}") - else: - print("✗ t=2: No ViewedByAll nodes found (expected TextMessage)") - - # Additional debugging: show all Viewed facts at each timestep - print("\n" + "=" * 80) - print("DEBUG - Viewed nodes at each timestep") - print("=" * 80) - viewed_dataframes = pr.filter_and_sort_nodes(interpretation, ["Viewed"]) - for t, df in enumerate(viewed_dataframes): - print(f"\nTIMESTEP {t}:") - if len(df) > 0: - print(df) - else: - print(" (no Viewed nodes)") - - # Show HaveAccess edges if possible - print("\n" + "=" * 80) - print("DEBUG - HaveAccess edges") - print("=" * 80) - try: - access_dataframes = pr.filter_and_sort_edges(interpretation, ["HaveAccess"]) - print(f"Number of HaveAccess edges at t=0: {len(access_dataframes[0]) if access_dataframes else 'N/A'}") - if access_dataframes and len(access_dataframes[0]) > 0: - print("\nSample HaveAccess edges:") - print(access_dataframes[0].head(10)) - except Exception as e: - print(f"Could not retrieve HaveAccess edges: {e}") - - -if __name__ == "__main__": - main() \ No newline at end of file From 5a16be77de8b98e4d78e5bfd722665e56eb5ab45 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 15:15:03 -0400 Subject: [PATCH 21/32] Rem rule trace csv --- renamed_nodes_rule_trace.csv | 24 ------------------------ 1 file changed, 24 deletions(-) delete mode 100644 renamed_nodes_rule_trace.csv diff --git a/renamed_nodes_rule_trace.csv b/renamed_nodes_rule_trace.csv deleted file mode 100644 index 93ba1d0d..00000000 --- a/renamed_nodes_rule_trace.csv +++ /dev/null @@ -1,24 +0,0 @@ -Time,Fixed-Point-Operation,Node,Label,Old Bound,New Bound,Occurred Due To,Clause-1,Clause-2 -0,1,card_drawn_obj,_9s,"[0.0,1.0]","[1.0,1.0]",card_drawn_obj-9s-fact,, -0,2,_9s,player_holds_,"[0.0,1.0]","[0.6,1.0]",player_holds_9s_rule,['card_drawn_obj'], -0,3,hand,hand_as_point_vals,"[0.0,1.0]","[0.6,1.0]",hand_as_point_vals_rule,['_9s'], -0,4,card_drawn_obj,_3c,"[0.0,1.0]","[1.0,1.0]",card_drawn_obj-3c-fact,, -0,5,_3c,player_holds_,"[0.0,1.0]","[0.6,1.0]",player_holds_3c_rule,['card_drawn_obj'], -0,6,hand,hand_as_point_vals,"[0.6,1.0]","[0.66,1.0]",hand_as_point_vals_rule,"['_9s', '_3c']", -0,7,card_drawn_obj,_Ad,"[0.0,1.0]","[1.0,1.0]",card_drawn_obj-Ad-fact,, -0,8,_Ad,player_holds_,"[0.0,1.0]","[0.3,1.0]",player_holds_Ad_rule,['card_drawn_obj'], -0,9,hand,hand_as_point_vals,"[0.66,1.0]","[0.663,1.0]",hand_as_point_vals_rule,"['_9s', '_3c', '_Ad']", -0,10,card_drawn_obj,_Kc,"[0.0,1.0]","[1.0,1.0]",card_drawn_obj-Kc-fact,, -0,11,_Kc,player_holds_,"[0.0,1.0]","[0.9,1.0]",player_holds_Kc_rule,['card_drawn_obj'], -0,12,hand,hand_as_point_vals,"[0.663,1.0]","[0.6639,1.0]",hand_as_point_vals_rule,"['_9s', '_3c', '_Ad', '_Kc']", -0,13,card_drawn_obj,_4d,"[0.0,1.0]","[1.0,1.0]",card_drawn_obj-4d-fact,, -0,14,_4d,player_holds_,"[0.0,1.0]","[0.6,1.0]",player_holds_4d_rule,['card_drawn_obj'], -0,15,hand,hand_as_point_vals,"[0.6639,1.0]","[0.66396,1.0]",hand_as_point_vals_rule,"['_9s', '_3c', '_Ad', '_Kc', '_4d']", -0,16,card_drawn_obj,_9c,"[0.0,1.0]","[1.0,1.0]",card_drawn_obj-9c-fact,, -0,17,_9c,player_holds_,"[0.0,1.0]","[0.6,1.0]",player_holds_9c_rule,['card_drawn_obj'], -0,18,hand,hand_as_point_vals,"[0.66396,1.0]","[0.663966,1.0]",hand_as_point_vals_rule,"['_9s', '_3c', '_Ad', '_Kc', '_4d', '_9c']", -0,19,hand,odds_of_losing,"[0.0,1.0]","[0.2391304347826087,1.0]",odds_of_losing_rule,['hand'],"[('Ah', 'full_deck'), ('Ad', 'full_deck'), ('Ac', 'full_deck'), ('As', 'full_deck'), ('2h', 'full_deck'), ('2d', 'full_deck'), ('2c', 'full_deck'), ('2s', 'full_deck'), ('3h', 'full_deck'), ('3d', 'full_deck'), ('3c', 'full_deck'), ('3s', 'full_deck'), ('4h', 'full_deck'), ('4d', 'full_deck'), ('4c', 'full_deck'), ('4s', 'full_deck'), ('5h', 'full_deck'), ('5d', 'full_deck'), ('5c', 'full_deck'), ('5s', 'full_deck'), ('6h', 'full_deck'), ('6d', 'full_deck'), ('6c', 'full_deck'), ('6s', 'full_deck'), ('7h', 'full_deck'), ('7d', 'full_deck'), ('7c', 'full_deck'), ('7s', 'full_deck'), ('8h', 'full_deck'), ('8d', 'full_deck'), ('8c', 'full_deck'), ('8s', 'full_deck'), ('9h', 'full_deck'), ('9d', 'full_deck'), ('9c', 'full_deck'), ('9s', 'full_deck'), ('10h', 'full_deck'), ('10d', 'full_deck'), ('10c', 'full_deck'), ('10s', 'full_deck'), ('Jh', 'full_deck'), ('Jd', 'full_deck'), ('Jc', 'full_deck'), ('Js', 'full_deck'), ('Qh', 'full_deck'), ('Qd', 'full_deck'), ('Qc', 'full_deck'), ('Qs', 'full_deck'), ('Kh', 'full_deck'), ('Kd', 'full_deck'), ('Kc', 'full_deck'), ('Ks', 'full_deck')]" -0,20,card_drawn_obj,_2s,"[0.0,1.0]","[1.0,1.0]",card_drawn_obj-2s-fact,, -0,21,_2s,player_holds_,"[0.0,1.0]","[0.6,1.0]",player_holds_2s_rule,['card_drawn_obj'], -0,22,hand,hand_as_point_vals,"[0.663966,1.0]","[0.6639666,1.0]",hand_as_point_vals_rule,"['_9s', '_3c', '_Ad', '_Kc', '_4d', '_9c', '_2s']", -0,23,hand,odds_of_losing,"[0.2391304347826087,1.0]","[1.0,1.0]",odds_of_losing_rule,['hand'],"[('Ah', 'full_deck'), ('Ad', 'full_deck'), ('Ac', 'full_deck'), ('As', 'full_deck'), ('2h', 'full_deck'), ('2d', 'full_deck'), ('2c', 'full_deck'), ('2s', 'full_deck'), ('3h', 'full_deck'), ('3d', 'full_deck'), ('3c', 'full_deck'), ('3s', 'full_deck'), ('4h', 'full_deck'), ('4d', 'full_deck'), ('4c', 'full_deck'), ('4s', 'full_deck'), ('5h', 'full_deck'), ('5d', 'full_deck'), ('5c', 'full_deck'), ('5s', 'full_deck'), ('6h', 'full_deck'), ('6d', 'full_deck'), ('6c', 'full_deck'), ('6s', 'full_deck'), ('7h', 'full_deck'), ('7d', 'full_deck'), ('7c', 'full_deck'), ('7s', 'full_deck'), ('8h', 'full_deck'), ('8d', 'full_deck'), ('8c', 'full_deck'), ('8s', 'full_deck'), ('9h', 'full_deck'), ('9d', 'full_deck'), ('9c', 'full_deck'), ('9s', 'full_deck'), ('10h', 'full_deck'), ('10d', 'full_deck'), ('10c', 'full_deck'), ('10s', 'full_deck'), ('Jh', 'full_deck'), ('Jd', 'full_deck'), ('Jc', 'full_deck'), ('Js', 'full_deck'), ('Qh', 'full_deck'), ('Qd', 'full_deck'), ('Qc', 'full_deck'), ('Qs', 'full_deck'), ('Kh', 'full_deck'), ('Kd', 'full_deck'), ('Kc', 'full_deck'), ('Ks', 'full_deck')]" From 218247c361058109607a85f22c80de8f99246058 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 15:39:27 -0400 Subject: [PATCH 22/32] Move parrallel consistency --- .../api_tests/test_pyreason_add_operations.py | 53 ------------------- .../test_parralell_consistency.py} | 0 2 files changed, 53 deletions(-) rename tests/unit/{api_tests/test_pyreason_file_consistency.py => dont_disable_jit/test_parralell_consistency.py} (100%) diff --git a/tests/unit/api_tests/test_pyreason_add_operations.py b/tests/unit/api_tests/test_pyreason_add_operations.py index fb84f6f9..a7b21bcc 100644 --- a/tests/unit/api_tests/test_pyreason_add_operations.py +++ b/tests/unit/api_tests/test_pyreason_add_operations.py @@ -17,8 +17,6 @@ def setup_method(self): def test_add_rule_creates_rules_list_when_none(self): """Test add_rule() works when starting with no rules.""" - - # Start with clean state pr.reset_rules() @@ -30,8 +28,6 @@ def test_add_rule_creates_rules_list_when_none(self): def test_add_rule_with_valid_rule(self): """Test adding a valid rule.""" - - rule = pr.Rule('test(x) <- fact(x)', 'test_rule') # Should not raise an exception @@ -39,8 +35,6 @@ def test_add_rule_with_valid_rule(self): def test_add_rule_with_rule_without_name(self): """Test adding a rule without a name (should work).""" - - rule = pr.Rule('test(x) <- fact(x)') # No name provided # Should work without errors even without explicit name @@ -48,8 +42,6 @@ def test_add_rule_with_rule_without_name(self): def test_add_multiple_rules_auto_naming(self): """Test adding multiple rules without explicit names.""" - - rule1 = pr.Rule('test1(x) <- fact1(x)') rule2 = pr.Rule('test2(x) <- fact2(x)') @@ -59,8 +51,6 @@ def test_add_multiple_rules_auto_naming(self): def test_add_rule_with_named_and_unnamed_rules(self): """Test mixing named and unnamed rules.""" - - rule1 = pr.Rule('test1(x) <- fact1(x)', 'named_rule') rule2 = pr.Rule('test2(x) <- fact2(x)') # Without explicit name @@ -70,8 +60,6 @@ def test_add_rule_with_named_and_unnamed_rules(self): def test_add_rule_appends_to_existing_list(self): """Test that add_rule works with multiple rules.""" - - rule1 = pr.Rule('test1(x) <- fact1(x)') rule2 = pr.Rule('test2(x) <- fact2(x)') @@ -92,8 +80,6 @@ def setup_method(self): def test_add_fact_creates_lists_when_none(self): """Test add_fact() works when starting with no facts.""" - - # Start with clean state pr.reset() @@ -104,8 +90,6 @@ def test_add_fact_creates_lists_when_none(self): def test_add_fact_node_fact(self): """Test adding a node fact.""" - - fact = pr.Fact('test(node1)') # Should not raise an exception @@ -113,8 +97,6 @@ def test_add_fact_node_fact(self): def test_add_fact_edge_fact(self): """Test adding an edge fact.""" - - fact = pr.Fact('test(node1, node2)') # Should not raise an exception @@ -122,8 +104,6 @@ def test_add_fact_edge_fact(self): def test_add_fact_with_name(self): """Test adding a fact with a name.""" - - fact = pr.Fact('test(node1)', 'named_fact') pr.add_fact(fact) @@ -132,8 +112,6 @@ def test_add_fact_with_name(self): def test_add_fact_without_name_auto_generates(self): """Test adding a fact without a name auto-generates one.""" - - fact = pr.Fact('test(node1)') # No name pr.add_fact(fact) @@ -143,8 +121,6 @@ def test_add_fact_without_name_auto_generates(self): def test_add_multiple_facts_auto_naming(self): """Test adding multiple facts with auto-naming.""" - - fact1 = pr.Fact('test1(node1)') fact2 = pr.Fact('test2(node2)') @@ -158,16 +134,12 @@ def test_add_multiple_facts_auto_naming(self): def test_add_fact_with_time_bounds(self): """Test adding a fact with time bounds.""" - - fact = pr.Fact('test(node1)', 'timed_fact', 0, 5) pr.add_fact(fact) def test_add_mixed_node_and_edge_facts(self): """Test adding both node and edge facts.""" - - node_fact = pr.Fact('test_node(node1)') edge_fact = pr.Fact('test_edge(node1, node2)') @@ -186,7 +158,6 @@ def setup_method(self): def test_add_annotation_function_valid(self): """Test adding a valid annotation function.""" - def test_func(annotations, weights): return 0.5, 0.5 @@ -196,15 +167,12 @@ def test_func(annotations, weights): def test_add_annotation_function_lambda(self): """Test adding a lambda annotation function.""" - - func = lambda annotations, weights: (0.8, 0.8) pr.add_annotation_function(func) def test_add_multiple_annotation_functions(self): """Test adding multiple annotation functions.""" - def func1(annotations, weights): return 0.1, 0.1 @@ -222,7 +190,6 @@ def func3(annotations, weights): def test_add_annotation_function_with_complex_logic(self): """Test adding annotation function with complex logic.""" - def complex_func(annotations, weights): if not annotations: return 0.0, 0.0 @@ -244,15 +211,11 @@ def setup_method(self): def test_get_rules_when_none(self): """Test get_rules() when no rules have been added.""" - - rules = pr.get_rules() assert rules is None def test_get_rules_after_adding_rules(self): """Test get_rules() returns added rules.""" - - rule1 = pr.Rule('test1(x) <- fact1(x)', 'rule1') rule2 = pr.Rule('test2(x) <- fact2(x)', 'rule2') @@ -274,15 +237,11 @@ def setup_method(self): def test_add_inconsistent_predicate_pair(self): """Test adding a pair of inconsistent predicates.""" - - # Should not raise an exception pr.add_inconsistent_predicate('pred1', 'pred2') def test_add_multiple_inconsistent_predicate_pairs(self): """Test adding multiple pairs of inconsistent predicates.""" - - pr.add_inconsistent_predicate('pred1', 'pred2') pr.add_inconsistent_predicate('pred3', 'pred4') pr.add_inconsistent_predicate('pred5', 'pred6') @@ -306,8 +265,6 @@ def setup_method(self): def test_add_rules_and_facts_sequence(self): """Test adding rules and facts in sequence.""" - - # Add rules rule1 = pr.Rule('test1(x) <- fact1(x)') rule2 = pr.Rule('test2(x) <- fact2(x)') @@ -330,8 +287,6 @@ def test_func(annotations, weights): def test_complex_operation_sequence(self): """Test a complex sequence of operations.""" - - # Mixed sequence pr.add_rule(pr.Rule('rule1(x) <- fact1(x)', 'named_rule')) pr.add_fact(pr.Fact('fact1(node1)', 'named_fact')) @@ -346,8 +301,6 @@ def annotation_func(annotations, weights): def test_operations_after_reset(self): """Test operations work correctly after reset.""" - - # Add some content pr.add_rule(pr.Rule('test(x) <- fact(x)')) pr.add_fact(pr.Fact('fact(node1)')) @@ -376,8 +329,6 @@ def setup_method(self): def test_rule_counter_persistence(self): """Test that multiple rules can be added successfully.""" - - rule1 = pr.Rule('test1(x) <- fact1(x)') rule2 = pr.Rule('test2(x) <- fact2(x)', 'named_rule') rule3 = pr.Rule('test3(x) <- fact3(x)') @@ -389,8 +340,6 @@ def test_rule_counter_persistence(self): def test_fact_counter_independence(self): """Test that rules and facts can be added independently.""" - - rule = pr.Rule('test(x) <- fact(x)') fact = pr.Fact('fact(node1)') @@ -403,8 +352,6 @@ def test_fact_counter_independence(self): def test_counters_after_reset(self): """Test that rules and facts work after reset.""" - - # Add some rules and facts rule1 = pr.Rule('test1(x) <- fact1(x)') fact1 = pr.Fact('fact1(node1)') diff --git a/tests/unit/api_tests/test_pyreason_file_consistency.py b/tests/unit/dont_disable_jit/test_parralell_consistency.py similarity index 100% rename from tests/unit/api_tests/test_pyreason_file_consistency.py rename to tests/unit/dont_disable_jit/test_parralell_consistency.py From 8f8dc6cb727bf215cb6b521a662947b3c0ab9e54 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 16:25:36 -0400 Subject: [PATCH 23/32] Run full functional test suite on push --- .../workflows/python-package-version-test.yml | 6 +- .pre-commit-config.yaml | 2 +- contributing.md | 4 +- run_tests.py | 1 + test_config.json | 4 +- .../api_tests/conftest.py | 0 .../api_tests/test_pyreason_add_operations.py | 2 - .../api_tests/test_pyreason_file_loading.py | 0 .../api_tests/test_pyreason_reasoning.py | 14 -- .../api_tests/test_pyreason_settings.py | 0 .../test_pyreason_state_management.py | 0 .../api_tests/test_pyreason_validation.py | 0 tests/functional/test_basic_reasoning.py | 146 ++++++++++++++++++ 13 files changed, 155 insertions(+), 24 deletions(-) rename tests/{unit => functional}/api_tests/conftest.py (100%) rename tests/{unit => functional}/api_tests/test_pyreason_add_operations.py (99%) rename tests/{unit => functional}/api_tests/test_pyreason_file_loading.py (100%) rename tests/{unit => functional}/api_tests/test_pyreason_reasoning.py (98%) rename tests/{unit => functional}/api_tests/test_pyreason_settings.py (100%) rename tests/{unit => functional}/api_tests/test_pyreason_state_management.py (100%) rename tests/{unit => functional}/api_tests/test_pyreason_validation.py (100%) diff --git a/.github/workflows/python-package-version-test.yml b/.github/workflows/python-package-version-test.yml index e4efb8bb..7717ec1c 100644 --- a/.github/workflows/python-package-version-test.yml +++ b/.github/workflows/python-package-version-test.yml @@ -35,13 +35,13 @@ jobs: python -m ruff check pyreason/scripts - name: Pytest Unit Tests with JIT Disabled run: | - pytest tests/unit/disable_jit -m "not slow" --tb=short -q + pytest tests/unit/disable_jit --tb=short -q - name: Pytest Unit Tests with JIT Enabled run: | - pytest tests/unit/dont_disable_jit -m "not slow" --tb=short -q + pytest tests/unit/dont_disable_jit --tb=short -q - name: Pytest API Tests run: | - pytest tests/unit/api_tests --tb=short -q + pytest tests/functional/api_tests --tb=short -q - name: Pytest Functional Tests run: | pytest tests/functional/ --tb=short diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 20a21134..e11fbf52 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,7 @@ repos: # --- PUSH STAGE: Complete test suite --- - id: pytest-unit-api name: Run pyreason api unit tests - entry: .venv/bin/python -m pytest tests/unit/api_tests --tb=short -q + entry: .venv/bin/python -m pytest tests/functional/api_tests --tb=short -q language: system pass_filenames: false stages: [pre-push] diff --git a/contributing.md b/contributing.md index 52461586..1a32478c 100644 --- a/contributing.md +++ b/contributing.md @@ -36,7 +36,7 @@ this command to view linting results: PyReason uses a unified test runner that handles multiple test configurations automatically. The test suite is organized into four directories: -- **`tests/unit/api_tests/`** - Tests for main pyreason.py API functions (JIT enabled, real pyreason) +- **`tests/functional/api_tests/`** - Tests for main pyreason.py API functions (JIT enabled, real pyreason) - **`tests/unit/disable_jit/`** - Tests for internal interpretation logic (JIT disabled, stubbed environment) - **`tests/unit/dont_disable_jit/`** - Tests for components that benefit from JIT (JIT enabled, lightweight stubs) - **`tests/functional/`** - End-to-end functional tests (JIT enabled, real pyreason, longer running) @@ -99,7 +99,7 @@ You can still run pytest directly on individual directories: ```bash # API tests -pytest tests/unit/api_tests/ -v +pytest tests/functional/api_tests/ -v # JIT disabled tests NUMBA_DISABLE_JIT=1 pytest tests/unit/disable_jit/ -v diff --git a/run_tests.py b/run_tests.py index 8a2d01ed..18982992 100755 --- a/run_tests.py +++ b/run_tests.py @@ -263,6 +263,7 @@ def _create_coverage_config(self): */conftest.py */__pycache__/* */setup.py + */interpretation_parallel.py [report] exclude_lines = diff --git a/test_config.json b/test_config.json index 086502e8..260a8077 100644 --- a/test_config.json +++ b/test_config.json @@ -3,7 +3,7 @@ "api_tests": { "name": "API Tests", "description": "Tests for main pyreason.py API functions", - "path": "tests/unit/api_tests", + "path": "tests/functional/api_tests", "coverage_source": "pyreason", "jit_disabled": false, "uses_real_pyreason": true, @@ -43,7 +43,7 @@ "jit_disabled": false, "uses_real_pyreason": true, "environment_vars": {}, - "pytest_args": ["-v", "-m", "not slow"], + "pytest_args": ["-v"], "timeout": 600 } }, diff --git a/tests/unit/api_tests/conftest.py b/tests/functional/api_tests/conftest.py similarity index 100% rename from tests/unit/api_tests/conftest.py rename to tests/functional/api_tests/conftest.py diff --git a/tests/unit/api_tests/test_pyreason_add_operations.py b/tests/functional/api_tests/test_pyreason_add_operations.py similarity index 99% rename from tests/unit/api_tests/test_pyreason_add_operations.py rename to tests/functional/api_tests/test_pyreason_add_operations.py index a7b21bcc..7412e0ad 100644 --- a/tests/unit/api_tests/test_pyreason_add_operations.py +++ b/tests/functional/api_tests/test_pyreason_add_operations.py @@ -248,8 +248,6 @@ def test_add_multiple_inconsistent_predicate_pairs(self): def test_add_inconsistent_predicate_same_predicates(self): """Test adding the same predicate as inconsistent with itself.""" - - # This might be an edge case, but should be handled gracefully pr.add_inconsistent_predicate('pred1', 'pred1') diff --git a/tests/unit/api_tests/test_pyreason_file_loading.py b/tests/functional/api_tests/test_pyreason_file_loading.py similarity index 100% rename from tests/unit/api_tests/test_pyreason_file_loading.py rename to tests/functional/api_tests/test_pyreason_file_loading.py diff --git a/tests/unit/api_tests/test_pyreason_reasoning.py b/tests/functional/api_tests/test_pyreason_reasoning.py similarity index 98% rename from tests/unit/api_tests/test_pyreason_reasoning.py rename to tests/functional/api_tests/test_pyreason_reasoning.py index 20a2ec49..04f9f35f 100644 --- a/tests/unit/api_tests/test_pyreason_reasoning.py +++ b/tests/functional/api_tests/test_pyreason_reasoning.py @@ -314,20 +314,6 @@ def test_reason_with_different_update_modes(self): interpretation = pr.reason(timesteps=1) assert interpretation is not None - @pytest.mark.skip(reason="This test is very slow and we already test this extensively") - def test_reason_with_different_fp_versions(self): - """Test reasoning with different fixed point versions.""" - graph = nx.DiGraph() - graph.add_edge('A', 'B') - pr.load_graph(graph) - pr.add_rule(Rule("friend(A, B) <- connected(A, B)", "test_rule", False)) - - # Test different fp versions (boolean) - for fp_version in [True, False]: - pr.settings.fp_version = fp_version - interpretation = pr.reason(timesteps=1) - assert interpretation is not None - def test_reason_with_complex_rule_structure(self): """Test reasoning with complex rules that might trigger clause reordering.""" graph = nx.DiGraph() diff --git a/tests/unit/api_tests/test_pyreason_settings.py b/tests/functional/api_tests/test_pyreason_settings.py similarity index 100% rename from tests/unit/api_tests/test_pyreason_settings.py rename to tests/functional/api_tests/test_pyreason_settings.py diff --git a/tests/unit/api_tests/test_pyreason_state_management.py b/tests/functional/api_tests/test_pyreason_state_management.py similarity index 100% rename from tests/unit/api_tests/test_pyreason_state_management.py rename to tests/functional/api_tests/test_pyreason_state_management.py diff --git a/tests/unit/api_tests/test_pyreason_validation.py b/tests/functional/api_tests/test_pyreason_validation.py similarity index 100% rename from tests/unit/api_tests/test_pyreason_validation.py rename to tests/functional/api_tests/test_pyreason_validation.py diff --git a/tests/functional/test_basic_reasoning.py b/tests/functional/test_basic_reasoning.py index b0b39d30..8061c953 100644 --- a/tests/functional/test_basic_reasoning.py +++ b/tests/functional/test_basic_reasoning.py @@ -113,3 +113,149 @@ def test_reorder_clauses(mode): else: # Regular version: The second row, clause 1 should be the edge grounding ('Justin', 'Mary') assert rule_trace_node.iloc[2]['Clause-1'][0] == ('Justin', 'Mary') + + +@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) +def test_filter_and_sort_nodes_sorting_verification(mode): + """Test that filter_and_sort_nodes actually sorts nodes correctly by different criteria.""" + setup_mode(mode) + pr.settings.store_interpretation_changes = True + + import networkx as nx + + # Create a simple graph with multiple nodes + graph = nx.DiGraph() + graph.add_node('A') + graph.add_node('B') + graph.add_node('C') + graph.add_node('D') + + pr.load_graph(graph) + + # Add facts with different interval bounds to create varied data for sorting + # Node A: [0.7, 0.9] - high lower, high upper + pr.add_fact(pr.Fact('score(A) : [0.7, 0.9]', 'fact_a', 0, 1)) + # Node B: [0.1, 0.2] - low lower, low upper + pr.add_fact(pr.Fact('score(B) : [0.1, 0.2]', 'fact_b', 0, 1)) + # Node C: [0.4, 0.6] - medium lower, medium upper + pr.add_fact(pr.Fact('score(C) : [0.4, 0.6]', 'fact_c', 0, 1)) + # Node D: [0.2, 0.8] - low lower, high upper + pr.add_fact(pr.Fact('score(D) : [0.2, 0.8]', 'fact_d', 0, 1)) + + # Add a simple rule to ensure reasoning happens + pr.add_rule(pr.Rule('result(x) <- score(x)', 'test_rule')) + + # Run reasoning + interpretation = pr.reason(timesteps=1) + + # Test 1: Sort by lower bound, descending (default) + # Expected order: A(0.7), C(0.4), D(0.2), B(0.1) + result = pr.filter_and_sort_nodes(interpretation, ['score'], sort_by='lower', descending=True) + df = result[0] # Get timestep 0 + assert len(df) == 4, 'Should have 4 nodes' + + # Extract lower bounds for each row + lower_bounds = [df.iloc[i]['score'][0] for i in range(len(df))] + # Verify descending order + for i in range(len(lower_bounds) - 1): + assert lower_bounds[i] >= lower_bounds[i+1], f'Lower bounds should be descending: {lower_bounds}' + + # Test 2: Sort by lower bound, ascending + # Expected order: B(0.1), D(0.2), C(0.4), A(0.7) + result = pr.filter_and_sort_nodes(interpretation, ['score'], sort_by='lower', descending=False) + df = result[0] + lower_bounds = [df.iloc[i]['score'][0] for i in range(len(df))] + # Verify ascending order + for i in range(len(lower_bounds) - 1): + assert lower_bounds[i] <= lower_bounds[i+1], f'Lower bounds should be ascending: {lower_bounds}' + + # Test 3: Sort by upper bound, descending + # Expected order: A(0.9), D(0.8), C(0.6), B(0.2) + result = pr.filter_and_sort_nodes(interpretation, ['score'], sort_by='upper', descending=True) + df = result[0] + upper_bounds = [df.iloc[i]['score'][1] for i in range(len(df))] + # Verify descending order + for i in range(len(upper_bounds) - 1): + assert upper_bounds[i] >= upper_bounds[i+1], f'Upper bounds should be descending: {upper_bounds}' + + # Test 4: Sort by upper bound, ascending + # Expected order: B(0.2), C(0.6), D(0.8), A(0.9) + result = pr.filter_and_sort_nodes(interpretation, ['score'], sort_by='upper', descending=False) + df = result[0] + upper_bounds = [df.iloc[i]['score'][1] for i in range(len(df))] + # Verify ascending order + for i in range(len(upper_bounds) - 1): + assert upper_bounds[i] <= upper_bounds[i+1], f'Upper bounds should be ascending: {upper_bounds}' + + +@pytest.mark.parametrize("mode", ["regular", "fp", "parallel"]) +def test_filter_and_sort_edges_sorting_verification(mode): + """Test that filter_and_sort_edges actually sorts edges correctly by different criteria.""" + setup_mode(mode) + pr.settings.store_interpretation_changes = True + + import networkx as nx + + # Create a simple graph with multiple edges + graph = nx.DiGraph() + graph.add_edge('A', 'B') + graph.add_edge('B', 'C') + graph.add_edge('C', 'D') + graph.add_edge('D', 'E') + + pr.load_graph(graph) + + # Add facts with different interval bounds to create varied data for sorting + # Edge A->B: [0.7, 0.9] - high lower, high upper + pr.add_fact(pr.Fact('weight(A, B) : [0.7, 0.9]', 'fact_ab', 0, 1)) + # Edge B->C: [0.1, 0.2] - low lower, low upper + pr.add_fact(pr.Fact('weight(B, C) : [0.1, 0.2]', 'fact_bc', 0, 1)) + # Edge C->D: [0.4, 0.6] - medium lower, medium upper + pr.add_fact(pr.Fact('weight(C, D) : [0.4, 0.6]', 'fact_cd', 0, 1)) + # Edge D->E: [0.2, 0.8] - low lower, high upper + pr.add_fact(pr.Fact('weight(D, E) : [0.2, 0.8]', 'fact_de', 0, 1)) + + # Add a simple rule to ensure reasoning happens + pr.add_rule(pr.Rule('result(x, y) <- weight(x, y)', 'test_rule')) + + # Run reasoning + interpretation = pr.reason(timesteps=1) + + # Test 1: Sort by lower bound, descending (default) + # Expected order: A->B(0.7), C->D(0.4), D->E(0.2), B->C(0.1) + result = pr.filter_and_sort_edges(interpretation, ['weight'], sort_by='lower', descending=True) + df = result[0] # Get timestep 0 + assert len(df) == 4, 'Should have 4 edges' + + # Extract lower bounds for each row + lower_bounds = [df.iloc[i]['weight'][0] for i in range(len(df))] + # Verify descending order + for i in range(len(lower_bounds) - 1): + assert lower_bounds[i] >= lower_bounds[i+1], f'Lower bounds should be descending: {lower_bounds}' + + # Test 2: Sort by lower bound, ascending + # Expected order: B->C(0.1), D->E(0.2), C->D(0.4), A->B(0.7) + result = pr.filter_and_sort_edges(interpretation, ['weight'], sort_by='lower', descending=False) + df = result[0] + lower_bounds = [df.iloc[i]['weight'][0] for i in range(len(df))] + # Verify ascending order + for i in range(len(lower_bounds) - 1): + assert lower_bounds[i] <= lower_bounds[i+1], f'Lower bounds should be ascending: {lower_bounds}' + + # Test 3: Sort by upper bound, descending + # Expected order: A->B(0.9), D->E(0.8), C->D(0.6), B->C(0.2) + result = pr.filter_and_sort_edges(interpretation, ['weight'], sort_by='upper', descending=True) + df = result[0] + upper_bounds = [df.iloc[i]['weight'][1] for i in range(len(df))] + # Verify descending order + for i in range(len(upper_bounds) - 1): + assert upper_bounds[i] >= upper_bounds[i+1], f'Upper bounds should be descending: {upper_bounds}' + + # Test 4: Sort by upper bound, ascending + # Expected order: B->C(0.2), C->D(0.6), D->E(0.8), A->B(0.9) + result = pr.filter_and_sort_edges(interpretation, ['weight'], sort_by='upper', descending=False) + df = result[0] + upper_bounds = [df.iloc[i]['weight'][1] for i in range(len(df))] + # Verify ascending order + for i in range(len(upper_bounds) - 1): + assert upper_bounds[i] <= upper_bounds[i+1], f'Upper bounds should be ascending: {upper_bounds}' From 7bb31a9f5533298152787b2b16bfc76ebf61b1dc Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 16:33:10 -0400 Subject: [PATCH 24/32] Make parallel check more flexible --- .../test_parallel_consistency.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 tests/unit/dont_disable_jit/test_parallel_consistency.py diff --git a/tests/unit/dont_disable_jit/test_parallel_consistency.py b/tests/unit/dont_disable_jit/test_parallel_consistency.py new file mode 100644 index 00000000..b35b1f56 --- /dev/null +++ b/tests/unit/dont_disable_jit/test_parallel_consistency.py @@ -0,0 +1,77 @@ +""" +Tests to validate consistency between related files in the PyReason codebase. +""" + +from pathlib import Path + + +def test_interpretation_parallel_consistency(): + """ + Test that interpretation_parallel.py is identical to interpretation.py + except for the parallel flag in the @numba.njit decorator. + + The only difference should be: + - interpretation.py should have: @numba.njit(cache=True, parallel=False) + - interpretation_parallel.py should have: @numba.njit(cache=True, parallel=True) + """ + # Get the path to the interpretation files + scripts_dir = Path(__file__).parent.parent.parent.parent / "pyreason" / "scripts" / "interpretation" + interpretation_file = scripts_dir / "interpretation.py" + interpretation_parallel_file = scripts_dir / "interpretation_parallel.py" + + # Verify both files exist + assert interpretation_file.exists(), f"File not found: {interpretation_file}" + assert interpretation_parallel_file.exists(), f"File not found: {interpretation_parallel_file}" + + # Read both files + with open(interpretation_file, 'r', encoding='utf-8') as f: + interpretation_lines = f.readlines() + + with open(interpretation_parallel_file, 'r', encoding='utf-8') as f: + interpretation_parallel_lines = f.readlines() + + # Check that both files have the same number of lines + assert len(interpretation_lines) == len(interpretation_parallel_lines), \ + f"Files have different number of lines: {len(interpretation_lines)} vs {len(interpretation_parallel_lines)}" + + # Track all differences + differences = [] + numba_decorator_difference_found = False + numba_decorator_line_num = None + + # Compare line by line + for line_num, (line1, line2) in enumerate(zip(interpretation_lines, interpretation_parallel_lines), start=1): + if line1 != line2: + # Check if this is the expected numba decorator difference + # Strip whitespace for comparison + line1_stripped = line1.strip() + line2_stripped = line2.strip() + + # Check if this matches the expected decorator pattern + if (line1_stripped == "@numba.njit(cache=True, parallel=False)" and + line2_stripped == "@numba.njit(cache=True, parallel=True)"): + # This is the expected difference + if numba_decorator_difference_found: + differences.append( + f"Found multiple @numba.njit decorator differences (line {numba_decorator_line_num} and line {line_num})" + ) + numba_decorator_difference_found = True + numba_decorator_line_num = line_num + else: + # This is an unexpected difference + differences.append( + f"Unexpected difference at line {line_num}:\n" + f" interpretation.py: {line1.strip()}\n" + f" interpretation_parallel.py: {line2.strip()}" + ) + + # Check that we found exactly one expected difference + if not numba_decorator_difference_found: + differences.append( + "Expected difference not found: " + "@numba.njit(cache=True, parallel=False) vs @numba.njit(cache=True, parallel=True)" + ) + + # Assert no unexpected differences + assert len(differences) == 0, \ + f"Found {len(differences)} issue(s) between the files:\n" + "\n\n".join(differences) From a7702dfb03b83fa56744a44dc4df8a0924e42a18 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 16:35:32 -0400 Subject: [PATCH 25/32] Update path functional tests --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 8e36a702..1e365f6b 100644 --- a/Makefile +++ b/Makefile @@ -79,7 +79,7 @@ test-no-coverage: ## Run all tests without coverage collection $(RUN_TESTS) --no-coverage # Individual test suite targets -test-api: ## Run only API tests (tests/unit/api_tests) +test-api: ## Run only API tests (tests/functional/api_tests) @echo "$(BOLD)$(BLUE)Running API tests...$(RESET)" $(RUN_TESTS) --suite api_tests From 546f8fac1fe9501554ae4ee34c6833fcd4a96851 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 16:45:34 -0400 Subject: [PATCH 26/32] Update sync function --- .../workflows/python-package-version-test.yml | 2 +- sync_interpretation_parallel.py | 45 ++++++++++++------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/.github/workflows/python-package-version-test.yml b/.github/workflows/python-package-version-test.yml index 7717ec1c..6666fe2b 100644 --- a/.github/workflows/python-package-version-test.yml +++ b/.github/workflows/python-package-version-test.yml @@ -44,4 +44,4 @@ jobs: pytest tests/functional/api_tests --tb=short -q - name: Pytest Functional Tests run: | - pytest tests/functional/ --tb=short + pytest tests/functional/ --tb=short -q diff --git a/sync_interpretation_parallel.py b/sync_interpretation_parallel.py index 2ad25724..247c11db 100644 --- a/sync_interpretation_parallel.py +++ b/sync_interpretation_parallel.py @@ -3,8 +3,8 @@ Pre-commit hook script to synchronize interpretation_parallel.py from interpretation.py. This script ensures that interpretation_parallel.py is always an exact copy of -interpretation.py, except for line 226 which should have parallel=True instead of -parallel=False in the @numba.njit decorator. +interpretation.py, except for the @numba.njit decorator which should have +parallel=True instead of parallel=False. """ import sys @@ -39,23 +39,38 @@ def sync_interpretation_files(): print(f"Error reading {interpretation_file}: {e}", file=sys.stderr) return 1 - # Verify we have at least 226 lines - if len(lines) < 226: - print(f"Error: {interpretation_file} has fewer than 226 lines", file=sys.stderr) + # Find the line with the @numba.njit decorator that needs to be changed + target_line = "@numba.njit(cache=True, parallel=False)" + replacement_line = "@numba.njit(cache=True, parallel=True)" + + found_indices = [] + for i, line in enumerate(lines): + if line.strip() == target_line: + found_indices.append(i) + + # Validate we found exactly one occurrence + if len(found_indices) == 0: + print(f"Error: Could not find the expected decorator in {interpretation_file}", file=sys.stderr) + print(f" Looking for: {target_line}", file=sys.stderr) + return 1 + + if len(found_indices) > 1: + print(f"Error: Found multiple occurrences of the decorator in {interpretation_file}", file=sys.stderr) + print(f" Found on lines: {[i + 1 for i in found_indices]}", file=sys.stderr) + print(f" Expected exactly one occurrence", file=sys.stderr) return 1 - # Expected line 226 (index 225) in source file - expected_line = "\t@numba.njit(cache=True, parallel=False)\n" + # Found exactly one occurrence - replace it + line_index = found_indices[0] + line_num = line_index + 1 - if lines[225] != expected_line: - print(f"Warning: Line 226 in {interpretation_file} is not as expected.", file=sys.stderr) - print(f" Expected: {expected_line.strip()}", file=sys.stderr) - print(f" Got: {lines[225].strip()}", file=sys.stderr) - print(f" Proceeding with replacement anyway...", file=sys.stderr) + # Preserve the original indentation + original_line = lines[line_index] + indentation = original_line[:len(original_line) - len(original_line.lstrip())] - # Replace line 226 for parallel version + # Create modified lines with the replacement modified_lines = lines.copy() - modified_lines[225] = "\t@numba.njit(cache=True, parallel=True)\n" + modified_lines[line_index] = f"{indentation}{replacement_line}\n" # Write to parallel file try: @@ -66,7 +81,7 @@ def sync_interpretation_files(): return 1 print(f"✓ Successfully synced {interpretation_parallel_file.name} from {interpretation_file.name}") - print(f" Modified line 226: parallel=False → parallel=True") + print(f" Modified line {line_num}: parallel=False → parallel=True") return 0 From 067fd14439276c4d52af5f4e392650b52cd93d59 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 16:58:30 -0400 Subject: [PATCH 27/32] Delete duplicate file --- .../test_parralell_consistency.py | 78 ------------------- 1 file changed, 78 deletions(-) delete mode 100644 tests/unit/dont_disable_jit/test_parralell_consistency.py diff --git a/tests/unit/dont_disable_jit/test_parralell_consistency.py b/tests/unit/dont_disable_jit/test_parralell_consistency.py deleted file mode 100644 index f7f30e7e..00000000 --- a/tests/unit/dont_disable_jit/test_parralell_consistency.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -Tests to validate consistency between related files in the PyReason codebase. -""" - -import os -from pathlib import Path - - -def test_interpretation_parallel_consistency(): - """ - Test that interpretation_parallel.py is identical to interpretation.py - except for the parallel flag in the @numba.njit decorator on line 226. - - interpretation.py line 226 should have: @numba.njit(cache=True, parallel=False) - interpretation_parallel.py line 226 should have: @numba.njit(cache=True, parallel=True) - """ - # Get the path to the interpretation files - scripts_dir = Path(__file__).parent.parent.parent.parent / "pyreason" / "scripts" / "interpretation" - interpretation_file = scripts_dir / "interpretation.py" - interpretation_parallel_file = scripts_dir / "interpretation_parallel.py" - - # Verify both files exist - assert interpretation_file.exists(), f"File not found: {interpretation_file}" - assert interpretation_parallel_file.exists(), f"File not found: {interpretation_parallel_file}" - - # Read both files - with open(interpretation_file, 'r', encoding='utf-8') as f: - interpretation_lines = f.readlines() - - with open(interpretation_parallel_file, 'r', encoding='utf-8') as f: - interpretation_parallel_lines = f.readlines() - - # Check that both files have the same number of lines - assert len(interpretation_lines) == len(interpretation_parallel_lines), \ - f"Files have different number of lines: {len(interpretation_lines)} vs {len(interpretation_parallel_lines)}" - - # Expected difference on line 226 (index 225) - expected_line_226_interpretation = "\t@numba.njit(cache=True, parallel=False)\n" - expected_line_226_interpretation_parallel = "\t@numba.njit(cache=True, parallel=True)\n" - - # Track differences - differences = [] - - # Compare line by line - for line_num, (line1, line2) in enumerate(zip(interpretation_lines, interpretation_parallel_lines), start=1): - if line1 != line2: - # Line 226 should be the only difference - if line_num == 226: - # Verify the expected difference - if line1 != expected_line_226_interpretation: - differences.append( - f"Line {line_num} in interpretation.py is not as expected.\n" - f" Expected: {expected_line_226_interpretation.strip()}\n" - f" Got: {line1.strip()}" - ) - if line2 != expected_line_226_interpretation_parallel: - differences.append( - f"Line {line_num} in interpretation_parallel.py is not as expected.\n" - f" Expected: {expected_line_226_interpretation_parallel.strip()}\n" - f" Got: {line2.strip()}" - ) - else: - # Any other difference is unexpected - differences.append( - f"Unexpected difference at line {line_num}:\n" - f" interpretation.py: {line1.strip()}\n" - f" interpretation_parallel.py: {line2.strip()}" - ) - - # Assert no unexpected differences - assert len(differences) == 0, \ - f"Found {len(differences)} unexpected difference(s) between the files:\n" + "\n\n".join(differences) - - # Verify the expected difference exists on line 226 - assert interpretation_lines[225] == expected_line_226_interpretation, \ - f"Line 226 in interpretation.py is not as expected" - assert interpretation_parallel_lines[225] == expected_line_226_interpretation_parallel, \ - f"Line 226 in interpretation_parallel.py is not as expected" From 37a3e1b041da1d7bbed79a3160a95a13670366c9 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 17:16:05 -0400 Subject: [PATCH 28/32] Move API tests to thier own folder --- .github/workflows/python-package-version-test.yml | 2 +- .pre-commit-config.yaml | 2 +- Makefile | 2 +- contributing.md | 4 ++-- run_tests.py | 2 ++ test_config.json | 2 +- tests/{functional => }/api_tests/conftest.py | 0 .../api_tests/test_pyreason_add_operations.py | 0 .../{functional => }/api_tests/test_pyreason_file_loading.py | 0 tests/{functional => }/api_tests/test_pyreason_reasoning.py | 0 tests/{functional => }/api_tests/test_pyreason_settings.py | 0 .../api_tests/test_pyreason_state_management.py | 0 tests/{functional => }/api_tests/test_pyreason_validation.py | 0 13 files changed, 8 insertions(+), 6 deletions(-) rename tests/{functional => }/api_tests/conftest.py (100%) rename tests/{functional => }/api_tests/test_pyreason_add_operations.py (100%) rename tests/{functional => }/api_tests/test_pyreason_file_loading.py (100%) rename tests/{functional => }/api_tests/test_pyreason_reasoning.py (100%) rename tests/{functional => }/api_tests/test_pyreason_settings.py (100%) rename tests/{functional => }/api_tests/test_pyreason_state_management.py (100%) rename tests/{functional => }/api_tests/test_pyreason_validation.py (100%) diff --git a/.github/workflows/python-package-version-test.yml b/.github/workflows/python-package-version-test.yml index 6666fe2b..fb4a0da0 100644 --- a/.github/workflows/python-package-version-test.yml +++ b/.github/workflows/python-package-version-test.yml @@ -41,7 +41,7 @@ jobs: pytest tests/unit/dont_disable_jit --tb=short -q - name: Pytest API Tests run: | - pytest tests/functional/api_tests --tb=short -q + pytest tests/api_tests --tb=short -q - name: Pytest Functional Tests run: | pytest tests/functional/ --tb=short -q diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e11fbf52..3426c5ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,7 @@ repos: # --- PUSH STAGE: Complete test suite --- - id: pytest-unit-api name: Run pyreason api unit tests - entry: .venv/bin/python -m pytest tests/functional/api_tests --tb=short -q + entry: .venv/bin/python -m pytest tests/api_tests --tb=short -q language: system pass_filenames: false stages: [pre-push] diff --git a/Makefile b/Makefile index 1e365f6b..10384b0e 100644 --- a/Makefile +++ b/Makefile @@ -79,7 +79,7 @@ test-no-coverage: ## Run all tests without coverage collection $(RUN_TESTS) --no-coverage # Individual test suite targets -test-api: ## Run only API tests (tests/functional/api_tests) +test-api: ## Run only API tests (tests/api_tests) @echo "$(BOLD)$(BLUE)Running API tests...$(RESET)" $(RUN_TESTS) --suite api_tests diff --git a/contributing.md b/contributing.md index 1a32478c..7a72b0a5 100644 --- a/contributing.md +++ b/contributing.md @@ -36,7 +36,7 @@ this command to view linting results: PyReason uses a unified test runner that handles multiple test configurations automatically. The test suite is organized into four directories: -- **`tests/functional/api_tests/`** - Tests for main pyreason.py API functions (JIT enabled, real pyreason) +- **`tests/api_tests/`** - Tests for main pyreason.py API functions (JIT enabled, real pyreason) - **`tests/unit/disable_jit/`** - Tests for internal interpretation logic (JIT disabled, stubbed environment) - **`tests/unit/dont_disable_jit/`** - Tests for components that benefit from JIT (JIT enabled, lightweight stubs) - **`tests/functional/`** - End-to-end functional tests (JIT enabled, real pyreason, longer running) @@ -99,7 +99,7 @@ You can still run pytest directly on individual directories: ```bash # API tests -pytest tests/functional/api_tests/ -v +pytest tests/api_tests/ -v # JIT disabled tests NUMBA_DISABLE_JIT=1 pytest tests/unit/disable_jit/ -v diff --git a/run_tests.py b/run_tests.py index 18982992..582949c2 100755 --- a/run_tests.py +++ b/run_tests.py @@ -264,6 +264,8 @@ def _create_coverage_config(self): */__pycache__/* */setup.py */interpretation_parallel.py + */yaml_parser.py + [report] exclude_lines = diff --git a/test_config.json b/test_config.json index 260a8077..454a4a49 100644 --- a/test_config.json +++ b/test_config.json @@ -3,7 +3,7 @@ "api_tests": { "name": "API Tests", "description": "Tests for main pyreason.py API functions", - "path": "tests/functional/api_tests", + "path": "tests/api_tests", "coverage_source": "pyreason", "jit_disabled": false, "uses_real_pyreason": true, diff --git a/tests/functional/api_tests/conftest.py b/tests/api_tests/conftest.py similarity index 100% rename from tests/functional/api_tests/conftest.py rename to tests/api_tests/conftest.py diff --git a/tests/functional/api_tests/test_pyreason_add_operations.py b/tests/api_tests/test_pyreason_add_operations.py similarity index 100% rename from tests/functional/api_tests/test_pyreason_add_operations.py rename to tests/api_tests/test_pyreason_add_operations.py diff --git a/tests/functional/api_tests/test_pyreason_file_loading.py b/tests/api_tests/test_pyreason_file_loading.py similarity index 100% rename from tests/functional/api_tests/test_pyreason_file_loading.py rename to tests/api_tests/test_pyreason_file_loading.py diff --git a/tests/functional/api_tests/test_pyreason_reasoning.py b/tests/api_tests/test_pyreason_reasoning.py similarity index 100% rename from tests/functional/api_tests/test_pyreason_reasoning.py rename to tests/api_tests/test_pyreason_reasoning.py diff --git a/tests/functional/api_tests/test_pyreason_settings.py b/tests/api_tests/test_pyreason_settings.py similarity index 100% rename from tests/functional/api_tests/test_pyreason_settings.py rename to tests/api_tests/test_pyreason_settings.py diff --git a/tests/functional/api_tests/test_pyreason_state_management.py b/tests/api_tests/test_pyreason_state_management.py similarity index 100% rename from tests/functional/api_tests/test_pyreason_state_management.py rename to tests/api_tests/test_pyreason_state_management.py diff --git a/tests/functional/api_tests/test_pyreason_validation.py b/tests/api_tests/test_pyreason_validation.py similarity index 100% rename from tests/functional/api_tests/test_pyreason_validation.py rename to tests/api_tests/test_pyreason_validation.py From c3d04942ac725245dc51b45921f627b64bf4f3fa Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 17:19:09 -0400 Subject: [PATCH 29/32] Fix typo --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3426c5ab..4b2ef05a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,7 +42,7 @@ repos: # --- PUSH STAGE: Complete test suite --- - id: pytest-unit-api - name: Run pyreason api unit tests + name: Run pyreason api tests entry: .venv/bin/python -m pytest tests/api_tests --tb=short -q language: system pass_filenames: false From 3540cfb32f3e0b683fd0573ee060b58b3e864023 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Thu, 23 Oct 2025 17:43:39 -0400 Subject: [PATCH 30/32] Upd contributing.md --- contributing.md | 36 ++---------------------------------- 1 file changed, 2 insertions(+), 34 deletions(-) diff --git a/contributing.md b/contributing.md index 7a72b0a5..f549c7b2 100644 --- a/contributing.md +++ b/contributing.md @@ -5,7 +5,6 @@ Install the project requirements and the pre-commit framework: ```bash pip install -r requirements.txt -pip install pre-commit ``` ## Setting up Pre-Commit Hooks @@ -57,33 +56,12 @@ make test-fast make coverage-html ``` -### Individual Test Suites - -```bash -# API tests (real pyreason, JIT enabled) -make test-api - -# JIT disabled tests (stubbed environment) -make test-jit - -# JIT enabled tests (lightweight stubs) -make test-no-jit - -# Consistency tests -make test-consistency - -# Functional/end-to-end tests -make test-functional -``` - ### Advanced Options ```bash -# Run with parallel execution where possible -make test-parallel +# Run with sequential execution +make make test-sequential -# Run without coverage collection (faster) -python run_tests.py --no-coverage # Run specific suites python run_tests.py --suite api_tests --suite dont_disable_jit @@ -101,9 +79,6 @@ You can still run pytest directly on individual directories: # API tests pytest tests/api_tests/ -v -# JIT disabled tests -NUMBA_DISABLE_JIT=1 pytest tests/unit/disable_jit/ -v - # JIT enabled tests pytest tests/unit/dont_disable_jit/ -v @@ -138,10 +113,3 @@ pytest tests/functional/test_hello_world.py -v # Clean up generated files make clean ``` - -**Common Issues:** -- **Functional tests fail with warnings**: The pytest.ini has been updated to ignore expected warnings from numba and networkx -- **Tests time out**: Functional tests have longer timeouts (600s) and global timeout is disabled -- **Import errors**: Ensure pytest and dependencies are installed with `make install-deps` - -Running tests locally before committing or pushing helps catch issues early and speeds up code review. The unified test runner ensures consistent behavior across different development environments. From 11b546947e6f9e8fcb964c428d854c78758b2b80 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sun, 26 Oct 2025 09:18:30 -0700 Subject: [PATCH 31/32] Add scripts to api docs --- docs/source/conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index dc1279d2..ec9361d1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -38,9 +38,9 @@ #autoapi_dirs = ['../pyreason/pyreason'] autoapi_root = 'pyreason' -autoapi_ignore = ['*/scripts/*', '*/examples/*', '*/pyreason.pyreason/*'] +autoapi_ignore = ['*/examples/*', '*/pyreason.pyreason/*'] -# Ignore modules in the 'scripts' folder +# Scripts folder is now included in documentation # autoapi_ignore_modules = ['pyreason.scripts'] From c95621d724657e193f2fae770a1e7a67906de0e8 Mon Sep 17 00:00:00 2001 From: ColtonPayne <72282946+ColtonPayne@users.noreply.github.com> Date: Sun, 26 Oct 2025 09:34:07 -0700 Subject: [PATCH 32/32] Delete fp test folder --- tests/fp_tests/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/fp_tests/__init__.py diff --git a/tests/fp_tests/__init__.py b/tests/fp_tests/__init__.py deleted file mode 100644 index e69de29b..00000000