diff --git a/grafi/workflows/impl/async_node_tracker.py b/grafi/workflows/impl/async_node_tracker.py index 8c9e751..44f7033 100644 --- a/grafi/workflows/impl/async_node_tracker.py +++ b/grafi/workflows/impl/async_node_tracker.py @@ -4,53 +4,265 @@ import asyncio from collections import defaultdict from typing import Dict +from typing import Optional +from typing import Set + +from loguru import logger class AsyncNodeTracker: + """ + Central tracker for workflow activity and quiescence detection. + + Design: All tracking calls come from the ORCHESTRATOR layer, + not from TopicBase. This keeps topics as pure message queues. + + Quiescence = (no active nodes) AND (no uncommitted messages) AND (work done) + + Usage in workflow: + # In publish_events(): + tracker.on_messages_published(len(published_events)) + + # In _commit_events(): + tracker.on_messages_committed(len(events)) + + # In node processing: + await tracker.enter(node_name) + ... process ... + await tracker.leave(node_name) + """ + def __init__(self) -> None: - self._active: set[str] = set() - self._processing_count: Dict[str, int] = defaultdict( - int - ) # Track how many times each node processed + # Node activity tracking + self._active: Set[str] = set() + self._processing_count: Dict[str, int] = defaultdict(int) + + # Message tracking (uncommitted = published but not yet committed) + self._uncommitted_messages: int = 0 + + # Synchronization self._cond = asyncio.Condition() - self._idle_event = asyncio.Event() - # Set the event initially since we start in idle state - self._idle_event.set() + self._quiescence_event = asyncio.Event() + + # Work tracking (prevents premature quiescence before any work) + self._total_committed: int = 0 + self._has_started: bool = False + + # Force stop flag (for explicit workflow stop) + self._force_stopped: bool = False def reset(self) -> None: - """ - Reset the tracker to its initial state. - """ - self._active = set() - self._processing_count = defaultdict(int) + """Reset for a new workflow run.""" + self._active.clear() + self._processing_count.clear() + self._uncommitted_messages = 0 self._cond = asyncio.Condition() - self._idle_event = asyncio.Event() - # Set the event initially since we start in idle state - self._idle_event.set() + self._quiescence_event = asyncio.Event() + self._total_committed = 0 + self._has_started = False + self._force_stopped = False + + # ───────────────────────────────────────────────────────────────────────── + # Node Lifecycle (called from _invoke_node) + # ───────────────────────────────────────────────────────────────────────── async def enter(self, node_name: str) -> None: + """Called when a node begins processing.""" async with self._cond: - self._idle_event.clear() + self._has_started = True + self._quiescence_event.clear() self._active.add(node_name) self._processing_count[node_name] += 1 async def leave(self, node_name: str) -> None: + """Called when a node finishes processing.""" async with self._cond: self._active.discard(node_name) - if not self._active: - self._idle_event.set() - self._cond.notify_all() + self._check_quiescence_unlocked() + self._cond.notify_all() - async def wait_idle_event(self) -> None: + # ───────────────────────────────────────────────────────────────────────── + # Message Tracking (called from orchestrator utilities) + # ───────────────────────────────────────────────────────────────────────── + + async def on_messages_published(self, count: int = 1, source: str = "") -> None: + """ + Called when messages are published to topics. + + Call site: publish_events() in utils.py + """ + if count <= 0: + return + async with self._cond: + self._has_started = True + self._quiescence_event.clear() + self._uncommitted_messages += count + + logger.debug( + f"Tracker: {count} messages published from {source} (uncommitted={self._uncommitted_messages})" + ) + + async def on_messages_committed(self, count: int = 1, source: str = "") -> None: """ - Wait until the tracker is idle, meaning no active nodes. - This is useful for synchronization points in workflows. + Called when messages are committed (consumed and acknowledged). + + Call site: _commit_events() in EventDrivenWorkflow + """ + if count <= 0: + return + async with self._cond: + self._uncommitted_messages = max(0, self._uncommitted_messages - count) + self._total_committed += count + self._check_quiescence_unlocked() + + logger.debug( + f"Tracker: {count} messages committed from {source} " + f"(uncommitted={self._uncommitted_messages}, total={self._total_committed})" + ) + self._cond.notify_all() + + # Aliases for clarity + async def on_message_published(self) -> None: + """Single message version.""" + await self.on_messages_published(1) + + async def on_message_committed(self) -> None: + """Single message version.""" + await self.on_messages_committed(1) + + # ───────────────────────────────────────────────────────────────────────── + # Quiescence Detection + # ───────────────────────────────────────────────────────────────────────── + + def _check_quiescence_unlocked(self) -> None: + """ + Check and signal quiescence if all conditions met. + + MUST be called with self._cond lock held. + """ + is_quiescent = self._is_quiescent_unlocked() + logger.debug( + f"Tracker: checking quiescence - active={list(self._active)}, " + f"uncommitted={self._uncommitted_messages}, " + f"has_started={self._has_started}, " + f"total_committed={self._total_committed}, " + f"is_quiescent={is_quiescent}" + ) + if is_quiescent: + logger.info( + f"Tracker: quiescence detected (committed={self._total_committed})" + ) + self._quiescence_event.set() + + def _is_quiescent_unlocked(self) -> bool: """ - await self._idle_event.wait() + Internal quiescence check without lock. - def is_idle(self) -> bool: - return not self._active + MUST be called with self._cond lock held. - def get_activity_count(self) -> int: - """Get total processing count across all nodes""" - return sum(self._processing_count.values()) + True when workflow is truly idle: + - No nodes actively processing + - No messages waiting to be committed + - At least some work was done + """ + return ( + not self._active + and self._uncommitted_messages == 0 + and self._has_started + and self._total_committed > 0 + ) + + async def is_quiescent(self) -> bool: + """ + True when workflow is truly idle: + - No nodes actively processing + - No messages waiting to be committed + - At least some work was done + + This method acquires the lock to ensure consistent reads. + """ + async with self._cond: + return self._is_quiescent_unlocked() + + def _should_terminate_unlocked(self) -> bool: + """ + Internal termination check without lock. + + MUST be called with self._cond lock held. + """ + return self._is_quiescent_unlocked() or self._force_stopped + + async def should_terminate(self) -> bool: + """ + True when workflow should stop iteration. + Either natural quiescence or explicit force stop. + + This method acquires the lock to ensure consistent reads. + """ + async with self._cond: + return self._should_terminate_unlocked() + + async def force_stop(self) -> None: + """ + Force the workflow to stop immediately (async version with lock). + Called when workflow.stop() is invoked from async context. + """ + async with self._cond: + logger.info("Tracker: force stop requested") + self._force_stopped = True + self._quiescence_event.set() + self._cond.notify_all() + + def force_stop_sync(self) -> None: + """ + Force the workflow to stop immediately (sync version). + + This is a synchronous version for use from sync contexts (e.g., stop() method). + It sets the stop flag and event without acquiring the async lock. + This is safe because: + 1. Setting _force_stopped to True is atomic for the stop signal + 2. asyncio.Event.set() is thread-safe + 3. Readers will see the updated state on their next lock acquisition + """ + logger.info("Tracker: force stop requested (sync)") + self._force_stopped = True + self._quiescence_event.set() + + async def is_idle(self) -> bool: + """Legacy: just checks if no active nodes.""" + async with self._cond: + return not self._active + + async def wait_for_quiescence(self, timeout: Optional[float] = None) -> bool: + """Wait until quiescent. Returns False on timeout.""" + try: + if timeout: + await asyncio.wait_for(self._quiescence_event.wait(), timeout) + else: + await self._quiescence_event.wait() + return True + except asyncio.TimeoutError: + return False + + async def wait_idle_event(self) -> None: + """Legacy compatibility.""" + await self._quiescence_event.wait() + + # ───────────────────────────────────────────────────────────────────────── + # Metrics + # ───────────────────────────────────────────────────────────────────────── + + async def get_activity_count(self) -> int: + """Total processing count across all nodes.""" + async with self._cond: + return sum(self._processing_count.values()) + + async def get_metrics(self) -> Dict: + """Detailed metrics for debugging.""" + async with self._cond: + return { + "active_nodes": list(self._active), + "uncommitted_messages": self._uncommitted_messages, + "total_committed": self._total_committed, + "is_quiescent": self._is_quiescent_unlocked(), + } diff --git a/grafi/workflows/impl/async_output_queue.py b/grafi/workflows/impl/async_output_queue.py index 3607126..e3c72d2 100644 --- a/grafi/workflows/impl/async_output_queue.py +++ b/grafi/workflows/impl/async_output_queue.py @@ -1,7 +1,8 @@ import asyncio -from typing import AsyncGenerator from typing import List +from loguru import logger + from grafi.common.events.topic_events.topic_event import TopicEvent from grafi.topics.topic_base import TopicBase from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker @@ -9,8 +10,9 @@ class AsyncOutputQueue: """ - Manages output topics and their listeners for async workflow execution. - Wraps output_topics, listener_tasks, and tracker functionality. + Manages output topics and provides async iteration over output events. + + Simplified: All quiescence detection delegated to AsyncNodeTracker. """ def __init__( @@ -23,95 +25,116 @@ def __init__( self.consumer_name = consumer_name self.tracker = tracker self.queue: asyncio.Queue[TopicEvent] = asyncio.Queue() - self.listener_tasks: List[asyncio.Task] = [] + self._listener_tasks: List[asyncio.Task] = [] + self._stopped = False async def start_listeners(self) -> None: """Start listener tasks for all output topics.""" - self.listener_tasks = [ + self._stopped = False + self._listener_tasks = [ asyncio.create_task(self._output_listener(topic)) for topic in self.output_topics ] async def stop_listeners(self) -> None: """Stop all listener tasks.""" - for task in self.listener_tasks: + self._stopped = True + for task in self._listener_tasks: task.cancel() - await asyncio.gather(*self.listener_tasks, return_exceptions=True) + await asyncio.gather(*self._listener_tasks, return_exceptions=True) + self._listener_tasks.clear() async def _output_listener(self, topic: TopicBase) -> None: """ - Streams *matching* records from `topic` into `queue`. - Exits when the graph is idle *and* the topic has no more unseen data, - with proper handling for downstream node activation. - """ - last_activity_count = 0 - - while True: - # waiter 1: "some records arrived" - topic_task = asyncio.create_task(topic.consume(self.consumer_name)) - # waiter 2: "graph just became idle" - idle_event_waiter = asyncio.create_task(self.tracker.wait_idle_event()) - - done, pending = await asyncio.wait( - {topic_task, idle_event_waiter}, - return_when=asyncio.FIRST_COMPLETED, - ) - - # ---- If records arrived ----------------------------------------- - if topic_task in done: - output_events = topic_task.result() - - for output_event in output_events: - await self.queue.put(output_event) - - # ---- Check for workflow completion ---------------- - if idle_event_waiter in done and self.tracker.is_idle(): - current_activity = self.tracker.get_activity_count() + Forward events to queue and track message consumption. - # If no new activity since last check and no data, we're done - if ( - current_activity == last_activity_count - and not await topic.can_consume(self.consumer_name) - ): - # cancel an unfinished waiter (if any) to avoid warnings - for t in pending: - t.cancel() - break - - last_activity_count = current_activity - - # Cancel the topic task since we're checking idle state - for t in pending: - t.cancel() - - def __aiter__(self) -> AsyncGenerator[TopicEvent, None]: - """Make AsyncOutputQueue async iterable.""" + When events are consumed from output topics, they've reached their + destination (the output queue), so we mark them as committed. + """ + while not self._stopped: + try: + events = await topic.consume(self.consumer_name, timeout=0.1) + for event in events: + await self.queue.put(event) + # Mark messages as committed when they reach the output queue + if events: + logger.debug( + f"Output listener: consumed {len(events)} events from {topic.name}" + ) + await self.tracker.on_messages_committed( + len(events), source=f"output_listener:{topic.name}" + ) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Output listener error for {topic.name}: {e}") + await asyncio.sleep(0.1) + + def __aiter__(self) -> "AsyncOutputQueue": return self async def __anext__(self) -> TopicEvent: - """Async iteration implementation with idle detection.""" - # two parallel waiters + """ + SIMPLIFIED: Delegates quiescence check entirely to tracker. + + Removed: + - last_activity_count tracking + - asyncio.sleep(0) hack + - duplicated idle detection logic + """ + check_count = 0 while True: + check_count += 1 + + # Fast path: queue has items + if not self.queue.empty(): + try: + return self.queue.get_nowait() + except asyncio.QueueEmpty: + pass + + # Check for completion (natural quiescence or force stop) + if await self.tracker.should_terminate(): + # Final drain attempt - try to get any remaining items before stopping + # This avoids race where item is added between empty() check and raising + try: + return self.queue.get_nowait() + except asyncio.QueueEmpty: + raise StopAsyncIteration + + # Wait for queue item or quiescence queue_task = asyncio.create_task(self.queue.get()) - idle_task = asyncio.create_task(self.tracker._idle_event.wait()) + quiescent_task = asyncio.create_task( + self.tracker.wait_for_quiescence(timeout=0.5) + ) done, pending = await asyncio.wait( - {queue_task, idle_task}, + {queue_task, quiescent_task}, return_when=asyncio.FIRST_COMPLETED, ) - # Case A: we got a queue item first → stream it - if queue_task in done: - idle_task.cancel() - await asyncio.gather(idle_task, return_exceptions=True) - return queue_task.result() - - # Case B: pipeline went idle first - queue_task.cancel() - await asyncio.gather(queue_task, return_exceptions=True) - - # Give downstream consumers one chance to register activity. - await asyncio.sleep(0) # one event‑loop tick - - if self.tracker.is_idle() and self.queue.empty(): - raise StopAsyncIteration + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Got queue item + if queue_task in done and not queue_task.cancelled(): + try: + return queue_task.result() + except asyncio.QueueEmpty: + # Task was cancelled as part of normal cleanup; ignore. + continue + + # Quiescence or force stop detected + if await self.tracker.should_terminate(): + # Final drain attempt - try to get any remaining items before stopping + # This avoids race where item is added between empty() check and raising + try: + return self.queue.get_nowait() + except asyncio.QueueEmpty: + raise StopAsyncIteration diff --git a/grafi/workflows/impl/event_driven_workflow.py b/grafi/workflows/impl/event_driven_workflow.py index d337314..92d8fd6 100644 --- a/grafi/workflows/impl/event_driven_workflow.py +++ b/grafi/workflows/impl/event_driven_workflow.py @@ -64,7 +64,7 @@ class EventDrivenWorkflow(Workflow): # Queue of nodes that are ready to invoke (in response to published events) _invoke_queue: deque[NodeBase] = PrivateAttr(default=deque()) - _tracker: AsyncNodeTracker = AsyncNodeTracker() + _tracker: AsyncNodeTracker = PrivateAttr(default_factory=AsyncNodeTracker) # Optional callback that handles output events # Including agent output event, stream event and hil event @@ -73,6 +73,14 @@ def model_post_init(self, _context: Any) -> None: self._add_topics() self._handle_function_calling_nodes() + def stop(self) -> None: + """ + Stop the workflow execution. + Overrides base class to also trigger force stop on the tracker. + """ + super().stop() + self._tracker.force_stop_sync() + @classmethod def builder(cls) -> WorkflowBuilder: """ @@ -206,11 +214,13 @@ async def _get_output_events(self) -> List[ConsumeFromTopicEvent]: return consumed_events async def _commit_events( - self, consumer_name: str, topic_events: List[ConsumeFromTopicEvent] + self, + consumer_name: str, + topic_events: List[ConsumeFromTopicEvent], + track_commit: bool = True, ) -> None: if not topic_events: return - # commit all consumed events topic_max_offset: Dict[str, int] = {} for topic_event in topic_events: @@ -221,6 +231,16 @@ async def _commit_events( for topic, offset in topic_max_offset.items(): await self._topics[topic].commit(consumer_name, offset) + # Notify tracker that messages have been committed + # (skip if already tracked elsewhere, e.g., by output listener) + if track_commit: + logger.debug( + f"Committing {len(topic_events)} events for {consumer_name}, track_commit={track_commit}" + ) + await self._tracker.on_messages_committed( + len(topic_events), source=f"commit:{consumer_name}" + ) + async def _add_to_invoke_queue(self, event: TopicEvent) -> None: topic_name = event.name @@ -271,7 +291,9 @@ async def invoke_sequential( async for result in node.invoke( invoke_context, node_consumed_events ): - published_events.extend(await publish_events(node, result)) + published_events.extend( + await publish_events(node, result, self._tracker) + ) for event in published_events: await self._add_to_invoke_queue(event) @@ -301,6 +323,9 @@ async def invoke_parallel( self, input_data: PublishToTopicEvent ) -> AsyncGenerator[ConsumeFromTopicEvent, None]: invoke_context = input_data.invoke_context + logger.debug( + f"invoke_parallel: tracker_id={id(self._tracker)}, metrics={await self._tracker.get_metrics()}" + ) # Start a background task to process all nodes (including streaming generators) node_processing_task = [ @@ -374,9 +399,12 @@ async def invoke_parallel( finally: await output_queue.stop_listeners() - # Commit all consumed output events + # Commit all consumed output events to topics + # (tracking already done by output listener, so skip tracker update) await self._commit_events( - consumer_name=self.name, topic_events=consumed_output_events + consumer_name=self.name, + topic_events=consumed_output_events, + track_commit=False, ) # 4. graceful shutdown all the nodes @@ -500,7 +528,11 @@ def _cancel_all_active_tasks() -> None: if consumed_events: async for event in node.invoke(invoke_context, consumed_events): node_output_events.extend( - await publish_events(node=node, publish_event=event) + await publish_events( + node=node, + publish_event=event, + tracker=self._tracker, + ) ) await self._commit_events( @@ -576,6 +608,9 @@ async def init_workflow( self, input_data: PublishToTopicEvent, is_sequential: bool = False ) -> Any: # 1 – initial seeding + logger.debug( + f"init_workflow: is_sequential={is_sequential}, tracker_id={id(self._tracker)}" + ) if not is_sequential: self._tracker.reset() @@ -615,7 +650,21 @@ async def init_workflow( if is_sequential: await self._add_to_invoke_queue(event) + logger.debug( + f"init_workflow: events_to_record={len(events_to_record)}, input_topics={len(input_topics)}" + ) if events_to_record: + # Track initial input messages for quiescence detection + if not is_sequential: + logger.debug( + f"init_workflow: calling on_messages_published({len(events_to_record)})" + ) + await self._tracker.on_messages_published( + len(events_to_record), source="init_workflow" + ) + logger.debug( + f"init_workflow: tracker after publish: {await self._tracker.get_metrics()}" + ) await container.event_store.record_events(events_to_record) else: # When there is unfinished workflow, we need to restore the workflow topics @@ -667,6 +716,11 @@ async def init_workflow( ) ) if paired_event: + # Track the published message for quiescence detection + if not is_sequential: + await self._tracker.on_messages_published( + 1, source="restore_paired_input" + ) if is_sequential: await self._add_to_invoke_queue(paired_event) await container.event_store.record_event(paired_event) diff --git a/grafi/workflows/impl/utils.py b/grafi/workflows/impl/utils.py index 6c9df6f..8052c16 100644 --- a/grafi/workflows/impl/utils.py +++ b/grafi/workflows/impl/utils.py @@ -8,17 +8,14 @@ from grafi.common.events.topic_events.topic_event import TopicEvent from grafi.common.models.message import Message from grafi.nodes.node_base import NodeBase +from grafi.workflows.impl.async_node_tracker import AsyncNodeTracker def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: """ Process a list of TopicEvents, grouping by name and aggregating streaming messages. - Args: - events: List of TopicEvents to process - - Returns: - List of processed TopicEvents with streaming messages aggregated + NO CHANGES NEEDED - this is read-only aggregation. """ # Group events by name events_by_topic: Dict[str, List[TopicEvent]] = {} @@ -35,9 +32,7 @@ def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: non_streaming_events: List[TopicEvent] = [] for event in topic_events: - # Check if event.data contains streaming messages is_streaming_event = False - # Handle both single message and list of messages messages = event.data if messages and len(messages) > 0 and messages[0].is_streaming: is_streaming_event = True @@ -47,25 +42,18 @@ def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: else: non_streaming_events.append(event) - # Add non-streaming events as-is output_events.extend(non_streaming_events) - # Aggregate streaming events if any exist if streaming_events: - # Use the first streaming event as the base for creating the aggregated event base_event = streaming_events[0] - - # Aggregate content from all streaming messages aggregated_content_parts = [] for event in streaming_events: messages = event.data if isinstance(event.data, list) else [event.data] for message in messages: if message.content: aggregated_content_parts.append(message.content) - aggregated_content = "".join(aggregated_content_parts) # type: ignore[arg-type] + aggregated_content = "".join(aggregated_content_parts) - # Create a new message with aggregated content - # Copy properties from the first message but update content and streaming flag first_message = ( base_event.data if isinstance(base_event.data, list) @@ -74,40 +62,55 @@ def get_async_output_events(events: List[TopicEvent]) -> List[TopicEvent]: aggregated_message = Message( role=first_message.role, content=aggregated_content, - is_streaming=False, # Aggregated message is no longer streaming + is_streaming=False, ) - # Create new event based on the base event type aggregated_event = base_event aggregated_event.data = [aggregated_message] - output_events.append(aggregated_event) return output_events async def publish_events( - node: NodeBase, publish_event: PublishToTopicEvent + node: NodeBase, + publish_event: PublishToTopicEvent, + tracker: AsyncNodeTracker, ) -> List[PublishToTopicEvent]: + """ + Publish events to all topics the node publishes to. + + CHANGE: Added optional tracker parameter. + When provided, notifies tracker of published messages. + """ published_events: List[PublishToTopicEvent] = [] + for topic in node.publish_to: event = await topic.publish_data(publish_event) - if event: published_events.append(event) + # NEW: Notify tracker of published messages + if tracker and published_events: + await tracker.on_messages_published( + len(published_events), source=f"node:{node.name}" + ) + return published_events async def get_node_input(node: NodeBase) -> List[ConsumeFromTopicEvent]: + """ + Get input events for a node from its subscribed topics. + + NO CHANGES NEEDED - consumption tracking happens at commit time. + """ consumed_events: List[ConsumeFromTopicEvent] = [] node_subscribed_topics = node._subscribed_topics.values() - # Process each topic the node is subscribed to for subscribed_topic in node_subscribed_topics: if await subscribed_topic.can_consume(node.name): - # Get messages from topic and create consume events node_consumed_events = await subscribed_topic.consume(node.name) for event in node_consumed_events: consumed_event = ConsumeFromTopicEvent( diff --git a/pyproject.toml b/pyproject.toml index 606d2b8..92315dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "grafi" -version = "0.0.33" +version = "0.0.34" description = "Grafi - a flexible, event-driven framework that enables the creation of domain-specific AI agents through composable agentic workflows." authors = [{name = "Craig Li", email = "craig@binome.dev"}] license = {text = "Mozilla Public License Version 2.0"} diff --git a/tests/assistants/test_assistant.py b/tests/assistants/test_assistant.py index cdbfbb3..e181c9a 100644 --- a/tests/assistants/test_assistant.py +++ b/tests/assistants/test_assistant.py @@ -314,6 +314,251 @@ def test_generate_manifest_custom_directory(self, mock_assistant, tmp_path): }, } + @pytest.mark.asyncio + async def test_eight_node_dag_workflow(self): + """ + Test a DAG workflow with 8 nodes: A->B, B->C, C->D, C->E, C->F, D->G, E->G, F->G, G->H. + + Each node concatenates the previous input with its own label. + For example, node B receives "A" and outputs "AB". + + The topology creates a fan-out at C (to D, E, F) and a fan-in at G (from D, E, F). + """ + from grafi.common.events.topic_events.publish_to_topic_event import ( + PublishToTopicEvent, + ) + from grafi.common.models.invoke_context import InvokeContext + from grafi.common.models.message import Message + from grafi.nodes.node import Node + from grafi.tools.functions.function_tool import FunctionTool + from grafi.topics.expressions.subscription_builder import SubscriptionBuilder + from grafi.topics.topic_impl.input_topic import InputTopic + from grafi.topics.topic_impl.output_topic import OutputTopic + from grafi.topics.topic_impl.topic import Topic + from grafi.workflows.impl.event_driven_workflow import EventDrivenWorkflow + + # Define the concatenation function for each node + def make_concat_func(label: str): + def concat_func(messages): + # Collect all content from input messages + contents = [] + for msg in messages: + if msg.content: + contents.append(msg.content) + # Sort to ensure deterministic ordering for fan-in scenarios + contents.sort() + combined = "".join(contents) + return f"{combined}{label}" + + return concat_func + + # Create topics + # Input/Output topics for the workflow + agent_input_topic = InputTopic(name="agent_input") + agent_output_topic = OutputTopic(name="agent_output") + + # Intermediate topics for connecting nodes + topic_a_out = Topic(name="topic_a_out") + topic_b_out = Topic(name="topic_b_out") + topic_c_out = Topic(name="topic_c_out") + topic_d_out = Topic(name="topic_d_out") + topic_e_out = Topic(name="topic_e_out") + topic_f_out = Topic(name="topic_f_out") + topic_g_out = Topic(name="topic_g_out") + + # Create nodes + # Node A: subscribes to agent_input, publishes to topic_a_out + node_a = ( + Node.builder() + .name("NodeA") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(agent_input_topic).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolA") + .function(make_concat_func("A")) + .build() + ) + .publish_to(topic_a_out) + .build() + ) + + # Node B: subscribes to topic_a_out, publishes to topic_b_out + node_b = ( + Node.builder() + .name("NodeB") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_a_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolB") + .function(make_concat_func("B")) + .build() + ) + .publish_to(topic_b_out) + .build() + ) + + # Node C: subscribes to topic_b_out, publishes to topic_c_out + node_c = ( + Node.builder() + .name("NodeC") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_b_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolC") + .function(make_concat_func("C")) + .build() + ) + .publish_to(topic_c_out) + .build() + ) + + # Node D: subscribes to topic_c_out, publishes to topic_d_out (fan-out from C) + node_d = ( + Node.builder() + .name("NodeD") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_c_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolD") + .function(make_concat_func("D")) + .build() + ) + .publish_to(topic_d_out) + .build() + ) + + # Node E: subscribes to topic_c_out, publishes to topic_e_out (fan-out from C) + node_e = ( + Node.builder() + .name("NodeE") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_c_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolE") + .function(make_concat_func("E")) + .build() + ) + .publish_to(topic_e_out) + .build() + ) + + # Node F: subscribes to topic_c_out, publishes to topic_f_out (fan-out from C) + node_f = ( + Node.builder() + .name("NodeF") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_c_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolF") + .function(make_concat_func("F")) + .build() + ) + .publish_to(topic_f_out) + .build() + ) + + # Node G: subscribes to topic_d_out AND topic_e_out AND topic_f_out (fan-in) + node_g = ( + Node.builder() + .name("NodeG") + .type("ConcatNode") + .subscribe( + SubscriptionBuilder() + .subscribed_to(topic_d_out) + .and_() + .subscribed_to(topic_e_out) + .and_() + .subscribed_to(topic_f_out) + .build() + ) + .tool( + FunctionTool.builder() + .name("ConcatToolG") + .function(make_concat_func("G")) + .build() + ) + .publish_to(topic_g_out) + .build() + ) + + # Node H: subscribes to topic_g_out, publishes to agent_output + node_h = ( + Node.builder() + .name("NodeH") + .type("ConcatNode") + .subscribe(SubscriptionBuilder().subscribed_to(topic_g_out).build()) + .tool( + FunctionTool.builder() + .name("ConcatToolH") + .function(make_concat_func("H")) + .build() + ) + .publish_to(agent_output_topic) + .build() + ) + + # Build the workflow + workflow = ( + EventDrivenWorkflow.builder() + .name("EightNodeDAGWorkflow") + .node(node_a) + .node(node_b) + .node(node_c) + .node(node_d) + .node(node_e) + .node(node_f) + .node(node_g) + .node(node_h) + .build() + ) + + # Create assistant with the workflow + with patch.object(Assistant, "_construct_workflow"): + assistant = Assistant( + name="EightNodeDAGAssistant", + workflow=workflow, + ) + + # Create invoke context and input + invoke_context = InvokeContext( + conversation_id="test_dag_conversation", + invoke_id="test_dag_invoke", + assistant_request_id="test_dag_request", + ) + + # Start with empty input - each node adds its label + input_messages = [Message(content="", role="user")] + input_data = PublishToTopicEvent( + invoke_context=invoke_context, data=input_messages + ) + + # Invoke the workflow (using default parallel mode) + result_events = [] + async for event in assistant.invoke(input_data): + result_events.append(event) + + # Verify we get exactly 1 event from the agent_output topic + assert len(result_events) == 1, f"Expected 1 event, got {len(result_events)}" + assert result_events[0].name == "agent_output" + + # The expected output path is: + # A: "" -> "A" + # B: "A" -> "AB" + # C: "AB" -> "ABC" + # D: "ABC" -> "ABCD" + # E: "ABC" -> "ABCE" + # F: "ABC" -> "ABCF" + # G: combines "ABCD", "ABCE", "ABCF" (sorted) -> "ABCDABCEABCFG" + # H: "ABCDABCEABCFG" -> "ABCDABCEABCFGH" + expected_output = "ABCDABCEABCFGH" + assert result_events[0].data[0].content == expected_output + def test_generate_manifest_file_write_error(self, mock_assistant): """Test manifest generation with file write error.""" with patch("builtins.open", side_effect=IOError("Permission denied")): diff --git a/tests/workflow/test_async_node_tracker.py b/tests/workflow/test_async_node_tracker.py index 43bc138..e348740 100644 --- a/tests/workflow/test_async_node_tracker.py +++ b/tests/workflow/test_async_node_tracker.py @@ -13,155 +13,124 @@ def tracker(self): @pytest.mark.asyncio async def test_initial_state(self, tracker): - """Test that tracker starts in idle state.""" - assert tracker.is_idle() - assert tracker.get_activity_count() == 0 - assert tracker._idle_event.is_set() + """Tracker starts idle with no work recorded.""" + assert await tracker.is_idle() + assert await tracker.is_quiescent() is False + assert await tracker.get_activity_count() == 0 + assert (await tracker.get_metrics())["uncommitted_messages"] == 0 @pytest.mark.asyncio - async def test_enter_makes_tracker_active(self, tracker): - """Test that entering a node makes the tracker active.""" + async def test_enter_and_leave_updates_activity(self, tracker): + """Entering and leaving nodes updates activity counts.""" await tracker.enter("node1") - assert not tracker.is_idle() - assert not tracker._idle_event.is_set() - assert tracker.get_activity_count() == 1 + assert not await tracker.is_idle() + assert await tracker.get_activity_count() == 1 assert "node1" in tracker._active - @pytest.mark.asyncio - async def test_leave_makes_tracker_idle(self, tracker): - """Test that leaving the last node makes the tracker idle.""" - await tracker.enter("node1") await tracker.leave("node1") - assert tracker.is_idle() - assert tracker._idle_event.is_set() - assert tracker.get_activity_count() == 1 # Count persists - assert "node1" not in tracker._active + assert await tracker.is_idle() + # No commits yet so quiescence is still False + assert await tracker.is_quiescent() is False + assert await tracker.get_activity_count() == 1 @pytest.mark.asyncio - async def test_multiple_nodes_tracking(self, tracker): - """Test tracking multiple nodes.""" - await tracker.enter("node1") - await tracker.enter("node2") + async def test_message_tracking_and_quiescence(self, tracker): + """Published/committed message tracking drives quiescence detection.""" + await tracker.on_messages_published(2) + assert await tracker.is_quiescent() is False + assert (await tracker.get_metrics())["uncommitted_messages"] == 2 - assert not tracker.is_idle() - assert tracker.get_activity_count() == 2 - assert "node1" in tracker._active - assert "node2" in tracker._active + await tracker.on_messages_committed(1) + assert await tracker.is_quiescent() is False + assert (await tracker.get_metrics())["uncommitted_messages"] == 1 - await tracker.leave("node1") - assert not tracker.is_idle() # Still has node2 - - await tracker.leave("node2") - assert tracker.is_idle() + await tracker.on_messages_committed(1) + assert await tracker.is_quiescent() is True + assert (await tracker.get_metrics())["uncommitted_messages"] == 0 @pytest.mark.asyncio - async def test_reentrant_node_increases_count(self, tracker): - """Test that entering the same node multiple times increases count.""" - await tracker.enter("node1") - await tracker.enter("node1") - - assert tracker.get_activity_count() == 2 - assert len(tracker._active) == 1 # Still just one node in active set - - @pytest.mark.asyncio - async def test_wait_idle_event(self, tracker): - """Test waiting for idle event.""" - # Initially idle - await asyncio.wait_for(tracker.wait_idle_event(), timeout=0.1) + async def test_wait_for_quiescence(self, tracker): + """wait_for_quiescence resolves when work finishes.""" + await tracker.on_messages_published(1) - # Enter a node - await tracker.enter("node1") + async def finish_work(): + await asyncio.sleep(0.01) + await tracker.on_messages_committed(1) - # Create a task that waits for idle - idle_task = asyncio.create_task(tracker.wait_idle_event()) + asyncio.create_task(finish_work()) - # Should not be done yet - await asyncio.sleep(0.01) - assert not idle_task.done() + result = await tracker.wait_for_quiescence(timeout=0.5) + assert result is True + assert await tracker.is_quiescent() is True - # Leave node to trigger idle - await tracker.leave("node1") - - # Now the wait should complete - await asyncio.wait_for(idle_task, timeout=0.1) + @pytest.mark.asyncio + async def test_wait_for_quiescence_timeout(self, tracker): + """wait_for_quiescence returns False on timeout.""" + result = await tracker.wait_for_quiescence(timeout=0.01) + assert result is False + assert await tracker.is_quiescent() is False @pytest.mark.asyncio async def test_reset(self, tracker): - """Test reset functionality.""" - # Add some activity + """Reset clears activity and quiescence state.""" await tracker.enter("node1") - await tracker.enter("node2") - await tracker.leave("node1") + await tracker.on_messages_published(1) + await tracker.on_messages_committed(1) - assert not tracker.is_idle() - assert tracker.get_activity_count() > 0 - - # Reset tracker.reset() - # Should be back to initial state - assert tracker.is_idle() - assert tracker.get_activity_count() == 0 - assert tracker._idle_event.is_set() - assert len(tracker._active) == 0 + assert await tracker.is_idle() + assert await tracker.is_quiescent() is False + assert await tracker.get_activity_count() == 0 + assert (await tracker.get_metrics())["total_committed"] == 0 @pytest.mark.asyncio - async def test_concurrent_enter_leave(self, tracker): - """Test concurrent enter/leave operations.""" - - async def enter_leave_cycle(node_name: str, cycles: int): - for _ in range(cycles): - await tracker.enter(node_name) - await asyncio.sleep(0.001) # Small delay - await tracker.leave(node_name) - - # Run multiple concurrent cycles - tasks = [ - asyncio.create_task(enter_leave_cycle(f"node{i}", 10)) for i in range(5) - ] + async def test_force_stop(self, tracker): + """Force stop terminates workflow even with uncommitted messages.""" + await tracker.on_messages_published(2) + assert await tracker.is_quiescent() is False + assert await tracker.should_terminate() is False - await asyncio.gather(*tasks) + await tracker.force_stop() - # Should be idle after all complete - assert tracker.is_idle() - assert tracker.get_activity_count() == 50 # 5 nodes * 10 cycles + # Not quiescent (uncommitted messages still exist) + assert await tracker.is_quiescent() is False + # But should_terminate is True due to force stop + assert await tracker.should_terminate() is True + assert tracker._force_stopped is True @pytest.mark.asyncio - async def test_leave_nonexistent_node(self, tracker): - """Test leaving a node that was never entered.""" - # Should not raise an error - await tracker.leave("nonexistent") - assert tracker.is_idle() + async def test_should_terminate_on_quiescence(self, tracker): + """should_terminate is True when naturally quiescent.""" + await tracker.on_messages_published(1) + await tracker.on_messages_committed(1) - @pytest.mark.asyncio - async def test_condition_notification(self, tracker): - """Test that condition is properly notified on idle.""" - await tracker.enter("node1") + assert await tracker.is_quiescent() is True + assert await tracker.should_terminate() is True + assert tracker._force_stopped is False - # Create a flag to verify notification happened - notified = False + @pytest.mark.asyncio + async def test_force_stop_triggers_quiescence_event(self, tracker): + """Force stop sets the quiescence event so waiters can proceed.""" + await tracker.on_messages_published(1) - async def wait_for_notification(): - nonlocal notified - async with tracker._cond: - await tracker._cond.wait() - notified = True + # Event should not be set yet + assert not tracker._quiescence_event.is_set() - wait_task = asyncio.create_task(wait_for_notification()) + await tracker.force_stop() - # Give task time to start waiting - await asyncio.sleep(0.01) + # Event should now be set + assert tracker._quiescence_event.is_set() - # Leave node to trigger notification - await tracker.leave("node1") + @pytest.mark.asyncio + async def test_reset_clears_force_stop(self, tracker): + """Reset clears the force stop flag.""" + await tracker.force_stop() + assert tracker._force_stopped is True - # Wait should complete - try: - await asyncio.wait_for(wait_task, timeout=0.1) - except asyncio.TimeoutError: - pass # It's ok if it times out, we just check if notified + tracker.reset() - # Check that notification happened - assert tracker.is_idle() + assert tracker._force_stopped is False + assert await tracker.should_terminate() is False diff --git a/tests/workflow/test_async_output_queue.py b/tests/workflow/test_async_output_queue.py index afd4455..e28c963 100644 --- a/tests/workflow/test_async_output_queue.py +++ b/tests/workflow/test_async_output_queue.py @@ -1,4 +1,5 @@ import asyncio +from unittest.mock import Mock import pytest @@ -21,10 +22,10 @@ def __init__(self, name: str): self._events = [] self._consumed_offset = -1 - async def consume(self, consumer_name: str): + async def consume(self, consumer_name: str, timeout: float | None = None): """Mock async consume that returns events.""" - # Simulate waiting for events - await asyncio.sleep(0.01) + if timeout and timeout > 0: + await asyncio.sleep(timeout) # Return events after consumed offset new_events = [e for e in self._events if e.offset > self._consumed_offset] @@ -63,15 +64,15 @@ def test_initialization(self, output_queue, mock_topics, tracker): assert output_queue.consumer_name == "test_consumer" assert output_queue.tracker == tracker assert isinstance(output_queue.queue, asyncio.Queue) - assert output_queue.listener_tasks == [] + assert output_queue._listener_tasks == [] @pytest.mark.asyncio async def test_start_listeners(self, output_queue, mock_topics): """Test starting listener tasks.""" await output_queue.start_listeners() - assert len(output_queue.listener_tasks) == len(mock_topics) - for task in output_queue.listener_tasks: + assert len(output_queue._listener_tasks) == len(mock_topics) + for task in output_queue._listener_tasks: assert isinstance(task, asyncio.Task) assert not task.done() @@ -82,7 +83,7 @@ async def test_start_listeners(self, output_queue, mock_topics): async def test_stop_listeners(self, output_queue): """Test stopping listener tasks.""" await output_queue.start_listeners() - tasks = output_queue.listener_tasks.copy() + tasks = output_queue._listener_tasks.copy() await output_queue.stop_listeners() @@ -96,9 +97,6 @@ async def test_output_listener_receives_events(self, mock_topics, tracker): topic = mock_topics[0] queue = AsyncOutputQueue([topic], "test_consumer", tracker) - # Add activity to prevent listener from exiting - await tracker.enter("test_node") - # Add test event test_event = PublishToTopicEvent( name="output1", @@ -117,7 +115,7 @@ async def test_output_listener_receives_events(self, mock_topics, tracker): listener_task = asyncio.create_task(queue._output_listener(topic)) # Wait a bit for event to be processed - await asyncio.sleep(0.05) + await asyncio.sleep(0.15) # Check event was queued assert not queue.queue.empty() @@ -125,7 +123,6 @@ async def test_output_listener_receives_events(self, mock_topics, tracker): assert queued_event == test_event # Clean up - await tracker.leave("test_node") listener_task.cancel() await asyncio.gather(listener_task, return_exceptions=True) @@ -163,56 +160,51 @@ async def collect_events(): assert events[0] == test_event @pytest.mark.asyncio - async def test_async_iteration_stops_on_idle(self, output_queue, tracker): - """Test that async iteration stops when tracker is idle and queue is empty.""" - # Make tracker idle - assert tracker.is_idle() - - # Ensure queue is empty - assert output_queue.queue.empty() + async def test_async_iteration_stops_after_quiescence(self, output_queue, tracker): + """Async iteration ends when tracker reports quiescence and queue is empty.""" + await tracker.on_messages_published(1) + await tracker.on_messages_committed(1) - # Iteration should stop events = [] async for event in output_queue: events.append(event) - assert len(events) == 0 + assert events == [] @pytest.mark.asyncio - async def test_listener_exits_on_idle_and_no_data(self, mock_topics, tracker): - """Test that listener exits when workflow is idle and no more data.""" - topic = mock_topics[0] - queue = AsyncOutputQueue([topic], "test_consumer", tracker) - - # Ensure tracker is idle - assert tracker.is_idle() + async def test_async_iteration_waits_for_quiescence_or_events(self, tracker): + """__anext__ waits for either new queue data or quiescence signal.""" + queue = AsyncOutputQueue([], "test_consumer", tracker) - # Run listener - should exit quickly since idle and no data - await queue._output_listener(topic) + async def signal_quiescence(): + await tracker.on_messages_published(1) + await asyncio.sleep(0.02) + await tracker.on_messages_committed(1) - # Should complete without hanging - assert True + signal_task = asyncio.create_task(signal_quiescence()) - @pytest.mark.asyncio - async def test_listener_continues_with_activity(self, mock_topics, tracker): - """Test that listener continues when there's activity.""" - topic = mock_topics[0] - queue = AsyncOutputQueue([topic], "test_consumer", tracker) + events = [] + async for event in queue: + events.append(event) - # Add activity - await tracker.enter("node1") + await signal_task + assert events == [] - # Start listener - listener_task = asyncio.create_task(queue._output_listener(topic)) + @pytest.mark.asyncio + async def test_event_emitted_before_quiescence(self, tracker): + """Events in the queue are yielded even if quiescence follows immediately.""" + queue = AsyncOutputQueue([], "test_consumer", tracker) + queued_event = Mock() + queued_event.name = "queued_event" + await queue.queue.put(queued_event) - # Should still be running - await asyncio.sleep(0.05) - assert not listener_task.done() + await tracker.on_messages_published(1) + await tracker.on_messages_committed(1) - # Clean up - listener_task.cancel() - await asyncio.gather(listener_task, return_exceptions=True) - await tracker.leave("node1") + events = [] + async for event in queue: + events.append(event) + assert [e.name for e in events] == ["queued_event"] @pytest.mark.asyncio async def test_type_annotations(self, output_queue): @@ -268,3 +260,245 @@ async def test_concurrent_listeners(self, tracker): # Should have collected all events assert len(collected) == 3 assert all(isinstance(e, PublishToTopicEvent) for e in collected) + + @pytest.mark.asyncio + async def test_anext_waits_for_activity_count_stabilization(self): + """ + Test that __anext__ doesn't prematurely terminate when activity count changes. + + This tests the race condition fix where the output queue could terminate + before downstream nodes finish processing. + """ + tracker = AsyncNodeTracker() + + output_queue = AsyncOutputQueue( + output_topics=[], # Empty - we'll put events directly in queue + consumer_name="test_consumer", + tracker=tracker, + ) + + # Simulate: node enters, adds item to queue, leaves + # Then another node should enter before we terminate + + async def simulate_node_activity(): + """Simulate node activity that should prevent premature termination.""" + # First node processes - simulate full message lifecycle + await tracker.on_messages_published(1) + await tracker.enter("node_1") + await output_queue.queue.put(Mock(name="event_1")) + await tracker.leave("node_1") + await tracker.on_messages_committed(1) + + # Yield control - simulates realistic timing where next node + # starts within the same event loop cycle + await asyncio.sleep(0) + + # Second node picks up and processes - simulate full message lifecycle + await tracker.on_messages_published(1) + await tracker.enter("node_2") + await output_queue.queue.put(Mock(name="event_2")) + await tracker.leave("node_2") + await tracker.on_messages_committed(1) + + # Start the activity simulation + activity_task = asyncio.create_task(simulate_node_activity()) + + # Iterate over the queue + events = [] + async for event in output_queue: + events.append(event) + if len(events) >= 2: + break + + await activity_task + + # Should have received both events + assert len(events) == 2 + + @pytest.mark.asyncio + async def test_anext_terminates_when_truly_idle(self): + """ + Test that __anext__ correctly terminates when no more activity. + """ + tracker = AsyncNodeTracker() + + output_queue = AsyncOutputQueue( + output_topics=[], # Empty - we'll put events directly in queue + consumer_name="test_consumer", + tracker=tracker, + ) + + # Single node processes and finishes - simulate full message lifecycle + async def simulate_single_node(): + await tracker.on_messages_published(1) + await tracker.enter("node_1") + await output_queue.queue.put(Mock(name="event_1")) + await tracker.leave("node_1") + await tracker.on_messages_committed(1) + + activity_task = asyncio.create_task(simulate_single_node()) + + events = [] + async for event in output_queue: + events.append(event) + + await activity_task + + # Should terminate after receiving the single event + assert len(events) == 1 + + @pytest.mark.asyncio + async def test_activity_count_prevents_premature_exit(self): + """ + Test specifically that activity count tracking prevents race condition. + + Scenario: + 1. Node A finishes and tracker goes idle + 2. __anext__ sees idle but activity count changed + 3. Node B starts before __anext__ decides to terminate + 4. All events are properly yielded + """ + tracker = AsyncNodeTracker() + + output_queue = AsyncOutputQueue( + output_topics=[], # Empty - we'll put events directly in queue + consumer_name="test_consumer", + tracker=tracker, + ) + + events_received = [] + iteration_complete = asyncio.Event() + + async def consumer(): + async for event in output_queue: + events_received.append(event) + iteration_complete.set() + + async def producer(): + # Node A processes - simulate full message lifecycle + await tracker.on_messages_published(1) + await tracker.enter("node_a") + await output_queue.queue.put(Mock(name="event_a")) + await tracker.leave("node_a") + await tracker.on_messages_committed(1) + + # Critical timing window - yield to let consumer check idle state + await asyncio.sleep(0) + + # Node B starts before consumer terminates (if fix works) + # simulate full message lifecycle + await tracker.on_messages_published(1) + await tracker.enter("node_b") + await output_queue.queue.put(Mock(name="event_b")) + await tracker.leave("node_b") + await tracker.on_messages_committed(1) + + consumer_task = asyncio.create_task(consumer()) + producer_task = asyncio.create_task(producer()) + + # Wait for producer to finish + await producer_task + + # Wait a bit for consumer to process + try: + await asyncio.wait_for(iteration_complete.wait(), timeout=1.0) + except asyncio.TimeoutError: + consumer_task.cancel() + try: + await consumer_task + except asyncio.CancelledError: + pass + + # With the fix, we should receive both events + assert len(events_received) == 2, ( + f"Expected 2 events but got {len(events_received)}. " + "Race condition may have caused premature termination." + ) + + @pytest.mark.asyncio + async def test_force_stop_terminates_iteration(self): + """ + Test that force_stop terminates iteration even with uncommitted messages. + """ + tracker = AsyncNodeTracker() + output_queue = AsyncOutputQueue( + output_topics=[], + consumer_name="test_consumer", + tracker=tracker, + ) + + # Publish messages but don't commit them (simulates incomplete work) + await tracker.on_messages_published(5) + + # Not quiescent because uncommitted > 0 + assert not await tracker.is_quiescent() + assert (await tracker.get_metrics())["uncommitted_messages"] == 5 + + # Start iteration in background + events = [] + iteration_complete = asyncio.Event() + + async def iterate(): + async for event in output_queue: + events.append(event) + iteration_complete.set() + + iteration_task = asyncio.create_task(iterate()) + + # Give iteration a chance to start waiting + await asyncio.sleep(0.05) + + # Force stop should terminate iteration + await tracker.force_stop() + + # Wait for iteration to complete + try: + await asyncio.wait_for(iteration_complete.wait(), timeout=1.0) + except asyncio.TimeoutError: + iteration_task.cancel() + pytest.fail("Force stop did not terminate iteration within timeout") + + await iteration_task + assert events == [] + + @pytest.mark.asyncio + async def test_force_stop_yields_queued_events_before_terminating(self): + """ + Test that force_stop yields any queued events before terminating. + """ + tracker = AsyncNodeTracker() + output_queue = AsyncOutputQueue( + output_topics=[], + consumer_name="test_consumer", + tracker=tracker, + ) + + # Simulate work with uncommitted messages + await tracker.on_messages_published(5) + + # Queue some events + await output_queue.queue.put(Mock(name="event_1")) + await output_queue.queue.put(Mock(name="event_2")) + + events = [] + iteration_complete = asyncio.Event() + + async def iterate(): + async for event in output_queue: + events.append(event) + iteration_complete.set() + + iteration_task = asyncio.create_task(iterate()) + + # Give iteration a chance to get the queued events + await asyncio.sleep(0.05) + + # Force stop + await tracker.force_stop() + + # Wait for iteration to complete + await asyncio.wait_for(iteration_complete.wait(), timeout=1.0) + await iteration_task + + # Should have received the queued events before terminating + assert len(events) == 2 diff --git a/tests/workflow/test_event_driven_workflow.py b/tests/workflow/test_event_driven_workflow.py index f6b5405..510dba0 100644 --- a/tests/workflow/test_event_driven_workflow.py +++ b/tests/workflow/test_event_driven_workflow.py @@ -310,10 +310,10 @@ def workflow_for_initial_test(self): return EventDrivenWorkflow(nodes={"test_node": node}) - def test_initial_workflow_method_exists(self, workflow_for_initial_test): - """Test that initial_workflow method exists.""" - assert hasattr(workflow_for_initial_test, "initial_workflow") - assert callable(workflow_for_initial_test.initial_workflow) + def test_init_workflow_method_exists(self, workflow_for_initial_test): + """init_workflow should be available for restoring workflow state.""" + assert hasattr(workflow_for_initial_test, "init_workflow") + assert callable(workflow_for_initial_test.init_workflow) class TestEventDrivenWorkflowToDict: @@ -369,19 +369,26 @@ async def test_tracker_reset_on_init(self, workflow_with_tracker): """Test that tracker is reset on workflow initialization.""" # Add some activity to tracker await workflow_with_tracker._tracker.enter("test_node") - assert not workflow_with_tracker._tracker.is_idle() + assert not await workflow_with_tracker._tracker.is_idle() # Call init_workflow which should reset tracker invoke_context = InvokeContext( conversation_id="test", invoke_id="test", assistant_request_id="test" ) - with patch("grafi.common.containers.container.container"): + with patch( + "grafi.workflows.impl.event_driven_workflow.container" + ) as mock_container: + mock_event_store = Mock() + mock_event_store.get_agent_events = AsyncMock(return_value=[]) + mock_event_store.record_events = AsyncMock() + mock_event_store.record_event = AsyncMock() + mock_container.event_store = mock_event_store await workflow_with_tracker.init_workflow( PublishToTopicEvent(invoke_context=invoke_context, data=[]) ) # Tracker should be reset - assert workflow_with_tracker._tracker.is_idle() + assert await workflow_with_tracker._tracker.is_idle() class TestEventDrivenWorkflowStopFlag: diff --git a/tests/workflow/test_utils.py b/tests/workflow/test_utils.py index ebdd7e5..fb8853e 100644 --- a/tests/workflow/test_utils.py +++ b/tests/workflow/test_utils.py @@ -200,6 +200,7 @@ async def test_publish_events(self): # Mock node and topics mock_topic1 = AsyncMock(spec=TopicBase) mock_topic2 = AsyncMock(spec=TopicBase) + tracker = AsyncMock() node = MagicMock(spec=Node) node.name = "test_node" @@ -237,13 +238,16 @@ async def test_publish_events(self): consumed_event_ids=[event.event_id for event in consumed_events], ) - published_events = await publish_events(node, publish_to_event) + published_events = await publish_events(node, publish_to_event, tracker) assert len(published_events) == 1 assert published_events[0] == mock_event1 # Verify topics were called correctly mock_topic1.publish_data.assert_called_once_with(publish_to_event) + tracker.on_messages_published.assert_awaited_once_with( + 1, source="node:test_node" + ) class TestGetNodeInput: diff --git a/tests_integration/agents/run_agents.py b/tests_integration/agents/run_agents.py index 16d3c21..0989203 100644 --- a/tests_integration/agents/run_agents.py +++ b/tests_integration/agents/run_agents.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/embedding_assistant/run_embedding_assistant.py b/tests_integration/embedding_assistant/run_embedding_assistant.py index 30325ab..5f6858f 100644 --- a/tests_integration/embedding_assistant/run_embedding_assistant.py +++ b/tests_integration/embedding_assistant/run_embedding_assistant.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/event_store_postgres/run_event_store_postgres.py b/tests_integration/event_store_postgres/run_event_store_postgres.py index fdb4ffd..1b45356 100644 --- a/tests_integration/event_store_postgres/run_event_store_postgres.py +++ b/tests_integration/event_store_postgres/run_event_store_postgres.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/function_assistant/run_function_assistant.py b/tests_integration/function_assistant/run_function_assistant.py index f059cbd..496c367 100644 --- a/tests_integration/function_assistant/run_function_assistant.py +++ b/tests_integration/function_assistant/run_function_assistant.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/function_call_assistant/run_function_call_assistant.py b/tests_integration/function_call_assistant/run_function_call_assistant.py index 8765391..ac84de7 100644 --- a/tests_integration/function_call_assistant/run_function_call_assistant.py +++ b/tests_integration/function_call_assistant/run_function_call_assistant.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/hith_assistant/kyc_assistant_example.py b/tests_integration/hith_assistant/kyc_assistant_example.py index 9a465ce..4a11bfe 100644 --- a/tests_integration/hith_assistant/kyc_assistant_example.py +++ b/tests_integration/hith_assistant/kyc_assistant_example.py @@ -108,7 +108,7 @@ async def test_kyc_assistant() -> None: ) ] - output = await async_func_wrapper( + outputs = await async_func_wrapper( assistant.invoke( PublishToTopicEvent( invoke_context=get_invoke_context(), @@ -117,7 +117,7 @@ async def test_kyc_assistant() -> None: ) ) - print(output) + print(outputs) human_input = [ Message( @@ -131,6 +131,7 @@ async def test_kyc_assistant() -> None: PublishToTopicEvent( invoke_context=get_invoke_context(), data=human_input, + consumed_event_ids=[event.event_id for event in outputs], ) ) ) diff --git a/tests_integration/hith_assistant/run_hith_assistant.py b/tests_integration/hith_assistant/run_hith_assistant.py index 9db1407..f6cd9d1 100644 --- a/tests_integration/hith_assistant/run_hith_assistant.py +++ b/tests_integration/hith_assistant/run_hith_assistant.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/input_output_topics/run_input_output_topics.py b/tests_integration/input_output_topics/run_input_output_topics.py index 1e663a3..7f39b9c 100644 --- a/tests_integration/input_output_topics/run_input_output_topics.py +++ b/tests_integration/input_output_topics/run_input_output_topics.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/invoke_kwargs/run_invoke_kwargs.py b/tests_integration/invoke_kwargs/run_invoke_kwargs.py index 3fed0d4..a7c2212 100644 --- a/tests_integration/invoke_kwargs/run_invoke_kwargs.py +++ b/tests_integration/invoke_kwargs/run_invoke_kwargs.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/mcp_assistant/run_mcp_assistant.py b/tests_integration/mcp_assistant/run_mcp_assistant.py index 8a06451..1d43a05 100644 --- a/tests_integration/mcp_assistant/run_mcp_assistant.py +++ b/tests_integration/mcp_assistant/run_mcp_assistant.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/multimodal_assistant/run_multimodal_assistant.py b/tests_integration/multimodal_assistant/run_multimodal_assistant.py index cfb0830..ff921f1 100644 --- a/tests_integration/multimodal_assistant/run_multimodal_assistant.py +++ b/tests_integration/multimodal_assistant/run_multimodal_assistant.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/rag_assistant/run_rag_assistant.py b/tests_integration/rag_assistant/run_rag_assistant.py index 5b04b6c..bc98a15 100644 --- a/tests_integration/rag_assistant/run_rag_assistant.py +++ b/tests_integration/rag_assistant/run_rag_assistant.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/react_assistant/run_react_assistant.py b/tests_integration/react_assistant/run_react_assistant.py index 2af2eda..73247a4 100644 --- a/tests_integration/react_assistant/run_react_assistant.py +++ b/tests_integration/react_assistant/run_react_assistant.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/run_all.py b/tests_integration/run_all.py index 9dcdee5..8dcaec9 100644 --- a/tests_integration/run_all.py +++ b/tests_integration/run_all.py @@ -2,13 +2,30 @@ """Run all integration tests by executing run_*.py scripts in each subfolder.""" import argparse +import importlib.util import io -import subprocess import sys from pathlib import Path +from textwrap import indent -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) + + +def _load_runner_module(script: Path): + """Load a run_*.py file as a module so we can call run_scripts directly.""" + module_name = f"tests_integration.{script.parent.name}.{script.stem}_runner" + spec = importlib.util.spec_from_file_location(module_name, script) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to load spec for {script}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module def run_all_scripts(pass_local: bool = True) -> int: @@ -23,12 +40,14 @@ def run_all_scripts(pass_local: bool = True) -> int: """ python_executable = sys.executable current_directory = Path(__file__).parent + repo_root = current_directory.parent # Find all run_*.py scripts in subdirectories run_scripts = sorted(current_directory.glob("*/run_*.py")) - passed_folders = [] - failed_folders = {} + passed_examples = [] + failed_examples = {} + skipped_examples = [] print(f"Found {len(run_scripts)} test runners:") for script in run_scripts: @@ -42,37 +61,86 @@ def run_all_scripts(pass_local: bool = True) -> int: print(f"Running tests in: {folder_name}") print(f"{'=' * 60}") - cmd = [python_executable, str(script)] - if not pass_local: - cmd.append("--no-pass-local") - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - check=True, - cwd=script.parent, + runner_module = _load_runner_module(script) + runner_results = runner_module.run_scripts( + pass_local=pass_local, collect=True ) - print(result.stdout) - passed_folders.append(folder_name) - except subprocess.CalledProcessError as e: - print(f"Output:\n{e.stdout}") - print(f"Error:\n{e.stderr}") - failed_folders[folder_name] = e.stderr + except Exception as exc: # noqa: BLE001 + example_rel = script.relative_to(repo_root) + error_message = f"Runner failed before executing examples: {exc}" + print(f" ✗ {example_rel}") + print(f" Error: {error_message}") + failed_examples[example_rel] = { + "error": error_message, + "output": "", + "rerun_cmd": f"{python_executable} {example_rel}", + } + continue + + if not isinstance(runner_results, list): + example_rel = script.relative_to(repo_root) + error_message = "Runner did not return result details." + print(f" ✗ {example_rel}") + print(f" Error: {error_message}") + failed_examples[example_rel] = { + "error": error_message, + "output": "", + "rerun_cmd": f"{python_executable} {example_rel}", + } + continue + + for result in runner_results: + example_rel = (script.parent / result["name"]).relative_to(repo_root) + status = result.get("status", "unknown") + output = result.get("output", "").rstrip() + error = result.get("error", "").rstrip() + + if status == "passed": + print(f" ✓ {example_rel}") + if output: + print(indent(output, " ")) + passed_examples.append(example_rel) + elif status == "failed": + print(f" ✗ {example_rel}") + if output: + print(" Output:") + print(indent(output, " ")) + if error: + print(" Error:") + print(indent(error, " ")) + rerun_cmd = f"{python_executable} {example_rel}" + print(f" Rerun with: {rerun_cmd}") + failed_examples[example_rel] = { + "error": error, + "output": output, + "rerun_cmd": rerun_cmd, + } + else: + print(f" - {example_rel} (skipped)") + if error: + print(f" Reason: {error}") + skipped_examples.append(example_rel) # Summary print("\n" + "=" * 60) print("FINAL SUMMARY") print("=" * 60) - print(f"\nPassed folders: {len(passed_folders)}") - for folder in passed_folders: - print(f" ✓ {folder}") - - if failed_folders: - print(f"\nFailed folders: {len(failed_folders)}") - for folder in failed_folders: - print(f" ✗ {folder}") + print(f"\nPassed examples: {len(passed_examples)}") + for example in passed_examples: + print(f" ✓ {example}") + + if skipped_examples: + print(f"\nSkipped examples: {len(skipped_examples)}") + for example in skipped_examples: + print(f" - {example}") + + if failed_examples: + print(f"\nFailed examples: {len(failed_examples)}") + for example, data in failed_examples.items(): + print(f" ✗ {example}") + if data.get("rerun_cmd"): + print(f" Rerun with: {data['rerun_cmd']}") return 1 print("\nAll integration tests passed!") diff --git a/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py b/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py index f51edd2..8e6c714 100644 --- a/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py +++ b/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py b/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py index 16c23df..990d5ae 100644 --- a/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py +++ b/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py @@ -7,17 +7,23 @@ from pathlib import Path -sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") +try: + sys.stdout.reconfigure(encoding="utf-8") +except AttributeError: + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, encoding="utf-8", write_through=True + ) -def run_scripts(pass_local: bool = True) -> int: +def run_scripts(pass_local: bool = True, collect: bool = False): """Run all example scripts in this directory. Args: pass_local: If True, skip tests with 'ollama' or 'local' in their name. + collect: If True, return per-script results without printing. Returns: - Exit code (0 for success, 1 for failure). + List of per-script results if collect is True, otherwise exit code (0 for success, 1 for failure). """ python_executable = sys.executable current_directory = Path(__file__).parent @@ -25,16 +31,26 @@ def run_scripts(pass_local: bool = True) -> int: # Find all example files example_files = sorted(current_directory.glob("*_example.py")) - passed_scripts = [] - failed_scripts = {} + results = [] for file in example_files: filename = file.name if pass_local and ("ollama" in filename or "_local" in filename): - print(f"Skipping {filename} (local test)") + message = f"Skipping {filename} (local test)" + if not collect: + print(message) + results.append( + { + "name": filename, + "status": "skipped", + "output": "", + "error": message, + } + ) continue - print(f"Running {filename}...") + if not collect: + print(f"Running {filename}...") try: result = subprocess.run( [python_executable, str(file)], @@ -43,23 +59,44 @@ def run_scripts(pass_local: bool = True) -> int: check=True, cwd=current_directory, ) - print(f"Output of {filename}:\n{result.stdout}") - passed_scripts.append(filename) + if not collect: + print(f"Output of {filename}:\n{result.stdout}") + results.append( + { + "name": filename, + "status": "passed", + "output": result.stdout, + "error": "", + } + ) except subprocess.CalledProcessError as e: - print(f"Error running {filename}:\n{e.stderr}") - failed_scripts[filename] = e.stderr + if not collect: + print(f"Error running {filename}:\n{e.stderr}") + results.append( + { + "name": filename, + "status": "failed", + "output": e.stdout, + "error": e.stderr, + } + ) + + if collect: + return results + + passed_scripts = [r for r in results if r["status"] == "passed"] + failed_scripts = [r for r in results if r["status"] == "failed"] - # Summary print("\n" + "=" * 50) print("Summary:") print(f"Passed: {len(passed_scripts)}") for script in passed_scripts: - print(f" ✓ {script}") + print(f" ✓ {script['name']}") if failed_scripts: print(f"\nFailed: {len(failed_scripts)}") for script in failed_scripts: - print(f" ✗ {script}") + print(f" ✗ {script['name']}") return 1 return 0 diff --git a/uv.lock b/uv.lock index 0cc3438..f8079d1 100644 --- a/uv.lock +++ b/uv.lock @@ -1283,7 +1283,7 @@ wheels = [ [[package]] name = "grafi" -version = "0.0.33" +version = "0.0.34" source = { editable = "." } dependencies = [ { name = "anyio" },