diff --git a/dynamiq/callbacks/streaming.py b/dynamiq/callbacks/streaming.py index 0c757fdaa..73e9f6a03 100644 --- a/dynamiq/callbacks/streaming.py +++ b/dynamiq/callbacks/streaming.py @@ -7,9 +7,9 @@ from dynamiq.callbacks import BaseCallbackHandler from dynamiq.callbacks.base import get_run_id from dynamiq.types.streaming import ( - AgentReasoningEventMessageData, AgentToolData, AgentToolInputDeltaData, + AgentToolInputStartData, StreamingEntitySource, StreamingEventMessage, StreamingMode, @@ -607,12 +607,10 @@ def _emit(self, content: str, step: str, force: bool = False) -> None: def _emit_tool_input_start(self) -> None: """Emit a tool_input start event with full metadata before the first delta.""" tool_data = self._resolve_tool_data() - start_model = AgentReasoningEventMessageData( + start_model = AgentToolInputStartData( tool_run_id=self.agent._streaming_tool_run_id or "", - thought="", action=self._current_action_name or "", tool=tool_data, - action_input="", loop_num=self.loop_num, ) self.agent.stream_content( @@ -1030,21 +1028,16 @@ def _initialize_json_object_field_state(self, buf: str, field_name: str, state: return True return False - def _process_json_mode(self, final_answer_only: bool) -> None: - """ - Processing for function calling mode. + def _try_initialize_next_json_field(self, buf: str, final_answer_only: bool) -> None: + """Try to initialize the next JSON field state (thought, answer, or action_input). - Supports multiple tool calls (parallel function calling) — no single-cycle - constraint is enforced here, unlike structured output mode. - - Args: - final_answer_only: Whether to stream only final answers + Each initializer is a no-op when _current_state is already set, so this is safe + to call multiple times within a single chunk processing cycle. """ - buf = self._buffer - - self._initialize_json_field_state( - buf, JSONStreamingField.THOUGHT.value, StreamingState.REASONING, final_answer_only - ) + if not self._state_has_emitted.get(StreamingState.REASONING, False): + self._initialize_json_field_state( + buf, JSONStreamingField.THOUGHT.value, StreamingState.REASONING, final_answer_only + ) if self._answer_started: self._initialize_json_field_state(buf, JSONStreamingField.ANSWER.value, StreamingState.ANSWER) @@ -1058,15 +1051,38 @@ def _process_json_mode(self, final_answer_only: bool) -> None: buf, JSONStreamingField.ACTION_INPUT.value, StreamingState.TOOL_INPUT ) + def _emit_tool_input_state(self, buf: str) -> None: + """Emit content for the current TOOL_INPUT state.""" + if self._fc_object_tool_input: + self._emit_json_object_field_content(buf, StreamingState.TOOL_INPUT) + else: + self._emit_json_field_content(buf, StreamingState.TOOL_INPUT) + + def _process_json_mode(self, final_answer_only: bool) -> None: + """ + Processing for function calling mode. + + Supports multiple tool calls (parallel function calling) — no single-cycle + constraint is enforced here, unlike structured output mode. + + Args: + final_answer_only: Whether to stream only final answers + """ + buf = self._buffer + + self._try_initialize_next_json_field(buf, final_answer_only) + if self._current_state == StreamingState.REASONING: - self._emit_json_field_content(buf, StreamingState.REASONING) + field_complete = self._emit_json_field_content(buf, StreamingState.REASONING) + if field_complete: + # Reasoning completed — the buffer may already contain the next field + # (e.g. action_input). Re-run to detect and process it in the same chunk, + # before _reset_tool_call_state clears the buffer on the next parallel call. + self._process_json_mode(final_answer_only) elif self._current_state == StreamingState.ANSWER: self._emit_json_field_content(buf, StreamingState.ANSWER) elif self._current_state == StreamingState.TOOL_INPUT: - if self._fc_object_tool_input: - self._emit_json_object_field_content(buf, StreamingState.TOOL_INPUT) - else: - self._emit_json_field_content(buf, StreamingState.TOOL_INPUT) + self._emit_tool_input_state(buf) def _skip_whitespace(self, text: str, start: int) -> int: """Skip whitespace characters starting from the given position.""" @@ -1096,7 +1112,9 @@ def _emit_json_field_content(self, buf: str, step: str) -> bool: self._emit(buf[segment_start:segment_end], step=step) segment_start = segment_end self._state_last_emit_index = end_quote - # Reset the state + # Mark the field as emitted and reset the state + if step in self._state_has_emitted: + self._state_has_emitted[step] = True self._current_state = None return True diff --git a/dynamiq/nodes/agents/agent.py b/dynamiq/nodes/agents/agent.py index 922612fde..abce6fb7b 100644 --- a/dynamiq/nodes/agents/agent.py +++ b/dynamiq/nodes/agents/agent.py @@ -371,7 +371,7 @@ def _stream_batch_reasoning_event( per_tool_reasoning.append( AgentReasoningEventMessageData( tool_run_id=tid, - thought="", + thought=tp.get("thought", ""), action=tp["name"], tool=tool_data, action_input=tp["input"], @@ -585,7 +585,7 @@ def _handle_function_calling_mode( tc_input = tc_args["action_input"] if not isinstance(tc_input, dict): tc_input = {"input": tc_input} - tool_items.append(ToolCallItem(name=tc_name, input=tc_input)) + tool_items.append(ToolCallItem(name=tc_name, input=tc_input, thought=tc_args.get("thought", ""))) validated = ParallelToolCallsInputSchema(tools=tool_items) action_input = validated.model_dump() @@ -850,19 +850,20 @@ def _execute_single_tool( action_type=tool.action_type.value if tool.action_type else None, ) - self._stream_agent_event( - AgentReasoningEventMessageData( - tool_run_id=tool_run_id, - thought=thought or "", - action=action, - tool=tool_data, - action_input=action_input, - loop_num=loop_num, - ), - "reasoning", - config, - **kwargs, - ) + if not is_parallel: + self._stream_agent_event( + AgentReasoningEventMessageData( + tool_run_id=tool_run_id, + thought=thought or "", + action=action, + tool=tool_data, + action_input=action_input, + loop_num=loop_num, + ), + "reasoning", + config, + **kwargs, + ) try: if isinstance(tool, ContextManagerTool): tool_result = None @@ -1669,7 +1670,9 @@ def _execute_tools( } ) continue - prepared_tools.append({"order": idx, "name": tool_name, "input": tool_input}) + prepared_tools.append( + {"order": idx, "name": tool_name, "input": tool_input, "thought": td.get("thought", "")} + ) def _execute_single_tool_to_result(tool_payload: dict[str, Any], **extra) -> dict[str, Any]: """Execute a single tool and wrap the result as a dict.""" diff --git a/dynamiq/nodes/tools/parallel_tool_calls.py b/dynamiq/nodes/tools/parallel_tool_calls.py index b109c445a..87f13ed89 100644 --- a/dynamiq/nodes/tools/parallel_tool_calls.py +++ b/dynamiq/nodes/tools/parallel_tool_calls.py @@ -21,6 +21,7 @@ class ToolCallItem(BaseModel): default_factory=dict, description="Input parameters for the tool as key-value pairs", ) + thought: str = Field(default="", description="Reasoning for this tool call.") model_config = ConfigDict(extra="forbid") diff --git a/dynamiq/types/streaming.py b/dynamiq/types/streaming.py index 4cffc0cbc..8a8d902f0 100644 --- a/dynamiq/types/streaming.py +++ b/dynamiq/types/streaming.py @@ -106,6 +106,20 @@ class AgentReasoningEventMessageData(BaseModel): loop_num: int +# --------------------------------------------------------------------------- +# Tool input streaming models +# --------------------------------------------------------------------------- + + +class AgentToolInputStartData(BaseModel): + """Emitted once when tool_input streaming begins for a tool call.""" + + tool_run_id: str + action: str + tool: AgentToolData + loop_num: int + + class AgentToolInputDeltaData(BaseModel): """Lean delta for tool_input streaming. Only tool_run_id and action_input change.""" @@ -113,6 +127,22 @@ class AgentToolInputDeltaData(BaseModel): action_input: Any +class AgentToolInputErrorEventMessageData(BaseModel): + """Emitted when action parsing fails after tool input was already + partially streamed, so consumers can discard the invalid chunks. + """ + + tool_run_id: str + name: str + error: str + loop_num: int + + +# --------------------------------------------------------------------------- +# Tool result streaming model +# --------------------------------------------------------------------------- + + class AgentToolResultEventMessageData(BaseModel): """Model for agent tool result streaming event data.""" @@ -144,19 +174,6 @@ def to_dict(self, **kwargs) -> dict: return data -class AgentToolInputErrorEventMessageData(BaseModel): - """Model for agent tool input error streaming event data. - - Emitted when action parsing fails after tool input was already - partially streamed, so consumers can discard the invalid chunks. - """ - - tool_run_id: str - name: str - error: str - loop_num: int - - class StreamingConfig(BaseModel): """Configuration for streaming. diff --git a/tests/integration_with_creds/agents/streaming_assertions.py b/tests/integration_with_creds/agents/streaming_assertions.py new file mode 100644 index 000000000..cd78a1bfe --- /dev/null +++ b/tests/integration_with_creds/agents/streaming_assertions.py @@ -0,0 +1,666 @@ +import json +from enum import Enum, auto + +from dynamiq.nodes.types import InferenceMode +from dynamiq.types.streaming import StreamingMode +from dynamiq.utils.logger import logger + + +def collect_streaming_events(streaming_iterator, agent_id): + """Collect streaming events in chronological order. + + Returns: + list[tuple[str, Any]]: [(step, content), ...] in the order received. + """ + ordered_events = [] + raw_events = [] + + for event in streaming_iterator: + raw_events.append(event) + + if event.entity_id != agent_id: + continue + data = event.data + if not isinstance(data, dict): + continue + + choices = data.get("choices") or [] + if not choices: + continue + delta = choices[0].get("delta", {}) + step = delta.get("step") + content = delta.get("content") + if step is not None: + ordered_events.append((step, content)) + + logger.info(f"Collected {len(ordered_events)} streaming events from {len(raw_events)} raw events") + + return ordered_events + + +# --------------------------------------------------------------------------- +# FSM states and event classification +# --------------------------------------------------------------------------- + + +class State(Enum): + INIT = auto() + REASONING = auto() + TOOL_INPUT = auto() + POST_PARSE = auto() + TOOL_RESULT = auto() + ERROR = auto() + ANSWER = auto() + + +def _classify_event(step, content): + """Map a raw (step, content) pair to an FSM event name.""" + if step == "reasoning": + if isinstance(content, dict) and "tool_run_id" in content: + return "post_parse_reasoning" + return "reasoning" + if step == "tool_input_start": + return "tool_input_start" + if step == "tool_input": + return "tool_input" + if step == "tool": + return "tool_result" + if step == "tool_input_error": + return "tool_input_error" + if step == "answer": + return "answer" + return None + + +# --------------------------------------------------------------------------- +# Transition tables +# --------------------------------------------------------------------------- + +_TRANSITIONS_WITH_TOOL_INPUT = { + State.INIT: { + "reasoning": State.REASONING, + }, + State.REASONING: { + "reasoning": State.REASONING, + "tool_input_start": State.TOOL_INPUT, + "answer": State.ANSWER, + }, + State.TOOL_INPUT: { + "tool_input_start": State.TOOL_INPUT, + "tool_input": State.TOOL_INPUT, + "post_parse_reasoning": State.POST_PARSE, + "tool_input_error": State.ERROR, + }, + State.POST_PARSE: { + "post_parse_reasoning": State.POST_PARSE, + "tool_result": State.TOOL_RESULT, + }, + State.TOOL_RESULT: { + "tool_result": State.TOOL_RESULT, + "reasoning": State.REASONING, + "answer": State.ANSWER, + }, + State.ERROR: { + "tool_input_error": State.ERROR, + "reasoning": State.REASONING, + "answer": State.ANSWER, + }, + State.ANSWER: { + "answer": State.ANSWER, + }, +} + +# FC mode: allows reasoning after tool_input for parallel tool calls. +_TRANSITIONS_FC = { + **_TRANSITIONS_WITH_TOOL_INPUT, + State.TOOL_INPUT: { + **_TRANSITIONS_WITH_TOOL_INPUT[State.TOOL_INPUT], + "reasoning": State.REASONING, + }, +} + +_TRANSITIONS_DEFAULT = { + State.INIT: { + "reasoning": State.REASONING, + "answer": State.ANSWER, + }, + State.REASONING: { + "reasoning": State.REASONING, + "post_parse_reasoning": State.POST_PARSE, + "answer": State.ANSWER, + }, + State.POST_PARSE: { + "post_parse_reasoning": State.POST_PARSE, + "tool_result": State.TOOL_RESULT, + }, + State.TOOL_RESULT: { + "tool_result": State.TOOL_RESULT, + "reasoning": State.REASONING, + "answer": State.ANSWER, + }, + State.ANSWER: { + "answer": State.ANSWER, + }, +} + +_TRANSITIONS_FINAL = { + State.INIT: { + "answer": State.ANSWER, + }, + State.ANSWER: { + "answer": State.ANSWER, + }, +} + + +# --------------------------------------------------------------------------- +# Structural validators (called inline during FSM walk) +# --------------------------------------------------------------------------- + + +def _validate_post_parse_reasoning(content, idx): + assert isinstance(content, dict), f"Event {idx}: post-parse reasoning should be dict, got {type(content)}" + for key in ("thought", "action", "tool", "action_input", "loop_num"): + assert key in content, f"Event {idx}: post-parse reasoning missing '{key}': {content}" + tool = content["tool"] + assert "name" in tool and "type" in tool, f"Event {idx}: post-parse reasoning tool missing name/type: {tool}" + + +def _validate_tool_result(content, idx): + assert isinstance(content, dict), f"Event {idx}: tool result should be dict, got {type(content)}" + for key in ("tool_run_id", "name", "result", "status"): + assert key in content, f"Event {idx}: tool result missing '{key}': {content}" + + +def _validate_tool_input_start(content, idx): + assert isinstance(content, dict), f"Event {idx}: tool_input_start should be dict, got {type(content)}" + for key in ("tool_run_id", "action", "tool"): + assert key in content, f"Event {idx}: tool_input_start missing '{key}': {content}" + + +def _validate_tool_input(content, idx): + assert isinstance(content, dict), f"Event {idx}: tool_input should be dict, got {type(content)}" + for key in ("tool_run_id", "action_input"): + assert key in content, f"Event {idx}: tool_input missing '{key}': {content}" + + +_VALIDATORS = { + "post_parse_reasoning": _validate_post_parse_reasoning, + "tool_result": _validate_tool_result, + "tool_input_start": _validate_tool_input_start, + "tool_input": _validate_tool_input, +} + + +# --------------------------------------------------------------------------- +# Shared FSM helpers (reusable across all modes) +# --------------------------------------------------------------------------- + + +def _fsm_step_transition(event_name, state, transitions, idx, step, content): + """Validate and return the next state for a single FSM step. + + Asserts that event_name is allowed from the current state, runs structural + validators, and returns next_state. + """ + assert event_name is not None, f"Event {idx}: unknown step '{step}'" + + allowed = transitions.get(state, {}) + assert event_name in allowed, ( + f"Event {idx}: unexpected '{event_name}' in state {state.name}. " + f"Allowed: {list(allowed.keys())}. " + f"Raw: step={step}, content_type={'dict' if isinstance(content, dict) else 'str'}" + ) + + validator = _VALIDATORS.get(event_name) + if validator: + validator(content, idx) + + return allowed[event_name] + + +def _track_reasoning(event_name, state, next_state, content, reasoning_blocks): + """Track reasoning block lifecycle. Call on every FSM step.""" + if next_state == State.REASONING and state != State.REASONING: + reasoning_blocks.append("") + if event_name == "reasoning" and reasoning_blocks: + if isinstance(content, str): + reasoning_blocks[-1] += content + elif isinstance(content, dict) and "thought" in content: + reasoning_blocks[-1] += content["thought"] + + +def _track_tool_input(event_name, content, tool_blocks): + """Track tool_input_start and tool_input chunk accumulation. Returns updated max delta.""" + if event_name == "tool_input_start" and isinstance(content, dict): + tid = content.get("tool_run_id") + if tid: + tool_blocks[tid] = { + "name": content.get("action"), + "action_input_chunks": [], + } + elif event_name == "tool_input" and isinstance(content, dict): + tid = content.get("tool_run_id") + if tid and tid in tool_blocks: + tool_blocks[tid]["action_input_chunks"].append(content.get("action_input", "")) + + +def _handle_tool_result(event_name, content, tool_blocks, reasoning_blocks, run_parallel_count): + """Handle tool_result / tool_input_error events. Returns updated run_parallel_count.""" + if event_name not in ("tool_result", "tool_input_error") or not isinstance(content, dict): + return run_parallel_count + + tid = content.get("tool_run_id") + result_name = content.get("name", "") + + if result_name == "run-parallel": + run_parallel_count -= 1 + assert run_parallel_count == 0, ( + f"run-parallel tool_result without matching post_parse_reasoning, " + f"run_parallel_count={run_parallel_count}" + ) + if tid and tid in tool_blocks: + tool_blocks.pop(tid) + else: + if tid and tid in tool_blocks: + tool_blocks.pop(tid) + if reasoning_blocks: + reasoning_blocks.pop(0) + + return run_parallel_count + + +def _handle_answer(event_name, reasoning_blocks): + """Pop reasoning block on answer event.""" + if event_name == "answer" and reasoning_blocks: + reasoning_blocks.pop(0) + + +def _match_action_input(accumulated: str, expected) -> bool: + """Check if accumulated tool_input string matches expected action_input. + + Handles: + 1. Direct string match. + 2. Wrapped in {"input": ...} dict — compare against inner value. + 3. Accumulated is JSON-escaped (streamed as raw JSON) — decode and compare. + 4. Both sides JSON-decoded for structural comparison. + """ + attempts = [] + + # 1. Direct match + if accumulated == expected: + return True + attempts.append(f"direct: {accumulated!r} == {expected!r} -> False") + + # 2. Wrapped dict match + if isinstance(expected, dict) and "input" in expected: + inner = expected["input"] + if accumulated == inner: + return True + attempts.append(f"wrapped: {accumulated!r} == {inner!r} -> False") + + # 3. Decode accumulated (JSON-escaped streaming) and compare. + # Structured output streams action_input as a JSON string field, so the + # accumulated text is the raw string body with escape sequences (e.g. + # {\"key\":\"val\"}). Wrap in quotes to form a valid JSON string literal + # before decoding. + decoded = None + try: + decoded = json.loads(accumulated) + except (json.JSONDecodeError, TypeError): + try: + decoded = json.loads(f'"{accumulated}"') + except (json.JSONDecodeError, TypeError): + pass + if decoded is None: + attempts.append("json.loads(accumulated) -> FAILED") + + # If decoded is itself a JSON string (structured output double-encoding), + # unwrap one more level. + if isinstance(decoded, str): + try: + decoded = json.loads(decoded) + except (json.JSONDecodeError, TypeError): + pass + + if decoded is not None: + if decoded == expected: + return True + attempts.append(f"decoded==expected: {decoded!r} == {expected!r} -> False") + + if isinstance(expected, dict) and "input" in expected: + inner = expected["input"] + if decoded == inner: + return True + attempts.append(f"decoded==inner: {decoded!r} == {inner!r} -> False") + + # 4. Both sides JSON-decoded (inner may also be a JSON string) + if isinstance(inner, str): + try: + inner_decoded = json.loads(inner) + if decoded == inner_decoded: + return True + attempts.append(f"decoded==inner_decoded: {decoded!r} == {inner_decoded!r} -> False") + except (json.JSONDecodeError, TypeError): + attempts.append("json.loads(inner) -> FAILED") + + logger.debug("[_match_action_input] all attempts failed:\n" + "\n".join(f" {a}" for a in attempts)) + return False + + +def _validate_single_post_parse(content, tool_blocks, reasoning_blocks): + """Validate an individual (non-run-parallel) post_parse_reasoning event. + + Checks: + - Must not appear when multiple tools are in-flight. + - Accumulated tool_input matches action_input (direct or {"input": ...} wrapped). + - Accumulated reasoning matches thought. + """ + tid = content.get("tool_run_id") + tool_name = content.get("action", "") + + assert len(tool_blocks) <= 1, ( + f"Individual post_parse_reasoning for '{tool_name}' appeared with " + f"{len(tool_blocks)} tools in-flight. Individual post_parse events " + f"should not be emitted during parallel execution: " + f"{[b['name'] for b in tool_blocks.values()]}" + ) + + if tid and tid in tool_blocks: + block = tool_blocks[tid] + accumulated_input = "".join(block["action_input_chunks"]) + expected_input = content.get("action_input") + if accumulated_input and expected_input is not None: + assert _match_action_input(accumulated_input, expected_input), ( + f"tool_run_id {tid} ({tool_name}): accumulated tool_input " + f"does not match post_parse action_input. " + f"Accumulated: {accumulated_input!r}\n" + f"Expected: {expected_input!r}" + ) + + expected_thought = content.get("thought") + if reasoning_blocks and expected_thought: + accumulated_thought = reasoning_blocks[0] + # Streaming emits raw JSON string content (e.g. \n as two chars), + # while post_parse decodes via json.loads (real newline). + # Encode expected to raw JSON form for a fair comparison. + expected_thought_raw = json.dumps(expected_thought)[1:-1] + assert accumulated_thought == expected_thought or accumulated_thought == expected_thought_raw, ( + f"tool_run_id {tid} ({tool_name}): accumulated reasoning " + f"({len(accumulated_thought)} chars) does not match post_parse thought. " + f"Accumulated: {accumulated_thought!r}\n" + f"Expected: {expected_thought!r}" + ) + logger.debug(f"[reasoning_block] {tool_name} ({tid}): thought matched " f"({len(accumulated_thought)} chars)") + + +def _assert_fsm_end(tool_blocks, reasoning_blocks, run_parallel_count): + """Assert clean end state: no in-flight tool blocks, reasoning, or parallel batches.""" + assert len(tool_blocks) == 0, ( + f"Unresolved tool blocks at end of stream: " f"{[{tid: b['name']} for tid, b in tool_blocks.items()]}" + ) + assert len(reasoning_blocks) == 0, ( + f"Unresolved reasoning blocks at end of stream: {len(reasoning_blocks)} remaining. " + f"Previews: {[b[:80] + '...' if len(b) > 80 else b for b in reasoning_blocks]}" + ) + assert run_parallel_count == 0, ( + f"Unresolved run-parallel events: run_parallel_count={run_parallel_count} " + f"(post_parse_reasoning without matching tool_result or vice versa)" + ) + + +def _log_event(idx, step, event_name, state, content): + """Debug-log a single FSM event.""" + content_preview = content + if isinstance(content, str): + content_preview = repr(content[:80]) if len(content) > 80 else repr(content) + logger.debug( + f"[FSM] Event {idx}: step={step}, event={event_name}, " f"state={state.name}, content={content_preview}" + ) + + +# --------------------------------------------------------------------------- +# Per-mode FSM runners +# --------------------------------------------------------------------------- + + +def _run_fsm_fc(ordered_events, streaming_mode): + """FSM for FUNCTION_CALLING mode. + + FC streams per-tool (reasoning → tool_input_start → tool_input) blocks, + then a single run-parallel post_parse_reasoning (no individual post_parses + for parallel tools). Validates per-tool action_input inside run-parallel. + """ + transitions = _TRANSITIONS_FINAL if streaming_mode == StreamingMode.FINAL else _TRANSITIONS_FC + state = State.INIT + visited = {state} + reasoning_blocks: list[str] = [] + tool_blocks: dict[str, dict] = {} + run_parallel_count = 0 + + for idx, (step, content) in enumerate(ordered_events): + event_name = _classify_event(step, content) + _log_event(idx, step, event_name, state, content) + next_state = _fsm_step_transition(event_name, state, transitions, idx, step, content) + + _track_reasoning(event_name, state, next_state, content, reasoning_blocks) + _track_tool_input(event_name, content, tool_blocks) + + if event_name == "post_parse_reasoning" and isinstance(content, dict): + tool_name = content.get("action", "") + is_run_parallel = tool_name == "run-parallel" + + if is_run_parallel: + run_parallel_count += 1 + # FC parallel: multiple per-tool tool_blocks must be in-flight + assert len(tool_blocks) > 1, ( + f"run-parallel post_parse_reasoning but only " + f"{len(tool_blocks)} tool(s) in-flight: " + f"{[b['name'] for b in tool_blocks.values()]}" + ) + assert run_parallel_count == 1, ( + f"Expected exactly 1 run-parallel event per parallel batch, " f"got {run_parallel_count}" + ) + + # Validate per-tool entries inside action_input + action_input = content.get("action_input") + if isinstance(action_input, list): + tool_entries = [e for e in action_input if isinstance(e, dict)] + + # Each tool must have its own reasoning block, validated by index + assert len(reasoning_blocks) >= len(tool_entries), ( + f"run-parallel: {len(tool_entries)} tool entries but only " + f"{len(reasoning_blocks)} reasoning blocks accumulated" + ) + + for i, entry in enumerate(tool_entries): + entry_tid = entry.get("tool_run_id") + entry_input = entry.get("action_input") + entry_thought = entry.get("thought", "") + tool_label = entry.get("action", "?") + + # Validate action_input match + if entry_tid and entry_tid in tool_blocks: + block = tool_blocks[entry_tid] + accumulated = "".join(block["action_input_chunks"]) + if accumulated and entry_input is not None: + assert _match_action_input(accumulated, entry_input), ( + f"run-parallel tool[{i}] {entry_tid} " + f"({tool_label}): accumulated tool_input " + f"does not match action_input. " + f"Accumulated: {accumulated!r}\n" + f"Expected: {entry_input!r}" + ) + + # Validate thought against reasoning_blocks[i]. + # Streaming may emit raw JSON escapes (e.g. \n as two chars), + # so also compare against the JSON-encoded form. + accumulated_thought = reasoning_blocks[i] + expected_thought_raw = json.dumps(entry_thought)[1:-1] + assert accumulated_thought == entry_thought or accumulated_thought == expected_thought_raw, ( + f"run-parallel tool[{i}] {entry_tid} " + f"({tool_label}): reasoning_blocks[{i}] " + f"does not match per-tool thought. " + f"Accumulated: {accumulated_thought!r}\n" + f"Expected: {entry_thought!r}" + ) + else: + _validate_single_post_parse(content, tool_blocks, reasoning_blocks) + + run_parallel_count = _handle_tool_result(event_name, content, tool_blocks, reasoning_blocks, run_parallel_count) + _handle_answer(event_name, reasoning_blocks) + + state = next_state + visited.add(state) + + _assert_fsm_end(tool_blocks, reasoning_blocks, run_parallel_count) + return state, visited, reasoning_blocks + + +def _run_fsm_blob(ordered_events, streaming_mode): + """FSM for STRUCTURED_OUTPUT and XML modes. + + SO/XML stream a single run-parallel blob: + tool_input_start(action=run-parallel) → tool_input chunks → post_parse → tool results + + No per-tool tool_input_start events inside a parallel batch. + """ + transitions = _TRANSITIONS_FINAL if streaming_mode == StreamingMode.FINAL else _TRANSITIONS_WITH_TOOL_INPUT + state = State.INIT + visited = {state} + reasoning_blocks: list[str] = [] + tool_blocks: dict[str, dict] = {} + run_parallel_count = 0 + + for idx, (step, content) in enumerate(ordered_events): + event_name = _classify_event(step, content) + _log_event(idx, step, event_name, state, content) + next_state = _fsm_step_transition(event_name, state, transitions, idx, step, content) + + _track_reasoning(event_name, state, next_state, content, reasoning_blocks) + _track_tool_input(event_name, content, tool_blocks) + + if event_name == "post_parse_reasoning" and isinstance(content, dict): + tool_name = content.get("action", "") + is_run_parallel = tool_name == "run-parallel" + + if is_run_parallel: + run_parallel_count += 1 + # SO/XML parallel: a single run-parallel blob in tool_blocks + has_parallel_context = len(tool_blocks) > 1 or any( + b["name"] == "run-parallel" for b in tool_blocks.values() + ) + assert has_parallel_context, ( + f"run-parallel post_parse_reasoning but no parallel context: " + f"{len(tool_blocks)} tool(s) in-flight: " + f"{[b['name'] for b in tool_blocks.values()]}" + ) + assert run_parallel_count == 1, ( + f"Expected exactly 1 run-parallel event per parallel batch, " f"got {run_parallel_count}" + ) + + # Validate thought matches accumulated reasoning + expected_thought = content.get("thought", "") + if reasoning_blocks and expected_thought: + accumulated_thought = reasoning_blocks[0] + if accumulated_thought: + assert accumulated_thought == expected_thought, ( + f"run-parallel: accumulated reasoning " + f"({len(accumulated_thought)} chars) does not match thought. " + f"Accumulated: {accumulated_thought!r}\n" + f"Expected: {expected_thought!r}" + ) + else: + _validate_single_post_parse(content, tool_blocks, reasoning_blocks) + + run_parallel_count = _handle_tool_result(event_name, content, tool_blocks, reasoning_blocks, run_parallel_count) + _handle_answer(event_name, reasoning_blocks) + + state = next_state + visited.add(state) + + _assert_fsm_end(tool_blocks, reasoning_blocks, run_parallel_count) + return state, visited, reasoning_blocks + + +def _run_fsm_default(ordered_events, streaming_mode): + """FSM for DEFAULT mode. + + DEFAULT mode has no tool_input streaming phase. Only validates + transitions and reasoning/tool_result lifecycle. + """ + transitions = _TRANSITIONS_FINAL if streaming_mode == StreamingMode.FINAL else _TRANSITIONS_DEFAULT + state = State.INIT + visited = {state} + reasoning_blocks: list[str] = [] + run_parallel_count = 0 + tool_blocks: dict[str, dict] = {} + + for idx, (step, content) in enumerate(ordered_events): + event_name = _classify_event(step, content) + _log_event(idx, step, event_name, state, content) + next_state = _fsm_step_transition(event_name, state, transitions, idx, step, content) + + _track_reasoning(event_name, state, next_state, content, reasoning_blocks) + + if event_name == "post_parse_reasoning" and isinstance(content, dict): + tool_name = content.get("action", "") + if tool_name == "run-parallel": + run_parallel_count += 1 + + run_parallel_count = _handle_tool_result(event_name, content, tool_blocks, reasoning_blocks, run_parallel_count) + _handle_answer(event_name, reasoning_blocks) + + state = next_state + visited.add(state) + + _assert_fsm_end(tool_blocks, reasoning_blocks, run_parallel_count) + return state, visited, reasoning_blocks + + +# --------------------------------------------------------------------------- +# High-level assertion +# --------------------------------------------------------------------------- + + +def assert_streaming_events( + ordered_events: list, + inference_mode: InferenceMode, + streaming_mode: StreamingMode = StreamingMode.ALL, +): + """Validate ordered streaming events against the FSM event policy. + + Args: + ordered_events: List of (step, content) tuples from collect_streaming_events(). + inference_mode: The InferenceMode the agent was configured with. + streaming_mode: The StreamingMode used during the run. + """ + assert len(ordered_events) > 0, "No streaming events collected" + + steps = [s for s, _ in ordered_events] + step_counts = {} + for s in steps: + step_counts[s] = step_counts.get(s, 0) + 1 + + logger.info( + f"Asserting streaming FSM {inference_mode.value}/{streaming_mode.value}: " + f"{len(ordered_events)} events, counts = {step_counts}" + ) + + if inference_mode == InferenceMode.FUNCTION_CALLING: + final_state, visited, reasoning_blocks = _run_fsm_fc(ordered_events, streaming_mode) + elif inference_mode in (InferenceMode.STRUCTURED_OUTPUT, InferenceMode.XML): + final_state, visited, reasoning_blocks = _run_fsm_blob(ordered_events, streaming_mode) + else: + final_state, visited, reasoning_blocks = _run_fsm_default(ordered_events, streaming_mode) + + logger.info(f"Reasoning blocks: {reasoning_blocks}") + + assert final_state == State.ANSWER, ( + f"FSM ended in {final_state.name}, expected ANSWER. " f"Last event: {ordered_events[-1]}" + ) + + if streaming_mode == StreamingMode.ALL: + assert State.REASONING in visited, ( + f"{inference_mode.value}/ALL: never entered REASONING state. " f"Visited: {[s.name for s in visited]}" + ) diff --git a/tests/integration_with_creds/agents/test_agent_native_parallel.py b/tests/integration_with_creds/agents/test_agent_native_parallel.py index 1f55e7fae..8b5305515 100644 --- a/tests/integration_with_creds/agents/test_agent_native_parallel.py +++ b/tests/integration_with_creds/agents/test_agent_native_parallel.py @@ -6,12 +6,15 @@ from dynamiq import Workflow, connections from dynamiq.callbacks import TracingCallbackHandler +from dynamiq.callbacks.streaming import StreamingIteratorCallbackHandler from dynamiq.flows import Flow from dynamiq.nodes.agents import Agent from dynamiq.nodes.llms import OpenAI from dynamiq.nodes.tools.python import Python from dynamiq.nodes.types import InferenceMode from dynamiq.runnables import RunnableConfig, RunnableStatus +from dynamiq.types.streaming import StreamingConfig, StreamingMode +from tests.integration_with_creds.agents.streaming_assertions import assert_streaming_events, collect_streaming_events def _make_tool(name: str, code: str) -> Python: @@ -25,7 +28,9 @@ def _make_tool(name: str, code: str) -> Python: @pytest.mark.integration @pytest.mark.flaky(reruns=3) -@pytest.mark.parametrize("inference_mode", [InferenceMode.FUNCTION_CALLING, InferenceMode.XML]) +@pytest.mark.parametrize( + "inference_mode", [InferenceMode.XML, InferenceMode.STRUCTURED_OUTPUT, InferenceMode.FUNCTION_CALLING] +) def test_parallel_tool_calling(inference_mode: InferenceMode): """Agent with parallel_tool_calls_enabled calls two tools in parallel for both inference modes.""" if not os.getenv("OPENAI_API_KEY"): @@ -33,25 +38,28 @@ def test_parallel_tool_calling(inference_mode: InferenceMode): llm = OpenAI(model="gpt-5.4-mini", connection=connections.OpenAI()) - tool_a = _make_tool("CatFacts", 'output = "Cats sleep 12-16 hours per day."') - tool_b = _make_tool("DogFacts", 'output = "Dogs have a sense of smell 40x better than humans."') + tool_a = _make_tool("CatFacts", 'def run(input_data):\n return "Cats sleep 12-16 hours per day."') + tool_b = _make_tool( + "DogFacts", 'def run(input_data):\n return "Dogs have a sense of smell 40x better than humans."' + ) agent = Agent( - name="ParallelAgent", - role="You have two fact tools. When asked about multiple animals, call both tools simultaneously.", + name="parallel_tools_agent", + role="Researcher agent.", llm=llm, tools=[tool_a, tool_b], - inference_mode=inference_mode, parallel_tool_calls_enabled=True, - max_loops=5, + streaming=StreamingConfig(enabled=True, mode=StreamingMode.ALL), + inference_mode=inference_mode, ) tracing = TracingCallbackHandler() + streaming_handler = StreamingIteratorCallbackHandler() wf = Workflow(flow=Flow(nodes=[agent])) result = wf.run( - input_data={"input": "Tell me a fact about cats and a fact about dogs."}, - config=RunnableConfig(callbacks=[tracing]), + input_data={"input": "Call a cat tool and dog tool in parallel"}, + config=RunnableConfig(callbacks=[tracing, streaming_handler]), ) assert result.status == RunnableStatus.SUCCESS @@ -63,3 +71,6 @@ def test_parallel_tool_calling(inference_mode: InferenceMode): run for run in tracing.runs.values() if getattr(run, "metadata", {}).get("node", {}).get("name") == "OpenAI" ] assert len(llm_runs) <= 2, f"Expected at most 2 LLM loops (parallel tools + final answer), got {len(llm_runs)}" + + ordered_events = collect_streaming_events(streaming_handler, agent.id) + assert_streaming_events(ordered_events, inference_mode) diff --git a/tests/integration_with_creds/agents/test_agent_python_tool.py b/tests/integration_with_creds/agents/test_agent_python_tool.py index 54f434a33..0e4eb128e 100644 --- a/tests/integration_with_creds/agents/test_agent_python_tool.py +++ b/tests/integration_with_creds/agents/test_agent_python_tool.py @@ -4,16 +4,22 @@ import pytest from pydantic import BaseModel, ConfigDict, Field +from dynamiq import Workflow +from dynamiq.callbacks.streaming import StreamingIteratorCallbackHandler from dynamiq.connections import Anthropic as AnthropicConnection from dynamiq.connections import OpenAI as OpenAIConnection +from dynamiq.flows import Flow from dynamiq.nodes import Node, NodeGroup from dynamiq.nodes.agents import Agent from dynamiq.nodes.llms import Anthropic, OpenAI from dynamiq.nodes.tools.python import Python from dynamiq.nodes.types import InferenceMode from dynamiq.runnables import RunnableConfig, RunnableStatus +from dynamiq.types.streaming import StreamingConfig, StreamingMode from dynamiq.utils.logger import logger +from .streaming_assertions import assert_streaming_events, collect_streaming_events + class OutputFormat(str, Enum): SUMMARY = "summary" @@ -176,7 +182,7 @@ def llm_instance(): connection = OpenAIConnection() llm = OpenAI( connection=connection, - model="gpt-5-mini", + model="gpt-5.4-mini", max_tokens=5000, temperature=0, ) @@ -212,28 +218,29 @@ def comprehensive_tool(): def run_and_assert_agent(agent: Agent, agent_input, expected_length, run_config): - """Helper function to run agent and perform common assertions.""" + """Helper function to run agent and perform common assertions including streaming validation.""" logger.info(f"\n--- Running Agent: {agent.name} (Mode: {agent.inference_mode.value}) ---") - agent_output = None - try: - result = agent.run(input_data=agent_input, config=run_config) - logger.info(f"Agent raw result object: {result}") - if result.status != RunnableStatus.SUCCESS: - pytest.fail(f"Agent run failed with status '{result.status}'. Output: {result.output}.") + streaming = StreamingIteratorCallbackHandler() + wf = Workflow(flow=Flow(nodes=[agent])) + result = wf.run( + input_data=agent_input, + config=RunnableConfig(callbacks=[streaming], request_timeout=120), + ) - if isinstance(result.output, dict) and "content" in result.output: - agent_output = result.output["content"] - else: - agent_output = result.output - logger.info(f"Warning: Agent output structure unexpected: {type(result.output)}") + assert ( + result.status == RunnableStatus.SUCCESS + ), f"Agent run failed with status '{result.status}'. Output: {result.output}." - logger.info(f"Agent final output content: {agent_output}") + agent_result = result.output.get(agent.id, {}).get("output", {}) + if isinstance(agent_result, dict) and "content" in agent_result: + agent_output = agent_result["content"] + else: + agent_output = agent_result + logger.info(f"Warning: Agent output structure unexpected: {type(agent_result)}") - except Exception as e: - pytest.fail(f"Agent run failed with exception: {e}") + logger.info(f"Agent final output content: {agent_output}") - logger.info("Asserting results...") assert agent_output is not None, "Agent output content should not be None" assert isinstance(agent_output, str), f"Agent output content should be a string, got {type(agent_output)}" @@ -242,6 +249,9 @@ def run_and_assert_agent(agent: Agent, agent_input, expected_length, run_config) expected_length_str in agent_output ), f"Expected length '{expected_length_str}' not found in agent output: '{agent_output}'" + ordered_events = collect_streaming_events(streaming, agent.id) + assert_streaming_events(ordered_events, agent.inference_mode, agent.streaming.mode) + logger.info(f"--- Test Passed for Mode: {agent.inference_mode.value} ---") @@ -267,12 +277,16 @@ def test_react_agent_inference_modes( role=agent_role, inference_mode=inference_mode, verbose=True, + streaming=StreamingConfig( + enabled=True, + mode=StreamingMode.ALL, + ), ) run_and_assert_agent(agent, agent_input, expected_length, run_config) def _run_comprehensive_schema_test(llm, comprehensive_tool, run_config, inference_mode, label): - """Helper: run agent with ComprehensiveTool and assert success.""" + """Helper: run agent with ComprehensiveTool and assert success with streaming validation.""" agent = Agent( name=f"Comprehensive Schema Test ({label})", llm=llm, @@ -284,18 +298,30 @@ def _run_comprehensive_schema_test(llm, comprehensive_tool, run_config, inferenc inference_mode=inference_mode, max_loops=5, verbose=True, + streaming=StreamingConfig( + enabled=True, + mode=StreamingMode.ALL, + ), ) - result = agent.run( + + streaming = StreamingIteratorCallbackHandler() + wf = Workflow(flow=Flow(nodes=[agent])) + result = wf.run( input_data={ "input": "Analyze the text 'Hello world from Python testing' in summary format with limit " "5 and high priority" }, - config=run_config, + config=RunnableConfig(callbacks=[streaming], request_timeout=120), ) + assert result.status == RunnableStatus.SUCCESS, f"Agent run failed for {label}: {result.output}" - content = result.output.get("content", "") + agent_result = result.output.get(agent.id, {}).get("output", {}) + content = agent_result.get("content", "") if isinstance(agent_result, dict) else agent_result assert isinstance(content, str) and len(content) > 0, f"Expected non-empty string for {label}, got: {content!r}" + ordered_events = collect_streaming_events(streaming, agent.id) + assert_streaming_events(ordered_events, inference_mode, agent.streaming.mode) + @pytest.mark.integration @pytest.mark.flaky(reruns=3)