From 0e88ec91f9eb029074f4283b9df52ef42d0a8a9a Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Wed, 31 Dec 2025 14:04:12 +0000 Subject: [PATCH 1/7] improve async call and tests --- grafi/tools/llms/impl/gemini_tool.py | 2 +- grafi/workflows/impl/event_driven_workflow.py | 2 + tests/assistants/test_assistant_mock_llm.py | 2282 +++++++++++++++++ .../simple_gemini_function_call_assistant.py | 2 +- .../gemini_tool_example.py | 5 +- 5 files changed, 2289 insertions(+), 4 deletions(-) create mode 100644 tests/assistants/test_assistant_mock_llm.py diff --git a/grafi/tools/llms/impl/gemini_tool.py b/grafi/tools/llms/impl/gemini_tool.py index 04ef60e..ee11320 100644 --- a/grafi/tools/llms/impl/gemini_tool.py +++ b/grafi/tools/llms/impl/gemini_tool.py @@ -60,7 +60,7 @@ class GeminiTool(LLM): name: str = Field(default="GeminiTool") type: str = Field(default="GeminiTool") api_key: Optional[str] = Field(default_factory=lambda: os.getenv("GEMINI_API_KEY")) - model: str = Field(default="gemini-2.0-flash-lite") + model: str = Field(default="gemini-2.5-flash-lite") @classmethod def builder(cls) -> "GeminiToolBuilder": diff --git a/grafi/workflows/impl/event_driven_workflow.py b/grafi/workflows/impl/event_driven_workflow.py index 92d8fd6..1efc108 100644 --- a/grafi/workflows/impl/event_driven_workflow.py +++ b/grafi/workflows/impl/event_driven_workflow.py @@ -543,6 +543,8 @@ def _cancel_all_active_tasks() -> None: except Exception as node_error: logger.error(f"Error processing node {node.name}: {node_error}") + # Force stop the tracker so the workflow terminates + await self._tracker.force_stop() raise NodeExecutionError( node_name=node.name, message=f"Async node execution failed: {node_error}", diff --git a/tests/assistants/test_assistant_mock_llm.py b/tests/assistants/test_assistant_mock_llm.py new file mode 100644 index 0000000..923ebac --- /dev/null +++ b/tests/assistants/test_assistant_mock_llm.py @@ -0,0 +1,2282 @@ +""" +Unit tests for assistants using FunctionTool to simulate LLM behavior. + +This module provides tests for: +1. Human-in-the-Loop (HITL) workflows +2. ReAct agent patterns with function calling + +All tests use FunctionTool to deterministically mock LLM responses, +enabling reliable unit testing without real API calls. +""" + +import base64 +import inspect +import uuid +from typing import Any +from typing import Callable +from typing import List +from typing import Union +from unittest.mock import patch + +import cloudpickle +import pytest +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, +) +from openai.types.chat.chat_completion_message_tool_call import Function +from openinference.semconv.trace import OpenInferenceSpanKindValues + +from grafi.assistants.assistant import Assistant +from grafi.common.decorators.record_decorators import record_tool_invoke +from grafi.common.events.topic_events.publish_to_topic_event import PublishToTopicEvent +from grafi.common.exceptions import FunctionToolException +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.common.models.message import Messages +from grafi.common.models.message import MsgsAGen +from grafi.nodes.node import Node +from grafi.tools.command import Command +from grafi.tools.command import use_command +from grafi.tools.function_calls.function_call_tool import FunctionCallTool +from grafi.tools.tool import Tool +from grafi.topics.expressions.subscription_builder import SubscriptionBuilder +from grafi.topics.topic_impl.in_workflow_input_topic import InWorkflowInputTopic +from grafi.topics.topic_impl.in_workflow_output_topic import InWorkflowOutputTopic +from grafi.topics.topic_impl.input_topic import InputTopic +from grafi.topics.topic_impl.output_topic import OutputTopic +from grafi.topics.topic_impl.topic import Topic +from grafi.workflows.impl.event_driven_workflow import EventDrivenWorkflow + + +@use_command(Command) +class LLMMockTool(Tool): + name: str = "LLMMockTool" + type: str = "LLMMockTool" + role: str = "assistant" + function: Callable[[Messages], Union[Message, Messages]] + oi_span_type: OpenInferenceSpanKindValues = OpenInferenceSpanKindValues.TOOL + + @record_tool_invoke + async def invoke( + self, invoke_context: InvokeContext, input_data: Messages + ) -> MsgsAGen: + try: + response = self.function(input_data) + if inspect.isasyncgen(response): + async for item in response: + # Ensure item is always a list + if isinstance(item, list): + yield item + else: + yield [item] + return + if inspect.isawaitable(response): + response = await response + + # Ensure response is always a list + if isinstance(response, list): + yield response + else: + yield [response] + except Exception as e: + raise FunctionToolException( + tool_name=self.name, + operation="invoke", + message=f"Async function execution failed: {e}", + invoke_context=invoke_context, + cause=e, + ) from e + + + def to_dict(self) -> dict[str, Any]: + """ + Convert the tool instance to a dictionary representation. + + Returns: + Dict[str, Any]: A dictionary representation of the tool. + """ + return { + **super().to_dict(), + "role": self.role, + "base_class": "FunctionTool", + "function": base64.b64encode(cloudpickle.dumps(self.function)).decode( + "utf-8" + ), + } + + @classmethod + async def from_dict(cls, data: dict[str, Any]) -> "LLMMockTool": + """ + Create a FunctionTool instance from a dictionary representation. + + Args: + data (dict[str, Any]): A dictionary representation of the FunctionTool. + + Returns: + FunctionTool: A FunctionTool instance created from the dictionary. + + Note: + Functions cannot be fully reconstructed from serialized data as they + contain executable code. This method creates an instance with a + placeholder function that needs to be re-registered after deserialization. + """ + + return cls( + name=data.get("name", "LLMMockTool"), + type=data.get("type", "LLMMockTool"), + role=data.get("role", "assistant"), + function=lambda messages: messages, + oi_span_type=OpenInferenceSpanKindValues.TOOL, + ) + + + +def make_tool_call(call_id: str, name: str, arguments: str) -> ChatCompletionMessageToolCall: + """Helper to create tool calls.""" + return ChatCompletionMessageToolCall( + id=call_id, + type="function", + function=Function(name=name, arguments=arguments), + ) + + +class TestReActAgentWithMockLLM: + """ + Test ReAct agent patterns using FunctionTool to simulate LLM behavior. + + ReAct (Reasoning and Acting) agent pattern: + 1. LLM receives input and decides whether to call a function or respond + 2. If function call -> execute function -> return result to LLM + 3. LLM processes function result and decides next action + 4. Loop continues until LLM generates final response (no function call) + """ + + @pytest.fixture + def invoke_context(self): + """Create a test invoke context.""" + return InvokeContext( + conversation_id=uuid.uuid4().hex, + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + @pytest.mark.asyncio + async def test_react_agent_no_function_call(self, invoke_context): + """ + Test ReAct agent when LLM directly responds without function calls. + + Flow: Input -> LLM (no function call) -> Output + """ + # Mock LLM that always responds directly without function calls + def mock_llm(messages: List[Message]) -> List[Message]: + user_content = messages[-1].content if messages else "" + return [Message(role="assistant", content=f"Direct response to: {user_content}")] + + # Create topics + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + # Only output when there's content and no tool calls + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + + # Create LLM node + llm_node = ( + Node.builder() + .name("MockLLMNode") + .subscribe(SubscriptionBuilder().subscribed_to(agent_input).build()) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + # Build workflow + workflow = ( + EventDrivenWorkflow.builder() + .name("react_no_func_workflow") + .node(llm_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="TestReActAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Hello, how are you?")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "Direct response to: Hello, how are you?" in results[0].data[0].content + + @pytest.mark.asyncio + async def test_react_agent_single_function_call(self, invoke_context): + """ + Test ReAct agent with a single function call. + + Flow: Input -> LLM (function call) -> Function -> LLM (response) -> Output + """ + call_count = {"llm": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + """Mock LLM that calls function on first call, responds on second.""" + call_count["llm"] += 1 + + if call_count["llm"] == 1: + # First call: make a function call + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "call_1", + "search", + '{"query": "weather today"}', + ) + ], + )] + else: + # Second call: respond with the function result + last_msg = messages[-1] if messages else Message(role="user", content="") + return [Message( + role="assistant", + content=f"Based on search: {last_msg.content}", + )] + + def search(self, query: str) -> str: + """Mock search function.""" + return "The weather is sunny, 72°F" + + # Create topics + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + # Create LLM node + llm_node = ( + Node.builder() + .name("MockLLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + # Create function call node + function_node = ( + Node.builder() + .name("SearchNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool(FunctionCallTool.builder().name("SearchTool").function(search).build()) + .publish_to(function_result_topic) + .build() + ) + + # Build workflow + workflow = ( + EventDrivenWorkflow.builder() + .name("react_single_func_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="TestReActAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="What's the weather?")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "Based on search" in results[0].data[0].content + assert call_count["llm"] == 2 + + @pytest.mark.asyncio + async def test_react_agent_multiple_function_calls(self, invoke_context): + """ + Test ReAct agent with multiple sequential function calls. + + Flow: Input -> LLM (func1) -> Func1 -> LLM (func2) -> Func2 -> LLM (response) -> Output + """ + call_count = {"llm": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + """Mock LLM that makes multiple function calls.""" + call_count["llm"] += 1 + + if call_count["llm"] == 1: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("call_1", "get_user", '{"id": "123"}') + ], + )] + elif call_count["llm"] == 2: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "call_2", + "get_orders", + '{"user_id": "123"}', + ) + ], + )] + else: + return [Message( + role="assistant", content="User John has 3 orders totaling $150." + )] + + def get_user(self, id: str) -> str: + """Mock get_user function.""" + return '{"name": "John", "id": "123"}' + + def get_orders(self, user_id: str) -> str: + """Mock get_orders function.""" + return '{"orders": 3, "total": "$150"}' + + # Create topics + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + llm_node = ( + Node.builder() + .name("MockLLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("FunctionNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("MockFunction") + .function(get_user) + .function(get_orders) + .build() + ) + .publish_to(function_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("react_multi_func_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="TestReActAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Get user 123's order summary")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "John" in results[0].data[0].content + assert "3 orders" in results[0].data[0].content + assert call_count["llm"] == 3 + + +class TestHumanInTheLoopWithMockLLM: + """ + Test Human-in-the-Loop (HITL) workflows using FunctionTool to simulate LLM behavior. + + HITL workflow pattern: + 1. LLM processes input and decides to request human input + 2. Workflow pauses, emits event for human response + 3. Human provides input via InWorkflowInputTopic + 4. Workflow continues with human input + 5. LLM generates final response + """ + + @pytest.fixture + def invoke_context(self): + """Create a test invoke context.""" + return InvokeContext( + conversation_id=uuid.uuid4().hex, + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + @pytest.mark.asyncio + async def test_hitl_workflow_no_human_input_needed(self, invoke_context): + """ + Test HITL workflow when LLM can respond without human input. + + Flow: Input -> LLM (direct response) -> Output + """ + def mock_llm(messages: List[Message]) -> List[Message]: + return [Message(role="assistant", content="I can answer this directly!")] + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + hitl_call_topic = Topic( + name="hitl_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + + llm_node = ( + Node.builder() + .name("MockLLMNode") + .subscribe(SubscriptionBuilder().subscribed_to(agent_input).build()) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(hitl_call_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("hitl_no_human_workflow") + .node(llm_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="TestHITLAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Simple question")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert results[0].data[0].content == "I can answer this directly!" + + @pytest.mark.asyncio + async def test_hitl_workflow_with_human_approval(self, invoke_context): + """ + Test HITL workflow that requests and receives human approval. + + Flow: + 1. Input -> LLM (requests approval) -> HITL Output + 2. Human approval -> HITL Input -> LLM (processes approval) -> Output + """ + call_count = {"llm": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + """Mock LLM that requests approval on first call.""" + call_count["llm"] += 1 + + if call_count["llm"] == 1: + # First call: request human approval + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "approval_1", + "request_approval", + '{"action": "delete_account", "reason": "user requested"}', + ) + ], + )] + else: + # Second call: process approval result + last_content = messages[-1].content if messages else "" + if "approved" in last_content.lower(): + return [Message( + role="assistant", + content="Account deletion has been approved and completed.", + )] + else: + return [Message( + role="assistant", + content="Account deletion was rejected.", + )] + + def request_approval(self, action: str, reason: str) -> str: + """Mock HITL request that simulates human approval.""" + # In a real scenario, this would pause and wait for human input + # For testing, we simulate automatic approval + return "Action APPROVED by human reviewer" + + # Create topics + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + hitl_call_topic = Topic( + name="hitl_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + + # HITL topics for human interaction + human_response_topic = InWorkflowInputTopic(name="human_response") + human_request_topic = InWorkflowOutputTopic( + name="human_request", + paired_in_workflow_input_topic_names=["human_response"], + ) + hitl_result_topic = Topic(name="hitl_result") + + # LLM node + llm_node = ( + Node.builder() + .name("MockLLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(hitl_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(hitl_call_topic) + .build() + ) + + # HITL request node + hitl_node = ( + Node.builder() + .name("HITLRequestNode") + .subscribe(SubscriptionBuilder().subscribed_to(hitl_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("HITLRequest") + .function(request_approval) + .build() + ) + .publish_to(hitl_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("hitl_approval_workflow") + .node(llm_node) + .node(hitl_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="TestHITLAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Please delete my account")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "approved" in results[0].data[0].content.lower() + assert call_count["llm"] == 2 + + @pytest.mark.asyncio + async def test_hitl_workflow_with_rejection(self, invoke_context): + """ + Test HITL workflow where human rejects the action. + """ + call_count = {"llm": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + call_count["llm"] += 1 + + if call_count["llm"] == 1: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "approval_1", + "request_approval", + '{"action": "transfer_funds", "amount": "$10000"}', + ) + ], + )] + else: + last_content = messages[-1].content if messages else "" + if "rejected" in last_content.lower(): + return [Message( + role="assistant", + content="The fund transfer was not approved. No action taken.", + )] + return [Message(role="assistant", content="Transfer completed.")] + + def request_approval(self, action: str, amount: str) -> str: + """Simulate human rejection.""" + return "Action REJECTED - amount too high" + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + hitl_call_topic = Topic( + name="hitl_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + hitl_result_topic = Topic(name="hitl_result") + + llm_node = ( + Node.builder() + .name("MockLLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(hitl_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(hitl_call_topic) + .build() + ) + + hitl_node = ( + Node.builder() + .name("HITLRejectNode") + .subscribe(SubscriptionBuilder().subscribed_to(hitl_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("HITLReject") + .function(request_approval) + .build() + ) + .publish_to(hitl_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("hitl_rejection_workflow") + .node(llm_node) + .node(hitl_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="TestHITLAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Transfer $10000 to account X")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "not approved" in results[0].data[0].content.lower() + + @pytest.mark.asyncio + async def test_hitl_workflow_multi_step_approval(self, invoke_context): + """ + Test HITL workflow with multiple approval steps. + + Flow: Input -> LLM (approval1) -> Human1 -> LLM (approval2) -> Human2 -> LLM -> Output + """ + call_count = {"llm": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + call_count["llm"] += 1 + + if call_count["llm"] == 1: + # First approval: manager + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "approval_1", + "request_manager_approval", + '{"action": "large_purchase", "amount": "$5000"}', + ) + ], + )] + elif call_count["llm"] == 2: + # Second approval: finance + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "approval_2", + "request_finance_approval", + '{"action": "large_purchase", "amount": "$5000"}', + ) + ], + )] + else: + # Final response + return [Message( + role="assistant", + content="Purchase approved by manager and finance. Order placed!", + )] + + approval_count = {"count": 0} + + def request_manager_approval(self, action: str, amount: str) -> str: + approval_count["count"] += 1 + return "Manager APPROVED" + + def request_finance_approval(self, action: str, amount: str) -> str: + approval_count["count"] += 1 + return "Finance APPROVED" + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + approval_call_topic = Topic( + name="approval_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + approval_result_topic = Topic(name="approval_result") + + llm_node = ( + Node.builder() + .name("MockLLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(approval_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(approval_call_topic) + .build() + ) + + approval_node = ( + Node.builder() + .name("ApprovalNode") + .subscribe(SubscriptionBuilder().subscribed_to(approval_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("ApprovalTool") + .function(request_manager_approval) + .function(request_finance_approval) + .build() + ) + .publish_to(approval_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("hitl_multi_approval_workflow") + .node(llm_node) + .node(approval_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="TestHITLAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="I need to purchase equipment for $5000")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "manager" in results[0].data[0].content.lower() + assert "finance" in results[0].data[0].content.lower() + assert call_count["llm"] == 3 + assert approval_count["count"] == 2 + + +class TestComplexWorkflowPatterns: + """ + Test more complex workflow patterns combining multiple features. + """ + + @pytest.fixture + def invoke_context(self): + return InvokeContext( + conversation_id=uuid.uuid4().hex, + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + @pytest.mark.asyncio + async def test_conditional_branching_workflow(self, invoke_context): + """ + Test workflow with conditional branching based on LLM output. + + Flow: + - Input -> Router LLM + - If question about weather -> Weather function -> Response LLM + - If question about math -> Math function -> Response LLM + - Otherwise -> Direct response + """ + def mock_router(messages: List[Message]) -> List[Message]: + """Route based on input content.""" + content = messages[-1].content.lower() if messages else "" + if "weather" in content: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("w1", "weather", '{"location": "NYC"}') + ], + )] + elif "math" in content or "calculate" in content: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("m1", "math", '{"expr": "2+2"}') + ], + )] + else: + return [Message( + role="assistant", content="I can help with weather or math questions!" + )] + + def weather(self, location: str) -> str: + return "Weather in NYC: Sunny, 75°F" + + def math(self, expr: str) -> str: + return "Result: 4" + + def mock_response(messages: List[Message]) -> List[Message]: + """Generate final response from function result.""" + last_content = messages[-1].content if messages else "" + return [Message(role="assistant", content=f"Here's what I found: {last_content}")] + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + weather_topic = Topic( + name="weather_call", + condition=lambda event: ( + event.data[-1].tool_calls is not None + and any(tc.function.name == "weather" for tc in event.data[-1].tool_calls) + ), + ) + math_topic = Topic( + name="math_call", + condition=lambda event: ( + event.data[-1].tool_calls is not None + and any(tc.function.name == "math" for tc in event.data[-1].tool_calls) + ), + ) + function_result_topic = Topic(name="function_result") + + router_node = ( + Node.builder() + .name("RouterNode") + .subscribe(SubscriptionBuilder().subscribed_to(agent_input).build()) + .tool(LLMMockTool(function=mock_router)) + .publish_to(agent_output) + .publish_to(weather_topic) + .publish_to(math_topic) + .build() + ) + + weather_node = ( + Node.builder() + .name("WeatherNode") + .subscribe(SubscriptionBuilder().subscribed_to(weather_topic).build()) + .tool(FunctionCallTool.builder().name("Weather").function(weather).build()) + .publish_to(function_result_topic) + .build() + ) + + math_node = ( + Node.builder() + .name("MathNode") + .subscribe(SubscriptionBuilder().subscribed_to(math_topic).build()) + .tool(FunctionCallTool.builder().name("Math").function(math).build()) + .publish_to(function_result_topic) + .build() + ) + + response_node = ( + Node.builder() + .name("ResponseNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_result_topic).build()) + .tool(LLMMockTool(function=mock_response)) + .publish_to(agent_output) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("conditional_workflow") + .node(router_node) + .node(weather_node) + .node(math_node) + .node(response_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="ConditionalAgent", workflow=workflow) + + # Test weather branch + weather_input = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="What's the weather in NYC?")], + ) + weather_results = [] + async for event in assistant.invoke(weather_input): + weather_results.append(event) + + assert len(weather_results) == 1 + assert "Sunny" in weather_results[0].data[0].content + + @pytest.mark.asyncio + async def test_parallel_function_execution(self, invoke_context): + """ + Test workflow where LLM calls multiple functions that can execute in parallel. + + The FunctionTool will handle multiple tool_calls in a single message. + """ + + def mock_llm_parallel(messages: List[Message]) -> List[Message]: + """LLM that requests multiple functions at once.""" + # Check if we have function results + has_results = any(msg.role == "tool" for msg in messages) + if has_results: + return [Message( + role="assistant", + content="Combined weather and news: Great day, no major events!", + )] + else: + # Request both functions at once + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("w1", "weather", "{}"), + make_tool_call("n1", "news", "{}"), + ], + )] + + def weather(self) -> str: + """Handle weather function call.""" + return "Weather: Sunny" + + def news(self) -> str: + """Handle news function call.""" + return "News: Markets up 2%" + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + llm_node = ( + Node.builder() + .name("LLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm_parallel)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("FunctionNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("CombinedFunc") + .function(weather) + .function(news) + .build() + ) + .publish_to(function_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("parallel_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="ParallelAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="What's the weather and news?")], + ) + + results = [] + + print("starting invocation") + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "weather" in results[0].data[0].content.lower() + + @pytest.mark.asyncio + async def test_error_handling_in_function_call(self, invoke_context): + """ + Test workflow handles function errors gracefully. + """ + call_count = {"llm": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + call_count["llm"] += 1 + if call_count["llm"] == 1: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("f1", "failing_func", "{}") + ], + )] + else: + # Handle error from function + last_content = messages[-1].content if messages else "" + if "error" in last_content.lower(): + return [Message( + role="assistant", + content="I encountered an error. Let me try a different approach.", + )] + return [Message(role="assistant", content="Success!")] + + def failing_func(self) -> str: + """Function that returns an error.""" + return "Error: Service unavailable" + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + llm_node = ( + Node.builder() + .name("LLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("FailingNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("FailingFunc") + .function(failing_func) + .build() + ) + .publish_to(function_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("error_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="ErrorHandlingAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Do something that might fail")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "error" in results[0].data[0].content.lower() + + @pytest.mark.asyncio + async def test_context_preservation_across_turns(self, invoke_context): + """ + Test that context is properly passed through multiple turns. + """ + accumulated_context = [] + + def mock_llm_with_context(messages: List[Message]) -> List[Message]: + """LLM that tracks conversation context.""" + accumulated_context.append([m.content for m in messages if m.content]) + + if len(accumulated_context) == 1: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("c1", "context_func", "{}") + ], + )] + else: + # Return summary of all seen content + all_content = [c for ctx in accumulated_context for c in ctx] + return [Message( + role="assistant", + content=f"Processed {len(all_content)} messages", + )] + + def context_func(self) -> str: + return "Context function executed" + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + llm_node = ( + Node.builder() + .name("ContextLLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm_with_context)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("ContextFuncNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("ContextFunc") + .function(context_func) + .build() + ) + .publish_to(function_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("context_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="ContextAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Test context preservation")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "Processed" in results[0].data[0].content + # Verify context was accumulated + assert len(accumulated_context) == 2 + + +class TestEdgeCasesAndExceptions: + """ + Test edge cases, error handling, and exception scenarios. + + These tests ensure the workflow handles: + 1. Exceptions during tool execution + 2. Exceptions in LLM mock functions + 3. Empty/invalid message handling + 4. Workflow stop on error + 5. Serialization/deserialization + 6. Edge cases in data flow + """ + + @pytest.fixture + def invoke_context(self): + return InvokeContext( + conversation_id=uuid.uuid4().hex, + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + @pytest.mark.asyncio + async def test_exception_in_function_call_tool(self, invoke_context): + """ + Test that exceptions in FunctionCallTool are properly propagated. + """ + def mock_llm(messages: List[Message]) -> List[Message]: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("err1", "raise_error", '{}') + ], + )] + + def raise_error(self) -> str: + raise ValueError("Intentional test error in function call") + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + + llm_node = ( + Node.builder() + .name("LLMNode") + .subscribe(SubscriptionBuilder().subscribed_to(agent_input).build()) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("ErrorNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("ErrorTool") + .function(raise_error) + .build() + ) + .publish_to(agent_output) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("error_test_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="ErrorTestAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Trigger error")], + ) + + from grafi.common.exceptions import NodeExecutionError + with pytest.raises(NodeExecutionError) as exc_info: + async for _ in assistant.invoke(input_data): + pass + + assert "ErrorNode" in str(exc_info.value) or "Intentional test error" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_exception_in_llm_mock_tool(self, invoke_context): + """ + Test that exceptions in LLMMockTool are properly propagated. + """ + def failing_llm(messages: List[Message]) -> List[Message]: + raise RuntimeError("LLM processing failed") + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: event.data[-1].content is not None, + ) + + llm_node = ( + Node.builder() + .name("FailingLLMNode") + .subscribe(SubscriptionBuilder().subscribed_to(agent_input).build()) + .tool(LLMMockTool(function=failing_llm)) + .publish_to(agent_output) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("failing_llm_workflow") + .node(llm_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="FailingLLMAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Trigger LLM error")], + ) + + from grafi.common.exceptions import NodeExecutionError + with pytest.raises(NodeExecutionError) as exc_info: + async for _ in assistant.invoke(input_data): + pass + + assert "LLM processing failed" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_llm_returns_empty_content(self, invoke_context): + """ + Test handling when LLM returns a message with empty content but no tool calls. + """ + def empty_content_llm(messages: List[Message]) -> List[Message]: + return [Message(role="assistant", content="")] + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: event.data[-1].content is not None, + ) + + llm_node = ( + Node.builder() + .name("EmptyLLMNode") + .subscribe(SubscriptionBuilder().subscribed_to(agent_input).build()) + .tool(LLMMockTool(function=empty_content_llm)) + .publish_to(agent_output) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("empty_content_workflow") + .node(llm_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="EmptyContentAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Get empty response")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert results[0].data[0].content == "" + + @pytest.mark.asyncio + async def test_llm_returns_single_message_not_list(self, invoke_context): + """ + Test that LLMMockTool properly wraps single Message in a list. + """ + def single_message_llm(messages: List[Message]) -> Message: + # Return single Message, not list + return Message(role="assistant", content="Single message response") + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: event.data[-1].content is not None, + ) + + llm_node = ( + Node.builder() + .name("SingleMsgLLMNode") + .subscribe(SubscriptionBuilder().subscribed_to(agent_input).build()) + .tool(LLMMockTool(function=single_message_llm)) + .publish_to(agent_output) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("single_msg_workflow") + .node(llm_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="SingleMsgAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Get single message")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert results[0].data[0].content == "Single message response" + + @pytest.mark.asyncio + async def test_function_call_with_invalid_json_arguments(self, invoke_context): + """ + Test handling of tool calls with malformed JSON arguments. + """ + def mock_llm(messages: List[Message]) -> List[Message]: + return [Message( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="bad_json", + type="function", + function=Function(name="some_func", arguments="not valid json{"), + ) + ], + )] + + def some_func(self) -> str: + return "Should not reach here" + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + + llm_node = ( + Node.builder() + .name("LLMNode") + .subscribe(SubscriptionBuilder().subscribed_to(agent_input).build()) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("FuncNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("SomeFunc") + .function(some_func) + .build() + ) + .publish_to(agent_output) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("invalid_json_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="InvalidJsonAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Trigger invalid JSON")], + ) + + from grafi.common.exceptions import NodeExecutionError + with pytest.raises(NodeExecutionError): + async for _ in assistant.invoke(input_data): + pass + + @pytest.mark.asyncio + async def test_function_not_found_in_tool(self, invoke_context): + """ + Test handling when LLM calls a function that isn't registered. + """ + + call_count = {"llm": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + """Mock LLM that calls function on first call, responds on second.""" + call_count["llm"] += 1 + + if call_count["llm"] == 1: + # First call: make a function call + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("missing", "nonexistent_function", '{}') + ], + ) + ] + else: + return [Message( + role="assistant", + content="Function not found.", + )] + + def existing_func(self) -> str: + return "This is an existing function" + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + llm_node = ( + Node.builder() + .name("LLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("FuncNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("ExistingFunc") + .function(existing_func) + .build() + ) + .publish_to(function_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("missing_func_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="MissingFuncAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Call missing function")], + ) + + # When function is not found, FunctionCallTool yields empty messages + # This may cause workflow to hang or complete without output + # The test verifies this edge case is handled + results = [] + + try: + async for event in assistant.invoke(input_data): + results.append(event) + except Exception: + # Either an exception or empty results is acceptable + pass + + @pytest.mark.asyncio + async def test_workflow_stops_on_node_exception(self, invoke_context): + """ + Test that workflow stops processing when a node raises an exception. + This verifies the force_stop fix in _invoke_node. + """ + call_count = {"count": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + call_count["count"] += 1 + if call_count["count"] == 1: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("fail", "fail_func", '{}') + ], + )] + # Should not reach here if workflow stops on error + return [Message(role="assistant", content="Should not see this")] + + def fail_func(self) -> str: + raise Exception("Node failure - workflow should stop") + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + llm_node = ( + Node.builder() + .name("LLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("FailNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("FailTool") + .function(fail_func) + .build() + ) + .publish_to(function_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("stop_on_error_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="StopOnErrorAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Trigger failure")], + ) + + from grafi.common.exceptions import NodeExecutionError + + # Workflow should raise exception and stop, not hang + with pytest.raises(NodeExecutionError): + async for _ in assistant.invoke(input_data): + pass + + # Verify LLM was only called once (workflow stopped after error) + assert call_count["count"] == 1 + + @pytest.mark.asyncio + async def test_llm_mock_tool_serialization(self, invoke_context): + """ + Test LLMMockTool to_dict and from_dict methods. + """ + def sample_llm(messages: List[Message]) -> List[Message]: + return [Message(role="assistant", content="Serialization test")] + + tool = LLMMockTool( + name="SerializationTestTool", + function=sample_llm, + ) + + # Test to_dict + tool_dict = tool.to_dict() + assert tool_dict["name"] == "SerializationTestTool" + assert "function" in tool_dict + + # Test from_dict + restored_tool = await LLMMockTool.from_dict(tool_dict) + assert restored_tool.name == "SerializationTestTool" + + @pytest.mark.asyncio + async def test_multiple_tool_calls_in_single_message(self, invoke_context): + """ + Test handling multiple tool calls in a single LLM response. + """ + def mock_llm(messages: List[Message]) -> List[Message]: + has_results = any(msg.role == "tool" for msg in messages) + if has_results: + return [Message( + role="assistant", + content="Got results from both functions", + )] + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("t1", "func_a", '{}'), + make_tool_call("t2", "func_b", '{}'), + ], + )] + + def func_a(self) -> str: + return "Result A" + + def func_b(self) -> str: + return "Result B" + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + llm_node = ( + Node.builder() + .name("LLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("MultiFunc") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("MultiFuncTool") + .function(func_a) + .function(func_b) + .build() + ) + .publish_to(function_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("multi_tool_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="MultiToolAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Call multiple functions")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "results" in results[0].data[0].content.lower() + + @pytest.mark.asyncio + async def test_function_returns_complex_json(self, invoke_context): + """ + Test function that returns complex JSON data. + """ + call_count = {"llm": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + call_count["llm"] += 1 + if call_count["llm"] == 1: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("json1", "get_complex_data", '{}') + ], + )] + # Check if we received the complex data + last_content = messages[-1].content if messages else "" + return [Message( + role="assistant", + content=f"Received complex data: {last_content[:50]}...", + )] + + def get_complex_data(self) -> str: + import json + return json.dumps({ + "users": [ + {"id": 1, "name": "Alice", "roles": ["admin", "user"]}, + {"id": 2, "name": "Bob", "roles": ["user"]}, + ], + "metadata": { + "total": 2, + "page": 1, + "nested": {"deep": {"value": True}} + } + }) + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + llm_node = ( + Node.builder() + .name("LLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("JsonFunc") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("JsonTool") + .function(get_complex_data) + .build() + ) + .publish_to(function_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("complex_json_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="JsonAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Get complex data")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "complex data" in results[0].data[0].content.lower() + + @pytest.mark.asyncio + async def test_function_with_special_characters_in_args(self, invoke_context): + """ + Test function call with special characters in arguments. + """ + received_args = {} + + call_count = {"llm": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + call_count["llm"] += 1 + if call_count["llm"] == 1: + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "special", + "process_text", + '{"text": "Hello\\nWorld\\twith\\ttabs", "query": "test \\"quoted\\""}' + ) + ], + )] + return [Message(role="assistant", content="Processed special chars")] + + def process_text(self, text: str, query: str) -> str: + received_args["text"] = text + received_args["query"] = query + return f"Processed: {text}" + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + llm_node = ( + Node.builder() + .name("LLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + function_node = ( + Node.builder() + .name("TextFunc") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("TextTool") + .function(process_text) + .build() + ) + .publish_to(function_result_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("special_chars_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="SpecialCharsAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Test special characters")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + # Verify special characters were properly parsed + assert "\n" in received_args.get("text", "") + assert "\t" in received_args.get("text", "") + assert '"' in received_args.get("query", "") + + + @pytest.mark.asyncio + async def test_react_agent_single_function_call_twice(self): + """ + Test ReAct agent with a single function call. + + Flow: Input -> LLM (function call) -> Function -> LLM (response) -> Output + """ + call_count = {"llm": 0} + + invoke_context = InvokeContext( + conversation_id=uuid.uuid4().hex, + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + def mock_llm(messages: List[Message]) -> List[Message]: + """Mock LLM that calls function on first call, responds on second.""" + call_count["llm"] += 1 + + if call_count["llm"] == 1: + # First call: make a function call + return [Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "call_1", + "search", + '{"query": "weather today"}', + ) + ], + )] + else: + # Second call: respond with the function result + last_msg = messages[-1] if messages else Message(role="user", content="") + return [Message( + role="assistant", + content=f"Based on search: {last_msg.content}", + )] + + def search(self, query: str) -> str: + """Mock search function.""" + return "The weather is sunny, 72°F" + + # Create topics + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None + and event.data[-1].tool_calls is None + ), + ) + function_call_topic = Topic( + name="function_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + function_result_topic = Topic(name="function_result") + + # Create LLM node + llm_node = ( + Node.builder() + .name("MockLLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(function_result_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(function_call_topic) + .build() + ) + + # Create function call node + function_node = ( + Node.builder() + .name("SearchNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) + .tool(FunctionCallTool.builder().name("SearchTool").function(search).build()) + .publish_to(function_result_topic) + .build() + ) + + # Build workflow + workflow = ( + EventDrivenWorkflow.builder() + .name("react_single_func_workflow") + .node(llm_node) + .node(function_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="TestReActAgent", workflow=workflow) + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="What's the weather?")], + ) + + results = [] + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "Based on search" in results[0].data[0].content + assert call_count["llm"] == 2 + + input_data = PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="What's the weather again?")], + ) + + async for event in assistant.invoke(input_data): + results.append(event) + + assert len(results) == 1 + assert "Based on search" in results[0].data[0].content + assert call_count["llm"] == 2 \ No newline at end of file diff --git a/tests_integration/function_call_assistant/simple_gemini_function_call_assistant.py b/tests_integration/function_call_assistant/simple_gemini_function_call_assistant.py index 0edb6ce..6144c00 100644 --- a/tests_integration/function_call_assistant/simple_gemini_function_call_assistant.py +++ b/tests_integration/function_call_assistant/simple_gemini_function_call_assistant.py @@ -39,7 +39,7 @@ class SimpleGeminiFunctionCallAssistant(Assistant): name: str = Field(default="SimpleGeminiFunctionCallAssistant") type: str = Field(default="SimpleGeminiFunctionCallAssistant") api_key: str = Field(default_factory=lambda: os.getenv("GEMINI_API_KEY", "")) - model: str = Field(default="gemini-2.0-flash-lite") + model: str = Field(default="gemini-2.5-flash-lite") function_call_llm_system_message: Optional[str] = Field(default=None) summary_llm_system_message: Optional[str] = Field(default=None) function_tool: FunctionCallTool diff --git a/tests_integration/simple_llm_assistant/gemini_tool_example.py b/tests_integration/simple_llm_assistant/gemini_tool_example.py index 74dc44a..47d17d4 100644 --- a/tests_integration/simple_llm_assistant/gemini_tool_example.py +++ b/tests_integration/simple_llm_assistant/gemini_tool_example.py @@ -72,7 +72,8 @@ async def test_gemini_tool_with_chat_param() -> None: # 15 tokens ~ < 120 chars in normal language if isinstance(message.content, str): # Ensure the content length is within the expected range - assert len(message.content) < 150 + print(len(message.content)) + assert len(message.content) < 300 assert len(await event_store.get_events()) == 2 @@ -202,7 +203,7 @@ async def test_gemini_tool_with_chat_param_serialization() -> None: assert message.content and "Grafi" in message.content print(message.content) if isinstance(message.content, str): - assert len(message.content) < 150 + assert len(message.content) < 300 assert len(await event_store.get_events()) == 2 From f9cdf398f0f8dbd93facdcc09d04a875e37b7bc9 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Wed, 31 Dec 2025 15:11:53 +0000 Subject: [PATCH 2/7] add more tests and improve invoke finish signal --- grafi/topics/topic_base.py | 13 +++++++- grafi/workflows/impl/async_node_tracker.py | 22 +++----------- grafi/workflows/impl/async_output_queue.py | 8 ++++- tests/assistants/test_assistant_mock_llm.py | 9 +++--- tests/workflow/test_async_node_tracker.py | 21 ++++++++----- tests/workflow/test_async_output_queue.py | 33 +++++++++++---------- 6 files changed, 57 insertions(+), 49 deletions(-) diff --git a/grafi/topics/topic_base.py b/grafi/topics/topic_base.py index 837d657..88b3695 100644 --- a/grafi/topics/topic_base.py +++ b/grafi/topics/topic_base.py @@ -79,7 +79,18 @@ async def publish_data( """ Publish data to the topic if it meets the condition. """ - if self.condition(publish_event): + try: + condition_met = self.condition(publish_event) + except Exception as e: + # Condition evaluation failed (e.g., IndexError on empty data) + # Treat as condition not met + logger.debug( + f"[{self.name}] Condition evaluation failed: {e}. " + "Treating as condition not met." + ) + condition_met = False + + if condition_met: event = publish_event.model_copy( update={ "name": self.name, diff --git a/grafi/workflows/impl/async_node_tracker.py b/grafi/workflows/impl/async_node_tracker.py index 44f7033..38ae3d4 100644 --- a/grafi/workflows/impl/async_node_tracker.py +++ b/grafi/workflows/impl/async_node_tracker.py @@ -44,10 +44,6 @@ def __init__(self) -> None: self._cond = asyncio.Condition() self._quiescence_event = asyncio.Event() - # Work tracking (prevents premature quiescence before any work) - self._total_committed: int = 0 - self._has_started: bool = False - # Force stop flag (for explicit workflow stop) self._force_stopped: bool = False @@ -58,8 +54,6 @@ def reset(self) -> None: self._uncommitted_messages = 0 self._cond = asyncio.Condition() self._quiescence_event = asyncio.Event() - self._total_committed = 0 - self._has_started = False self._force_stopped = False # ───────────────────────────────────────────────────────────────────────── @@ -69,7 +63,6 @@ def reset(self) -> None: async def enter(self, node_name: str) -> None: """Called when a node begins processing.""" async with self._cond: - self._has_started = True self._quiescence_event.clear() self._active.add(node_name) self._processing_count[node_name] += 1 @@ -94,7 +87,6 @@ async def on_messages_published(self, count: int = 1, source: str = "") -> None: if count <= 0: return async with self._cond: - self._has_started = True self._quiescence_event.clear() self._uncommitted_messages += count @@ -112,12 +104,10 @@ async def on_messages_committed(self, count: int = 1, source: str = "") -> None: return async with self._cond: self._uncommitted_messages = max(0, self._uncommitted_messages - count) - self._total_committed += count self._check_quiescence_unlocked() logger.debug( f"Tracker: {count} messages committed from {source} " - f"(uncommitted={self._uncommitted_messages}, total={self._total_committed})" ) self._cond.notify_all() @@ -144,14 +134,9 @@ def _check_quiescence_unlocked(self) -> None: logger.debug( f"Tracker: checking quiescence - active={list(self._active)}, " f"uncommitted={self._uncommitted_messages}, " - f"has_started={self._has_started}, " - f"total_committed={self._total_committed}, " f"is_quiescent={is_quiescent}" ) if is_quiescent: - logger.info( - f"Tracker: quiescence detected (committed={self._total_committed})" - ) self._quiescence_event.set() def _is_quiescent_unlocked(self) -> bool: @@ -165,11 +150,13 @@ def _is_quiescent_unlocked(self) -> bool: - No messages waiting to be committed - At least some work was done """ + logger.debug( + f"Tracker: _is_quiescent_unlocked check - active={list(self._active)}, " + f"uncommitted={self._uncommitted_messages}, " + ) return ( not self._active and self._uncommitted_messages == 0 - and self._has_started - and self._total_committed > 0 ) async def is_quiescent(self) -> bool: @@ -263,6 +250,5 @@ async def get_metrics(self) -> Dict: return { "active_nodes": list(self._active), "uncommitted_messages": self._uncommitted_messages, - "total_committed": self._total_committed, "is_quiescent": self._is_quiescent_unlocked(), } diff --git a/grafi/workflows/impl/async_output_queue.py b/grafi/workflows/impl/async_output_queue.py index e3c72d2..5b8d434 100644 --- a/grafi/workflows/impl/async_output_queue.py +++ b/grafi/workflows/impl/async_output_queue.py @@ -54,6 +54,12 @@ async def _output_listener(self, topic: TopicBase) -> None: while not self._stopped: try: events = await topic.consume(self.consumer_name, timeout=0.1) + + if len(events) == 0: + # No events fetched within timeout, check if all node quiescence + if await self.tracker.should_terminate(): + break + for event in events: await self.queue.put(event) # Mark messages as committed when they reach the output queue @@ -70,7 +76,7 @@ async def _output_listener(self, topic: TopicBase) -> None: break except Exception as e: logger.error(f"Output listener error for {topic.name}: {e}") - await asyncio.sleep(0.1) + raise e def __aiter__(self) -> "AsyncOutputQueue": return self diff --git a/tests/assistants/test_assistant_mock_llm.py b/tests/assistants/test_assistant_mock_llm.py index 923ebac..84888cb 100644 --- a/tests/assistants/test_assistant_mock_llm.py +++ b/tests/assistants/test_assistant_mock_llm.py @@ -2274,9 +2274,8 @@ def search(self, query: str) -> str: data=[Message(role="user", content="What's the weather again?")], ) - async for event in assistant.invoke(input_data): - results.append(event) + secound_results = [] + async for event in assistant.invoke(input_data, is_sequential=True): + secound_results.append(event) - assert len(results) == 1 - assert "Based on search" in results[0].data[0].content - assert call_count["llm"] == 2 \ No newline at end of file + assert len(secound_results) == 0 diff --git a/tests/workflow/test_async_node_tracker.py b/tests/workflow/test_async_node_tracker.py index e348740..64cce37 100644 --- a/tests/workflow/test_async_node_tracker.py +++ b/tests/workflow/test_async_node_tracker.py @@ -13,9 +13,10 @@ def tracker(self): @pytest.mark.asyncio async def test_initial_state(self, tracker): - """Tracker starts idle with no work recorded.""" + """Tracker starts idle with no active nodes and no uncommitted messages.""" assert await tracker.is_idle() - assert await tracker.is_quiescent() is False + # Fresh tracker is quiescent (no active nodes, no uncommitted messages) + assert await tracker.is_quiescent() is True assert await tracker.get_activity_count() == 0 assert (await tracker.get_metrics())["uncommitted_messages"] == 0 @@ -31,8 +32,8 @@ async def test_enter_and_leave_updates_activity(self, tracker): await tracker.leave("node1") assert await tracker.is_idle() - # No commits yet so quiescence is still False - assert await tracker.is_quiescent() is False + # After leaving with no uncommitted messages, tracker is quiescent + assert await tracker.is_quiescent() is True assert await tracker.get_activity_count() == 1 @pytest.mark.asyncio @@ -67,7 +68,9 @@ async def finish_work(): @pytest.mark.asyncio async def test_wait_for_quiescence_timeout(self, tracker): - """wait_for_quiescence returns False on timeout.""" + """wait_for_quiescence returns False on timeout when not quiescent.""" + # Make tracker not quiescent by adding uncommitted messages + await tracker.on_messages_published(1) result = await tracker.wait_for_quiescence(timeout=0.01) assert result is False assert await tracker.is_quiescent() is False @@ -82,9 +85,10 @@ async def test_reset(self, tracker): tracker.reset() assert await tracker.is_idle() - assert await tracker.is_quiescent() is False + # After reset, tracker is quiescent (no active nodes, no uncommitted messages) + assert await tracker.is_quiescent() is True assert await tracker.get_activity_count() == 0 - assert (await tracker.get_metrics())["total_committed"] == 0 + assert (await tracker.get_metrics())["uncommitted_messages"] == 0 @pytest.mark.asyncio async def test_force_stop(self, tracker): @@ -133,4 +137,5 @@ async def test_reset_clears_force_stop(self, tracker): tracker.reset() assert tracker._force_stopped is False - assert await tracker.should_terminate() is False + # After reset with no uncommitted messages, should_terminate is True (quiescent) + assert await tracker.should_terminate() is True diff --git a/tests/workflow/test_async_output_queue.py b/tests/workflow/test_async_output_queue.py index e28c963..974359d 100644 --- a/tests/workflow/test_async_output_queue.py +++ b/tests/workflow/test_async_output_queue.py @@ -277,27 +277,25 @@ async def test_anext_waits_for_activity_count_stabilization(self): tracker=tracker, ) - # Simulate: node enters, adds item to queue, leaves - # Then another node should enter before we terminate + # Make tracker not quiescent first by publishing a message + await tracker.on_messages_published(1) async def simulate_node_activity(): """Simulate node activity that should prevent premature termination.""" - # First node processes - simulate full message lifecycle - await tracker.on_messages_published(1) + # First node processes await tracker.enter("node_1") await output_queue.queue.put(Mock(name="event_1")) await tracker.leave("node_1") - await tracker.on_messages_committed(1) - # Yield control - simulates realistic timing where next node - # starts within the same event loop cycle + # Yield control - simulates realistic timing await asyncio.sleep(0) - # Second node picks up and processes - simulate full message lifecycle - await tracker.on_messages_published(1) + # Second node processes await tracker.enter("node_2") await output_queue.queue.put(Mock(name="event_2")) await tracker.leave("node_2") + + # Finally commit the initial message to allow quiescence await tracker.on_messages_committed(1) # Start the activity simulation @@ -328,12 +326,15 @@ async def test_anext_terminates_when_truly_idle(self): tracker=tracker, ) - # Single node processes and finishes - simulate full message lifecycle + # Make tracker not quiescent first by publishing a message + await tracker.on_messages_published(1) + + # Single node processes and finishes async def simulate_single_node(): - await tracker.on_messages_published(1) await tracker.enter("node_1") await output_queue.queue.put(Mock(name="event_1")) await tracker.leave("node_1") + # Commit the message to allow quiescence await tracker.on_messages_committed(1) activity_task = asyncio.create_task(simulate_single_node()) @@ -366,6 +367,9 @@ async def test_activity_count_prevents_premature_exit(self): tracker=tracker, ) + # Make tracker not quiescent first by publishing messages + await tracker.on_messages_published(2) + events_received = [] iteration_complete = asyncio.Event() @@ -375,8 +379,7 @@ async def consumer(): iteration_complete.set() async def producer(): - # Node A processes - simulate full message lifecycle - await tracker.on_messages_published(1) + # Node A processes await tracker.enter("node_a") await output_queue.queue.put(Mock(name="event_a")) await tracker.leave("node_a") @@ -385,9 +388,7 @@ async def producer(): # Critical timing window - yield to let consumer check idle state await asyncio.sleep(0) - # Node B starts before consumer terminates (if fix works) - # simulate full message lifecycle - await tracker.on_messages_published(1) + # Node B processes await tracker.enter("node_b") await output_queue.queue.put(Mock(name="event_b")) await tracker.leave("node_b") From 4ae74af16714a82406ffceda77e4773d6fc13164 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Wed, 31 Dec 2025 15:47:27 +0000 Subject: [PATCH 3/7] improve codebase --- grafi/common/containers/container.py | 54 +- grafi/common/models/async_result.py | 14 +- grafi/tools/llms/impl/claude_tool.py | 42 +- grafi/tools/llms/impl/openai_tool.py | 49 +- .../queue_impl/in_mem_topic_event_queue.py | 13 +- grafi/workflows/impl/async_node_tracker.py | 41 +- tests/assistants/test_assistant_mock_llm.py | 675 ++++++++++-------- tests/tools/llms/test_claude_tool.py | 54 +- tests/tools/llms/test_openai_tool.py | 30 +- 9 files changed, 577 insertions(+), 395 deletions(-) diff --git a/grafi/common/containers/container.py b/grafi/common/containers/container.py index a5695eb..3c703c7 100644 --- a/grafi/common/containers/container.py +++ b/grafi/common/containers/container.py @@ -29,38 +29,52 @@ def __init__(self) -> None: # Per-instance attributes: self._event_store: Optional[EventStore] = None self._tracer: Optional[Tracer] = None + # Lock for thread-safe lazy initialization of properties + self._init_lock: threading.Lock = threading.Lock() def register_event_store(self, event_store: EventStore) -> None: """Override the default EventStore implementation.""" - if isinstance(event_store, EventStoreInMemory): - logger.warning( - "Using EventStoreInMemory. This is ONLY suitable for local testing but not for production." - ) - self._event_store = event_store + with self._init_lock: + if isinstance(event_store, EventStoreInMemory): + logger.warning( + "Using EventStoreInMemory. This is ONLY suitable for local testing but not for production." + ) + self._event_store = event_store def register_tracer(self, tracer: Tracer) -> None: """Override the default Tracer implementation.""" - self._tracer = tracer + with self._init_lock: + self._tracer = tracer @property def event_store(self) -> EventStore: - if self._event_store is None: - logger.warning( - "Using EventStoreInMemory. This is ONLY suitable for local testing but not for production." - ) - self._event_store = EventStoreInMemory() - return self._event_store + # Fast path: already initialized + if self._event_store is not None: + return self._event_store + # Slow path: initialize with lock (double-checked locking) + with self._init_lock: + if self._event_store is None: + logger.warning( + "Using EventStoreInMemory. This is ONLY suitable for local testing but not for production." + ) + self._event_store = EventStoreInMemory() + return self._event_store @property def tracer(self) -> Tracer: - if self._tracer is None: - self._tracer = setup_tracing( - tracing_options=TracingOptions.AUTO, - collector_endpoint="localhost", - collector_port=4317, - project_name="grafi-trace", - ) - return self._tracer + # Fast path: already initialized + if self._tracer is not None: + return self._tracer + # Slow path: initialize with lock (double-checked locking) + with self._init_lock: + if self._tracer is None: + self._tracer = setup_tracing( + tracing_options=TracingOptions.AUTO, + collector_endpoint="localhost", + collector_port=4317, + project_name="grafi-trace", + ) + return self._tracer container: Container = Container() diff --git a/grafi/common/models/async_result.py b/grafi/common/models/async_result.py index f9bfdfa..d6b1370 100644 --- a/grafi/common/models/async_result.py +++ b/grafi/common/models/async_result.py @@ -43,11 +43,12 @@ def __init__(self, source: AsyncGenerator[ConsumeFromTopicEvent, None]): self._done = asyncio.Event() self._started = False self._exc: Optional[BaseException] = None + self._producer_task: Optional[asyncio.Task] = None def _ensure_started(self) -> None: if not self._started: loop = asyncio.get_running_loop() - loop.create_task(self._producer()) + self._producer_task = loop.create_task(self._producer()) self._started = True async def _producer(self) -> None: @@ -94,7 +95,16 @@ async def to_list(self) -> list[ConsumeFromTopicEvent]: return result if isinstance(result, list) else [result] async def aclose(self) -> None: - """Attempt to close the underlying async generator (if any).""" + """Cancel producer task and close the underlying async generator.""" + # Cancel the producer task if it's running + if self._producer_task is not None and not self._producer_task.done(): + self._producer_task.cancel() + try: + await self._producer_task + except asyncio.CancelledError: + pass + + # Close the underlying source generator try: await self._source.aclose() except Exception: diff --git a/grafi/tools/llms/impl/claude_tool.py b/grafi/tools/llms/impl/claude_tool.py index fa28747..28d9ad0 100644 --- a/grafi/tools/llms/impl/claude_tool.py +++ b/grafi/tools/llms/impl/claude_tool.py @@ -102,29 +102,29 @@ async def invoke( input_data: Messages, ) -> MsgsAGen: messages, tools = self.prepare_api_input(input_data) - client = AsyncAnthropic(api_key=self.api_key) try: - if self.is_streaming: - async with client.messages.stream( - max_tokens=self.max_tokens, - model=self.model, - messages=messages, - tools=tools, - **self.chat_params, - ) as stream: - async for event in stream: - if event.type == "text": - yield self.to_stream_messages(event.text) - else: - resp: AnthropicMessage = await client.messages.create( - max_tokens=self.max_tokens, - model=self.model, - messages=messages, - tools=tools, - **self.chat_params, - ) - yield self.to_messages(resp) + async with AsyncAnthropic(api_key=self.api_key) as client: + if self.is_streaming: + async with client.messages.stream( + max_tokens=self.max_tokens, + model=self.model, + messages=messages, + tools=tools, + **self.chat_params, + ) as stream: + async for event in stream: + if event.type == "text": + yield self.to_stream_messages(event.text) + else: + resp: AnthropicMessage = await client.messages.create( + max_tokens=self.max_tokens, + model=self.model, + messages=messages, + tools=tools, + **self.chat_params, + ) + yield self.to_messages(resp) except asyncio.CancelledError: raise diff --git a/grafi/tools/llms/impl/openai_tool.py b/grafi/tools/llms/impl/openai_tool.py index ba9fa40..7be8aa3 100644 --- a/grafi/tools/llms/impl/openai_tool.py +++ b/grafi/tools/llms/impl/openai_tool.py @@ -107,31 +107,30 @@ async def invoke( ) -> MsgsAGen: api_messages, api_tools = self.prepare_api_input(input_data) try: - client = AsyncClient(api_key=self.api_key) - - if self.is_streaming: - async for chunk in await client.chat.completions.create( - model=self.model, - messages=api_messages, - tools=api_tools, - stream=True, - **self.chat_params, - ): - yield self.to_stream_messages(chunk) - else: - req_func = ( - client.chat.completions.create - if not self.structured_output - else client.beta.chat.completions.parse - ) - response: ChatCompletion = await req_func( - model=self.model, - messages=api_messages, - tools=api_tools, - **self.chat_params, - ) - - yield self.to_messages(response) + async with AsyncClient(api_key=self.api_key) as client: + if self.is_streaming: + async for chunk in await client.chat.completions.create( + model=self.model, + messages=api_messages, + tools=api_tools, + stream=True, + **self.chat_params, + ): + yield self.to_stream_messages(chunk) + else: + req_func = ( + client.chat.completions.create + if not self.structured_output + else client.beta.chat.completions.parse + ) + response: ChatCompletion = await req_func( + model=self.model, + messages=api_messages, + tools=api_tools, + **self.chat_params, + ) + + yield self.to_messages(response) except asyncio.CancelledError: raise # let caller handle except OpenAIError as exc: diff --git a/grafi/topics/queue_impl/in_mem_topic_event_queue.py b/grafi/topics/queue_impl/in_mem_topic_event_queue.py index 3af9b15..25e148e 100644 --- a/grafi/topics/queue_impl/in_mem_topic_event_queue.py +++ b/grafi/topics/queue_impl/in_mem_topic_event_queue.py @@ -66,7 +66,7 @@ async def fetch( async with self._cond: # If timeout is 0 or None and no data, return immediately - while not await self.can_consume(consumer_id): + while not self._can_consume_unlocked(consumer_id): try: logger.debug( f"Consumer {consumer_id} waiting for new messages with timeout={timeout}" @@ -109,8 +109,17 @@ async def reset(self) -> None: self._consumed = defaultdict(int) self._committed = defaultdict(lambda: -1) + def _can_consume_unlocked(self, consumer_id: str) -> bool: + """ + Internal check without lock. MUST be called with self._cond held. + """ + return self._consumed[consumer_id] < len(self._records) + async def can_consume(self, consumer_id: str) -> bool: """ Check if there are events available for consumption by a consumer asynchronously. + + This method acquires the lock to ensure consistent reads of shared state. """ - return self._consumed[consumer_id] < len(self._records) + async with self._cond: + return self._can_consume_unlocked(consumer_id) diff --git a/grafi/workflows/impl/async_node_tracker.py b/grafi/workflows/impl/async_node_tracker.py index 38ae3d4..abf821d 100644 --- a/grafi/workflows/impl/async_node_tracker.py +++ b/grafi/workflows/impl/async_node_tracker.py @@ -48,7 +48,13 @@ def __init__(self) -> None: self._force_stopped: bool = False def reset(self) -> None: - """Reset for a new workflow run.""" + """ + Reset for a new workflow run. + + Note: This is a sync reset that replaces primitives. It should only be + called when no coroutines are waiting on the old primitives (e.g., at + the start of a new workflow invocation before any tasks are spawned). + """ self._active.clear() self._processing_count.clear() self._uncommitted_messages = 0 @@ -56,6 +62,30 @@ def reset(self) -> None: self._quiescence_event = asyncio.Event() self._force_stopped = False + async def reset_async(self) -> None: + """ + Reset for a new workflow run (async version). + + This version properly wakes any waiting coroutines before resetting, + preventing deadlocks if called while the workflow is still running. + """ + async with self._cond: + # Wake all waiters so they can exit gracefully + self._force_stopped = True + self._quiescence_event.set() + self._cond.notify_all() + + # Give waiters a chance to wake up and exit + await asyncio.sleep(0) + + # Now safe to reset state + async with self._cond: + self._active.clear() + self._processing_count.clear() + self._uncommitted_messages = 0 + self._force_stopped = False + self._quiescence_event.clear() + # ───────────────────────────────────────────────────────────────────────── # Node Lifecycle (called from _invoke_node) # ───────────────────────────────────────────────────────────────────────── @@ -106,9 +136,7 @@ async def on_messages_committed(self, count: int = 1, source: str = "") -> None: self._uncommitted_messages = max(0, self._uncommitted_messages - count) self._check_quiescence_unlocked() - logger.debug( - f"Tracker: {count} messages committed from {source} " - ) + logger.debug(f"Tracker: {count} messages committed from {source} ") self._cond.notify_all() # Aliases for clarity @@ -154,10 +182,7 @@ def _is_quiescent_unlocked(self) -> bool: f"Tracker: _is_quiescent_unlocked check - active={list(self._active)}, " f"uncommitted={self._uncommitted_messages}, " ) - return ( - not self._active - and self._uncommitted_messages == 0 - ) + return not self._active and self._uncommitted_messages == 0 async def is_quiescent(self) -> bool: """ diff --git a/tests/assistants/test_assistant_mock_llm.py b/tests/assistants/test_assistant_mock_llm.py index 84888cb..7cca349 100644 --- a/tests/assistants/test_assistant_mock_llm.py +++ b/tests/assistants/test_assistant_mock_llm.py @@ -86,7 +86,6 @@ async def invoke( invoke_context=invoke_context, cause=e, ) from e - def to_dict(self) -> dict[str, Any]: """ @@ -130,8 +129,9 @@ async def from_dict(cls, data: dict[str, Any]) -> "LLMMockTool": ) - -def make_tool_call(call_id: str, name: str, arguments: str) -> ChatCompletionMessageToolCall: +def make_tool_call( + call_id: str, name: str, arguments: str +) -> ChatCompletionMessageToolCall: """Helper to create tool calls.""" return ChatCompletionMessageToolCall( id=call_id, @@ -167,10 +167,13 @@ async def test_react_agent_no_function_call(self, invoke_context): Flow: Input -> LLM (no function call) -> Output """ + # Mock LLM that always responds directly without function calls def mock_llm(messages: List[Message]) -> List[Message]: user_content = messages[-1].content if messages else "" - return [Message(role="assistant", content=f"Direct response to: {user_content}")] + return [ + Message(role="assistant", content=f"Direct response to: {user_content}") + ] # Create topics agent_input = InputTopic(name="agent_input") @@ -178,8 +181,7 @@ def mock_llm(messages: List[Message]) -> List[Message]: name="agent_output", # Only output when there's content and no tool calls condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -236,24 +238,30 @@ def mock_llm(messages: List[Message]) -> List[Message]: if call_count["llm"] == 1: # First call: make a function call - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call( - "call_1", - "search", - '{"query": "weather today"}', - ) - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "call_1", + "search", + '{"query": "weather today"}', + ) + ], + ) + ] else: # Second call: respond with the function result - last_msg = messages[-1] if messages else Message(role="user", content="") - return [Message( - role="assistant", - content=f"Based on search: {last_msg.content}", - )] + last_msg = ( + messages[-1] if messages else Message(role="user", content="") + ) + return [ + Message( + role="assistant", + content=f"Based on search: {last_msg.content}", + ) + ] def search(self, query: str) -> str: """Mock search function.""" @@ -264,8 +272,7 @@ def search(self, query: str) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -296,7 +303,9 @@ def search(self, query: str) -> str: Node.builder() .name("SearchNode") .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) - .tool(FunctionCallTool.builder().name("SearchTool").function(search).build()) + .tool( + FunctionCallTool.builder().name("SearchTool").function(search).build() + ) .publish_to(function_result_topic) .build() ) @@ -340,29 +349,36 @@ def mock_llm(messages: List[Message]) -> List[Message]: call_count["llm"] += 1 if call_count["llm"] == 1: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("call_1", "get_user", '{"id": "123"}') - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("call_1", "get_user", '{"id": "123"}') + ], + ) + ] elif call_count["llm"] == 2: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call( - "call_2", - "get_orders", - '{"user_id": "123"}', - ) - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "call_2", + "get_orders", + '{"user_id": "123"}', + ) + ], + ) + ] else: - return [Message( - role="assistant", content="User John has 3 orders totaling $150." - )] + return [ + Message( + role="assistant", + content="User John has 3 orders totaling $150.", + ) + ] def get_user(self, id: str) -> str: """Mock get_user function.""" @@ -377,8 +393,7 @@ def get_orders(self, user_id: str) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -472,6 +487,7 @@ async def test_hitl_workflow_no_human_input_needed(self, invoke_context): Flow: Input -> LLM (direct response) -> Output """ + def mock_llm(messages: List[Message]) -> List[Message]: return [Message(role="assistant", content="I can answer this directly!")] @@ -479,8 +495,7 @@ def mock_llm(messages: List[Message]) -> List[Message]: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) hitl_call_topic = Topic( @@ -537,30 +552,36 @@ def mock_llm(messages: List[Message]) -> List[Message]: if call_count["llm"] == 1: # First call: request human approval - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call( - "approval_1", - "request_approval", - '{"action": "delete_account", "reason": "user requested"}', - ) - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "approval_1", + "request_approval", + '{"action": "delete_account", "reason": "user requested"}', + ) + ], + ) + ] else: # Second call: process approval result last_content = messages[-1].content if messages else "" if "approved" in last_content.lower(): - return [Message( - role="assistant", - content="Account deletion has been approved and completed.", - )] + return [ + Message( + role="assistant", + content="Account deletion has been approved and completed.", + ) + ] else: - return [Message( - role="assistant", - content="Account deletion was rejected.", - )] + return [ + Message( + role="assistant", + content="Account deletion was rejected.", + ) + ] def request_approval(self, action: str, reason: str) -> str: """Mock HITL request that simulates human approval.""" @@ -573,8 +594,7 @@ def request_approval(self, action: str, reason: str) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) hitl_call_topic = Topic( @@ -657,24 +677,28 @@ def mock_llm(messages: List[Message]) -> List[Message]: call_count["llm"] += 1 if call_count["llm"] == 1: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call( - "approval_1", - "request_approval", - '{"action": "transfer_funds", "amount": "$10000"}', - ) - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "approval_1", + "request_approval", + '{"action": "transfer_funds", "amount": "$10000"}', + ) + ], + ) + ] else: last_content = messages[-1].content if messages else "" if "rejected" in last_content.lower(): - return [Message( - role="assistant", - content="The fund transfer was not approved. No action taken.", - )] + return [ + Message( + role="assistant", + content="The fund transfer was not approved. No action taken.", + ) + ] return [Message(role="assistant", content="Transfer completed.")] def request_approval(self, action: str, amount: str) -> str: @@ -685,8 +709,7 @@ def request_approval(self, action: str, amount: str) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) hitl_call_topic = Topic( @@ -762,36 +785,42 @@ def mock_llm(messages: List[Message]) -> List[Message]: if call_count["llm"] == 1: # First approval: manager - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call( - "approval_1", - "request_manager_approval", - '{"action": "large_purchase", "amount": "$5000"}', - ) - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "approval_1", + "request_manager_approval", + '{"action": "large_purchase", "amount": "$5000"}', + ) + ], + ) + ] elif call_count["llm"] == 2: # Second approval: finance - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call( - "approval_2", - "request_finance_approval", - '{"action": "large_purchase", "amount": "$5000"}', - ) - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "approval_2", + "request_finance_approval", + '{"action": "large_purchase", "amount": "$5000"}', + ) + ], + ) + ] else: # Final response - return [Message( - role="assistant", - content="Purchase approved by manager and finance. Order placed!", - )] + return [ + Message( + role="assistant", + content="Purchase approved by manager and finance. Order placed!", + ) + ] approval_count = {"count": 0} @@ -807,8 +836,7 @@ def request_finance_approval(self, action: str, amount: str) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) approval_call_topic = Topic( @@ -861,7 +889,9 @@ def request_finance_approval(self, action: str, amount: str) -> str: input_data = PublishToTopicEvent( invoke_context=invoke_context, - data=[Message(role="user", content="I need to purchase equipment for $5000")], + data=[ + Message(role="user", content="I need to purchase equipment for $5000") + ], ) results = [] @@ -899,29 +929,35 @@ async def test_conditional_branching_workflow(self, invoke_context): - If question about math -> Math function -> Response LLM - Otherwise -> Direct response """ + def mock_router(messages: List[Message]) -> List[Message]: """Route based on input content.""" content = messages[-1].content.lower() if messages else "" if "weather" in content: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("w1", "weather", '{"location": "NYC"}') - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("w1", "weather", '{"location": "NYC"}') + ], + ) + ] elif "math" in content or "calculate" in content: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("m1", "math", '{"expr": "2+2"}') - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[make_tool_call("m1", "math", '{"expr": "2+2"}')], + ) + ] else: - return [Message( - role="assistant", content="I can help with weather or math questions!" - )] + return [ + Message( + role="assistant", + content="I can help with weather or math questions!", + ) + ] def weather(self, location: str) -> str: return "Weather in NYC: Sunny, 75°F" @@ -932,21 +968,26 @@ def math(self, expr: str) -> str: def mock_response(messages: List[Message]) -> List[Message]: """Generate final response from function result.""" last_content = messages[-1].content if messages else "" - return [Message(role="assistant", content=f"Here's what I found: {last_content}")] + return [ + Message( + role="assistant", content=f"Here's what I found: {last_content}" + ) + ] agent_input = InputTopic(name="agent_input") agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) weather_topic = Topic( name="weather_call", condition=lambda event: ( event.data[-1].tool_calls is not None - and any(tc.function.name == "weather" for tc in event.data[-1].tool_calls) + and any( + tc.function.name == "weather" for tc in event.data[-1].tool_calls + ) ), ) math_topic = Topic( @@ -990,7 +1031,9 @@ def mock_response(messages: List[Message]) -> List[Message]: response_node = ( Node.builder() .name("ResponseNode") - .subscribe(SubscriptionBuilder().subscribed_to(function_result_topic).build()) + .subscribe( + SubscriptionBuilder().subscribed_to(function_result_topic).build() + ) .tool(LLMMockTool(function=mock_response)) .publish_to(agent_output) .build() @@ -1034,20 +1077,24 @@ def mock_llm_parallel(messages: List[Message]) -> List[Message]: # Check if we have function results has_results = any(msg.role == "tool" for msg in messages) if has_results: - return [Message( - role="assistant", - content="Combined weather and news: Great day, no major events!", - )] + return [ + Message( + role="assistant", + content="Combined weather and news: Great day, no major events!", + ) + ] else: # Request both functions at once - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("w1", "weather", "{}"), - make_tool_call("n1", "news", "{}"), - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("w1", "weather", "{}"), + make_tool_call("n1", "news", "{}"), + ], + ) + ] def weather(self) -> str: """Handle weather function call.""" @@ -1061,8 +1108,7 @@ def news(self) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -1137,21 +1183,23 @@ async def test_error_handling_in_function_call(self, invoke_context): def mock_llm(messages: List[Message]) -> List[Message]: call_count["llm"] += 1 if call_count["llm"] == 1: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("f1", "failing_func", "{}") - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[make_tool_call("f1", "failing_func", "{}")], + ) + ] else: # Handle error from function last_content = messages[-1].content if messages else "" if "error" in last_content.lower(): - return [Message( - role="assistant", - content="I encountered an error. Let me try a different approach.", - )] + return [ + Message( + role="assistant", + content="I encountered an error. Let me try a different approach.", + ) + ] return [Message(role="assistant", content="Success!")] def failing_func(self) -> str: @@ -1162,8 +1210,7 @@ def failing_func(self) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -1237,20 +1284,22 @@ def mock_llm_with_context(messages: List[Message]) -> List[Message]: accumulated_context.append([m.content for m in messages if m.content]) if len(accumulated_context) == 1: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("c1", "context_func", "{}") - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[make_tool_call("c1", "context_func", "{}")], + ) + ] else: # Return summary of all seen content all_content = [c for ctx in accumulated_context for c in ctx] - return [Message( - role="assistant", - content=f"Processed {len(all_content)} messages", - )] + return [ + Message( + role="assistant", + content=f"Processed {len(all_content)} messages", + ) + ] def context_func(self) -> str: return "Context function executed" @@ -1259,8 +1308,7 @@ def context_func(self) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -1351,14 +1399,15 @@ async def test_exception_in_function_call_tool(self, invoke_context): """ Test that exceptions in FunctionCallTool are properly propagated. """ + def mock_llm(messages: List[Message]) -> List[Message]: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("err1", "raise_error", '{}') - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[make_tool_call("err1", "raise_error", "{}")], + ) + ] def raise_error(self) -> str: raise ValueError("Intentional test error in function call") @@ -1367,8 +1416,7 @@ def raise_error(self) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -1417,17 +1465,21 @@ def raise_error(self) -> str: ) from grafi.common.exceptions import NodeExecutionError + with pytest.raises(NodeExecutionError) as exc_info: async for _ in assistant.invoke(input_data): pass - assert "ErrorNode" in str(exc_info.value) or "Intentional test error" in str(exc_info.value) + assert "ErrorNode" in str(exc_info.value) or "Intentional test error" in str( + exc_info.value + ) @pytest.mark.asyncio async def test_exception_in_llm_mock_tool(self, invoke_context): """ Test that exceptions in LLMMockTool are properly propagated. """ + def failing_llm(messages: List[Message]) -> List[Message]: raise RuntimeError("LLM processing failed") @@ -1462,6 +1514,7 @@ def failing_llm(messages: List[Message]) -> List[Message]: ) from grafi.common.exceptions import NodeExecutionError + with pytest.raises(NodeExecutionError) as exc_info: async for _ in assistant.invoke(input_data): pass @@ -1473,6 +1526,7 @@ async def test_llm_returns_empty_content(self, invoke_context): """ Test handling when LLM returns a message with empty content but no tool calls. """ + def empty_content_llm(messages: List[Message]) -> List[Message]: return [Message(role="assistant", content="")] @@ -1518,6 +1572,7 @@ async def test_llm_returns_single_message_not_list(self, invoke_context): """ Test that LLMMockTool properly wraps single Message in a list. """ + def single_message_llm(messages: List[Message]) -> Message: # Return single Message, not list return Message(role="assistant", content="Single message response") @@ -1564,18 +1619,23 @@ async def test_function_call_with_invalid_json_arguments(self, invoke_context): """ Test handling of tool calls with malformed JSON arguments. """ + def mock_llm(messages: List[Message]) -> List[Message]: - return [Message( - role="assistant", - content=None, - tool_calls=[ - ChatCompletionMessageToolCall( - id="bad_json", - type="function", - function=Function(name="some_func", arguments="not valid json{"), - ) - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="bad_json", + type="function", + function=Function( + name="some_func", arguments="not valid json{" + ), + ) + ], + ) + ] def some_func(self) -> str: return "Should not reach here" @@ -1584,8 +1644,7 @@ def some_func(self) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -1608,10 +1667,7 @@ def some_func(self) -> str: .name("FuncNode") .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) .tool( - FunctionCallTool.builder() - .name("SomeFunc") - .function(some_func) - .build() + FunctionCallTool.builder().name("SomeFunc").function(some_func).build() ) .publish_to(agent_output) .build() @@ -1634,6 +1690,7 @@ def some_func(self) -> str: ) from grafi.common.exceptions import NodeExecutionError + with pytest.raises(NodeExecutionError): async for _ in assistant.invoke(input_data): pass @@ -1652,19 +1709,22 @@ def mock_llm(messages: List[Message]) -> List[Message]: if call_count["llm"] == 1: # First call: make a function call - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("missing", "nonexistent_function", '{}') - ], - ) - ] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call("missing", "nonexistent_function", "{}") + ], + ) + ] else: - return [Message( - role="assistant", - content="Function not found.", - )] + return [ + Message( + role="assistant", + content="Function not found.", + ) + ] def existing_func(self) -> str: return "This is an existing function" @@ -1673,8 +1733,7 @@ def existing_func(self) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -1752,13 +1811,13 @@ async def test_workflow_stops_on_node_exception(self, invoke_context): def mock_llm(messages: List[Message]) -> List[Message]: call_count["count"] += 1 if call_count["count"] == 1: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("fail", "fail_func", '{}') - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[make_tool_call("fail", "fail_func", "{}")], + ) + ] # Should not reach here if workflow stops on error return [Message(role="assistant", content="Should not see this")] @@ -1769,8 +1828,7 @@ def fail_func(self) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -1800,10 +1858,7 @@ def fail_func(self) -> str: .name("FailNode") .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) .tool( - FunctionCallTool.builder() - .name("FailTool") - .function(fail_func) - .build() + FunctionCallTool.builder().name("FailTool").function(fail_func).build() ) .publish_to(function_result_topic) .build() @@ -1840,6 +1895,7 @@ async def test_llm_mock_tool_serialization(self, invoke_context): """ Test LLMMockTool to_dict and from_dict methods. """ + def sample_llm(messages: List[Message]) -> List[Message]: return [Message(role="assistant", content="Serialization test")] @@ -1862,21 +1918,26 @@ async def test_multiple_tool_calls_in_single_message(self, invoke_context): """ Test handling multiple tool calls in a single LLM response. """ + def mock_llm(messages: List[Message]) -> List[Message]: has_results = any(msg.role == "tool" for msg in messages) if has_results: - return [Message( + return [ + Message( + role="assistant", + content="Got results from both functions", + ) + ] + return [ + Message( role="assistant", - content="Got results from both functions", - )] - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("t1", "func_a", '{}'), - make_tool_call("t2", "func_b", '{}'), - ], - )] + content=None, + tool_calls=[ + make_tool_call("t1", "func_a", "{}"), + make_tool_call("t2", "func_b", "{}"), + ], + ) + ] def func_a(self) -> str: return "Result A" @@ -1888,8 +1949,7 @@ def func_b(self) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -1962,40 +2022,44 @@ async def test_function_returns_complex_json(self, invoke_context): def mock_llm(messages: List[Message]) -> List[Message]: call_count["llm"] += 1 if call_count["llm"] == 1: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call("json1", "get_complex_data", '{}') - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[make_tool_call("json1", "get_complex_data", "{}")], + ) + ] # Check if we received the complex data last_content = messages[-1].content if messages else "" - return [Message( - role="assistant", - content=f"Received complex data: {last_content[:50]}...", - )] + return [ + Message( + role="assistant", + content=f"Received complex data: {last_content[:50]}...", + ) + ] def get_complex_data(self) -> str: import json - return json.dumps({ - "users": [ - {"id": 1, "name": "Alice", "roles": ["admin", "user"]}, - {"id": 2, "name": "Bob", "roles": ["user"]}, - ], - "metadata": { - "total": 2, - "page": 1, - "nested": {"deep": {"value": True}} + + return json.dumps( + { + "users": [ + {"id": 1, "name": "Alice", "roles": ["admin", "user"]}, + {"id": 2, "name": "Bob", "roles": ["user"]}, + ], + "metadata": { + "total": 2, + "page": 1, + "nested": {"deep": {"value": True}}, + }, } - }) + ) agent_input = InputTopic(name="agent_input") agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -2069,17 +2133,19 @@ async def test_function_with_special_characters_in_args(self, invoke_context): def mock_llm(messages: List[Message]) -> List[Message]: call_count["llm"] += 1 if call_count["llm"] == 1: - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call( - "special", - "process_text", - '{"text": "Hello\\nWorld\\twith\\ttabs", "query": "test \\"quoted\\""}' - ) - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "special", + "process_text", + '{"text": "Hello\\nWorld\\twith\\ttabs", "query": "test \\"quoted\\""}', + ) + ], + ) + ] return [Message(role="assistant", content="Processed special chars")] def process_text(self, text: str, query: str) -> str: @@ -2091,8 +2157,7 @@ def process_text(self, text: str, query: str) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -2157,7 +2222,6 @@ def process_text(self, text: str, query: str) -> str: assert "\t" in received_args.get("text", "") assert '"' in received_args.get("query", "") - @pytest.mark.asyncio async def test_react_agent_single_function_call_twice(self): """ @@ -2179,24 +2243,30 @@ def mock_llm(messages: List[Message]) -> List[Message]: if call_count["llm"] == 1: # First call: make a function call - return [Message( - role="assistant", - content=None, - tool_calls=[ - make_tool_call( - "call_1", - "search", - '{"query": "weather today"}', - ) - ], - )] + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "call_1", + "search", + '{"query": "weather today"}', + ) + ], + ) + ] else: # Second call: respond with the function result - last_msg = messages[-1] if messages else Message(role="user", content="") - return [Message( - role="assistant", - content=f"Based on search: {last_msg.content}", - )] + last_msg = ( + messages[-1] if messages else Message(role="user", content="") + ) + return [ + Message( + role="assistant", + content=f"Based on search: {last_msg.content}", + ) + ] def search(self, query: str) -> str: """Mock search function.""" @@ -2207,8 +2277,7 @@ def search(self, query: str) -> str: agent_output = OutputTopic( name="agent_output", condition=lambda event: ( - event.data[-1].content is not None - and event.data[-1].tool_calls is None + event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) function_call_topic = Topic( @@ -2239,7 +2308,9 @@ def search(self, query: str) -> str: Node.builder() .name("SearchNode") .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) - .tool(FunctionCallTool.builder().name("SearchTool").function(search).build()) + .tool( + FunctionCallTool.builder().name("SearchTool").function(search).build() + ) .publish_to(function_result_topic) .build() ) diff --git a/tests/tools/llms/test_claude_tool.py b/tests/tools/llms/test_claude_tool.py index 8823c5e..b895de5 100644 --- a/tests/tools/llms/test_claude_tool.py +++ b/tests/tools/llms/test_claude_tool.py @@ -64,10 +64,20 @@ async def test_invoke_simple_response(monkeypatch, claude_instance, invoke_conte mock_client = MagicMock() mock_client.messages.create = AsyncMock(return_value=fake_response) - # patch AsyncAnthropic constructor - monkeypatch.setattr( - cl_module, "AsyncAnthropic", MagicMock(return_value=mock_client) - ) + # Create async context manager mock for AsyncAnthropic + async def mock_aenter(self): + return mock_client + + async def mock_aexit(self, *args): + pass + + mock_context_manager = MagicMock() + mock_context_manager.__aenter__ = mock_aenter + mock_context_manager.__aexit__ = mock_aexit + + # patch AsyncAnthropic constructor to return context manager + mock_async_anthropic_cls = MagicMock(return_value=mock_context_manager) + monkeypatch.setattr(cl_module, "AsyncAnthropic", mock_async_anthropic_cls) input_data = [Message(role="user", content="Say hello")] result = [] @@ -79,7 +89,7 @@ async def test_invoke_simple_response(monkeypatch, claude_instance, invoke_conte assert result[0].content == "Hello, world!" # verify constructor args - cl_module.AsyncAnthropic.assert_called_once_with(api_key="test_api_key") + mock_async_anthropic_cls.assert_called_once_with(api_key="test_api_key") # verify create() called with right kwargs kwargs = mock_client.messages.create.call_args[1] @@ -108,9 +118,20 @@ async def test_invoke_function_call(monkeypatch, claude_instance, invoke_context mock_client = MagicMock() mock_client.messages.create = AsyncMock(return_value=fake_response) - monkeypatch.setattr( - cl_module, "AsyncAnthropic", MagicMock(return_value=mock_client) - ) + + # Create async context manager mock for AsyncAnthropic + async def mock_aenter(self): + return mock_client + + async def mock_aexit(self, *args): + pass + + mock_context_manager = MagicMock() + mock_context_manager.__aenter__ = mock_aenter + mock_context_manager.__aexit__ = mock_aexit + + mock_async_anthropic_cls = MagicMock(return_value=mock_context_manager) + monkeypatch.setattr(cl_module, "AsyncAnthropic", mock_async_anthropic_cls) tools = [ FunctionSpec( @@ -144,9 +165,20 @@ def _raise(*_a, **_kw): # pragma: no cover mock_client = MagicMock() mock_client.messages.create.side_effect = _raise - monkeypatch.setattr( - cl_module, "AsyncAnthropic", MagicMock(return_value=mock_client) - ) + + # Create async context manager mock for AsyncAnthropic + async def mock_aenter(self): + return mock_client + + async def mock_aexit(self, *args): + pass + + mock_context_manager = MagicMock() + mock_context_manager.__aenter__ = mock_aenter + mock_context_manager.__aexit__ = mock_aexit + + mock_async_anthropic_cls = MagicMock(return_value=mock_context_manager) + monkeypatch.setattr(cl_module, "AsyncAnthropic", mock_async_anthropic_cls) from grafi.common.exceptions import LLMToolException diff --git a/tests/tools/llms/test_openai_tool.py b/tests/tools/llms/test_openai_tool.py index 8683fb2..8e3f205 100644 --- a/tests/tools/llms/test_openai_tool.py +++ b/tests/tools/llms/test_openai_tool.py @@ -61,8 +61,19 @@ async def mock_create(*args, **kwargs): mock_client = MagicMock() mock_client.chat.completions.create = mock_create - # Mock the AsyncClient constructor - mock_async_client_cls = MagicMock(return_value=mock_client) + # Create async context manager mock + async def mock_aenter(self): + return mock_client + + async def mock_aexit(self, *args): + pass + + mock_context_manager = MagicMock() + mock_context_manager.__aenter__ = mock_aenter + mock_context_manager.__aexit__ = mock_aexit + + # Mock the AsyncClient constructor to return our context manager + mock_async_client_cls = MagicMock(return_value=mock_context_manager) monkeypatch.setattr( grafi.tools.llms.impl.openai_tool, "AsyncClient", mock_async_client_cls ) @@ -115,8 +126,19 @@ async def mock_create(*args, **kwargs): mock_client = MagicMock() mock_client.chat.completions.create = mock_create - # Mock the AsyncClient constructor - mock_async_client_cls = MagicMock(return_value=mock_client) + # Create async context manager mock + async def mock_aenter(self): + return mock_client + + async def mock_aexit(self, *args): + pass + + mock_context_manager = MagicMock() + mock_context_manager.__aenter__ = mock_aenter + mock_context_manager.__aexit__ = mock_aexit + + # Mock the AsyncClient constructor to return our context manager + mock_async_client_cls = MagicMock(return_value=mock_context_manager) monkeypatch.setattr( grafi.tools.llms.impl.openai_tool, "AsyncClient", mock_async_client_cls ) From 1329795d000590c01013e6586522171865072a3d Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Wed, 31 Dec 2025 15:54:24 +0000 Subject: [PATCH 4/7] update unit tests --- tests/assistants/test_assistant_mock_llm.py | 498 ++++++++++++++------ 1 file changed, 360 insertions(+), 138 deletions(-) diff --git a/tests/assistants/test_assistant_mock_llm.py b/tests/assistants/test_assistant_mock_llm.py index 7cca349..70de5dd 100644 --- a/tests/assistants/test_assistant_mock_llm.py +++ b/tests/assistants/test_assistant_mock_llm.py @@ -463,12 +463,18 @@ class TestHumanInTheLoopWithMockLLM: """ Test Human-in-the-Loop (HITL) workflows using FunctionTool to simulate LLM behavior. - HITL workflow pattern: - 1. LLM processes input and decides to request human input - 2. Workflow pauses, emits event for human response - 3. Human provides input via InWorkflowInputTopic - 4. Workflow continues with human input - 5. LLM generates final response + HITL workflow pattern (following tests_integration/hith_assistant concepts): + 1. LLM processes input and decides to request human input via tool call + 2. Function node executes and publishes to InWorkflowOutputTopic + 3. Workflow pauses, emits event for human response + 4. Human provides input via new invoke with consumed_event_ids + 5. Input goes to InWorkflowInputTopic, LLM continues processing + 6. LLM generates final response or requests more human input + + Key components: + - InWorkflowOutputTopic: Pauses workflow and emits to external consumer (human) + - InWorkflowInputTopic: Receives human response to continue workflow + - consumed_event_ids: Links human response to previous outputs when resuming """ @pytest.fixture @@ -485,7 +491,7 @@ async def test_hitl_workflow_no_human_input_needed(self, invoke_context): """ Test HITL workflow when LLM can respond without human input. - Flow: Input -> LLM (direct response) -> Output + Flow: Input -> LLM (direct response, no tool call) -> Output """ def mock_llm(messages: List[Message]) -> List[Message]: @@ -536,60 +542,61 @@ def mock_llm(messages: List[Message]) -> List[Message]: assert results[0].data[0].content == "I can answer this directly!" @pytest.mark.asyncio - async def test_hitl_workflow_with_human_approval(self, invoke_context): + async def test_hitl_workflow_with_in_workflow_topics(self, invoke_context): """ - Test HITL workflow that requests and receives human approval. + Test proper HITL workflow using InWorkflowInputTopic and InWorkflowOutputTopic. + + This follows the pattern from tests_integration/hith_assistant: + 1. First invoke: LLM requests human info -> pauses at InWorkflowOutputTopic + 2. Second invoke: Human provides response via consumed_event_ids -> continues + 3. LLM generates final response Flow: - 1. Input -> LLM (requests approval) -> HITL Output - 2. Human approval -> HITL Input -> LLM (processes approval) -> Output + Invoke 1: Input -> LLM (tool call) -> FunctionNode -> InWorkflowOutputTopic (pauses) + Invoke 2: Human response (with consumed_event_ids) -> InWorkflowInputTopic -> LLM -> Output """ call_count = {"llm": 0} def mock_llm(messages: List[Message]) -> List[Message]: - """Mock LLM that requests approval on first call.""" + """Mock LLM that requests human info on first call, responds on second.""" call_count["llm"] += 1 if call_count["llm"] == 1: - # First call: request human approval + # First call: request human information return [ Message( role="assistant", content=None, tool_calls=[ make_tool_call( - "approval_1", - "request_approval", - '{"action": "delete_account", "reason": "user requested"}', + "info_1", + "request_human_information", + '{"question": "What is your name?"}', ) ], ) ] else: - # Second call: process approval result - last_content = messages[-1].content if messages else "" - if "approved" in last_content.lower(): - return [ - Message( - role="assistant", - content="Account deletion has been approved and completed.", - ) - ] - else: - return [ - Message( - role="assistant", - content="Account deletion was rejected.", - ) - ] + # Second call: process human response and generate final answer + # Find the user's response in messages + user_response = "" + for msg in messages: + if msg.role == "user" and msg.content: + user_response = msg.content + return [ + Message( + role="assistant", + content=f"Thank you! I received your response: {user_response}", + ) + ] + + def request_human_information(self, question: str) -> str: + """Mock function that returns a schema for human to fill.""" + import json - def request_approval(self, action: str, reason: str) -> str: - """Mock HITL request that simulates human approval.""" - # In a real scenario, this would pause and wait for human input - # For testing, we simulate automatic approval - return "Action APPROVED by human reviewer" + return json.dumps({"question": question, "answer": "string"}) - # Create topics + # Create topics following integration test pattern agent_input = InputTopic(name="agent_input") agent_output = OutputTopic( name="agent_output", @@ -602,15 +609,14 @@ def request_approval(self, action: str, reason: str) -> str: condition=lambda event: event.data[-1].tool_calls is not None, ) - # HITL topics for human interaction + # HITL topics - the key components for true HITL pattern human_response_topic = InWorkflowInputTopic(name="human_response") human_request_topic = InWorkflowOutputTopic( name="human_request", paired_in_workflow_input_topic_names=["human_response"], ) - hitl_result_topic = Topic(name="hitl_result") - # LLM node + # LLM node subscribes to both initial input AND human responses llm_node = ( Node.builder() .name("MockLLMNode") @@ -618,7 +624,7 @@ def request_approval(self, action: str, reason: str) -> str: SubscriptionBuilder() .subscribed_to(agent_input) .or_() - .subscribed_to(hitl_result_topic) + .subscribed_to(human_response_topic) .build() ) .tool(LLMMockTool(function=mock_llm)) @@ -627,84 +633,130 @@ def request_approval(self, action: str, reason: str) -> str: .build() ) - # HITL request node - hitl_node = ( + # Function node publishes to InWorkflowOutputTopic to pause for human + function_node = ( Node.builder() .name("HITLRequestNode") .subscribe(SubscriptionBuilder().subscribed_to(hitl_call_topic).build()) .tool( FunctionCallTool.builder() .name("HITLRequest") - .function(request_approval) + .function(request_human_information) .build() ) - .publish_to(hitl_result_topic) + .publish_to(human_request_topic) # InWorkflowOutputTopic - pauses here .build() ) workflow = ( EventDrivenWorkflow.builder() - .name("hitl_approval_workflow") + .name("hitl_in_workflow_topics") .node(llm_node) - .node(hitl_node) + .node(function_node) .build() ) with patch.object(Assistant, "_construct_workflow"): assistant = Assistant(name="TestHITLAgent", workflow=workflow) - input_data = PublishToTopicEvent( + # First invoke: should pause at InWorkflowOutputTopic + first_input = PublishToTopicEvent( invoke_context=invoke_context, - data=[Message(role="user", content="Please delete my account")], + data=[Message(role="user", content="I want to register")], ) - results = [] - async for event in assistant.invoke(input_data): - results.append(event) + first_outputs = [] + async for event in assistant.invoke(first_input): + first_outputs.append(event) - assert len(results) == 1 - assert "approved" in results[0].data[0].content.lower() + # Should get output from InWorkflowOutputTopic (the HITL request) + assert len(first_outputs) == 1 + assert call_count["llm"] == 1 + + # Second invoke: human provides response with consumed_event_ids + human_response = PublishToTopicEvent( + invoke_context=InvokeContext( + conversation_id=invoke_context.conversation_id, + invoke_id=uuid.uuid4().hex, + assistant_request_id=invoke_context.assistant_request_id, + ), + data=[Message(role="user", content="My name is Alice")], + consumed_event_ids=[event.event_id for event in first_outputs], + ) + + second_outputs = [] + async for event in assistant.invoke(human_response): + second_outputs.append(event) + + # Should get final response from LLM + assert len(second_outputs) == 1 + assert "Alice" in second_outputs[0].data[0].content assert call_count["llm"] == 2 @pytest.mark.asyncio - async def test_hitl_workflow_with_rejection(self, invoke_context): + async def test_hitl_workflow_multi_turn_human_input(self, invoke_context): """ - Test HITL workflow where human rejects the action. + Test HITL workflow requiring multiple rounds of human input. + + This simulates a registration flow requiring name and age separately. + + Flow: + Invoke 1: Input -> LLM (request name) -> pause + Invoke 2: Name response -> LLM (request age) -> pause + Invoke 3: Age response -> LLM (complete registration) -> Output """ call_count = {"llm": 0} def mock_llm(messages: List[Message]) -> List[Message]: + """Mock LLM that collects info step by step.""" call_count["llm"] += 1 if call_count["llm"] == 1: + # First: request name return [ Message( role="assistant", content=None, tool_calls=[ make_tool_call( - "approval_1", - "request_approval", - '{"action": "transfer_funds", "amount": "$10000"}', + "name_req", + "request_info", + '{"field": "name"}', + ) + ], + ) + ] + elif call_count["llm"] == 2: + # Second: got name, request age + return [ + Message( + role="assistant", + content=None, + tool_calls=[ + make_tool_call( + "age_req", + "request_info", + '{"field": "age"}', ) ], ) ] else: - last_content = messages[-1].content if messages else "" - if "rejected" in last_content.lower(): - return [ - Message( - role="assistant", - content="The fund transfer was not approved. No action taken.", - ) - ] - return [Message(role="assistant", content="Transfer completed.")] + # Third: got all info, complete registration + return [ + Message( + role="assistant", + content="Registration complete! Welcome to the gym.", + ) + ] - def request_approval(self, action: str, amount: str) -> str: - """Simulate human rejection.""" - return "Action REJECTED - amount too high" + def request_info(self, field: str) -> str: + """Request a specific piece of information.""" + import json + + return json.dumps({"requested_field": field}) + # Topics agent_input = InputTopic(name="agent_input") agent_output = OutputTopic( name="agent_output", @@ -716,7 +768,11 @@ def request_approval(self, action: str, amount: str) -> str: name="hitl_call", condition=lambda event: event.data[-1].tool_calls is not None, ) - hitl_result_topic = Topic(name="hitl_result") + human_response_topic = InWorkflowInputTopic(name="human_response") + human_request_topic = InWorkflowOutputTopic( + name="human_request", + paired_in_workflow_input_topic_names=["human_response"], + ) llm_node = ( Node.builder() @@ -725,7 +781,7 @@ def request_approval(self, action: str, amount: str) -> str: SubscriptionBuilder() .subscribed_to(agent_input) .or_() - .subscribed_to(hitl_result_topic) + .subscribed_to(human_response_topic) .build() ) .tool(LLMMockTool(function=mock_llm)) @@ -734,49 +790,87 @@ def request_approval(self, action: str, amount: str) -> str: .build() ) - hitl_node = ( + function_node = ( Node.builder() - .name("HITLRejectNode") + .name("InfoRequestNode") .subscribe(SubscriptionBuilder().subscribed_to(hitl_call_topic).build()) .tool( FunctionCallTool.builder() - .name("HITLReject") - .function(request_approval) + .name("InfoRequest") + .function(request_info) .build() ) - .publish_to(hitl_result_topic) + .publish_to(human_request_topic) .build() ) workflow = ( EventDrivenWorkflow.builder() - .name("hitl_rejection_workflow") + .name("hitl_multi_turn") .node(llm_node) - .node(hitl_node) + .node(function_node) .build() ) with patch.object(Assistant, "_construct_workflow"): assistant = Assistant(name="TestHITLAgent", workflow=workflow) - input_data = PublishToTopicEvent( - invoke_context=invoke_context, - data=[Message(role="user", content="Transfer $10000 to account X")], - ) + # Invoke 1: Initial request + outputs_1 = [] + async for event in assistant.invoke( + PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Register me for the gym")], + ) + ): + outputs_1.append(event) + + assert len(outputs_1) == 1 + assert call_count["llm"] == 1 + + # Invoke 2: Provide name + outputs_2 = [] + async for event in assistant.invoke( + PublishToTopicEvent( + invoke_context=InvokeContext( + conversation_id=invoke_context.conversation_id, + invoke_id=uuid.uuid4().hex, + assistant_request_id=invoke_context.assistant_request_id, + ), + data=[Message(role="user", content="My name is Bob")], + consumed_event_ids=[e.event_id for e in outputs_1], + ) + ): + outputs_2.append(event) - results = [] - async for event in assistant.invoke(input_data): - results.append(event) + assert len(outputs_2) == 1 + assert call_count["llm"] == 2 - assert len(results) == 1 - assert "not approved" in results[0].data[0].content.lower() + # Invoke 3: Provide age + outputs_3 = [] + async for event in assistant.invoke( + PublishToTopicEvent( + invoke_context=InvokeContext( + conversation_id=invoke_context.conversation_id, + invoke_id=uuid.uuid4().hex, + assistant_request_id=invoke_context.assistant_request_id, + ), + data=[Message(role="user", content="My age is 25")], + consumed_event_ids=[e.event_id for e in outputs_2], + ) + ): + outputs_3.append(event) + + assert len(outputs_3) == 1 + assert "Registration complete" in outputs_3[0].data[0].content + assert call_count["llm"] == 3 @pytest.mark.asyncio - async def test_hitl_workflow_multi_step_approval(self, invoke_context): + async def test_hitl_workflow_with_approval_rejection(self, invoke_context): """ - Test HITL workflow with multiple approval steps. + Test HITL workflow where human can approve or reject an action. - Flow: Input -> LLM (approval1) -> Human1 -> LLM (approval2) -> Human2 -> LLM -> Output + This tests the approval pattern with InWorkflowOutputTopic. """ call_count = {"llm": 0} @@ -784,7 +878,6 @@ def mock_llm(messages: List[Message]) -> List[Message]: call_count["llm"] += 1 if call_count["llm"] == 1: - # First approval: manager return [ Message( role="assistant", @@ -792,45 +885,179 @@ def mock_llm(messages: List[Message]) -> List[Message]: tool_calls=[ make_tool_call( "approval_1", - "request_manager_approval", - '{"action": "large_purchase", "amount": "$5000"}', + "request_approval", + '{"action": "delete_account"}', ) ], ) ] - elif call_count["llm"] == 2: - # Second approval: finance + else: + # Check last user message for approval decision + last_user_msg = "" + for msg in reversed(messages): + if msg.role == "user" and msg.content: + last_user_msg = msg.content.lower() + break + + if "approve" in last_user_msg or "yes" in last_user_msg: + return [ + Message( + role="assistant", + content="Account deletion approved and completed.", + ) + ] + else: + return [ + Message( + role="assistant", + content="Account deletion was rejected. No action taken.", + ) + ] + + def request_approval(self, action: str) -> str: + """Request human approval for an action.""" + import json + + return json.dumps( + { + "action": action, + "message": f"Do you approve: {action}?", + "options": ["approve", "reject"], + } + ) + + agent_input = InputTopic(name="agent_input") + agent_output = OutputTopic( + name="agent_output", + condition=lambda event: ( + event.data[-1].content is not None and event.data[-1].tool_calls is None + ), + ) + hitl_call_topic = Topic( + name="hitl_call", + condition=lambda event: event.data[-1].tool_calls is not None, + ) + human_response_topic = InWorkflowInputTopic(name="human_response") + human_request_topic = InWorkflowOutputTopic( + name="human_request", + paired_in_workflow_input_topic_names=["human_response"], + ) + + llm_node = ( + Node.builder() + .name("MockLLMNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(agent_input) + .or_() + .subscribed_to(human_response_topic) + .build() + ) + .tool(LLMMockTool(function=mock_llm)) + .publish_to(agent_output) + .publish_to(hitl_call_topic) + .build() + ) + + hitl_node = ( + Node.builder() + .name("ApprovalNode") + .subscribe(SubscriptionBuilder().subscribed_to(hitl_call_topic).build()) + .tool( + FunctionCallTool.builder() + .name("ApprovalTool") + .function(request_approval) + .build() + ) + .publish_to(human_request_topic) + .build() + ) + + workflow = ( + EventDrivenWorkflow.builder() + .name("hitl_approval_workflow") + .node(llm_node) + .node(hitl_node) + .build() + ) + + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant(name="TestHITLAgent", workflow=workflow) + + # First invoke: request approval + outputs_1 = [] + async for event in assistant.invoke( + PublishToTopicEvent( + invoke_context=invoke_context, + data=[Message(role="user", content="Delete my account")], + ) + ): + outputs_1.append(event) + + assert len(outputs_1) == 1 + assert call_count["llm"] == 1 + + # Second invoke: human rejects + outputs_2 = [] + async for event in assistant.invoke( + PublishToTopicEvent( + invoke_context=InvokeContext( + conversation_id=invoke_context.conversation_id, + invoke_id=uuid.uuid4().hex, + assistant_request_id=invoke_context.assistant_request_id, + ), + data=[Message(role="user", content="reject")], + consumed_event_ids=[e.event_id for e in outputs_1], + ) + ): + outputs_2.append(event) + + assert len(outputs_2) == 1 + assert "rejected" in outputs_2[0].data[0].content.lower() + assert call_count["llm"] == 2 + + @pytest.mark.asyncio + async def test_hitl_legacy_auto_approval_pattern(self, invoke_context): + """ + Test legacy HITL pattern where function auto-responds (no real human pause). + + This is the simpler pattern where the function immediately returns a result + without pausing for actual human input. Useful for testing function call flows. + """ + call_count = {"llm": 0} + + def mock_llm(messages: List[Message]) -> List[Message]: + """Mock LLM that requests approval on first call.""" + call_count["llm"] += 1 + + if call_count["llm"] == 1: return [ Message( role="assistant", content=None, tool_calls=[ make_tool_call( - "approval_2", - "request_finance_approval", - '{"action": "large_purchase", "amount": "$5000"}', + "approval_1", + "auto_approve", + '{"action": "test_action"}', ) ], ) ] else: - # Final response - return [ - Message( - role="assistant", - content="Purchase approved by manager and finance. Order placed!", - ) - ] - - approval_count = {"count": 0} - - def request_manager_approval(self, action: str, amount: str) -> str: - approval_count["count"] += 1 - return "Manager APPROVED" + last_content = messages[-1].content if messages else "" + if "approved" in last_content.lower(): + return [ + Message( + role="assistant", + content="Action was automatically approved.", + ) + ] + return [Message(role="assistant", content="Action completed.")] - def request_finance_approval(self, action: str, amount: str) -> str: - approval_count["count"] += 1 - return "Finance APPROVED" + def auto_approve(self, action: str) -> str: + """Simulate automatic approval without human intervention.""" + return "Action APPROVED automatically" agent_input = InputTopic(name="agent_input") agent_output = OutputTopic( @@ -839,11 +1066,11 @@ def request_finance_approval(self, action: str, amount: str) -> str: event.data[-1].content is not None and event.data[-1].tool_calls is None ), ) - approval_call_topic = Topic( - name="approval_call", + function_call_topic = Topic( + name="function_call", condition=lambda event: event.data[-1].tool_calls is not None, ) - approval_result_topic = Topic(name="approval_result") + function_result_topic = Topic(name="function_result") llm_node = ( Node.builder() @@ -852,35 +1079,34 @@ def request_finance_approval(self, action: str, amount: str) -> str: SubscriptionBuilder() .subscribed_to(agent_input) .or_() - .subscribed_to(approval_result_topic) + .subscribed_to(function_result_topic) .build() ) .tool(LLMMockTool(function=mock_llm)) .publish_to(agent_output) - .publish_to(approval_call_topic) + .publish_to(function_call_topic) .build() ) - approval_node = ( + function_node = ( Node.builder() - .name("ApprovalNode") - .subscribe(SubscriptionBuilder().subscribed_to(approval_call_topic).build()) + .name("AutoApproveNode") + .subscribe(SubscriptionBuilder().subscribed_to(function_call_topic).build()) .tool( FunctionCallTool.builder() - .name("ApprovalTool") - .function(request_manager_approval) - .function(request_finance_approval) + .name("AutoApprove") + .function(auto_approve) .build() ) - .publish_to(approval_result_topic) + .publish_to(function_result_topic) .build() ) workflow = ( EventDrivenWorkflow.builder() - .name("hitl_multi_approval_workflow") + .name("legacy_auto_approval") .node(llm_node) - .node(approval_node) + .node(function_node) .build() ) @@ -889,9 +1115,7 @@ def request_finance_approval(self, action: str, amount: str) -> str: input_data = PublishToTopicEvent( invoke_context=invoke_context, - data=[ - Message(role="user", content="I need to purchase equipment for $5000") - ], + data=[Message(role="user", content="Do something that needs approval")], ) results = [] @@ -899,10 +1123,8 @@ def request_finance_approval(self, action: str, amount: str) -> str: results.append(event) assert len(results) == 1 - assert "manager" in results[0].data[0].content.lower() - assert "finance" in results[0].data[0].content.lower() - assert call_count["llm"] == 3 - assert approval_count["count"] == 2 + assert "approved" in results[0].data[0].content.lower() + assert call_count["llm"] == 2 class TestComplexWorkflowPatterns: From 85cbf618a65875e37b5093c5371602bd73349ea2 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Wed, 31 Dec 2025 16:25:55 +0000 Subject: [PATCH 5/7] update hitl tests --- tests/assistants/test_assistant_mock_llm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/assistants/test_assistant_mock_llm.py b/tests/assistants/test_assistant_mock_llm.py index 70de5dd..19cc806 100644 --- a/tests/assistants/test_assistant_mock_llm.py +++ b/tests/assistants/test_assistant_mock_llm.py @@ -2113,7 +2113,7 @@ def fail_func(self) -> str: assert call_count["count"] == 1 @pytest.mark.asyncio - async def test_llm_mock_tool_serialization(self, invoke_context): + async def test_llm_mock_tool_serialization(self): """ Test LLMMockTool to_dict and from_dict methods. """ @@ -2567,8 +2567,8 @@ def search(self, query: str) -> str: data=[Message(role="user", content="What's the weather again?")], ) - secound_results = [] + second_results = [] async for event in assistant.invoke(input_data, is_sequential=True): - secound_results.append(event) + second_results.append(event) - assert len(secound_results) == 0 + assert len(second_results) == 0 From 2b2ce864f7132fe5c5210c0c13244105a4741559 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Wed, 31 Dec 2025 17:52:40 +0000 Subject: [PATCH 6/7] address comments --- grafi/common/models/async_result.py | 3 ++- grafi/workflows/impl/async_node_tracker.py | 7 ++++--- tests/assistants/test_assistant_mock_llm.py | 3 ++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/grafi/common/models/async_result.py b/grafi/common/models/async_result.py index d6b1370..ebb859d 100644 --- a/grafi/common/models/async_result.py +++ b/grafi/common/models/async_result.py @@ -102,12 +102,13 @@ async def aclose(self) -> None: try: await self._producer_task except asyncio.CancelledError: + # The task was cancelled by aclose(); a CancelledError here is expected. pass - # Close the underlying source generator try: await self._source.aclose() except Exception: + # Best-effort cleanup: ignore errors from closing the underlying source. pass diff --git a/grafi/workflows/impl/async_node_tracker.py b/grafi/workflows/impl/async_node_tracker.py index abf821d..67e0ada 100644 --- a/grafi/workflows/impl/async_node_tracker.py +++ b/grafi/workflows/impl/async_node_tracker.py @@ -136,7 +136,7 @@ async def on_messages_committed(self, count: int = 1, source: str = "") -> None: self._uncommitted_messages = max(0, self._uncommitted_messages - count) self._check_quiescence_unlocked() - logger.debug(f"Tracker: {count} messages committed from {source} ") + logger.debug(f"Tracker: {count} messages committed from {source}") self._cond.notify_all() # Aliases for clarity @@ -178,11 +178,12 @@ def _is_quiescent_unlocked(self) -> bool: - No messages waiting to be committed - At least some work was done """ + is_quiescent = not self._active and self._uncommitted_messages == 0 logger.debug( f"Tracker: _is_quiescent_unlocked check - active={list(self._active)}, " - f"uncommitted={self._uncommitted_messages}, " + f"uncommitted={self._uncommitted_messages}, is_quiescent={is_quiescent}" ) - return not self._active and self._uncommitted_messages == 0 + return is_quiescent async def is_quiescent(self) -> bool: """ diff --git a/tests/assistants/test_assistant_mock_llm.py b/tests/assistants/test_assistant_mock_llm.py index 19cc806..2bf4455 100644 --- a/tests/assistants/test_assistant_mock_llm.py +++ b/tests/assistants/test_assistant_mock_llm.py @@ -1388,7 +1388,6 @@ def news(self) -> str: results = [] - print("starting invocation") async for event in assistant.invoke(input_data): results.append(event) @@ -2571,4 +2570,6 @@ def search(self, query: str) -> str: async for event in assistant.invoke(input_data, is_sequential=True): second_results.append(event) + + # The second invocation should not produce any output as the workflow completes after first assert len(second_results) == 0 From ef303ccacc577d598956fc8c09cd68672cc15703 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Wed, 31 Dec 2025 18:03:46 +0000 Subject: [PATCH 7/7] fix lint --- tests/assistants/test_assistant_mock_llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/assistants/test_assistant_mock_llm.py b/tests/assistants/test_assistant_mock_llm.py index 2bf4455..55d1c61 100644 --- a/tests/assistants/test_assistant_mock_llm.py +++ b/tests/assistants/test_assistant_mock_llm.py @@ -2570,6 +2570,5 @@ def search(self, query: str) -> str: async for event in assistant.invoke(input_data, is_sequential=True): second_results.append(event) - # The second invocation should not produce any output as the workflow completes after first assert len(second_results) == 0