diff --git a/src/api_specs/dtos/memory.py b/src/api_specs/dtos/memory.py index 4b5a6f57..2b46e994 100644 --- a/src/api_specs/dtos/memory.py +++ b/src/api_specs/dtos/memory.py @@ -284,6 +284,10 @@ class MemorizeRequest(BaseModel): # Optional extraction control parameters enable_foresight_extraction: bool = True # Whether to extract foresight enable_event_log_extraction: bool = True # Whether to extract event logs + force_boundary: bool = Field( + default=False, + description="Force the extraction of a new memcell", + ) model_config = {"arbitrary_types_allowed": True} @@ -340,6 +344,10 @@ class MemorizeMessageRequest(BaseModel): description="List of referenced message IDs", examples=[["msg_000"]], ) + force_boundary: bool = Field( + default=False, + description="Force the extraction of a new memcell", + ) @model_validator(mode="after") def validate_role(self): diff --git a/src/api_specs/request_converter.py b/src/api_specs/request_converter.py index b59438f0..d602ce91 100644 --- a/src/api_specs/request_converter.py +++ b/src/api_specs/request_converter.py @@ -420,4 +420,5 @@ async def convert_simple_message_to_memorize_request( group_id=group_id, group_name=group_name, current_time=timestamp, + force_boundary=message_data.get("force_boundary", False), ) diff --git a/src/biz_layer/mem_memorize.py b/src/biz_layer/mem_memorize.py index e951a13c..068e245b 100644 --- a/src/biz_layer/mem_memorize.py +++ b/src/biz_layer/mem_memorize.py @@ -1317,6 +1317,8 @@ async def memorize(request: MemorizeRequest) -> int: request.group_id, request.group_name, request.user_id_list, + old_memory_list=None, + force_boundary=request.force_boundary, ) record_extraction_stage( space_id=space_id, diff --git a/src/memory_layer/memcell_extractor/conv_memcell_extractor.py b/src/memory_layer/memcell_extractor/conv_memcell_extractor.py index 881628e4..96b821b3 100644 --- a/src/memory_layer/memcell_extractor/conv_memcell_extractor.py +++ b/src/memory_layer/memcell_extractor/conv_memcell_extractor.py @@ -49,7 +49,7 @@ class BoundaryDetectionResult: @dataclass class ConversationMemCellExtractRequest(MemCellExtractRequest): - pass + force_boundary: bool = False class ConvMemCellExtractor(MemCellExtractor): @@ -478,7 +478,15 @@ async def extract_memcell( ) # === Normal LLM-based boundary detection === - if request.smart_mask_flag: + if getattr(request, 'force_boundary', False): + boundary_detection_result = BoundaryDetectionResult( + should_end=True, + should_wait=False, + reasoning="Force boundary requested by API client payload", + confidence=1.0, + topic_summary="", + ) + elif request.smart_mask_flag: boundary_detection_result = await self._detect_boundary( conversation_history=history_message_dict_list[:-1], new_messages=new_message_dict_list, diff --git a/src/memory_layer/memory_manager.py b/src/memory_layer/memory_manager.py index e1af421e..14350b40 100644 --- a/src/memory_layer/memory_manager.py +++ b/src/memory_layer/memory_manager.py @@ -81,6 +81,7 @@ async def extract_memcell( group_name: Optional[str] = None, user_id_list: Optional[List[str]] = None, old_memory_list: Optional[List[BaseMemory]] = None, + force_boundary: bool = False, ) -> tuple[Optional[MemCell], Optional[StatusResult]]: """ Extract MemCell (boundary detection + raw data) @@ -115,6 +116,7 @@ async def extract_memcell( group_name=group_name, old_memory_list=old_memory_list, smart_mask_flag=smart_mask_flag, + force_boundary=force_boundary, ) extractor = ConvMemCellExtractor(self.llm_provider)