diff --git a/dynamiq/cache/utils.py b/dynamiq/cache/utils.py index d07af6840..25331bb25 100644 --- a/dynamiq/cache/utils.py +++ b/dynamiq/cache/utils.py @@ -1,3 +1,4 @@ +import asyncio from functools import wraps from typing import Any, Callable @@ -82,3 +83,65 @@ def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, bool]: return wrapper return _cache + + +def cache_wf_entity_async( + entity_id: str, + cache_enabled: bool = False, + cache_manager_cls: type[WorkflowCacheManager] = WorkflowCacheManager, + cache_config: CacheConfig | None = None, + func_kwargs_to_remove: tuple[str] = FUNC_KWARGS_TO_REMOVE, +) -> Callable: + """Async decorator to cache workflow entity outputs. + + Like cache_wf_entity but wraps an async function. Cache I/O (Redis get/set) + is offloaded to threads via asyncio.to_thread to avoid blocking the event loop. + + Args: + entity_id (str): Identifier for the entity. + cache_enabled (bool): Flag to enable caching. + cache_manager_cls (type[WorkflowCacheManager]): Cache manager class. + cache_config (CacheConfig | None): Cache configuration. + func_kwargs_to_remove (tuple[str]): List of params to remove from callable function kwargs. + + Returns: + Callable: Wrapped async function with caching. + """ + def _cache(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> tuple[Any, bool]: + cache_manager = None + from_cache = False + input_data = kwargs.pop("input_data", args[0] if args else {}) + input_data = dict(input_data) if isinstance(input_data, BaseModel) else input_data + + cleaned_kwargs = {k: v for k, v in kwargs.items() if k not in func_kwargs_to_remove} + if cache_enabled and cache_config: + logger.debug(f"Entity_id {entity_id}: async cache used") + cache_manager = cache_manager_cls(config=cache_config) + output = await asyncio.to_thread( + cache_manager.get_entity_output, + entity_id=entity_id, + input_data=input_data, + **cleaned_kwargs, + ) + if output is not None: + from_cache = True + return output, from_cache + + output = await func(*args, **kwargs) + + if cache_manager: + await asyncio.to_thread( + cache_manager.set_entity_output, + entity_id=entity_id, + input_data=input_data, + output_data=output, + **cleaned_kwargs, + ) + + return output, from_cache + + return wrapper + + return _cache diff --git a/dynamiq/flows/flow.py b/dynamiq/flows/flow.py index 26347f3ff..14c666587 100644 --- a/dynamiq/flows/flow.py +++ b/dynamiq/flows/flow.py @@ -11,6 +11,7 @@ from dynamiq.connections.managers import ConnectionManager from dynamiq.executors.base import BaseExecutor +from dynamiq.executors.context import ContextAwareThreadPoolExecutor from dynamiq.executors.pool import ThreadExecutor from dynamiq.flows.base import BaseFlow from dynamiq.nodes.node import Node, NodeReadyToRun @@ -360,13 +361,8 @@ async def run_async(self, input_data: Any, config: RunnableConfig = None, **kwar """ Run the flow asynchronously with the given input data and configuration. - Args: - input_data (Any): Input data for the flow. - config (RunnableConfig, optional): Configuration for the run. Defaults to None. - **kwargs: Additional keyword arguments. - - Returns: - RunnableResult: Result of the flow execution. + Creates a dedicated ContextAwareThreadPoolExecutor for this flow run, + isolating it from other concurrent flow executions. """ self.reset_run_state() run_id = uuid4() @@ -375,6 +371,9 @@ async def run_async(self, input_data: Any, config: RunnableConfig = None, **kwar "parent_run_id": kwargs.get("parent_run_id", run_id), } + max_workers = (config.max_node_workers if config else None) or self.max_node_workers + executor = ContextAwareThreadPoolExecutor(max_workers=max_workers) + logger.info(f"Flow {self.id}: execution started.") self.run_on_flow_start(input_data, config, **merged_kwargs) time_start = datetime.now() @@ -391,6 +390,7 @@ async def run_async(self, input_data: Any, config: RunnableConfig = None, **kwar input_data=node.input_data, depends_result=node.depends_result, config=config, + executor=executor, **(merged_kwargs | {"parent_run_id": run_id}), ) for node in nodes_to_run @@ -435,6 +435,8 @@ async def run_async(self, input_data: Any, config: RunnableConfig = None, **kwar error=RunnableResultError.from_exception(e, failed_nodes=failed_nodes), ) finally: + # wait=False is safe: all node tasks have been awaited via asyncio.gather() + executor.shutdown(wait=False) try: await self._cleanup_dry_run_async(config) except Exception as e: diff --git a/dynamiq/nodes/llms/base.py b/dynamiq/nodes/llms/base.py index 6227e6b30..893d64976 100644 --- a/dynamiq/nodes/llms/base.py +++ b/dynamiq/nodes/llms/base.py @@ -28,6 +28,8 @@ from dynamiq.utils.logger import logger if TYPE_CHECKING: + from concurrent.futures import ThreadPoolExecutor + from litellm import CustomStreamWrapper, ModelResponse @@ -210,6 +212,7 @@ class BaseLLM(ConnectionNode): model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) _completion: Callable = PrivateAttr() + _acompletion: Callable = PrivateAttr() _stream_chunk_builder: Callable = PrivateAttr() _is_fallback_run: bool = PrivateAttr(default=False) _json_schema_fields: ClassVar[list[str]] = ["model", "temperature", "max_tokens", "prompt"] @@ -261,10 +264,11 @@ def __init__(self, **kwargs): super().__init__(**kwargs) # Save a bit of loading time as litellm is slow - from litellm import completion, stream_chunk_builder + from litellm import acompletion, completion, stream_chunk_builder # Avoid the same imports multiple times and for future usage in execute self._completion = completion + self._acompletion = acompletion self._stream_chunk_builder = stream_chunk_builder def init_components(self, connection_manager=None): @@ -499,6 +503,36 @@ def _handle_streaming_completion_response( full_response = self._stream_chunk_builder(chunks=chunks, messages=messages) return self._handle_completion_response(response=full_response, config=config, **kwargs) + async def _handle_streaming_completion_response_async( + self, + response: Union["ModelResponse", "CustomStreamWrapper"], + messages: list[dict], + config: RunnableConfig = None, + **kwargs, + ): + """Handle async streaming completion response. + + Args: + response (ModelResponse | CustomStreamWrapper): The async streaming response from the LLM. + messages (list[dict]): The messages used for the LLM. + config (RunnableConfig, optional): The configuration for the execution. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + dict: A dictionary containing the generated content and tool calls. + """ + chunks = [] + async for chunk in response: + chunks.append(chunk) + self.run_on_node_execute_stream( + config.callbacks, + chunk.model_dump(), + **kwargs, + ) + + full_response = self._stream_chunk_builder(chunks=chunks, messages=messages) + return self._handle_completion_response(response=full_response, config=config, **kwargs) + def _get_response_format_and_tools( self, prompt: Prompt | None = None, @@ -559,43 +593,34 @@ def update_completion_params(self, params: dict[str, Any]) -> dict[str, Any]: params["stream_options"]["include_usage"] = True return params - def execute( + def _build_completion_params( self, - input_data: BaseLLMInputSchema, - config: RunnableConfig = None, + messages: list[dict], + config: RunnableConfig, prompt: Prompt | None = None, tools: list[Tool | dict] | None = None, response_format: dict[str, Any] | None = None, parallel_tool_calls: bool | None = None, - **kwargs, - ): - """Execute the LLM node. - - This method processes the input data, formats the prompt, and generates a response using - the configured LLM. + include_sync_client: bool = True, + ) -> dict[str, Any]: + """Build the common parameter dict for litellm completion/acompletion calls. Args: - input_data (BaseLLMInputSchema): The input data for the LLM. - config (RunnableConfig, optional): The configuration for the execution. Defaults to None. - prompt (Prompt, optional): The prompt to use for this execution. Defaults to None. - tools (list[Tool|dict]): List of tools that llm can call. - response_format (dict[str, Any]): JSON schema that specifies the structure of the llm's output - parallel_tool_calls (bool | None): Whether to allow the LLM to return multiple tool calls - in a single response. None means provider decides. - **kwargs: Additional keyword arguments. + messages: Formatted prompt messages. + config: Runnable configuration (used to detect streaming callbacks). + prompt: Prompt with optional tools/response_format overrides. + tools: Explicit tool list override. + response_format: Explicit response format override. + parallel_tool_calls: Whether to allow parallel tool calls. + include_sync_client: If True and self.client exists, include it in params. + Set to False for async calls that should not receive the sync client. Returns: - dict: A dictionary containing the generated content and tool calls. + Dict of params ready to pass to _completion or _acompletion. """ - config = ensure_config(config) - self.reset_run_state() - prompt = prompt or self.prompt or Prompt(messages=[], tools=None, response_format=None) - messages = self.get_messages(prompt, input_data) - self.run_on_node_execute_run(callbacks=config.callbacks, prompt_messages=messages, **kwargs) - extra = copy.deepcopy(self.__pydantic_extra__) params = self.connection.conn_params.copy() - if self.client and not isinstance(self.connection, HttpApiKey): + if include_sync_client and self.client and not isinstance(self.connection, HttpApiKey): params.update({"client": self.client}) if self.thinking_enabled: params.update({"thinking": {"type": "enabled", "budget_tokens": self.budget_tokens}}) @@ -607,8 +632,8 @@ def execute( tools=tools, response_format=response_format, ) - # Check if a streaming callback is available in the config and enable streaming only if it is - # This is to avoid unnecessary streaming to reduce CPU usage + # Check if a streaming callback is available in the config and enable streaming only if it is. + # This is to avoid unnecessary streaming to reduce CPU usage. is_streaming_callback_available = any( isinstance(callback, BaseStreamingCallbackHandler) for callback in config.callbacks ) @@ -632,12 +657,56 @@ def execute( if parallel_tool_calls is not None: common_params["parallel_tool_calls"] = parallel_tool_calls - common_params = self.update_completion_params(common_params) + return self.update_completion_params(common_params) + + def execute( + self, + input_data: BaseLLMInputSchema, + config: RunnableConfig = None, + prompt: Prompt | None = None, + tools: list[Tool | dict] | None = None, + response_format: dict[str, Any] | None = None, + parallel_tool_calls: bool | None = None, + **kwargs, + ): + """Execute the LLM node. + + This method processes the input data, formats the prompt, and generates a response using + the configured LLM. + + Args: + input_data (BaseLLMInputSchema): The input data for the LLM. + config (RunnableConfig, optional): The configuration for the execution. Defaults to None. + prompt (Prompt, optional): The prompt to use for this execution. Defaults to None. + tools (list[Tool|dict]): List of tools that llm can call. + response_format (dict[str, Any]): JSON schema that specifies the structure of the llm's output + parallel_tool_calls (bool | None): Whether to allow the LLM to return multiple tool calls + in a single response. None means provider decides. + **kwargs: Additional keyword arguments. + + Returns: + dict: A dictionary containing the generated content and tool calls. + """ + config = ensure_config(config) + self.reset_run_state() + prompt = prompt or self.prompt or Prompt(messages=[], tools=None, response_format=None) + messages = self.get_messages(prompt, input_data) + self.run_on_node_execute_run(callbacks=config.callbacks, prompt_messages=messages, **kwargs) + + common_params = self._build_completion_params( + messages=messages, + config=config, + prompt=prompt, + tools=tools, + response_format=response_format, + parallel_tool_calls=parallel_tool_calls, + include_sync_client=True, + ) response = self._completion(**common_params) handle_completion = ( self._handle_streaming_completion_response - if self.streaming.enabled and is_streaming_callback_available + if common_params.get("stream") else self._handle_completion_response ) @@ -645,6 +714,60 @@ def execute( response=response, messages=messages, config=config, input_data=dict(input_data), **kwargs ) + async def execute_async( + self, + input_data: BaseLLMInputSchema, + config: RunnableConfig = None, + prompt: Prompt | None = None, + tools: list[Tool | dict] | None = None, + response_format: dict[str, Any] | None = None, + parallel_tool_calls: bool | None = None, + **kwargs, + ): + """Execute the LLM node asynchronously using litellm.acompletion. + + This method mirrors execute() but uses await self._acompletion(...) + and async streaming iteration. + + Args: + input_data (BaseLLMInputSchema): The input data for the LLM. + config (RunnableConfig, optional): The configuration for the execution. Defaults to None. + prompt (Prompt, optional): The prompt to use for this execution. Defaults to None. + tools (list[Tool|dict]): List of tools that llm can call. + response_format (dict[str, Any]): JSON schema that specifies the structure of the llm's output. + parallel_tool_calls (bool | None): Whether to allow the LLM to return multiple tool calls + in a single response. None means provider decides. + **kwargs: Additional keyword arguments. + + Returns: + dict: A dictionary containing the generated content and tool calls. + """ + config = ensure_config(config) + self.reset_run_state() + prompt = prompt or self.prompt or Prompt(messages=[], tools=None, response_format=None) + messages = self.get_messages(prompt, input_data) + self.run_on_node_execute_run(callbacks=config.callbacks, prompt_messages=messages, **kwargs) + + common_params = self._build_completion_params( + messages=messages, + config=config, + prompt=prompt, + tools=tools, + response_format=response_format, + parallel_tool_calls=parallel_tool_calls, + include_sync_client=False, + ) + + response = await self._acompletion(**common_params) + + if common_params.get("stream"): + return await self._handle_streaming_completion_response_async( + response=response, messages=messages, config=config, input_data=dict(input_data), **kwargs + ) + return self._handle_completion_response( + response=response, messages=messages, config=config, input_data=dict(input_data), **kwargs + ) + def _is_rate_limit_error(self, exception_type: type[Exception], error_str: str) -> bool: """Check if the error is a rate limit error. @@ -772,3 +895,78 @@ def run_sync( logger.error(f"LLM {self.name} - {self.id}: Fallback LLM ({fallback_llm.model}) failed.") return result + + async def run_async( + self, + input_data: dict, + config: RunnableConfig = None, + depends_result: dict = None, + executor: "ThreadPoolExecutor | None" = None, + **kwargs, + ) -> RunnableResult: + """Run the LLM asynchronously with fallback support. + + If the primary LLM fails and a fallback is configured, the primary failure + is traced first, then the fallback LLM is executed separately. + + The fallback receives the same transformed input that the primary received, + and the primary's output_transformer is applied to the fallback's output. + + Args: + input_data: Input data for the LLM. + config: Configuration for the run. + depends_result: Results of dependent nodes. + executor: Optional thread pool executor for sync fallback. + **kwargs: Additional keyword arguments. + + Returns: + RunnableResult: Result of the LLM execution. + """ + result = await super().run_async( + input_data=input_data, config=config, depends_result=depends_result, + executor=executor, **kwargs + ) + + if result.status != RunnableStatus.FAILURE: + return result + + if not self.fallback or not self.fallback.llm: + return result + + if not result.error: + return result + + if not self._should_trigger_fallback(result.error.type, result.error.message): + return result + + fallback_llm = self.fallback.llm + fallback_llm._is_fallback_run = True + logger.warning( + f"LLM {self.name} - {self.id}: Primary LLM ({self.model}) failed. " + f"Error: {result.error.type.__name__}: {result.error.message}. " + f"Attempting fallback to {fallback_llm.name} - {fallback_llm.id}" + ) + + fallback_kwargs = {k: v for k, v in kwargs.items() if k != "run_depends"} + fallback_kwargs["parent_run_id"] = kwargs.get("parent_run_id") + + fallback_input = result.input.model_dump() if hasattr(result.input, "model_dump") else result.input + fallback_result = await fallback_llm.run_async( + input_data=fallback_input, + config=config, + depends_result=None, + executor=executor, + **fallback_kwargs, + ) + + if fallback_result.status == RunnableStatus.SUCCESS: + logger.info(f"LLM {self.name} - {self.id}: Fallback LLM ({fallback_llm.model}) succeeded") + transformed_output = self.transform_output(fallback_result.output, config=config, **kwargs) + return RunnableResult( + status=RunnableStatus.SUCCESS, + input=result.input, + output=transformed_output, + ) + + logger.error(f"LLM {self.name} - {self.id}: Fallback LLM ({fallback_llm.model}) failed.") + return result diff --git a/dynamiq/nodes/node.py b/dynamiq/nodes/node.py index a0b3d0c1a..5e298ec0e 100644 --- a/dynamiq/nodes/node.py +++ b/dynamiq/nodes/node.py @@ -1,5 +1,7 @@ import asyncio +import contextvars import copy +import functools import inspect import time from abc import ABC, abstractmethod @@ -8,13 +10,13 @@ from functools import cached_property from queue import Empty from types import FunctionType, ModuleType -from typing import Any, Callable, ClassVar, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Union from uuid import uuid4 from jinja2 import Template from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, computed_field, create_model, model_validator -from dynamiq.cache.utils import cache_wf_entity +from dynamiq.cache.utils import cache_wf_entity, cache_wf_entity_async from dynamiq.callbacks import BaseCallbackHandler, NodeCallbackHandler, TracingCallbackHandler from dynamiq.connections import BaseConnection from dynamiq.connections.managers import ConnectionManager, ConnectionManagerException @@ -47,6 +49,9 @@ from dynamiq.utils.logger import logger from dynamiq.utils.utils import clear_annotation +if TYPE_CHECKING: + from concurrent.futures import ThreadPoolExecutor + def ensure_config(config: RunnableConfig = None) -> RunnableConfig: """ @@ -267,6 +272,18 @@ class Node(BaseModel, Runnable, DryRunMixin, ABC): _output_references: NodeOutputReferences = PrivateAttr() + # Set to True in subclasses that manage their own background event loop + # (e.g. CuaDesktopTool, E2BDesktopTool) to force run_async to offload + # via run_in_executor instead of calling execute_async on the main loop. + _force_thread_executor: ClassVar[bool] = False + + @property + def has_native_async(self) -> bool: + """Check if the subclass provides a native async execute implementation.""" + if self._force_thread_executor: + return False + return type(self).execute_async is not Node.execute_async + model_config = ConfigDict(arbitrary_types_allowed=True) input_schema: type[BaseModel] | None = None callbacks: list[NodeCallbackHandler] = [] @@ -476,8 +493,6 @@ def validate_input_schema(self, input_data: dict[str, Any], **kwargs) -> dict[st Raises: NodeException: If input data does not match the input schema. """ - from dynamiq.nodes.agents.exceptions import RecoverableAgentException - if self.input_schema: try: return self.input_schema.model_validate( @@ -485,6 +500,8 @@ def validate_input_schema(self, input_data: dict[str, Any], **kwargs) -> dict[st ) except Exception as e: if kwargs.get("recoverable_error", False): + from dynamiq.nodes.agents.exceptions import RecoverableAgentException + raise RecoverableAgentException(f"Input data validation failed: {e}") raise e @@ -841,6 +858,102 @@ def get_approved_data_or_origin( return input_data + def _prepare_execution( + self, + input_data: dict, + config: RunnableConfig, + depends_result: dict | None, + **kwargs, + ) -> tuple[RunnableConfig, dict, dict]: + """Shared pre-execution setup for run_sync and _run_async_native. + + Returns: + Tuple of (config, merged_kwargs, depends_result). + """ + config = ensure_config(config) + run_id = uuid4() + merged_kwargs = merge(kwargs, {"run_id": run_id, "parent_run_id": kwargs.get("parent_run_id", None)}) + if depends_result is None: + depends_result = {} + return config, merged_kwargs, depends_result + + def _handle_skip( + self, + e: "NodeException", + input_data: dict, + depends_result: dict, + config: RunnableConfig, + **merged_kwargs, + ) -> RunnableResult: + """Handle node skip due to failed dependency or approval rejection.""" + transformed_input = input_data | { + k: result.to_tracing_depend_dict() for k, result in depends_result.items() + } + skip_data = {"failed_dependency": e.failed_depend.to_dict(for_tracing=True)} + self.run_on_node_skip( + callbacks=config.callbacks, + skip_data=skip_data, + input_data=transformed_input, + human_feedback=getattr(e, "human_feedback", None), + **merged_kwargs, + ) + logger.info(f"Node {self.name} - {self.id}: execution skipped.") + return RunnableResult( + status=RunnableStatus.SKIP, + input=transformed_input, + output=None, + error=RunnableResultError.from_exception(e, recoverable=e.recoverable), + ) + + def _handle_success( + self, + output: Any, + from_cache: bool, + transformed_input: dict, + config: RunnableConfig, + time_start: datetime, + log_prefix: str, + merged_kwargs: dict, + **kwargs, + ) -> RunnableResult: + """Handle successful execution — transform output, fire callbacks, return result.""" + callback_kwargs = {**merged_kwargs, "is_output_from_cache": from_cache} + # transform_output uses original kwargs; run_on_node_end needs merged_kwargs with run_id + transformed_output = self.transform_output(output, config=config, **kwargs) + self.run_on_node_end(config.callbacks, transformed_output, **callback_kwargs) + logger.info( + f"Node {self.name} - {self.id}: {log_prefix}execution succeeded in " + f"{format_duration(time_start, datetime.now())}." + ) + return RunnableResult( + status=RunnableStatus.SUCCESS, input=dict(transformed_input), output=transformed_output + ) + + def _handle_failure( + self, + e: Exception, + transformed_input: dict, + config: RunnableConfig, + time_start: datetime, + log_prefix: str, + **merged_kwargs, + ) -> RunnableResult: + """Handle execution failure — fire error callbacks, return failure result.""" + from dynamiq.nodes.agents.exceptions import RecoverableAgentException + + self.run_on_node_error(callbacks=config.callbacks, error=e, input_data=transformed_input, **merged_kwargs) + logger.error( + f"Node {self.name} - {self.id}: {log_prefix}execution failed in " + f"{format_duration(time_start, datetime.now())}. {e}" + ) + recoverable = isinstance(e, RecoverableAgentException) + return RunnableResult( + status=RunnableStatus.FAILURE, + input=transformed_input, + output=None, + error=RunnableResultError.from_exception(e, recoverable=recoverable), + ) + def run_sync( self, input_data: dict, @@ -860,42 +973,19 @@ def run_sync( Returns: RunnableResult: Result of the node execution. """ - from dynamiq.nodes.agents.exceptions import RecoverableAgentException - logger.info(f"Node {self.name} - {self.id}: execution started.") transformed_input = input_data time_start = datetime.now() - - config = ensure_config(config) - - run_id = uuid4() - merged_kwargs = merge(kwargs, {"run_id": run_id, "parent_run_id": kwargs.get("parent_run_id", None)}) - if depends_result is None: - depends_result = {} + config, merged_kwargs, depends_result = self._prepare_execution( + input_data, config, depends_result, **kwargs + ) try: try: self.validate_depends(depends_result) input_data = self.get_approved_data_or_origin(input_data, config=config, **merged_kwargs) except NodeException as e: - transformed_input = input_data | { - k: result.to_tracing_depend_dict() for k, result in depends_result.items() - } - skip_data = {"failed_dependency": e.failed_depend.to_dict(for_tracing=True)} - self.run_on_node_skip( - callbacks=config.callbacks, - skip_data=skip_data, - input_data=transformed_input, - human_feedback=getattr(e, "human_feedback", None), - **merged_kwargs, - ) - logger.info(f"Node {self.name} - {self.id}: execution skipped.") - return RunnableResult( - status=RunnableStatus.SKIP, - input=transformed_input, - output=None, - error=RunnableResultError.from_exception(e, recoverable=e.recoverable), - ) + return self._handle_skip(e, input_data, depends_result, config, **merged_kwargs) transformed_input = self.validate_input_schema( self.transform_input(input_data=input_data, depends_result=depends_result, config=config, **kwargs), @@ -910,57 +1000,105 @@ def run_sync( output, from_cache = cache(self.execute_with_retry)(transformed_input, config, **merged_kwargs) - merged_kwargs["is_output_from_cache"] = from_cache - transformed_output = self.transform_output(output, config=config, **kwargs) + return self._handle_success( + output, from_cache, transformed_input, config, time_start, "", + merged_kwargs=merged_kwargs, **kwargs + ) + except Exception as e: + return self._handle_failure(e, transformed_input, config, time_start, "", **merged_kwargs) + + async def _run_async_native( + self, + input_data: dict, + config: RunnableConfig = None, + depends_result: dict = None, + **kwargs, + ) -> RunnableResult: + """ + Run the node asynchronously using native async execute. + Mirrors run_sync() lifecycle but calls execute_async_with_retry(). + """ + logger.info(f"Node {self.name} - {self.id}: async execution started.") + transformed_input = input_data + time_start = datetime.now() + config, merged_kwargs, depends_result = self._prepare_execution( + input_data, config, depends_result, **kwargs + ) - self.run_on_node_end(config.callbacks, transformed_output, **merged_kwargs) + try: + try: + self.validate_depends(depends_result) + # Offload blocking approval queue read to a thread to avoid blocking the event loop + input_data = await asyncio.to_thread( + self.get_approved_data_or_origin, input_data, config=config, **merged_kwargs + ) + except NodeException as e: + return self._handle_skip(e, input_data, depends_result, config, **merged_kwargs) - logger.info( - f"Node {self.name} - {self.id}: execution succeeded in " - f"{format_duration(time_start, datetime.now())}." + transformed_input = self.validate_input_schema( + self.transform_input(input_data=input_data, depends_result=depends_result, config=config, **kwargs), + **kwargs, ) - return RunnableResult( - status=RunnableStatus.SUCCESS, input=dict(transformed_input), output=transformed_output + self.run_on_node_start(config.callbacks, dict(transformed_input), **merged_kwargs) + + cache = cache_wf_entity_async( + entity_id=self.id, + cache_enabled=self.caching.enabled, + cache_config=config.cache, ) - except Exception as e: - self.run_on_node_error(callbacks=config.callbacks, error=e, input_data=transformed_input, **merged_kwargs) - logger.error( - f"Node {self.name} - {self.id}: execution failed in " - f"{format_duration(time_start, datetime.now())}. {e}" + output, from_cache = await cache(self.execute_async_with_retry)( + transformed_input, config, **merged_kwargs ) - recoverable = isinstance(e, RecoverableAgentException) - result = RunnableResult( - status=RunnableStatus.FAILURE, - input=transformed_input, - output=None, - error=RunnableResultError.from_exception(e, recoverable=recoverable), + return self._handle_success( + output, from_cache, transformed_input, config, time_start, "async ", + merged_kwargs=merged_kwargs, **kwargs ) - return result + except Exception as e: + return self._handle_failure(e, transformed_input, config, time_start, "async ", **merged_kwargs) async def run_async( self, input_data: dict, config: RunnableConfig = None, depends_result: dict = None, + executor: "ThreadPoolExecutor | None" = None, **kwargs, ) -> RunnableResult: """ Run the node asynchronously with given input data and configuration. - This runs the synchronous implementation in a thread pool to avoid blocking the event loop. + + If the node has a native async execute implementation (has_native_async), + runs directly on the event loop. Otherwise, offloads sync execution to + the provided executor (or the default asyncio executor if None). Args: input_data (Any): Input data for the node. - config (RunnableConfig, optional): Configuration for the run. Defaults to None. - depends_result (dict, optional): Results of dependent nodes. Defaults to None. + config (RunnableConfig, optional): Configuration for the run. + depends_result (dict, optional): Results of dependent nodes. + executor (ThreadPoolExecutor, optional): Thread pool executor for sync fallback. **kwargs: Additional keyword arguments. Returns: RunnableResult: Result of the node execution. """ - return await asyncio.to_thread( - self.run_sync, input_data=input_data, config=config, depends_result=depends_result, **kwargs - ) + if self.has_native_async: + return await self._run_async_native( + input_data=input_data, config=config, depends_result=depends_result, **kwargs + ) + else: + loop = asyncio.get_running_loop() + fn = functools.partial( + self.run_sync, input_data=input_data, config=config, + depends_result=depends_result, **kwargs + ) + # ContextAwareThreadPoolExecutor already propagates contextvars + # in its submit() method — no need to copy context again. + if isinstance(executor, ContextAwareThreadPoolExecutor): + return await loop.run_in_executor(executor, fn) + else: + ctx = contextvars.copy_context() + return await loop.run_in_executor(executor, ctx.run, fn) def ensure_client(self) -> None: """ @@ -1092,6 +1230,74 @@ def execute_with_timeout( future.cancel() raise + async def execute_async_with_retry( + self, input_data: dict[str, Any] | BaseModel, config: RunnableConfig = None, **kwargs + ): + """ + Execute the node asynchronously with retry logic. + Uses asyncio.wait_for for timeout instead of thread-based timeout. + Uses asyncio.sleep for non-blocking retry backoff. + """ + config = ensure_config(config) + timeout = self.error_handling.timeout_seconds + error = None + n_attempt = self.error_handling.max_retries + 1 + + for attempt in range(n_attempt): + merged_kwargs = merge(kwargs, {"execution_run_id": uuid4()}) + + try: + # Offload blocking client initialization to a thread to avoid blocking the event loop + await asyncio.to_thread(self.ensure_client) + except Exception as conn_error: + logger.error( + f"Node {self.name} - {self.id}: Failed to ensure client connection: {conn_error}" + ) + error = conn_error + if attempt < n_attempt - 1: + time_to_sleep = self.error_handling.retry_interval_seconds * ( + self.error_handling.backoff_rate ** attempt + ) + logger.info( + f"Node {self.name} - {self.id}: retrying connection in {time_to_sleep} seconds." + ) + await asyncio.sleep(time_to_sleep) + continue + else: + raise + + self.run_on_node_execute_start(config.callbacks, input_data, **merged_kwargs) + + try: + if timeout is not None: + output = await asyncio.wait_for( + self.execute_async(input_data=input_data, config=config, **merged_kwargs), + timeout=timeout, + ) + else: + output = await self.execute_async(input_data=input_data, config=config, **merged_kwargs) + + self.run_on_node_execute_end(config.callbacks, output, **merged_kwargs) + return output + except asyncio.TimeoutError as e: + error = e + self.run_on_node_execute_error(config.callbacks, error, **merged_kwargs) + logger.warning(f"Node {self.name} - {self.id}: timeout.") + except Exception as e: + error = e + self.run_on_node_execute_error(config.callbacks, error, **merged_kwargs) + logger.error(f"Node {self.name} - {self.id}: execution error: {e}") + + if attempt < n_attempt - 1: + time_to_sleep = self.error_handling.retry_interval_seconds * ( + self.error_handling.backoff_rate ** attempt + ) + logger.info(f"Node {self.name} - {self.id}: retrying in {time_to_sleep} seconds.") + await asyncio.sleep(time_to_sleep) + + logger.error(f"Node {self.name} - {self.id}: execution failed after {n_attempt} attempts.") + raise error + def get_context_for_input_schema(self) -> dict: """Provides context for input schema that is required for proper validation.""" return {} @@ -1396,6 +1602,15 @@ def execute(self, input_data: dict[str, Any] | BaseModel, config: RunnableConfig """ pass + async def execute_async( + self, input_data: dict[str, Any] | BaseModel, config: RunnableConfig = None, **kwargs + ) -> Any: + """ + Async execution of the node. Override in subclasses for native async support. + Returns NotImplemented to signal fallback to sync execute() in a thread. + """ + return NotImplemented + def depends_on(self, nodes: Union["Node", list["Node"]], condition: ChoiceCondition | None = None) -> "Node": """ Add dependencies for this node. Accepts either a single node or a list of nodes. diff --git a/dynamiq/nodes/tools/cua_desktop/cua_desktop.py b/dynamiq/nodes/tools/cua_desktop/cua_desktop.py index 47ba7b39b..5e7605c97 100644 --- a/dynamiq/nodes/tools/cua_desktop/cua_desktop.py +++ b/dynamiq/nodes/tools/cua_desktop/cua_desktop.py @@ -320,6 +320,7 @@ class CuaDesktopTool(ConnectionNode): input_schema: ClassVar[type[CuaDesktopToolInputSchema]] = CuaDesktopToolInputSchema timeout: int = 3600 is_files_allowed: bool = True + _force_thread_executor: ClassVar[bool] = True _computer: Any | None = PrivateAttr(default=None) _loop = PrivateAttr(default=None) diff --git a/dynamiq/nodes/tools/e2b_desktop/e2b_desktop.py b/dynamiq/nodes/tools/e2b_desktop/e2b_desktop.py index 018c897e3..4a073a099 100644 --- a/dynamiq/nodes/tools/e2b_desktop/e2b_desktop.py +++ b/dynamiq/nodes/tools/e2b_desktop/e2b_desktop.py @@ -253,6 +253,7 @@ class E2BDesktopTool(ConnectionNode): input_schema: ClassVar[type[E2BDesktopToolInputSchema]] = E2BDesktopToolInputSchema timeout: int = 3600 is_files_allowed: bool = True + _force_thread_executor: ClassVar[bool] = True _desktop: Sandbox | None = PrivateAttr(default=None) _sandbox_id: str | None = PrivateAttr(default=None) diff --git a/tests/conftest.py b/tests/conftest.py index 38bfe46a7..391c948a7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,6 +55,11 @@ def response(stream: bool, *args, **kwargs): return model_r mock_llm = mocker.patch("dynamiq.nodes.llms.base.BaseLLM._completion", side_effect=response) + + async def async_response(*args, **kwargs): + return mock_llm(*args, **kwargs) + + mocker.patch("dynamiq.nodes.llms.base.BaseLLM._acompletion", side_effect=async_response) yield mock_llm diff --git a/tests/integration/flows/test_flow.py b/tests/integration/flows/test_flow.py index 26e15a5bb..af99605b4 100644 --- a/tests/integration/flows/test_flow.py +++ b/tests/integration/flows/test_flow.py @@ -382,40 +382,39 @@ async def test_workflow_with_depend_nodes_with_tracing_async( assert mock_llm_executor.call_count == 2 assert mock_llm_executor.call_args_list == [ mock.call( - tools=None, - tool_choice=None, model=openai_node.model, messages=expected_openai_messages, stream=False, temperature=openai_node.temperature, max_tokens=None, + tools=None, + tool_choice=None, stop=None, + top_p=None, seed=None, presence_penalty=None, frequency_penalty=None, - top_p=None, - api_key=openai_node.connection.api_key, - client=ANY, response_format=None, drop_params=True, api_base="https://api.openai.com/v1", + api_key=openai_node.connection.api_key, ), mock.call( - tools=None, - tool_choice=None, model=anthropic_node_with_dependency.model, messages=expected_anthropic_messages, stream=False, temperature=anthropic_node_with_dependency.temperature, max_tokens=None, + tools=None, + tool_choice=None, stop=None, + top_p=None, seed=None, presence_penalty=None, frequency_penalty=None, - top_p=None, - api_key=anthropic_node_with_dependency.connection.api_key, response_format=None, drop_params=True, + api_key=anthropic_node_with_dependency.connection.api_key, ), ] @@ -701,25 +700,25 @@ async def test_workflow_with_depend_nodes_and_depend_fail_async( assert response.error.failed_nodes[0].name == openai_node.name assert mock_llm_executor.call_count == 1 + # Async path uses _acompletion which does not pass `client` assert mock_llm_executor.call_args_list == [ mock.call( - tools=None, - tool_choice=None, model=openai_node.model, messages=expected_openai_messages, stream=False, temperature=openai_node.temperature, - api_key=openai_node.connection.api_key, - client=ANY, max_tokens=None, + tools=None, + tool_choice=None, stop=None, + top_p=None, seed=None, presence_penalty=None, frequency_penalty=None, - top_p=None, response_format=None, drop_params=True, api_base="https://api.openai.com/v1", + api_key=openai_node.connection.api_key, ) ] @@ -988,42 +987,42 @@ async def test_workflow_with_conditional_depend_nodes_with_tracing_async( assert response == RunnableResult(status=RunnableStatus.SUCCESS, input=input_data, output=expected_output) assert mock_llm_executor.call_count == 2 + # Async path uses _acompletion which does not pass `client` assert mock_llm_executor.call_args_list == [ mock.call( - tools=None, - tool_choice=None, model=openai_node_with_return_behavior.model, messages=expected_openai_messages, stream=False, temperature=openai_node_with_return_behavior.temperature, max_tokens=None, + tools=None, + tool_choice=None, stop=None, + top_p=None, seed=None, presence_penalty=None, frequency_penalty=None, - top_p=None, - api_key=openai_node_with_return_behavior.connection.api_key, - client=ANY, response_format=None, drop_params=True, api_base="https://api.openai.com/v1", + api_key=openai_node_with_return_behavior.connection.api_key, ), mock.call( - tools=None, - tool_choice=None, model=anthropic_node_with_success_status_conditional_depend.model, messages=expected_anthropic_messages, stream=False, temperature=anthropic_node_with_success_status_conditional_depend.temperature, max_tokens=None, + tools=None, + tool_choice=None, stop=None, + top_p=None, seed=None, presence_penalty=None, frequency_penalty=None, - top_p=None, - api_key=anthropic_node_with_success_status_conditional_depend.connection.api_key, response_format=None, drop_params=True, + api_key=anthropic_node_with_success_status_conditional_depend.connection.api_key, ), ] diff --git a/tests/integration/nodes/tools/test_mcp_tool.py b/tests/integration/nodes/tools/test_mcp_tool.py index 140470ec9..5064cef62 100644 --- a/tests/integration/nodes/tools/test_mcp_tool.py +++ b/tests/integration/nodes/tools/test_mcp_tool.py @@ -146,7 +146,7 @@ async def test_mock_tool_execute(mcp_server_tool): mock_exec.assert_called_once() assert result == mocked_result - with patch("dynamiq.nodes.tools.mcp.MCPTool.execute", return_value=mocked_result) as mock_exec: + with patch.object(MCPTool, "execute_async", new_callable=AsyncMock, return_value=mocked_result) as mock_exec: result = await tool.run(input_data={"a": 20, "b": 22}) mock_exec.assert_called_once() assert result.output == mocked_result diff --git a/tests/unit/cache/test_cache_utils_async.py b/tests/unit/cache/test_cache_utils_async.py new file mode 100644 index 000000000..c916f377b --- /dev/null +++ b/tests/unit/cache/test_cache_utils_async.py @@ -0,0 +1,68 @@ +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from dynamiq.cache.utils import cache_wf_entity_async + + +class TestCacheWfEntityAsync: + @pytest.mark.asyncio + async def test_cache_miss_calls_async_func(self): + """On cache miss, the async wrapper should await the wrapped coroutine.""" + async def my_async_func(*args, **kwargs): + return {"result": "computed"} + + cache = cache_wf_entity_async( + entity_id="node-1", + cache_enabled=False, + ) + wrapped = cache(my_async_func) + output, from_cache = await wrapped({"key": "val"}, config=None) + + assert output == {"result": "computed"} + assert from_cache is False + + @pytest.mark.asyncio + async def test_cache_hit_returns_cached(self): + """On cache hit, should return cached output without calling the function.""" + mock_func = AsyncMock(return_value={"result": "computed"}) + + mock_cache_manager = MagicMock() + mock_cache_manager.get_entity_output.return_value = {"result": "cached"} + mock_cls = MagicMock(return_value=mock_cache_manager) + + cache = cache_wf_entity_async( + entity_id="node-1", + cache_enabled=True, + cache_manager_cls=mock_cls, + cache_config=MagicMock(), + ) + wrapped = cache(mock_func) + output, from_cache = await wrapped({"key": "val"}, config=None) + + assert output == {"result": "cached"} + assert from_cache is True + mock_func.assert_not_called() + + @pytest.mark.asyncio + async def test_cache_miss_stores_result(self): + """On cache miss with caching enabled, should store the result.""" + async def my_async_func(*args, **kwargs): + return {"result": "computed"} + + mock_cache_manager = MagicMock() + mock_cache_manager.get_entity_output.return_value = None + mock_cls = MagicMock(return_value=mock_cache_manager) + + cache = cache_wf_entity_async( + entity_id="node-1", + cache_enabled=True, + cache_manager_cls=mock_cls, + cache_config=MagicMock(), + ) + wrapped = cache(my_async_func) + output, from_cache = await wrapped({"key": "val"}, config=None) + + assert output == {"result": "computed"} + assert from_cache is False + mock_cache_manager.set_entity_output.assert_called_once() diff --git a/tests/unit/nodes/llms/test_llm_async.py b/tests/unit/nodes/llms/test_llm_async.py new file mode 100644 index 000000000..fe35e9575 --- /dev/null +++ b/tests/unit/nodes/llms/test_llm_async.py @@ -0,0 +1,199 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from dynamiq.connections import OpenAI as OpenAIConnection +from dynamiq.nodes.llms.openai import OpenAI +from dynamiq.prompts import Prompt +from dynamiq.nodes.llms.base import FallbackConfig +from dynamiq.runnables import RunnableConfig, RunnableStatus + + +def make_mock_response(content="test response"): + """Create a mock litellm ModelResponse.""" + choice = MagicMock() + choice.message.content = content + choice.message.tool_calls = None + response = MagicMock() + response.choices = [choice] + return response + + +class TestBaseLLMAsync: + def test_base_llm_has_native_async(self): + """BaseLLM should report has_native_async=True after we add execute_async.""" + with patch("litellm.completion"), \ + patch("litellm.stream_chunk_builder"): + node = OpenAI( + model="gpt-4o-mini", + connection=OpenAIConnection(api_key="test-key"), + ) + assert node.has_native_async is True + + @pytest.mark.asyncio + async def test_execute_async_calls_acompletion(self): + """execute_async should call litellm.acompletion, not completion.""" + mock_response = make_mock_response("async response") + + with patch("litellm.completion"), \ + patch("litellm.stream_chunk_builder"): + node = OpenAI( + model="gpt-4o-mini", + connection=OpenAIConnection(api_key="test-key"), + prompt=Prompt(messages=[{"role": "user", "content": "Hello"}]), + ) + node._acompletion = AsyncMock(return_value=mock_response) + + result = await node.execute_async( + input_data=MagicMock(messages=None, files=None), + config=RunnableConfig(callbacks=[]), + ) + + node._acompletion.assert_called_once() + assert result["content"] == "async response" + + @pytest.mark.asyncio + async def test_execute_async_streaming(self): + """execute_async should handle streaming via async for.""" + chunk1 = MagicMock() + chunk1.model_dump.return_value = {"choices": [{"delta": {"content": "hel"}}]} + chunk2 = MagicMock() + chunk2.model_dump.return_value = {"choices": [{"delta": {"content": "lo"}}]} + + async def async_chunk_iter(): + for chunk in [chunk1, chunk2]: + yield chunk + + full_response = make_mock_response("hello") + + from dynamiq.callbacks.streaming import StreamingIteratorCallbackHandler + from dynamiq.types.streaming import StreamingConfig + + with patch("litellm.completion"), \ + patch("litellm.stream_chunk_builder"): + node = OpenAI( + model="gpt-4o-mini", + connection=OpenAIConnection(api_key="test-key"), + prompt=Prompt(messages=[{"role": "user", "content": "Hello"}]), + streaming=StreamingConfig(enabled=True), + ) + node._acompletion = AsyncMock(return_value=async_chunk_iter()) + node._stream_chunk_builder = MagicMock(return_value=full_response) + + streaming_handler = StreamingIteratorCallbackHandler() + result = await node.execute_async( + input_data=MagicMock(messages=None, files=None), + config=RunnableConfig(callbacks=[streaming_handler]), + ) + + assert result["content"] == "hello" + node._stream_chunk_builder.assert_called_once() + + +class TestBuildCompletionParams: + def test_build_completion_params_returns_expected_keys(self): + """_build_completion_params should return dict with model, messages, stream, etc.""" + with patch("litellm.completion"), \ + patch("litellm.stream_chunk_builder"): + node = OpenAI( + model="gpt-4o-mini", + connection=OpenAIConnection(api_key="test-key"), + prompt=Prompt(messages=[{"role": "user", "content": "Hello"}]), + ) + input_data = MagicMock(messages=None, files=None) + config = RunnableConfig(callbacks=[]) + prompt = node.prompt or Prompt(messages=[], tools=None, response_format=None) + messages = node.get_messages(prompt, input_data) + + params = node._build_completion_params( + messages=messages, + config=config, + prompt=prompt, + include_sync_client=True, + ) + + assert params["model"] == "openai/gpt-4o-mini" + assert params["messages"] == messages + assert "stream" in params + assert "temperature" in params + assert "drop_params" in params + + def test_build_completion_params_excludes_client_when_not_sync(self): + """When include_sync_client=False, client should not be in params.""" + with patch("litellm.completion"), \ + patch("litellm.stream_chunk_builder"): + node = OpenAI( + model="gpt-4o-mini", + connection=OpenAIConnection(api_key="test-key"), + prompt=Prompt(messages=[{"role": "user", "content": "Hello"}]), + ) + node.client = MagicMock() # Simulate a sync client + input_data = MagicMock(messages=None, files=None) + config = RunnableConfig(callbacks=[]) + prompt = node.prompt or Prompt(messages=[], tools=None, response_format=None) + messages = node.get_messages(prompt, input_data) + + params = node._build_completion_params( + messages=messages, + config=config, + prompt=prompt, + include_sync_client=False, + ) + + assert "client" not in params + + +class TestBaseLLMAsyncFallback: + @pytest.mark.asyncio + async def test_run_async_no_fallback_on_success(self): + """Successful run should not trigger fallback.""" + mock_response = make_mock_response("primary response") + + node = OpenAI( + model="gpt-4o-mini", + connection=OpenAIConnection(api_key="test-key"), + prompt=Prompt(messages=[{"role": "user", "content": "Hello"}]), + ) + node._acompletion = AsyncMock(return_value=mock_response) + + result = await node.run_async( + input_data={"input": "test"}, + config=RunnableConfig(callbacks=[]), + ) + assert result.status == RunnableStatus.SUCCESS + + @pytest.mark.asyncio + async def test_run_async_triggers_fallback_on_rate_limit(self): + """Failed primary with rate limit should trigger fallback LLM via async path.""" + mock_fallback_response = make_mock_response("fallback response") + + primary = OpenAI( + model="gpt-4o-mini", + connection=OpenAIConnection(api_key="test-key"), + prompt=Prompt(messages=[{"role": "user", "content": "Hello"}]), + ) + fallback_llm = OpenAI( + model="gpt-4o", + connection=OpenAIConnection(api_key="test-key"), + prompt=Prompt(messages=[{"role": "user", "content": "Hello"}]), + ) + primary.fallback = FallbackConfig(llm=fallback_llm, enabled=True) + + # Primary raises a rate limit error + from litellm.exceptions import RateLimitError + primary._acompletion = AsyncMock( + side_effect=RateLimitError( + message="Rate limit exceeded", + model="gpt-4o-mini", + llm_provider="openai", + ) + ) + # Fallback succeeds + fallback_llm._acompletion = AsyncMock(return_value=mock_fallback_response) + + result = await primary.run_async( + input_data={"input": "test"}, + config=RunnableConfig(callbacks=[]), + ) + assert result.status == RunnableStatus.SUCCESS + assert result.output["content"] == "fallback response" diff --git a/tests/unit/nodes/test_flow_async.py b/tests/unit/nodes/test_flow_async.py new file mode 100644 index 000000000..712d9b088 --- /dev/null +++ b/tests/unit/nodes/test_flow_async.py @@ -0,0 +1,73 @@ +import asyncio +import time +from unittest.mock import patch + +import pytest + +from dynamiq.flows.flow import Flow +from dynamiq.nodes.node import Node +from dynamiq.nodes.types import NodeGroup +from dynamiq.runnables import RunnableConfig, RunnableStatus + + +class SlowSyncNode(Node): + """Sync-only node that takes time.""" + group: NodeGroup = NodeGroup.UTILS + name: str = "SlowSync" + latency: float = 0.1 + + def execute(self, input_data, config=None, **kwargs): + time.sleep(self.latency) + return {"result": "sync_done"} + + +class FastAsyncNode(Node): + """Async node that is fast.""" + group: NodeGroup = NodeGroup.UTILS + name: str = "FastAsync" + + def execute(self, input_data, config=None, **kwargs): + time.sleep(0.1) + return {"result": "sync_done"} + + async def execute_async(self, input_data, config=None, **kwargs): + await asyncio.sleep(0.01) + return {"result": "async_done"} + + +class TestFlowAsyncExecutor: + @pytest.mark.asyncio + async def test_flow_run_async_creates_dedicated_executor(self): + """Each flow run should create its own ContextAwareThreadPoolExecutor.""" + node = SlowSyncNode(id="slow1") + flow = Flow(nodes=[node]) + + with patch("dynamiq.flows.flow.ContextAwareThreadPoolExecutor") as mock_executor_cls: + from dynamiq.executors.context import ContextAwareThreadPoolExecutor + real_executor = ContextAwareThreadPoolExecutor(max_workers=4) + mock_executor_cls.return_value = real_executor + + try: + _ = await flow.run_async(input_data={}, config=RunnableConfig(callbacks=[])) + finally: + real_executor.shutdown(wait=False) + + mock_executor_cls.assert_called_once() + + @pytest.mark.asyncio + async def test_concurrent_flows_have_separate_executors(self): + """Two concurrent flow runs should not share executors.""" + node_a = SlowSyncNode(id="a", latency=0.05) + node_b = SlowSyncNode(id="b", latency=0.05) + flow_a = Flow(nodes=[node_a]) + flow_b = Flow(nodes=[node_b]) + + config = RunnableConfig(callbacks=[]) + + results = await asyncio.gather( + flow_a.run_async(input_data={}, config=config), + flow_b.run_async(input_data={}, config=config), + ) + + assert results[0].status == RunnableStatus.SUCCESS + assert results[1].status == RunnableStatus.SUCCESS diff --git a/tests/unit/nodes/test_node.py b/tests/unit/nodes/test_node.py index c5d869285..51083c800 100644 --- a/tests/unit/nodes/test_node.py +++ b/tests/unit/nodes/test_node.py @@ -91,7 +91,7 @@ def node_async_result(): @pytest.fixture def openai_node(mocker, node_sync_result, node_async_result): mocker.patch("dynamiq.nodes.llms.base.BaseLLM.run_sync", return_value=node_sync_result) - mocker.patch("dynamiq.nodes.node.Node.run_async", return_value=node_async_result) + mocker.patch("dynamiq.nodes.llms.base.BaseLLM.run_async", return_value=node_async_result) yield OpenAI(model="gpt-4", connection=OpenAIConnection(api_key="test_api_key")) diff --git a/tests/unit/nodes/test_node_async.py b/tests/unit/nodes/test_node_async.py new file mode 100644 index 000000000..3b5a288d3 --- /dev/null +++ b/tests/unit/nodes/test_node_async.py @@ -0,0 +1,219 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import patch, MagicMock + +import pytest + +from dynamiq.executors.context import ContextAwareThreadPoolExecutor +from dynamiq.nodes.node import Node, ErrorHandling +from dynamiq.nodes.types import NodeGroup +from dynamiq.runnables import RunnableConfig, RunnableStatus + + +class SyncOnlyNode(Node): + """Test node with only sync execute.""" + group: NodeGroup = NodeGroup.UTILS + name: str = "SyncOnly" + + def execute(self, input_data, config=None, **kwargs): + return {"result": "sync"} + + +class NativeAsyncNode(Node): + """Test node with both sync and async execute.""" + group: NodeGroup = NodeGroup.UTILS + name: str = "NativeAsync" + + def execute(self, input_data, config=None, **kwargs): + return {"result": "sync"} + + async def execute_async(self, input_data, config=None, **kwargs): + await asyncio.sleep(0.01) + return {"result": "async"} + + +class TestNodeAsyncProtocol: + def test_sync_only_node_has_no_native_async(self): + node = SyncOnlyNode() + assert node.has_native_async is False + + def test_native_async_node_has_native_async(self): + node = NativeAsyncNode() + assert node.has_native_async is True + + @pytest.mark.asyncio + async def test_base_execute_async_returns_not_implemented(self): + node = SyncOnlyNode() + result = await node.execute_async(input_data={}) + assert result is NotImplemented + + +class FailThenSucceedAsyncNode(Node): + """Test node that fails N times then succeeds.""" + group: NodeGroup = NodeGroup.UTILS + name: str = "FailThenSucceed" + attempt_count: int = 0 + fail_times: int = 2 + error_handling: ErrorHandling = ErrorHandling( + max_retries=3, retry_interval_seconds=0.01, backoff_rate=1 + ) + + def execute(self, input_data, config=None, **kwargs): + return {"result": "sync"} + + async def execute_async(self, input_data, config=None, **kwargs): + self.attempt_count += 1 + if self.attempt_count <= self.fail_times: + raise ValueError(f"Attempt {self.attempt_count} failed") + return {"result": "success", "attempts": self.attempt_count} + + +class TimeoutAsyncNode(Node): + """Test node that takes too long.""" + group: NodeGroup = NodeGroup.UTILS + name: str = "TimeoutAsync" + error_handling: ErrorHandling = ErrorHandling(timeout_seconds=0.05) + + def execute(self, input_data, config=None, **kwargs): + return {"result": "sync"} + + async def execute_async(self, input_data, config=None, **kwargs): + await asyncio.sleep(10) # Way longer than timeout + return {"result": "should not reach"} + + +class TestExecuteAsyncWithRetry: + @pytest.mark.asyncio + async def test_retry_succeeds_after_failures(self): + node = FailThenSucceedAsyncNode() + config = RunnableConfig(callbacks=[]) + result = await node.execute_async_with_retry(input_data={}, config=config) + assert result == {"result": "success", "attempts": 3} + assert node.attempt_count == 3 + + @pytest.mark.asyncio + async def test_retry_exhausted_raises(self): + node = FailThenSucceedAsyncNode(fail_times=10) + config = RunnableConfig(callbacks=[]) + with pytest.raises(ValueError, match="Attempt .* failed"): + await node.execute_async_with_retry(input_data={}, config=config) + + @pytest.mark.asyncio + async def test_timeout_raises(self): + node = TimeoutAsyncNode() + config = RunnableConfig(callbacks=[]) + with pytest.raises(asyncio.TimeoutError): + await node.execute_async_with_retry(input_data={}, config=config) + + +class TestRunAsyncRouting: + @pytest.mark.asyncio + async def test_sync_node_uses_executor(self): + """Sync-only node should offload to the provided executor.""" + node = SyncOnlyNode() + executor = ThreadPoolExecutor(max_workers=2) + try: + result = await node.run_async( + input_data={"input": "test"}, config=RunnableConfig(callbacks=[]), executor=executor + ) + assert result.status == RunnableStatus.SUCCESS + assert result.output == {"result": "sync"} + finally: + executor.shutdown(wait=False) + + @pytest.mark.asyncio + async def test_async_node_runs_on_event_loop(self): + """Async-native node should NOT use executor — runs directly on event loop.""" + node = NativeAsyncNode() + result = await node.run_async( + input_data={"input": "test"}, config=RunnableConfig(callbacks=[]), executor=None + ) + assert result.status == RunnableStatus.SUCCESS + assert result.output == {"result": "async"} + + @pytest.mark.asyncio + async def test_sync_node_without_executor_falls_back_to_default(self): + """Sync-only node with executor=None should use default executor (backward compat).""" + node = SyncOnlyNode() + result = await node.run_async( + input_data={"input": "test"}, config=RunnableConfig(callbacks=[]) + ) + assert result.status == RunnableStatus.SUCCESS + assert result.output == {"result": "sync"} + + +class CachingAsyncNode(Node): + """Test node that tracks whether sync or async execute was called.""" + group: NodeGroup = NodeGroup.UTILS + name: str = "CachingAsync" + sync_called: bool = False + async_called: bool = False + + def execute(self, input_data, config=None, **kwargs): + self.sync_called = True + return {"result": "sync"} + + async def execute_async(self, input_data, config=None, **kwargs): + self.async_called = True + return {"result": "async"} + + +class TestRunAsyncContextPropagation: + @pytest.mark.asyncio + async def test_context_aware_executor_does_not_double_copy(self): + """When executor is ContextAwareThreadPoolExecutor, run_async should not + wrap with ctx.run since the executor handles context propagation.""" + node = SyncOnlyNode() + executor = ContextAwareThreadPoolExecutor(max_workers=2) + try: + with patch("dynamiq.nodes.node.contextvars") as mock_contextvars: + result = await node.run_async( + input_data={"input": "test"}, + config=RunnableConfig(callbacks=[]), + executor=executor, + ) + mock_contextvars.copy_context.assert_not_called() + assert result.status == RunnableStatus.SUCCESS + finally: + executor.shutdown(wait=False) + + @pytest.mark.asyncio + async def test_regular_executor_still_copies_context(self): + """When executor is a regular ThreadPoolExecutor, run_async should + still copy context explicitly.""" + node = SyncOnlyNode() + executor = ThreadPoolExecutor(max_workers=2) + try: + with patch("dynamiq.nodes.node.contextvars") as mock_contextvars: + mock_ctx = MagicMock() + mock_ctx.run = lambda fn, *a, **kw: fn(*a, **kw) + mock_contextvars.copy_context.return_value = mock_ctx + result = await node.run_async( + input_data={"input": "test"}, + config=RunnableConfig(callbacks=[]), + executor=executor, + ) + mock_contextvars.copy_context.assert_called_once() + assert result.status == RunnableStatus.SUCCESS + finally: + executor.shutdown(wait=False) + + +class TestAsyncCachingPath: + @pytest.mark.asyncio + async def test_cached_async_path_uses_execute_async(self): + """When caching is enabled in _run_async_native, it should still use + execute_async_with_retry (async path), not execute_with_retry (sync path).""" + from dynamiq.nodes.node import CachingConfig + + node = CachingAsyncNode( + caching=CachingConfig(enabled=True), + ) + # Run without actual cache config so cache decorator is a passthrough + result = await node.run_async( + input_data={"input": "test"}, + config=RunnableConfig(callbacks=[]), + ) + assert result.status == RunnableStatus.SUCCESS + assert node.async_called is True + assert node.sync_called is False