From 958fa673c0d66e77304b0fb3e884d077d24d18d8 Mon Sep 17 00:00:00 2001 From: SparkLabScout Date: Tue, 3 Mar 2026 04:10:37 +0800 Subject: [PATCH] fix: support multiple memory_types in search API Fixes #78: Search API now iterates over all requested memory_types instead of only using the first one. Changes: - get_keyword_search_results: now iterates over all supported memory_types (EPISODIC_MEMORY, EVENT_LOG, FORESIGHT) and merges results with deduplication - get_vector_search_results: now iterates over all supported memory_types and merges results with deduplication - For unsupported types (e.g., profile which is stored in MongoDB), logs an info message and skips instead of erroring out Behavior: - When multiple memory_types are provided, all supported types are searched - Results are merged and deduplicated by id - The first memory_type is still used for metrics/logging purposes --- src/agentic_layer/memory_manager.py | 288 +++++++++++++++------------- 1 file changed, 151 insertions(+), 137 deletions(-) diff --git a/src/agentic_layer/memory_manager.py b/src/agentic_layer/memory_manager.py index df42b05a..679d6284 100644 --- a/src/agentic_layer/memory_manager.py +++ b/src/agentic_layer/memory_manager.py @@ -337,12 +337,11 @@ async def get_keyword_search_results( retrieve_mem_request: 'RetrieveMemRequest', retrieve_method: str = RetrieveMethod.KEYWORD.value, ) -> List[Dict[str, Any]]: - """Keyword search with stage-level metrics""" + """Keyword search with stage-level metrics - supports multiple memory_types""" stage_start = time.perf_counter() + memory_types = retrieve_mem_request.memory_types memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' + memory_types[0].value if memory_types else 'unknown' ) try: @@ -356,7 +355,6 @@ async def get_keyword_search_results( group_id = retrieve_mem_request.group_id start_time = retrieve_mem_request.start_time end_time = retrieve_mem_request.end_time - memory_types = retrieve_mem_request.memory_types # Convert query string to search word list # Use jieba for search mode word segmentation, then filter stopwords @@ -375,32 +373,41 @@ async def get_keyword_search_results( if end_time is not None: date_range["lte"] = end_time - mem_type = memory_types[0] - - repo_class = ES_REPO_MAP.get(mem_type) - if not repo_class: - logger.warning(f"Unsupported memory_type: {mem_type}") - return [] - - es_repo = get_bean_by_type(repo_class) - logger.debug(f"Using {repo_class.__name__} for {mem_type}") - - results = await es_repo.multi_search( - query=query_words, - user_id=user_id, - group_id=group_id, - size=top_k, - from_=0, - date_range=date_range, - ) + # Iterate over all requested memory_types and collect results + all_results = [] + seen_ids = set() + + for mem_type in memory_types: + # Skip unsupported memory types (e.g., profile which is stored in MongoDB) + repo_class = ES_REPO_MAP.get(mem_type) + if not repo_class: + logger.info(f"Skipping unsupported memory_type for keyword search: {mem_type}") + continue + + es_repo = get_bean_by_type(repo_class) + logger.debug(f"Using {repo_class.__name__} for {mem_type}") + + results = await es_repo.multi_search( + query=query_words, + user_id=user_id, + group_id=group_id, + size=top_k, + from_=0, + date_range=date_range, + ) - # Mark memory_type, search_source, and unified score - if results: - for r in results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.KEYWORD.value - r['id'] = r.get('_id', '') # Unify ES '_id' to 'id' - r['score'] = r.get('_score', 0.0) # Unified score field + # Mark memory_type, search_source, and unified score + # Deduplicate by id + if results: + for r in results: + result_id = r.get('_id', '') + if result_id not in seen_ids: + seen_ids.add(result_id) + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.KEYWORD.value + r['id'] = result_id # Unify ES '_id' to 'id' + r['score'] = r.get('_score', 0.0) # Unified score field + all_results.append(r) # Record stage metrics record_retrieve_stage( @@ -410,7 +417,7 @@ async def get_keyword_search_results( duration_seconds=time.perf_counter() - stage_start, ) - return results or [] + return all_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, @@ -472,24 +479,41 @@ async def get_vector_search_results( retrieve_mem_request: 'RetrieveMemRequest', retrieve_method: str = RetrieveMethod.VECTOR.value, ) -> List[Dict[str, Any]]: - """Vector search with stage-level metrics (embedding + milvus_search)""" + """Vector search with stage-level metrics - supports multiple memory_types""" + memory_types = retrieve_mem_request.memory_types memory_type = ( - retrieve_mem_request.memory_types[0].value - if retrieve_mem_request.memory_types - else 'unknown' + memory_types[0].value if memory_types else 'unknown' ) + # Get vectorization service (shared across all memory types) + vectorize_service = get_vectorize_service() + + # Convert query text to vector (embedding stage) - shared across all memory types + logger.debug(f"Starting to vectorize query text: {retrieve_mem_request.query}") + embedding_start = time.perf_counter() + query_vector = await vectorize_service.get_embedding(retrieve_mem_request.query) + query_vector_list = query_vector.tolist() # Convert to list format + record_retrieve_stage( + retrieve_method=retrieve_method, + stage='embedding', + memory_type=memory_type, + duration_seconds=time.perf_counter() - embedding_start, + ) + logger.debug( + f"Query text vectorization completed, vector dimension: {len(query_vector_list)}" + ) + + # Iterate over all requested memory_types and collect results + all_results = [] + seen_ids = set() + try: - # Get parameters from Request - logger.debug( - f"get_vector_search_results called with retrieve_mem_request: {retrieve_mem_request}" - ) + # Get common parameters from Request if not retrieve_mem_request: raise ValueError( "retrieve_mem_request is required for get_vector_search_results" ) - query = retrieve_mem_request.query - if not query: + if not retrieve_mem_request.query: raise ValueError("query is required for retrieve_mem_vector") user_id = retrieve_mem_request.user_id @@ -497,117 +521,107 @@ async def get_vector_search_results( top_k = retrieve_mem_request.top_k start_time = retrieve_mem_request.start_time end_time = retrieve_mem_request.end_time - mem_type = retrieve_mem_request.memory_types[0] - - logger.debug( - f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}" - ) - # Get vectorization service - vectorize_service = get_vectorize_service() - - # Convert query text to vector (embedding stage) - logger.debug(f"Starting to vectorize query text: {query}") - embedding_start = time.perf_counter() - query_vector = await vectorize_service.get_embedding(query) - query_vector_list = query_vector.tolist() # Convert to list format - record_retrieve_stage( - retrieve_method=retrieve_method, - stage='embedding', - memory_type=memory_type, - duration_seconds=time.perf_counter() - embedding_start, - ) logger.debug( - f"Query text vectorization completed, vector dimension: {len(query_vector_list)}" + f"retrieve_mem_vector called with query: {retrieve_mem_request.query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}" ) - # Select Milvus repository based on memory type - match mem_type: - case MemoryType.FORESIGHT: - milvus_repo = get_bean_by_type(ForesightMilvusRepository) - case MemoryType.EVENT_LOG: - milvus_repo = get_bean_by_type(EventLogMilvusRepository) - case MemoryType.EPISODIC_MEMORY: - milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) - case _: - raise ValueError(f"Unsupported memory type: {mem_type}") + for mem_type in memory_types: + # Skip unsupported memory types (e.g., profile which is stored in MongoDB) + # Select Milvus repository based on memory type + match mem_type: + case MemoryType.FORESIGHT: + milvus_repo = get_bean_by_type(ForesightMilvusRepository) + case MemoryType.EVENT_LOG: + milvus_repo = get_bean_by_type(EventLogMilvusRepository) + case MemoryType.EPISODIC_MEMORY: + milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository) + case _: + logger.info(f"Skipping unsupported memory_type for vector search: {mem_type}") + continue - # Handle time range filter conditions - start_time_dt = None - end_time_dt = None - current_time_dt = None + # Handle time range filter conditions + start_time_dt = None + end_time_dt = None + current_time_dt = None - if start_time is not None: - start_time_dt = ( - from_iso_format(start_time) - if isinstance(start_time, str) - else start_time - ) + if start_time is not None: + start_time_dt = ( + from_iso_format(start_time) + if isinstance(start_time, str) + else start_time + ) - if end_time is not None: - if isinstance(end_time, str): - end_time_dt = from_iso_format(end_time) - # If date only format, set to end of day - if len(end_time) == 10: - end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59) + if end_time is not None: + if isinstance(end_time, str): + end_time_dt = from_iso_format(end_time) + # If date only format, set to end of day + if len(end_time) == 10: + end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59) + else: + end_time_dt = end_time + + # Handle foresight time range (only valid for foresight) + if mem_type == MemoryType.FORESIGHT: + if retrieve_mem_request.start_time: + start_time_dt = from_iso_format(retrieve_mem_request.start_time) + if retrieve_mem_request.end_time: + end_time_dt = from_iso_format(retrieve_mem_request.end_time) + if retrieve_mem_request.current_time: + current_time_dt = from_iso_format(retrieve_mem_request.current_time) + + # Call Milvus vector search (pass different parameters based on memory type) + milvus_start = time.perf_counter() + if mem_type == MemoryType.FORESIGHT: + # Foresight: supports time range and validity filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=start_time_dt, + end_time=end_time_dt, + current_time=current_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) else: - end_time_dt = end_time - - # Handle foresight time range (only valid for foresight) - if mem_type == MemoryType.FORESIGHT: - if retrieve_mem_request.start_time: - start_time_dt = from_iso_format(retrieve_mem_request.start_time) - if retrieve_mem_request.end_time: - end_time_dt = from_iso_format(retrieve_mem_request.end_time) - if retrieve_mem_request.current_time: - current_time_dt = from_iso_format(retrieve_mem_request.current_time) - - # Call Milvus vector search (pass different parameters based on memory type) - milvus_start = time.perf_counter() - if mem_type == MemoryType.FORESIGHT: - # Foresight: supports time range and validity filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - current_time=current_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, - ) - else: - # Episodic memory and event log: use timestamp filtering, supports radius parameter - search_results = await milvus_repo.vector_search( - query_vector=query_vector_list, - user_id=user_id, - group_id=group_id, - start_time=start_time_dt, - end_time=end_time_dt, - limit=top_k, - score_threshold=0.0, - radius=retrieve_mem_request.radius, + # Episodic memory and event log: use timestamp filtering, supports radius parameter + search_results = await milvus_repo.vector_search( + query_vector=query_vector_list, + user_id=user_id, + group_id=group_id, + start_time=start_time_dt, + end_time=end_time_dt, + limit=top_k, + score_threshold=0.0, + radius=retrieve_mem_request.radius, + ) + record_retrieve_stage( + retrieve_method=retrieve_method, + stage='milvus_search', + memory_type=mem_type.value, + duration_seconds=time.perf_counter() - milvus_start, ) - record_retrieve_stage( - retrieve_method=retrieve_method, - stage='milvus_search', - memory_type=memory_type, - duration_seconds=time.perf_counter() - milvus_start, - ) - - for r in search_results: - r['memory_type'] = mem_type.value - r['_search_source'] = RetrieveMethod.VECTOR.value - # Milvus already uses 'score', no need to rename - return search_results + # Deduplicate by id + if search_results: + for r in search_results: + result_id = r.get('id', '') + if result_id not in seen_ids: + seen_ids.add(result_id) + r['memory_type'] = mem_type.value + r['_search_source'] = RetrieveMethod.VECTOR.value + # Milvus already uses 'score', no need to rename + all_results.append(r) + + return all_results except Exception as e: record_retrieve_stage( retrieve_method=retrieve_method, stage=RetrieveMethod.VECTOR.value, memory_type=memory_type, - duration_seconds=time.perf_counter() - milvus_start, + duration_seconds=time.perf_counter() - embedding_start, ) record_retrieve_error( retrieve_method=retrieve_method,