Skip to content

Commit ec64e12

Browse files
Merge branch 'main' into add-checkpoints
2 parents 70cc64d + 8cc64bb commit ec64e12

File tree

10 files changed

+765
-61
lines changed

10 files changed

+765
-61
lines changed

dynamiq/nodes/agents/base.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class AgentInputSchema(BaseModel):
142142

143143
user_id: str | None = Field(default=None, description="Parameter to provide user ID.")
144144
session_id: str | None = Field(default=None, description="Parameter to provide session ID.")
145-
metadata: dict | list = Field(default={}, description="Parameter to provide metadata.")
145+
metadata: dict = Field(default={}, description="Parameter to provide metadata in key-value pairs.")
146146

147147
model_config = ConfigDict(extra="allow", strict=True, arbitrary_types_allowed=True)
148148

@@ -398,34 +398,34 @@ def set_prompt_variable(self, variable_name: str, value: Any):
398398
"""Sets or updates a prompt variable."""
399399
self.system_prompt_manager.set_variable(variable_name, value)
400400

401-
def _prepare_metadata(self, input_data: dict) -> dict:
401+
def _prepare_metadata(self, input_data: AgentInputSchema) -> dict:
402402
"""
403403
Prepare metadata from input data.
404404
405405
Args:
406-
input_data (dict): Input data containing user information
406+
input_data: Agent input schema containing user information
407407
408408
Returns:
409409
dict: Processed metadata
410410
"""
411-
EXCLUDED_KEYS = {"user_id", "session_id", "input", "metadata", "files", "images", "tool_params"}
412-
custom_metadata = input_data.get("metadata", {}).copy()
413-
custom_metadata.update({k: v for k, v in input_data.items() if k not in EXCLUDED_KEYS})
411+
custom_metadata = input_data.metadata.copy()
414412

413+
# Add extra fields that were provided (model allows extra fields with ConfigDict extra="allow")
414+
if input_data.model_extra:
415+
custom_metadata.update(input_data.model_extra)
416+
417+
# Clean up any leaked fields
415418
if "files" in custom_metadata:
416419
del custom_metadata["files"]
417420
if "images" in custom_metadata:
418421
del custom_metadata["images"]
419422
if "tool_params" in custom_metadata:
420423
del custom_metadata["tool_params"]
421424

422-
user_id = input_data.get("user_id")
423-
session_id = input_data.get("session_id")
424-
425-
if user_id:
426-
custom_metadata["user_id"] = user_id
427-
if session_id:
428-
custom_metadata["session_id"] = session_id
425+
if input_data.user_id:
426+
custom_metadata["user_id"] = input_data.user_id
427+
if input_data.session_id:
428+
custom_metadata["session_id"] = input_data.session_id
429429

430430
return custom_metadata
431431

@@ -439,12 +439,10 @@ def execute(
439439
"""
440440
Executes the agent with the given input data.
441441
"""
442-
input_dict = dict(input_data)
443-
log_data = input_dict.copy()
444-
442+
# Convert to dict only for logging (to avoid logging BytesIO objects)
443+
log_data = input_data.model_dump()
445444
if log_data.get("images"):
446445
log_data["images"] = [f"image_{i}" for i in range(len(log_data["images"]))]
447-
448446
if log_data.get("files"):
449447
log_data["files"] = [f"file_{i}" for i in range(len(log_data["files"]))]
450448

@@ -453,20 +451,24 @@ def execute(
453451
config = ensure_config(config)
454452
self.run_on_node_execute_run(config.callbacks, **kwargs)
455453

456-
custom_metadata = self._prepare_metadata(input_dict)
454+
custom_metadata = self._prepare_metadata(input_data)
457455
self._current_call_context = {
458-
"user_id": input_dict.get("user_id"),
459-
"session_id": input_dict.get("session_id"),
456+
"user_id": input_data.user_id,
457+
"session_id": input_data.session_id,
460458
"metadata": custom_metadata,
461459
}
462460

463461
input_message = input_message or self.input_message or Message(role=MessageRole.USER, content=input_data.input)
464-
input_message = input_message.format_message(**input_dict)
462+
# Convert to dict for format_message, excluding fields that are unsafe for templates
463+
# (binary data like files/images, complex objects like tool_params, and input which is already handled)
464+
standard_fields = set(AgentInputSchema.model_fields.keys())
465+
extra_fields = input_data.model_dump(exclude=standard_fields)
466+
input_message = input_message.format_message(**extra_fields)
465467

466-
use_memory = self.memory and (input_dict.get("user_id") or input_dict.get("session_id"))
468+
use_memory = self.memory and (input_data.user_id or input_data.session_id)
467469

468470
if use_memory:
469-
history_messages = self._retrieve_memory(input_dict)
471+
history_messages = self._retrieve_memory(input_data)
470472
if len(history_messages) > 0:
471473
history_messages.insert(
472474
0,
@@ -596,18 +598,19 @@ def retrieve_conversation_history(
596598
)
597599
return conversation
598600

599-
def _retrieve_memory(self, input_data: dict) -> list[Message]:
601+
def _retrieve_memory(self, input_data: AgentInputSchema) -> list[Message]:
600602
"""
603+
Args:
604+
input_data: Agent input schema containing user information
605+
606+
Returns:
607+
list[Message]: List of messages forming a valid conversation context
601608
Retrieves memory messages when user_id and/or session_id are provided.
602609
"""
603-
user_id = input_data.get("user_id")
604-
session_id = input_data.get("session_id")
605-
606-
user_query = input_data.get("input", "")
607610
history_messages = self.retrieve_conversation_history(
608-
user_query=user_query,
609-
user_id=user_id,
610-
session_id=session_id,
611+
user_query=input_data.input,
612+
user_id=input_data.user_id,
613+
session_id=input_data.session_id,
611614
strategy=self.memory_retrieval_strategy,
612615
)
613616
logger.info("Agent %s - %s: retrieved %d messages from memory", self.name, self.id, len(history_messages))

0 commit comments

Comments
 (0)