Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/biz_layer/mem_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Memory cleanup utilities.

Provides scheduled cleanup tasks for expired memory records.
Currently handles foresight expiry: records whose validity window has passed
are removed from all three stores (Milvus → Elasticsearch → MongoDB) in that
order to minimise the window where a record is searchable but absent from the
primary store.
"""

from datetime import datetime
from typing import Dict

from common_utils.datetime_utils import get_now_with_timezone
from core.di import get_bean_by_type
from core.observation.logger import get_logger
from infra_layer.adapters.out.persistence.repository.foresight_record_repository import (
ForesightRecordRawRepository,
)
from infra_layer.adapters.out.search.repository.foresight_es_repository import (
ForesightEsRepository,
)
from infra_layer.adapters.out.search.repository.foresight_milvus_repository import (
ForesightMilvusRepository,
)

logger = get_logger(__name__)


async def cleanup_expired_foresights(
before: datetime | None = None,
) -> Dict[str, int]:
"""
Delete foresight records that have passed their validity end time.

Deletion order: Milvus → Elasticsearch → MongoDB.
This ensures that even if a later step fails, the record is no longer
returned by vector or keyword search.

Args:
before: Treat records with end_time < before as expired.
Defaults to the current time when not provided.

Returns:
Dict with keys ``milvus``, ``es``, ``mongo`` and the number of
records deleted from each store.
"""
if before is None:
before = get_now_with_timezone()

stats: Dict[str, int] = {"milvus": 0, "es": 0, "mongo": 0}

foresight_milvus_repo = get_bean_by_type(ForesightMilvusRepository)
foresight_es_repo = get_bean_by_type(ForesightEsRepository)
foresight_mongo_repo = get_bean_by_type(ForesightRecordRawRepository)

# Step 1: remove from Milvus (vector search)
try:
stats["milvus"] = await foresight_milvus_repo.delete_by_filters(end_time=before)
except Exception as exc:
logger.error("Failed to delete expired foresights from Milvus: %s", exc)

# Step 2: remove from Elasticsearch (keyword search)
try:
stats["es"] = await foresight_es_repo.delete_expired(before=before)
except Exception as exc:
logger.error("Failed to delete expired foresights from ES: %s", exc)

# Step 3: remove from MongoDB (primary store)
try:
stats["mongo"] = await foresight_mongo_repo.delete_expired(before=before)
except Exception as exc:
logger.error("Failed to delete expired foresights from MongoDB: %s", exc)

logger.info(
"✅ Expired foresight cleanup complete (before=%s): milvus=%d es=%d mongo=%d",
before.isoformat(),
stats["milvus"],
stats["es"],
stats["mongo"],
)
return stats
37 changes: 37 additions & 0 deletions src/biz_layer/mem_memorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@
from infra_layer.adapters.out.search.repository.episodic_memory_es_repository import (
EpisodicMemoryEsRepository,
)
from infra_layer.adapters.out.search.repository.event_log_es_repository import (
EventLogEsRepository,
)
from infra_layer.adapters.out.search.repository.event_log_milvus_repository import (
EventLogMilvusRepository,
)
from biz_layer.mem_sync import MemorySyncService
from core.context.context import get_current_app_info

Expand Down Expand Up @@ -1122,6 +1128,19 @@ async def save_memory_docs(
saved_episodic: List[Any] = []

for doc in episodic_docs:
# Deduplicate: remove any existing records from the same source MemCell
# for this specific user before inserting the new one.
# Key is (parent_id, user_id) because one MemCell produces one episode
# per participant (personal) plus one group episode (user_id=None/"").
parent_id = getattr(doc, "parent_id", None)
user_id = getattr(doc, "user_id", None)
if parent_id:
await asyncio.gather(
episodic_repo.delete_by_parent_id(parent_id, user_id=user_id),
episodic_es_repo.delete_by_parent_id(parent_id, user_id=user_id),
episodic_milvus_repo.delete_by_parent_id(parent_id, user_id=user_id),
)

saved_doc = await episodic_repo.append_episodic_memory(doc)
saved_episodic.append(saved_doc)

Expand Down Expand Up @@ -1158,6 +1177,24 @@ async def save_memory_docs(
event_log_docs = grouped_docs.get(MemoryType.EVENT_LOG, [])
if event_log_docs:
event_log_repo = get_bean_by_type(EventLogRecordRawRepository)
event_log_es_repo = get_bean_by_type(EventLogEsRepository)
event_log_milvus_repo = get_bean_by_type(EventLogMilvusRepository)

# Deduplicate: collect unique (parent_id, user_id) pairs and delete old records
# before batch-inserting the new ones.
seen_parent_keys: set = set()
for doc in event_log_docs:
parent_id = getattr(doc, "parent_id", None)
user_id = getattr(doc, "user_id", None)
key = (parent_id, user_id)
if parent_id and key not in seen_parent_keys:
seen_parent_keys.add(key)
await asyncio.gather(
event_log_repo.delete_by_parent_id(parent_id),
event_log_es_repo.delete_by_parent_id(parent_id, user_id=user_id),
event_log_milvus_repo.delete_by_parent_id(parent_id),
)

saved_event_logs = await event_log_repo.create_batch(event_log_docs)
saved_result[MemoryType.EVENT_LOG] = saved_event_logs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,46 @@ async def delete_by_event_id(
)
return False

async def delete_by_parent_id(
self,
parent_id: str,
user_id: Optional[str] = None,
session: Optional[AsyncClientSession] = None,
) -> int:
"""
Delete episodic memories by parent MemCell ID, optionally scoped to a user.

Args:
parent_id: Source MemCell ID (stored as parent_id field)
user_id: If provided, only delete records belonging to this user.
Pass None to delete all records (group + personal) for the parent.
session: Optional MongoDB session for transaction support

Returns:
Number of deleted records
"""
try:
query_filter: Dict[str, Any] = {"parent_id": parent_id}
if user_id is not None:
query_filter["user_id"] = user_id

result = await self.model.find(query_filter, session=session).delete()
count = result.deleted_count if result else 0
logger.info(
"✅ Deleted episodic memories by parent_id=%s user_id=%s: %d records",
parent_id,
user_id,
count,
)
return count
except Exception as e:
logger.error(
"❌ Failed to delete episodic memories by parent_id=%s: %s",
parent_id,
e,
)
return 0

async def delete_by_user_id(
self, user_id: str, session: Optional[AsyncClientSession] = None
) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,39 @@ async def find_by_filters(
logger.error("❌ Failed to retrieve foresights: %s", e)
return []

async def delete_expired(
self,
before: datetime,
session: Optional[AsyncClientSession] = None,
) -> int:
"""
Delete foresight records whose validity period ended before the given time.

Args:
before: Delete foresights with end_time strictly before this datetime.
end_time is stored as an ISO date string (YYYY-MM-DD).
session: Optional MongoDB session for transaction support

Returns:
Number of deleted records
"""
try:
from common_utils.datetime_utils import to_date_str

before_str = to_date_str(before)
query_filter: Dict[str, Any] = {"end_time": {"$lt": before_str}}
result = await self.model.find(query_filter, session=session).delete()
count = result.deleted_count if result else 0
logger.info(
"✅ Deleted expired foresights from MongoDB (before=%s): %d records",
before_str,
count,
)
return count
except Exception as e:
logger.error("❌ Failed to delete expired foresights from MongoDB: %s", e)
return 0

async def delete_by_id(
self, memory_id: str, session: Optional[AsyncClientSession] = None
) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,50 @@ async def append_episodic_memory(

# ==================== Deletion functionality ====================

async def delete_by_parent_id(
self,
parent_id: str,
user_id: Optional[str] = None,
refresh: bool = False,
) -> int:
"""
Delete episodic memory documents by parent MemCell ID.

Args:
parent_id: Source MemCell ID
user_id: If provided, only delete records for this user.
refresh: Whether to refresh the index immediately

Returns:
Number of deleted documents
"""
try:
filter_queries: List[Dict[str, Any]] = [{"term": {"parent_id": parent_id}}]
if user_id is not None:
filter_queries.append({"term": {"user_id": user_id}})

delete_query = {"bool": {"must": filter_queries}}
client = await self.get_client()
index_name = self.get_index_name()
response = await client.delete_by_query(
index=index_name,
body={"query": delete_query},
refresh=refresh,
)
deleted_count = response.get("deleted", 0)
logger.debug(
"✅ Deleted episodic memory by parent_id=%s user_id=%s: %d records",
parent_id,
user_id,
deleted_count,
)
return deleted_count
except Exception as e:
logger.error(
"❌ Failed to delete episodic memory by parent_id=%s: %s", parent_id, e
)
raise

async def delete_by_event_id(self, event_id: str, refresh: bool = False) -> bool:
"""
Delete episodic memory document by event_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,45 @@ async def vector_search(

# ==================== Deletion Functionality ====================

async def delete_by_parent_id(
self,
parent_id: str,
user_id: Optional[str] = None,
) -> int:
"""
Delete episodic memory vectors by parent MemCell ID.

Args:
parent_id: Source MemCell ID
user_id: If provided, only delete records for this user.

Returns:
Number of deleted records
"""
try:
expr = f'parent_id == "{parent_id}"'
if user_id is not None:
expr += f' and user_id == "{user_id}"'

results = await self.collection.query(expr=expr, output_fields=["id"])
delete_count = len(results)
if delete_count > 0:
await self.collection.delete(expr)
logger.debug(
"✅ Deleted episodic memory vectors by parent_id=%s user_id=%s: %d records",
parent_id,
user_id,
delete_count,
)
return delete_count
except Exception as e:
logger.error(
"❌ Failed to delete episodic memory vectors by parent_id=%s: %s",
parent_id,
e,
)
raise

async def delete_by_event_id(self, event_id: str) -> bool:
"""
Delete episodic memory document by event_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,49 @@ async def multi_search(
e,
)
raise

# ==================== Deletion functionality ====================

async def delete_by_parent_id(
self,
parent_id: str,
user_id: Optional[str] = None,
refresh: bool = False,
) -> int:
"""
Delete event log documents by parent memory ID.

Args:
parent_id: Parent memory ID (MemCell or Episode ID)
user_id: If provided, only delete records for this user.
refresh: Whether to refresh the index immediately

Returns:
Number of deleted documents
"""
try:
filter_queries: List[Dict[str, Any]] = [{"term": {"parent_id": parent_id}}]
if user_id is not None:
filter_queries.append({"term": {"user_id": user_id}})

delete_query = {"bool": {"must": filter_queries}}
client = await self.get_client()
index_name = self.get_index_name()
response = await client.delete_by_query(
index=index_name,
body={"query": delete_query},
refresh=refresh,
)
deleted_count = response.get("deleted", 0)
logger.debug(
"✅ Deleted event logs by parent_id=%s user_id=%s: %d records",
parent_id,
user_id,
deleted_count,
)
return deleted_count
except Exception as e:
logger.error(
"❌ Failed to delete event logs by parent_id=%s: %s", parent_id, e
)
raise
Loading