From 4c59e3dd436826290a0199013d6b308689ce2986 Mon Sep 17 00:00:00 2001 From: Peter Wilson Date: Fri, 16 Jan 2026 12:00:46 +0000 Subject: [PATCH 1/2] fix(core): Unwrap AgentCancel from framework exception chains Some frameworks (e.g., smolagents) catch callback exceptions and re-raise them wrapped in their own error types using `raise ... from e`. This caused AgentCancel to be wrapped in AgentRunError instead of propagating directly. Add _unwrap_agent_cancel() helper that traverses Python's exception chain (__cause__ and __context__) to find any wrapped AgentCancel, allowing it to propagate correctly to the caller. --- src/any_agent/frameworks/any_agent.py | 50 ++++++++++++++++ tests/unit/frameworks/test_agent_cancel.py | 69 ++++++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/src/any_agent/frameworks/any_agent.py b/src/any_agent/frameworks/any_agent.py index a8780b0b..e7ab6485 100644 --- a/src/any_agent/frameworks/any_agent.py +++ b/src/any_agent/frameworks/any_agent.py @@ -151,6 +151,51 @@ def __repr__(self) -> str: return f"AgentRunError({self._original_exception!r})" +def _unwrap_agent_cancel(exc: BaseException) -> AgentCancel | None: + """Traverse an exception chain to find an AgentCancel if present. + + When callbacks raise AgentCancel subclasses, some frameworks catch and + re-raise them wrapped in their own error types. For example: + + - smolagents wraps with AgentGenerationError using `raise ... from e` + - Other frameworks may use similar patterns + + Python's exception chaining stores the original exception in __cause__ + (explicit: `raise X from Y`) or __context__ (implicit: `raise X` inside + an except block). This function walks that chain to find any AgentCancel. + + Note: + This is a defensive catch-all for frameworks that properly chain + exceptions. Some frameworks may swallow exceptions entirely (e.g., + LangChain's default callback behavior) and require framework-specific + fixes to ensure AgentCancel propagates. See wrapper implementations + for details. + + Args: + exc: The exception to inspect. + + Returns: + The first AgentCancel found in the exception chain, or None if the + chain contains no AgentCancel instances. + + Example: + try: + framework.run() # Raises FrameworkError from AgentCancel + except Exception as e: + if cancel := _unwrap_agent_cancel(e): + # Found the wrapped AgentCancel, re-raise it directly. + raise cancel from e + + """ + current: BaseException | None = exc + while current is not None: + if isinstance(current, AgentCancel): + return current + # Check both explicit (raise from) and implicit (raise in except) chaining. + current = current.__cause__ or current.__context__ + return None + + class AnyAgent(ABC): """Base abstract class for all agent implementations. @@ -355,6 +400,11 @@ async def run_async(self, prompt: str, **kwargs: Any) -> AgentTrace: e._trace = trace raise + # Check if the framework wrapped an AgentCancel in its own error type. + if cancel := _unwrap_agent_cancel(e): + cancel._trace = trace + raise cancel from e + raise AgentRunError(trace, e) from e async with self._lock: diff --git a/tests/unit/frameworks/test_agent_cancel.py b/tests/unit/frameworks/test_agent_cancel.py index fe86d3be..26ac1946 100644 --- a/tests/unit/frameworks/test_agent_cancel.py +++ b/tests/unit/frameworks/test_agent_cancel.py @@ -8,6 +8,7 @@ from any_agent import AgentCancel, AgentConfig, AgentFramework, AgentRunError, AnyAgent from any_agent.callbacks import Callback, Context +from any_agent.frameworks.any_agent import _unwrap_agent_cancel from any_agent.testing.helpers import DEFAULT_SMALL_MODEL_ID, LLM_IMPORT_PATHS from any_agent.tracing.agent_trace import AgentTrace @@ -148,3 +149,71 @@ async def test_regular_exception_wrapped_in_agent_run_error(self) -> None: assert str(exc_info.value.original_exception) == "Unexpected error" assert exc_info.value.trace is not None assert len(exc_info.value.trace.spans) > 0 + + +class TestUnwrapAgentCancel: + """Tests for _unwrap_agent_cancel helper function.""" + + def test_returns_none_for_regular_exception(self) -> None: + """Returns None when exception chain contains no AgentCancel.""" + exc = RuntimeError("regular error") + assert _unwrap_agent_cancel(exc) is None + + def test_returns_none_for_chained_regular_exceptions(self) -> None: + """Returns None when chained exceptions contain no AgentCancel.""" + inner = ValueError("inner") + outer = RuntimeError("outer") + outer.__cause__ = inner + assert _unwrap_agent_cancel(outer) is None + + def test_finds_direct_agent_cancel(self) -> None: + """Returns the exception itself if it is an AgentCancel.""" + exc = StopAgent("direct") + result = _unwrap_agent_cancel(exc) + assert result is exc + + def test_finds_agent_cancel_via_cause(self) -> None: + """Finds AgentCancel in __cause__ (explicit raise from).""" + cancel = StopAgent("wrapped") + wrapper = RuntimeError("framework error") + wrapper.__cause__ = cancel + result = _unwrap_agent_cancel(wrapper) + assert result is cancel + + def test_finds_agent_cancel_via_context(self) -> None: + """Finds AgentCancel in __context__ (implicit chaining).""" + cancel = StopAgent("wrapped") + wrapper = RuntimeError("framework error") + wrapper.__context__ = cancel + result = _unwrap_agent_cancel(wrapper) + assert result is cancel + + def test_finds_deeply_nested_agent_cancel(self) -> None: + """Finds AgentCancel nested multiple levels deep.""" + cancel = StopAgent("deep") + middle = ValueError("middle") + middle.__cause__ = cancel + outer = RuntimeError("outer") + outer.__cause__ = middle + result = _unwrap_agent_cancel(outer) + assert result is cancel + + def test_prefers_cause_over_context(self) -> None: + """When both __cause__ and __context__ exist, follows __cause__ first.""" + cause_cancel = StopAgent("from cause") + context_cancel = SpecificStopAgent("from context") + wrapper = RuntimeError("wrapper") + wrapper.__cause__ = cause_cancel + wrapper.__context__ = context_cancel + result = _unwrap_agent_cancel(wrapper) + assert result is cause_cancel + + def test_finds_subclass_of_agent_cancel(self) -> None: + """Finds subclasses of AgentCancel (e.g., SpecificStopAgent).""" + cancel = SpecificStopAgent("specific") + wrapper = RuntimeError("wrapper") + wrapper.__cause__ = cancel + result = _unwrap_agent_cancel(wrapper) + assert result is cancel + assert isinstance(result, StopAgent) + assert isinstance(result, AgentCancel) From 5c33b5c0bbafb0de72b522aae6b4fb376fb74ebd Mon Sep 17 00:00:00 2001 From: Peter Wilson Date: Fri, 16 Jan 2026 12:01:37 +0000 Subject: [PATCH 2/2] fix(langchain): Propagate callback exceptions instead of swallowing LangChain's BaseCallbackHandler defaults to raise_error=False, which catches exceptions in callbacks, logs a warning, and continues execution. This silently broke the circuit-breaker pattern for callback exceptions. Set raise_error=True on our callback handler to ensure exceptions (including AgentCancel subclasses) propagate correctly to run_async. --- src/any_agent/callbacks/wrappers/langchain.py | 6 +++ .../wrappers/test_get_wrapper_and_unwrap.py | 45 ++++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/any_agent/callbacks/wrappers/langchain.py b/src/any_agent/callbacks/wrappers/langchain.py index 207490c4..532d9450 100644 --- a/src/any_agent/callbacks/wrappers/langchain.py +++ b/src/any_agent/callbacks/wrappers/langchain.py @@ -55,6 +55,12 @@ def after_tool_execution(*args, **kwargs): context = callback.after_tool_execution(context, *args, **kwargs) class _LangChainTracingCallback(BaseCallbackHandler): + # Propagate exceptions from callbacks instead of swallowing them. + # LangChain defaults to raise_error=False which logs warnings but + # continues execution. We need exceptions (especially AgentCancel) + # to propagate so they can be handled by run_async. + raise_error = True + def on_chat_model_start( self, serialized: dict[str, Any], diff --git a/tests/unit/callbacks/wrappers/test_get_wrapper_and_unwrap.py b/tests/unit/callbacks/wrappers/test_get_wrapper_and_unwrap.py index f2fed665..287b84f0 100644 --- a/tests/unit/callbacks/wrappers/test_get_wrapper_and_unwrap.py +++ b/tests/unit/callbacks/wrappers/test_get_wrapper_and_unwrap.py @@ -1,7 +1,9 @@ -from unittest.mock import MagicMock +from typing import Any +from unittest.mock import AsyncMock, MagicMock from any_agent import AgentFramework from any_agent.callbacks.wrappers import _get_wrapper_by_framework +from any_agent.callbacks.wrappers.langchain import _LangChainWrapper async def test_unwrap_before_wrap(agent_framework: AgentFramework) -> None: @@ -30,3 +32,44 @@ async def test_google_instrument_uninstrument() -> None: assert agent._agent.after_model_callback is None assert agent._agent.before_tool_callback is None assert agent._agent.after_tool_callback is None + + +async def test_langchain_callback_raises_errors() -> None: + """LangChain callback handler must have raise_error=True to propagate AgentCancel. + + By default, LangChain swallows exceptions in callback handlers and only logs + warnings. Setting raise_error=True ensures exceptions (especially AgentCancel + subclasses) propagate so they can be handled by run_async. + """ + agent = MagicMock() + agent._agent = MagicMock() + agent._agent.ainvoke = AsyncMock() + agent.config = MagicMock() + agent.config.callbacks = [] + + wrapper = _LangChainWrapper() + await wrapper.wrap(agent) + + # Call the wrapped ainvoke to trigger callback injection. + captured_kwargs: dict[str, Any] = {} + + async def capture_ainvoke(*args: Any, **kwargs: Any) -> MagicMock: + captured_kwargs.update(kwargs) + return MagicMock() + + # Replace the mock's original ainvoke to capture the kwargs. + wrapper._original_ainvoke = capture_ainvoke + await agent._agent.ainvoke("test") + + # Verify the callback was added with raise_error=True. + assert "config" in captured_kwargs + config = captured_kwargs["config"] + # Config can be a dict or RunnableConfig, handle both. + callbacks = ( + config.get("callbacks") if isinstance(config, dict) else config.callbacks + ) + assert callbacks is not None + assert len(callbacks) == 1 + callback = callbacks[0] + assert hasattr(callback, "raise_error") + assert callback.raise_error is True