From b0071d9637cbde79ae4d8c4f49e120517d0c3647 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Mon, 29 Dec 2025 21:19:27 +0000 Subject: [PATCH 1/9] improve async call stability --- grafi/workflows/impl/async_output_queue.py | 10 +- tests/assistants/test_assistant.py | 398 ++++++++++++++++++++- tests/workflow/test_async_output_queue.py | 145 ++++++++ 3 files changed, 548 insertions(+), 5 deletions(-) diff --git a/grafi/workflows/impl/async_output_queue.py b/grafi/workflows/impl/async_output_queue.py index 3607126..32145a6 100644 --- a/grafi/workflows/impl/async_output_queue.py +++ b/grafi/workflows/impl/async_output_queue.py @@ -1,5 +1,4 @@ import asyncio -from typing import AsyncGenerator from typing import List from grafi.common.events.topic_events.topic_event import TopicEvent @@ -84,8 +83,9 @@ async def _output_listener(self, topic: TopicBase) -> None: for t in pending: t.cancel() - def __aiter__(self) -> AsyncGenerator[TopicEvent, None]: + def __aiter__(self) -> "AsyncOutputQueue": """Make AsyncOutputQueue async iterable.""" + self._last_activity_count = 0 return self async def __anext__(self) -> TopicEvent: @@ -114,4 +114,8 @@ async def __anext__(self) -> TopicEvent: await asyncio.sleep(0) # one event‑loop tick if self.tracker.is_idle() and self.queue.empty(): - raise StopAsyncIteration + current_activity = self.tracker.get_activity_count() + # Only terminate if no new activity since last check + if current_activity == self._last_activity_count: + raise StopAsyncIteration + self._last_activity_count = current_activity diff --git a/tests/assistants/test_assistant.py b/tests/assistants/test_assistant.py index cdbfbb3..294417c 100644 --- a/tests/assistants/test_assistant.py +++ b/tests/assistants/test_assistant.py @@ -1,7 +1,7 @@ +import asyncio import json import os -from unittest.mock import Mock -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from openinference.semconv.trace import OpenInferenceSpanKindValues @@ -14,6 +14,8 @@ from grafi.common.models.invoke_context import InvokeContext from grafi.common.models.message import Message from grafi.topics.topic_types import TopicType +from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker +from grafi.workflows.impl.async_output_queue import AsyncOutputQueue from grafi.workflows.workflow import Workflow @@ -314,6 +316,251 @@ def test_generate_manifest_custom_directory(self, mock_assistant, tmp_path): }, } + @pytest.mark.asyncio + async def test_eight_node_dag_workflow(self): + """ + Test a DAG workflow with 8 nodes: A->B, B->C, C->D, C->E, C->F, D->G, E->G, F->G, G->H. + + Each node concatenates the previous input with its own label. + For example, node B receives "A" and outputs "AB". + + The topology creates a fan-out at C (to D, E, F) and a fan-in at G (from D, E, F). + """ + from grafi.common.events.topic_events.publish_to_topic_event import ( + PublishToTopicEvent, + ) + from grafi.common.models.invoke_context import InvokeContext + from grafi.common.models.message import Message + from grafi.nodes.node import Node + from grafi.tools.functions.function_tool import FunctionTool + from grafi.topics.expressions.subscription_builder import SubscriptionBuilder + from grafi.topics.topic_impl.input_topic import InputTopic + from grafi.topics.topic_impl.output_topic import OutputTopic + from grafi.topics.topic_impl.topic import Topic + from grafi.workflows.impl.event_driven_workflow import EventDrivenWorkflow + + # Define the concatenation function for each node + def make_concat_func(label: str): + def concat_func(messages): + # Collect all content from input messages + contents = [] + for msg in messages: + if msg.content: + contents.append(msg.content) + # Sort to ensure deterministic ordering for fan-in scenarios + contents.sort() + combined = "".join(contents) + return f"{combined}{label}" + + return concat_func + + # Create topics + # Input/Output topics for the workflow + agent_input_topic = InputTopic(name="agent_input") + agent_output_topic = OutputTopic(name="agent_output") + + # Intermediate topics for connecting nodes + topic_a_out = Topic(name="topic_a_out") + topic_b_out = Topic(name="topic_b_out") + topic_c_out = Topic(name="topic_c_out") + topic_d_out = Topic(name="topic_d_out") + topic_e_out = Topic(name="topic_e_out") + topic_f_out = Topic(name="topic_f_out") + topic_g_out = Topic(name="topic_g_out") + + # Create nodes + # Node A: subscribes to agent_input, publishes to topic_a_out + node_a = ( + Node.builder() + .name("NodeA") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(agent_input_topic).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolA") + .function(make_concat_func("A")) + .build() + ) + .publish_to(topic_a_out) + .build() + ) + + # Node B: subscribes to topic_a_out, publishes to topic_b_out + node_b = ( + Node.builder() + .name("NodeB") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_a_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolB") + .function(make_concat_func("B")) + .build() + ) + .publish_to(topic_b_out) + .build() + ) + + # Node C: subscribes to topic_b_out, publishes to topic_c_out + node_c = ( + Node.builder() + .name("NodeC") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_b_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolC") + .function(make_concat_func("C")) + .build() + ) + .publish_to(topic_c_out) + .build() + ) + + # Node D: subscribes to topic_c_out, publishes to topic_d_out (fan-out from C) + node_d = ( + Node.builder() + .name("NodeD") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_c_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolD") + .function(make_concat_func("D")) + .build() + ) + .publish_to(topic_d_out) + .build() + ) + + # Node E: subscribes to topic_c_out, publishes to topic_e_out (fan-out from C) + node_e = ( + Node.builder() + .name("NodeE") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_c_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolE") + .function(make_concat_func("E")) + .build() + ) + .publish_to(topic_e_out) + .build() + ) + + # Node F: subscribes to topic_c_out, publishes to topic_f_out (fan-out from C) + node_f = ( + Node.builder() + .name("NodeF") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_c_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolF") + .function(make_concat_func("F")) + .build() + ) + .publish_to(topic_f_out) + .build() + ) + + # Node G: subscribes to topic_d_out AND topic_e_out AND topic_f_out (fan-in) + node_g = ( + Node.builder() + .name("NodeG") + .type("ConcatNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(topic_d_out) + .and_() + .subscribed_to(topic_e_out) + .and_() + .subscribed_to(topic_f_out) + .build() + ) + .tool( + FunctionTool.builder() + .name("ConcatToolG") + .function(make_concat_func("G")) + .build() + ) + .publish_to(topic_g_out) + .build() + ) + + # Node H: subscribes to topic_g_out, publishes to agent_output + node_h = ( + Node.builder() + .name("NodeH") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_g_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolH") + .function(make_concat_func("H")) + .build() + ) + .publish_to(agent_output_topic) + .build() + ) + + # Build the workflow + workflow = ( + EventDrivenWorkflow.builder() + .name("EightNodeDAGWorkflow") + .node(node_a) + .node(node_b) + .node(node_c) + .node(node_d) + .node(node_e) + .node(node_f) + .node(node_g) + .node(node_h) + .build() + ) + + # Create assistant with the workflow + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant( + name="EightNodeDAGAssistant", + workflow=workflow, + ) + + # Create invoke context and input + invoke_context = InvokeContext( + conversation_id="test_dag_conversation", + invoke_id="test_dag_invoke", + assistant_request_id="test_dag_request", + ) + + # Start with empty input - each node adds its label + input_messages = [Message(content="", role="user")] + input_data = PublishToTopicEvent( + invoke_context=invoke_context, data=input_messages + ) + + # Invoke the workflow (using default parallel mode) + result_events = [] + async for event in assistant.invoke(input_data): + result_events.append(event) + + # Verify we get exactly 1 event from the agent_output topic + assert len(result_events) == 1, f"Expected 1 event, got {len(result_events)}" + assert result_events[0].name == "agent_output" + + # The expected output path is: + # A: "" -> "A" + # B: "A" -> "AB" + # C: "AB" -> "ABC" + # D: "ABC" -> "ABCD" + # E: "ABC" -> "ABCE" + # F: "ABC" -> "ABCF" + # G: combines "ABCD", "ABCE", "ABCF" (sorted) -> "ABCDABCEABCFG" + # H: "ABCDABCEABCFG" -> "ABCDABCEABCFGH" + expected_output = "ABCDABCEABCFGH" + assert result_events[0].data[0].content == expected_output + def test_generate_manifest_file_write_error(self, mock_assistant): """Test manifest generation with file write error.""" with patch("builtins.open", side_effect=IOError("Permission denied")): @@ -533,3 +780,150 @@ async def test_from_dict_with_defaults(self): # This should fail because EventDrivenWorkflow needs input/output topics with pytest.raises(Exception): # Will raise WorkflowError await Assistant.from_dict(data) + + +class TestAsyncOutputQueue: + """Tests for AsyncOutputQueue race condition handling.""" + + @pytest.mark.asyncio + async def test_anext_waits_for_activity_count_stabilization(self): + """ + Test that __anext__ doesn't prematurely terminate when activity count changes. + + This tests the race condition fix where the output queue could terminate + before downstream nodes finish processing. + """ + tracker = AsyncNodeTracker() + + output_queue = AsyncOutputQueue( + output_topics=[], # Empty - we'll put events directly in queue + consumer_name="test_consumer", + tracker=tracker, + ) + + # Simulate: node enters, adds item to queue, leaves + # Then another node should enter before we terminate + + async def simulate_node_activity(): + """Simulate node activity that should prevent premature termination.""" + # First node processes + await tracker.enter("node_1") + await output_queue.queue.put(Mock(name="event_1")) + await tracker.leave("node_1") + + # Yield control - simulates realistic timing where next node + # starts within the same event loop cycle + await asyncio.sleep(0) + + # Second node picks up and processes + await tracker.enter("node_2") + await output_queue.queue.put(Mock(name="event_2")) + await tracker.leave("node_2") + + # Start the activity simulation + activity_task = asyncio.create_task(simulate_node_activity()) + + # Iterate over the queue + events = [] + async for event in output_queue: + events.append(event) + if len(events) >= 2: + break + + await activity_task + + # Should have received both events + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_anext_terminates_when_truly_idle(self): + """ + Test that __anext__ correctly terminates when no more activity. + """ + tracker = AsyncNodeTracker() + + output_queue = AsyncOutputQueue( + output_topics=[], # Empty - we'll put events directly in queue + consumer_name="test_consumer", + tracker=tracker, + ) + + # Single node processes and finishes + async def simulate_single_node(): + await tracker.enter("node_1") + await output_queue.queue.put(Mock(name="event_1")) + await tracker.leave("node_1") + + activity_task = asyncio.create_task(simulate_single_node()) + + events = [] + async for event in output_queue: + events.append(event) + + await activity_task + + # Should terminate after receiving the single event + assert len(events) == 1 + + @pytest.mark.asyncio + async def test_activity_count_prevents_premature_exit(self): + """ + Test specifically that activity count tracking prevents race condition. + + Scenario: + 1. Node A finishes and tracker goes idle + 2. __anext__ sees idle but activity count changed + 3. Node B starts before __anext__ decides to terminate + 4. All events are properly yielded + """ + tracker = AsyncNodeTracker() + + output_queue = AsyncOutputQueue( + output_topics=[], # Empty - we'll put events directly in queue + consumer_name="test_consumer", + tracker=tracker, + ) + + events_received = [] + iteration_complete = asyncio.Event() + + async def consumer(): + async for event in output_queue: + events_received.append(event) + iteration_complete.set() + + async def producer(): + # Node A processes + await tracker.enter("node_a") + await output_queue.queue.put(Mock(name="event_a")) + await tracker.leave("node_a") + + # Critical timing window - yield to let consumer check idle state + await asyncio.sleep(0) + + # Node B starts before consumer terminates (if fix works) + await tracker.enter("node_b") + await output_queue.queue.put(Mock(name="event_b")) + await tracker.leave("node_b") + + consumer_task = asyncio.create_task(consumer()) + producer_task = asyncio.create_task(producer()) + + # Wait for producer to finish + await producer_task + + # Wait a bit for consumer to process + try: + await asyncio.wait_for(iteration_complete.wait(), timeout=1.0) + except asyncio.TimeoutError: + consumer_task.cancel() + try: + await consumer_task + except asyncio.CancelledError: + pass + + # With the fix, we should receive both events + assert len(events_received) == 2, ( + f"Expected 2 events but got {len(events_received)}. " + "Race condition may have caused premature termination." + ) diff --git a/tests/workflow/test_async_output_queue.py b/tests/workflow/test_async_output_queue.py index afd4455..54a63ce 100644 --- a/tests/workflow/test_async_output_queue.py +++ b/tests/workflow/test_async_output_queue.py @@ -1,4 +1,5 @@ import asyncio +from unittest.mock import Mock import pytest @@ -268,3 +269,147 @@ async def test_concurrent_listeners(self, tracker): # Should have collected all events assert len(collected) == 3 assert all(isinstance(e, PublishToTopicEvent) for e in collected) + + + @pytest.mark.asyncio + async def test_anext_waits_for_activity_count_stabilization(self): + """ + Test that __anext__ doesn't prematurely terminate when activity count changes. + + This tests the race condition fix where the output queue could terminate + before downstream nodes finish processing. + """ + tracker = AsyncNodeTracker() + + output_queue = AsyncOutputQueue( + output_topics=[], # Empty - we'll put events directly in queue + consumer_name="test_consumer", + tracker=tracker, + ) + + # Simulate: node enters, adds item to queue, leaves + # Then another node should enter before we terminate + + async def simulate_node_activity(): + """Simulate node activity that should prevent premature termination.""" + # First node processes + await tracker.enter("node_1") + await output_queue.queue.put(Mock(name="event_1")) + await tracker.leave("node_1") + + # Yield control - simulates realistic timing where next node + # starts within the same event loop cycle + await asyncio.sleep(0) + + # Second node picks up and processes + await tracker.enter("node_2") + await output_queue.queue.put(Mock(name="event_2")) + await tracker.leave("node_2") + + # Start the activity simulation + activity_task = asyncio.create_task(simulate_node_activity()) + + # Iterate over the queue + events = [] + async for event in output_queue: + events.append(event) + if len(events) >= 2: + break + + await activity_task + + # Should have received both events + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_anext_terminates_when_truly_idle(self): + """ + Test that __anext__ correctly terminates when no more activity. + """ + tracker = AsyncNodeTracker() + + output_queue = AsyncOutputQueue( + output_topics=[], # Empty - we'll put events directly in queue + consumer_name="test_consumer", + tracker=tracker, + ) + + # Single node processes and finishes + async def simulate_single_node(): + await tracker.enter("node_1") + await output_queue.queue.put(Mock(name="event_1")) + await tracker.leave("node_1") + + activity_task = asyncio.create_task(simulate_single_node()) + + events = [] + async for event in output_queue: + events.append(event) + + await activity_task + + # Should terminate after receiving the single event + assert len(events) == 1 + + @pytest.mark.asyncio + async def test_activity_count_prevents_premature_exit(self): + """ + Test specifically that activity count tracking prevents race condition. + + Scenario: + 1. Node A finishes and tracker goes idle + 2. __anext__ sees idle but activity count changed + 3. Node B starts before __anext__ decides to terminate + 4. All events are properly yielded + """ + tracker = AsyncNodeTracker() + + output_queue = AsyncOutputQueue( + output_topics=[], # Empty - we'll put events directly in queue + consumer_name="test_consumer", + tracker=tracker, + ) + + events_received = [] + iteration_complete = asyncio.Event() + + async def consumer(): + async for event in output_queue: + events_received.append(event) + iteration_complete.set() + + async def producer(): + # Node A processes + await tracker.enter("node_a") + await output_queue.queue.put(Mock(name="event_a")) + await tracker.leave("node_a") + + # Critical timing window - yield to let consumer check idle state + await asyncio.sleep(0) + + # Node B starts before consumer terminates (if fix works) + await tracker.enter("node_b") + await output_queue.queue.put(Mock(name="event_b")) + await tracker.leave("node_b") + + consumer_task = asyncio.create_task(consumer()) + producer_task = asyncio.create_task(producer()) + + # Wait for producer to finish + await producer_task + + # Wait a bit for consumer to process + try: + await asyncio.wait_for(iteration_complete.wait(), timeout=1.0) + except asyncio.TimeoutError: + consumer_task.cancel() + try: + await consumer_task + except asyncio.CancelledError: + pass + + # With the fix, we should receive both events + assert len(events_received) == 2, ( + f"Expected 2 events but got {len(events_received)}. " + "Race condition may have caused premature termination." + ) From 7d6982519459f372ed6b169853a9a981dfe17648 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Mon, 29 Dec 2025 22:47:14 +0000 Subject: [PATCH 2/9] fix lint --- tests/assistants/test_assistant.py | 3 ++- tests/workflow/test_async_output_queue.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/assistants/test_assistant.py b/tests/assistants/test_assistant.py index 294417c..8d9dd03 100644 --- a/tests/assistants/test_assistant.py +++ b/tests/assistants/test_assistant.py @@ -1,7 +1,8 @@ import asyncio import json import os -from unittest.mock import Mock, patch +from unittest.mock import Mock +from unittest.mock import patch import pytest from openinference.semconv.trace import OpenInferenceSpanKindValues diff --git a/tests/workflow/test_async_output_queue.py b/tests/workflow/test_async_output_queue.py index 54a63ce..049178f 100644 --- a/tests/workflow/test_async_output_queue.py +++ b/tests/workflow/test_async_output_queue.py @@ -270,7 +270,6 @@ async def test_concurrent_listeners(self, tracker): assert len(collected) == 3 assert all(isinstance(e, PublishToTopicEvent) for e in collected) - @pytest.mark.asyncio async def test_anext_waits_for_activity_count_stabilization(self): """ From 003288d7d72b8aa8ff261568b11758903636e9e4 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Mon, 29 Dec 2025 22:57:00 +0000 Subject: [PATCH 3/9] update version and remove dup code --- pyproject.toml | 2 +- tests/assistants/test_assistant.py | 150 ----------------------------- uv.lock | 2 +- 3 files changed, 2 insertions(+), 152 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 606d2b8..92315dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "grafi" -version = "0.0.33" +version = "0.0.34" description = "Grafi - a flexible, event-driven framework that enables the creation of domain-specific AI agents through composable agentic workflows." authors = [{name = "Craig Li", email = "craig@binome.dev"}] license = {text = "Mozilla Public License Version 2.0"} diff --git a/tests/assistants/test_assistant.py b/tests/assistants/test_assistant.py index 8d9dd03..e181c9a 100644 --- a/tests/assistants/test_assistant.py +++ b/tests/assistants/test_assistant.py @@ -1,4 +1,3 @@ -import asyncio import json import os from unittest.mock import Mock @@ -15,8 +14,6 @@ from grafi.common.models.invoke_context import InvokeContext from grafi.common.models.message import Message from grafi.topics.topic_types import TopicType -from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker -from grafi.workflows.impl.async_output_queue import AsyncOutputQueue from grafi.workflows.workflow import Workflow @@ -781,150 +778,3 @@ async def test_from_dict_with_defaults(self): # This should fail because EventDrivenWorkflow needs input/output topics with pytest.raises(Exception): # Will raise WorkflowError await Assistant.from_dict(data) - - -class TestAsyncOutputQueue: - """Tests for AsyncOutputQueue race condition handling.""" - - @pytest.mark.asyncio - async def test_anext_waits_for_activity_count_stabilization(self): - """ - Test that __anext__ doesn't prematurely terminate when activity count changes. - - This tests the race condition fix where the output queue could terminate - before downstream nodes finish processing. - """ - tracker = AsyncNodeTracker() - - output_queue = AsyncOutputQueue( - output_topics=[], # Empty - we'll put events directly in queue - consumer_name="test_consumer", - tracker=tracker, - ) - - # Simulate: node enters, adds item to queue, leaves - # Then another node should enter before we terminate - - async def simulate_node_activity(): - """Simulate node activity that should prevent premature termination.""" - # First node processes - await tracker.enter("node_1") - await output_queue.queue.put(Mock(name="event_1")) - await tracker.leave("node_1") - - # Yield control - simulates realistic timing where next node - # starts within the same event loop cycle - await asyncio.sleep(0) - - # Second node picks up and processes - await tracker.enter("node_2") - await output_queue.queue.put(Mock(name="event_2")) - await tracker.leave("node_2") - - # Start the activity simulation - activity_task = asyncio.create_task(simulate_node_activity()) - - # Iterate over the queue - events = [] - async for event in output_queue: - events.append(event) - if len(events) >= 2: - break - - await activity_task - - # Should have received both events - assert len(events) == 2 - - @pytest.mark.asyncio - async def test_anext_terminates_when_truly_idle(self): - """ - Test that __anext__ correctly terminates when no more activity. - """ - tracker = AsyncNodeTracker() - - output_queue = AsyncOutputQueue( - output_topics=[], # Empty - we'll put events directly in queue - consumer_name="test_consumer", - tracker=tracker, - ) - - # Single node processes and finishes - async def simulate_single_node(): - await tracker.enter("node_1") - await output_queue.queue.put(Mock(name="event_1")) - await tracker.leave("node_1") - - activity_task = asyncio.create_task(simulate_single_node()) - - events = [] - async for event in output_queue: - events.append(event) - - await activity_task - - # Should terminate after receiving the single event - assert len(events) == 1 - - @pytest.mark.asyncio - async def test_activity_count_prevents_premature_exit(self): - """ - Test specifically that activity count tracking prevents race condition. - - Scenario: - 1. Node A finishes and tracker goes idle - 2. __anext__ sees idle but activity count changed - 3. Node B starts before __anext__ decides to terminate - 4. All events are properly yielded - """ - tracker = AsyncNodeTracker() - - output_queue = AsyncOutputQueue( - output_topics=[], # Empty - we'll put events directly in queue - consumer_name="test_consumer", - tracker=tracker, - ) - - events_received = [] - iteration_complete = asyncio.Event() - - async def consumer(): - async for event in output_queue: - events_received.append(event) - iteration_complete.set() - - async def producer(): - # Node A processes - await tracker.enter("node_a") - await output_queue.queue.put(Mock(name="event_a")) - await tracker.leave("node_a") - - # Critical timing window - yield to let consumer check idle state - await asyncio.sleep(0) - - # Node B starts before consumer terminates (if fix works) - await tracker.enter("node_b") - await output_queue.queue.put(Mock(name="event_b")) - await tracker.leave("node_b") - - consumer_task = asyncio.create_task(consumer()) - producer_task = asyncio.create_task(producer()) - - # Wait for producer to finish - await producer_task - - # Wait a bit for consumer to process - try: - await asyncio.wait_for(iteration_complete.wait(), timeout=1.0) - except asyncio.TimeoutError: - consumer_task.cancel() - try: - await consumer_task - except asyncio.CancelledError: - pass - - # With the fix, we should receive both events - assert len(events_received) == 2, ( - f"Expected 2 events but got {len(events_received)}. " - "Race condition may have caused premature termination." - ) diff --git a/uv.lock b/uv.lock index 0cc3438..f8079d1 100644 --- a/uv.lock +++ b/uv.lock @@ -1283,7 +1283,7 @@ wheels = [ [[package]] name = "grafi" -version = "0.0.33" +version = "0.0.34" source = { editable = "." } dependencies = [ { name = "anyio" }, From 03445083415ceca0c7630f807da4ec38959167a2 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Tue, 30 Dec 2025 17:24:05 +0000 Subject: [PATCH 4/9] update integration tests output --- tests_integration/agents/run_agents.py | 64 +++++++-- .../run_embedding_assistant.py | 64 +++++++-- .../run_event_store_postgres.py | 64 +++++++-- .../run_function_assistant.py | 64 +++++++-- .../run_function_call_assistant.py | 64 +++++++-- .../hith_assistant/run_hith_assistant.py | 64 +++++++-- .../run_input_output_topics.py | 64 +++++++-- .../invoke_kwargs/run_invoke_kwargs.py | 64 +++++++-- .../mcp_assistant/run_mcp_assistant.py | 64 +++++++-- .../run_multimodal_assistant.py | 64 +++++++-- .../rag_assistant/run_rag_assistant.py | 64 +++++++-- .../react_assistant/run_react_assistant.py | 64 +++++++-- tests_integration/run_all.py | 123 +++++++++++++----- .../run_simple_llm_assistant.py | 64 +++++++-- .../run_simple_stream_assistant.py | 64 +++++++-- 15 files changed, 794 insertions(+), 225 deletions(-) diff --git a/tests_integration/agents/run_agents.py b/tests_integration/agents/run_agents.py index 16d3c21..883c0b5 100644 --- a/tests_integration/agents/run_agents.py +++ b/tests_integration/agents/run_agents.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/embedding_assistant/run_embedding_assistant.py b/tests_integration/embedding_assistant/run_embedding_assistant.py index 30325ab..10aae9f 100644 --- a/tests_integration/embedding_assistant/run_embedding_assistant.py +++ b/tests_integration/embedding_assistant/run_embedding_assistant.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/event_store_postgres/run_event_store_postgres.py b/tests_integration/event_store_postgres/run_event_store_postgres.py index fdb4ffd..c9f60ff 100644 --- a/tests_integration/event_store_postgres/run_event_store_postgres.py +++ b/tests_integration/event_store_postgres/run_event_store_postgres.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/function_assistant/run_function_assistant.py b/tests_integration/function_assistant/run_function_assistant.py index f059cbd..dbe48bb 100644 --- a/tests_integration/function_assistant/run_function_assistant.py +++ b/tests_integration/function_assistant/run_function_assistant.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/function_call_assistant/run_function_call_assistant.py b/tests_integration/function_call_assistant/run_function_call_assistant.py index 8765391..f53579f 100644 --- a/tests_integration/function_call_assistant/run_function_call_assistant.py +++ b/tests_integration/function_call_assistant/run_function_call_assistant.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/hith_assistant/run_hith_assistant.py b/tests_integration/hith_assistant/run_hith_assistant.py index 9db1407..5ac8497 100644 --- a/tests_integration/hith_assistant/run_hith_assistant.py +++ b/tests_integration/hith_assistant/run_hith_assistant.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/input_output_topics/run_input_output_topics.py b/tests_integration/input_output_topics/run_input_output_topics.py index 1e663a3..1873e53 100644 --- a/tests_integration/input_output_topics/run_input_output_topics.py +++ b/tests_integration/input_output_topics/run_input_output_topics.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/invoke_kwargs/run_invoke_kwargs.py b/tests_integration/invoke_kwargs/run_invoke_kwargs.py index 3fed0d4..bbf29f6 100644 --- a/tests_integration/invoke_kwargs/run_invoke_kwargs.py +++ b/tests_integration/invoke_kwargs/run_invoke_kwargs.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/mcp_assistant/run_mcp_assistant.py b/tests_integration/mcp_assistant/run_mcp_assistant.py index 8a06451..81b17c4 100644 --- a/tests_integration/mcp_assistant/run_mcp_assistant.py +++ b/tests_integration/mcp_assistant/run_mcp_assistant.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/multimodal_assistant/run_multimodal_assistant.py b/tests_integration/multimodal_assistant/run_multimodal_assistant.py index cfb0830..7d83370 100644 --- a/tests_integration/multimodal_assistant/run_multimodal_assistant.py +++ b/tests_integration/multimodal_assistant/run_multimodal_assistant.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/rag_assistant/run_rag_assistant.py b/tests_integration/rag_assistant/run_rag_assistant.py index 5b04b6c..d1f047b 100644 --- a/tests_integration/rag_assistant/run_rag_assistant.py +++ b/tests_integration/rag_assistant/run_rag_assistant.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/react_assistant/run_react_assistant.py b/tests_integration/react_assistant/run_react_assistant.py index 2af2eda..d31af28 100644 --- a/tests_integration/react_assistant/run_react_assistant.py +++ b/tests_integration/react_assistant/run_react_assistant.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/run_all.py b/tests_integration/run_all.py index 9dcdee5..6a821c5 100644 --- a/tests_integration/run_all.py +++ b/tests_integration/run_all.py @@ -2,13 +2,29 @@ """Run all integration tests by executing run_*.py scripts in each subfolder.""" import argparse +import importlib.util import io -import subprocess import sys from pathlib import Path +from textwrap import indent -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) + + + +def _load_runner_module(script: Path): + """Load a run_*.py file as a module so we can call run_scripts directly.""" + module_name = f"tests_integration.{script.parent.name}.{script.stem}_runner" + spec = importlib.util.spec_from_file_location(module_name, script) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load spec for {script}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module def run_all_scripts(pass_local: bool = True) -> int: @@ -23,12 +39,14 @@ def run_all_scripts(pass_local: bool = True) -> int: """ python_executable = sys.executable current_directory = Path(__file__).parent + repo_root = current_directory.parent # Find all run_*.py scripts in subdirectories run_scripts = sorted(current_directory.glob("*/run_*.py")) - passed_folders = [] - failed_folders = {} + passed_examples = [] + failed_examples = {} + skipped_examples = [] print(f"Found {len(run_scripts)} test runners:") for script in run_scripts: @@ -42,37 +60,84 @@ def run_all_scripts(pass_local: bool = True) -> int: print(f"Running tests in: {folder_name}") print(f"{'=' * 60}") - cmd = [python_executable, str(script)] - if not pass_local: - cmd.append("--no-pass-local") - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - check=True, - cwd=script.parent, - ) - print(result.stdout) - passed_folders.append(folder_name) - except subprocess.CalledProcessError as e: - print(f"Output:\n{e.stdout}") - print(f"Error:\n{e.stderr}") - failed_folders[folder_name] = e.stderr + runner_module = _load_runner_module(script) + runner_results = runner_module.run_scripts(pass_local=pass_local, collect=True) + except Exception as exc: # noqa: BLE001 + example_rel = script.relative_to(repo_root) + error_message = f"Runner failed before executing examples: {exc}" + print(f" ✗ {example_rel}") + print(f" Error: {error_message}") + failed_examples[example_rel] = { + "error": error_message, + "output": "", + "rerun_cmd": f"{python_executable} {example_rel}", + } + continue + + if not isinstance(runner_results, list): + example_rel = script.relative_to(repo_root) + error_message = "Runner did not return result details." + print(f" ✗ {example_rel}") + print(f" Error: {error_message}") + failed_examples[example_rel] = { + "error": error_message, + "output": "", + "rerun_cmd": f"{python_executable} {example_rel}", + } + continue + + for result in runner_results: + example_rel = (script.parent / result["name"]).relative_to(repo_root) + status = result.get("status", "unknown") + output = result.get("output", "").rstrip() + error = result.get("error", "").rstrip() + + if status == "passed": + print(f" ✓ {example_rel}") + if output: + print(indent(output, " ")) + passed_examples.append(example_rel) + elif status == "failed": + print(f" ✗ {example_rel}") + if output: + print(" Output:") + print(indent(output, " ")) + if error: + print(" Error:") + print(indent(error, " ")) + rerun_cmd = f"{python_executable} {example_rel}" + print(f" Rerun with: {rerun_cmd}") + failed_examples[example_rel] = { + "error": error, + "output": output, + "rerun_cmd": rerun_cmd, + } + else: + print(f" - {example_rel} (skipped)") + if error: + print(f" Reason: {error}") + skipped_examples.append(example_rel) # Summary print("\n" + "=" * 60) print("FINAL SUMMARY") print("=" * 60) - print(f"\nPassed folders: {len(passed_folders)}") - for folder in passed_folders: - print(f" ✓ {folder}") - - if failed_folders: - print(f"\nFailed folders: {len(failed_folders)}") - for folder in failed_folders: - print(f" ✗ {folder}") + print(f"\nPassed examples: {len(passed_examples)}") + for example in passed_examples: + print(f" ✓ {example}") + + if skipped_examples: + print(f"\nSkipped examples: {len(skipped_examples)}") + for example in skipped_examples: + print(f" - {example}") + + if failed_examples: + print(f"\nFailed examples: {len(failed_examples)}") + for example, data in failed_examples.items(): + print(f" ✗ {example}") + if data.get("rerun_cmd"): + print(f" Rerun with: {data['rerun_cmd']}") return 1 print("\nAll integration tests passed!") diff --git a/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py b/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py index f51edd2..14a27ea 100644 --- a/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py +++ b/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py b/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py index 16c23df..3b71737 100644 --- a/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py +++ b/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py @@ -7,17 +7,22 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) -def run_scripts(pass_local: bool = True) -> int: + +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +30,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +58,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 From c69667a5a12f04e77e52737469b791585f845e4f Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Tue, 30 Dec 2025 20:58:20 +0000 Subject: [PATCH 5/9] update async call --- .../queue_impl/in_mem_topic_event_queue.py | 6 +- grafi/workflows/impl/async_node_tracker.py | 203 +++++++++++++++--- grafi/workflows/impl/async_output_queue.py | 155 ++++++------- grafi/workflows/impl/event_driven_workflow.py | 63 ++++-- grafi/workflows/impl/utils.py | 87 +++++--- tests/workflow/test_async_node_tracker.py | 181 +++++++--------- tests/workflow/test_async_output_queue.py | 194 ++++++++++++----- tests/workflow/test_event_driven_workflow.py | 19 +- tests/workflow/test_utils.py | 15 +- .../hith_assistant/kyc_assistant_example.py | 5 +- 10 files changed, 609 insertions(+), 319 deletions(-) diff --git a/grafi/topics/queue_impl/in_mem_topic_event_queue.py b/grafi/topics/queue_impl/in_mem_topic_event_queue.py index 3af9b15..8bf6d1e 100644 --- a/grafi/topics/queue_impl/in_mem_topic_event_queue.py +++ b/grafi/topics/queue_impl/in_mem_topic_event_queue.py @@ -68,9 +68,9 @@ async def fetch( # If timeout is 0 or None and no data, return immediately while not await self.can_consume(consumer_id): try: - logger.debug( - f"Consumer {consumer_id} waiting for new messages with timeout={timeout}" - ) + # logger.debug( + # f"Consumer {consumer_id} waiting for new messages with timeout={timeout}" + # ) await asyncio.wait_for(self._cond.wait(), timeout) except asyncio.TimeoutError: return [] diff --git a/grafi/workflows/impl/async_node_tracker.py b/grafi/workflows/impl/async_node_tracker.py index 8c9e751..11662fe 100644 --- a/grafi/workflows/impl/async_node_tracker.py +++ b/grafi/workflows/impl/async_node_tracker.py @@ -3,54 +3,207 @@ # ────────────────────────────────────────────────────────────────────────────── import asyncio from collections import defaultdict -from typing import Dict +from typing import Dict, Optional, Set + +from loguru import logger class AsyncNodeTracker: + """ + Central tracker for workflow activity and quiescence detection. + + Design: All tracking calls come from the ORCHESTRATOR layer, + not from TopicBase. This keeps topics as pure message queues. + + Quiescence = (no active nodes) AND (no uncommitted messages) AND (work done) + + Usage in workflow: + # In publish_events(): + tracker.on_messages_published(len(published_events)) + + # In _commit_events(): + tracker.on_messages_committed(len(events)) + + # In node processing: + await tracker.enter(node_name) + ... process ... + await tracker.leave(node_name) + """ + def __init__(self) -> None: - self._active: set[str] = set() - self._processing_count: Dict[str, int] = defaultdict( - int - ) # Track how many times each node processed + # Node activity tracking + self._active: Set[str] = set() + self._processing_count: Dict[str, int] = defaultdict(int) + + # Message tracking (uncommitted = published but not yet committed) + self._uncommitted_messages: int = 0 + + # Synchronization self._cond = asyncio.Condition() - self._idle_event = asyncio.Event() - # Set the event initially since we start in idle state - self._idle_event.set() + self._quiescence_event = asyncio.Event() + + # Work tracking (prevents premature quiescence before any work) + self._total_committed: int = 0 + self._has_started: bool = False + + # Force stop flag (for explicit workflow stop) + self._force_stopped: bool = False def reset(self) -> None: - """ - Reset the tracker to its initial state. - """ - self._active = set() - self._processing_count = defaultdict(int) + """Reset for a new workflow run.""" + self._active.clear() + self._processing_count.clear() + self._uncommitted_messages = 0 self._cond = asyncio.Condition() - self._idle_event = asyncio.Event() - # Set the event initially since we start in idle state - self._idle_event.set() + self._quiescence_event = asyncio.Event() + self._total_committed = 0 + self._has_started = False + self._force_stopped = False + + # ───────────────────────────────────────────────────────────────────────── + # Node Lifecycle (called from _invoke_node) + # ───────────────────────────────────────────────────────────────────────── async def enter(self, node_name: str) -> None: + """Called when a node begins processing.""" async with self._cond: - self._idle_event.clear() + self._has_started = True + self._quiescence_event.clear() self._active.add(node_name) self._processing_count[node_name] += 1 async def leave(self, node_name: str) -> None: + """Called when a node finishes processing.""" async with self._cond: self._active.discard(node_name) - if not self._active: - self._idle_event.set() - self._cond.notify_all() + self._check_quiescence() + self._cond.notify_all() - async def wait_idle_event(self) -> None: + # ───────────────────────────────────────────────────────────────────────── + # Message Tracking (called from orchestrator utilities) + # ───────────────────────────────────────────────────────────────────────── + + def on_messages_published(self, count: int = 1, source: str = "") -> None: + """ + Called when messages are published to topics. + + Call site: publish_events() in utils.py + """ + if count <= 0: + return + self._has_started = True + self._quiescence_event.clear() + self._uncommitted_messages += count + + logger.debug(f"Tracker: {count} messages published from {source} (uncommitted={self._uncommitted_messages})") + + def on_messages_committed(self, count: int = 1, source: str = "") -> None: + """ + Called when messages are committed (consumed and acknowledged). + + Call site: _commit_events() in EventDrivenWorkflow """ - Wait until the tracker is idle, meaning no active nodes. - This is useful for synchronization points in workflows. + if count <= 0: + return + self._uncommitted_messages = max(0, self._uncommitted_messages - count) + self._total_committed += count + self._check_quiescence() + + logger.debug( + f"Tracker: {count} messages committed from {source} " + f"(uncommitted={self._uncommitted_messages}, total={self._total_committed})" + ) + + # Aliases for clarity + def on_message_published(self) -> None: + """Single message version.""" + self.on_messages_published(1) + + def on_message_committed(self) -> None: + """Single message version.""" + self.on_messages_committed(1) + + # ───────────────────────────────────────────────────────────────────────── + # Quiescence Detection + # ───────────────────────────────────────────────────────────────────────── + + def _check_quiescence(self) -> None: + """Check and signal quiescence if all conditions met.""" + logger.debug( + f"Tracker: checking quiescence - active={list(self._active)}, " + f"uncommitted={self._uncommitted_messages}, " + f"has_started={self._has_started}, " + f"total_committed={self._total_committed}, " + f"is_quiescent={self.is_quiescent}" + ) + if self.is_quiescent: + logger.info(f"Tracker: quiescence detected (committed={self._total_committed})") + self._quiescence_event.set() + + @property + def is_quiescent(self) -> bool: + """ + True when workflow is truly idle: + - No nodes actively processing + - No messages waiting to be committed + - At least some work was done + """ + return ( + not self._active + and self._uncommitted_messages == 0 + and self._has_started + and self._total_committed > 0 + ) + + @property + def should_terminate(self) -> bool: """ - await self._idle_event.wait() + True when workflow should stop iteration. + Either natural quiescence or explicit force stop. + """ + return self.is_quiescent or self._force_stopped + + def force_stop(self) -> None: + """ + Force the workflow to stop immediately. + Called when workflow.stop() is invoked. + """ + logger.info("Tracker: force stop requested") + self._force_stopped = True + self._quiescence_event.set() def is_idle(self) -> bool: + """Legacy: just checks if no active nodes.""" return not self._active + async def wait_for_quiescence(self, timeout: Optional[float] = None) -> bool: + """Wait until quiescent. Returns False on timeout.""" + try: + if timeout: + await asyncio.wait_for(self._quiescence_event.wait(), timeout) + else: + await self._quiescence_event.wait() + return True + except asyncio.TimeoutError: + return False + + async def wait_idle_event(self) -> None: + """Legacy compatibility.""" + await self._quiescence_event.wait() + + # ───────────────────────────────────────────────────────────────────────── + # Metrics + # ───────────────────────────────────────────────────────────────────────── + def get_activity_count(self) -> int: - """Get total processing count across all nodes""" + """Total processing count across all nodes.""" return sum(self._processing_count.values()) + + def get_metrics(self) -> Dict: + """Detailed metrics for debugging.""" + return { + "active_nodes": list(self._active), + "uncommitted_messages": self._uncommitted_messages, + "total_committed": self._total_committed, + "is_quiescent": self.is_quiescent, + } \ No newline at end of file diff --git a/grafi/workflows/impl/async_output_queue.py b/grafi/workflows/impl/async_output_queue.py index 32145a6..df6ead3 100644 --- a/grafi/workflows/impl/async_output_queue.py +++ b/grafi/workflows/impl/async_output_queue.py @@ -1,6 +1,8 @@ import asyncio from typing import List +from loguru import logger + from grafi.common.events.topic_events.topic_event import TopicEvent from grafi.topics.topic_base import TopicBase from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker @@ -8,8 +10,9 @@ class AsyncOutputQueue: """ - Manages output topics and their listeners for async workflow execution. - Wraps output_topics, listener_tasks, and tracker functionality. + Manages output topics and provides async iteration over output events. + + Simplified: All quiescence detection delegated to AsyncNodeTracker. """ def __init__( @@ -22,100 +25,106 @@ def __init__( self.consumer_name = consumer_name self.tracker = tracker self.queue: asyncio.Queue[TopicEvent] = asyncio.Queue() - self.listener_tasks: List[asyncio.Task] = [] + self._listener_tasks: List[asyncio.Task] = [] + self._stopped = False async def start_listeners(self) -> None: """Start listener tasks for all output topics.""" - self.listener_tasks = [ + self._stopped = False + self._listener_tasks = [ asyncio.create_task(self._output_listener(topic)) for topic in self.output_topics ] async def stop_listeners(self) -> None: """Stop all listener tasks.""" - for task in self.listener_tasks: + self._stopped = True + for task in self._listener_tasks: task.cancel() - await asyncio.gather(*self.listener_tasks, return_exceptions=True) + await asyncio.gather(*self._listener_tasks, return_exceptions=True) + self._listener_tasks.clear() async def _output_listener(self, topic: TopicBase) -> None: """ - Streams *matching* records from `topic` into `queue`. - Exits when the graph is idle *and* the topic has no more unseen data, - with proper handling for downstream node activation. - """ - last_activity_count = 0 - - while True: - # waiter 1: "some records arrived" - topic_task = asyncio.create_task(topic.consume(self.consumer_name)) - # waiter 2: "graph just became idle" - idle_event_waiter = asyncio.create_task(self.tracker.wait_idle_event()) - - done, pending = await asyncio.wait( - {topic_task, idle_event_waiter}, - return_when=asyncio.FIRST_COMPLETED, - ) - - # ---- If records arrived ----------------------------------------- - if topic_task in done: - output_events = topic_task.result() + Forward events to queue and track message consumption. - for output_event in output_events: - await self.queue.put(output_event) - - # ---- Check for workflow completion ---------------- - if idle_event_waiter in done and self.tracker.is_idle(): - current_activity = self.tracker.get_activity_count() - - # If no new activity since last check and no data, we're done - if ( - current_activity == last_activity_count - and not await topic.can_consume(self.consumer_name) - ): - # cancel an unfinished waiter (if any) to avoid warnings - for t in pending: - t.cancel() - break - - last_activity_count = current_activity - - # Cancel the topic task since we're checking idle state - for t in pending: - t.cancel() + When events are consumed from output topics, they've reached their + destination (the output queue), so we mark them as committed. + """ + while not self._stopped: + try: + events = await topic.consume(self.consumer_name, timeout=0.1) + for event in events: + await self.queue.put(event) + # Mark messages as committed when they reach the output queue + if events: + logger.debug(f"Output listener: consumed {len(events)} events from {topic.name}") + self.tracker.on_messages_committed(len(events), source=f"output_listener:{topic.name}") + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Output listener error for {topic.name}: {e}") + await asyncio.sleep(0.1) def __aiter__(self) -> "AsyncOutputQueue": - """Make AsyncOutputQueue async iterable.""" - self._last_activity_count = 0 return self async def __anext__(self) -> TopicEvent: - """Async iteration implementation with idle detection.""" - # two parallel waiters + """ + SIMPLIFIED: Delegates quiescence check entirely to tracker. + + Removed: + - last_activity_count tracking + - asyncio.sleep(0) hack + - duplicated idle detection logic + """ + check_count = 0 while True: + check_count += 1 + if check_count % 20 == 0: # Log every ~10 seconds (20 * 0.5s) + logger.debug( + f"AsyncOutputQueue: still waiting - " + f"queue_empty={self.queue.empty()}, " + f"tracker_metrics={self.tracker.get_metrics()}" + ) + # Fast path: queue has items + if not self.queue.empty(): + try: + return self.queue.get_nowait() + except asyncio.QueueEmpty: + pass + + # Check for completion (natural quiescence or force stop) + if self.tracker.should_terminate and self.queue.empty(): + raise StopAsyncIteration + + # Wait for queue item or quiescence queue_task = asyncio.create_task(self.queue.get()) - idle_task = asyncio.create_task(self.tracker._idle_event.wait()) + quiescent_task = asyncio.create_task( + self.tracker.wait_for_quiescence(timeout=0.5) + ) done, pending = await asyncio.wait( - {queue_task, idle_task}, + {queue_task, quiescent_task}, return_when=asyncio.FIRST_COMPLETED, ) - # Case A: we got a queue item first → stream it - if queue_task in done: - idle_task.cancel() - await asyncio.gather(idle_task, return_exceptions=True) - return queue_task.result() - - # Case B: pipeline went idle first - queue_task.cancel() - await asyncio.gather(queue_task, return_exceptions=True) - - # Give downstream consumers one chance to register activity. - await asyncio.sleep(0) # one event‑loop tick - - if self.tracker.is_idle() and self.queue.empty(): - current_activity = self.tracker.get_activity_count() - # Only terminate if no new activity since last check - if current_activity == self._last_activity_count: - raise StopAsyncIteration - self._last_activity_count = current_activity + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Got queue item + if queue_task in done and not queue_task.cancelled(): + try: + return queue_task.result() + except asyncio.QueueEmpty: + continue + + # Quiescence or force stop detected + if self.tracker.should_terminate and self.queue.empty(): + raise StopAsyncIteration \ No newline at end of file diff --git a/grafi/workflows/impl/event_driven_workflow.py b/grafi/workflows/impl/event_driven_workflow.py index d337314..abcfd9f 100644 --- a/grafi/workflows/impl/event_driven_workflow.py +++ b/grafi/workflows/impl/event_driven_workflow.py @@ -1,10 +1,6 @@ import asyncio from collections import deque -from typing import Any -from typing import AsyncGenerator -from typing import Dict -from typing import List -from typing import Set +from typing import Any, AsyncGenerator, Dict, List, Set from loguru import logger from openinference.semconv.trace import OpenInferenceSpanKindValues @@ -18,8 +14,7 @@ ) from grafi.common.events.topic_events.publish_to_topic_event import PublishToTopicEvent from grafi.common.events.topic_events.topic_event import TopicEvent -from grafi.common.exceptions import NodeExecutionError -from grafi.common.exceptions import WorkflowError +from grafi.common.exceptions import NodeExecutionError, WorkflowError from grafi.common.models.invoke_context import InvokeContext from grafi.nodes.node import Node from grafi.nodes.node_base import NodeBase @@ -33,11 +28,12 @@ from grafi.topics.topic_types import TopicType from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker from grafi.workflows.impl.async_output_queue import AsyncOutputQueue -from grafi.workflows.impl.utils import get_async_output_events -from grafi.workflows.impl.utils import get_node_input -from grafi.workflows.impl.utils import publish_events -from grafi.workflows.workflow import Workflow -from grafi.workflows.workflow import WorkflowBuilder +from grafi.workflows.impl.utils import ( + get_async_output_events, + get_node_input, + publish_events, +) +from grafi.workflows.workflow import Workflow, WorkflowBuilder class EventDrivenWorkflow(Workflow): @@ -64,7 +60,7 @@ class EventDrivenWorkflow(Workflow): # Queue of nodes that are ready to invoke (in response to published events) _invoke_queue: deque[NodeBase] = PrivateAttr(default=deque()) - _tracker: AsyncNodeTracker = AsyncNodeTracker() + _tracker: AsyncNodeTracker = PrivateAttr(default_factory=AsyncNodeTracker) # Optional callback that handles output events # Including agent output event, stream event and hil event @@ -73,6 +69,14 @@ def model_post_init(self, _context: Any) -> None: self._add_topics() self._handle_function_calling_nodes() + def stop(self) -> None: + """ + Stop the workflow execution. + Overrides base class to also trigger force stop on the tracker. + """ + super().stop() + self._tracker.force_stop() + @classmethod def builder(cls) -> WorkflowBuilder: """ @@ -205,12 +209,13 @@ async def _get_output_events(self) -> List[ConsumeFromTopicEvent]: return consumed_events + async def _commit_events( - self, consumer_name: str, topic_events: List[ConsumeFromTopicEvent] + self, consumer_name: str, topic_events: List[ConsumeFromTopicEvent], + track_commit: bool = True, ) -> None: if not topic_events: return - # commit all consumed events topic_max_offset: Dict[str, int] = {} for topic_event in topic_events: @@ -221,6 +226,13 @@ async def _commit_events( for topic, offset in topic_max_offset.items(): await self._topics[topic].commit(consumer_name, offset) + # Notify tracker that messages have been committed + # (skip if already tracked elsewhere, e.g., by output listener) + if track_commit: + logger.debug(f"Committing {len(topic_events)} events for {consumer_name}, track_commit={track_commit}") + self._tracker.on_messages_committed(len(topic_events), source=f"commit:{consumer_name}") + + async def _add_to_invoke_queue(self, event: TopicEvent) -> None: topic_name = event.name @@ -271,7 +283,7 @@ async def invoke_sequential( async for result in node.invoke( invoke_context, node_consumed_events ): - published_events.extend(await publish_events(node, result)) + published_events.extend(await publish_events(node, result, self._tracker)) for event in published_events: await self._add_to_invoke_queue(event) @@ -301,6 +313,7 @@ async def invoke_parallel( self, input_data: PublishToTopicEvent ) -> AsyncGenerator[ConsumeFromTopicEvent, None]: invoke_context = input_data.invoke_context + logger.debug(f"invoke_parallel: tracker_id={id(self._tracker)}, metrics={self._tracker.get_metrics()}") # Start a background task to process all nodes (including streaming generators) node_processing_task = [ @@ -374,9 +387,11 @@ async def invoke_parallel( finally: await output_queue.stop_listeners() - # Commit all consumed output events + # Commit all consumed output events to topics + # (tracking already done by output listener, so skip tracker update) await self._commit_events( - consumer_name=self.name, topic_events=consumed_output_events + consumer_name=self.name, topic_events=consumed_output_events, + track_commit=False, ) # 4. graceful shutdown all the nodes @@ -500,7 +515,7 @@ def _cancel_all_active_tasks() -> None: if consumed_events: async for event in node.invoke(invoke_context, consumed_events): node_output_events.extend( - await publish_events(node=node, publish_event=event) + await publish_events(node=node, publish_event=event, tracker=self._tracker) ) await self._commit_events( @@ -576,6 +591,7 @@ async def init_workflow( self, input_data: PublishToTopicEvent, is_sequential: bool = False ) -> Any: # 1 – initial seeding + logger.debug(f"init_workflow: is_sequential={is_sequential}, tracker_id={id(self._tracker)}") if not is_sequential: self._tracker.reset() @@ -615,7 +631,13 @@ async def init_workflow( if is_sequential: await self._add_to_invoke_queue(event) + logger.debug(f"init_workflow: events_to_record={len(events_to_record)}, input_topics={len(input_topics)}") if events_to_record: + # Track initial input messages for quiescence detection + if not is_sequential: + logger.debug(f"init_workflow: calling on_messages_published({len(events_to_record)})") + self._tracker.on_messages_published(len(events_to_record), source="init_workflow") + logger.debug(f"init_workflow: tracker after publish: {self._tracker.get_metrics()}") await container.event_store.record_events(events_to_record) else: # When there is unfinished workflow, we need to restore the workflow topics @@ -667,6 +689,9 @@ async def init_workflow( ) ) if paired_event: + # Track the published message for quiescence detection + if not is_sequential: + self._tracker.on_messages_published(1, source="restore_paired_input") if is_sequential: await self._add_to_invoke_queue(paired_event) await container.event_store.record_event(paired_event) diff --git a/grafi/workflows/impl/utils.py b/grafi/workflows/impl/utils.py index 6c9df6f..4190625 100644 --- a/grafi/workflows/impl/utils.py +++ b/grafi/workflows/impl/utils.py @@ -1,5 +1,4 @@ -from typing import Dict -from typing import List +from typing import TYPE_CHECKING, Dict, List, Optional from grafi.common.events.topic_events.consume_from_topic_event import ( ConsumeFromTopicEvent, @@ -9,16 +8,15 @@ from grafi.common.models.message import Message from grafi.nodes.node_base import NodeBase +if TYPE_CHECKING: + from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker + def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: """ Process a list of TopicEvents, grouping by name and aggregating streaming messages. - - Args: - events: List of TopicEvents to process - - Returns: - List of processed TopicEvents with streaming messages aggregated + + NO CHANGES NEEDED - this is read-only aggregation. """ # Group events by name events_by_topic: Dict[str, List[TopicEvent]] = {} @@ -35,9 +33,7 @@ def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: non_streaming_events: List[TopicEvent] = [] for event in topic_events: - # Check if event.data contains streaming messages is_streaming_event = False - # Handle both single message and list of messages messages = event.data if messages and len(messages) > 0 and messages[0].is_streaming: is_streaming_event = True @@ -47,25 +43,18 @@ def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: else: non_streaming_events.append(event) - # Add non-streaming events as-is output_events.extend(non_streaming_events) - # Aggregate streaming events if any exist if streaming_events: - # Use the first streaming event as the base for creating the aggregated event base_event = streaming_events[0] - - # Aggregate content from all streaming messages aggregated_content_parts = [] for event in streaming_events: messages = event.data if isinstance(event.data, list) else [event.data] for message in messages: if message.content: aggregated_content_parts.append(message.content) - aggregated_content = "".join(aggregated_content_parts) # type: ignore[arg-type] + aggregated_content = "".join(aggregated_content_parts) - # Create a new message with aggregated content - # Copy properties from the first message but update content and streaming flag first_message = ( base_event.data if isinstance(base_event.data, list) @@ -74,40 +63,53 @@ def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: aggregated_message = Message( role=first_message.role, content=aggregated_content, - is_streaming=False, # Aggregated message is no longer streaming + is_streaming=False, ) - # Create new event based on the base event type aggregated_event = base_event aggregated_event.data = [aggregated_message] - output_events.append(aggregated_event) return output_events async def publish_events( - node: NodeBase, publish_event: PublishToTopicEvent + node: NodeBase, + publish_event: PublishToTopicEvent, + tracker: Optional["AsyncNodeTracker"] = None, # NEW: Optional tracker ) -> List[PublishToTopicEvent]: + """ + Publish events to all topics the node publishes to. + + CHANGE: Added optional tracker parameter. + When provided, notifies tracker of published messages. + """ published_events: List[PublishToTopicEvent] = [] + for topic in node.publish_to: event = await topic.publish_data(publish_event) - if event: published_events.append(event) + # NEW: Notify tracker of published messages + if tracker and published_events: + tracker.on_messages_published(len(published_events), source=f"node:{node.name}") + return published_events async def get_node_input(node: NodeBase) -> List[ConsumeFromTopicEvent]: + """ + Get input events for a node from its subscribed topics. + + NO CHANGES NEEDED - consumption tracking happens at commit time. + """ consumed_events: List[ConsumeFromTopicEvent] = [] node_subscribed_topics = node._subscribed_topics.values() - # Process each topic the node is subscribed to for subscribed_topic in node_subscribed_topics: if await subscribed_topic.can_consume(node.name): - # Get messages from topic and create consume events node_consumed_events = await subscribed_topic.consume(node.name) for event in node_consumed_events: consumed_event = ConsumeFromTopicEvent( @@ -122,3 +124,38 @@ async def get_node_input(node: NodeBase) -> List[ConsumeFromTopicEvent]: consumed_events.append(consumed_event) return consumed_events + + +# ============================================================================= +# Alternative: Wrapper approach if you can't modify function signatures +# ============================================================================= + +class TrackedPublisher: + """ + Wrapper that adds tracking to publish_events without changing its signature. + + Usage: + publisher = TrackedPublisher(tracker) + events = await publisher.publish(node, event) + """ + + def __init__(self, tracker: "AsyncNodeTracker"): + self.tracker = tracker + + async def publish( + self, + node: NodeBase, + publish_event: PublishToTopicEvent, + ) -> List[PublishToTopicEvent]: + """Publish with automatic tracking.""" + published_events: List[PublishToTopicEvent] = [] + + for topic in node.publish_to: + event = await topic.publish_data(publish_event) + if event: + published_events.append(event) + + if published_events: + self.tracker.on_messages_published(len(published_events)) + + return published_events \ No newline at end of file diff --git a/tests/workflow/test_async_node_tracker.py b/tests/workflow/test_async_node_tracker.py index 43bc138..c0a5118 100644 --- a/tests/workflow/test_async_node_tracker.py +++ b/tests/workflow/test_async_node_tracker.py @@ -13,155 +13,124 @@ def tracker(self): @pytest.mark.asyncio async def test_initial_state(self, tracker): - """Test that tracker starts in idle state.""" + """Tracker starts idle with no work recorded.""" assert tracker.is_idle() + assert tracker.is_quiescent is False assert tracker.get_activity_count() == 0 - assert tracker._idle_event.is_set() + assert tracker.get_metrics()["uncommitted_messages"] == 0 @pytest.mark.asyncio - async def test_enter_makes_tracker_active(self, tracker): - """Test that entering a node makes the tracker active.""" + async def test_enter_and_leave_updates_activity(self, tracker): + """Entering and leaving nodes updates activity counts.""" await tracker.enter("node1") assert not tracker.is_idle() - assert not tracker._idle_event.is_set() assert tracker.get_activity_count() == 1 assert "node1" in tracker._active - @pytest.mark.asyncio - async def test_leave_makes_tracker_idle(self, tracker): - """Test that leaving the last node makes the tracker idle.""" - await tracker.enter("node1") await tracker.leave("node1") assert tracker.is_idle() - assert tracker._idle_event.is_set() - assert tracker.get_activity_count() == 1 # Count persists - assert "node1" not in tracker._active + # No commits yet so quiescence is still False + assert tracker.is_quiescent is False + assert tracker.get_activity_count() == 1 @pytest.mark.asyncio - async def test_multiple_nodes_tracking(self, tracker): - """Test tracking multiple nodes.""" - await tracker.enter("node1") - await tracker.enter("node2") - - assert not tracker.is_idle() - assert tracker.get_activity_count() == 2 - assert "node1" in tracker._active - assert "node2" in tracker._active - - await tracker.leave("node1") - assert not tracker.is_idle() # Still has node2 - - await tracker.leave("node2") - assert tracker.is_idle() + async def test_message_tracking_and_quiescence(self, tracker): + """Published/committed message tracking drives quiescence detection.""" + tracker.on_messages_published(2) + assert tracker.is_quiescent is False + assert tracker.get_metrics()["uncommitted_messages"] == 2 - @pytest.mark.asyncio - async def test_reentrant_node_increases_count(self, tracker): - """Test that entering the same node multiple times increases count.""" - await tracker.enter("node1") - await tracker.enter("node1") + tracker.on_messages_committed(1) + assert tracker.is_quiescent is False + assert tracker.get_metrics()["uncommitted_messages"] == 1 - assert tracker.get_activity_count() == 2 - assert len(tracker._active) == 1 # Still just one node in active set + tracker.on_messages_committed(1) + assert tracker.is_quiescent is True + assert tracker.get_metrics()["uncommitted_messages"] == 0 @pytest.mark.asyncio - async def test_wait_idle_event(self, tracker): - """Test waiting for idle event.""" - # Initially idle - await asyncio.wait_for(tracker.wait_idle_event(), timeout=0.1) - - # Enter a node - await tracker.enter("node1") + async def test_wait_for_quiescence(self, tracker): + """wait_for_quiescence resolves when work finishes.""" + tracker.on_messages_published(1) - # Create a task that waits for idle - idle_task = asyncio.create_task(tracker.wait_idle_event()) + async def finish_work(): + await asyncio.sleep(0.01) + tracker.on_messages_committed(1) - # Should not be done yet - await asyncio.sleep(0.01) - assert not idle_task.done() + asyncio.create_task(finish_work()) - # Leave node to trigger idle - await tracker.leave("node1") + result = await tracker.wait_for_quiescence(timeout=0.5) + assert result is True + assert tracker.is_quiescent is True - # Now the wait should complete - await asyncio.wait_for(idle_task, timeout=0.1) + @pytest.mark.asyncio + async def test_wait_for_quiescence_timeout(self, tracker): + """wait_for_quiescence returns False on timeout.""" + result = await tracker.wait_for_quiescence(timeout=0.01) + assert result is False + assert tracker.is_quiescent is False @pytest.mark.asyncio async def test_reset(self, tracker): - """Test reset functionality.""" - # Add some activity + """Reset clears activity and quiescence state.""" await tracker.enter("node1") - await tracker.enter("node2") - await tracker.leave("node1") - - assert not tracker.is_idle() - assert tracker.get_activity_count() > 0 + tracker.on_messages_published(1) + tracker.on_messages_committed(1) - # Reset tracker.reset() - # Should be back to initial state assert tracker.is_idle() + assert tracker.is_quiescent is False assert tracker.get_activity_count() == 0 - assert tracker._idle_event.is_set() - assert len(tracker._active) == 0 + assert tracker.get_metrics()["total_committed"] == 0 @pytest.mark.asyncio - async def test_concurrent_enter_leave(self, tracker): - """Test concurrent enter/leave operations.""" - - async def enter_leave_cycle(node_name: str, cycles: int): - for _ in range(cycles): - await tracker.enter(node_name) - await asyncio.sleep(0.001) # Small delay - await tracker.leave(node_name) + async def test_force_stop(self, tracker): + """Force stop terminates workflow even with uncommitted messages.""" + tracker.on_messages_published(2) + assert tracker.is_quiescent is False + assert tracker.should_terminate is False - # Run multiple concurrent cycles - tasks = [ - asyncio.create_task(enter_leave_cycle(f"node{i}", 10)) for i in range(5) - ] + tracker.force_stop() - await asyncio.gather(*tasks) - - # Should be idle after all complete - assert tracker.is_idle() - assert tracker.get_activity_count() == 50 # 5 nodes * 10 cycles + # Not quiescent (uncommitted messages still exist) + assert tracker.is_quiescent is False + # But should_terminate is True due to force stop + assert tracker.should_terminate is True + assert tracker._force_stopped is True @pytest.mark.asyncio - async def test_leave_nonexistent_node(self, tracker): - """Test leaving a node that was never entered.""" - # Should not raise an error - await tracker.leave("nonexistent") - assert tracker.is_idle() + async def test_should_terminate_on_quiescence(self, tracker): + """should_terminate is True when naturally quiescent.""" + tracker.on_messages_published(1) + tracker.on_messages_committed(1) - @pytest.mark.asyncio - async def test_condition_notification(self, tracker): - """Test that condition is properly notified on idle.""" - await tracker.enter("node1") + assert tracker.is_quiescent is True + assert tracker.should_terminate is True + assert tracker._force_stopped is False - # Create a flag to verify notification happened - notified = False + @pytest.mark.asyncio + async def test_force_stop_triggers_quiescence_event(self, tracker): + """Force stop sets the quiescence event so waiters can proceed.""" + tracker.on_messages_published(1) - async def wait_for_notification(): - nonlocal notified - async with tracker._cond: - await tracker._cond.wait() - notified = True + # Event should not be set yet + assert not tracker._quiescence_event.is_set() - wait_task = asyncio.create_task(wait_for_notification()) + tracker.force_stop() - # Give task time to start waiting - await asyncio.sleep(0.01) + # Event should now be set + assert tracker._quiescence_event.is_set() - # Leave node to trigger notification - await tracker.leave("node1") + @pytest.mark.asyncio + async def test_reset_clears_force_stop(self, tracker): + """Reset clears the force stop flag.""" + tracker.force_stop() + assert tracker._force_stopped is True - # Wait should complete - try: - await asyncio.wait_for(wait_task, timeout=0.1) - except asyncio.TimeoutError: - pass # It's ok if it times out, we just check if notified + tracker.reset() - # Check that notification happened - assert tracker.is_idle() + assert tracker._force_stopped is False + assert tracker.should_terminate is False diff --git a/tests/workflow/test_async_output_queue.py b/tests/workflow/test_async_output_queue.py index 049178f..9fb626d 100644 --- a/tests/workflow/test_async_output_queue.py +++ b/tests/workflow/test_async_output_queue.py @@ -22,10 +22,10 @@ def __init__(self, name: str): self._events = [] self._consumed_offset = -1 - async def consume(self, consumer_name: str): + async def consume(self, consumer_name: str, timeout: float | None = None): """Mock async consume that returns events.""" - # Simulate waiting for events - await asyncio.sleep(0.01) + if timeout and timeout > 0: + await asyncio.sleep(timeout) # Return events after consumed offset new_events = [e for e in self._events if e.offset > self._consumed_offset] @@ -64,15 +64,15 @@ def test_initialization(self, output_queue, mock_topics, tracker): assert output_queue.consumer_name == "test_consumer" assert output_queue.tracker == tracker assert isinstance(output_queue.queue, asyncio.Queue) - assert output_queue.listener_tasks == [] + assert output_queue._listener_tasks == [] @pytest.mark.asyncio async def test_start_listeners(self, output_queue, mock_topics): """Test starting listener tasks.""" await output_queue.start_listeners() - assert len(output_queue.listener_tasks) == len(mock_topics) - for task in output_queue.listener_tasks: + assert len(output_queue._listener_tasks) == len(mock_topics) + for task in output_queue._listener_tasks: assert isinstance(task, asyncio.Task) assert not task.done() @@ -83,7 +83,7 @@ async def test_start_listeners(self, output_queue, mock_topics): async def test_stop_listeners(self, output_queue): """Test stopping listener tasks.""" await output_queue.start_listeners() - tasks = output_queue.listener_tasks.copy() + tasks = output_queue._listener_tasks.copy() await output_queue.stop_listeners() @@ -97,9 +97,6 @@ async def test_output_listener_receives_events(self, mock_topics, tracker): topic = mock_topics[0] queue = AsyncOutputQueue([topic], "test_consumer", tracker) - # Add activity to prevent listener from exiting - await tracker.enter("test_node") - # Add test event test_event = PublishToTopicEvent( name="output1", @@ -118,7 +115,7 @@ async def test_output_listener_receives_events(self, mock_topics, tracker): listener_task = asyncio.create_task(queue._output_listener(topic)) # Wait a bit for event to be processed - await asyncio.sleep(0.05) + await asyncio.sleep(0.15) # Check event was queued assert not queue.queue.empty() @@ -126,7 +123,6 @@ async def test_output_listener_receives_events(self, mock_topics, tracker): assert queued_event == test_event # Clean up - await tracker.leave("test_node") listener_task.cancel() await asyncio.gather(listener_task, return_exceptions=True) @@ -164,56 +160,51 @@ async def collect_events(): assert events[0] == test_event @pytest.mark.asyncio - async def test_async_iteration_stops_on_idle(self, output_queue, tracker): - """Test that async iteration stops when tracker is idle and queue is empty.""" - # Make tracker idle - assert tracker.is_idle() - - # Ensure queue is empty - assert output_queue.queue.empty() + async def test_async_iteration_stops_after_quiescence(self, output_queue, tracker): + """Async iteration ends when tracker reports quiescence and queue is empty.""" + tracker.on_messages_published(1) + tracker.on_messages_committed(1) - # Iteration should stop events = [] async for event in output_queue: events.append(event) - assert len(events) == 0 + assert events == [] @pytest.mark.asyncio - async def test_listener_exits_on_idle_and_no_data(self, mock_topics, tracker): - """Test that listener exits when workflow is idle and no more data.""" - topic = mock_topics[0] - queue = AsyncOutputQueue([topic], "test_consumer", tracker) + async def test_async_iteration_waits_for_quiescence_or_events(self, tracker): + """__anext__ waits for either new queue data or quiescence signal.""" + queue = AsyncOutputQueue([], "test_consumer", tracker) - # Ensure tracker is idle - assert tracker.is_idle() + async def signal_quiescence(): + tracker.on_messages_published(1) + await asyncio.sleep(0.02) + tracker.on_messages_committed(1) - # Run listener - should exit quickly since idle and no data - await queue._output_listener(topic) + signal_task = asyncio.create_task(signal_quiescence()) - # Should complete without hanging - assert True - - @pytest.mark.asyncio - async def test_listener_continues_with_activity(self, mock_topics, tracker): - """Test that listener continues when there's activity.""" - topic = mock_topics[0] - queue = AsyncOutputQueue([topic], "test_consumer", tracker) + events = [] + async for event in queue: + events.append(event) - # Add activity - await tracker.enter("node1") + await signal_task + assert events == [] - # Start listener - listener_task = asyncio.create_task(queue._output_listener(topic)) + @pytest.mark.asyncio + async def test_event_emitted_before_quiescence(self, tracker): + """Events in the queue are yielded even if quiescence follows immediately.""" + queue = AsyncOutputQueue([], "test_consumer", tracker) + queued_event = Mock() + queued_event.name = "queued_event" + await queue.queue.put(queued_event) - # Should still be running - await asyncio.sleep(0.05) - assert not listener_task.done() + tracker.on_messages_published(1) + tracker.on_messages_committed(1) - # Clean up - listener_task.cancel() - await asyncio.gather(listener_task, return_exceptions=True) - await tracker.leave("node1") + events = [] + async for event in queue: + events.append(event) + assert [e.name for e in events] == ["queued_event"] @pytest.mark.asyncio async def test_type_annotations(self, output_queue): @@ -291,19 +282,23 @@ async def test_anext_waits_for_activity_count_stabilization(self): async def simulate_node_activity(): """Simulate node activity that should prevent premature termination.""" - # First node processes + # First node processes - simulate full message lifecycle + tracker.on_messages_published(1) await tracker.enter("node_1") await output_queue.queue.put(Mock(name="event_1")) await tracker.leave("node_1") + tracker.on_messages_committed(1) # Yield control - simulates realistic timing where next node # starts within the same event loop cycle await asyncio.sleep(0) - # Second node picks up and processes + # Second node picks up and processes - simulate full message lifecycle + tracker.on_messages_published(1) await tracker.enter("node_2") await output_queue.queue.put(Mock(name="event_2")) await tracker.leave("node_2") + tracker.on_messages_committed(1) # Start the activity simulation activity_task = asyncio.create_task(simulate_node_activity()) @@ -333,11 +328,13 @@ async def test_anext_terminates_when_truly_idle(self): tracker=tracker, ) - # Single node processes and finishes + # Single node processes and finishes - simulate full message lifecycle async def simulate_single_node(): + tracker.on_messages_published(1) await tracker.enter("node_1") await output_queue.queue.put(Mock(name="event_1")) await tracker.leave("node_1") + tracker.on_messages_committed(1) activity_task = asyncio.create_task(simulate_single_node()) @@ -378,18 +375,23 @@ async def consumer(): iteration_complete.set() async def producer(): - # Node A processes + # Node A processes - simulate full message lifecycle + tracker.on_messages_published(1) await tracker.enter("node_a") await output_queue.queue.put(Mock(name="event_a")) await tracker.leave("node_a") + tracker.on_messages_committed(1) # Critical timing window - yield to let consumer check idle state await asyncio.sleep(0) # Node B starts before consumer terminates (if fix works) + # simulate full message lifecycle + tracker.on_messages_published(1) await tracker.enter("node_b") await output_queue.queue.put(Mock(name="event_b")) await tracker.leave("node_b") + tracker.on_messages_committed(1) consumer_task = asyncio.create_task(consumer()) producer_task = asyncio.create_task(producer()) @@ -412,3 +414,91 @@ async def producer(): f"Expected 2 events but got {len(events_received)}. " "Race condition may have caused premature termination." ) + + @pytest.mark.asyncio + async def test_force_stop_terminates_iteration(self): + """ + Test that force_stop terminates iteration even with uncommitted messages. + """ + tracker = AsyncNodeTracker() + output_queue = AsyncOutputQueue( + output_topics=[], + consumer_name="test_consumer", + tracker=tracker, + ) + + # Publish messages but don't commit them (simulates incomplete work) + tracker.on_messages_published(5) + + # Not quiescent because uncommitted > 0 + assert not tracker.is_quiescent + assert tracker.get_metrics()["uncommitted_messages"] == 5 + + # Start iteration in background + events = [] + iteration_complete = asyncio.Event() + + async def iterate(): + async for event in output_queue: + events.append(event) + iteration_complete.set() + + iteration_task = asyncio.create_task(iterate()) + + # Give iteration a chance to start waiting + await asyncio.sleep(0.05) + + # Force stop should terminate iteration + tracker.force_stop() + + # Wait for iteration to complete + try: + await asyncio.wait_for(iteration_complete.wait(), timeout=1.0) + except asyncio.TimeoutError: + iteration_task.cancel() + pytest.fail("Force stop did not terminate iteration within timeout") + + await iteration_task + assert events == [] + + @pytest.mark.asyncio + async def test_force_stop_yields_queued_events_before_terminating(self): + """ + Test that force_stop yields any queued events before terminating. + """ + tracker = AsyncNodeTracker() + output_queue = AsyncOutputQueue( + output_topics=[], + consumer_name="test_consumer", + tracker=tracker, + ) + + # Simulate work with uncommitted messages + tracker.on_messages_published(5) + + # Queue some events + await output_queue.queue.put(Mock(name="event_1")) + await output_queue.queue.put(Mock(name="event_2")) + + events = [] + iteration_complete = asyncio.Event() + + async def iterate(): + async for event in output_queue: + events.append(event) + iteration_complete.set() + + iteration_task = asyncio.create_task(iterate()) + + # Give iteration a chance to get the queued events + await asyncio.sleep(0.05) + + # Force stop + tracker.force_stop() + + # Wait for iteration to complete + await asyncio.wait_for(iteration_complete.wait(), timeout=1.0) + await iteration_task + + # Should have received the queued events before terminating + assert len(events) == 2 diff --git a/tests/workflow/test_event_driven_workflow.py b/tests/workflow/test_event_driven_workflow.py index f6b5405..8313727 100644 --- a/tests/workflow/test_event_driven_workflow.py +++ b/tests/workflow/test_event_driven_workflow.py @@ -1,7 +1,5 @@ import asyncio -from unittest.mock import AsyncMock -from unittest.mock import Mock -from unittest.mock import patch +from unittest.mock import AsyncMock, Mock, patch import pytest from openinference.semconv.trace import OpenInferenceSpanKindValues @@ -310,10 +308,10 @@ def workflow_for_initial_test(self): return EventDrivenWorkflow(nodes={"test_node": node}) - def test_initial_workflow_method_exists(self, workflow_for_initial_test): - """Test that initial_workflow method exists.""" - assert hasattr(workflow_for_initial_test, "initial_workflow") - assert callable(workflow_for_initial_test.initial_workflow) + def test_init_workflow_method_exists(self, workflow_for_initial_test): + """init_workflow should be available for restoring workflow state.""" + assert hasattr(workflow_for_initial_test, "init_workflow") + assert callable(workflow_for_initial_test.init_workflow) class TestEventDrivenWorkflowToDict: @@ -375,7 +373,12 @@ async def test_tracker_reset_on_init(self, workflow_with_tracker): invoke_context = InvokeContext( conversation_id="test", invoke_id="test", assistant_request_id="test" ) - with patch("grafi.common.containers.container.container"): + with patch("grafi.workflows.impl.event_driven_workflow.container") as mock_container: + mock_event_store = Mock() + mock_event_store.get_agent_events = AsyncMock(return_value=[]) + mock_event_store.record_events = AsyncMock() + mock_event_store.record_event = AsyncMock() + mock_container.event_store = mock_event_store await workflow_with_tracker.init_workflow( PublishToTopicEvent(invoke_context=invoke_context, data=[]) ) diff --git a/tests/workflow/test_utils.py b/tests/workflow/test_utils.py index ebdd7e5..50f882a 100644 --- a/tests/workflow/test_utils.py +++ b/tests/workflow/test_utils.py @@ -1,5 +1,4 @@ -from unittest.mock import AsyncMock -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest @@ -12,9 +11,11 @@ from grafi.nodes.node import Node from grafi.topics.topic_base import TopicBase from grafi.topics.topic_types import TopicType -from grafi.workflows.impl.utils import get_async_output_events -from grafi.workflows.impl.utils import get_node_input -from grafi.workflows.impl.utils import publish_events +from grafi.workflows.impl.utils import ( + get_async_output_events, + get_node_input, + publish_events, +) class TestGetAsyncOutputEvents: @@ -200,6 +201,7 @@ async def test_publish_events(self): # Mock node and topics mock_topic1 = AsyncMock(spec=TopicBase) mock_topic2 = AsyncMock(spec=TopicBase) + tracker = MagicMock() node = MagicMock(spec=Node) node.name = "test_node" @@ -237,13 +239,14 @@ async def test_publish_events(self): consumed_event_ids=[event.event_id for event in consumed_events], ) - published_events = await publish_events(node, publish_to_event) + published_events = await publish_events(node, publish_to_event, tracker=tracker) assert len(published_events) == 1 assert published_events[0] == mock_event1 # Verify topics were called correctly mock_topic1.publish_data.assert_called_once_with(publish_to_event) + tracker.on_messages_published.assert_called_once_with(1, source="node:test_node") class TestGetNodeInput: diff --git a/tests_integration/hith_assistant/kyc_assistant_example.py b/tests_integration/hith_assistant/kyc_assistant_example.py index 9a465ce..4a11bfe 100644 --- a/tests_integration/hith_assistant/kyc_assistant_example.py +++ b/tests_integration/hith_assistant/kyc_assistant_example.py @@ -108,7 +108,7 @@ async def test_kyc_assistant() -> None: ) ] - output = await async_func_wrapper( + outputs = await async_func_wrapper( assistant.invoke( PublishToTopicEvent( invoke_context=get_invoke_context(), @@ -117,7 +117,7 @@ async def test_kyc_assistant() -> None: ) ) - print(output) + print(outputs) human_input = [ Message( @@ -131,6 +131,7 @@ async def test_kyc_assistant() -> None: PublishToTopicEvent( invoke_context=get_invoke_context(), data=human_input, + consumed_event_ids=[event.event_id for event in outputs], ) ) ) From ad8795688f2da48a0e877f7879497b1af438c23b Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Tue, 30 Dec 2025 21:20:30 +0000 Subject: [PATCH 6/9] fix lint --- grafi/workflows/impl/async_node_tracker.py | 24 +++--- grafi/workflows/impl/async_output_queue.py | 12 ++- grafi/workflows/impl/event_driven_workflow.py | 75 +++++++++++++------ grafi/workflows/impl/utils.py | 29 ++++--- tests/workflow/test_event_driven_workflow.py | 8 +- tests/workflow/test_utils.py | 15 ++-- tests_integration/agents/run_agents.py | 5 +- .../run_embedding_assistant.py | 5 +- .../run_event_store_postgres.py | 5 +- .../run_function_assistant.py | 5 +- .../run_function_call_assistant.py | 5 +- .../hith_assistant/run_hith_assistant.py | 5 +- .../run_input_output_topics.py | 5 +- .../invoke_kwargs/run_invoke_kwargs.py | 5 +- .../mcp_assistant/run_mcp_assistant.py | 5 +- .../run_multimodal_assistant.py | 5 +- .../rag_assistant/run_rag_assistant.py | 5 +- .../react_assistant/run_react_assistant.py | 5 +- tests_integration/run_all.py | 9 ++- .../run_simple_llm_assistant.py | 5 +- .../run_simple_stream_assistant.py | 5 +- 21 files changed, 154 insertions(+), 88 deletions(-) diff --git a/grafi/workflows/impl/async_node_tracker.py b/grafi/workflows/impl/async_node_tracker.py index 11662fe..814f06f 100644 --- a/grafi/workflows/impl/async_node_tracker.py +++ b/grafi/workflows/impl/async_node_tracker.py @@ -3,7 +3,9 @@ # ────────────────────────────────────────────────────────────────────────────── import asyncio from collections import defaultdict -from typing import Dict, Optional, Set +from typing import Dict +from typing import Optional +from typing import Set from loguru import logger @@ -11,19 +13,19 @@ class AsyncNodeTracker: """ Central tracker for workflow activity and quiescence detection. - + Design: All tracking calls come from the ORCHESTRATOR layer, not from TopicBase. This keeps topics as pure message queues. - + Quiescence = (no active nodes) AND (no uncommitted messages) AND (work done) - + Usage in workflow: # In publish_events(): tracker.on_messages_published(len(published_events)) - + # In _commit_events(): tracker.on_messages_committed(len(events)) - + # In node processing: await tracker.enter(node_name) ... process ... @@ -95,7 +97,9 @@ def on_messages_published(self, count: int = 1, source: str = "") -> None: self._quiescence_event.clear() self._uncommitted_messages += count - logger.debug(f"Tracker: {count} messages published from {source} (uncommitted={self._uncommitted_messages})") + logger.debug( + f"Tracker: {count} messages published from {source} (uncommitted={self._uncommitted_messages})" + ) def on_messages_committed(self, count: int = 1, source: str = "") -> None: """ @@ -137,7 +141,9 @@ def _check_quiescence(self) -> None: f"is_quiescent={self.is_quiescent}" ) if self.is_quiescent: - logger.info(f"Tracker: quiescence detected (committed={self._total_committed})") + logger.info( + f"Tracker: quiescence detected (committed={self._total_committed})" + ) self._quiescence_event.set() @property @@ -206,4 +212,4 @@ def get_metrics(self) -> Dict: "uncommitted_messages": self._uncommitted_messages, "total_committed": self._total_committed, "is_quiescent": self.is_quiescent, - } \ No newline at end of file + } diff --git a/grafi/workflows/impl/async_output_queue.py b/grafi/workflows/impl/async_output_queue.py index df6ead3..79eca20 100644 --- a/grafi/workflows/impl/async_output_queue.py +++ b/grafi/workflows/impl/async_output_queue.py @@ -11,7 +11,7 @@ class AsyncOutputQueue: """ Manages output topics and provides async iteration over output events. - + Simplified: All quiescence detection delegated to AsyncNodeTracker. """ @@ -58,8 +58,12 @@ async def _output_listener(self, topic: TopicBase) -> None: await self.queue.put(event) # Mark messages as committed when they reach the output queue if events: - logger.debug(f"Output listener: consumed {len(events)} events from {topic.name}") - self.tracker.on_messages_committed(len(events), source=f"output_listener:{topic.name}") + logger.debug( + f"Output listener: consumed {len(events)} events from {topic.name}" + ) + self.tracker.on_messages_committed( + len(events), source=f"output_listener:{topic.name}" + ) except asyncio.TimeoutError: continue except asyncio.CancelledError: @@ -127,4 +131,4 @@ async def __anext__(self) -> TopicEvent: # Quiescence or force stop detected if self.tracker.should_terminate and self.queue.empty(): - raise StopAsyncIteration \ No newline at end of file + raise StopAsyncIteration diff --git a/grafi/workflows/impl/event_driven_workflow.py b/grafi/workflows/impl/event_driven_workflow.py index abcfd9f..ced825c 100644 --- a/grafi/workflows/impl/event_driven_workflow.py +++ b/grafi/workflows/impl/event_driven_workflow.py @@ -1,6 +1,10 @@ import asyncio from collections import deque -from typing import Any, AsyncGenerator, Dict, List, Set +from typing import Any +from typing import AsyncGenerator +from typing import Dict +from typing import List +from typing import Set from loguru import logger from openinference.semconv.trace import OpenInferenceSpanKindValues @@ -14,7 +18,8 @@ ) from grafi.common.events.topic_events.publish_to_topic_event import PublishToTopicEvent from grafi.common.events.topic_events.topic_event import TopicEvent -from grafi.common.exceptions import NodeExecutionError, WorkflowError +from grafi.common.exceptions import NodeExecutionError +from grafi.common.exceptions import WorkflowError from grafi.common.models.invoke_context import InvokeContext from grafi.nodes.node import Node from grafi.nodes.node_base import NodeBase @@ -28,12 +33,11 @@ from grafi.topics.topic_types import TopicType from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker from grafi.workflows.impl.async_output_queue import AsyncOutputQueue -from grafi.workflows.impl.utils import ( - get_async_output_events, - get_node_input, - publish_events, -) -from grafi.workflows.workflow import Workflow, WorkflowBuilder +from grafi.workflows.impl.utils import get_async_output_events +from grafi.workflows.impl.utils import get_node_input +from grafi.workflows.impl.utils import publish_events +from grafi.workflows.workflow import Workflow +from grafi.workflows.workflow import WorkflowBuilder class EventDrivenWorkflow(Workflow): @@ -209,9 +213,10 @@ async def _get_output_events(self) -> List[ConsumeFromTopicEvent]: return consumed_events - async def _commit_events( - self, consumer_name: str, topic_events: List[ConsumeFromTopicEvent], + self, + consumer_name: str, + topic_events: List[ConsumeFromTopicEvent], track_commit: bool = True, ) -> None: if not topic_events: @@ -229,9 +234,12 @@ async def _commit_events( # Notify tracker that messages have been committed # (skip if already tracked elsewhere, e.g., by output listener) if track_commit: - logger.debug(f"Committing {len(topic_events)} events for {consumer_name}, track_commit={track_commit}") - self._tracker.on_messages_committed(len(topic_events), source=f"commit:{consumer_name}") - + logger.debug( + f"Committing {len(topic_events)} events for {consumer_name}, track_commit={track_commit}" + ) + self._tracker.on_messages_committed( + len(topic_events), source=f"commit:{consumer_name}" + ) async def _add_to_invoke_queue(self, event: TopicEvent) -> None: topic_name = event.name @@ -283,7 +291,9 @@ async def invoke_sequential( async for result in node.invoke( invoke_context, node_consumed_events ): - published_events.extend(await publish_events(node, result, self._tracker)) + published_events.extend( + await publish_events(node, result, self._tracker) + ) for event in published_events: await self._add_to_invoke_queue(event) @@ -313,7 +323,9 @@ async def invoke_parallel( self, input_data: PublishToTopicEvent ) -> AsyncGenerator[ConsumeFromTopicEvent, None]: invoke_context = input_data.invoke_context - logger.debug(f"invoke_parallel: tracker_id={id(self._tracker)}, metrics={self._tracker.get_metrics()}") + logger.debug( + f"invoke_parallel: tracker_id={id(self._tracker)}, metrics={self._tracker.get_metrics()}" + ) # Start a background task to process all nodes (including streaming generators) node_processing_task = [ @@ -390,7 +402,8 @@ async def invoke_parallel( # Commit all consumed output events to topics # (tracking already done by output listener, so skip tracker update) await self._commit_events( - consumer_name=self.name, topic_events=consumed_output_events, + consumer_name=self.name, + topic_events=consumed_output_events, track_commit=False, ) @@ -515,7 +528,11 @@ def _cancel_all_active_tasks() -> None: if consumed_events: async for event in node.invoke(invoke_context, consumed_events): node_output_events.extend( - await publish_events(node=node, publish_event=event, tracker=self._tracker) + await publish_events( + node=node, + publish_event=event, + tracker=self._tracker, + ) ) await self._commit_events( @@ -591,7 +608,9 @@ async def init_workflow( self, input_data: PublishToTopicEvent, is_sequential: bool = False ) -> Any: # 1 – initial seeding - logger.debug(f"init_workflow: is_sequential={is_sequential}, tracker_id={id(self._tracker)}") + logger.debug( + f"init_workflow: is_sequential={is_sequential}, tracker_id={id(self._tracker)}" + ) if not is_sequential: self._tracker.reset() @@ -631,13 +650,21 @@ async def init_workflow( if is_sequential: await self._add_to_invoke_queue(event) - logger.debug(f"init_workflow: events_to_record={len(events_to_record)}, input_topics={len(input_topics)}") + logger.debug( + f"init_workflow: events_to_record={len(events_to_record)}, input_topics={len(input_topics)}" + ) if events_to_record: # Track initial input messages for quiescence detection if not is_sequential: - logger.debug(f"init_workflow: calling on_messages_published({len(events_to_record)})") - self._tracker.on_messages_published(len(events_to_record), source="init_workflow") - logger.debug(f"init_workflow: tracker after publish: {self._tracker.get_metrics()}") + logger.debug( + f"init_workflow: calling on_messages_published({len(events_to_record)})" + ) + self._tracker.on_messages_published( + len(events_to_record), source="init_workflow" + ) + logger.debug( + f"init_workflow: tracker after publish: {self._tracker.get_metrics()}" + ) await container.event_store.record_events(events_to_record) else: # When there is unfinished workflow, we need to restore the workflow topics @@ -691,7 +718,9 @@ async def init_workflow( if paired_event: # Track the published message for quiescence detection if not is_sequential: - self._tracker.on_messages_published(1, source="restore_paired_input") + self._tracker.on_messages_published( + 1, source="restore_paired_input" + ) if is_sequential: await self._add_to_invoke_queue(paired_event) await container.event_store.record_event(paired_event) diff --git a/grafi/workflows/impl/utils.py b/grafi/workflows/impl/utils.py index 4190625..2381450 100644 --- a/grafi/workflows/impl/utils.py +++ b/grafi/workflows/impl/utils.py @@ -1,4 +1,7 @@ -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING +from typing import Dict +from typing import List +from typing import Optional from grafi.common.events.topic_events.consume_from_topic_event import ( ConsumeFromTopicEvent, @@ -8,6 +11,7 @@ from grafi.common.models.message import Message from grafi.nodes.node_base import NodeBase + if TYPE_CHECKING: from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker @@ -15,7 +19,7 @@ def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: """ Process a list of TopicEvents, grouping by name and aggregating streaming messages. - + NO CHANGES NEEDED - this is read-only aggregation. """ # Group events by name @@ -80,12 +84,12 @@ async def publish_events( ) -> List[PublishToTopicEvent]: """ Publish events to all topics the node publishes to. - + CHANGE: Added optional tracker parameter. When provided, notifies tracker of published messages. """ published_events: List[PublishToTopicEvent] = [] - + for topic in node.publish_to: event = await topic.publish_data(publish_event) if event: @@ -101,7 +105,7 @@ async def publish_events( async def get_node_input(node: NodeBase) -> List[ConsumeFromTopicEvent]: """ Get input events for a node from its subscribed topics. - + NO CHANGES NEEDED - consumption tracking happens at commit time. """ consumed_events: List[ConsumeFromTopicEvent] = [] @@ -130,18 +134,19 @@ async def get_node_input(node: NodeBase) -> List[ConsumeFromTopicEvent]: # Alternative: Wrapper approach if you can't modify function signatures # ============================================================================= + class TrackedPublisher: """ Wrapper that adds tracking to publish_events without changing its signature. - + Usage: publisher = TrackedPublisher(tracker) events = await publisher.publish(node, event) """ - + def __init__(self, tracker: "AsyncNodeTracker"): self.tracker = tracker - + async def publish( self, node: NodeBase, @@ -149,13 +154,13 @@ async def publish( ) -> List[PublishToTopicEvent]: """Publish with automatic tracking.""" published_events: List[PublishToTopicEvent] = [] - + for topic in node.publish_to: event = await topic.publish_data(publish_event) if event: published_events.append(event) - + if published_events: self.tracker.on_messages_published(len(published_events)) - - return published_events \ No newline at end of file + + return published_events diff --git a/tests/workflow/test_event_driven_workflow.py b/tests/workflow/test_event_driven_workflow.py index 8313727..017ed66 100644 --- a/tests/workflow/test_event_driven_workflow.py +++ b/tests/workflow/test_event_driven_workflow.py @@ -1,5 +1,7 @@ import asyncio -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock +from unittest.mock import Mock +from unittest.mock import patch import pytest from openinference.semconv.trace import OpenInferenceSpanKindValues @@ -373,7 +375,9 @@ async def test_tracker_reset_on_init(self, workflow_with_tracker): invoke_context = InvokeContext( conversation_id="test", invoke_id="test", assistant_request_id="test" ) - with patch("grafi.workflows.impl.event_driven_workflow.container") as mock_container: + with patch( + "grafi.workflows.impl.event_driven_workflow.container" + ) as mock_container: mock_event_store = Mock() mock_event_store.get_agent_events = AsyncMock(return_value=[]) mock_event_store.record_events = AsyncMock() diff --git a/tests/workflow/test_utils.py b/tests/workflow/test_utils.py index 50f882a..1ece46f 100644 --- a/tests/workflow/test_utils.py +++ b/tests/workflow/test_utils.py @@ -1,4 +1,5 @@ -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock +from unittest.mock import MagicMock import pytest @@ -11,11 +12,9 @@ from grafi.nodes.node import Node from grafi.topics.topic_base import TopicBase from grafi.topics.topic_types import TopicType -from grafi.workflows.impl.utils import ( - get_async_output_events, - get_node_input, - publish_events, -) +from grafi.workflows.impl.utils import get_async_output_events +from grafi.workflows.impl.utils import get_node_input +from grafi.workflows.impl.utils import publish_events class TestGetAsyncOutputEvents: @@ -246,7 +245,9 @@ async def test_publish_events(self): # Verify topics were called correctly mock_topic1.publish_data.assert_called_once_with(publish_to_event) - tracker.on_messages_published.assert_called_once_with(1, source="node:test_node") + tracker.on_messages_published.assert_called_once_with( + 1, source="node:test_node" + ) class TestGetNodeInput: diff --git a/tests_integration/agents/run_agents.py b/tests_integration/agents/run_agents.py index 883c0b5..0989203 100644 --- a/tests_integration/agents/run_agents.py +++ b/tests_integration/agents/run_agents.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/embedding_assistant/run_embedding_assistant.py b/tests_integration/embedding_assistant/run_embedding_assistant.py index 10aae9f..5f6858f 100644 --- a/tests_integration/embedding_assistant/run_embedding_assistant.py +++ b/tests_integration/embedding_assistant/run_embedding_assistant.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/event_store_postgres/run_event_store_postgres.py b/tests_integration/event_store_postgres/run_event_store_postgres.py index c9f60ff..1b45356 100644 --- a/tests_integration/event_store_postgres/run_event_store_postgres.py +++ b/tests_integration/event_store_postgres/run_event_store_postgres.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/function_assistant/run_function_assistant.py b/tests_integration/function_assistant/run_function_assistant.py index dbe48bb..496c367 100644 --- a/tests_integration/function_assistant/run_function_assistant.py +++ b/tests_integration/function_assistant/run_function_assistant.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/function_call_assistant/run_function_call_assistant.py b/tests_integration/function_call_assistant/run_function_call_assistant.py index f53579f..ac84de7 100644 --- a/tests_integration/function_call_assistant/run_function_call_assistant.py +++ b/tests_integration/function_call_assistant/run_function_call_assistant.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/hith_assistant/run_hith_assistant.py b/tests_integration/hith_assistant/run_hith_assistant.py index 5ac8497..f6cd9d1 100644 --- a/tests_integration/hith_assistant/run_hith_assistant.py +++ b/tests_integration/hith_assistant/run_hith_assistant.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/input_output_topics/run_input_output_topics.py b/tests_integration/input_output_topics/run_input_output_topics.py index 1873e53..7f39b9c 100644 --- a/tests_integration/input_output_topics/run_input_output_topics.py +++ b/tests_integration/input_output_topics/run_input_output_topics.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/invoke_kwargs/run_invoke_kwargs.py b/tests_integration/invoke_kwargs/run_invoke_kwargs.py index bbf29f6..a7c2212 100644 --- a/tests_integration/invoke_kwargs/run_invoke_kwargs.py +++ b/tests_integration/invoke_kwargs/run_invoke_kwargs.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/mcp_assistant/run_mcp_assistant.py b/tests_integration/mcp_assistant/run_mcp_assistant.py index 81b17c4..1d43a05 100644 --- a/tests_integration/mcp_assistant/run_mcp_assistant.py +++ b/tests_integration/mcp_assistant/run_mcp_assistant.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/multimodal_assistant/run_multimodal_assistant.py b/tests_integration/multimodal_assistant/run_multimodal_assistant.py index 7d83370..ff921f1 100644 --- a/tests_integration/multimodal_assistant/run_multimodal_assistant.py +++ b/tests_integration/multimodal_assistant/run_multimodal_assistant.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/rag_assistant/run_rag_assistant.py b/tests_integration/rag_assistant/run_rag_assistant.py index d1f047b..bc98a15 100644 --- a/tests_integration/rag_assistant/run_rag_assistant.py +++ b/tests_integration/rag_assistant/run_rag_assistant.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/react_assistant/run_react_assistant.py b/tests_integration/react_assistant/run_react_assistant.py index d31af28..73247a4 100644 --- a/tests_integration/react_assistant/run_react_assistant.py +++ b/tests_integration/react_assistant/run_react_assistant.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/run_all.py b/tests_integration/run_all.py index 6a821c5..8dcaec9 100644 --- a/tests_integration/run_all.py +++ b/tests_integration/run_all.py @@ -12,8 +12,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def _load_runner_module(script: Path): @@ -62,7 +63,9 @@ def run_all_scripts(pass_local: bool = True) -> int: try: runner_module = _load_runner_module(script) - runner_results = runner_module.run_scripts(pass_local=pass_local, collect=True) + runner_results = runner_module.run_scripts( + pass_local=pass_local, collect=True + ) except Exception as exc: # noqa: BLE001 example_rel = script.relative_to(repo_root) error_message = f"Runner failed before executing examples: {exc}" diff --git a/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py b/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py index 14a27ea..8e6c714 100644 --- a/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py +++ b/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): diff --git a/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py b/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py index 3b71737..990d5ae 100644 --- a/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py +++ b/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py @@ -10,8 +10,9 @@ try: sys.stdout.reconfigure(encoding="utf-8") except AttributeError: - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", write_through=True) - + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) def run_scripts(pass_local: bool = True, collect: bool = False): From aab0492f97a6b887de411ef5be02eb88c9121001 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Tue, 30 Dec 2025 21:25:25 +0000 Subject: [PATCH 7/9] revert change --- grafi/topics/queue_impl/in_mem_topic_event_queue.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grafi/topics/queue_impl/in_mem_topic_event_queue.py b/grafi/topics/queue_impl/in_mem_topic_event_queue.py index 8bf6d1e..3af9b15 100644 --- a/grafi/topics/queue_impl/in_mem_topic_event_queue.py +++ b/grafi/topics/queue_impl/in_mem_topic_event_queue.py @@ -68,9 +68,9 @@ async def fetch( # If timeout is 0 or None and no data, return immediately while not await self.can_consume(consumer_id): try: - # logger.debug( - # f"Consumer {consumer_id} waiting for new messages with timeout={timeout}" - # ) + logger.debug( + f"Consumer {consumer_id} waiting for new messages with timeout={timeout}" + ) await asyncio.wait_for(self._cond.wait(), timeout) except asyncio.TimeoutError: return [] From 1ec989458188440f145f3d337a08cc5069747d14 Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Tue, 30 Dec 2025 21:37:57 +0000 Subject: [PATCH 8/9] address comments --- grafi/workflows/impl/async_output_queue.py | 8 +--- grafi/workflows/impl/utils.py | 45 +--------------------- tests/workflow/test_utils.py | 2 +- 3 files changed, 5 insertions(+), 50 deletions(-) diff --git a/grafi/workflows/impl/async_output_queue.py b/grafi/workflows/impl/async_output_queue.py index 79eca20..4e8eaf1 100644 --- a/grafi/workflows/impl/async_output_queue.py +++ b/grafi/workflows/impl/async_output_queue.py @@ -87,12 +87,7 @@ async def __anext__(self) -> TopicEvent: check_count = 0 while True: check_count += 1 - if check_count % 20 == 0: # Log every ~10 seconds (20 * 0.5s) - logger.debug( - f"AsyncOutputQueue: still waiting - " - f"queue_empty={self.queue.empty()}, " - f"tracker_metrics={self.tracker.get_metrics()}" - ) + # Fast path: queue has items if not self.queue.empty(): try: @@ -127,6 +122,7 @@ async def __anext__(self) -> TopicEvent: try: return queue_task.result() except asyncio.QueueEmpty: + # Task was cancelled as part of normal cleanup; ignore. continue # Quiescence or force stop detected diff --git a/grafi/workflows/impl/utils.py b/grafi/workflows/impl/utils.py index 2381450..de60872 100644 --- a/grafi/workflows/impl/utils.py +++ b/grafi/workflows/impl/utils.py @@ -1,7 +1,5 @@ -from typing import TYPE_CHECKING from typing import Dict from typing import List -from typing import Optional from grafi.common.events.topic_events.consume_from_topic_event import ( ConsumeFromTopicEvent, @@ -10,10 +8,7 @@ from grafi.common.events.topic_events.topic_event import TopicEvent from grafi.common.models.message import Message from grafi.nodes.node_base import NodeBase - - -if TYPE_CHECKING: - from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker +from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: @@ -80,7 +75,7 @@ def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: async def publish_events( node: NodeBase, publish_event: PublishToTopicEvent, - tracker: Optional["AsyncNodeTracker"] = None, # NEW: Optional tracker + tracker: AsyncNodeTracker, ) -> List[PublishToTopicEvent]: """ Publish events to all topics the node publishes to. @@ -128,39 +123,3 @@ async def get_node_input(node: NodeBase) -> List[ConsumeFromTopicEvent]: consumed_events.append(consumed_event) return consumed_events - - -# ============================================================================= -# Alternative: Wrapper approach if you can't modify function signatures -# ============================================================================= - - -class TrackedPublisher: - """ - Wrapper that adds tracking to publish_events without changing its signature. - - Usage: - publisher = TrackedPublisher(tracker) - events = await publisher.publish(node, event) - """ - - def __init__(self, tracker: "AsyncNodeTracker"): - self.tracker = tracker - - async def publish( - self, - node: NodeBase, - publish_event: PublishToTopicEvent, - ) -> List[PublishToTopicEvent]: - """Publish with automatic tracking.""" - published_events: List[PublishToTopicEvent] = [] - - for topic in node.publish_to: - event = await topic.publish_data(publish_event) - if event: - published_events.append(event) - - if published_events: - self.tracker.on_messages_published(len(published_events)) - - return published_events diff --git a/tests/workflow/test_utils.py b/tests/workflow/test_utils.py index 1ece46f..947326b 100644 --- a/tests/workflow/test_utils.py +++ b/tests/workflow/test_utils.py @@ -238,7 +238,7 @@ async def test_publish_events(self): consumed_event_ids=[event.event_id for event in consumed_events], ) - published_events = await publish_events(node, publish_to_event, tracker=tracker) + published_events = await publish_events(node, publish_to_event, tracker) assert len(published_events) == 1 assert published_events[0] == mock_event1 From c5a643532ba94ee4fd09600303927ab7bcb7670e Mon Sep 17 00:00:00 2001 From: GuanyiLi-Craig Date: Tue, 30 Dec 2025 22:25:01 +0000 Subject: [PATCH 9/9] address comments --- grafi/workflows/impl/async_node_tracker.py | 141 ++++++++++++------ grafi/workflows/impl/async_output_queue.py | 20 ++- grafi/workflows/impl/event_driven_workflow.py | 12 +- grafi/workflows/impl/utils.py | 4 +- tests/workflow/test_async_node_tracker.py | 84 +++++------ tests/workflow/test_async_output_queue.py | 44 +++--- tests/workflow/test_event_driven_workflow.py | 4 +- tests/workflow/test_utils.py | 4 +- 8 files changed, 189 insertions(+), 124 deletions(-) diff --git a/grafi/workflows/impl/async_node_tracker.py b/grafi/workflows/impl/async_node_tracker.py index 814f06f..44f7033 100644 --- a/grafi/workflows/impl/async_node_tracker.py +++ b/grafi/workflows/impl/async_node_tracker.py @@ -78,14 +78,14 @@ async def leave(self, node_name: str) -> None: """Called when a node finishes processing.""" async with self._cond: self._active.discard(node_name) - self._check_quiescence() + self._check_quiescence_unlocked() self._cond.notify_all() # ───────────────────────────────────────────────────────────────────────── # Message Tracking (called from orchestrator utilities) # ───────────────────────────────────────────────────────────────────────── - def on_messages_published(self, count: int = 1, source: str = "") -> None: + async def on_messages_published(self, count: int = 1, source: str = "") -> None: """ Called when messages are published to topics. @@ -93,15 +93,16 @@ def on_messages_published(self, count: int = 1, source: str = "") -> None: """ if count <= 0: return - self._has_started = True - self._quiescence_event.clear() - self._uncommitted_messages += count + async with self._cond: + self._has_started = True + self._quiescence_event.clear() + self._uncommitted_messages += count - logger.debug( - f"Tracker: {count} messages published from {source} (uncommitted={self._uncommitted_messages})" - ) + logger.debug( + f"Tracker: {count} messages published from {source} (uncommitted={self._uncommitted_messages})" + ) - def on_messages_committed(self, count: int = 1, source: str = "") -> None: + async def on_messages_committed(self, count: int = 1, source: str = "") -> None: """ Called when messages are committed (consumed and acknowledged). @@ -109,46 +110,56 @@ def on_messages_committed(self, count: int = 1, source: str = "") -> None: """ if count <= 0: return - self._uncommitted_messages = max(0, self._uncommitted_messages - count) - self._total_committed += count - self._check_quiescence() + async with self._cond: + self._uncommitted_messages = max(0, self._uncommitted_messages - count) + self._total_committed += count + self._check_quiescence_unlocked() - logger.debug( - f"Tracker: {count} messages committed from {source} " - f"(uncommitted={self._uncommitted_messages}, total={self._total_committed})" - ) + logger.debug( + f"Tracker: {count} messages committed from {source} " + f"(uncommitted={self._uncommitted_messages}, total={self._total_committed})" + ) + self._cond.notify_all() # Aliases for clarity - def on_message_published(self) -> None: + async def on_message_published(self) -> None: """Single message version.""" - self.on_messages_published(1) + await self.on_messages_published(1) - def on_message_committed(self) -> None: + async def on_message_committed(self) -> None: """Single message version.""" - self.on_messages_committed(1) + await self.on_messages_committed(1) # ───────────────────────────────────────────────────────────────────────── # Quiescence Detection # ───────────────────────────────────────────────────────────────────────── - def _check_quiescence(self) -> None: - """Check and signal quiescence if all conditions met.""" + def _check_quiescence_unlocked(self) -> None: + """ + Check and signal quiescence if all conditions met. + + MUST be called with self._cond lock held. + """ + is_quiescent = self._is_quiescent_unlocked() logger.debug( f"Tracker: checking quiescence - active={list(self._active)}, " f"uncommitted={self._uncommitted_messages}, " f"has_started={self._has_started}, " f"total_committed={self._total_committed}, " - f"is_quiescent={self.is_quiescent}" + f"is_quiescent={is_quiescent}" ) - if self.is_quiescent: + if is_quiescent: logger.info( f"Tracker: quiescence detected (committed={self._total_committed})" ) self._quiescence_event.set() - @property - def is_quiescent(self) -> bool: + def _is_quiescent_unlocked(self) -> bool: """ + Internal quiescence check without lock. + + MUST be called with self._cond lock held. + True when workflow is truly idle: - No nodes actively processing - No messages waiting to be committed @@ -161,26 +172,66 @@ def is_quiescent(self) -> bool: and self._total_committed > 0 ) - @property - def should_terminate(self) -> bool: + async def is_quiescent(self) -> bool: + """ + True when workflow is truly idle: + - No nodes actively processing + - No messages waiting to be committed + - At least some work was done + + This method acquires the lock to ensure consistent reads. + """ + async with self._cond: + return self._is_quiescent_unlocked() + + def _should_terminate_unlocked(self) -> bool: + """ + Internal termination check without lock. + + MUST be called with self._cond lock held. + """ + return self._is_quiescent_unlocked() or self._force_stopped + + async def should_terminate(self) -> bool: """ True when workflow should stop iteration. Either natural quiescence or explicit force stop. + + This method acquires the lock to ensure consistent reads. """ - return self.is_quiescent or self._force_stopped + async with self._cond: + return self._should_terminate_unlocked() + + async def force_stop(self) -> None: + """ + Force the workflow to stop immediately (async version with lock). + Called when workflow.stop() is invoked from async context. + """ + async with self._cond: + logger.info("Tracker: force stop requested") + self._force_stopped = True + self._quiescence_event.set() + self._cond.notify_all() - def force_stop(self) -> None: + def force_stop_sync(self) -> None: """ - Force the workflow to stop immediately. - Called when workflow.stop() is invoked. + Force the workflow to stop immediately (sync version). + + This is a synchronous version for use from sync contexts (e.g., stop() method). + It sets the stop flag and event without acquiring the async lock. + This is safe because: + 1. Setting _force_stopped to True is atomic for the stop signal + 2. asyncio.Event.set() is thread-safe + 3. Readers will see the updated state on their next lock acquisition """ - logger.info("Tracker: force stop requested") + logger.info("Tracker: force stop requested (sync)") self._force_stopped = True self._quiescence_event.set() - def is_idle(self) -> bool: + async def is_idle(self) -> bool: """Legacy: just checks if no active nodes.""" - return not self._active + async with self._cond: + return not self._active async def wait_for_quiescence(self, timeout: Optional[float] = None) -> bool: """Wait until quiescent. Returns False on timeout.""" @@ -201,15 +252,17 @@ async def wait_idle_event(self) -> None: # Metrics # ───────────────────────────────────────────────────────────────────────── - def get_activity_count(self) -> int: + async def get_activity_count(self) -> int: """Total processing count across all nodes.""" - return sum(self._processing_count.values()) + async with self._cond: + return sum(self._processing_count.values()) - def get_metrics(self) -> Dict: + async def get_metrics(self) -> Dict: """Detailed metrics for debugging.""" - return { - "active_nodes": list(self._active), - "uncommitted_messages": self._uncommitted_messages, - "total_committed": self._total_committed, - "is_quiescent": self.is_quiescent, - } + async with self._cond: + return { + "active_nodes": list(self._active), + "uncommitted_messages": self._uncommitted_messages, + "total_committed": self._total_committed, + "is_quiescent": self._is_quiescent_unlocked(), + } diff --git a/grafi/workflows/impl/async_output_queue.py b/grafi/workflows/impl/async_output_queue.py index 4e8eaf1..e3c72d2 100644 --- a/grafi/workflows/impl/async_output_queue.py +++ b/grafi/workflows/impl/async_output_queue.py @@ -61,7 +61,7 @@ async def _output_listener(self, topic: TopicBase) -> None: logger.debug( f"Output listener: consumed {len(events)} events from {topic.name}" ) - self.tracker.on_messages_committed( + await self.tracker.on_messages_committed( len(events), source=f"output_listener:{topic.name}" ) except asyncio.TimeoutError: @@ -96,8 +96,13 @@ async def __anext__(self) -> TopicEvent: pass # Check for completion (natural quiescence or force stop) - if self.tracker.should_terminate and self.queue.empty(): - raise StopAsyncIteration + if await self.tracker.should_terminate(): + # Final drain attempt - try to get any remaining items before stopping + # This avoids race where item is added between empty() check and raising + try: + return self.queue.get_nowait() + except asyncio.QueueEmpty: + raise StopAsyncIteration # Wait for queue item or quiescence queue_task = asyncio.create_task(self.queue.get()) @@ -126,5 +131,10 @@ async def __anext__(self) -> TopicEvent: continue # Quiescence or force stop detected - if self.tracker.should_terminate and self.queue.empty(): - raise StopAsyncIteration + if await self.tracker.should_terminate(): + # Final drain attempt - try to get any remaining items before stopping + # This avoids race where item is added between empty() check and raising + try: + return self.queue.get_nowait() + except asyncio.QueueEmpty: + raise StopAsyncIteration diff --git a/grafi/workflows/impl/event_driven_workflow.py b/grafi/workflows/impl/event_driven_workflow.py index ced825c..92d8fd6 100644 --- a/grafi/workflows/impl/event_driven_workflow.py +++ b/grafi/workflows/impl/event_driven_workflow.py @@ -79,7 +79,7 @@ def stop(self) -> None: Overrides base class to also trigger force stop on the tracker. """ super().stop() - self._tracker.force_stop() + self._tracker.force_stop_sync() @classmethod def builder(cls) -> WorkflowBuilder: @@ -237,7 +237,7 @@ async def _commit_events( logger.debug( f"Committing {len(topic_events)} events for {consumer_name}, track_commit={track_commit}" ) - self._tracker.on_messages_committed( + await self._tracker.on_messages_committed( len(topic_events), source=f"commit:{consumer_name}" ) @@ -324,7 +324,7 @@ async def invoke_parallel( ) -> AsyncGenerator[ConsumeFromTopicEvent, None]: invoke_context = input_data.invoke_context logger.debug( - f"invoke_parallel: tracker_id={id(self._tracker)}, metrics={self._tracker.get_metrics()}" + f"invoke_parallel: tracker_id={id(self._tracker)}, metrics={await self._tracker.get_metrics()}" ) # Start a background task to process all nodes (including streaming generators) @@ -659,11 +659,11 @@ async def init_workflow( logger.debug( f"init_workflow: calling on_messages_published({len(events_to_record)})" ) - self._tracker.on_messages_published( + await self._tracker.on_messages_published( len(events_to_record), source="init_workflow" ) logger.debug( - f"init_workflow: tracker after publish: {self._tracker.get_metrics()}" + f"init_workflow: tracker after publish: {await self._tracker.get_metrics()}" ) await container.event_store.record_events(events_to_record) else: @@ -718,7 +718,7 @@ async def init_workflow( if paired_event: # Track the published message for quiescence detection if not is_sequential: - self._tracker.on_messages_published( + await self._tracker.on_messages_published( 1, source="restore_paired_input" ) if is_sequential: diff --git a/grafi/workflows/impl/utils.py b/grafi/workflows/impl/utils.py index de60872..8052c16 100644 --- a/grafi/workflows/impl/utils.py +++ b/grafi/workflows/impl/utils.py @@ -92,7 +92,9 @@ async def publish_events( # NEW: Notify tracker of published messages if tracker and published_events: - tracker.on_messages_published(len(published_events), source=f"node:{node.name}") + await tracker.on_messages_published( + len(published_events), source=f"node:{node.name}" + ) return published_events diff --git a/tests/workflow/test_async_node_tracker.py b/tests/workflow/test_async_node_tracker.py index c0a5118..e348740 100644 --- a/tests/workflow/test_async_node_tracker.py +++ b/tests/workflow/test_async_node_tracker.py @@ -14,112 +14,112 @@ def tracker(self): @pytest.mark.asyncio async def test_initial_state(self, tracker): """Tracker starts idle with no work recorded.""" - assert tracker.is_idle() - assert tracker.is_quiescent is False - assert tracker.get_activity_count() == 0 - assert tracker.get_metrics()["uncommitted_messages"] == 0 + assert await tracker.is_idle() + assert await tracker.is_quiescent() is False + assert await tracker.get_activity_count() == 0 + assert (await tracker.get_metrics())["uncommitted_messages"] == 0 @pytest.mark.asyncio async def test_enter_and_leave_updates_activity(self, tracker): """Entering and leaving nodes updates activity counts.""" await tracker.enter("node1") - assert not tracker.is_idle() - assert tracker.get_activity_count() == 1 + assert not await tracker.is_idle() + assert await tracker.get_activity_count() == 1 assert "node1" in tracker._active await tracker.leave("node1") - assert tracker.is_idle() + assert await tracker.is_idle() # No commits yet so quiescence is still False - assert tracker.is_quiescent is False - assert tracker.get_activity_count() == 1 + assert await tracker.is_quiescent() is False + assert await tracker.get_activity_count() == 1 @pytest.mark.asyncio async def test_message_tracking_and_quiescence(self, tracker): """Published/committed message tracking drives quiescence detection.""" - tracker.on_messages_published(2) - assert tracker.is_quiescent is False - assert tracker.get_metrics()["uncommitted_messages"] == 2 + await tracker.on_messages_published(2) + assert await tracker.is_quiescent() is False + assert (await tracker.get_metrics())["uncommitted_messages"] == 2 - tracker.on_messages_committed(1) - assert tracker.is_quiescent is False - assert tracker.get_metrics()["uncommitted_messages"] == 1 + await tracker.on_messages_committed(1) + assert await tracker.is_quiescent() is False + assert (await tracker.get_metrics())["uncommitted_messages"] == 1 - tracker.on_messages_committed(1) - assert tracker.is_quiescent is True - assert tracker.get_metrics()["uncommitted_messages"] == 0 + await tracker.on_messages_committed(1) + assert await tracker.is_quiescent() is True + assert (await tracker.get_metrics())["uncommitted_messages"] == 0 @pytest.mark.asyncio async def test_wait_for_quiescence(self, tracker): """wait_for_quiescence resolves when work finishes.""" - tracker.on_messages_published(1) + await tracker.on_messages_published(1) async def finish_work(): await asyncio.sleep(0.01) - tracker.on_messages_committed(1) + await tracker.on_messages_committed(1) asyncio.create_task(finish_work()) result = await tracker.wait_for_quiescence(timeout=0.5) assert result is True - assert tracker.is_quiescent is True + assert await tracker.is_quiescent() is True @pytest.mark.asyncio async def test_wait_for_quiescence_timeout(self, tracker): """wait_for_quiescence returns False on timeout.""" result = await tracker.wait_for_quiescence(timeout=0.01) assert result is False - assert tracker.is_quiescent is False + assert await tracker.is_quiescent() is False @pytest.mark.asyncio async def test_reset(self, tracker): """Reset clears activity and quiescence state.""" await tracker.enter("node1") - tracker.on_messages_published(1) - tracker.on_messages_committed(1) + await tracker.on_messages_published(1) + await tracker.on_messages_committed(1) tracker.reset() - assert tracker.is_idle() - assert tracker.is_quiescent is False - assert tracker.get_activity_count() == 0 - assert tracker.get_metrics()["total_committed"] == 0 + assert await tracker.is_idle() + assert await tracker.is_quiescent() is False + assert await tracker.get_activity_count() == 0 + assert (await tracker.get_metrics())["total_committed"] == 0 @pytest.mark.asyncio async def test_force_stop(self, tracker): """Force stop terminates workflow even with uncommitted messages.""" - tracker.on_messages_published(2) - assert tracker.is_quiescent is False - assert tracker.should_terminate is False + await tracker.on_messages_published(2) + assert await tracker.is_quiescent() is False + assert await tracker.should_terminate() is False - tracker.force_stop() + await tracker.force_stop() # Not quiescent (uncommitted messages still exist) - assert tracker.is_quiescent is False + assert await tracker.is_quiescent() is False # But should_terminate is True due to force stop - assert tracker.should_terminate is True + assert await tracker.should_terminate() is True assert tracker._force_stopped is True @pytest.mark.asyncio async def test_should_terminate_on_quiescence(self, tracker): """should_terminate is True when naturally quiescent.""" - tracker.on_messages_published(1) - tracker.on_messages_committed(1) + await tracker.on_messages_published(1) + await tracker.on_messages_committed(1) - assert tracker.is_quiescent is True - assert tracker.should_terminate is True + assert await tracker.is_quiescent() is True + assert await tracker.should_terminate() is True assert tracker._force_stopped is False @pytest.mark.asyncio async def test_force_stop_triggers_quiescence_event(self, tracker): """Force stop sets the quiescence event so waiters can proceed.""" - tracker.on_messages_published(1) + await tracker.on_messages_published(1) # Event should not be set yet assert not tracker._quiescence_event.is_set() - tracker.force_stop() + await tracker.force_stop() # Event should now be set assert tracker._quiescence_event.is_set() @@ -127,10 +127,10 @@ async def test_force_stop_triggers_quiescence_event(self, tracker): @pytest.mark.asyncio async def test_reset_clears_force_stop(self, tracker): """Reset clears the force stop flag.""" - tracker.force_stop() + await tracker.force_stop() assert tracker._force_stopped is True tracker.reset() assert tracker._force_stopped is False - assert tracker.should_terminate is False + assert await tracker.should_terminate() is False diff --git a/tests/workflow/test_async_output_queue.py b/tests/workflow/test_async_output_queue.py index 9fb626d..e28c963 100644 --- a/tests/workflow/test_async_output_queue.py +++ b/tests/workflow/test_async_output_queue.py @@ -162,8 +162,8 @@ async def collect_events(): @pytest.mark.asyncio async def test_async_iteration_stops_after_quiescence(self, output_queue, tracker): """Async iteration ends when tracker reports quiescence and queue is empty.""" - tracker.on_messages_published(1) - tracker.on_messages_committed(1) + await tracker.on_messages_published(1) + await tracker.on_messages_committed(1) events = [] async for event in output_queue: @@ -177,9 +177,9 @@ async def test_async_iteration_waits_for_quiescence_or_events(self, tracker): queue = AsyncOutputQueue([], "test_consumer", tracker) async def signal_quiescence(): - tracker.on_messages_published(1) + await tracker.on_messages_published(1) await asyncio.sleep(0.02) - tracker.on_messages_committed(1) + await tracker.on_messages_committed(1) signal_task = asyncio.create_task(signal_quiescence()) @@ -198,8 +198,8 @@ async def test_event_emitted_before_quiescence(self, tracker): queued_event.name = "queued_event" await queue.queue.put(queued_event) - tracker.on_messages_published(1) - tracker.on_messages_committed(1) + await tracker.on_messages_published(1) + await tracker.on_messages_committed(1) events = [] async for event in queue: @@ -283,22 +283,22 @@ async def test_anext_waits_for_activity_count_stabilization(self): async def simulate_node_activity(): """Simulate node activity that should prevent premature termination.""" # First node processes - simulate full message lifecycle - tracker.on_messages_published(1) + await tracker.on_messages_published(1) await tracker.enter("node_1") await output_queue.queue.put(Mock(name="event_1")) await tracker.leave("node_1") - tracker.on_messages_committed(1) + await tracker.on_messages_committed(1) # Yield control - simulates realistic timing where next node # starts within the same event loop cycle await asyncio.sleep(0) # Second node picks up and processes - simulate full message lifecycle - tracker.on_messages_published(1) + await tracker.on_messages_published(1) await tracker.enter("node_2") await output_queue.queue.put(Mock(name="event_2")) await tracker.leave("node_2") - tracker.on_messages_committed(1) + await tracker.on_messages_committed(1) # Start the activity simulation activity_task = asyncio.create_task(simulate_node_activity()) @@ -330,11 +330,11 @@ async def test_anext_terminates_when_truly_idle(self): # Single node processes and finishes - simulate full message lifecycle async def simulate_single_node(): - tracker.on_messages_published(1) + await tracker.on_messages_published(1) await tracker.enter("node_1") await output_queue.queue.put(Mock(name="event_1")) await tracker.leave("node_1") - tracker.on_messages_committed(1) + await tracker.on_messages_committed(1) activity_task = asyncio.create_task(simulate_single_node()) @@ -376,22 +376,22 @@ async def consumer(): async def producer(): # Node A processes - simulate full message lifecycle - tracker.on_messages_published(1) + await tracker.on_messages_published(1) await tracker.enter("node_a") await output_queue.queue.put(Mock(name="event_a")) await tracker.leave("node_a") - tracker.on_messages_committed(1) + await tracker.on_messages_committed(1) # Critical timing window - yield to let consumer check idle state await asyncio.sleep(0) # Node B starts before consumer terminates (if fix works) # simulate full message lifecycle - tracker.on_messages_published(1) + await tracker.on_messages_published(1) await tracker.enter("node_b") await output_queue.queue.put(Mock(name="event_b")) await tracker.leave("node_b") - tracker.on_messages_committed(1) + await tracker.on_messages_committed(1) consumer_task = asyncio.create_task(consumer()) producer_task = asyncio.create_task(producer()) @@ -428,11 +428,11 @@ async def test_force_stop_terminates_iteration(self): ) # Publish messages but don't commit them (simulates incomplete work) - tracker.on_messages_published(5) + await tracker.on_messages_published(5) # Not quiescent because uncommitted > 0 - assert not tracker.is_quiescent - assert tracker.get_metrics()["uncommitted_messages"] == 5 + assert not await tracker.is_quiescent() + assert (await tracker.get_metrics())["uncommitted_messages"] == 5 # Start iteration in background events = [] @@ -449,7 +449,7 @@ async def iterate(): await asyncio.sleep(0.05) # Force stop should terminate iteration - tracker.force_stop() + await tracker.force_stop() # Wait for iteration to complete try: @@ -474,7 +474,7 @@ async def test_force_stop_yields_queued_events_before_terminating(self): ) # Simulate work with uncommitted messages - tracker.on_messages_published(5) + await tracker.on_messages_published(5) # Queue some events await output_queue.queue.put(Mock(name="event_1")) @@ -494,7 +494,7 @@ async def iterate(): await asyncio.sleep(0.05) # Force stop - tracker.force_stop() + await tracker.force_stop() # Wait for iteration to complete await asyncio.wait_for(iteration_complete.wait(), timeout=1.0) diff --git a/tests/workflow/test_event_driven_workflow.py b/tests/workflow/test_event_driven_workflow.py index 017ed66..510dba0 100644 --- a/tests/workflow/test_event_driven_workflow.py +++ b/tests/workflow/test_event_driven_workflow.py @@ -369,7 +369,7 @@ async def test_tracker_reset_on_init(self, workflow_with_tracker): """Test that tracker is reset on workflow initialization.""" # Add some activity to tracker await workflow_with_tracker._tracker.enter("test_node") - assert not workflow_with_tracker._tracker.is_idle() + assert not await workflow_with_tracker._tracker.is_idle() # Call init_workflow which should reset tracker invoke_context = InvokeContext( @@ -388,7 +388,7 @@ async def test_tracker_reset_on_init(self, workflow_with_tracker): ) # Tracker should be reset - assert workflow_with_tracker._tracker.is_idle() + assert await workflow_with_tracker._tracker.is_idle() class TestEventDrivenWorkflowStopFlag: diff --git a/tests/workflow/test_utils.py b/tests/workflow/test_utils.py index 947326b..fb8853e 100644 --- a/tests/workflow/test_utils.py +++ b/tests/workflow/test_utils.py @@ -200,7 +200,7 @@ async def test_publish_events(self): # Mock node and topics mock_topic1 = AsyncMock(spec=TopicBase) mock_topic2 = AsyncMock(spec=TopicBase) - tracker = MagicMock() + tracker = AsyncMock() node = MagicMock(spec=Node) node.name = "test_node" @@ -245,7 +245,7 @@ async def test_publish_events(self): # Verify topics were called correctly mock_topic1.publish_data.assert_called_once_with(publish_to_event) - tracker.on_messages_published.assert_called_once_with( + tracker.on_messages_published.assert_awaited_once_with( 1, source="node:test_node" )