diff --git a/backend/database/knowledge_graph.py b/backend/database/knowledge_graph.py index 87b5d7766a..39cfe3101a 100644 --- a/backend/database/knowledge_graph.py +++ b/backend/database/knowledge_graph.py @@ -4,6 +4,7 @@ from google.cloud import firestore from google.cloud.firestore_v1 import FieldFilter +from google.api_core import exceptions as google_exceptions from ._client import db @@ -244,3 +245,87 @@ def _batch_delete(coll_ref): edges_ref = user_ref.collection(knowledge_edges_collection) _batch_delete(edges_ref) + + +def cleanup_for_memory(uid: str, memory_id: str): + """ + Removes a memory_id from all nodes and edges in the knowledge graph atomically. + If a node or edge is no longer associated with any memories, it is deleted. + Also removes edges that point to a deleted node. + Handles Firestore query limits and atomicity using transactions. + """ + try: + user_ref = db.collection(users_collection).document(uid) + + @firestore.transactional + def update_in_transaction(transaction, nodes_to_delete_ids): + # Fetch nodes and edges that currently contain memory_id + nodes_query = user_ref.collection(knowledge_nodes_collection).where(filter=FieldFilter('memory_ids', 'array_contains', memory_id)) + edges_query = user_ref.collection(knowledge_edges_collection).where(filter=FieldFilter('memory_ids', 'array_contains', memory_id)) + + # Fetch relevant documents within the transaction + nodes_docs = list(nodes_query.stream()) + edges_docs = list(edges_query.stream()) + + # Track nodes that will be deleted to clean up related edges + nodes_fully_deleted_in_this_tx = set() + + # Process Nodes + for doc in nodes_docs: + node_data = doc.to_dict() + memory_ids = node_data.get('memory_ids', []) + + if len(memory_ids) == 1 and memory_ids[0] == memory_id: + # Node will be deleted as this is its only remaining memory_id + transaction.delete(doc.reference) + nodes_fully_deleted_in_this_tx.add(doc.id) + else: + # Only remove the memory_id + transaction.update(doc.reference, {'memory_ids': firestore.ArrayRemove([memory_id])}) + + # Process Edges (those explicitly linked to this memory_id) + for doc in edges_docs: + edge_data = doc.to_dict() + memory_ids = edge_data.get('memory_ids', []) + + if len(memory_ids) == 1 and memory_ids[0] == memory_id: + # Edge will be deleted as this is its only remaining memory_id + transaction.delete(doc.reference) + else: + # Only remove the memory_id + transaction.update(doc.reference, {'memory_ids': firestore.ArrayRemove([memory_id])}) + + # Process potentially orphaned edges (those whose source/target nodes are deleted in this transaction) + if nodes_fully_deleted_in_this_tx: + # Firestore 'in' query limit is 10, so chunk the node IDs if necessary + chunk_size = 10 + nodes_chunks = [list(nodes_fully_deleted_in_this_tx)[i:i + chunk_size] for i in range(0, len(nodes_fully_deleted_in_this_tx), chunk_size)] + + for chunk in nodes_chunks: + # Delete edges where source node is in the chunk + source_edges_query = user_ref.collection(knowledge_edges_collection).where(filter=FieldFilter('source_id', 'in', chunk)) + for doc in source_edges_query.stream(): + transaction.delete(doc.reference) + + # Delete edges where target node is in the chunk + target_edges_query = user_ref.collection(knowledge_edges_collection).where(filter=FieldFilter('target_id', 'in', chunk)) + for doc in target_edges_query.stream(): + transaction.delete(doc.reference) + + print(f"Knowledge graph transaction complete for memory_id: {memory_id}") + + # Run the transaction + transaction = db.transaction() + update_in_transaction(transaction, set()) # Pass an empty set for initial call. Nodes to delete are determined inside. + + except google_exceptions.GoogleAPICallError as e: + print(f"ERROR: Firestore API error during KG cleanup for memory_id {memory_id}: {e}") + raise # Re-raise to indicate a critical failure + + except ValueError as e: + print(f"ERROR: Data validation error during KG cleanup for memory_id {memory_id}: {e}") + raise # Re-raise to indicate a critical failure + + except Exception as e: # Catch any other unexpected errors + print(f"ERROR: Unexpected error during KG cleanup for memory_id {memory_id}: {e}") + raise # Re-raise to indicate a critical failure diff --git a/backend/database/memories.py b/backend/database/memories.py index 5d31e1ce8a..dcab2d4e09 100644 --- a/backend/database/memories.py +++ b/backend/database/memories.py @@ -7,6 +7,7 @@ from ._client import db from database import users as users_db +from database import knowledge_graph as kg_db from utils import encryption from .helpers import set_data_protection_level, prepare_for_write, prepare_for_read @@ -223,16 +224,24 @@ def delete_memory(uid: str, memory_id: str): memories_ref = user_ref.collection(memories_collection) memory_ref = memories_ref.document(memory_id) memory_ref.delete() + + # Trigger cascading cleanup for the knowledge graph + kg_db.cleanup_for_memory(uid, memory_id) def delete_all_memories(uid: str): user_ref = db.collection(users_collection).document(uid) memories_ref = user_ref.collection(memories_collection) + + # Efficiently delete all documents in the collection batch = db.batch() for doc in memories_ref.stream(): batch.delete(doc.reference) batch.commit() + # Trigger a single, efficient cleanup of the entire knowledge graph + kg_db.delete_knowledge_graph(uid) + def delete_memories_for_conversation(uid: str, memory_id: str): batch = db.batch()