diff --git a/py/trace.py b/py/trace.py index 83f55502..cf674aa2 100644 --- a/py/trace.py +++ b/py/trace.py @@ -10,6 +10,8 @@ def trace(cls, start_node_id, prompt): class_type = prompt[start_node_id]["class_type"] Q = deque() Q.append((start_node_id, 0)) + visited = set() # Keep track of visited nodes + visited.add(start_node_id) trace_tree = {start_node_id: (0, class_type)} while len(Q) > 0: current_node_id, distance = Q.popleft() @@ -17,9 +19,11 @@ def trace(cls, start_node_id, prompt): for value in input_fields.values(): if isinstance(value, list): nid = value[0] - class_type = prompt[nid]["class_type"] - trace_tree[nid] = (distance + 1, class_type) - Q.append((nid, distance + 1)) + if nid not in visited: # Ensure the node is not visited + class_type = prompt[nid]["class_type"] + trace_tree[nid] = (distance + 1, class_type) + Q.append((nid, distance + 1)) + visited.add(nid) # Mark the node as visited return trace_tree @classmethod