Skip to content
85 changes: 85 additions & 0 deletions backend/database/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from google.cloud import firestore
from google.cloud.firestore_v1 import FieldFilter
from google.api_core import exceptions as google_exceptions

from ._client import db

Expand Down Expand Up @@ -244,3 +245,87 @@ def _batch_delete(coll_ref):

edges_ref = user_ref.collection(knowledge_edges_collection)
_batch_delete(edges_ref)


def cleanup_for_memory(uid: str, memory_id: str):
"""
Removes a memory_id from all nodes and edges in the knowledge graph atomically.
If a node or edge is no longer associated with any memories, it is deleted.
Also removes edges that point to a deleted node.
Handles Firestore query limits and atomicity using transactions.
"""
try:
user_ref = db.collection(users_collection).document(uid)

@firestore.transactional
def update_in_transaction(transaction, nodes_to_delete_ids):
# Fetch nodes and edges that currently contain memory_id
nodes_query = user_ref.collection(knowledge_nodes_collection).where(filter=FieldFilter('memory_ids', 'array_contains', memory_id))
edges_query = user_ref.collection(knowledge_edges_collection).where(filter=FieldFilter('memory_ids', 'array_contains', memory_id))

# Fetch relevant documents within the transaction
nodes_docs = list(nodes_query.stream())
edges_docs = list(edges_query.stream())

# Track nodes that will be deleted to clean up related edges
nodes_fully_deleted_in_this_tx = set()

# Process Nodes
for doc in nodes_docs:
node_data = doc.to_dict()
memory_ids = node_data.get('memory_ids', [])

if len(memory_ids) == 1 and memory_ids[0] == memory_id:
# Node will be deleted as this is its only remaining memory_id
transaction.delete(doc.reference)
nodes_fully_deleted_in_this_tx.add(doc.id)
else:
# Only remove the memory_id
transaction.update(doc.reference, {'memory_ids': firestore.ArrayRemove([memory_id])})

# Process Edges (those explicitly linked to this memory_id)
for doc in edges_docs:
edge_data = doc.to_dict()
memory_ids = edge_data.get('memory_ids', [])

if len(memory_ids) == 1 and memory_ids[0] == memory_id:
# Edge will be deleted as this is its only remaining memory_id
transaction.delete(doc.reference)
else:
# Only remove the memory_id
transaction.update(doc.reference, {'memory_ids': firestore.ArrayRemove([memory_id])})

# Process potentially orphaned edges (those whose source/target nodes are deleted in this transaction)
if nodes_fully_deleted_in_this_tx:
# Firestore 'in' query limit is 10, so chunk the node IDs if necessary
chunk_size = 10
nodes_chunks = [list(nodes_fully_deleted_in_this_tx)[i:i + chunk_size] for i in range(0, len(nodes_fully_deleted_in_this_tx), chunk_size)]

for chunk in nodes_chunks:
# Delete edges where source node is in the chunk
source_edges_query = user_ref.collection(knowledge_edges_collection).where(filter=FieldFilter('source_id', 'in', chunk))
for doc in source_edges_query.stream():
transaction.delete(doc.reference)

# Delete edges where target node is in the chunk
target_edges_query = user_ref.collection(knowledge_edges_collection).where(filter=FieldFilter('target_id', 'in', chunk))
for doc in target_edges_query.stream():
transaction.delete(doc.reference)

print(f"Knowledge graph transaction complete for memory_id: {memory_id}")

# Run the transaction
transaction = db.transaction()
update_in_transaction(transaction, set()) # Pass an empty set for initial call. Nodes to delete are determined inside.

except google_exceptions.GoogleAPICallError as e:
print(f"ERROR: Firestore API error during KG cleanup for memory_id {memory_id}: {e}")
raise # Re-raise to indicate a critical failure

except ValueError as e:
print(f"ERROR: Data validation error during KG cleanup for memory_id {memory_id}: {e}")
raise # Re-raise to indicate a critical failure

except Exception as e: # Catch any other unexpected errors
print(f"ERROR: Unexpected error during KG cleanup for memory_id {memory_id}: {e}")
raise # Re-raise to indicate a critical failure
9 changes: 9 additions & 0 deletions backend/database/memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ._client import db
from database import users as users_db
from database import knowledge_graph as kg_db
from utils import encryption
from .helpers import set_data_protection_level, prepare_for_write, prepare_for_read

Expand Down Expand Up @@ -223,16 +224,24 @@ def delete_memory(uid: str, memory_id: str):
memories_ref = user_ref.collection(memories_collection)
memory_ref = memories_ref.document(memory_id)
memory_ref.delete()

# Trigger cascading cleanup for the knowledge graph
kg_db.cleanup_for_memory(uid, memory_id)


def delete_all_memories(uid: str):
user_ref = db.collection(users_collection).document(uid)
memories_ref = user_ref.collection(memories_collection)

# Efficiently delete all documents in the collection
batch = db.batch()
for doc in memories_ref.stream():
batch.delete(doc.reference)
batch.commit()

# Trigger a single, efficient cleanup of the entire knowledge graph
kg_db.delete_knowledge_graph(uid)


def delete_memories_for_conversation(uid: str, memory_id: str):
batch = db.batch()
Expand Down