From 48697db4758ccf8fc619e74ba4873f2673677461 Mon Sep 17 00:00:00 2001 From: dugubuyan Date: Wed, 11 Mar 2026 10:49:26 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(memory):=20deduplicate=20epi?= =?UTF-8?q?sodic/event=5Flog=20on=20re-memorize=20and=20add=20foresight=20?= =?UTF-8?q?expiry=20cleanup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/biz_layer/mem_cleanup.py | 82 +++++++++++++++++++ src/biz_layer/mem_memorize.py | 37 +++++++++ .../episodic_memory_raw_repository.py | 40 +++++++++ .../repository/foresight_record_repository.py | 33 ++++++++ .../episodic_memory_es_repository.py | 44 ++++++++++ .../episodic_memory_milvus_repository.py | 39 +++++++++ .../repository/event_log_es_repository.py | 46 +++++++++++ .../repository/foresight_es_repository.py | 46 +++++++++++ 8 files changed, 367 insertions(+) create mode 100644 src/biz_layer/mem_cleanup.py diff --git a/src/biz_layer/mem_cleanup.py b/src/biz_layer/mem_cleanup.py new file mode 100644 index 00000000..052f0f51 --- /dev/null +++ b/src/biz_layer/mem_cleanup.py @@ -0,0 +1,82 @@ +""" +Memory cleanup utilities. + +Provides scheduled cleanup tasks for expired memory records. +Currently handles foresight expiry: records whose validity window has passed +are removed from all three stores (Milvus → Elasticsearch → MongoDB) in that +order to minimise the window where a record is searchable but absent from the +primary store. +""" + +from datetime import datetime +from typing import Dict + +from common_utils.datetime_utils import get_now_with_timezone +from core.di import get_bean_by_type +from core.observation.logger import get_logger +from infra_layer.adapters.out.persistence.repository.foresight_record_repository import ( + ForesightRecordRawRepository, +) +from infra_layer.adapters.out.search.repository.foresight_es_repository import ( + ForesightEsRepository, +) +from infra_layer.adapters.out.search.repository.foresight_milvus_repository import ( + ForesightMilvusRepository, +) + +logger = get_logger(__name__) + + +async def cleanup_expired_foresights( + before: datetime | None = None, +) -> Dict[str, int]: + """ + Delete foresight records that have passed their validity end time. + + Deletion order: Milvus → Elasticsearch → MongoDB. + This ensures that even if a later step fails, the record is no longer + returned by vector or keyword search. + + Args: + before: Treat records with end_time < before as expired. + Defaults to the current time when not provided. + + Returns: + Dict with keys ``milvus``, ``es``, ``mongo`` and the number of + records deleted from each store. + """ + if before is None: + before = get_now_with_timezone() + + stats: Dict[str, int] = {"milvus": 0, "es": 0, "mongo": 0} + + foresight_milvus_repo = get_bean_by_type(ForesightMilvusRepository) + foresight_es_repo = get_bean_by_type(ForesightEsRepository) + foresight_mongo_repo = get_bean_by_type(ForesightRecordRawRepository) + + # Step 1: remove from Milvus (vector search) + try: + stats["milvus"] = await foresight_milvus_repo.delete_by_filters(end_time=before) + except Exception as exc: + logger.error("Failed to delete expired foresights from Milvus: %s", exc) + + # Step 2: remove from Elasticsearch (keyword search) + try: + stats["es"] = await foresight_es_repo.delete_expired(before=before) + except Exception as exc: + logger.error("Failed to delete expired foresights from ES: %s", exc) + + # Step 3: remove from MongoDB (primary store) + try: + stats["mongo"] = await foresight_mongo_repo.delete_expired(before=before) + except Exception as exc: + logger.error("Failed to delete expired foresights from MongoDB: %s", exc) + + logger.info( + "✅ Expired foresight cleanup complete (before=%s): milvus=%d es=%d mongo=%d", + before.isoformat(), + stats["milvus"], + stats["es"], + stats["mongo"], + ) + return stats diff --git a/src/biz_layer/mem_memorize.py b/src/biz_layer/mem_memorize.py index e951a13c..29e10345 100644 --- a/src/biz_layer/mem_memorize.py +++ b/src/biz_layer/mem_memorize.py @@ -79,6 +79,12 @@ from infra_layer.adapters.out.search.repository.episodic_memory_es_repository import ( EpisodicMemoryEsRepository, ) +from infra_layer.adapters.out.search.repository.event_log_es_repository import ( + EventLogEsRepository, +) +from infra_layer.adapters.out.search.repository.event_log_milvus_repository import ( + EventLogMilvusRepository, +) from biz_layer.mem_sync import MemorySyncService from core.context.context import get_current_app_info @@ -1122,6 +1128,19 @@ async def save_memory_docs( saved_episodic: List[Any] = [] for doc in episodic_docs: + # Deduplicate: remove any existing records from the same source MemCell + # for this specific user before inserting the new one. + # Key is (parent_id, user_id) because one MemCell produces one episode + # per participant (personal) plus one group episode (user_id=None/""). + parent_id = getattr(doc, "parent_id", None) + user_id = getattr(doc, "user_id", None) + if parent_id: + await asyncio.gather( + episodic_repo.delete_by_parent_id(parent_id, user_id=user_id), + episodic_es_repo.delete_by_parent_id(parent_id, user_id=user_id), + episodic_milvus_repo.delete_by_parent_id(parent_id, user_id=user_id), + ) + saved_doc = await episodic_repo.append_episodic_memory(doc) saved_episodic.append(saved_doc) @@ -1158,6 +1177,24 @@ async def save_memory_docs( event_log_docs = grouped_docs.get(MemoryType.EVENT_LOG, []) if event_log_docs: event_log_repo = get_bean_by_type(EventLogRecordRawRepository) + event_log_es_repo = get_bean_by_type(EventLogEsRepository) + event_log_milvus_repo = get_bean_by_type(EventLogMilvusRepository) + + # Deduplicate: collect unique (parent_id, user_id) pairs and delete old records + # before batch-inserting the new ones. + seen_parent_keys: set = set() + for doc in event_log_docs: + parent_id = getattr(doc, "parent_id", None) + user_id = getattr(doc, "user_id", None) + key = (parent_id, user_id) + if parent_id and key not in seen_parent_keys: + seen_parent_keys.add(key) + await asyncio.gather( + event_log_repo.delete_by_parent_id(parent_id), + event_log_es_repo.delete_by_parent_id(parent_id, user_id=user_id), + event_log_milvus_repo.delete_by_parent_id(parent_id), + ) + saved_event_logs = await event_log_repo.create_batch(event_log_docs) saved_result[MemoryType.EVENT_LOG] = saved_event_logs diff --git a/src/infra_layer/adapters/out/persistence/repository/episodic_memory_raw_repository.py b/src/infra_layer/adapters/out/persistence/repository/episodic_memory_raw_repository.py index 7240c403..00cdc54d 100644 --- a/src/infra_layer/adapters/out/persistence/repository/episodic_memory_raw_repository.py +++ b/src/infra_layer/adapters/out/persistence/repository/episodic_memory_raw_repository.py @@ -292,6 +292,46 @@ async def delete_by_event_id( ) return False + async def delete_by_parent_id( + self, + parent_id: str, + user_id: Optional[str] = None, + session: Optional[AsyncClientSession] = None, + ) -> int: + """ + Delete episodic memories by parent MemCell ID, optionally scoped to a user. + + Args: + parent_id: Source MemCell ID (stored as parent_id field) + user_id: If provided, only delete records belonging to this user. + Pass None to delete all records (group + personal) for the parent. + session: Optional MongoDB session for transaction support + + Returns: + Number of deleted records + """ + try: + query_filter: Dict[str, Any] = {"parent_id": parent_id} + if user_id is not None: + query_filter["user_id"] = user_id + + result = await self.model.find(query_filter, session=session).delete() + count = result.deleted_count if result else 0 + logger.info( + "✅ Deleted episodic memories by parent_id=%s user_id=%s: %d records", + parent_id, + user_id, + count, + ) + return count + except Exception as e: + logger.error( + "❌ Failed to delete episodic memories by parent_id=%s: %s", + parent_id, + e, + ) + return 0 + async def delete_by_user_id( self, user_id: str, session: Optional[AsyncClientSession] = None ) -> int: diff --git a/src/infra_layer/adapters/out/persistence/repository/foresight_record_repository.py b/src/infra_layer/adapters/out/persistence/repository/foresight_record_repository.py index 220affd4..61ec06e2 100644 --- a/src/infra_layer/adapters/out/persistence/repository/foresight_record_repository.py +++ b/src/infra_layer/adapters/out/persistence/repository/foresight_record_repository.py @@ -266,6 +266,39 @@ async def find_by_filters( logger.error("❌ Failed to retrieve foresights: %s", e) return [] + async def delete_expired( + self, + before: datetime, + session: Optional[AsyncClientSession] = None, + ) -> int: + """ + Delete foresight records whose validity period ended before the given time. + + Args: + before: Delete foresights with end_time strictly before this datetime. + end_time is stored as an ISO date string (YYYY-MM-DD). + session: Optional MongoDB session for transaction support + + Returns: + Number of deleted records + """ + try: + from common_utils.datetime_utils import to_date_str + + before_str = to_date_str(before) + query_filter: Dict[str, Any] = {"end_time": {"$lt": before_str}} + result = await self.model.find(query_filter, session=session).delete() + count = result.deleted_count if result else 0 + logger.info( + "✅ Deleted expired foresights from MongoDB (before=%s): %d records", + before_str, + count, + ) + return count + except Exception as e: + logger.error("❌ Failed to delete expired foresights from MongoDB: %s", e) + return 0 + async def delete_by_id( self, memory_id: str, session: Optional[AsyncClientSession] = None ) -> bool: diff --git a/src/infra_layer/adapters/out/search/repository/episodic_memory_es_repository.py b/src/infra_layer/adapters/out/search/repository/episodic_memory_es_repository.py index 999fc6b0..22b67a56 100644 --- a/src/infra_layer/adapters/out/search/repository/episodic_memory_es_repository.py +++ b/src/infra_layer/adapters/out/search/repository/episodic_memory_es_repository.py @@ -494,6 +494,50 @@ async def append_episodic_memory( # ==================== Deletion functionality ==================== + async def delete_by_parent_id( + self, + parent_id: str, + user_id: Optional[str] = None, + refresh: bool = False, + ) -> int: + """ + Delete episodic memory documents by parent MemCell ID. + + Args: + parent_id: Source MemCell ID + user_id: If provided, only delete records for this user. + refresh: Whether to refresh the index immediately + + Returns: + Number of deleted documents + """ + try: + filter_queries: List[Dict[str, Any]] = [{"term": {"parent_id": parent_id}}] + if user_id is not None: + filter_queries.append({"term": {"user_id": user_id}}) + + delete_query = {"bool": {"must": filter_queries}} + client = await self.get_client() + index_name = self.get_index_name() + response = await client.delete_by_query( + index=index_name, + body={"query": delete_query}, + refresh=refresh, + ) + deleted_count = response.get("deleted", 0) + logger.debug( + "✅ Deleted episodic memory by parent_id=%s user_id=%s: %d records", + parent_id, + user_id, + deleted_count, + ) + return deleted_count + except Exception as e: + logger.error( + "❌ Failed to delete episodic memory by parent_id=%s: %s", parent_id, e + ) + raise + async def delete_by_event_id(self, event_id: str, refresh: bool = False) -> bool: """ Delete episodic memory document by event_id diff --git a/src/infra_layer/adapters/out/search/repository/episodic_memory_milvus_repository.py b/src/infra_layer/adapters/out/search/repository/episodic_memory_milvus_repository.py index 4aad056b..55609a45 100644 --- a/src/infra_layer/adapters/out/search/repository/episodic_memory_milvus_repository.py +++ b/src/infra_layer/adapters/out/search/repository/episodic_memory_milvus_repository.py @@ -308,6 +308,45 @@ async def vector_search( # ==================== Deletion Functionality ==================== + async def delete_by_parent_id( + self, + parent_id: str, + user_id: Optional[str] = None, + ) -> int: + """ + Delete episodic memory vectors by parent MemCell ID. + + Args: + parent_id: Source MemCell ID + user_id: If provided, only delete records for this user. + + Returns: + Number of deleted records + """ + try: + expr = f'parent_id == "{parent_id}"' + if user_id is not None: + expr += f' and user_id == "{user_id}"' + + results = await self.collection.query(expr=expr, output_fields=["id"]) + delete_count = len(results) + if delete_count > 0: + await self.collection.delete(expr) + logger.debug( + "✅ Deleted episodic memory vectors by parent_id=%s user_id=%s: %d records", + parent_id, + user_id, + delete_count, + ) + return delete_count + except Exception as e: + logger.error( + "❌ Failed to delete episodic memory vectors by parent_id=%s: %s", + parent_id, + e, + ) + raise + async def delete_by_event_id(self, event_id: str) -> bool: """ Delete episodic memory document by event_id diff --git a/src/infra_layer/adapters/out/search/repository/event_log_es_repository.py b/src/infra_layer/adapters/out/search/repository/event_log_es_repository.py index 39932191..7b6bd0b1 100644 --- a/src/infra_layer/adapters/out/search/repository/event_log_es_repository.py +++ b/src/infra_layer/adapters/out/search/repository/event_log_es_repository.py @@ -354,3 +354,49 @@ async def multi_search( e, ) raise + + # ==================== Deletion functionality ==================== + + async def delete_by_parent_id( + self, + parent_id: str, + user_id: Optional[str] = None, + refresh: bool = False, + ) -> int: + """ + Delete event log documents by parent memory ID. + + Args: + parent_id: Parent memory ID (MemCell or Episode ID) + user_id: If provided, only delete records for this user. + refresh: Whether to refresh the index immediately + + Returns: + Number of deleted documents + """ + try: + filter_queries: List[Dict[str, Any]] = [{"term": {"parent_id": parent_id}}] + if user_id is not None: + filter_queries.append({"term": {"user_id": user_id}}) + + delete_query = {"bool": {"must": filter_queries}} + client = await self.get_client() + index_name = self.get_index_name() + response = await client.delete_by_query( + index=index_name, + body={"query": delete_query}, + refresh=refresh, + ) + deleted_count = response.get("deleted", 0) + logger.debug( + "✅ Deleted event logs by parent_id=%s user_id=%s: %d records", + parent_id, + user_id, + deleted_count, + ) + return deleted_count + except Exception as e: + logger.error( + "❌ Failed to delete event logs by parent_id=%s: %s", parent_id, e + ) + raise diff --git a/src/infra_layer/adapters/out/search/repository/foresight_es_repository.py b/src/infra_layer/adapters/out/search/repository/foresight_es_repository.py index 4f775bb5..aff69ba0 100644 --- a/src/infra_layer/adapters/out/search/repository/foresight_es_repository.py +++ b/src/infra_layer/adapters/out/search/repository/foresight_es_repository.py @@ -409,3 +409,49 @@ def _parse_datetime(value: Optional[str]) -> Optional[datetime]: return datetime.fromisoformat(value.replace("Z", "+00:00")) except ValueError: return None + + # ==================== Deletion functionality ==================== + + async def delete_expired( + self, + before: datetime, + refresh: bool = False, + ) -> int: + """ + Delete foresight documents whose validity period has ended before the given time. + + The end_time is stored as an ISO string inside the ``extend`` field. + We use a range query on ``extend.end_time`` to find expired records. + + Args: + before: Delete foresights with end_time strictly before this datetime. + refresh: Whether to refresh the index immediately + + Returns: + Number of deleted documents + """ + try: + # extend.end_time is stored as ISO string, e.g. "2024-03-01T00:00:00" + before_str = before.isoformat() + delete_query = { + "range": { + "extend.end_time": {"lt": before_str} + } + } + client = await self.get_client() + index_name = self.get_index_name() + response = await client.delete_by_query( + index=index_name, + body={"query": delete_query}, + refresh=refresh, + ) + deleted_count = response.get("deleted", 0) + logger.info( + "✅ Deleted expired foresights from ES (before=%s): %d records", + before_str, + deleted_count, + ) + return deleted_count + except Exception as e: + logger.error("❌ Failed to delete expired foresights from ES: %s", e) + raise