From b5ba6ee0ad5253993fb9c40962607f614b4eae09 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 9 Feb 2026 14:18:00 -0500 Subject: [PATCH 1/8] feat: add TraceProvider interface and trace data types Introduce an abstract TraceProvider base class for retrieving agent trace data from observability backends for evaluation. This includes: - TraceProvider ABC with get_session, list_sessions, and get_session_by_trace_id methods - SessionFilter dataclass for filtering session discovery - Custom error hierarchy (TraceProviderError, SessionNotFoundError, TraceNotFoundError, ProviderError) - Session and Trace data types with span tree construction and convenience accessors (input/output messages, token usage, duration) - New providers module exposed at package level - Comprehensive unit tests for providers and trace types --- src/strands_evals/__init__.py | 3 +- src/strands_evals/providers/__init__.py | 19 ++ src/strands_evals/providers/exceptions.py | 25 +++ src/strands_evals/providers/trace_provider.py | 92 +++++++++ tests/strands_evals/providers/__init__.py | 0 .../providers/test_trace_provider.py | 176 ++++++++++++++++++ 6 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 src/strands_evals/providers/__init__.py create mode 100644 src/strands_evals/providers/exceptions.py create mode 100644 src/strands_evals/providers/trace_provider.py create mode 100644 tests/strands_evals/providers/__init__.py create mode 100644 tests/strands_evals/providers/test_trace_provider.py diff --git a/src/strands_evals/__init__.py b/src/strands_evals/__init__.py index 09d4526..f5c600c 100644 --- a/src/strands_evals/__init__.py +++ b/src/strands_evals/__init__.py @@ -1,4 +1,4 @@ -from . import evaluators, extractors, generators, simulation, telemetry, types +from . import evaluators, extractors, generators, providers, simulation, telemetry, types from .case import Case from .experiment import Experiment from .simulation import ActorSimulator, UserSimulator @@ -9,6 +9,7 @@ "Case", "evaluators", "extractors", + "providers", "types", "generators", "simulation", diff --git a/src/strands_evals/providers/__init__.py b/src/strands_evals/providers/__init__.py new file mode 100644 index 0000000..d7babcd --- /dev/null +++ b/src/strands_evals/providers/__init__.py @@ -0,0 +1,19 @@ +from .exceptions import ( + ProviderError, + SessionNotFoundError, + TraceNotFoundError, + TraceProviderError, +) +from .trace_provider import ( + SessionFilter, + TraceProvider, +) + +__all__ = [ + "ProviderError", + "SessionFilter", + "SessionNotFoundError", + "TraceNotFoundError", + "TraceProvider", + "TraceProviderError", +] diff --git a/src/strands_evals/providers/exceptions.py b/src/strands_evals/providers/exceptions.py new file mode 100644 index 0000000..e5ee890 --- /dev/null +++ b/src/strands_evals/providers/exceptions.py @@ -0,0 +1,25 @@ +"""Exceptions for trace providers.""" + + +class TraceProviderError(Exception): + """Base exception for trace provider errors.""" + + pass + + +class SessionNotFoundError(TraceProviderError): + """No traces found for the given session ID.""" + + pass + + +class TraceNotFoundError(TraceProviderError): + """Trace with the given ID not found.""" + + pass + + +class ProviderError(TraceProviderError): + """Provider is unreachable or returned an error.""" + + pass diff --git a/src/strands_evals/providers/trace_provider.py b/src/strands_evals/providers/trace_provider.py new file mode 100644 index 0000000..d9d7ed0 --- /dev/null +++ b/src/strands_evals/providers/trace_provider.py @@ -0,0 +1,92 @@ +"""TraceProvider interface for retrieving agent trace data from observability backends.""" + +from abc import ABC, abstractmethod +from collections.abc import Iterator +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from ..types.trace import Session + + +@dataclass +class SessionFilter: + """Filter criteria for discovering sessions. + + Universal fields are defined here. Provider-specific parameters + go in `additional_fields`. + """ + + start_time: datetime | None = None + end_time: datetime | None = None + limit: int | None = None + additional_fields: dict[str, Any] = field(default_factory=dict) + + +class TraceProvider(ABC): + """Retrieves agent trace data from observability backends for evaluation. + + Implementations handle authentication, pagination, and conversion from + provider-native formats to the Session/Trace types the evals system consumes. + """ + + @abstractmethod + def get_session(self, session_id: str) -> Session: + """Retrieve all traces for a session. + + Args: + session_id: The session identifier (maps to Strands session_id) + + Returns: + Session object containing all traces for the session + + Raises: + SessionNotFoundError: If no traces found for session_id + ProviderError: If the provider is unreachable or returns an error + """ + ... + + def list_sessions( + self, + session_filter: SessionFilter | None = None, + ) -> Iterator[str]: + """Discover session IDs matching filter criteria. + + Returns session IDs that can be fed to get_session(). + Not abstract — providers override to enable session discovery. + + Args: + session_filter: Optional filter. If None, provider-specific defaults apply. + + Yields: + Session ID strings + + Raises: + NotImplementedError: If the provider does not support session discovery + ProviderError: If the provider is unreachable or returns an error + """ + raise NotImplementedError( + "This provider does not support session discovery. Use get_session() with a known session_id instead." + ) + + def get_session_by_trace_id(self, trace_id: str) -> Session: + """Fetch a single trace and wrap it in a Session. + + Useful when someone has a trace_id but not a session_id, or for + single-shot agent runs without sessions. + Not abstract — providers override to enable trace-level retrieval. + + Args: + trace_id: The unique trace identifier + + Returns: + Session object containing the single trace + + Raises: + NotImplementedError: If the provider does not support trace-level retrieval + TraceNotFoundError: If trace doesn't exist + ProviderError: If the provider is unreachable or returns an error + """ + raise NotImplementedError( + "This provider does not support trace-level retrieval. Use get_session() with a session_id instead." + ) diff --git a/tests/strands_evals/providers/__init__.py b/tests/strands_evals/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strands_evals/providers/test_trace_provider.py b/tests/strands_evals/providers/test_trace_provider.py new file mode 100644 index 0000000..737e9e3 --- /dev/null +++ b/tests/strands_evals/providers/test_trace_provider.py @@ -0,0 +1,176 @@ +"""Tests for TraceProvider ABC, SessionFilter, and exception hierarchy.""" + +from collections.abc import Iterator +from datetime import datetime, timezone + +import pytest + +from strands_evals.providers.exceptions import ( + ProviderError, + SessionNotFoundError, + TraceNotFoundError, + TraceProviderError, +) +from strands_evals.providers.trace_provider import ( + SessionFilter, + TraceProvider, +) +from strands_evals.types.trace import Session + + +class ConcreteProvider(TraceProvider): + """Minimal concrete implementation for testing the ABC.""" + + def __init__(self, session: Session | None = None): + self._session = session + + def get_session(self, session_id: str) -> Session: + if self._session is None: + raise SessionNotFoundError(f"No session found: {session_id}") + return self._session + + +class FullProvider(TraceProvider): + """Provider that overrides all optional methods.""" + + def __init__(self, sessions: dict[str, Session] | None = None, session_ids: list[str] | None = None): + self._sessions = sessions or {} + self._session_ids = session_ids or [] + + def get_session(self, session_id: str) -> Session: + if session_id not in self._sessions: + raise SessionNotFoundError(f"No session found: {session_id}") + return self._sessions[session_id] + + def list_sessions(self, session_filter: SessionFilter | None = None) -> Iterator[str]: + yield from self._session_ids + + def get_session_by_trace_id(self, trace_id: str) -> Session: + for session in self._sessions.values(): + for trace in session.traces: + if trace.trace_id == trace_id: + return session + raise TraceNotFoundError(f"No trace found: {trace_id}") + + +# --- SessionFilter tests --- + + +class TestSessionFilter: + def test_defaults(self): + f = SessionFilter() + assert f.start_time is None + assert f.end_time is None + assert f.limit is None + assert f.additional_fields == {} + + def test_with_all_fields(self): + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 31, tzinfo=timezone.utc) + f = SessionFilter( + start_time=start, + end_time=end, + limit=50, + additional_fields={"environment": "production"}, + ) + assert f.start_time == start + assert f.end_time == end + assert f.limit == 50 + assert f.additional_fields == {"environment": "production"} + + def test_additional_fields_default_is_independent(self): + """Each instance gets its own dict (no shared mutable default).""" + f1 = SessionFilter() + f2 = SessionFilter() + f1.additional_fields["key"] = "value" + assert "key" not in f2.additional_fields + + +# --- Exception hierarchy tests --- + + +class TestExceptionHierarchy: + def test_trace_provider_error_is_exception(self): + assert issubclass(TraceProviderError, Exception) + + def test_session_not_found_is_trace_provider_error(self): + assert issubclass(SessionNotFoundError, TraceProviderError) + + def test_trace_not_found_is_trace_provider_error(self): + assert issubclass(TraceNotFoundError, TraceProviderError) + + def test_provider_error_is_trace_provider_error(self): + assert issubclass(ProviderError, TraceProviderError) + + def test_exceptions_carry_message(self): + err = SessionNotFoundError("session-123 not found") + assert "session-123 not found" in str(err) + + def test_catching_base_catches_all(self): + """All provider exceptions can be caught with TraceProviderError.""" + for exc_class in (SessionNotFoundError, TraceNotFoundError, ProviderError): + with pytest.raises(TraceProviderError): + raise exc_class("test") + + +# --- TraceProvider ABC tests --- + + +class TestTraceProviderABC: + def test_cannot_instantiate_without_get_session(self): + with pytest.raises(TypeError): + TraceProvider() # type: ignore[abstract] + + def test_concrete_provider_instantiates(self): + provider = ConcreteProvider() + assert isinstance(provider, TraceProvider) + + def test_get_session_returns_session(self): + session = Session(session_id="s1", traces=[]) + provider = ConcreteProvider(session=session) + result = provider.get_session("s1") + assert result == session + + def test_get_session_raises_session_not_found(self): + provider = ConcreteProvider(session=None) + with pytest.raises(SessionNotFoundError, match="No session found"): + provider.get_session("missing") + + def test_list_sessions_default_raises_not_implemented(self): + provider = ConcreteProvider() + with pytest.raises(NotImplementedError, match="does not support session discovery"): + list(provider.list_sessions()) + + def test_get_session_by_trace_id_default_raises_not_implemented(self): + provider = ConcreteProvider() + with pytest.raises(NotImplementedError, match="does not support trace-level retrieval"): + provider.get_session_by_trace_id("trace-123") + + +class TestFullProvider: + def test_list_sessions_yields_ids(self): + provider = FullProvider(session_ids=["s1", "s2", "s3"]) + result = list(provider.list_sessions()) + assert result == ["s1", "s2", "s3"] + + def test_list_sessions_with_filter(self): + provider = FullProvider(session_ids=["s1"]) + f = SessionFilter(limit=10) + result = list(provider.list_sessions(session_filter=f)) + assert result == ["s1"] + + def test_get_session_by_trace_id(self): + from strands_evals.types.trace import Trace + + session = Session( + session_id="s1", + traces=[Trace(trace_id="t1", session_id="s1", spans=[])], + ) + provider = FullProvider(sessions={"s1": session}) + result = provider.get_session_by_trace_id("t1") + assert result == session + + def test_get_session_by_trace_id_not_found(self): + provider = FullProvider(sessions={}) + with pytest.raises(TraceNotFoundError, match="No trace found"): + provider.get_session_by_trace_id("missing") From 3e8eb8982a3765b03f673d153e5d6b40023e0019 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 16 Feb 2026 12:05:24 -0500 Subject: [PATCH 2/8] feat(providers): Add TraceProvider interface for observability backends Add abstract TraceProvider that retrieves agent trace data from observability backends and returns Session/Trace types the evals system already consumes. - TraceProvider ABC with get_session() (required), list_sessions() and get_session_by_trace_id() (optional, raise NotImplementedError) - SessionFilter dataclass for time-range and limit-based discovery - Exception hierarchy: TraceProviderError base with SessionNotFoundError, TraceNotFoundError, ProviderError - Export providers module from strands_evals package --- src/strands_evals/providers/trace_provider.py | 33 +++++++----- src/strands_evals/types/evaluation.py | 7 +-- .../providers/test_trace_provider.py | 54 +++++++++++-------- 3 files changed, 56 insertions(+), 38 deletions(-) diff --git a/src/strands_evals/providers/trace_provider.py b/src/strands_evals/providers/trace_provider.py index d9d7ed0..f40d95e 100644 --- a/src/strands_evals/providers/trace_provider.py +++ b/src/strands_evals/providers/trace_provider.py @@ -6,7 +6,7 @@ from datetime import datetime from typing import Any -from ..types.trace import Session +from ..types.evaluation import TaskOutput @dataclass @@ -27,18 +27,23 @@ class TraceProvider(ABC): """Retrieves agent trace data from observability backends for evaluation. Implementations handle authentication, pagination, and conversion from - provider-native formats to the Session/Trace types the evals system consumes. + provider-native formats to the types the evals system consumes. """ @abstractmethod - def get_session(self, session_id: str) -> Session: - """Retrieve all traces for a session. + def get_evaluation_data(self, session_id: str) -> TaskOutput: + """Retrieve all data needed to evaluate a session. + + This is the primary access pattern — given a session ID, fetch all + traces, extract the agent output and trajectory, and return them + in a format ready for evaluation. Args: session_id: The session identifier (maps to Strands session_id) Returns: - Session object containing all traces for the session + TaskOutput with 'output' (final agent response) and + 'trajectory' (Session containing all traces/spans) Raises: SessionNotFoundError: If no traces found for session_id @@ -52,7 +57,7 @@ def list_sessions( ) -> Iterator[str]: """Discover session IDs matching filter criteria. - Returns session IDs that can be fed to get_session(). + Returns session IDs that can be fed to get_evaluation_data(). Not abstract — providers override to enable session discovery. Args: @@ -66,21 +71,22 @@ def list_sessions( ProviderError: If the provider is unreachable or returns an error """ raise NotImplementedError( - "This provider does not support session discovery. Use get_session() with a known session_id instead." + "This provider does not support session discovery. " + "Use get_evaluation_data() with a known session_id instead." ) - def get_session_by_trace_id(self, trace_id: str) -> Session: - """Fetch a single trace and wrap it in a Session. + def get_evaluation_data_by_trace_id(self, trace_id: str) -> TaskOutput: + """Fetch a single trace and return its evaluation data. - Useful when someone has a trace_id but not a session_id, or for - single-shot agent runs without sessions. + Useful when someone has a trace_id (e.g., from a Langfuse link) but + not a session_id, or for single-shot agent runs without sessions. Not abstract — providers override to enable trace-level retrieval. Args: trace_id: The unique trace identifier Returns: - Session object containing the single trace + TaskOutput with 'output' and 'trajectory' for the single trace Raises: NotImplementedError: If the provider does not support trace-level retrieval @@ -88,5 +94,6 @@ def get_session_by_trace_id(self, trace_id: str) -> Session: ProviderError: If the provider is unreachable or returns an error """ raise NotImplementedError( - "This provider does not support trace-level retrieval. Use get_session() with a session_id instead." + "This provider does not support trace-level retrieval. " + "Use get_evaluation_data() with a session_id instead." ) diff --git a/src/strands_evals/types/evaluation.py b/src/strands_evals/types/evaluation.py index 05d596e..8a3c028 100644 --- a/src/strands_evals/types/evaluation.py +++ b/src/strands_evals/types/evaluation.py @@ -3,13 +3,14 @@ from .trace import Session + InputT = TypeVar("InputT") OutputT = TypeVar("OutputT") class Interaction(TypedDict, total=False): - """ - Represents a single interaction in a multi-agent or multi-step system. + """ Represents a single interaction in a multi-agent or multi-step system. + Used to capture the communication flow and dependencies between different components (agents, tools, or processing nodes) during task execution. @@ -56,7 +57,7 @@ class TaskOutput(TypedDict, total=False): """ output: Any - trajectory: list[Any] + trajectory: Union[list[Any], Session, None] interactions: list[Interaction] input: Any diff --git a/tests/strands_evals/providers/test_trace_provider.py b/tests/strands_evals/providers/test_trace_provider.py index 737e9e3..07e7d30 100644 --- a/tests/strands_evals/providers/test_trace_provider.py +++ b/tests/strands_evals/providers/test_trace_provider.py @@ -15,7 +15,8 @@ SessionFilter, TraceProvider, ) -from strands_evals.types.trace import Session +from strands_evals.types.evaluation import TaskOutput +from strands_evals.types.trace import Session, Trace class ConcreteProvider(TraceProvider): @@ -24,10 +25,13 @@ class ConcreteProvider(TraceProvider): def __init__(self, session: Session | None = None): self._session = session - def get_session(self, session_id: str) -> Session: + def get_evaluation_data(self, session_id: str) -> TaskOutput: if self._session is None: raise SessionNotFoundError(f"No session found: {session_id}") - return self._session + return TaskOutput( + output="test response", + trajectory=self._session, + ) class FullProvider(TraceProvider): @@ -37,19 +41,25 @@ def __init__(self, sessions: dict[str, Session] | None = None, session_ids: list self._sessions = sessions or {} self._session_ids = session_ids or [] - def get_session(self, session_id: str) -> Session: + def get_evaluation_data(self, session_id: str) -> TaskOutput: if session_id not in self._sessions: raise SessionNotFoundError(f"No session found: {session_id}") - return self._sessions[session_id] + return TaskOutput( + output="test response", + trajectory=self._sessions[session_id], + ) def list_sessions(self, session_filter: SessionFilter | None = None) -> Iterator[str]: yield from self._session_ids - def get_session_by_trace_id(self, trace_id: str) -> Session: + def get_evaluation_data_by_trace_id(self, trace_id: str) -> TaskOutput: for session in self._sessions.values(): for trace in session.traces: if trace.trace_id == trace_id: - return session + return TaskOutput( + output="test response", + trajectory=session, + ) raise TraceNotFoundError(f"No trace found: {trace_id}") @@ -117,7 +127,7 @@ def test_catching_base_catches_all(self): class TestTraceProviderABC: - def test_cannot_instantiate_without_get_session(self): + def test_cannot_instantiate_without_get_evaluation_data(self): with pytest.raises(TypeError): TraceProvider() # type: ignore[abstract] @@ -125,26 +135,27 @@ def test_concrete_provider_instantiates(self): provider = ConcreteProvider() assert isinstance(provider, TraceProvider) - def test_get_session_returns_session(self): + def test_get_evaluation_data_returns_task_output(self): session = Session(session_id="s1", traces=[]) provider = ConcreteProvider(session=session) - result = provider.get_session("s1") - assert result == session + result = provider.get_evaluation_data("s1") + assert result["output"] == "test response" + assert result["trajectory"] == session - def test_get_session_raises_session_not_found(self): + def test_get_evaluation_data_raises_session_not_found(self): provider = ConcreteProvider(session=None) with pytest.raises(SessionNotFoundError, match="No session found"): - provider.get_session("missing") + provider.get_evaluation_data("missing") def test_list_sessions_default_raises_not_implemented(self): provider = ConcreteProvider() with pytest.raises(NotImplementedError, match="does not support session discovery"): list(provider.list_sessions()) - def test_get_session_by_trace_id_default_raises_not_implemented(self): + def test_get_evaluation_data_by_trace_id_default_raises_not_implemented(self): provider = ConcreteProvider() with pytest.raises(NotImplementedError, match="does not support trace-level retrieval"): - provider.get_session_by_trace_id("trace-123") + provider.get_evaluation_data_by_trace_id("trace-123") class TestFullProvider: @@ -159,18 +170,17 @@ def test_list_sessions_with_filter(self): result = list(provider.list_sessions(session_filter=f)) assert result == ["s1"] - def test_get_session_by_trace_id(self): - from strands_evals.types.trace import Trace - + def test_get_evaluation_data_by_trace_id(self): session = Session( session_id="s1", traces=[Trace(trace_id="t1", session_id="s1", spans=[])], ) provider = FullProvider(sessions={"s1": session}) - result = provider.get_session_by_trace_id("t1") - assert result == session + result = provider.get_evaluation_data_by_trace_id("t1") + assert result["output"] == "test response" + assert result["trajectory"] == session - def test_get_session_by_trace_id_not_found(self): + def test_get_evaluation_data_by_trace_id_not_found(self): provider = FullProvider(sessions={}) with pytest.raises(TraceNotFoundError, match="No trace found"): - provider.get_session_by_trace_id("missing") + provider.get_evaluation_data_by_trace_id("missing") From 73bd27af642a1bc1e5c3dbaf554f68880e4f97c9 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Mon, 16 Feb 2026 12:16:50 -0500 Subject: [PATCH 3/8] feat(providers): Add TraceProvider interface for observability backends Add abstract TraceProvider that retrieves agent trace data from observability backends and returns Session/Trace types the evals system already consumes. - TraceProvider ABC with get_session() (required), list_sessions() and get_session_by_trace_id() (optional, raise NotImplementedError) - SessionFilter dataclass for time-range and limit-based discovery - Exception hierarchy: TraceProviderError base with SessionNotFoundError, TraceNotFoundError, ProviderError - Export providers module from strands_evals package --- src/strands_evals/providers/trace_provider.py | 3 +-- src/strands_evals/types/evaluation.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/strands_evals/providers/trace_provider.py b/src/strands_evals/providers/trace_provider.py index f40d95e..ed34189 100644 --- a/src/strands_evals/providers/trace_provider.py +++ b/src/strands_evals/providers/trace_provider.py @@ -94,6 +94,5 @@ def get_evaluation_data_by_trace_id(self, trace_id: str) -> TaskOutput: ProviderError: If the provider is unreachable or returns an error """ raise NotImplementedError( - "This provider does not support trace-level retrieval. " - "Use get_evaluation_data() with a session_id instead." + "This provider does not support trace-level retrieval. Use get_evaluation_data() with a session_id instead." ) diff --git a/src/strands_evals/types/evaluation.py b/src/strands_evals/types/evaluation.py index 8a3c028..8188998 100644 --- a/src/strands_evals/types/evaluation.py +++ b/src/strands_evals/types/evaluation.py @@ -3,13 +3,12 @@ from .trace import Session - InputT = TypeVar("InputT") OutputT = TypeVar("OutputT") class Interaction(TypedDict, total=False): - """ Represents a single interaction in a multi-agent or multi-step system. + """Represents a single interaction in a multi-agent or multi-step system. Used to capture the communication flow and dependencies between different From d70baebf52d7baadf546be4499f6f48eca9c408f Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Tue, 17 Feb 2026 10:48:47 -0500 Subject: [PATCH 4/8] feat(providers): Add LangfuseProvider for remote trace evaluation Implement LangfuseProvider that fetches agent traces from Langfuse and converts them to Session objects for the evals pipeline. Supports session-level and trace-level retrieval with paginated API calls. - get_evaluation_data(): fetch traces by session ID, convert Langfuse observations to typed spans (InferenceSpan, ToolExecutionSpan, AgentInvocationSpan), extract output from last agent invocation - list_sessions(): paginated session discovery with time-range filtering - get_evaluation_data_by_trace_id(): single trace retrieval - Host resolution: explicit param > LANGFUSE_HOST env var > cloud default - 30 unit tests (mocked SDK), 15 integration tests (real Langfuse + evaluators) --- src/strands_evals/providers/__init__.py | 2 + .../providers/langfuse_provider.py | 431 ++++++++++++++++++ .../providers/test_langfuse_provider.py | 376 +++++++++++++++ tests_integ/test_langfuse_provider.py | 242 ++++++++++ 4 files changed, 1051 insertions(+) create mode 100644 src/strands_evals/providers/langfuse_provider.py create mode 100644 tests/strands_evals/providers/test_langfuse_provider.py create mode 100644 tests_integ/test_langfuse_provider.py diff --git a/src/strands_evals/providers/__init__.py b/src/strands_evals/providers/__init__.py index d7babcd..fcbab85 100644 --- a/src/strands_evals/providers/__init__.py +++ b/src/strands_evals/providers/__init__.py @@ -4,12 +4,14 @@ TraceNotFoundError, TraceProviderError, ) +from .langfuse_provider import LangfuseProvider from .trace_provider import ( SessionFilter, TraceProvider, ) __all__ = [ + "LangfuseProvider", "ProviderError", "SessionFilter", "SessionNotFoundError", diff --git a/src/strands_evals/providers/langfuse_provider.py b/src/strands_evals/providers/langfuse_provider.py new file mode 100644 index 0000000..33c64f6 --- /dev/null +++ b/src/strands_evals/providers/langfuse_provider.py @@ -0,0 +1,431 @@ +"""Langfuse trace provider for retrieving agent traces from Langfuse.""" + +import json +import logging +import os +from collections.abc import Iterator +from typing import Any + +from ..providers.exceptions import ( + ProviderError, + SessionNotFoundError, + TraceNotFoundError, +) +from ..providers.trace_provider import SessionFilter, TraceProvider +from ..types.evaluation import TaskOutput +from ..types.trace import ( + AgentInvocationSpan, + AssistantMessage, + InferenceSpan, + Session, + SpanInfo, + TextContent, + ToolCall, + ToolCallContent, + ToolConfig, + ToolExecutionSpan, + ToolResult, + ToolResultContent, + Trace, + UserMessage, +) + +try: + from langfuse import Langfuse +except ImportError: + Langfuse = None # type: ignore[assignment, misc] + +logger = logging.getLogger(__name__) + +_PAGE_SIZE = 100 + + +class LangfuseProvider(TraceProvider): + """Retrieves agent trace data from Langfuse for evaluation. + + Fetches traces and observations via the Langfuse Python SDK, + converts Langfuse observations to typed evals spans, and returns + Session objects ready for the evaluation pipeline. + """ + + def __init__( + self, + public_key: str | None = None, + secret_key: str | None = None, + host: str | None = None, + ): + if Langfuse is None: + raise ProviderError( + "Langfuse SDK is not installed. Install it with: pip install 'strands-evals[langfuse]'" + ) + + resolved_public_key = public_key or os.environ.get("LANGFUSE_PUBLIC_KEY") + resolved_secret_key = secret_key or os.environ.get("LANGFUSE_SECRET_KEY") + resolved_host = host or os.environ.get("LANGFUSE_HOST", "https://us.cloud.langfuse.com") + + if not resolved_public_key or not resolved_secret_key: + raise ProviderError( + "Langfuse credentials required. Provide public_key/secret_key or set " + "LANGFUSE_PUBLIC_KEY/LANGFUSE_SECRET_KEY environment variables." + ) + + self._client = Langfuse( + public_key=resolved_public_key, + secret_key=resolved_secret_key, + host=resolved_host, + ) + + def get_evaluation_data(self, session_id: str) -> TaskOutput: + """Fetch all traces for a session and return evaluation data.""" + try: + all_traces = self._fetch_traces_for_session(session_id) + except (SessionNotFoundError, ProviderError): + raise + except Exception as e: + raise ProviderError(f"Langfuse: failed to fetch traces for session '{session_id}': {e}") from e + + if not all_traces: + raise SessionNotFoundError(f"Langfuse: no traces found for session_id='{session_id}'") + + session = self._build_session(session_id, all_traces) + + if not session.traces: + raise SessionNotFoundError( + f"Langfuse: traces found for session_id='{session_id}' but none contained convertible observations" + ) + + output = self._extract_output(session) + + return TaskOutput(output=output, trajectory=session) + + def list_sessions(self, session_filter: SessionFilter | None = None) -> Iterator[str]: + """Yield session IDs from Langfuse, with optional time filtering.""" + try: + page = 1 + while True: + kwargs: dict[str, Any] = {"page": page, "limit": _PAGE_SIZE} + if session_filter: + if session_filter.start_time: + kwargs["from_timestamp"] = session_filter.start_time + if session_filter.end_time: + kwargs["to_timestamp"] = session_filter.end_time + + response = self._client.api.sessions.list(**kwargs) + + for s in response.data: + yield s.id + + if page >= response.meta.total_pages: + break + page += 1 + except ProviderError: + raise + except Exception as e: + raise ProviderError(f"Langfuse: failed to list sessions: {e}") from e + + def get_evaluation_data_by_trace_id(self, trace_id: str) -> TaskOutput: + """Fetch a single trace by ID and return evaluation data.""" + try: + trace_detail = self._client.api.trace.get(trace_id) + except Exception as e: + raise TraceNotFoundError(f"Langfuse: trace not found for trace_id='{trace_id}': {e}") from e + + session_id = trace_detail.session_id or trace_id + observations = trace_detail.observations or [] + + spans = self._convert_observations(observations, session_id) + trace = Trace(trace_id=trace_id, session_id=session_id, spans=spans) + session = Session(session_id=session_id, traces=[trace] if trace.spans else []) + output = self._extract_output(session) + + return TaskOutput(output=output, trajectory=session) + + # --- Internal: fetching --- + + def _fetch_traces_for_session(self, session_id: str) -> list: + """Fetch all trace metadata for a session, handling pagination.""" + all_traces = [] + page = 1 + while True: + response = self._client.api.trace.list( + session_id=session_id, page=page, limit=_PAGE_SIZE + ) + all_traces.extend(response.data) + if page >= response.meta.total_pages: + break + page += 1 + return all_traces + + def _fetch_observations(self, trace_id: str) -> list: + """Fetch all observations for a trace, handling pagination.""" + all_observations = [] + page = 1 + while True: + response = self._client.api.observations.get_many( + trace_id=trace_id, page=page, limit=_PAGE_SIZE + ) + all_observations.extend(response.data) + if page >= response.meta.total_pages: + break + page += 1 + return all_observations + + # --- Internal: building Session --- + + def _build_session(self, session_id: str, langfuse_traces: list) -> Session: + """Convert Langfuse traces + observations into an evals Session.""" + traces = [] + for lf_trace in langfuse_traces: + observations = self._fetch_observations(lf_trace.id) + spans = self._convert_observations(observations, session_id) + if spans: + traces.append(Trace(trace_id=lf_trace.id, session_id=session_id, spans=spans)) + return Session(session_id=session_id, traces=traces) + + def _convert_observations(self, observations: list, session_id: str) -> list: + """Convert a list of Langfuse observations to typed evals spans.""" + spans = [] + for obs in observations: + try: + span = self._convert_observation(obs, session_id) + if span is not None: + spans.append(span) + except Exception as e: + logger.warning("Failed to convert observation %s: %s", obs.id, e) + return spans + + def _convert_observation(self, obs: Any, session_id: str) -> Any: + """Convert a single Langfuse observation to a typed span, or None to skip.""" + obs_type = obs.type + obs_name = obs.name or "" + + if obs_type == "GENERATION": + return self._convert_generation(obs, session_id) + elif obs_type == "SPAN": + if obs_name.startswith("execute_tool"): + return self._convert_tool_execution(obs, session_id) + elif obs_name.startswith("invoke_agent"): + return self._convert_agent_invocation(obs, session_id) + else: + logger.debug("Skipping SPAN with unrecognized name: %s", obs_name) + return None + else: + logger.debug("Skipping observation with type: %s", obs_type) + return None + + def _create_span_info(self, obs: Any, session_id: str) -> SpanInfo: + return SpanInfo( + trace_id=obs.trace_id, + span_id=obs.id, + session_id=session_id, + parent_span_id=obs.parent_observation_id, + start_time=obs.start_time, + end_time=obs.end_time, + ) + + # --- Internal: conversion methods --- + + def _convert_generation(self, obs: Any, session_id: str) -> InferenceSpan: + """Convert a GENERATION observation to an InferenceSpan.""" + span_info = self._create_span_info(obs, session_id) + messages = self._extract_messages_from_generation(obs) + return InferenceSpan(span_info=span_info, messages=messages, metadata=obs.metadata or {}) + + def _extract_messages_from_generation(self, obs: Any) -> list[UserMessage | AssistantMessage]: + """Extract messages from a GENERATION observation's input/output.""" + messages: list[UserMessage | AssistantMessage] = [] + + # Process input messages + obs_input = obs.input + if isinstance(obs_input, list): + for msg in obs_input: + if isinstance(msg, dict): + converted = self._convert_message(msg) + if converted: + messages.append(converted) + + # Process output message + obs_output = obs.output + if isinstance(obs_output, dict): + converted = self._convert_message(obs_output) + if converted: + messages.append(converted) + + return messages + + def _convert_message(self, msg: dict) -> UserMessage | AssistantMessage | None: + """Convert a Langfuse message dict to a UserMessage or AssistantMessage.""" + role = msg.get("role", "") + content_data = msg.get("content", []) + + if role == "assistant": + content = self._parse_assistant_content(content_data) + return AssistantMessage(content=content) if content else None + elif role == "user": + content = self._parse_user_content(content_data) + return UserMessage(content=content) if content else None + else: + # Tool messages come back as user messages with tool results + if isinstance(content_data, list): + tool_results = self._parse_tool_result_content(content_data) + if tool_results: + return UserMessage(content=tool_results) + return None + + def _parse_user_content(self, content_data: Any) -> list[TextContent | ToolResultContent]: + """Parse user message content.""" + result: list[TextContent | ToolResultContent] = [] + if isinstance(content_data, list): + for item in content_data: + if isinstance(item, dict) and "text" in item: + result.append(TextContent(text=item["text"])) + elif isinstance(content_data, str): + result.append(TextContent(text=content_data)) + return result + + def _parse_assistant_content(self, content_data: Any) -> list[TextContent | ToolCallContent]: + """Parse assistant message content.""" + result: list[TextContent | ToolCallContent] = [] + if isinstance(content_data, list): + for item in content_data: + if isinstance(item, dict): + if "text" in item: + result.append(TextContent(text=item["text"])) + elif "toolUse" in item: + tu = item["toolUse"] + result.append(ToolCallContent( + name=tu["name"], + arguments=tu.get("input", {}), + tool_call_id=tu.get("toolUseId"), + )) + elif isinstance(content_data, str): + result.append(TextContent(text=content_data)) + return result + + def _parse_tool_result_content(self, content_data: list) -> list[TextContent | ToolResultContent]: + """Parse tool result content from a message.""" + result: list[TextContent | ToolResultContent] = [] + for item in content_data: + if isinstance(item, dict) and "toolResult" in item: + tr = item["toolResult"] + text = "" + if "content" in tr and tr["content"]: + c = tr["content"] + text = c[0].get("text", "") if isinstance(c, list) else str(c) + result.append(ToolResultContent( + content=text, + error=tr.get("error"), + tool_call_id=tr.get("toolUseId"), + )) + return result + + def _convert_tool_execution(self, obs: Any, session_id: str) -> ToolExecutionSpan: + """Convert an execute_tool SPAN observation to a ToolExecutionSpan.""" + span_info = self._create_span_info(obs, session_id) + obs_input = obs.input or {} + obs_output = obs.output + + # Extract tool call info from input + if isinstance(obs_input, dict): + tool_name = obs_input.get("name", "") + tool_arguments = obs_input.get("arguments", {}) + tool_call_id = obs_input.get("toolUseId") + else: + tool_name = "" + tool_arguments = {} + tool_call_id = None + + # Extract tool result from output + if isinstance(obs_output, str): + result_content = obs_output + result_error = None + elif isinstance(obs_output, dict): + result_content = obs_output.get("result", str(obs_output)) + status = obs_output.get("status", "") + result_error = None if status == "success" else (str(status) if status else None) + else: + result_content = str(obs_output) if obs_output is not None else "" + result_error = None + + tool_call = ToolCall(name=tool_name, arguments=tool_arguments, tool_call_id=tool_call_id) + tool_result = ToolResult(content=result_content, error=result_error, tool_call_id=tool_call_id) + + return ToolExecutionSpan( + span_info=span_info, tool_call=tool_call, tool_result=tool_result, metadata=obs.metadata or {} + ) + + def _convert_agent_invocation(self, obs: Any, session_id: str) -> AgentInvocationSpan: + """Convert an invoke_agent SPAN observation to an AgentInvocationSpan.""" + span_info = self._create_span_info(obs, session_id) + obs_input = obs.input + obs_output = obs.output + + # Extract user prompt from input + user_prompt = self._extract_user_prompt(obs_input) + + # Extract agent response from output + agent_response = self._extract_agent_response(obs_output) + + # Extract available tools from metadata + available_tools = self._extract_available_tools(obs.metadata) + + return AgentInvocationSpan( + span_info=span_info, + user_prompt=user_prompt, + agent_response=agent_response, + available_tools=available_tools, + metadata=obs.metadata or {}, + ) + + def _extract_user_prompt(self, obs_input: Any) -> str: + """Extract user prompt from observation input (handles string or list formats).""" + if isinstance(obs_input, str): + return obs_input + if isinstance(obs_input, list): + for item in obs_input: + if isinstance(item, dict) and "text" in item: + return item["text"] + if isinstance(obs_input, dict) and "text" in obs_input: + return obs_input["text"] + return str(obs_input) if obs_input else "" + + def _extract_agent_response(self, obs_output: Any) -> str: + """Extract agent response from observation output (handles string or dict formats).""" + if isinstance(obs_output, str): + return obs_output + if isinstance(obs_output, dict): + if "text" in obs_output: + return obs_output["text"] + if "content" in obs_output: + content = obs_output["content"] + if isinstance(content, list): + for item in content: + if isinstance(item, dict) and "text" in item: + return item["text"] + elif isinstance(content, str): + return content + return str(obs_output) if obs_output else "" + + def _extract_available_tools(self, metadata: Any) -> list[ToolConfig]: + """Extract available tools from observation metadata.""" + if not metadata or not isinstance(metadata, dict): + return [] + tools_data = metadata.get("tools") + if not tools_data: + return [] + try: + if isinstance(tools_data, str): + tools_list = json.loads(tools_data) + else: + tools_list = tools_data + return [ToolConfig(name=name) for name in tools_list if isinstance(name, str)] + except (json.JSONDecodeError, TypeError): + return [] + + def _extract_output(self, session: Session) -> str: + """Extract the final agent response from the session for TaskOutput.output.""" + for trace in reversed(session.traces): + for span in reversed(trace.spans): + if isinstance(span, AgentInvocationSpan): + return span.agent_response + return "" \ No newline at end of file diff --git a/tests/strands_evals/providers/test_langfuse_provider.py b/tests/strands_evals/providers/test_langfuse_provider.py new file mode 100644 index 0000000..08eb268 --- /dev/null +++ b/tests/strands_evals/providers/test_langfuse_provider.py @@ -0,0 +1,376 @@ +"""Tests for LangfuseProvider — mocked Langfuse SDK.""" + +import os +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from strands_evals.providers.exceptions import ( + ProviderError, + SessionNotFoundError, + TraceNotFoundError, +) +from strands_evals.types.trace import ( + AgentInvocationSpan, + InferenceSpan, + Session, + ToolExecutionSpan, +) + + +# --- Helpers --- + + +def _meta(page=1, total_pages=1, total_items=10, limit=100): + m = MagicMock() + m.page, m.limit, m.total_items, m.total_pages = page, limit, total_items, total_pages + return m + + +def _trace(trace_id, session_id, output=None): + t = MagicMock() + t.id, t.session_id, t.output = trace_id, session_id, output + t.name, t.input, t.metadata = None, None, None + t.timestamp = datetime(2025, 1, 15, 10, 0, 0, tzinfo=timezone.utc) + return t + + +def _obs(obs_id, trace_id, obs_type, name=None, obs_input=None, obs_output=None, + start_time=None, end_time=None, parent_observation_id=None, metadata=None, model=None): + o = MagicMock() + o.id, o.trace_id, o.type, o.name = obs_id, trace_id, obs_type, name + o.start_time = start_time or datetime(2025, 1, 15, 10, 0, 0, tzinfo=timezone.utc) + o.end_time = end_time or datetime(2025, 1, 15, 10, 0, 5, tzinfo=timezone.utc) + o.input, o.output = obs_input, obs_output + o.parent_observation_id = parent_observation_id + o.metadata, o.model = metadata or {}, model + o.level, o.usage, o.usage_details = "DEFAULT", None, None + return o + + +def _paginated(data, page=1, total_pages=1): + r = MagicMock() + r.data, r.meta = data, _meta(page=page, total_pages=total_pages, total_items=len(data)) + return r + + +def _lf_session(sid): + s = MagicMock() + s.id, s.created_at, s.project_id, s.environment = sid, datetime(2025, 1, 15, tzinfo=timezone.utc), "p1", "prod" + return s + + +@pytest.fixture +def mock_client(): + return MagicMock() + + +@pytest.fixture +def provider(mock_client): + with patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client): + from strands_evals.providers.langfuse_provider import LangfuseProvider + return LangfuseProvider(public_key="pk-test", secret_key="sk-test") + + +# --- Constructor --- + + +class TestConstructor: + def test_explicit_credentials(self, mock_client): + with patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client) as cls: + from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider(public_key="pk-1", secret_key="sk-2", host="https://custom.langfuse.com") + cls.assert_called_once_with(public_key="pk-1", secret_key="sk-2", host="https://custom.langfuse.com") + + def test_env_var_fallback(self, mock_client): + env = {"LANGFUSE_PUBLIC_KEY": "pk-env", "LANGFUSE_SECRET_KEY": "sk-env"} + # Remove LANGFUSE_HOST so we get the default + clean_env = os.environ.copy() + clean_env.pop("LANGFUSE_HOST", None) + clean_env.update(env) + with ( + patch.dict(os.environ, clean_env, clear=True), + patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client) as cls, + ): + from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider() + cls.assert_called_once_with(public_key="pk-env", secret_key="sk-env", host="https://us.cloud.langfuse.com") + + def test_host_env_var_fallback(self, mock_client): + env = { + "LANGFUSE_PUBLIC_KEY": "pk-env", + "LANGFUSE_SECRET_KEY": "sk-env", + "LANGFUSE_HOST": "https://my-langfuse.example.com", + } + with ( + patch.dict(os.environ, env), + patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client) as cls, + ): + from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider() + cls.assert_called_once_with( + public_key="pk-env", secret_key="sk-env", host="https://my-langfuse.example.com" + ) + + def test_missing_credentials_raises(self): + env = os.environ.copy() + env.pop("LANGFUSE_PUBLIC_KEY", None) + env.pop("LANGFUSE_SECRET_KEY", None) + with patch.dict(os.environ, env, clear=True), pytest.raises(ProviderError, match="Langfuse credentials"): + from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider() + + def test_default_host(self, mock_client): + # Remove LANGFUSE_HOST so we get the default + clean_env = os.environ.copy() + clean_env.pop("LANGFUSE_HOST", None) + with ( + patch.dict(os.environ, clean_env, clear=True), + patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client) as cls, + ): + from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider(public_key="pk", secret_key="sk") + assert cls.call_args[1]["host"] == "https://us.cloud.langfuse.com" + + def test_explicit_host_overrides_env(self, mock_client): + env = {"LANGFUSE_HOST": "https://env-host.example.com"} + with ( + patch.dict(os.environ, env), + patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client) as cls, + ): + from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider(public_key="pk", secret_key="sk", host="https://explicit.example.com") + assert cls.call_args[1]["host"] == "https://explicit.example.com" + + +# --- get_evaluation_data --- + + +class TestGetEvaluationData: + def test_happy_path(self, provider, mock_client): + mock_client.api.trace.list.return_value = _paginated([_trace("t1", "s1")]) + mock_client.api.observations.get_many.return_value = _paginated([ + _obs("o1", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a"), + ]) + result = provider.get_evaluation_data("s1") + assert isinstance(result["trajectory"], Session) + assert result["trajectory"].session_id == "s1" + assert len(result["trajectory"].traces) == 1 + + def test_empty_session_raises(self, provider, mock_client): + mock_client.api.trace.list.return_value = _paginated([]) + with pytest.raises(SessionNotFoundError, match="s-missing"): + provider.get_evaluation_data("s-missing") + + def test_output_from_last_agent_invocation(self, provider, mock_client): + mock_client.api.trace.list.return_value = _paginated([_trace("t1", "s1")]) + mock_client.api.observations.get_many.return_value = _paginated([ + _obs("o1", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q1"}], obs_output="first", + start_time=datetime(2025, 1, 15, 10, 0, 0, tzinfo=timezone.utc)), + _obs("o2", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q2"}], obs_output="second", + start_time=datetime(2025, 1, 15, 10, 1, 0, tzinfo=timezone.utc)), + ]) + assert provider.get_evaluation_data("s1")["output"] == "second" + + def test_paginates_traces(self, provider, mock_client): + mock_client.api.trace.list.side_effect = [ + _paginated([_trace("t1", "s1")], page=1, total_pages=2), + _paginated([_trace("t2", "s1")], page=2, total_pages=2), + ] + mock_client.api.observations.get_many.side_effect = [ + _paginated([_obs("o1", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a")]), + _paginated([_obs("o2", "t2", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="b")]), + ] + assert len(provider.get_evaluation_data("s1")["trajectory"].traces) == 2 + + def test_paginates_observations(self, provider, mock_client): + mock_client.api.trace.list.return_value = _paginated([_trace("t1", "s1")]) + mock_client.api.observations.get_many.side_effect = [ + _paginated([_obs("o1", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a")], + page=1, total_pages=2), + _paginated([_obs("o2", "t1", "GENERATION", name="chat", + obs_input=[{"role": "user", "content": [{"text": "q"}]}], + obs_output={"role": "assistant", "content": [{"text": "a"}]})], + page=2, total_pages=2), + ] + assert len(provider.get_evaluation_data("s1")["trajectory"].traces[0].spans) == 2 + + def test_wraps_api_error(self, provider, mock_client): + mock_client.api.trace.list.side_effect = Exception("Connection refused") + with pytest.raises(ProviderError, match="Connection refused"): + provider.get_evaluation_data("s1") + + def test_unconvertible_observations_excluded(self, provider, mock_client): + mock_client.api.trace.list.return_value = _paginated([_trace("t1", "s1")]) + mock_client.api.observations.get_many.return_value = _paginated([ + _obs("o1", "t1", "EVENT", name="some_event"), + ]) + with pytest.raises(SessionNotFoundError): + provider.get_evaluation_data("s1") + + +# --- Observation conversion --- + + +class TestConversion: + def _get_spans(self, provider, mock_client, observations): + mock_client.api.trace.list.return_value = _paginated([_trace("t1", "s1")]) + mock_client.api.observations.get_many.return_value = _paginated(observations) + return provider.get_evaluation_data("s1")["trajectory"].traces[0].spans + + def test_generation_to_inference_span(self, provider, mock_client): + spans = self._get_spans(provider, mock_client, [ + _obs("o-gen", "t1", "GENERATION", name="chat", + obs_input=[{"role": "user", "content": [{"text": "q"}]}], + obs_output={"role": "assistant", "content": [{"text": "a"}]}), + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", + obs_input=[{"text": "q"}], obs_output="a"), + ]) + inference = [s for s in spans if isinstance(s, InferenceSpan)] + assert len(inference) == 1 + assert inference[0].span_info.span_id == "o-gen" + + def test_execute_tool_to_tool_execution_span(self, provider, mock_client): + spans = self._get_spans(provider, mock_client, [ + _obs("o-tool", "t1", "SPAN", name="execute_tool calc", + obs_input={"name": "calc", "arguments": {"x": "2+2"}, "toolUseId": "c1"}, + obs_output={"result": "4", "status": "success"}), + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", + obs_input=[{"text": "q"}], obs_output="a"), + ]) + tools = [s for s in spans if isinstance(s, ToolExecutionSpan)] + assert len(tools) == 1 + assert tools[0].tool_call.name == "calc" + assert tools[0].tool_call.arguments == {"x": "2+2"} + + def test_invoke_agent_to_agent_invocation_span(self, provider, mock_client): + spans = self._get_spans(provider, mock_client, [ + _obs("o-agent", "t1", "SPAN", name="invoke_agent my_agent", + obs_input=[{"text": "Hello"}], obs_output="Hi there!"), + ]) + agents = [s for s in spans if isinstance(s, AgentInvocationSpan)] + assert len(agents) == 1 + assert agents[0].user_prompt == "Hello" + assert agents[0].agent_response == "Hi there!" + + def test_unknown_type_skipped(self, provider, mock_client): + spans = self._get_spans(provider, mock_client, [ + _obs("o-event", "t1", "EVENT", name="log"), + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", + obs_input=[{"text": "q"}], obs_output="a"), + ]) + assert len(spans) == 1 + assert isinstance(spans[0], AgentInvocationSpan) + + def test_unknown_span_name_skipped(self, provider, mock_client): + spans = self._get_spans(provider, mock_client, [ + _obs("o-unk", "t1", "SPAN", name="some_other_op"), + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", + obs_input=[{"text": "q"}], obs_output="a"), + ]) + assert len(spans) == 1 + + def test_span_info_populated(self, provider, mock_client): + start = datetime(2025, 6, 1, 12, 0, 0, tzinfo=timezone.utc) + end = datetime(2025, 6, 1, 12, 0, 10, tzinfo=timezone.utc) + spans = self._get_spans(provider, mock_client, [ + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", + obs_input=[{"text": "q"}], obs_output="a", + start_time=start, end_time=end, parent_observation_id="o-parent"), + ]) + si = spans[0].span_info + assert si.trace_id == "t1" + assert si.span_id == "o-agent" + assert si.session_id == "s1" + assert si.parent_span_id == "o-parent" + assert si.start_time == start + assert si.end_time == end + + def test_string_input_for_agent(self, provider, mock_client): + spans = self._get_spans(provider, mock_client, [ + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", + obs_input="plain string prompt", obs_output="response"), + ]) + assert spans[0].user_prompt == "plain string prompt" + + def test_string_output_for_tool(self, provider, mock_client): + spans = self._get_spans(provider, mock_client, [ + _obs("o-tool", "t1", "SPAN", name="execute_tool calc", + obs_input={"name": "calc", "arguments": {"x": 1}}, obs_output="42"), + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", + obs_input=[{"text": "q"}], obs_output="a"), + ]) + tools = [s for s in spans if isinstance(s, ToolExecutionSpan)] + assert tools[0].tool_result.content == "42" + + +# --- list_sessions --- + + +class TestListSessions: + def test_yields_ids(self, provider, mock_client): + mock_client.api.sessions.list.return_value = _paginated( + [_lf_session("s1"), _lf_session("s2"), _lf_session("s3")]) + assert list(provider.list_sessions()) == ["s1", "s2", "s3"] + + def test_paginates(self, provider, mock_client): + mock_client.api.sessions.list.side_effect = [ + _paginated([_lf_session("s1")], page=1, total_pages=2), + _paginated([_lf_session("s2")], page=2, total_pages=2), + ] + assert list(provider.list_sessions()) == ["s1", "s2"] + + def test_time_filter(self, provider, mock_client): + from strands_evals.providers.trace_provider import SessionFilter + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 31, tzinfo=timezone.utc) + mock_client.api.sessions.list.return_value = _paginated([_lf_session("s1")]) + list(provider.list_sessions(session_filter=SessionFilter(start_time=start, end_time=end))) + kw = mock_client.api.sessions.list.call_args[1] + assert kw["from_timestamp"] == start + assert kw["to_timestamp"] == end + + def test_empty(self, provider, mock_client): + mock_client.api.sessions.list.return_value = _paginated([]) + assert list(provider.list_sessions()) == [] + + def test_wraps_error(self, provider, mock_client): + mock_client.api.sessions.list.side_effect = Exception("API error") + with pytest.raises(ProviderError, match="API error"): + list(provider.list_sessions()) + + +# --- get_evaluation_data_by_trace_id --- + + +class TestGetEvaluationDataByTraceId: + def _trace_detail(self, trace_id="t1", session_id="s1", output="answer", observations=None): + td = MagicMock() + td.id, td.session_id, td.output = trace_id, session_id, output + td.observations = observations or [ + _obs("o-agent", trace_id, "SPAN", name="invoke_agent a", + obs_input=[{"text": "q"}], obs_output=output), + ] + return td + + def test_happy_path(self, provider, mock_client): + mock_client.api.trace.get.return_value = self._trace_detail() + result = provider.get_evaluation_data_by_trace_id("t1") + assert isinstance(result["trajectory"], Session) + assert result["trajectory"].traces[0].trace_id == "t1" + + def test_not_found_raises(self, provider, mock_client): + mock_client.api.trace.get.side_effect = Exception("Not found") + with pytest.raises(TraceNotFoundError, match="t-missing"): + provider.get_evaluation_data_by_trace_id("t-missing") + + def test_uses_trace_session_id(self, provider, mock_client): + mock_client.api.trace.get.return_value = self._trace_detail(session_id="from-trace") + result = provider.get_evaluation_data_by_trace_id("t1") + assert result["trajectory"].session_id == "from-trace" + + def test_no_session_id_falls_back_to_trace_id(self, provider, mock_client): + mock_client.api.trace.get.return_value = self._trace_detail(session_id=None) + result = provider.get_evaluation_data_by_trace_id("t1") + assert result["trajectory"].session_id == "t1" diff --git a/tests_integ/test_langfuse_provider.py b/tests_integ/test_langfuse_provider.py new file mode 100644 index 0000000..cdb5e3f --- /dev/null +++ b/tests_integ/test_langfuse_provider.py @@ -0,0 +1,242 @@ +"""Integration tests for LangfuseProvider against a real Langfuse instance. + +Requires LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment variables. +Run with: pytest tests_integ/test_langfuse_provider.py -v +""" + +import pytest + +from strands_evals import Case, Experiment +from strands_evals.evaluators import ( + CoherenceEvaluator, + HelpfulnessEvaluator, + OutputEvaluator, +) +from strands_evals.providers.exceptions import ( + ProviderError, + SessionNotFoundError, + TraceNotFoundError, +) +from strands_evals.providers.langfuse_provider import LangfuseProvider +from strands_evals.types.trace import ( + AgentInvocationSpan, + InferenceSpan, + Session, + ToolExecutionSpan, + Trace, +) + + +@pytest.fixture(scope="module") +def provider(): + """Create a LangfuseProvider using env var credentials.""" + try: + return LangfuseProvider() + except ProviderError as e: + pytest.skip(f"Langfuse credentials not available: {e}") + + +@pytest.fixture(scope="module") +def discovered_session_id(provider): + """Discover a session ID from Langfuse that has convertible observations.""" + for session_id in provider.list_sessions(): + try: + provider.get_evaluation_data(session_id) + return session_id + except SessionNotFoundError: + # Session exists but has no convertible observations, try next + continue + pytest.skip("No sessions with convertible observations found in Langfuse") + + +@pytest.fixture(scope="module") +def evaluation_data(provider, discovered_session_id): + """Fetch evaluation data for the discovered session.""" + return provider.get_evaluation_data(discovered_session_id) + + +class TestListSessions: + def test_returns_at_least_one_session(self, provider): + sessions = list(provider.list_sessions()) + assert len(sessions) > 0, "Expected at least one session in Langfuse" + + def test_session_ids_are_strings(self, provider): + for session_id in provider.list_sessions(): + assert isinstance(session_id, str) + assert len(session_id) > 0 + break # Only check the first one + + +class TestGetEvaluationData: + def test_returns_session_with_traces(self, evaluation_data, discovered_session_id): + session = evaluation_data["trajectory"] + assert isinstance(session, Session) + assert session.session_id == discovered_session_id + assert len(session.traces) > 0 + + def test_traces_have_spans(self, evaluation_data): + session = evaluation_data["trajectory"] + for trace in session.traces: + assert isinstance(trace, Trace) + assert isinstance(trace.trace_id, str) + assert len(trace.trace_id) > 0 + assert len(trace.spans) > 0 + + def test_spans_are_typed(self, evaluation_data): + """All spans should be one of the three known types.""" + valid_types = (AgentInvocationSpan, InferenceSpan, ToolExecutionSpan) + session = evaluation_data["trajectory"] + for trace in session.traces: + for span in trace.spans: + assert isinstance(span, valid_types), ( + f"Unexpected span type: {type(span).__name__}" + ) + + def test_has_agent_invocation_span(self, evaluation_data): + """At least one trace should have an AgentInvocationSpan.""" + session = evaluation_data["trajectory"] + agent_spans = [ + span + for trace in session.traces + for span in trace.spans + if isinstance(span, AgentInvocationSpan) + ] + assert len(agent_spans) > 0, "Expected at least one AgentInvocationSpan" + + def test_agent_invocation_has_prompt_and_response(self, evaluation_data): + session = evaluation_data["trajectory"] + for trace in session.traces: + for span in trace.spans: + if isinstance(span, AgentInvocationSpan): + assert isinstance(span.user_prompt, str) + assert isinstance(span.agent_response, str) + assert len(span.user_prompt) > 0 + assert len(span.agent_response) > 0 + return + pytest.fail("No AgentInvocationSpan found") + + def test_output_is_nonempty_string(self, evaluation_data): + assert isinstance(evaluation_data["output"], str) + assert len(evaluation_data["output"]) > 0 + + def test_span_info_populated(self, evaluation_data): + session = evaluation_data["trajectory"] + for trace in session.traces: + for span in trace.spans: + si = span.span_info + assert si.trace_id is not None + assert si.span_id is not None + assert si.session_id == session.session_id + assert si.start_time is not None + assert si.end_time is not None + return + pytest.fail("No spans found to check") + + def test_nonexistent_session_raises(self, provider): + with pytest.raises(SessionNotFoundError): + provider.get_evaluation_data("nonexistent-session-id-that-does-not-exist-12345") + + +class TestGetEvaluationDataByTraceId: + def test_fetches_by_trace_id(self, provider, evaluation_data): + """Use a trace_id from the discovered session to test trace-level retrieval.""" + session = evaluation_data["trajectory"] + trace_id = session.traces[0].trace_id + + result = provider.get_evaluation_data_by_trace_id(trace_id) + + assert isinstance(result["trajectory"], Session) + assert len(result["trajectory"].traces) > 0 + assert result["trajectory"].traces[0].trace_id == trace_id + + def test_nonexistent_trace_raises(self, provider): + with pytest.raises((TraceNotFoundError, ProviderError)): + provider.get_evaluation_data_by_trace_id("nonexistent-trace-id-12345") + + +# --- End-to-end: Langfuse → Evaluator pipeline --- + + +class TestEndToEnd: + """Fetch traces from Langfuse and run real evaluators on them.""" + + def test_output_evaluator_on_remote_trace(self, provider, discovered_session_id): + """OutputEvaluator produces a valid score from a Langfuse session.""" + + def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + + cases = [ + Case( + name="langfuse_session", + input=discovered_session_id, + expected_output="any agent response", + ), + ] + + evaluator = OutputEvaluator( + rubric="Score 1.0 if the output is a coherent response from an AI agent. " + "Score 0.0 if the output is empty or clearly broken.", + ) + + experiment = Experiment(cases=cases, evaluators=[evaluator]) + reports = experiment.run_evaluations(task) + + assert len(reports) == 1 + report = reports[0] + assert report.score is not None + assert 0.0 <= report.score <= 1.0 + assert len(report.case_results) == 1 + + def test_coherence_evaluator_on_remote_trace(self, provider, discovered_session_id): + """CoherenceEvaluator produces a valid score from a Langfuse session.""" + + def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + + cases = [ + Case( + name="langfuse_session", + input=discovered_session_id, + expected_output="any agent response", + ), + ] + + evaluator = CoherenceEvaluator() + + experiment = Experiment(cases=cases, evaluators=[evaluator]) + reports = experiment.run_evaluations(task) + + assert len(reports) == 1 + report = reports[0] + assert report.score is not None + assert 0.0 <= report.score <= 1.0 + + def test_multiple_evaluators_on_remote_trace(self, provider, discovered_session_id): + """Multiple evaluators can all run on the same Langfuse session data.""" + + def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + + cases = [ + Case( + name="langfuse_session", + input=discovered_session_id, + expected_output="any agent response", + ), + ] + + evaluators = [ + OutputEvaluator(rubric="Score 1.0 if the output is coherent. Score 0.0 otherwise."), + CoherenceEvaluator(), + HelpfulnessEvaluator(), + ] + + experiment = Experiment(cases=cases, evaluators=evaluators) + reports = experiment.run_evaluations(task) + + assert len(reports) == 3 + for report in reports: + assert report.score is not None + assert 0.0 <= report.score <= 1.0 + assert len(report.case_results) == 1 From 285a4344d1176710558e4b6b85ca505cd0366cd2 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Wed, 18 Feb 2026 16:56:44 -0500 Subject: [PATCH 5/8] feat: add trace providers for CloudWatch and Langfuse backends Add TraceProvider interface and implementations for fetching agent execution data from observability backends (CloudWatch Logs and Langfuse). This enables running evaluators against production/staging traces without re-executing agents. - Add CloudWatchProvider for Bedrock AgentCore runtime logs - Add LangfuseProvider for Langfuse-hosted traces - Add TraceProvider base interface with get_evaluation_data API - Add comprehensive test suites for both providers - Add providers README with usage documentation - Fix minor whitespace issues in CoherenceEvaluator docstring --- .../evaluators/coherence_evaluator.py | 4 +- src/strands_evals/providers/README.md | 119 +++ src/strands_evals/providers/__init__.py | 2 + .../providers/cloudwatch_provider.py | 547 +++++++++++++ .../providers/langfuse_provider.py | 38 +- .../providers/test_cloudwatch_provider.py | 734 ++++++++++++++++++ .../providers/test_langfuse_provider.py | 254 ++++-- tests_integ/test_cloudwatch_provider.py | 271 +++++++ tests_integ/test_langfuse_provider.py | 22 +- 9 files changed, 1886 insertions(+), 105 deletions(-) create mode 100644 src/strands_evals/providers/README.md create mode 100644 src/strands_evals/providers/cloudwatch_provider.py create mode 100644 tests/strands_evals/providers/test_cloudwatch_provider.py create mode 100644 tests_integ/test_cloudwatch_provider.py diff --git a/src/strands_evals/evaluators/coherence_evaluator.py b/src/strands_evals/evaluators/coherence_evaluator.py index 3813eec..8f95fb7 100644 --- a/src/strands_evals/evaluators/coherence_evaluator.py +++ b/src/strands_evals/evaluators/coherence_evaluator.py @@ -31,11 +31,11 @@ class CoherenceRating(BaseModel): class CoherenceEvaluator(Evaluator[InputT, OutputT]): """Evaluates the logical cohesion of the assistant's response. - + This evaluator assesses whether the assistant's response maintains logical consistency, flows naturally, and presents ideas in a well-organized manner. It uses an LLM-as-judge approach to provide categorical ratings that are then normalized to numeric scores. - + Scores: - NOT_AT_ALL (0.0): Response is completely incoherent or contradictory - NOT_GENERALLY (0.25): Response has significant logical gaps or inconsistencies diff --git a/src/strands_evals/providers/README.md b/src/strands_evals/providers/README.md new file mode 100644 index 0000000..72bb4bb --- /dev/null +++ b/src/strands_evals/providers/README.md @@ -0,0 +1,119 @@ +# Trace Providers + +Trace providers fetch agent execution data from observability backends and convert it into the format the evaluation pipeline expects. This lets you run evaluators against traces from production or staging agents without re-running them. + +## Available Providers + +| Provider | Backend | Auth | +|----------|---------|------| +| `CloudWatchProvider` | AWS CloudWatch Logs (Bedrock AgentCore runtime logs) | AWS credentials (boto3) | +| `LangfuseProvider` | Langfuse | API keys | + +## Quick Start + +### CloudWatch + +```python +from strands_evals.providers import CloudWatchProvider + +# Option 1: Provide the log group directly +provider = CloudWatchProvider( + log_group="/aws/bedrock-agentcore/runtimes/my-agent-abc123-DEFAULT", + region="us-east-1", +) + +# Option 2: Discover the log group from the agent name +provider = CloudWatchProvider(agent_name="my-agent", region="us-east-1") +``` + +### Langfuse + +```python +from strands_evals.providers import LangfuseProvider + +# Reads LANGFUSE_PUBLIC_KEY / LANGFUSE_SECRET_KEY from env by default +provider = LangfuseProvider() + +# Or pass credentials explicitly +provider = LangfuseProvider( + public_key="pk-...", + secret_key="sk-...", + host="https://us.cloud.langfuse.com", +) +``` + +## Core API + +All providers implement the `TraceProvider` interface: + +```python +# Fetch traces for a session, ready for evaluation +data = provider.get_evaluation_data(session_id="my-session-id") +# data["output"] -> str (final agent response) +# data["trajectory"] -> Session (traces and spans) + +# Discover session IDs +for session_id in provider.list_sessions(): + print(session_id) + +# Fetch a single trace by ID +data = provider.get_evaluation_data_by_trace_id(trace_id="abc123") +``` + +## Running Evaluators on Remote Traces + +Pass the provider's data into the standard `Experiment` pipeline: + +```python +from strands_evals import Case, Experiment +from strands_evals.evaluators import CoherenceEvaluator, OutputEvaluator +from strands_evals.providers import CloudWatchProvider + +provider = CloudWatchProvider(log_group="/aws/...", region="us-east-1") + +def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + +cases = [Case(name="session_1", input="my-session-id", expected_output="any")] +evaluators = [ + OutputEvaluator(rubric="Score 1.0 if the output is coherent. Score 0.0 otherwise."), + CoherenceEvaluator(), +] + +experiment = Experiment(cases=cases, evaluators=evaluators) +reports = experiment.run_evaluations(task) + +for report in reports: + print(f"{report.overall_score:.2f} - {report.reasons}") +``` + +## Error Handling + +```python +from strands_evals.providers import SessionNotFoundError, TraceNotFoundError, ProviderError + +try: + data = provider.get_evaluation_data("unknown-session") +except SessionNotFoundError: + print("No traces found for that session") +except ProviderError: + print("Provider unreachable or query failed") +``` + +## Implementing a Custom Provider + +Subclass `TraceProvider` and implement `get_evaluation_data`: + +```python +from strands_evals.providers import TraceProvider + +class MyProvider(TraceProvider): + def get_evaluation_data(self, session_id: str) -> dict: + # Fetch traces from your backend, return: + # {"output": "final response text", "trajectory": Session(...)} + ... + + def list_sessions(self, session_filter=None): + # Optional: yield session ID strings + ... +``` diff --git a/src/strands_evals/providers/__init__.py b/src/strands_evals/providers/__init__.py index fcbab85..ed92af7 100644 --- a/src/strands_evals/providers/__init__.py +++ b/src/strands_evals/providers/__init__.py @@ -1,3 +1,4 @@ +from .cloudwatch_provider import CloudWatchProvider from .exceptions import ( ProviderError, SessionNotFoundError, @@ -11,6 +12,7 @@ ) __all__ = [ + "CloudWatchProvider", "LangfuseProvider", "ProviderError", "SessionFilter", diff --git a/src/strands_evals/providers/cloudwatch_provider.py b/src/strands_evals/providers/cloudwatch_provider.py new file mode 100644 index 0000000..d8e90ab --- /dev/null +++ b/src/strands_evals/providers/cloudwatch_provider.py @@ -0,0 +1,547 @@ +"""CloudWatch trace provider for retrieving agent traces from AWS CloudWatch Logs.""" + +import json +import logging +import os +import time +from collections import defaultdict +from collections.abc import Iterator +from datetime import datetime, timedelta, timezone +from typing import Any + +import boto3 + +from ..providers.exceptions import ProviderError, SessionNotFoundError, TraceNotFoundError +from ..providers.trace_provider import SessionFilter, TraceProvider +from ..types.evaluation import TaskOutput +from ..types.trace import ( + AgentInvocationSpan, + AssistantMessage, + InferenceSpan, + Session, + SpanInfo, + TextContent, + ToolCall, + ToolCallContent, + ToolConfig, + ToolExecutionSpan, + ToolResult, + ToolResultContent, + Trace, + UserMessage, +) + +logger = logging.getLogger(__name__) + + +class CloudWatchProvider(TraceProvider): + """Retrieves agent trace data from AWS CloudWatch Logs for evaluation. + + Queries CloudWatch Logs Insights to fetch OTEL log records from an + agent-specific runtime log group, parses body.input/output messages, + and returns Session objects ready for the evaluation pipeline. + """ + + def __init__( + self, + region: str | None = None, + log_group: str | None = None, + agent_name: str | None = None, + lookback_days: int = 30, + query_timeout_seconds: float = 60.0, + ): + resolved_region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") + + try: + self._client = boto3.client("logs", region_name=resolved_region) + except Exception as e: + raise ProviderError(f"CloudWatch: failed to create boto3 logs client: {e}") from e + + if log_group: + self._log_group = log_group + elif agent_name: + self._log_group = self._discover_log_group(agent_name) + else: + raise ProviderError("CloudWatch: either log_group or agent_name must be provided") + + self._lookback_days = lookback_days + self._query_timeout_seconds = query_timeout_seconds + + def _discover_log_group(self, agent_name: str) -> str: + """Discover the runtime log group for an agent via describe_log_groups.""" + prefix = f"/aws/bedrock-agentcore/runtimes/{agent_name}" + response = self._client.describe_log_groups(logGroupNamePrefix=prefix) + log_groups = response.get("logGroups", []) + if not log_groups: + raise ProviderError(f"CloudWatch: no log group found for agent_name='{agent_name}' (prefix={prefix})") + return log_groups[0]["logGroupName"] + + def get_evaluation_data(self, session_id: str) -> TaskOutput: + """Fetch all traces for a session and return evaluation data.""" + query = f"fields @message | filter attributes.session.id = '{session_id}' | sort @timestamp asc | limit 10000" + + try: + span_dicts = self._run_logs_insights_query(query) + except ProviderError: + raise + except Exception as e: + raise ProviderError(f"CloudWatch: failed to query spans for session '{session_id}': {e}") from e + + if not span_dicts: + raise SessionNotFoundError(f"CloudWatch: no spans found for session_id='{session_id}'") + + session = self._build_session(session_id, span_dicts) + + if not session.traces: + raise SessionNotFoundError( + f"CloudWatch: spans found for session_id='{session_id}' but none contained convertible spans" + ) + + output = self._extract_output(session) + return TaskOutput(output=output, trajectory=session) + + def get_evaluation_data_by_trace_id(self, trace_id: str) -> TaskOutput: + """Fetch a single trace by ID and return evaluation data.""" + query = f'fields @message | filter traceId = "{trace_id}" | sort @timestamp asc | limit 10000' + + try: + span_dicts = self._run_logs_insights_query(query) + except ProviderError: + raise + except Exception as e: + raise ProviderError(f"CloudWatch: failed to query trace '{trace_id}': {e}") from e + + if not span_dicts: + raise TraceNotFoundError(f"CloudWatch: no spans found for trace_id='{trace_id}'") + + # Extract session_id from the first record's attributes + first_attrs = span_dicts[0].get("attributes", {}) + session_id = first_attrs.get("session.id") or trace_id + + session = self._build_session(session_id, span_dicts) + output = self._extract_output(session) + return TaskOutput(output=output, trajectory=session) + + def list_sessions(self, session_filter: SessionFilter | None = None) -> Iterator[str]: + """Yield distinct session IDs from CloudWatch Logs.""" + limit = session_filter.limit if session_filter and session_filter.limit else 1000 + start_time = session_filter.start_time if session_filter else None + end_time = session_filter.end_time if session_filter else None + + query = ( + "fields attributes.session.id as sessionId" + " | filter ispresent(attributes.session.id)" + " | stats count(*) as span_count by sessionId" + " | sort sessionId asc" + f" | limit {limit}" + ) + + try: + results = self._run_raw_logs_insights_query(query, start_time=start_time, end_time=end_time) + except ProviderError: + raise + except Exception as e: + raise ProviderError(f"CloudWatch: failed to list sessions: {e}") from e + + for row in results: + for field in row: + if field.get("field") == "sessionId": + yield field["value"] + + # --- Internal: CW Logs Insights query execution --- + + def _run_logs_insights_query( + self, query: str, start_time: datetime | None = None, end_time: datetime | None = None + ) -> list[dict[str, Any]]: + """Execute a CW Logs Insights query and return parsed span dicts from @message fields.""" + raw_results = self._run_raw_logs_insights_query(query, start_time=start_time, end_time=end_time) + return self._parse_query_results(raw_results) + + def _run_raw_logs_insights_query( + self, + query: str, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> list[list[dict[str, str]]]: + """Execute a CW Logs Insights query and return raw result rows.""" + now = datetime.now(tz=timezone.utc) + if end_time is None: + end_time = now + if start_time is None: + start_time = now - timedelta(days=self._lookback_days) + + try: + response = self._client.start_query( + logGroupName=self._log_group, + startTime=int(start_time.timestamp()), + endTime=int(end_time.timestamp()), + queryString=query, + ) + except Exception as e: + raise ProviderError(f"CloudWatch: failed to start query: {e}") from e + + query_id = response["queryId"] + return self._poll_query_results(query_id) + + def _poll_query_results(self, query_id: str) -> list[list[dict[str, str]]]: + """Poll for query completion with exponential backoff. Returns raw result rows.""" + delay = 0.5 + max_delay = 8.0 + deadline = time.monotonic() + self._query_timeout_seconds + + while True: + response = self._client.get_query_results(queryId=query_id) + status = response.get("status", "") + + if status == "Complete": + return response.get("results", []) + elif status in ("Failed", "Cancelled", "Timeout"): + raise ProviderError(f"CloudWatch: query {status}") + + if time.monotonic() >= deadline: + raise ProviderError(f"CloudWatch: query timed out after {self._query_timeout_seconds}s") + + time.sleep(delay) + delay = min(delay * 2, max_delay) + + @staticmethod + def _parse_query_results(results: list[list[dict[str, str]]]) -> list[dict[str, Any]]: + """Parse @message fields from CW Logs Insights results into span dicts.""" + span_dicts = [] + for row in results: + for field in row: + if field.get("field") == "@message": + try: + span_dicts.append(json.loads(field["value"])) + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Failed to parse @message: %s", e) + return span_dicts + + # --- Internal: session building (body-based parsing) --- + + def _build_session(self, session_id: str, records: list[dict[str, Any]]) -> Session: + """Group log records by traceId, convert each group to a Trace, return a Session.""" + traces_by_id: dict[str, list[dict[str, Any]]] = defaultdict(list) + for record in records: + trace_id = record.get("traceId", "") + if not trace_id: + continue + traces_by_id[trace_id].append(record) + + traces: list[Trace] = [] + for trace_id, trace_records in traces_by_id.items(): + trace = self._convert_trace(trace_id, trace_records, session_id) + if trace.spans: + traces.append(trace) + + return Session(session_id=session_id, traces=traces) + + def _convert_trace(self, trace_id: str, records: list[dict[str, Any]], session_id: str) -> Trace: + """Convert a group of log records (same traceId) into a Trace with typed spans.""" + sorted_records = sorted(records, key=lambda r: r.get("timeUnixNano", 0)) + + spans: list[InferenceSpan | ToolExecutionSpan | AgentInvocationSpan] = [] + + # Collect all tool calls and results across records + all_tool_calls: dict[str, ToolCall] = {} + all_tool_results: dict[str, ToolResult] = {} + + for record in sorted_records: + if not isinstance(record.get("body"), dict): + continue + + for tc in self._extract_tool_calls(record): + if tc.tool_call_id: + all_tool_calls[tc.tool_call_id] = tc + + for tr in self._extract_tool_results(record): + if tr.tool_call_id: + all_tool_results[tr.tool_call_id] = tr + + # Create InferenceSpans (one per record with parseable body) + for record in sorted_records: + if not isinstance(record.get("body"), dict): + continue + + try: + messages = self._record_to_messages(record) + if messages: + span_info = self._create_span_info(record, session_id) + spans.append(InferenceSpan(span_info=span_info, messages=messages, metadata={})) + except Exception as e: + logger.warning("Failed to create inference span from record %s: %s", record.get("spanId"), e) + + # Create ToolExecutionSpans by matching calls to results + seen_tool_ids: set[str] = set() + for record in sorted_records: + for tc in self._extract_tool_calls(record): + if tc.tool_call_id and tc.tool_call_id not in seen_tool_ids: + seen_tool_ids.add(tc.tool_call_id) + tr = all_tool_results.get(tc.tool_call_id, ToolResult(content="", tool_call_id=tc.tool_call_id)) + span_info = self._create_span_info(record, session_id) + spans.append(ToolExecutionSpan(span_info=span_info, tool_call=tc, tool_result=tr, metadata={})) + + # Create AgentInvocationSpan from first user prompt + last agent response + agent_span = self._create_agent_invocation_span(sorted_records, all_tool_calls, session_id) + if agent_span: + spans.append(agent_span) + + return Trace(spans=spans, trace_id=trace_id, session_id=session_id) + + def _create_agent_invocation_span( + self, records: list[dict[str, Any]], tool_calls: dict[str, ToolCall], session_id: str + ) -> AgentInvocationSpan | None: + """Create an AgentInvocationSpan from the first user prompt and last agent response.""" + user_prompt = None + for record in records: + prompt = self._extract_user_prompt(record) + if prompt: + user_prompt = prompt + break + + if not user_prompt: + return None + + agent_response = None + best_record = None + for record in reversed(records): + response = self._extract_agent_response(record) + if response: + agent_response = response + best_record = record + break + + if not agent_response or not best_record: + return None + + available_tools = [ToolConfig(name=name) for name in sorted({tc.name for tc in tool_calls.values()})] + span_info = self._create_span_info(best_record, session_id) + + return AgentInvocationSpan( + span_info=span_info, + user_prompt=user_prompt, + agent_response=agent_response, + available_tools=available_tools, + metadata={}, + ) + + def _extract_output(self, session: Session) -> str: + """Extract the final agent response from the session for TaskOutput.output.""" + for trace in reversed(session.traces): + for span in reversed(trace.spans): + if isinstance(span, AgentInvocationSpan): + return span.agent_response + return "" + + # --- Internal: span info --- + + def _create_span_info(self, record: dict[str, Any], session_id: str) -> SpanInfo: + time_nano = record.get("timeUnixNano", 0) + ts = datetime.fromtimestamp(time_nano / 1e9, tz=timezone.utc) + + return SpanInfo( + trace_id=record.get("traceId", ""), + span_id=record.get("spanId", ""), + session_id=session_id, + parent_span_id=record.get("parentSpanId") or None, + start_time=ts, + end_time=ts, + ) + + # --- Internal: body-based content extraction --- + + def _parse_message_content(self, raw: str) -> list[dict[str, Any]] | None: + """Parse double-encoded message content into a list of content blocks.""" + if not isinstance(raw, str): + return None + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, list) else None + except (json.JSONDecodeError, TypeError): + return None + + def _extract_content_field(self, content: dict[str, Any]) -> str | None: + """Extract the raw content field from a message.""" + if not isinstance(content, dict): + return None + return content.get("content") or content.get("message") + + def _extract_text_from_content(self, content: Any) -> str | None: + """Extract text from a content field, handling double-encoded JSON strings.""" + raw = self._extract_content_field(content) + if not raw: + return None + + parsed = self._parse_message_content(raw) + if parsed: + texts = [item["text"] for item in parsed if isinstance(item, dict) and "text" in item] + return " ".join(texts) if texts else None + + return raw if isinstance(raw, str) else None + + def _extract_message_text(self, record: dict[str, Any], message_type: str, role: str) -> str | None: + """Extract text from a specific message type and role in a log record.""" + body = record.get("body", {}) + if not isinstance(body, dict): + return None + + messages = body.get(message_type, {}).get("messages", []) + for msg in messages: + if msg.get("role") == role: + text = self._extract_text_from_content(msg.get("content", {})) + if text: + return text + return None + + def _extract_user_prompt(self, record: dict[str, Any]) -> str | None: + """Extract user prompt text from a log record's body.input.messages.""" + return self._extract_message_text(record, "input", "user") + + def _extract_agent_response(self, record: dict[str, Any]) -> str | None: + """Extract assistant text response from a log record's body.output.messages.""" + return self._extract_message_text(record, "output", "assistant") + + def _extract_tool_calls(self, record: dict[str, Any]) -> list[ToolCall]: + """Extract tool calls from a log record's body.output.messages.""" + tool_calls: list[ToolCall] = [] + body = record.get("body", {}) + if not isinstance(body, dict): + return tool_calls + + for msg in body.get("output", {}).get("messages", []): + if msg.get("role") != "assistant": + continue + + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + for item in parsed: + if isinstance(item, dict) and "toolUse" in item: + tu = item["toolUse"] + tool_calls.append( + ToolCall( + name=tu.get("name", ""), + arguments=tu.get("input", {}), + tool_call_id=tu.get("toolUseId"), + ) + ) + + return tool_calls + + def _extract_tool_results(self, record: dict[str, Any]) -> list[ToolResult]: + """Extract tool results from a log record's body.input.messages.""" + tool_results: list[ToolResult] = [] + body = record.get("body", {}) + if not isinstance(body, dict): + return tool_results + + for msg in body.get("input", {}).get("messages", []): + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + for item in parsed: + if isinstance(item, dict) and "toolResult" in item: + tr_data = item["toolResult"] + result_text = self._extract_tool_result_text(tr_data.get("content")) + tool_results.append( + ToolResult( + content=result_text, + error=tr_data.get("error"), + tool_call_id=tr_data.get("toolUseId"), + ) + ) + + return tool_results + + def _extract_tool_result_text(self, content: Any) -> str: + """Extract text from tool result content.""" + if not content: + return "" + if isinstance(content, list) and content: + return content[0].get("text", "") + return str(content) + + # --- Internal: record-to-messages conversion --- + + def _record_to_messages(self, record: dict[str, Any]) -> list[UserMessage | AssistantMessage]: + """Convert a log record's body into a list of typed messages for InferenceSpan.""" + messages: list[UserMessage | AssistantMessage] = [] + body = record.get("body", {}) + if not isinstance(body, dict): + return messages + + # Process input messages + for msg in body.get("input", {}).get("messages", []): + role = msg.get("role", "") + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + if role == "user": + user_content = self._process_user_message(parsed) + if user_content: + messages.append(UserMessage(content=user_content)) + elif role == "tool": + tool_content = self._process_tool_results(parsed) + if tool_content: + messages.append(UserMessage(content=tool_content)) + + # Process output messages + for msg in body.get("output", {}).get("messages", []): + if msg.get("role") != "assistant": + continue + + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + assistant_content = self._process_assistant_content(parsed) + if assistant_content: + messages.append(AssistantMessage(content=assistant_content)) + + return messages + + # --- Internal: content parsing helpers (Bedrock Converse format) --- + + @staticmethod + def _process_user_message(content_list: list[dict[str, Any]]) -> list[TextContent | ToolResultContent]: + return [TextContent(text=item["text"]) for item in content_list if "text" in item] + + @staticmethod + def _process_assistant_content(content_list: list[dict[str, Any]]) -> list[TextContent | ToolCallContent]: + result: list[TextContent | ToolCallContent] = [] + for item in content_list: + if "text" in item: + result.append(TextContent(text=item["text"])) + elif "toolUse" in item: + tool_use = item["toolUse"] + result.append( + ToolCallContent( + name=tool_use["name"], + arguments=tool_use.get("input", {}), + tool_call_id=tool_use.get("toolUseId"), + ) + ) + return result + + def _process_tool_results(self, content_list: list[dict[str, Any]]) -> list[TextContent | ToolResultContent]: + result: list[TextContent | ToolResultContent] = [] + for item in content_list: + if "toolResult" not in item: + continue + tool_result = item["toolResult"] + result_text = self._extract_tool_result_text(tool_result.get("content")) + result.append( + ToolResultContent( + content=result_text, + error=tool_result.get("error"), + tool_call_id=tool_result.get("toolUseId"), + ) + ) + return result diff --git a/src/strands_evals/providers/langfuse_provider.py b/src/strands_evals/providers/langfuse_provider.py index 33c64f6..02db3f0 100644 --- a/src/strands_evals/providers/langfuse_provider.py +++ b/src/strands_evals/providers/langfuse_provider.py @@ -55,9 +55,7 @@ def __init__( host: str | None = None, ): if Langfuse is None: - raise ProviderError( - "Langfuse SDK is not installed. Install it with: pip install 'strands-evals[langfuse]'" - ) + raise ProviderError("Langfuse SDK is not installed. Install it with: pip install 'strands-evals[langfuse]'") resolved_public_key = public_key or os.environ.get("LANGFUSE_PUBLIC_KEY") resolved_secret_key = secret_key or os.environ.get("LANGFUSE_SECRET_KEY") @@ -147,9 +145,7 @@ def _fetch_traces_for_session(self, session_id: str) -> list: all_traces = [] page = 1 while True: - response = self._client.api.trace.list( - session_id=session_id, page=page, limit=_PAGE_SIZE - ) + response = self._client.api.trace.list(session_id=session_id, page=page, limit=_PAGE_SIZE) all_traces.extend(response.data) if page >= response.meta.total_pages: break @@ -161,9 +157,7 @@ def _fetch_observations(self, trace_id: str) -> list: all_observations = [] page = 1 while True: - response = self._client.api.observations.get_many( - trace_id=trace_id, page=page, limit=_PAGE_SIZE - ) + response = self._client.api.observations.get_many(trace_id=trace_id, page=page, limit=_PAGE_SIZE) all_observations.extend(response.data) if page >= response.meta.total_pages: break @@ -293,11 +287,13 @@ def _parse_assistant_content(self, content_data: Any) -> list[TextContent | Tool result.append(TextContent(text=item["text"])) elif "toolUse" in item: tu = item["toolUse"] - result.append(ToolCallContent( - name=tu["name"], - arguments=tu.get("input", {}), - tool_call_id=tu.get("toolUseId"), - )) + result.append( + ToolCallContent( + name=tu["name"], + arguments=tu.get("input", {}), + tool_call_id=tu.get("toolUseId"), + ) + ) elif isinstance(content_data, str): result.append(TextContent(text=content_data)) return result @@ -312,11 +308,13 @@ def _parse_tool_result_content(self, content_data: list) -> list[TextContent | T if "content" in tr and tr["content"]: c = tr["content"] text = c[0].get("text", "") if isinstance(c, list) else str(c) - result.append(ToolResultContent( - content=text, - error=tr.get("error"), - tool_call_id=tr.get("toolUseId"), - )) + result.append( + ToolResultContent( + content=text, + error=tr.get("error"), + tool_call_id=tr.get("toolUseId"), + ) + ) return result def _convert_tool_execution(self, obs: Any, session_id: str) -> ToolExecutionSpan: @@ -428,4 +426,4 @@ def _extract_output(self, session: Session) -> str: for span in reversed(trace.spans): if isinstance(span, AgentInvocationSpan): return span.agent_response - return "" \ No newline at end of file + return "" diff --git a/tests/strands_evals/providers/test_cloudwatch_provider.py b/tests/strands_evals/providers/test_cloudwatch_provider.py new file mode 100644 index 0000000..ef0dac6 --- /dev/null +++ b/tests/strands_evals/providers/test_cloudwatch_provider.py @@ -0,0 +1,734 @@ +"""Tests for CloudWatchProvider — mocked boto3 CloudWatch Logs client.""" + +import json +import os +from unittest.mock import MagicMock, patch + +import pytest + +from strands_evals.providers.cloudwatch_provider import CloudWatchProvider +from strands_evals.providers.exceptions import ( + ProviderError, + SessionNotFoundError, + TraceNotFoundError, +) +from strands_evals.types.trace import ( + AgentInvocationSpan, + InferenceSpan, + Session, + ToolExecutionSpan, +) + +# --- Fixtures --- + + +@pytest.fixture +def mock_logs_client(): + return MagicMock() + + +@pytest.fixture +def provider(mock_logs_client): + with patch("boto3.client", return_value=mock_logs_client): + return CloudWatchProvider(log_group="/test/group") + + +# --- Constructor --- + + +class TestConstructor: + def test_explicit_log_group(self, mock_logs_client): + with patch("boto3.client", return_value=mock_logs_client): + p = CloudWatchProvider(log_group="/custom/group") + assert p._log_group == "/custom/group" + assert p._lookback_days == 30 + assert p._query_timeout_seconds == 60.0 + + def test_custom_params(self, mock_logs_client): + with patch("boto3.client", return_value=mock_logs_client) as mock_boto: + p = CloudWatchProvider( + region="eu-west-1", + log_group="/custom/log-group", + lookback_days=7, + query_timeout_seconds=120.0, + ) + mock_boto.assert_called_once_with("logs", region_name="eu-west-1") + assert p._log_group == "/custom/log-group" + assert p._lookback_days == 7 + assert p._query_timeout_seconds == 120.0 + + def test_agent_name_discovery(self, mock_logs_client): + mock_logs_client.describe_log_groups.return_value = { + "logGroups": [{"logGroupName": "/aws/bedrock-agentcore/runtimes/my-agent-abc-DEFAULT"}] + } + with patch("boto3.client", return_value=mock_logs_client): + p = CloudWatchProvider(agent_name="my-agent") + assert p._log_group == "/aws/bedrock-agentcore/runtimes/my-agent-abc-DEFAULT" + mock_logs_client.describe_log_groups.assert_called_once_with( + logGroupNamePrefix="/aws/bedrock-agentcore/runtimes/my-agent" + ) + + def test_agent_name_no_match_raises(self, mock_logs_client): + mock_logs_client.describe_log_groups.return_value = {"logGroups": []} + with ( + patch("boto3.client", return_value=mock_logs_client), + pytest.raises(ProviderError, match="no log group found"), + ): + CloudWatchProvider(agent_name="nonexistent-agent") + + def test_neither_log_group_nor_agent_name_raises(self, mock_logs_client): + with ( + patch("boto3.client", return_value=mock_logs_client), + pytest.raises(ProviderError, match="log_group.*agent_name"), + ): + CloudWatchProvider() + + def test_region_from_aws_region_env(self, mock_logs_client): + env = os.environ.copy() + env.pop("AWS_DEFAULT_REGION", None) + env["AWS_REGION"] = "ap-southeast-1" + with ( + patch.dict(os.environ, env, clear=True), + patch("boto3.client", return_value=mock_logs_client) as mock_boto, + ): + CloudWatchProvider(log_group="/test/group") + mock_boto.assert_called_once_with("logs", region_name="ap-southeast-1") + + def test_region_from_aws_default_region_env(self, mock_logs_client): + env = os.environ.copy() + env.pop("AWS_REGION", None) + env["AWS_DEFAULT_REGION"] = "us-west-2" + with ( + patch.dict(os.environ, env, clear=True), + patch("boto3.client", return_value=mock_logs_client) as mock_boto, + ): + CloudWatchProvider(log_group="/test/group") + mock_boto.assert_called_once_with("logs", region_name="us-west-2") + + def test_aws_region_takes_precedence_over_default_region(self, mock_logs_client): + env = {"AWS_REGION": "eu-central-1", "AWS_DEFAULT_REGION": "us-west-2"} + with ( + patch.dict(os.environ, env), + patch("boto3.client", return_value=mock_logs_client) as mock_boto, + ): + CloudWatchProvider(log_group="/test/group") + mock_boto.assert_called_once_with("logs", region_name="eu-central-1") + + def test_explicit_region_overrides_env(self, mock_logs_client): + env = {"AWS_REGION": "eu-central-1"} + with ( + patch.dict(os.environ, env), + patch("boto3.client", return_value=mock_logs_client) as mock_boto, + ): + CloudWatchProvider(region="ca-central-1", log_group="/test/group") + mock_boto.assert_called_once_with("logs", region_name="ca-central-1") + + def test_boto3_client_creation_failure(self): + with ( + patch("boto3.client", side_effect=Exception("bad credentials")), + pytest.raises(ProviderError, match="bad credentials"), + ): + CloudWatchProvider(log_group="/test/group") + + +# --- Helpers for body-format log records --- + + +def _setup_query_results(mock_logs_client, records): + """Wire up mock to return records from a CW Logs Insights query.""" + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Complete", + "results": [[{"field": "@message", "value": json.dumps(r)}] for r in records], + } + + +def _make_log_record( + trace_id="abc123", + span_id="span-1", + input_messages=None, + output_messages=None, + session_id="sess-1", + time_nano=1000000000000000000, +): + """Build a body-format OTEL log record dict as found in runtime log groups.""" + record = { + "traceId": trace_id, + "spanId": span_id, + "timeUnixNano": time_nano, + "body": { + "input": {"messages": input_messages or []}, + "output": {"messages": output_messages or []}, + }, + "attributes": {"session.id": session_id}, + } + return record + + +def _make_user_message(text): + """Build a user input message with double-encoded content.""" + return {"role": "user", "content": {"content": json.dumps([{"text": text}])}} + + +def _make_assistant_text_message(text): + """Build an assistant output message with double-encoded text content.""" + return { + "role": "assistant", + "content": {"message": json.dumps([{"text": text}]), "finish_reason": "end_turn"}, + } + + +def _make_assistant_tool_use_message(tool_name, tool_input, tool_use_id): + """Build an assistant output message with a toolUse block.""" + return { + "role": "assistant", + "content": { + "message": json.dumps([{"toolUse": {"name": tool_name, "input": tool_input, "toolUseId": tool_use_id}}]), + "finish_reason": "tool_use", + }, + } + + +def _make_tool_result_message(tool_use_id, result_text): + """Build a tool result input message with double-encoded content.""" + return { + "role": "tool", + "content": { + "content": json.dumps([{"toolResult": {"content": [{"text": result_text}], "toolUseId": tool_use_id}}]) + }, + } + + +# --- Span conversion (body-based parsing) --- + + +class TestSpanConversion: + def test_single_record_produces_inference_span(self, provider): + """One log record produces an InferenceSpan with input/output messages.""" + record = _make_log_record( + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("Hello!")], + ) + session = provider._build_session("sess-1", [record]) + spans = session.traces[0].spans + inference_spans = [s for s in spans if isinstance(s, InferenceSpan)] + assert len(inference_spans) == 1 + assert inference_spans[0].messages[0].content[0].text == "Hi" + assert inference_spans[0].messages[1].content[0].text == "Hello!" + + def test_record_with_tool_use_and_result(self, provider): + """toolUse in output + toolResult in next record's input → ToolExecutionSpan.""" + # Record 1: user asks, assistant calls tool + record1 = _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("Calculate 6*7")], + output_messages=[_make_assistant_tool_use_message("calculator", {"expr": "6*7"}, "tu-1")], + time_nano=1000, + ) + # Record 2: tool result comes back, assistant responds + record2 = _make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[ + _make_user_message("Calculate 6*7"), + _make_tool_result_message("tu-1", "42"), + ], + output_messages=[_make_assistant_text_message("The answer is 42.")], + time_nano=2000, + ) + session = provider._build_session("sess-1", [record1, record2]) + tool_spans = [s for s in session.traces[0].spans if isinstance(s, ToolExecutionSpan)] + assert len(tool_spans) == 1 + assert tool_spans[0].tool_call.name == "calculator" + assert tool_spans[0].tool_call.arguments == {"expr": "6*7"} + assert tool_spans[0].tool_result.content == "42" + + def test_agent_invocation_from_trace(self, provider): + """User prompt from first record, response from last → AgentInvocationSpan.""" + record1 = _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("Tell me a joke")], + output_messages=[_make_assistant_text_message("Why did the chicken cross the road?")], + time_nano=1000, + ) + session = provider._build_session("sess-1", [record1]) + agent_spans = [s for s in session.traces[0].spans if isinstance(s, AgentInvocationSpan)] + assert len(agent_spans) == 1 + assert agent_spans[0].user_prompt == "Tell me a joke" + assert agent_spans[0].agent_response == "Why did the chicken cross the road?" + + def test_agent_invocation_extracts_tools(self, provider): + """available_tools populated from tool call names in the trace.""" + record1 = _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("Search for X")], + output_messages=[_make_assistant_tool_use_message("web_search", {"q": "X"}, "tu-1")], + time_nano=1000, + ) + record2 = _make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[_make_user_message("Search for X"), _make_tool_result_message("tu-1", "found X")], + output_messages=[_make_assistant_text_message("Here's what I found about X.")], + time_nano=2000, + ) + session = provider._build_session("sess-1", [record1, record2]) + agent_spans = [s for s in session.traces[0].spans if isinstance(s, AgentInvocationSpan)] + assert len(agent_spans) == 1 + tool_names = [t.name for t in agent_spans[0].available_tools] + assert "web_search" in tool_names + + def test_double_encoded_content_parsed(self, provider): + """Content field is a JSON string that must be parsed to get content blocks.""" + record = _make_log_record( + input_messages=[_make_user_message("test double encoding")], + output_messages=[_make_assistant_text_message("parsed correctly")], + ) + session = provider._build_session("sess-1", [record]) + inference_spans = [s for s in session.traces[0].spans if isinstance(s, InferenceSpan)] + assert inference_spans[0].messages[0].content[0].text == "test double encoding" + assert inference_spans[0].messages[1].content[0].text == "parsed correctly" + + def test_tool_call_matched_to_result_by_id(self, provider): + """toolUseId matching works across records in the same trace.""" + record1 = _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("Do two things")], + output_messages=[ + _make_assistant_tool_use_message("tool_a", {"x": 1}, "tu-a"), + ], + time_nano=1000, + ) + record2 = _make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[ + _make_user_message("Do two things"), + _make_tool_result_message("tu-a", "result-a"), + ], + output_messages=[ + _make_assistant_tool_use_message("tool_b", {"y": 2}, "tu-b"), + ], + time_nano=2000, + ) + record3 = _make_log_record( + trace_id="t1", + span_id="s3", + input_messages=[ + _make_user_message("Do two things"), + _make_tool_result_message("tu-a", "result-a"), + _make_tool_result_message("tu-b", "result-b"), + ], + output_messages=[_make_assistant_text_message("Both done.")], + time_nano=3000, + ) + session = provider._build_session("sess-1", [record1, record2, record3]) + tool_spans = [s for s in session.traces[0].spans if isinstance(s, ToolExecutionSpan)] + assert len(tool_spans) == 2 + tool_span_by_name = {ts.tool_call.name: ts for ts in tool_spans} + assert tool_span_by_name["tool_a"].tool_result.content == "result-a" + assert tool_span_by_name["tool_b"].tool_result.content == "result-b" + + +# --- Session building --- + + +class TestSessionBuilding: + def test_multiple_records_grouped_by_trace_id(self, provider): + """Records with different traceIds become separate Trace objects.""" + records = [ + _make_log_record( + trace_id="t1", + input_messages=[_make_user_message("q1")], + output_messages=[_make_assistant_text_message("a1")], + ), + _make_log_record( + trace_id="t2", + input_messages=[_make_user_message("q2")], + output_messages=[_make_assistant_text_message("a2")], + ), + ] + session = provider._build_session("sess-1", records) + assert session.session_id == "sess-1" + assert len(session.traces) == 2 + trace_ids = {t.trace_id for t in session.traces} + assert trace_ids == {"t1", "t2"} + + def test_multi_step_agent_loop(self, provider): + """user→LLM→tool→LLM→response produces InferenceSpan + ToolExecutionSpan + AgentInvocationSpan.""" + record1 = _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("What is 6*7?")], + output_messages=[_make_assistant_tool_use_message("calculator", {"expr": "6*7"}, "tu-1")], + time_nano=1000, + ) + record2 = _make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[ + _make_user_message("What is 6*7?"), + _make_tool_result_message("tu-1", "42"), + ], + output_messages=[_make_assistant_text_message("The answer is 42.")], + time_nano=2000, + ) + session = provider._build_session("sess-1", [record1, record2]) + assert len(session.traces) == 1 + spans = session.traces[0].spans + span_types = [type(s).__name__ for s in spans] + assert "InferenceSpan" in span_types + assert "ToolExecutionSpan" in span_types + assert "AgentInvocationSpan" in span_types + + def test_empty_records_list(self, provider): + session = provider._build_session("sess-1", []) + assert session.session_id == "sess-1" + assert session.traces == [] + + def test_extract_output_from_agent_response(self, provider): + """_extract_output returns last agent response text.""" + records = [ + _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("First response")], + time_nano=1000, + ), + _make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("Final response")], + time_nano=2000, + ), + ] + session = provider._build_session("sess-1", records) + output = provider._extract_output(session) + assert output == "Final response" + + def test_record_with_no_body_skipped(self, provider): + """Malformed records without body don't crash.""" + records = [ + {"traceId": "t1", "spanId": "s1", "timeUnixNano": 1000}, + _make_log_record( + trace_id="t1", + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("Hello!")], + time_nano=2000, + ), + ] + session = provider._build_session("sess-1", records) + assert len(session.traces) == 1 + assert len(session.traces[0].spans) > 0 + + +# --- CW Logs Insights polling --- + + +class TestLogsInsightsPolling: + def _make_record_json(self, trace_id="t1", span_id="s1"): + """Return a JSON-serialized body-format log record for use in CW Logs @message fields.""" + return json.dumps( + _make_log_record( + trace_id=trace_id, + span_id=span_id, + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("Hello!")], + ) + ) + + def test_happy_path(self, provider, mock_logs_client): + """Query starts, one poll, completes with results.""" + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Complete", + "results": [ + [{"field": "@message", "value": self._make_record_json()}], + ], + } + results = provider._run_logs_insights_query("fields @message") + assert len(results) == 1 + assert results[0]["traceId"] == "t1" + + def test_polls_through_intermediate_statuses(self, provider, mock_logs_client): + """Query goes through Scheduled → Running → Complete.""" + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.side_effect = [ + {"status": "Scheduled", "results": []}, + {"status": "Running", "results": []}, + { + "status": "Complete", + "results": [ + [{"field": "@message", "value": self._make_record_json()}], + ], + }, + ] + with patch("time.sleep"): + results = provider._run_logs_insights_query("fields @message") + assert len(results) == 1 + + def test_failed_status_raises(self, provider, mock_logs_client): + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Failed", + "results": [], + } + with pytest.raises(ProviderError, match="Failed"): + provider._run_logs_insights_query("fields @message") + + def test_timeout_raises(self, provider, mock_logs_client): + """If query doesn't complete within timeout, raises ProviderError.""" + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Running", + "results": [], + } + provider._query_timeout_seconds = 0.01 + with ( + patch("time.sleep"), + patch("time.monotonic", side_effect=[0.0, 0.0, 1.0]), + pytest.raises(ProviderError, match="timed out"), + ): + provider._run_logs_insights_query("fields @message") + + def test_empty_results(self, provider, mock_logs_client): + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Complete", + "results": [], + } + results = provider._run_logs_insights_query("fields @message") + assert results == [] + + def test_parses_message_field(self, provider, mock_logs_client): + """Each result row's @message field is parsed as JSON into a record dict.""" + record1 = _make_log_record(trace_id="t1", span_id="s1") + record2 = _make_log_record(trace_id="t1", span_id="s2") + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Complete", + "results": [ + [{"field": "@message", "value": json.dumps(record1)}], + [{"field": "@message", "value": json.dumps(record2)}], + ], + } + results = provider._run_logs_insights_query("fields @message") + assert len(results) == 2 + assert results[0]["spanId"] == "s1" + assert results[1]["spanId"] == "s2" + + def test_start_query_failure_raises(self, provider, mock_logs_client): + mock_logs_client.start_query.side_effect = Exception("access denied") + with pytest.raises(ProviderError, match="access denied"): + provider._run_logs_insights_query("fields @message") + + +# --- get_evaluation_data --- + + +class TestGetEvaluationData: + def test_happy_path(self, provider, mock_logs_client): + records = [ + _make_log_record( + trace_id="t1", + span_id="s1", + session_id="sess-1", + input_messages=[_make_user_message("What is 6*7?")], + output_messages=[_make_assistant_text_message("The answer is 42.")], + ) + ] + _setup_query_results(mock_logs_client, records) + + result = provider.get_evaluation_data("sess-1") + assert isinstance(result["trajectory"], Session) + assert result["trajectory"].session_id == "sess-1" + assert len(result["trajectory"].traces) == 1 + assert result["output"] == "The answer is 42." + + def test_no_results_raises_session_not_found(self, provider, mock_logs_client): + _setup_query_results(mock_logs_client, []) + with pytest.raises(SessionNotFoundError, match="sess-missing"): + provider.get_evaluation_data("sess-missing") + + def test_query_failure_raises_provider_error(self, provider, mock_logs_client): + mock_logs_client.start_query.side_effect = Exception("throttled") + with pytest.raises(ProviderError, match="throttled"): + provider.get_evaluation_data("sess-1") + + def test_query_uses_session_id_filter(self, provider, mock_logs_client): + """Verify the query string uses attributes.session.id filter.""" + records = [ + _make_log_record( + session_id="sess-1", + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("Hello")], + ) + ] + _setup_query_results(mock_logs_client, records) + provider.get_evaluation_data("sess-1") + query_string = mock_logs_client.start_query.call_args[1]["queryString"] + assert "attributes.session.id" in query_string + assert "sess-1" in query_string + + def test_multiple_traces(self, provider, mock_logs_client): + records = [ + _make_log_record( + trace_id="t1", + input_messages=[_make_user_message("q1")], + output_messages=[_make_assistant_text_message("first")], + ), + _make_log_record( + trace_id="t2", + input_messages=[_make_user_message("q2")], + output_messages=[_make_assistant_text_message("second")], + ), + ] + _setup_query_results(mock_logs_client, records) + result = provider.get_evaluation_data("sess-1") + assert len(result["trajectory"].traces) == 2 + assert result["output"] == "second" + + def test_output_from_last_agent_invocation(self, provider, mock_logs_client): + records = [ + _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("first")], + time_nano=1000, + ), + _make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("last")], + time_nano=2000, + ), + ] + _setup_query_results(mock_logs_client, records) + assert provider.get_evaluation_data("sess-1")["output"] == "last" + + +# --- get_evaluation_data_by_trace_id --- + + +class TestGetEvaluationDataByTraceId: + def test_happy_path(self, provider, mock_logs_client): + records = [ + _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("What is 6*7?")], + output_messages=[_make_assistant_text_message("The answer is 42.")], + ) + ] + _setup_query_results(mock_logs_client, records) + result = provider.get_evaluation_data_by_trace_id("t1") + assert isinstance(result["trajectory"], Session) + assert result["trajectory"].traces[0].trace_id == "t1" + assert result["output"] == "The answer is 42." + + def test_not_found_raises(self, provider, mock_logs_client): + _setup_query_results(mock_logs_client, []) + with pytest.raises(TraceNotFoundError, match="t-missing"): + provider.get_evaluation_data_by_trace_id("t-missing") + + def test_query_failure_raises(self, provider, mock_logs_client): + mock_logs_client.start_query.side_effect = Exception("throttled") + with pytest.raises(ProviderError, match="throttled"): + provider.get_evaluation_data_by_trace_id("t1") + + def test_query_uses_trace_id_filter(self, provider, mock_logs_client): + records = [ + _make_log_record( + trace_id="t-abc", + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("Hello")], + ) + ] + _setup_query_results(mock_logs_client, records) + provider.get_evaluation_data_by_trace_id("t-abc") + query_string = mock_logs_client.start_query.call_args[1]["queryString"] + assert "t-abc" in query_string + + def test_session_id_from_record_attributes(self, provider, mock_logs_client): + """Session ID is taken from record attributes when available.""" + records = [ + _make_log_record( + trace_id="t1", + session_id="sess-from-record", + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("Hello")], + ) + ] + _setup_query_results(mock_logs_client, records) + result = provider.get_evaluation_data_by_trace_id("t1") + assert result["trajectory"].session_id == "sess-from-record" + + +# --- list_sessions --- + + +def _setup_session_query(mock_logs_client, session_ids): + """Wire up mock to return session IDs from a stats aggregation query.""" + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + results = [] + for sid in session_ids: + results.append( + [ + {"field": "sessionId", "value": sid}, + {"field": "span_count", "value": "5"}, + ] + ) + mock_logs_client.get_query_results.return_value = { + "status": "Complete", + "results": results, + } + + +class TestListSessions: + def test_returns_session_ids(self, provider, mock_logs_client): + _setup_session_query(mock_logs_client, ["s1", "s2", "s3"]) + assert list(provider.list_sessions()) == ["s1", "s2", "s3"] + + def test_empty_results(self, provider, mock_logs_client): + _setup_session_query(mock_logs_client, []) + assert list(provider.list_sessions()) == [] + + def test_time_filter_applied(self, provider, mock_logs_client): + from datetime import datetime, timezone + + from strands_evals.providers.trace_provider import SessionFilter + + start = datetime(2025, 1, 1, tzinfo=timezone.utc) + end = datetime(2025, 1, 31, tzinfo=timezone.utc) + _setup_session_query(mock_logs_client, ["s1"]) + list(provider.list_sessions(session_filter=SessionFilter(start_time=start, end_time=end))) + kw = mock_logs_client.start_query.call_args[1] + assert kw["startTime"] == int(start.timestamp()) + assert kw["endTime"] == int(end.timestamp()) + + def test_limit_applied(self, provider, mock_logs_client): + from strands_evals.providers.trace_provider import SessionFilter + + _setup_session_query(mock_logs_client, ["s1"]) + list(provider.list_sessions(session_filter=SessionFilter(limit=50))) + query_string = mock_logs_client.start_query.call_args[1]["queryString"] + assert "limit 50" in query_string + + def test_default_limit(self, provider, mock_logs_client): + _setup_session_query(mock_logs_client, ["s1"]) + list(provider.list_sessions()) + query_string = mock_logs_client.start_query.call_args[1]["queryString"] + assert "limit 1000" in query_string + + def test_query_failure_raises(self, provider, mock_logs_client): + mock_logs_client.start_query.side_effect = Exception("access denied") + with pytest.raises(ProviderError, match="access denied"): + list(provider.list_sessions()) diff --git a/tests/strands_evals/providers/test_langfuse_provider.py b/tests/strands_evals/providers/test_langfuse_provider.py index 08eb268..c03e3e0 100644 --- a/tests/strands_evals/providers/test_langfuse_provider.py +++ b/tests/strands_evals/providers/test_langfuse_provider.py @@ -18,7 +18,6 @@ ToolExecutionSpan, ) - # --- Helpers --- @@ -36,8 +35,19 @@ def _trace(trace_id, session_id, output=None): return t -def _obs(obs_id, trace_id, obs_type, name=None, obs_input=None, obs_output=None, - start_time=None, end_time=None, parent_observation_id=None, metadata=None, model=None): +def _obs( + obs_id, + trace_id, + obs_type, + name=None, + obs_input=None, + obs_output=None, + start_time=None, + end_time=None, + parent_observation_id=None, + metadata=None, + model=None, +): o = MagicMock() o.id, o.trace_id, o.type, o.name = obs_id, trace_id, obs_type, name o.start_time = start_time or datetime(2025, 1, 15, 10, 0, 0, tzinfo=timezone.utc) @@ -70,6 +80,7 @@ def mock_client(): def provider(mock_client): with patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client): from strands_evals.providers.langfuse_provider import LangfuseProvider + return LangfuseProvider(public_key="pk-test", secret_key="sk-test") @@ -80,6 +91,7 @@ class TestConstructor: def test_explicit_credentials(self, mock_client): with patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client) as cls: from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider(public_key="pk-1", secret_key="sk-2", host="https://custom.langfuse.com") cls.assert_called_once_with(public_key="pk-1", secret_key="sk-2", host="https://custom.langfuse.com") @@ -94,6 +106,7 @@ def test_env_var_fallback(self, mock_client): patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client) as cls, ): from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider() cls.assert_called_once_with(public_key="pk-env", secret_key="sk-env", host="https://us.cloud.langfuse.com") @@ -108,6 +121,7 @@ def test_host_env_var_fallback(self, mock_client): patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client) as cls, ): from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider() cls.assert_called_once_with( public_key="pk-env", secret_key="sk-env", host="https://my-langfuse.example.com" @@ -117,8 +131,13 @@ def test_missing_credentials_raises(self): env = os.environ.copy() env.pop("LANGFUSE_PUBLIC_KEY", None) env.pop("LANGFUSE_SECRET_KEY", None) - with patch.dict(os.environ, env, clear=True), pytest.raises(ProviderError, match="Langfuse credentials"): + with ( + patch.dict(os.environ, env, clear=True), + patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=MagicMock()), + pytest.raises(ProviderError, match="Langfuse credentials"), + ): from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider() def test_default_host(self, mock_client): @@ -130,6 +149,7 @@ def test_default_host(self, mock_client): patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client) as cls, ): from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider(public_key="pk", secret_key="sk") assert cls.call_args[1]["host"] == "https://us.cloud.langfuse.com" @@ -140,6 +160,7 @@ def test_explicit_host_overrides_env(self, mock_client): patch("strands_evals.providers.langfuse_provider.Langfuse", return_value=mock_client) as cls, ): from strands_evals.providers.langfuse_provider import LangfuseProvider + LangfuseProvider(public_key="pk", secret_key="sk", host="https://explicit.example.com") assert cls.call_args[1]["host"] == "https://explicit.example.com" @@ -150,9 +171,11 @@ def test_explicit_host_overrides_env(self, mock_client): class TestGetEvaluationData: def test_happy_path(self, provider, mock_client): mock_client.api.trace.list.return_value = _paginated([_trace("t1", "s1")]) - mock_client.api.observations.get_many.return_value = _paginated([ - _obs("o1", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a"), - ]) + mock_client.api.observations.get_many.return_value = _paginated( + [ + _obs("o1", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a"), + ] + ) result = provider.get_evaluation_data("s1") assert isinstance(result["trajectory"], Session) assert result["trajectory"].session_id == "s1" @@ -165,12 +188,28 @@ def test_empty_session_raises(self, provider, mock_client): def test_output_from_last_agent_invocation(self, provider, mock_client): mock_client.api.trace.list.return_value = _paginated([_trace("t1", "s1")]) - mock_client.api.observations.get_many.return_value = _paginated([ - _obs("o1", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q1"}], obs_output="first", - start_time=datetime(2025, 1, 15, 10, 0, 0, tzinfo=timezone.utc)), - _obs("o2", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q2"}], obs_output="second", - start_time=datetime(2025, 1, 15, 10, 1, 0, tzinfo=timezone.utc)), - ]) + mock_client.api.observations.get_many.return_value = _paginated( + [ + _obs( + "o1", + "t1", + "SPAN", + name="invoke_agent a", + obs_input=[{"text": "q1"}], + obs_output="first", + start_time=datetime(2025, 1, 15, 10, 0, 0, tzinfo=timezone.utc), + ), + _obs( + "o2", + "t1", + "SPAN", + name="invoke_agent a", + obs_input=[{"text": "q2"}], + obs_output="second", + start_time=datetime(2025, 1, 15, 10, 1, 0, tzinfo=timezone.utc), + ), + ] + ) assert provider.get_evaluation_data("s1")["output"] == "second" def test_paginates_traces(self, provider, mock_client): @@ -187,12 +226,25 @@ def test_paginates_traces(self, provider, mock_client): def test_paginates_observations(self, provider, mock_client): mock_client.api.trace.list.return_value = _paginated([_trace("t1", "s1")]) mock_client.api.observations.get_many.side_effect = [ - _paginated([_obs("o1", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a")], - page=1, total_pages=2), - _paginated([_obs("o2", "t1", "GENERATION", name="chat", - obs_input=[{"role": "user", "content": [{"text": "q"}]}], - obs_output={"role": "assistant", "content": [{"text": "a"}]})], - page=2, total_pages=2), + _paginated( + [_obs("o1", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a")], + page=1, + total_pages=2, + ), + _paginated( + [ + _obs( + "o2", + "t1", + "GENERATION", + name="chat", + obs_input=[{"role": "user", "content": [{"text": "q"}]}], + obs_output={"role": "assistant", "content": [{"text": "a"}]}, + ) + ], + page=2, + total_pages=2, + ), ] assert len(provider.get_evaluation_data("s1")["trajectory"].traces[0].spans) == 2 @@ -203,9 +255,11 @@ def test_wraps_api_error(self, provider, mock_client): def test_unconvertible_observations_excluded(self, provider, mock_client): mock_client.api.trace.list.return_value = _paginated([_trace("t1", "s1")]) - mock_client.api.observations.get_many.return_value = _paginated([ - _obs("o1", "t1", "EVENT", name="some_event"), - ]) + mock_client.api.observations.get_many.return_value = _paginated( + [ + _obs("o1", "t1", "EVENT", name="some_event"), + ] + ) with pytest.raises(SessionNotFoundError): provider.get_evaluation_data("s1") @@ -220,65 +274,109 @@ def _get_spans(self, provider, mock_client, observations): return provider.get_evaluation_data("s1")["trajectory"].traces[0].spans def test_generation_to_inference_span(self, provider, mock_client): - spans = self._get_spans(provider, mock_client, [ - _obs("o-gen", "t1", "GENERATION", name="chat", - obs_input=[{"role": "user", "content": [{"text": "q"}]}], - obs_output={"role": "assistant", "content": [{"text": "a"}]}), - _obs("o-agent", "t1", "SPAN", name="invoke_agent a", - obs_input=[{"text": "q"}], obs_output="a"), - ]) + spans = self._get_spans( + provider, + mock_client, + [ + _obs( + "o-gen", + "t1", + "GENERATION", + name="chat", + obs_input=[{"role": "user", "content": [{"text": "q"}]}], + obs_output={"role": "assistant", "content": [{"text": "a"}]}, + ), + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a"), + ], + ) inference = [s for s in spans if isinstance(s, InferenceSpan)] assert len(inference) == 1 assert inference[0].span_info.span_id == "o-gen" def test_execute_tool_to_tool_execution_span(self, provider, mock_client): - spans = self._get_spans(provider, mock_client, [ - _obs("o-tool", "t1", "SPAN", name="execute_tool calc", - obs_input={"name": "calc", "arguments": {"x": "2+2"}, "toolUseId": "c1"}, - obs_output={"result": "4", "status": "success"}), - _obs("o-agent", "t1", "SPAN", name="invoke_agent a", - obs_input=[{"text": "q"}], obs_output="a"), - ]) + spans = self._get_spans( + provider, + mock_client, + [ + _obs( + "o-tool", + "t1", + "SPAN", + name="execute_tool calc", + obs_input={"name": "calc", "arguments": {"x": "2+2"}, "toolUseId": "c1"}, + obs_output={"result": "4", "status": "success"}, + ), + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a"), + ], + ) tools = [s for s in spans if isinstance(s, ToolExecutionSpan)] assert len(tools) == 1 assert tools[0].tool_call.name == "calc" assert tools[0].tool_call.arguments == {"x": "2+2"} def test_invoke_agent_to_agent_invocation_span(self, provider, mock_client): - spans = self._get_spans(provider, mock_client, [ - _obs("o-agent", "t1", "SPAN", name="invoke_agent my_agent", - obs_input=[{"text": "Hello"}], obs_output="Hi there!"), - ]) + spans = self._get_spans( + provider, + mock_client, + [ + _obs( + "o-agent", + "t1", + "SPAN", + name="invoke_agent my_agent", + obs_input=[{"text": "Hello"}], + obs_output="Hi there!", + ), + ], + ) agents = [s for s in spans if isinstance(s, AgentInvocationSpan)] assert len(agents) == 1 assert agents[0].user_prompt == "Hello" assert agents[0].agent_response == "Hi there!" def test_unknown_type_skipped(self, provider, mock_client): - spans = self._get_spans(provider, mock_client, [ - _obs("o-event", "t1", "EVENT", name="log"), - _obs("o-agent", "t1", "SPAN", name="invoke_agent a", - obs_input=[{"text": "q"}], obs_output="a"), - ]) + spans = self._get_spans( + provider, + mock_client, + [ + _obs("o-event", "t1", "EVENT", name="log"), + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a"), + ], + ) assert len(spans) == 1 assert isinstance(spans[0], AgentInvocationSpan) def test_unknown_span_name_skipped(self, provider, mock_client): - spans = self._get_spans(provider, mock_client, [ - _obs("o-unk", "t1", "SPAN", name="some_other_op"), - _obs("o-agent", "t1", "SPAN", name="invoke_agent a", - obs_input=[{"text": "q"}], obs_output="a"), - ]) + spans = self._get_spans( + provider, + mock_client, + [ + _obs("o-unk", "t1", "SPAN", name="some_other_op"), + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a"), + ], + ) assert len(spans) == 1 def test_span_info_populated(self, provider, mock_client): start = datetime(2025, 6, 1, 12, 0, 0, tzinfo=timezone.utc) end = datetime(2025, 6, 1, 12, 0, 10, tzinfo=timezone.utc) - spans = self._get_spans(provider, mock_client, [ - _obs("o-agent", "t1", "SPAN", name="invoke_agent a", - obs_input=[{"text": "q"}], obs_output="a", - start_time=start, end_time=end, parent_observation_id="o-parent"), - ]) + spans = self._get_spans( + provider, + mock_client, + [ + _obs( + "o-agent", + "t1", + "SPAN", + name="invoke_agent a", + obs_input=[{"text": "q"}], + obs_output="a", + start_time=start, + end_time=end, + parent_observation_id="o-parent", + ), + ], + ) si = spans[0].span_info assert si.trace_id == "t1" assert si.span_id == "o-agent" @@ -288,19 +386,38 @@ def test_span_info_populated(self, provider, mock_client): assert si.end_time == end def test_string_input_for_agent(self, provider, mock_client): - spans = self._get_spans(provider, mock_client, [ - _obs("o-agent", "t1", "SPAN", name="invoke_agent a", - obs_input="plain string prompt", obs_output="response"), - ]) + spans = self._get_spans( + provider, + mock_client, + [ + _obs( + "o-agent", + "t1", + "SPAN", + name="invoke_agent a", + obs_input="plain string prompt", + obs_output="response", + ), + ], + ) assert spans[0].user_prompt == "plain string prompt" def test_string_output_for_tool(self, provider, mock_client): - spans = self._get_spans(provider, mock_client, [ - _obs("o-tool", "t1", "SPAN", name="execute_tool calc", - obs_input={"name": "calc", "arguments": {"x": 1}}, obs_output="42"), - _obs("o-agent", "t1", "SPAN", name="invoke_agent a", - obs_input=[{"text": "q"}], obs_output="a"), - ]) + spans = self._get_spans( + provider, + mock_client, + [ + _obs( + "o-tool", + "t1", + "SPAN", + name="execute_tool calc", + obs_input={"name": "calc", "arguments": {"x": 1}}, + obs_output="42", + ), + _obs("o-agent", "t1", "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output="a"), + ], + ) tools = [s for s in spans if isinstance(s, ToolExecutionSpan)] assert tools[0].tool_result.content == "42" @@ -311,7 +428,8 @@ def test_string_output_for_tool(self, provider, mock_client): class TestListSessions: def test_yields_ids(self, provider, mock_client): mock_client.api.sessions.list.return_value = _paginated( - [_lf_session("s1"), _lf_session("s2"), _lf_session("s3")]) + [_lf_session("s1"), _lf_session("s2"), _lf_session("s3")] + ) assert list(provider.list_sessions()) == ["s1", "s2", "s3"] def test_paginates(self, provider, mock_client): @@ -323,6 +441,7 @@ def test_paginates(self, provider, mock_client): def test_time_filter(self, provider, mock_client): from strands_evals.providers.trace_provider import SessionFilter + start = datetime(2025, 1, 1, tzinfo=timezone.utc) end = datetime(2025, 1, 31, tzinfo=timezone.utc) mock_client.api.sessions.list.return_value = _paginated([_lf_session("s1")]) @@ -349,8 +468,7 @@ def _trace_detail(self, trace_id="t1", session_id="s1", output="answer", observa td = MagicMock() td.id, td.session_id, td.output = trace_id, session_id, output td.observations = observations or [ - _obs("o-agent", trace_id, "SPAN", name="invoke_agent a", - obs_input=[{"text": "q"}], obs_output=output), + _obs("o-agent", trace_id, "SPAN", name="invoke_agent a", obs_input=[{"text": "q"}], obs_output=output), ] return td diff --git a/tests_integ/test_cloudwatch_provider.py b/tests_integ/test_cloudwatch_provider.py new file mode 100644 index 0000000..f3e54cb --- /dev/null +++ b/tests_integ/test_cloudwatch_provider.py @@ -0,0 +1,271 @@ +"""Integration tests for CloudWatchProvider against real CloudWatch Logs data. + +Requires AWS credentials with CloudWatch Logs read access. Uses the +'strands_github_bot' AWS profile if available, otherwise falls back to default credentials. +Run with: pytest tests_integ/test_cloudwatch_provider.py -v +""" + +import boto3 +import pytest + +from strands_evals import Case, Experiment +from strands_evals.evaluators import ( + CoherenceEvaluator, + HelpfulnessEvaluator, + OutputEvaluator, +) +from strands_evals.providers.cloudwatch_provider import CloudWatchProvider +from strands_evals.providers.exceptions import ( + ProviderError, + SessionNotFoundError, + TraceNotFoundError, +) +from strands_evals.types.trace import ( + AgentInvocationSpan, + InferenceSpan, + Session, + ToolExecutionSpan, + Trace, +) + +LOG_GROUP = "/aws/bedrock-agentcore/runtimes/github_issue_handler-zf6fZR2saQ-DEFAULT" + +KNOWN_SESSION_IDS = [ + "github_issue_68_20260218_172558_d27beb07", + "github-issue-68-1771435544179-d5gc8kk4xan", +] + +EXPECTED_ACCOUNT_ID = "249746592913" +AWS_PROFILE = "strands_github_bot" +AWS_REGION = "us-east-1" + + +def _create_logs_client() -> boto3.client | None: + """Try the named profile first, then fall back to default credentials.""" + for profile in [AWS_PROFILE, None]: + try: + session = boto3.Session(profile_name=profile, region_name=AWS_REGION) + sts = session.client("sts") + identity = sts.get_caller_identity() + if identity["Account"] == EXPECTED_ACCOUNT_ID: + return session.client("logs") + except Exception: + continue + return None + + +@pytest.fixture(scope="module") +def provider(): + """Create a CloudWatchProvider targeting the correct AWS account.""" + client = _create_logs_client() + if client is None: + pytest.skip(f"No AWS credentials found for account {EXPECTED_ACCOUNT_ID}") + + try: + cw = CloudWatchProvider(log_group=LOG_GROUP, region=AWS_REGION) + except ProviderError as e: + pytest.skip(f"CloudWatch provider creation failed: {e}") + + # Inject the verified client + cw._client = client + return cw + + +@pytest.fixture(scope="module") +def session_id(provider): + """Try known session IDs; skip if none found.""" + for sid in KNOWN_SESSION_IDS: + try: + provider.get_evaluation_data(sid) + return sid + except (SessionNotFoundError, ProviderError): + continue + pytest.skip("No known sessions found in CloudWatch") + + +@pytest.fixture(scope="module") +def evaluation_data(provider, session_id): + """Fetch evaluation data for the discovered session.""" + return provider.get_evaluation_data(session_id) + + +class TestListSessions: + def test_returns_at_least_one_session(self, provider): + sessions = list(provider.list_sessions()) + assert len(sessions) > 0, "Expected at least one session in CloudWatch" + + def test_session_ids_are_strings(self, provider): + for session_id in provider.list_sessions(): + assert isinstance(session_id, str) + assert len(session_id) > 0 + break # Only check the first one + + +class TestGetEvaluationData: + def test_returns_session_with_traces(self, evaluation_data, session_id): + session = evaluation_data["trajectory"] + assert isinstance(session, Session) + assert session.session_id == session_id + assert len(session.traces) > 0 + + def test_traces_have_spans(self, evaluation_data): + session = evaluation_data["trajectory"] + for trace in session.traces: + assert isinstance(trace, Trace) + assert isinstance(trace.trace_id, str) + assert len(trace.trace_id) > 0 + assert len(trace.spans) > 0 + + def test_spans_are_typed(self, evaluation_data): + """All spans should be one of the three known types.""" + valid_types = (AgentInvocationSpan, InferenceSpan, ToolExecutionSpan) + session = evaluation_data["trajectory"] + for trace in session.traces: + for span in trace.spans: + assert isinstance(span, valid_types), f"Unexpected span type: {type(span).__name__}" + + def test_has_agent_invocation_span(self, evaluation_data): + """At least one trace should have an AgentInvocationSpan.""" + session = evaluation_data["trajectory"] + agent_spans = [ + span for trace in session.traces for span in trace.spans if isinstance(span, AgentInvocationSpan) + ] + assert len(agent_spans) > 0, "Expected at least one AgentInvocationSpan" + + def test_agent_invocation_has_prompt_and_response(self, evaluation_data): + """AgentInvocationSpan should have user prompt, agent response, and tools list.""" + session = evaluation_data["trajectory"] + for trace in session.traces: + for span in trace.spans: + if isinstance(span, AgentInvocationSpan): + assert isinstance(span.user_prompt, str) + assert len(span.user_prompt) > 0 + assert isinstance(span.agent_response, str) + assert len(span.agent_response) > 0 + assert isinstance(span.available_tools, list) + return + pytest.fail("No AgentInvocationSpan found") + + def test_output_is_nonempty_string(self, evaluation_data): + output = evaluation_data["output"] + assert isinstance(output, str) + assert len(output) > 0, "Expected non-empty output from agent response" + + def test_span_info_populated(self, evaluation_data): + session = evaluation_data["trajectory"] + for trace in session.traces: + for span in trace.spans: + si = span.span_info + assert si.trace_id is not None + assert si.span_id is not None + assert si.session_id == session.session_id + assert si.start_time is not None + assert si.end_time is not None + return + pytest.fail("No spans found to check") + + def test_nonexistent_session_raises(self, provider): + with pytest.raises(SessionNotFoundError): + provider.get_evaluation_data("nonexistent-session-id-that-does-not-exist-12345") + + +class TestGetEvaluationDataByTraceId: + def test_fetches_by_trace_id(self, provider, evaluation_data): + """Use a trace_id from the discovered session to test trace-level retrieval.""" + session = evaluation_data["trajectory"] + trace_id = session.traces[0].trace_id + + result = provider.get_evaluation_data_by_trace_id(trace_id) + + assert isinstance(result["trajectory"], Session) + assert len(result["trajectory"].traces) > 0 + assert result["trajectory"].traces[0].trace_id == trace_id + + def test_nonexistent_trace_raises(self, provider): + with pytest.raises((TraceNotFoundError, ProviderError)): + provider.get_evaluation_data_by_trace_id("nonexistent-trace-id-12345") + + +# --- End-to-end: CloudWatch → Evaluator pipeline --- + + +class TestEndToEnd: + """Fetch traces from CloudWatch and run real evaluators on them.""" + + def test_output_evaluator_on_remote_trace(self, provider, session_id): + """OutputEvaluator produces a valid score from a CloudWatch session.""" + + def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + + cases = [ + Case( + name="cloudwatch_session", + input=session_id, + expected_output="any agent response", + ), + ] + + evaluator = OutputEvaluator( + rubric="Score 1.0 if the output is a coherent response from an AI agent. " + "Score 0.0 if the output is empty or clearly broken.", + ) + + experiment = Experiment(cases=cases, evaluators=[evaluator]) + reports = experiment.run_evaluations(task) + + assert len(reports) == 1 + report = reports[0] + assert 0.0 <= report.overall_score <= 1.0 + assert len(report.scores) == 1 + + def test_coherence_evaluator_on_remote_trace(self, provider, session_id): + """CoherenceEvaluator produces a valid score from a CloudWatch session.""" + + def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + + cases = [ + Case( + name="cloudwatch_session", + input=session_id, + expected_output="any agent response", + ), + ] + + evaluator = CoherenceEvaluator() + + experiment = Experiment(cases=cases, evaluators=[evaluator]) + reports = experiment.run_evaluations(task) + + assert len(reports) == 1 + report = reports[0] + assert 0.0 <= report.overall_score <= 1.0 + + def test_multiple_evaluators_on_remote_trace(self, provider, session_id): + """Multiple evaluators can all run on the same CloudWatch session data.""" + + def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + + cases = [ + Case( + name="cloudwatch_session", + input=session_id, + expected_output="any agent response", + ), + ] + + evaluators = [ + OutputEvaluator(rubric="Score 1.0 if the output is coherent. Score 0.0 otherwise."), + CoherenceEvaluator(), + HelpfulnessEvaluator(), + ] + + experiment = Experiment(cases=cases, evaluators=evaluators) + reports = experiment.run_evaluations(task) + + assert len(reports) == 3 + for report in reports: + assert 0.0 <= report.overall_score <= 1.0 + assert len(report.scores) == 1 diff --git a/tests_integ/test_langfuse_provider.py b/tests_integ/test_langfuse_provider.py index cdb5e3f..78d98a9 100644 --- a/tests_integ/test_langfuse_provider.py +++ b/tests_integ/test_langfuse_provider.py @@ -88,18 +88,13 @@ def test_spans_are_typed(self, evaluation_data): session = evaluation_data["trajectory"] for trace in session.traces: for span in trace.spans: - assert isinstance(span, valid_types), ( - f"Unexpected span type: {type(span).__name__}" - ) + assert isinstance(span, valid_types), f"Unexpected span type: {type(span).__name__}" def test_has_agent_invocation_span(self, evaluation_data): """At least one trace should have an AgentInvocationSpan.""" session = evaluation_data["trajectory"] agent_spans = [ - span - for trace in session.traces - for span in trace.spans - if isinstance(span, AgentInvocationSpan) + span for trace in session.traces for span in trace.spans if isinstance(span, AgentInvocationSpan) ] assert len(agent_spans) > 0, "Expected at least one AgentInvocationSpan" @@ -184,9 +179,8 @@ def task(case: Case) -> dict: assert len(reports) == 1 report = reports[0] - assert report.score is not None - assert 0.0 <= report.score <= 1.0 - assert len(report.case_results) == 1 + assert 0.0 <= report.overall_score <= 1.0 + assert len(report.scores) == 1 def test_coherence_evaluator_on_remote_trace(self, provider, discovered_session_id): """CoherenceEvaluator produces a valid score from a Langfuse session.""" @@ -209,8 +203,7 @@ def task(case: Case) -> dict: assert len(reports) == 1 report = reports[0] - assert report.score is not None - assert 0.0 <= report.score <= 1.0 + assert 0.0 <= report.overall_score <= 1.0 def test_multiple_evaluators_on_remote_trace(self, provider, discovered_session_id): """Multiple evaluators can all run on the same Langfuse session data.""" @@ -237,6 +230,5 @@ def task(case: Case) -> dict: assert len(reports) == 3 for report in reports: - assert report.score is not None - assert 0.0 <= report.score <= 1.0 - assert len(report.case_results) == 1 + assert 0.0 <= report.overall_score <= 1.0 + assert len(report.scores) == 1 From 61599d2f1a81e0f7af6097da4cacee97a83137a5 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Thu, 26 Feb 2026 09:25:45 -0500 Subject: [PATCH 6/8] feat: add CloudWatchSessionMapper for CW Logs Insights records Add a new session mapper that converts CloudWatch Logs Insights OTEL log records into typed Session objects for evaluation. The mapper parses body.input/output messages, extracts tool calls/results, and builds InferenceSpan, ToolExecutionSpan, and AgentInvocationSpan instances grouped by traceId. Includes comprehensive unit tests. --- src/strands_evals/mappers/__init__.py | 3 +- .../mappers/cloudwatch_session_mapper.py | 354 +++++++++++++++ src/strands_evals/providers/README.md | 13 +- src/strands_evals/providers/__init__.py | 2 - .../providers/cloudwatch_provider.py | 414 +----------------- src/strands_evals/providers/exceptions.py | 4 - .../mappers/test_cloudwatch_session_mapper.py | 279 ++++++++++++ .../providers/test_cloudwatch_provider.py | 356 +-------------- .../providers/test_trace_provider.py | 6 +- tests_integ/test_cloudwatch_provider.py | 30 -- 10 files changed, 652 insertions(+), 809 deletions(-) create mode 100644 src/strands_evals/mappers/cloudwatch_session_mapper.py create mode 100644 tests/strands_evals/mappers/test_cloudwatch_session_mapper.py diff --git a/src/strands_evals/mappers/__init__.py b/src/strands_evals/mappers/__init__.py index 8a27078..5325de4 100644 --- a/src/strands_evals/mappers/__init__.py +++ b/src/strands_evals/mappers/__init__.py @@ -1,6 +1,7 @@ """Converters for transforming telemetry data to Session format.""" +from .cloudwatch_session_mapper import CloudWatchSessionMapper from .session_mapper import SessionMapper from .strands_in_memory_session_mapper import GenAIConventionVersion, StrandsInMemorySessionMapper -__all__ = ["GenAIConventionVersion", "SessionMapper", "StrandsInMemorySessionMapper"] +__all__ = ["CloudWatchSessionMapper", "GenAIConventionVersion", "SessionMapper", "StrandsInMemorySessionMapper"] diff --git a/src/strands_evals/mappers/cloudwatch_session_mapper.py b/src/strands_evals/mappers/cloudwatch_session_mapper.py new file mode 100644 index 0000000..4cdb57b --- /dev/null +++ b/src/strands_evals/mappers/cloudwatch_session_mapper.py @@ -0,0 +1,354 @@ +"""CloudWatch session mapper — converts CW Logs Insights records to Session format.""" + +import json +import logging +from collections import defaultdict +from datetime import datetime, timezone +from typing import Any + +from ..mappers.session_mapper import SessionMapper +from ..types.trace import ( + AgentInvocationSpan, + AssistantMessage, + InferenceSpan, + Session, + SpanInfo, + TextContent, + ToolCall, + ToolCallContent, + ToolConfig, + ToolExecutionSpan, + ToolResult, + ToolResultContent, + Trace, + UserMessage, +) + +logger = logging.getLogger(__name__) + + +class CloudWatchSessionMapper(SessionMapper): + """Maps CloudWatch Logs Insights records to Session format. + + Parses body.input/output messages from OTEL log records emitted by + Strands agent runtimes and builds typed Session objects for evaluation. + """ + + def map_to_session(self, records: list[dict[str, Any]], session_id: str) -> Session: + """Group log records by traceId, convert each group to a Trace, return a Session.""" + traces_by_id: dict[str, list[dict[str, Any]]] = defaultdict(list) + for record in records: + trace_id = record.get("traceId", "") + if not trace_id: + continue + traces_by_id[trace_id].append(record) + + traces: list[Trace] = [] + for trace_id, trace_records in traces_by_id.items(): + trace = self._convert_trace(trace_id, trace_records, session_id) + if trace.spans: + traces.append(trace) + + return Session(session_id=session_id, traces=traces) + + def _convert_trace(self, trace_id: str, records: list[dict[str, Any]], session_id: str) -> Trace: + """Convert a group of log records (same traceId) into a Trace with typed spans.""" + sorted_records = sorted(records, key=lambda r: r.get("timeUnixNano", 0)) + + spans: list[InferenceSpan | ToolExecutionSpan | AgentInvocationSpan] = [] + + # Collect all tool calls and results across records + all_tool_calls: dict[str, ToolCall] = {} + all_tool_results: dict[str, ToolResult] = {} + + for record in sorted_records: + if not isinstance(record.get("body"), dict): + continue + + for tc in self._extract_tool_calls(record): + if tc.tool_call_id: + all_tool_calls[tc.tool_call_id] = tc + + for tr in self._extract_tool_results(record): + if tr.tool_call_id: + all_tool_results[tr.tool_call_id] = tr + + # Create InferenceSpans (one per record with parseable body) + for record in sorted_records: + if not isinstance(record.get("body"), dict): + continue + + try: + messages = self._record_to_messages(record) + if messages: + span_info = self._create_span_info(record, session_id) + spans.append(InferenceSpan(span_info=span_info, messages=messages, metadata={})) + except Exception as e: + logger.warning("Failed to create inference span from record %s: %s", record.get("spanId"), e) + + # Create ToolExecutionSpans by matching calls to results + seen_tool_ids: set[str] = set() + for record in sorted_records: + for tc in self._extract_tool_calls(record): + if tc.tool_call_id and tc.tool_call_id not in seen_tool_ids: + seen_tool_ids.add(tc.tool_call_id) + tr = all_tool_results.get(tc.tool_call_id, ToolResult(content="", tool_call_id=tc.tool_call_id)) + span_info = self._create_span_info(record, session_id) + spans.append(ToolExecutionSpan(span_info=span_info, tool_call=tc, tool_result=tr, metadata={})) + + # Create AgentInvocationSpan from first user prompt + last agent response + agent_span = self._create_agent_invocation_span(sorted_records, all_tool_calls, session_id) + if agent_span: + spans.append(agent_span) + + return Trace(spans=spans, trace_id=trace_id, session_id=session_id) + + def _create_agent_invocation_span( + self, records: list[dict[str, Any]], tool_calls: dict[str, ToolCall], session_id: str + ) -> AgentInvocationSpan | None: + """Create an AgentInvocationSpan from the first user prompt and last agent response.""" + user_prompt = None + for record in records: + prompt = self._extract_user_prompt(record) + if prompt: + user_prompt = prompt + break + + if not user_prompt: + return None + + agent_response = None + best_record = None + for record in reversed(records): + response = self._extract_agent_response(record) + if response: + agent_response = response + best_record = record + break + + if not agent_response or not best_record: + return None + + available_tools = [ToolConfig(name=name) for name in sorted({tc.name for tc in tool_calls.values()})] + span_info = self._create_span_info(best_record, session_id) + + return AgentInvocationSpan( + span_info=span_info, + user_prompt=user_prompt, + agent_response=agent_response, + available_tools=available_tools, + metadata={}, + ) + + # --- Span info --- + + def _create_span_info(self, record: dict[str, Any], session_id: str) -> SpanInfo: + time_nano = record.get("timeUnixNano", 0) + ts = datetime.fromtimestamp(time_nano / 1e9, tz=timezone.utc) + + return SpanInfo( + trace_id=record.get("traceId", ""), + span_id=record.get("spanId", ""), + session_id=session_id, + parent_span_id=record.get("parentSpanId") or None, + start_time=ts, + end_time=ts, + ) + + # --- Body-based content extraction --- + + def _parse_message_content(self, raw: str) -> list[dict[str, Any]] | None: + """Parse double-encoded message content into a list of content blocks.""" + if not isinstance(raw, str): + return None + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, list) else None + except (json.JSONDecodeError, TypeError): + return None + + def _extract_content_field(self, content: dict[str, Any]) -> str | None: + """Extract the raw content field from a message.""" + if not isinstance(content, dict): + return None + return content.get("content") or content.get("message") + + def _extract_text_from_content(self, content: Any) -> str | None: + """Extract text from a content field, handling double-encoded JSON strings.""" + raw = self._extract_content_field(content) + if not raw: + return None + + parsed = self._parse_message_content(raw) + if parsed: + texts = [item["text"] for item in parsed if isinstance(item, dict) and "text" in item] + return " ".join(texts) if texts else None + + return raw if isinstance(raw, str) else None + + def _extract_message_text(self, record: dict[str, Any], message_type: str, role: str) -> str | None: + """Extract text from a specific message type and role in a log record.""" + body = record.get("body", {}) + if not isinstance(body, dict): + return None + + messages = body.get(message_type, {}).get("messages", []) + for msg in messages: + if msg.get("role") == role: + text = self._extract_text_from_content(msg.get("content", {})) + if text: + return text + return None + + def _extract_user_prompt(self, record: dict[str, Any]) -> str | None: + """Extract user prompt text from a log record's body.input.messages.""" + return self._extract_message_text(record, "input", "user") + + def _extract_agent_response(self, record: dict[str, Any]) -> str | None: + """Extract assistant text response from a log record's body.output.messages.""" + return self._extract_message_text(record, "output", "assistant") + + def _extract_tool_calls(self, record: dict[str, Any]) -> list[ToolCall]: + """Extract tool calls from a log record's body.output.messages.""" + tool_calls: list[ToolCall] = [] + body = record.get("body", {}) + if not isinstance(body, dict): + return tool_calls + + for msg in body.get("output", {}).get("messages", []): + if msg.get("role") != "assistant": + continue + + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + for item in parsed: + if isinstance(item, dict) and "toolUse" in item: + tu = item["toolUse"] + tool_calls.append( + ToolCall( + name=tu.get("name", ""), + arguments=tu.get("input", {}), + tool_call_id=tu.get("toolUseId"), + ) + ) + + return tool_calls + + def _extract_tool_results(self, record: dict[str, Any]) -> list[ToolResult]: + """Extract tool results from a log record's body.input.messages.""" + tool_results: list[ToolResult] = [] + body = record.get("body", {}) + if not isinstance(body, dict): + return tool_results + + for msg in body.get("input", {}).get("messages", []): + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + for item in parsed: + if isinstance(item, dict) and "toolResult" in item: + tr_data = item["toolResult"] + result_text = self._extract_tool_result_text(tr_data.get("content")) + tool_results.append( + ToolResult( + content=result_text, + error=tr_data.get("error"), + tool_call_id=tr_data.get("toolUseId"), + ) + ) + + return tool_results + + def _extract_tool_result_text(self, content: Any) -> str: + """Extract text from tool result content.""" + if not content: + return "" + if isinstance(content, list) and content: + return content[0].get("text", "") + return str(content) + + # --- Record-to-messages conversion --- + + def _record_to_messages(self, record: dict[str, Any]) -> list[UserMessage | AssistantMessage]: + """Convert a log record's body into a list of typed messages for InferenceSpan.""" + messages: list[UserMessage | AssistantMessage] = [] + body = record.get("body", {}) + if not isinstance(body, dict): + return messages + + # Process input messages + for msg in body.get("input", {}).get("messages", []): + role = msg.get("role", "") + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + if role == "user": + user_content = self._process_user_message(parsed) + if user_content: + messages.append(UserMessage(content=user_content)) + elif role == "tool": + tool_content = self._process_tool_results(parsed) + if tool_content: + messages.append(UserMessage(content=tool_content)) + + # Process output messages + for msg in body.get("output", {}).get("messages", []): + if msg.get("role") != "assistant": + continue + + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + assistant_content = self._process_assistant_content(parsed) + if assistant_content: + messages.append(AssistantMessage(content=assistant_content)) + + return messages + + # --- Content parsing helpers (Bedrock Converse format) --- + + @staticmethod + def _process_user_message(content_list: list[dict[str, Any]]) -> list[TextContent | ToolResultContent]: + return [TextContent(text=item["text"]) for item in content_list if "text" in item] + + @staticmethod + def _process_assistant_content(content_list: list[dict[str, Any]]) -> list[TextContent | ToolCallContent]: + result: list[TextContent | ToolCallContent] = [] + for item in content_list: + if "text" in item: + result.append(TextContent(text=item["text"])) + elif "toolUse" in item: + tool_use = item["toolUse"] + result.append( + ToolCallContent( + name=tool_use["name"], + arguments=tool_use.get("input", {}), + tool_call_id=tool_use.get("toolUseId"), + ) + ) + return result + + def _process_tool_results(self, content_list: list[dict[str, Any]]) -> list[TextContent | ToolResultContent]: + result: list[TextContent | ToolResultContent] = [] + for item in content_list: + if "toolResult" not in item: + continue + tool_result = item["toolResult"] + result_text = self._extract_tool_result_text(tool_result.get("content")) + result.append( + ToolResultContent( + content=result_text, + error=tool_result.get("error"), + tool_call_id=tool_result.get("toolUseId"), + ) + ) + return result diff --git a/src/strands_evals/providers/README.md b/src/strands_evals/providers/README.md index 72bb4bb..17b2b21 100644 --- a/src/strands_evals/providers/README.md +++ b/src/strands_evals/providers/README.md @@ -51,13 +51,6 @@ All providers implement the `TraceProvider` interface: data = provider.get_evaluation_data(session_id="my-session-id") # data["output"] -> str (final agent response) # data["trajectory"] -> Session (traces and spans) - -# Discover session IDs -for session_id in provider.list_sessions(): - print(session_id) - -# Fetch a single trace by ID -data = provider.get_evaluation_data_by_trace_id(trace_id="abc123") ``` ## Running Evaluators on Remote Traces @@ -90,7 +83,7 @@ for report in reports: ## Error Handling ```python -from strands_evals.providers import SessionNotFoundError, TraceNotFoundError, ProviderError +from strands_evals.providers import SessionNotFoundError, ProviderError try: data = provider.get_evaluation_data("unknown-session") @@ -112,8 +105,4 @@ class MyProvider(TraceProvider): # Fetch traces from your backend, return: # {"output": "final response text", "trajectory": Session(...)} ... - - def list_sessions(self, session_filter=None): - # Optional: yield session ID strings - ... ``` diff --git a/src/strands_evals/providers/__init__.py b/src/strands_evals/providers/__init__.py index 40c9c4c..1b81ee3 100644 --- a/src/strands_evals/providers/__init__.py +++ b/src/strands_evals/providers/__init__.py @@ -3,7 +3,6 @@ from .exceptions import ( ProviderError, SessionNotFoundError, - TraceNotFoundError, TraceProviderError, ) from .trace_provider import ( @@ -14,7 +13,6 @@ "LangfuseProvider", "ProviderError", "SessionNotFoundError", - "TraceNotFoundError", "TraceProvider", "TraceProviderError", ] diff --git a/src/strands_evals/providers/cloudwatch_provider.py b/src/strands_evals/providers/cloudwatch_provider.py index d8e90ab..375ef11 100644 --- a/src/strands_evals/providers/cloudwatch_provider.py +++ b/src/strands_evals/providers/cloudwatch_provider.py @@ -4,31 +4,18 @@ import logging import os import time -from collections import defaultdict -from collections.abc import Iterator from datetime import datetime, timedelta, timezone from typing import Any import boto3 -from ..providers.exceptions import ProviderError, SessionNotFoundError, TraceNotFoundError -from ..providers.trace_provider import SessionFilter, TraceProvider +from ..mappers.cloudwatch_session_mapper import CloudWatchSessionMapper +from ..providers.exceptions import ProviderError, SessionNotFoundError +from ..providers.trace_provider import TraceProvider from ..types.evaluation import TaskOutput from ..types.trace import ( AgentInvocationSpan, - AssistantMessage, - InferenceSpan, Session, - SpanInfo, - TextContent, - ToolCall, - ToolCallContent, - ToolConfig, - ToolExecutionSpan, - ToolResult, - ToolResultContent, - Trace, - UserMessage, ) logger = logging.getLogger(__name__) @@ -66,6 +53,7 @@ def __init__( self._lookback_days = lookback_days self._query_timeout_seconds = query_timeout_seconds + self._mapper = CloudWatchSessionMapper() def _discover_log_group(self, agent_name: str) -> str: """Discover the runtime log group for an agent via describe_log_groups.""" @@ -90,7 +78,7 @@ def get_evaluation_data(self, session_id: str) -> TaskOutput: if not span_dicts: raise SessionNotFoundError(f"CloudWatch: no spans found for session_id='{session_id}'") - session = self._build_session(session_id, span_dicts) + session = self._mapper.map_to_session(span_dicts, session_id) if not session.traces: raise SessionNotFoundError( @@ -100,75 +88,13 @@ def get_evaluation_data(self, session_id: str) -> TaskOutput: output = self._extract_output(session) return TaskOutput(output=output, trajectory=session) - def get_evaluation_data_by_trace_id(self, trace_id: str) -> TaskOutput: - """Fetch a single trace by ID and return evaluation data.""" - query = f'fields @message | filter traceId = "{trace_id}" | sort @timestamp asc | limit 10000' - - try: - span_dicts = self._run_logs_insights_query(query) - except ProviderError: - raise - except Exception as e: - raise ProviderError(f"CloudWatch: failed to query trace '{trace_id}': {e}") from e - - if not span_dicts: - raise TraceNotFoundError(f"CloudWatch: no spans found for trace_id='{trace_id}'") - - # Extract session_id from the first record's attributes - first_attrs = span_dicts[0].get("attributes", {}) - session_id = first_attrs.get("session.id") or trace_id - - session = self._build_session(session_id, span_dicts) - output = self._extract_output(session) - return TaskOutput(output=output, trajectory=session) - - def list_sessions(self, session_filter: SessionFilter | None = None) -> Iterator[str]: - """Yield distinct session IDs from CloudWatch Logs.""" - limit = session_filter.limit if session_filter and session_filter.limit else 1000 - start_time = session_filter.start_time if session_filter else None - end_time = session_filter.end_time if session_filter else None - - query = ( - "fields attributes.session.id as sessionId" - " | filter ispresent(attributes.session.id)" - " | stats count(*) as span_count by sessionId" - " | sort sessionId asc" - f" | limit {limit}" - ) - - try: - results = self._run_raw_logs_insights_query(query, start_time=start_time, end_time=end_time) - except ProviderError: - raise - except Exception as e: - raise ProviderError(f"CloudWatch: failed to list sessions: {e}") from e - - for row in results: - for field in row: - if field.get("field") == "sessionId": - yield field["value"] - # --- Internal: CW Logs Insights query execution --- - def _run_logs_insights_query( - self, query: str, start_time: datetime | None = None, end_time: datetime | None = None - ) -> list[dict[str, Any]]: + def _run_logs_insights_query(self, query: str) -> list[dict[str, Any]]: """Execute a CW Logs Insights query and return parsed span dicts from @message fields.""" - raw_results = self._run_raw_logs_insights_query(query, start_time=start_time, end_time=end_time) - return self._parse_query_results(raw_results) - - def _run_raw_logs_insights_query( - self, - query: str, - start_time: datetime | None = None, - end_time: datetime | None = None, - ) -> list[list[dict[str, str]]]: - """Execute a CW Logs Insights query and return raw result rows.""" now = datetime.now(tz=timezone.utc) - if end_time is None: - end_time = now - if start_time is None: - start_time = now - timedelta(days=self._lookback_days) + end_time = now + start_time = now - timedelta(days=self._lookback_days) try: response = self._client.start_query( @@ -181,7 +107,8 @@ def _run_raw_logs_insights_query( raise ProviderError(f"CloudWatch: failed to start query: {e}") from e query_id = response["queryId"] - return self._poll_query_results(query_id) + raw_results = self._poll_query_results(query_id) + return self._parse_query_results(raw_results) def _poll_query_results(self, query_id: str) -> list[list[dict[str, str]]]: """Poll for query completion with exponential backoff. Returns raw result rows.""" @@ -217,113 +144,7 @@ def _parse_query_results(results: list[list[dict[str, str]]]) -> list[dict[str, logger.warning("Failed to parse @message: %s", e) return span_dicts - # --- Internal: session building (body-based parsing) --- - - def _build_session(self, session_id: str, records: list[dict[str, Any]]) -> Session: - """Group log records by traceId, convert each group to a Trace, return a Session.""" - traces_by_id: dict[str, list[dict[str, Any]]] = defaultdict(list) - for record in records: - trace_id = record.get("traceId", "") - if not trace_id: - continue - traces_by_id[trace_id].append(record) - - traces: list[Trace] = [] - for trace_id, trace_records in traces_by_id.items(): - trace = self._convert_trace(trace_id, trace_records, session_id) - if trace.spans: - traces.append(trace) - - return Session(session_id=session_id, traces=traces) - - def _convert_trace(self, trace_id: str, records: list[dict[str, Any]], session_id: str) -> Trace: - """Convert a group of log records (same traceId) into a Trace with typed spans.""" - sorted_records = sorted(records, key=lambda r: r.get("timeUnixNano", 0)) - - spans: list[InferenceSpan | ToolExecutionSpan | AgentInvocationSpan] = [] - - # Collect all tool calls and results across records - all_tool_calls: dict[str, ToolCall] = {} - all_tool_results: dict[str, ToolResult] = {} - - for record in sorted_records: - if not isinstance(record.get("body"), dict): - continue - - for tc in self._extract_tool_calls(record): - if tc.tool_call_id: - all_tool_calls[tc.tool_call_id] = tc - - for tr in self._extract_tool_results(record): - if tr.tool_call_id: - all_tool_results[tr.tool_call_id] = tr - - # Create InferenceSpans (one per record with parseable body) - for record in sorted_records: - if not isinstance(record.get("body"), dict): - continue - - try: - messages = self._record_to_messages(record) - if messages: - span_info = self._create_span_info(record, session_id) - spans.append(InferenceSpan(span_info=span_info, messages=messages, metadata={})) - except Exception as e: - logger.warning("Failed to create inference span from record %s: %s", record.get("spanId"), e) - - # Create ToolExecutionSpans by matching calls to results - seen_tool_ids: set[str] = set() - for record in sorted_records: - for tc in self._extract_tool_calls(record): - if tc.tool_call_id and tc.tool_call_id not in seen_tool_ids: - seen_tool_ids.add(tc.tool_call_id) - tr = all_tool_results.get(tc.tool_call_id, ToolResult(content="", tool_call_id=tc.tool_call_id)) - span_info = self._create_span_info(record, session_id) - spans.append(ToolExecutionSpan(span_info=span_info, tool_call=tc, tool_result=tr, metadata={})) - - # Create AgentInvocationSpan from first user prompt + last agent response - agent_span = self._create_agent_invocation_span(sorted_records, all_tool_calls, session_id) - if agent_span: - spans.append(agent_span) - - return Trace(spans=spans, trace_id=trace_id, session_id=session_id) - - def _create_agent_invocation_span( - self, records: list[dict[str, Any]], tool_calls: dict[str, ToolCall], session_id: str - ) -> AgentInvocationSpan | None: - """Create an AgentInvocationSpan from the first user prompt and last agent response.""" - user_prompt = None - for record in records: - prompt = self._extract_user_prompt(record) - if prompt: - user_prompt = prompt - break - - if not user_prompt: - return None - - agent_response = None - best_record = None - for record in reversed(records): - response = self._extract_agent_response(record) - if response: - agent_response = response - best_record = record - break - - if not agent_response or not best_record: - return None - - available_tools = [ToolConfig(name=name) for name in sorted({tc.name for tc in tool_calls.values()})] - span_info = self._create_span_info(best_record, session_id) - - return AgentInvocationSpan( - span_info=span_info, - user_prompt=user_prompt, - agent_response=agent_response, - available_tools=available_tools, - metadata={}, - ) + # --- Internal: output extraction --- def _extract_output(self, session: Session) -> str: """Extract the final agent response from the session for TaskOutput.output.""" @@ -332,216 +153,3 @@ def _extract_output(self, session: Session) -> str: if isinstance(span, AgentInvocationSpan): return span.agent_response return "" - - # --- Internal: span info --- - - def _create_span_info(self, record: dict[str, Any], session_id: str) -> SpanInfo: - time_nano = record.get("timeUnixNano", 0) - ts = datetime.fromtimestamp(time_nano / 1e9, tz=timezone.utc) - - return SpanInfo( - trace_id=record.get("traceId", ""), - span_id=record.get("spanId", ""), - session_id=session_id, - parent_span_id=record.get("parentSpanId") or None, - start_time=ts, - end_time=ts, - ) - - # --- Internal: body-based content extraction --- - - def _parse_message_content(self, raw: str) -> list[dict[str, Any]] | None: - """Parse double-encoded message content into a list of content blocks.""" - if not isinstance(raw, str): - return None - try: - parsed = json.loads(raw) - return parsed if isinstance(parsed, list) else None - except (json.JSONDecodeError, TypeError): - return None - - def _extract_content_field(self, content: dict[str, Any]) -> str | None: - """Extract the raw content field from a message.""" - if not isinstance(content, dict): - return None - return content.get("content") or content.get("message") - - def _extract_text_from_content(self, content: Any) -> str | None: - """Extract text from a content field, handling double-encoded JSON strings.""" - raw = self._extract_content_field(content) - if not raw: - return None - - parsed = self._parse_message_content(raw) - if parsed: - texts = [item["text"] for item in parsed if isinstance(item, dict) and "text" in item] - return " ".join(texts) if texts else None - - return raw if isinstance(raw, str) else None - - def _extract_message_text(self, record: dict[str, Any], message_type: str, role: str) -> str | None: - """Extract text from a specific message type and role in a log record.""" - body = record.get("body", {}) - if not isinstance(body, dict): - return None - - messages = body.get(message_type, {}).get("messages", []) - for msg in messages: - if msg.get("role") == role: - text = self._extract_text_from_content(msg.get("content", {})) - if text: - return text - return None - - def _extract_user_prompt(self, record: dict[str, Any]) -> str | None: - """Extract user prompt text from a log record's body.input.messages.""" - return self._extract_message_text(record, "input", "user") - - def _extract_agent_response(self, record: dict[str, Any]) -> str | None: - """Extract assistant text response from a log record's body.output.messages.""" - return self._extract_message_text(record, "output", "assistant") - - def _extract_tool_calls(self, record: dict[str, Any]) -> list[ToolCall]: - """Extract tool calls from a log record's body.output.messages.""" - tool_calls: list[ToolCall] = [] - body = record.get("body", {}) - if not isinstance(body, dict): - return tool_calls - - for msg in body.get("output", {}).get("messages", []): - if msg.get("role") != "assistant": - continue - - raw = self._extract_content_field(msg.get("content", {})) - parsed = self._parse_message_content(raw) if raw else None - if not parsed: - continue - - for item in parsed: - if isinstance(item, dict) and "toolUse" in item: - tu = item["toolUse"] - tool_calls.append( - ToolCall( - name=tu.get("name", ""), - arguments=tu.get("input", {}), - tool_call_id=tu.get("toolUseId"), - ) - ) - - return tool_calls - - def _extract_tool_results(self, record: dict[str, Any]) -> list[ToolResult]: - """Extract tool results from a log record's body.input.messages.""" - tool_results: list[ToolResult] = [] - body = record.get("body", {}) - if not isinstance(body, dict): - return tool_results - - for msg in body.get("input", {}).get("messages", []): - raw = self._extract_content_field(msg.get("content", {})) - parsed = self._parse_message_content(raw) if raw else None - if not parsed: - continue - - for item in parsed: - if isinstance(item, dict) and "toolResult" in item: - tr_data = item["toolResult"] - result_text = self._extract_tool_result_text(tr_data.get("content")) - tool_results.append( - ToolResult( - content=result_text, - error=tr_data.get("error"), - tool_call_id=tr_data.get("toolUseId"), - ) - ) - - return tool_results - - def _extract_tool_result_text(self, content: Any) -> str: - """Extract text from tool result content.""" - if not content: - return "" - if isinstance(content, list) and content: - return content[0].get("text", "") - return str(content) - - # --- Internal: record-to-messages conversion --- - - def _record_to_messages(self, record: dict[str, Any]) -> list[UserMessage | AssistantMessage]: - """Convert a log record's body into a list of typed messages for InferenceSpan.""" - messages: list[UserMessage | AssistantMessage] = [] - body = record.get("body", {}) - if not isinstance(body, dict): - return messages - - # Process input messages - for msg in body.get("input", {}).get("messages", []): - role = msg.get("role", "") - raw = self._extract_content_field(msg.get("content", {})) - parsed = self._parse_message_content(raw) if raw else None - if not parsed: - continue - - if role == "user": - user_content = self._process_user_message(parsed) - if user_content: - messages.append(UserMessage(content=user_content)) - elif role == "tool": - tool_content = self._process_tool_results(parsed) - if tool_content: - messages.append(UserMessage(content=tool_content)) - - # Process output messages - for msg in body.get("output", {}).get("messages", []): - if msg.get("role") != "assistant": - continue - - raw = self._extract_content_field(msg.get("content", {})) - parsed = self._parse_message_content(raw) if raw else None - if not parsed: - continue - - assistant_content = self._process_assistant_content(parsed) - if assistant_content: - messages.append(AssistantMessage(content=assistant_content)) - - return messages - - # --- Internal: content parsing helpers (Bedrock Converse format) --- - - @staticmethod - def _process_user_message(content_list: list[dict[str, Any]]) -> list[TextContent | ToolResultContent]: - return [TextContent(text=item["text"]) for item in content_list if "text" in item] - - @staticmethod - def _process_assistant_content(content_list: list[dict[str, Any]]) -> list[TextContent | ToolCallContent]: - result: list[TextContent | ToolCallContent] = [] - for item in content_list: - if "text" in item: - result.append(TextContent(text=item["text"])) - elif "toolUse" in item: - tool_use = item["toolUse"] - result.append( - ToolCallContent( - name=tool_use["name"], - arguments=tool_use.get("input", {}), - tool_call_id=tool_use.get("toolUseId"), - ) - ) - return result - - def _process_tool_results(self, content_list: list[dict[str, Any]]) -> list[TextContent | ToolResultContent]: - result: list[TextContent | ToolResultContent] = [] - for item in content_list: - if "toolResult" not in item: - continue - tool_result = item["toolResult"] - result_text = self._extract_tool_result_text(tool_result.get("content")) - result.append( - ToolResultContent( - content=result_text, - error=tool_result.get("error"), - tool_call_id=tool_result.get("toolUseId"), - ) - ) - return result diff --git a/src/strands_evals/providers/exceptions.py b/src/strands_evals/providers/exceptions.py index cb4f8a4..2823c1b 100644 --- a/src/strands_evals/providers/exceptions.py +++ b/src/strands_evals/providers/exceptions.py @@ -9,9 +9,5 @@ class SessionNotFoundError(TraceProviderError): """No traces found for the given session ID.""" -class TraceNotFoundError(TraceProviderError): - """Trace with the given ID not found.""" - - class ProviderError(TraceProviderError): """Provider is unreachable or returned an error.""" diff --git a/tests/strands_evals/mappers/test_cloudwatch_session_mapper.py b/tests/strands_evals/mappers/test_cloudwatch_session_mapper.py new file mode 100644 index 0000000..685ca43 --- /dev/null +++ b/tests/strands_evals/mappers/test_cloudwatch_session_mapper.py @@ -0,0 +1,279 @@ +"""Tests for CloudWatchSessionMapper — body-format CW log record → Session conversion.""" + +import json + +from strands_evals.mappers.cloudwatch_session_mapper import CloudWatchSessionMapper +from strands_evals.types.trace import ( + AgentInvocationSpan, + InferenceSpan, + ToolExecutionSpan, +) + +# --- Helpers for body-format log records --- + + +def _make_log_record( + trace_id="abc123", + span_id="span-1", + input_messages=None, + output_messages=None, + session_id="sess-1", + time_nano=1000000000000000000, +): + """Build a body-format OTEL log record dict as found in runtime log groups.""" + record = { + "traceId": trace_id, + "spanId": span_id, + "timeUnixNano": time_nano, + "body": { + "input": {"messages": input_messages or []}, + "output": {"messages": output_messages or []}, + }, + "attributes": {"session.id": session_id}, + } + return record + + +def _make_user_message(text): + """Build a user input message with double-encoded content.""" + return {"role": "user", "content": {"content": json.dumps([{"text": text}])}} + + +def _make_assistant_text_message(text): + """Build an assistant output message with double-encoded text content.""" + return { + "role": "assistant", + "content": {"message": json.dumps([{"text": text}]), "finish_reason": "end_turn"}, + } + + +def _make_assistant_tool_use_message(tool_name, tool_input, tool_use_id): + """Build an assistant output message with a toolUse block.""" + return { + "role": "assistant", + "content": { + "message": json.dumps([{"toolUse": {"name": tool_name, "input": tool_input, "toolUseId": tool_use_id}}]), + "finish_reason": "tool_use", + }, + } + + +def _make_tool_result_message(tool_use_id, result_text): + """Build a tool result input message with double-encoded content.""" + return { + "role": "tool", + "content": { + "content": json.dumps([{"toolResult": {"content": [{"text": result_text}], "toolUseId": tool_use_id}}]) + }, + } + + +# --- Span conversion (body-based parsing) --- + + +class TestSpanConversion: + def setup_method(self): + self.mapper = CloudWatchSessionMapper() + + def test_single_record_produces_inference_span(self): + """One log record produces an InferenceSpan with input/output messages.""" + record = _make_log_record( + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("Hello!")], + ) + session = self.mapper.map_to_session([record], "sess-1") + spans = session.traces[0].spans + inference_spans = [s for s in spans if isinstance(s, InferenceSpan)] + assert len(inference_spans) == 1 + assert inference_spans[0].messages[0].content[0].text == "Hi" + assert inference_spans[0].messages[1].content[0].text == "Hello!" + + def test_record_with_tool_use_and_result(self): + """toolUse in output + toolResult in next record's input → ToolExecutionSpan.""" + record1 = _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("Calculate 6*7")], + output_messages=[_make_assistant_tool_use_message("calculator", {"expr": "6*7"}, "tu-1")], + time_nano=1000, + ) + record2 = _make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[ + _make_user_message("Calculate 6*7"), + _make_tool_result_message("tu-1", "42"), + ], + output_messages=[_make_assistant_text_message("The answer is 42.")], + time_nano=2000, + ) + session = self.mapper.map_to_session([record1, record2], "sess-1") + tool_spans = [s for s in session.traces[0].spans if isinstance(s, ToolExecutionSpan)] + assert len(tool_spans) == 1 + assert tool_spans[0].tool_call.name == "calculator" + assert tool_spans[0].tool_call.arguments == {"expr": "6*7"} + assert tool_spans[0].tool_result.content == "42" + + def test_agent_invocation_from_trace(self): + """User prompt from first record, response from last → AgentInvocationSpan.""" + record1 = _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("Tell me a joke")], + output_messages=[_make_assistant_text_message("Why did the chicken cross the road?")], + time_nano=1000, + ) + session = self.mapper.map_to_session([record1], "sess-1") + agent_spans = [s for s in session.traces[0].spans if isinstance(s, AgentInvocationSpan)] + assert len(agent_spans) == 1 + assert agent_spans[0].user_prompt == "Tell me a joke" + assert agent_spans[0].agent_response == "Why did the chicken cross the road?" + + def test_agent_invocation_extracts_tools(self): + """available_tools populated from tool call names in the trace.""" + record1 = _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("Search for X")], + output_messages=[_make_assistant_tool_use_message("web_search", {"q": "X"}, "tu-1")], + time_nano=1000, + ) + record2 = _make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[_make_user_message("Search for X"), _make_tool_result_message("tu-1", "found X")], + output_messages=[_make_assistant_text_message("Here's what I found about X.")], + time_nano=2000, + ) + session = self.mapper.map_to_session([record1, record2], "sess-1") + agent_spans = [s for s in session.traces[0].spans if isinstance(s, AgentInvocationSpan)] + assert len(agent_spans) == 1 + tool_names = [t.name for t in agent_spans[0].available_tools] + assert "web_search" in tool_names + + def test_double_encoded_content_parsed(self): + """Content field is a JSON string that must be parsed to get content blocks.""" + record = _make_log_record( + input_messages=[_make_user_message("test double encoding")], + output_messages=[_make_assistant_text_message("parsed correctly")], + ) + session = self.mapper.map_to_session([record], "sess-1") + inference_spans = [s for s in session.traces[0].spans if isinstance(s, InferenceSpan)] + assert inference_spans[0].messages[0].content[0].text == "test double encoding" + assert inference_spans[0].messages[1].content[0].text == "parsed correctly" + + def test_tool_call_matched_to_result_by_id(self): + """toolUseId matching works across records in the same trace.""" + record1 = _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("Do two things")], + output_messages=[ + _make_assistant_tool_use_message("tool_a", {"x": 1}, "tu-a"), + ], + time_nano=1000, + ) + record2 = _make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[ + _make_user_message("Do two things"), + _make_tool_result_message("tu-a", "result-a"), + ], + output_messages=[ + _make_assistant_tool_use_message("tool_b", {"y": 2}, "tu-b"), + ], + time_nano=2000, + ) + record3 = _make_log_record( + trace_id="t1", + span_id="s3", + input_messages=[ + _make_user_message("Do two things"), + _make_tool_result_message("tu-a", "result-a"), + _make_tool_result_message("tu-b", "result-b"), + ], + output_messages=[_make_assistant_text_message("Both done.")], + time_nano=3000, + ) + session = self.mapper.map_to_session([record1, record2, record3], "sess-1") + tool_spans = [s for s in session.traces[0].spans if isinstance(s, ToolExecutionSpan)] + assert len(tool_spans) == 2 + tool_span_by_name = {ts.tool_call.name: ts for ts in tool_spans} + assert tool_span_by_name["tool_a"].tool_result.content == "result-a" + assert tool_span_by_name["tool_b"].tool_result.content == "result-b" + + +# --- Session building --- + + +class TestSessionBuilding: + def setup_method(self): + self.mapper = CloudWatchSessionMapper() + + def test_multiple_records_grouped_by_trace_id(self): + """Records with different traceIds become separate Trace objects.""" + records = [ + _make_log_record( + trace_id="t1", + input_messages=[_make_user_message("q1")], + output_messages=[_make_assistant_text_message("a1")], + ), + _make_log_record( + trace_id="t2", + input_messages=[_make_user_message("q2")], + output_messages=[_make_assistant_text_message("a2")], + ), + ] + session = self.mapper.map_to_session(records, "sess-1") + assert session.session_id == "sess-1" + assert len(session.traces) == 2 + trace_ids = {t.trace_id for t in session.traces} + assert trace_ids == {"t1", "t2"} + + def test_multi_step_agent_loop(self): + """user→LLM→tool→LLM→response produces InferenceSpan + ToolExecutionSpan + AgentInvocationSpan.""" + record1 = _make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[_make_user_message("What is 6*7?")], + output_messages=[_make_assistant_tool_use_message("calculator", {"expr": "6*7"}, "tu-1")], + time_nano=1000, + ) + record2 = _make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[ + _make_user_message("What is 6*7?"), + _make_tool_result_message("tu-1", "42"), + ], + output_messages=[_make_assistant_text_message("The answer is 42.")], + time_nano=2000, + ) + session = self.mapper.map_to_session([record1, record2], "sess-1") + assert len(session.traces) == 1 + spans = session.traces[0].spans + span_types = [type(s).__name__ for s in spans] + assert "InferenceSpan" in span_types + assert "ToolExecutionSpan" in span_types + assert "AgentInvocationSpan" in span_types + + def test_empty_records_list(self): + session = self.mapper.map_to_session([], "sess-1") + assert session.session_id == "sess-1" + assert session.traces == [] + + def test_record_with_no_body_skipped(self): + """Malformed records without body don't crash.""" + records = [ + {"traceId": "t1", "spanId": "s1", "timeUnixNano": 1000}, + _make_log_record( + trace_id="t1", + input_messages=[_make_user_message("Hi")], + output_messages=[_make_assistant_text_message("Hello!")], + time_nano=2000, + ), + ] + session = self.mapper.map_to_session(records, "sess-1") + assert len(session.traces) == 1 + assert len(session.traces[0].spans) > 0 diff --git a/tests/strands_evals/providers/test_cloudwatch_provider.py b/tests/strands_evals/providers/test_cloudwatch_provider.py index ef0dac6..879fc33 100644 --- a/tests/strands_evals/providers/test_cloudwatch_provider.py +++ b/tests/strands_evals/providers/test_cloudwatch_provider.py @@ -10,14 +10,8 @@ from strands_evals.providers.exceptions import ( ProviderError, SessionNotFoundError, - TraceNotFoundError, -) -from strands_evals.types.trace import ( - AgentInvocationSpan, - InferenceSpan, - Session, - ToolExecutionSpan, ) +from strands_evals.types.trace import Session # --- Fixtures --- @@ -178,218 +172,10 @@ def _make_assistant_text_message(text): } -def _make_assistant_tool_use_message(tool_name, tool_input, tool_use_id): - """Build an assistant output message with a toolUse block.""" - return { - "role": "assistant", - "content": { - "message": json.dumps([{"toolUse": {"name": tool_name, "input": tool_input, "toolUseId": tool_use_id}}]), - "finish_reason": "tool_use", - }, - } - - -def _make_tool_result_message(tool_use_id, result_text): - """Build a tool result input message with double-encoded content.""" - return { - "role": "tool", - "content": { - "content": json.dumps([{"toolResult": {"content": [{"text": result_text}], "toolUseId": tool_use_id}}]) - }, - } - - -# --- Span conversion (body-based parsing) --- - - -class TestSpanConversion: - def test_single_record_produces_inference_span(self, provider): - """One log record produces an InferenceSpan with input/output messages.""" - record = _make_log_record( - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("Hello!")], - ) - session = provider._build_session("sess-1", [record]) - spans = session.traces[0].spans - inference_spans = [s for s in spans if isinstance(s, InferenceSpan)] - assert len(inference_spans) == 1 - assert inference_spans[0].messages[0].content[0].text == "Hi" - assert inference_spans[0].messages[1].content[0].text == "Hello!" - - def test_record_with_tool_use_and_result(self, provider): - """toolUse in output + toolResult in next record's input → ToolExecutionSpan.""" - # Record 1: user asks, assistant calls tool - record1 = _make_log_record( - trace_id="t1", - span_id="s1", - input_messages=[_make_user_message("Calculate 6*7")], - output_messages=[_make_assistant_tool_use_message("calculator", {"expr": "6*7"}, "tu-1")], - time_nano=1000, - ) - # Record 2: tool result comes back, assistant responds - record2 = _make_log_record( - trace_id="t1", - span_id="s2", - input_messages=[ - _make_user_message("Calculate 6*7"), - _make_tool_result_message("tu-1", "42"), - ], - output_messages=[_make_assistant_text_message("The answer is 42.")], - time_nano=2000, - ) - session = provider._build_session("sess-1", [record1, record2]) - tool_spans = [s for s in session.traces[0].spans if isinstance(s, ToolExecutionSpan)] - assert len(tool_spans) == 1 - assert tool_spans[0].tool_call.name == "calculator" - assert tool_spans[0].tool_call.arguments == {"expr": "6*7"} - assert tool_spans[0].tool_result.content == "42" - - def test_agent_invocation_from_trace(self, provider): - """User prompt from first record, response from last → AgentInvocationSpan.""" - record1 = _make_log_record( - trace_id="t1", - span_id="s1", - input_messages=[_make_user_message("Tell me a joke")], - output_messages=[_make_assistant_text_message("Why did the chicken cross the road?")], - time_nano=1000, - ) - session = provider._build_session("sess-1", [record1]) - agent_spans = [s for s in session.traces[0].spans if isinstance(s, AgentInvocationSpan)] - assert len(agent_spans) == 1 - assert agent_spans[0].user_prompt == "Tell me a joke" - assert agent_spans[0].agent_response == "Why did the chicken cross the road?" - - def test_agent_invocation_extracts_tools(self, provider): - """available_tools populated from tool call names in the trace.""" - record1 = _make_log_record( - trace_id="t1", - span_id="s1", - input_messages=[_make_user_message("Search for X")], - output_messages=[_make_assistant_tool_use_message("web_search", {"q": "X"}, "tu-1")], - time_nano=1000, - ) - record2 = _make_log_record( - trace_id="t1", - span_id="s2", - input_messages=[_make_user_message("Search for X"), _make_tool_result_message("tu-1", "found X")], - output_messages=[_make_assistant_text_message("Here's what I found about X.")], - time_nano=2000, - ) - session = provider._build_session("sess-1", [record1, record2]) - agent_spans = [s for s in session.traces[0].spans if isinstance(s, AgentInvocationSpan)] - assert len(agent_spans) == 1 - tool_names = [t.name for t in agent_spans[0].available_tools] - assert "web_search" in tool_names - - def test_double_encoded_content_parsed(self, provider): - """Content field is a JSON string that must be parsed to get content blocks.""" - record = _make_log_record( - input_messages=[_make_user_message("test double encoding")], - output_messages=[_make_assistant_text_message("parsed correctly")], - ) - session = provider._build_session("sess-1", [record]) - inference_spans = [s for s in session.traces[0].spans if isinstance(s, InferenceSpan)] - assert inference_spans[0].messages[0].content[0].text == "test double encoding" - assert inference_spans[0].messages[1].content[0].text == "parsed correctly" - - def test_tool_call_matched_to_result_by_id(self, provider): - """toolUseId matching works across records in the same trace.""" - record1 = _make_log_record( - trace_id="t1", - span_id="s1", - input_messages=[_make_user_message("Do two things")], - output_messages=[ - _make_assistant_tool_use_message("tool_a", {"x": 1}, "tu-a"), - ], - time_nano=1000, - ) - record2 = _make_log_record( - trace_id="t1", - span_id="s2", - input_messages=[ - _make_user_message("Do two things"), - _make_tool_result_message("tu-a", "result-a"), - ], - output_messages=[ - _make_assistant_tool_use_message("tool_b", {"y": 2}, "tu-b"), - ], - time_nano=2000, - ) - record3 = _make_log_record( - trace_id="t1", - span_id="s3", - input_messages=[ - _make_user_message("Do two things"), - _make_tool_result_message("tu-a", "result-a"), - _make_tool_result_message("tu-b", "result-b"), - ], - output_messages=[_make_assistant_text_message("Both done.")], - time_nano=3000, - ) - session = provider._build_session("sess-1", [record1, record2, record3]) - tool_spans = [s for s in session.traces[0].spans if isinstance(s, ToolExecutionSpan)] - assert len(tool_spans) == 2 - tool_span_by_name = {ts.tool_call.name: ts for ts in tool_spans} - assert tool_span_by_name["tool_a"].tool_result.content == "result-a" - assert tool_span_by_name["tool_b"].tool_result.content == "result-b" - +# --- Output extraction --- -# --- Session building --- - - -class TestSessionBuilding: - def test_multiple_records_grouped_by_trace_id(self, provider): - """Records with different traceIds become separate Trace objects.""" - records = [ - _make_log_record( - trace_id="t1", - input_messages=[_make_user_message("q1")], - output_messages=[_make_assistant_text_message("a1")], - ), - _make_log_record( - trace_id="t2", - input_messages=[_make_user_message("q2")], - output_messages=[_make_assistant_text_message("a2")], - ), - ] - session = provider._build_session("sess-1", records) - assert session.session_id == "sess-1" - assert len(session.traces) == 2 - trace_ids = {t.trace_id for t in session.traces} - assert trace_ids == {"t1", "t2"} - - def test_multi_step_agent_loop(self, provider): - """user→LLM→tool→LLM→response produces InferenceSpan + ToolExecutionSpan + AgentInvocationSpan.""" - record1 = _make_log_record( - trace_id="t1", - span_id="s1", - input_messages=[_make_user_message("What is 6*7?")], - output_messages=[_make_assistant_tool_use_message("calculator", {"expr": "6*7"}, "tu-1")], - time_nano=1000, - ) - record2 = _make_log_record( - trace_id="t1", - span_id="s2", - input_messages=[ - _make_user_message("What is 6*7?"), - _make_tool_result_message("tu-1", "42"), - ], - output_messages=[_make_assistant_text_message("The answer is 42.")], - time_nano=2000, - ) - session = provider._build_session("sess-1", [record1, record2]) - assert len(session.traces) == 1 - spans = session.traces[0].spans - span_types = [type(s).__name__ for s in spans] - assert "InferenceSpan" in span_types - assert "ToolExecutionSpan" in span_types - assert "AgentInvocationSpan" in span_types - - def test_empty_records_list(self, provider): - session = provider._build_session("sess-1", []) - assert session.session_id == "sess-1" - assert session.traces == [] +class TestExtractOutput: def test_extract_output_from_agent_response(self, provider): """_extract_output returns last agent response text.""" records = [ @@ -408,25 +194,10 @@ def test_extract_output_from_agent_response(self, provider): time_nano=2000, ), ] - session = provider._build_session("sess-1", records) + session = provider._mapper.map_to_session(records, "sess-1") output = provider._extract_output(session) assert output == "Final response" - def test_record_with_no_body_skipped(self, provider): - """Malformed records without body don't crash.""" - records = [ - {"traceId": "t1", "spanId": "s1", "timeUnixNano": 1000}, - _make_log_record( - trace_id="t1", - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("Hello!")], - time_nano=2000, - ), - ] - session = provider._build_session("sess-1", records) - assert len(session.traces) == 1 - assert len(session.traces[0].spans) > 0 - # --- CW Logs Insights polling --- @@ -613,122 +384,3 @@ def test_output_from_last_agent_invocation(self, provider, mock_logs_client): ] _setup_query_results(mock_logs_client, records) assert provider.get_evaluation_data("sess-1")["output"] == "last" - - -# --- get_evaluation_data_by_trace_id --- - - -class TestGetEvaluationDataByTraceId: - def test_happy_path(self, provider, mock_logs_client): - records = [ - _make_log_record( - trace_id="t1", - span_id="s1", - input_messages=[_make_user_message("What is 6*7?")], - output_messages=[_make_assistant_text_message("The answer is 42.")], - ) - ] - _setup_query_results(mock_logs_client, records) - result = provider.get_evaluation_data_by_trace_id("t1") - assert isinstance(result["trajectory"], Session) - assert result["trajectory"].traces[0].trace_id == "t1" - assert result["output"] == "The answer is 42." - - def test_not_found_raises(self, provider, mock_logs_client): - _setup_query_results(mock_logs_client, []) - with pytest.raises(TraceNotFoundError, match="t-missing"): - provider.get_evaluation_data_by_trace_id("t-missing") - - def test_query_failure_raises(self, provider, mock_logs_client): - mock_logs_client.start_query.side_effect = Exception("throttled") - with pytest.raises(ProviderError, match="throttled"): - provider.get_evaluation_data_by_trace_id("t1") - - def test_query_uses_trace_id_filter(self, provider, mock_logs_client): - records = [ - _make_log_record( - trace_id="t-abc", - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("Hello")], - ) - ] - _setup_query_results(mock_logs_client, records) - provider.get_evaluation_data_by_trace_id("t-abc") - query_string = mock_logs_client.start_query.call_args[1]["queryString"] - assert "t-abc" in query_string - - def test_session_id_from_record_attributes(self, provider, mock_logs_client): - """Session ID is taken from record attributes when available.""" - records = [ - _make_log_record( - trace_id="t1", - session_id="sess-from-record", - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("Hello")], - ) - ] - _setup_query_results(mock_logs_client, records) - result = provider.get_evaluation_data_by_trace_id("t1") - assert result["trajectory"].session_id == "sess-from-record" - - -# --- list_sessions --- - - -def _setup_session_query(mock_logs_client, session_ids): - """Wire up mock to return session IDs from a stats aggregation query.""" - mock_logs_client.start_query.return_value = {"queryId": "q-1"} - results = [] - for sid in session_ids: - results.append( - [ - {"field": "sessionId", "value": sid}, - {"field": "span_count", "value": "5"}, - ] - ) - mock_logs_client.get_query_results.return_value = { - "status": "Complete", - "results": results, - } - - -class TestListSessions: - def test_returns_session_ids(self, provider, mock_logs_client): - _setup_session_query(mock_logs_client, ["s1", "s2", "s3"]) - assert list(provider.list_sessions()) == ["s1", "s2", "s3"] - - def test_empty_results(self, provider, mock_logs_client): - _setup_session_query(mock_logs_client, []) - assert list(provider.list_sessions()) == [] - - def test_time_filter_applied(self, provider, mock_logs_client): - from datetime import datetime, timezone - - from strands_evals.providers.trace_provider import SessionFilter - - start = datetime(2025, 1, 1, tzinfo=timezone.utc) - end = datetime(2025, 1, 31, tzinfo=timezone.utc) - _setup_session_query(mock_logs_client, ["s1"]) - list(provider.list_sessions(session_filter=SessionFilter(start_time=start, end_time=end))) - kw = mock_logs_client.start_query.call_args[1] - assert kw["startTime"] == int(start.timestamp()) - assert kw["endTime"] == int(end.timestamp()) - - def test_limit_applied(self, provider, mock_logs_client): - from strands_evals.providers.trace_provider import SessionFilter - - _setup_session_query(mock_logs_client, ["s1"]) - list(provider.list_sessions(session_filter=SessionFilter(limit=50))) - query_string = mock_logs_client.start_query.call_args[1]["queryString"] - assert "limit 50" in query_string - - def test_default_limit(self, provider, mock_logs_client): - _setup_session_query(mock_logs_client, ["s1"]) - list(provider.list_sessions()) - query_string = mock_logs_client.start_query.call_args[1]["queryString"] - assert "limit 1000" in query_string - - def test_query_failure_raises(self, provider, mock_logs_client): - mock_logs_client.start_query.side_effect = Exception("access denied") - with pytest.raises(ProviderError, match="access denied"): - list(provider.list_sessions()) diff --git a/tests/strands_evals/providers/test_trace_provider.py b/tests/strands_evals/providers/test_trace_provider.py index f89a36f..fc5ee46 100644 --- a/tests/strands_evals/providers/test_trace_provider.py +++ b/tests/strands_evals/providers/test_trace_provider.py @@ -5,7 +5,6 @@ from strands_evals.providers.exceptions import ( ProviderError, SessionNotFoundError, - TraceNotFoundError, TraceProviderError, ) from strands_evals.providers.trace_provider import ( @@ -37,9 +36,6 @@ def test_trace_provider_error_is_exception(self): def test_session_not_found_is_trace_provider_error(self): assert issubclass(SessionNotFoundError, TraceProviderError) - def test_trace_not_found_is_trace_provider_error(self): - assert issubclass(TraceNotFoundError, TraceProviderError) - def test_provider_error_is_trace_provider_error(self): assert issubclass(ProviderError, TraceProviderError) @@ -49,7 +45,7 @@ def test_exceptions_carry_message(self): def test_catching_base_catches_all(self): """All provider exceptions can be caught with TraceProviderError.""" - for exc_class in (SessionNotFoundError, TraceNotFoundError, ProviderError): + for exc_class in (SessionNotFoundError, ProviderError): with pytest.raises(TraceProviderError): raise exc_class("test") diff --git a/tests_integ/test_cloudwatch_provider.py b/tests_integ/test_cloudwatch_provider.py index f3e54cb..e6b7707 100644 --- a/tests_integ/test_cloudwatch_provider.py +++ b/tests_integ/test_cloudwatch_provider.py @@ -18,7 +18,6 @@ from strands_evals.providers.exceptions import ( ProviderError, SessionNotFoundError, - TraceNotFoundError, ) from strands_evals.types.trace import ( AgentInvocationSpan, @@ -89,18 +88,6 @@ def evaluation_data(provider, session_id): return provider.get_evaluation_data(session_id) -class TestListSessions: - def test_returns_at_least_one_session(self, provider): - sessions = list(provider.list_sessions()) - assert len(sessions) > 0, "Expected at least one session in CloudWatch" - - def test_session_ids_are_strings(self, provider): - for session_id in provider.list_sessions(): - assert isinstance(session_id, str) - assert len(session_id) > 0 - break # Only check the first one - - class TestGetEvaluationData: def test_returns_session_with_traces(self, evaluation_data, session_id): session = evaluation_data["trajectory"] @@ -169,23 +156,6 @@ def test_nonexistent_session_raises(self, provider): provider.get_evaluation_data("nonexistent-session-id-that-does-not-exist-12345") -class TestGetEvaluationDataByTraceId: - def test_fetches_by_trace_id(self, provider, evaluation_data): - """Use a trace_id from the discovered session to test trace-level retrieval.""" - session = evaluation_data["trajectory"] - trace_id = session.traces[0].trace_id - - result = provider.get_evaluation_data_by_trace_id(trace_id) - - assert isinstance(result["trajectory"], Session) - assert len(result["trajectory"].traces) > 0 - assert result["trajectory"].traces[0].trace_id == trace_id - - def test_nonexistent_trace_raises(self, provider): - with pytest.raises((TraceNotFoundError, ProviderError)): - provider.get_evaluation_data_by_trace_id("nonexistent-trace-id-12345") - - # --- End-to-end: CloudWatch → Evaluator pipeline --- From 35c9e4e2666d188f92c4a7940f70fdec595f019d Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Thu, 26 Feb 2026 09:38:05 -0500 Subject: [PATCH 7/8] feat: add CloudWatchSessionMapper for CW Logs Insights records Add a new session mapper that converts CloudWatch Logs Insights OTEL log records into typed Session objects for evaluation. The mapper parses body.input/output messages, extracts tool calls/results, and builds InferenceSpan, ToolExecutionSpan, and AgentInvocationSpan instances grouped by traceId. Includes comprehensive unit tests. --- tests_integ/test_cloudwatch_provider.py | 68 ++++++++----------------- 1 file changed, 22 insertions(+), 46 deletions(-) diff --git a/tests_integ/test_cloudwatch_provider.py b/tests_integ/test_cloudwatch_provider.py index e6b7707..098e754 100644 --- a/tests_integ/test_cloudwatch_provider.py +++ b/tests_integ/test_cloudwatch_provider.py @@ -1,11 +1,17 @@ """Integration tests for CloudWatchProvider against real CloudWatch Logs data. -Requires AWS credentials with CloudWatch Logs read access. Uses the -'strands_github_bot' AWS profile if available, otherwise falls back to default credentials. +Requires the following environment variables: + CLOUDWATCH_TEST_LOG_GROUP — CW Logs group containing agent traces + CLOUDWATCH_TEST_SESSION_ID — session ID with convertible spans + AWS_REGION (optional) — defaults to us-east-1 + +AWS credentials must be configured via standard mechanisms (env vars, profile, instance role). + Run with: pytest tests_integ/test_cloudwatch_provider.py -v """ -import boto3 +import os + import pytest from strands_evals import Case, Experiment @@ -27,59 +33,29 @@ Trace, ) -LOG_GROUP = "/aws/bedrock-agentcore/runtimes/github_issue_handler-zf6fZR2saQ-DEFAULT" - -KNOWN_SESSION_IDS = [ - "github_issue_68_20260218_172558_d27beb07", - "github-issue-68-1771435544179-d5gc8kk4xan", -] - -EXPECTED_ACCOUNT_ID = "249746592913" -AWS_PROFILE = "strands_github_bot" -AWS_REGION = "us-east-1" - - -def _create_logs_client() -> boto3.client | None: - """Try the named profile first, then fall back to default credentials.""" - for profile in [AWS_PROFILE, None]: - try: - session = boto3.Session(profile_name=profile, region_name=AWS_REGION) - sts = session.client("sts") - identity = sts.get_caller_identity() - if identity["Account"] == EXPECTED_ACCOUNT_ID: - return session.client("logs") - except Exception: - continue - return None - @pytest.fixture(scope="module") def provider(): - """Create a CloudWatchProvider targeting the correct AWS account.""" - client = _create_logs_client() - if client is None: - pytest.skip(f"No AWS credentials found for account {EXPECTED_ACCOUNT_ID}") + """Create a CloudWatchProvider using env var configuration.""" + log_group = os.environ.get("CLOUDWATCH_TEST_LOG_GROUP") + if not log_group: + pytest.skip("CLOUDWATCH_TEST_LOG_GROUP not set") + + region = os.environ.get("AWS_REGION", "us-east-1") try: - cw = CloudWatchProvider(log_group=LOG_GROUP, region=AWS_REGION) + return CloudWatchProvider(log_group=log_group, region=region) except ProviderError as e: pytest.skip(f"CloudWatch provider creation failed: {e}") - # Inject the verified client - cw._client = client - return cw - @pytest.fixture(scope="module") -def session_id(provider): - """Try known session IDs; skip if none found.""" - for sid in KNOWN_SESSION_IDS: - try: - provider.get_evaluation_data(sid) - return sid - except (SessionNotFoundError, ProviderError): - continue - pytest.skip("No known sessions found in CloudWatch") +def session_id(): + """Get a test session ID from environment variable.""" + sid = os.environ.get("CLOUDWATCH_TEST_SESSION_ID") + if not sid: + pytest.skip("CLOUDWATCH_TEST_SESSION_ID not set") + return sid @pytest.fixture(scope="module") From a6205aea81f3b2a65dd7fe87181fc0541ecce5b9 Mon Sep 17 00:00:00 2001 From: Aaron Farntrog Date: Thu, 26 Feb 2026 11:46:46 -0500 Subject: [PATCH 8/8] feat: add CloudWatchProvider to public API and improve documentation - Export CloudWatchProvider via lazy-loading in providers __init__.py - Add comprehensive docstring to CloudWatchProvider.__init__ with usage examples and parameter descriptions - Extract shared CloudWatch test helpers into a reusable cloudwatch_helpers module to reduce duplication across test files --- src/strands_evals/providers/__init__.py | 5 + .../providers/cloudwatch_provider.py | 31 +++++ tests/strands_evals/__init__.py | 0 tests/strands_evals/cloudwatch_helpers.py | 38 ++++++ .../mappers/test_cloudwatch_session_mapper.py | 116 +++++++----------- .../providers/test_cloudwatch_provider.py | 94 +++++--------- tests_integ/conftest.py | 13 ++ tests_integ/test_cloudwatch_provider.py | 6 - tests_integ/test_langfuse_provider.py | 24 ++-- 9 files changed, 167 insertions(+), 160 deletions(-) create mode 100644 tests/strands_evals/__init__.py create mode 100644 tests/strands_evals/cloudwatch_helpers.py create mode 100644 tests_integ/conftest.py diff --git a/src/strands_evals/providers/__init__.py b/src/strands_evals/providers/__init__.py index 1b81ee3..3fcae61 100644 --- a/src/strands_evals/providers/__init__.py +++ b/src/strands_evals/providers/__init__.py @@ -10,6 +10,7 @@ ) __all__ = [ + "CloudWatchProvider", "LangfuseProvider", "ProviderError", "SessionNotFoundError", @@ -20,6 +21,10 @@ def __getattr__(name: str) -> Any: """Lazy-load providers that depend on optional packages.""" + if name == "CloudWatchProvider": + from .cloudwatch_provider import CloudWatchProvider + + return CloudWatchProvider if name == "LangfuseProvider": from .langfuse_provider import LangfuseProvider diff --git a/src/strands_evals/providers/cloudwatch_provider.py b/src/strands_evals/providers/cloudwatch_provider.py index 375ef11..3b9d037 100644 --- a/src/strands_evals/providers/cloudwatch_provider.py +++ b/src/strands_evals/providers/cloudwatch_provider.py @@ -37,6 +37,37 @@ def __init__( lookback_days: int = 30, query_timeout_seconds: float = 60.0, ): + """Initialize the CloudWatch provider. + + The log group can be specified directly or discovered automatically + from an agent name. Region falls back to AWS_REGION, then + AWS_DEFAULT_REGION, then `us-east-1`. + + Example:: + + from strands_evals.providers import CloudWatchProvider + + # Explicit log group + provider = CloudWatchProvider( + log_group="/aws/bedrock-agentcore/runtimes/my-agent-abc123-DEFAULT", + ) + + # Discover log group from agent name + provider = CloudWatchProvider(agent_name="my-agent") + + Args: + region: AWS region. Falls back to AWS_REGION / AWS_DEFAULT_REGION env vars. + log_group: Full CloudWatch log group path. + agent_name: Agent name used to discover the runtime log group via + `describe_log_groups`. Exactly one of `log_group` or + `agent_name` must be provided. + lookback_days: How many days back to search for traces. + query_timeout_seconds: Maximum seconds to wait for a Logs Insights query. + + Raises: + ProviderError: If neither `log_group` nor `agent_name` is provided, + or if the boto3 client cannot be created. + """ resolved_region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") try: diff --git a/tests/strands_evals/__init__.py b/tests/strands_evals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strands_evals/cloudwatch_helpers.py b/tests/strands_evals/cloudwatch_helpers.py new file mode 100644 index 0000000..4b91abe --- /dev/null +++ b/tests/strands_evals/cloudwatch_helpers.py @@ -0,0 +1,38 @@ +"""Shared test helpers for building CloudWatch body-format OTEL log records.""" + +import json + + +def make_log_record( + trace_id="abc123", + span_id="span-1", + input_messages=None, + output_messages=None, + session_id="sess-1", + time_nano=1000000000000000000, +): + """Build a body-format OTEL log record dict as found in runtime log groups.""" + record = { + "traceId": trace_id, + "spanId": span_id, + "timeUnixNano": time_nano, + "body": { + "input": {"messages": input_messages or []}, + "output": {"messages": output_messages or []}, + }, + "attributes": {"session.id": session_id}, + } + return record + + +def make_user_message(text): + """Build a user input message with double-encoded content.""" + return {"role": "user", "content": {"content": json.dumps([{"text": text}])}} + + +def make_assistant_text_message(text): + """Build an assistant output message with double-encoded text content.""" + return { + "role": "assistant", + "content": {"message": json.dumps([{"text": text}]), "finish_reason": "end_turn"}, + } diff --git a/tests/strands_evals/mappers/test_cloudwatch_session_mapper.py b/tests/strands_evals/mappers/test_cloudwatch_session_mapper.py index 685ca43..90518fc 100644 --- a/tests/strands_evals/mappers/test_cloudwatch_session_mapper.py +++ b/tests/strands_evals/mappers/test_cloudwatch_session_mapper.py @@ -8,45 +8,11 @@ InferenceSpan, ToolExecutionSpan, ) +from tests.strands_evals.cloudwatch_helpers import make_assistant_text_message, make_log_record, make_user_message # --- Helpers for body-format log records --- -def _make_log_record( - trace_id="abc123", - span_id="span-1", - input_messages=None, - output_messages=None, - session_id="sess-1", - time_nano=1000000000000000000, -): - """Build a body-format OTEL log record dict as found in runtime log groups.""" - record = { - "traceId": trace_id, - "spanId": span_id, - "timeUnixNano": time_nano, - "body": { - "input": {"messages": input_messages or []}, - "output": {"messages": output_messages or []}, - }, - "attributes": {"session.id": session_id}, - } - return record - - -def _make_user_message(text): - """Build a user input message with double-encoded content.""" - return {"role": "user", "content": {"content": json.dumps([{"text": text}])}} - - -def _make_assistant_text_message(text): - """Build an assistant output message with double-encoded text content.""" - return { - "role": "assistant", - "content": {"message": json.dumps([{"text": text}]), "finish_reason": "end_turn"}, - } - - def _make_assistant_tool_use_message(tool_name, tool_input, tool_use_id): """Build an assistant output message with a toolUse block.""" return { @@ -77,9 +43,9 @@ def setup_method(self): def test_single_record_produces_inference_span(self): """One log record produces an InferenceSpan with input/output messages.""" - record = _make_log_record( - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("Hello!")], + record = make_log_record( + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("Hello!")], ) session = self.mapper.map_to_session([record], "sess-1") spans = session.traces[0].spans @@ -90,21 +56,21 @@ def test_single_record_produces_inference_span(self): def test_record_with_tool_use_and_result(self): """toolUse in output + toolResult in next record's input → ToolExecutionSpan.""" - record1 = _make_log_record( + record1 = make_log_record( trace_id="t1", span_id="s1", - input_messages=[_make_user_message("Calculate 6*7")], + input_messages=[make_user_message("Calculate 6*7")], output_messages=[_make_assistant_tool_use_message("calculator", {"expr": "6*7"}, "tu-1")], time_nano=1000, ) - record2 = _make_log_record( + record2 = make_log_record( trace_id="t1", span_id="s2", input_messages=[ - _make_user_message("Calculate 6*7"), + make_user_message("Calculate 6*7"), _make_tool_result_message("tu-1", "42"), ], - output_messages=[_make_assistant_text_message("The answer is 42.")], + output_messages=[make_assistant_text_message("The answer is 42.")], time_nano=2000, ) session = self.mapper.map_to_session([record1, record2], "sess-1") @@ -116,11 +82,11 @@ def test_record_with_tool_use_and_result(self): def test_agent_invocation_from_trace(self): """User prompt from first record, response from last → AgentInvocationSpan.""" - record1 = _make_log_record( + record1 = make_log_record( trace_id="t1", span_id="s1", - input_messages=[_make_user_message("Tell me a joke")], - output_messages=[_make_assistant_text_message("Why did the chicken cross the road?")], + input_messages=[make_user_message("Tell me a joke")], + output_messages=[make_assistant_text_message("Why did the chicken cross the road?")], time_nano=1000, ) session = self.mapper.map_to_session([record1], "sess-1") @@ -131,18 +97,18 @@ def test_agent_invocation_from_trace(self): def test_agent_invocation_extracts_tools(self): """available_tools populated from tool call names in the trace.""" - record1 = _make_log_record( + record1 = make_log_record( trace_id="t1", span_id="s1", - input_messages=[_make_user_message("Search for X")], + input_messages=[make_user_message("Search for X")], output_messages=[_make_assistant_tool_use_message("web_search", {"q": "X"}, "tu-1")], time_nano=1000, ) - record2 = _make_log_record( + record2 = make_log_record( trace_id="t1", span_id="s2", - input_messages=[_make_user_message("Search for X"), _make_tool_result_message("tu-1", "found X")], - output_messages=[_make_assistant_text_message("Here's what I found about X.")], + input_messages=[make_user_message("Search for X"), _make_tool_result_message("tu-1", "found X")], + output_messages=[make_assistant_text_message("Here's what I found about X.")], time_nano=2000, ) session = self.mapper.map_to_session([record1, record2], "sess-1") @@ -153,9 +119,9 @@ def test_agent_invocation_extracts_tools(self): def test_double_encoded_content_parsed(self): """Content field is a JSON string that must be parsed to get content blocks.""" - record = _make_log_record( - input_messages=[_make_user_message("test double encoding")], - output_messages=[_make_assistant_text_message("parsed correctly")], + record = make_log_record( + input_messages=[make_user_message("test double encoding")], + output_messages=[make_assistant_text_message("parsed correctly")], ) session = self.mapper.map_to_session([record], "sess-1") inference_spans = [s for s in session.traces[0].spans if isinstance(s, InferenceSpan)] @@ -164,20 +130,20 @@ def test_double_encoded_content_parsed(self): def test_tool_call_matched_to_result_by_id(self): """toolUseId matching works across records in the same trace.""" - record1 = _make_log_record( + record1 = make_log_record( trace_id="t1", span_id="s1", - input_messages=[_make_user_message("Do two things")], + input_messages=[make_user_message("Do two things")], output_messages=[ _make_assistant_tool_use_message("tool_a", {"x": 1}, "tu-a"), ], time_nano=1000, ) - record2 = _make_log_record( + record2 = make_log_record( trace_id="t1", span_id="s2", input_messages=[ - _make_user_message("Do two things"), + make_user_message("Do two things"), _make_tool_result_message("tu-a", "result-a"), ], output_messages=[ @@ -185,15 +151,15 @@ def test_tool_call_matched_to_result_by_id(self): ], time_nano=2000, ) - record3 = _make_log_record( + record3 = make_log_record( trace_id="t1", span_id="s3", input_messages=[ - _make_user_message("Do two things"), + make_user_message("Do two things"), _make_tool_result_message("tu-a", "result-a"), _make_tool_result_message("tu-b", "result-b"), ], - output_messages=[_make_assistant_text_message("Both done.")], + output_messages=[make_assistant_text_message("Both done.")], time_nano=3000, ) session = self.mapper.map_to_session([record1, record2, record3], "sess-1") @@ -214,15 +180,15 @@ def setup_method(self): def test_multiple_records_grouped_by_trace_id(self): """Records with different traceIds become separate Trace objects.""" records = [ - _make_log_record( + make_log_record( trace_id="t1", - input_messages=[_make_user_message("q1")], - output_messages=[_make_assistant_text_message("a1")], + input_messages=[make_user_message("q1")], + output_messages=[make_assistant_text_message("a1")], ), - _make_log_record( + make_log_record( trace_id="t2", - input_messages=[_make_user_message("q2")], - output_messages=[_make_assistant_text_message("a2")], + input_messages=[make_user_message("q2")], + output_messages=[make_assistant_text_message("a2")], ), ] session = self.mapper.map_to_session(records, "sess-1") @@ -233,21 +199,21 @@ def test_multiple_records_grouped_by_trace_id(self): def test_multi_step_agent_loop(self): """user→LLM→tool→LLM→response produces InferenceSpan + ToolExecutionSpan + AgentInvocationSpan.""" - record1 = _make_log_record( + record1 = make_log_record( trace_id="t1", span_id="s1", - input_messages=[_make_user_message("What is 6*7?")], + input_messages=[make_user_message("What is 6*7?")], output_messages=[_make_assistant_tool_use_message("calculator", {"expr": "6*7"}, "tu-1")], time_nano=1000, ) - record2 = _make_log_record( + record2 = make_log_record( trace_id="t1", span_id="s2", input_messages=[ - _make_user_message("What is 6*7?"), + make_user_message("What is 6*7?"), _make_tool_result_message("tu-1", "42"), ], - output_messages=[_make_assistant_text_message("The answer is 42.")], + output_messages=[make_assistant_text_message("The answer is 42.")], time_nano=2000, ) session = self.mapper.map_to_session([record1, record2], "sess-1") @@ -267,10 +233,10 @@ def test_record_with_no_body_skipped(self): """Malformed records without body don't crash.""" records = [ {"traceId": "t1", "spanId": "s1", "timeUnixNano": 1000}, - _make_log_record( + make_log_record( trace_id="t1", - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("Hello!")], + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("Hello!")], time_nano=2000, ), ] diff --git a/tests/strands_evals/providers/test_cloudwatch_provider.py b/tests/strands_evals/providers/test_cloudwatch_provider.py index 879fc33..d9ce0e1 100644 --- a/tests/strands_evals/providers/test_cloudwatch_provider.py +++ b/tests/strands_evals/providers/test_cloudwatch_provider.py @@ -12,6 +12,7 @@ SessionNotFoundError, ) from strands_evals.types.trace import Session +from tests.strands_evals.cloudwatch_helpers import make_assistant_text_message, make_log_record, make_user_message # --- Fixtures --- @@ -137,41 +138,6 @@ def _setup_query_results(mock_logs_client, records): } -def _make_log_record( - trace_id="abc123", - span_id="span-1", - input_messages=None, - output_messages=None, - session_id="sess-1", - time_nano=1000000000000000000, -): - """Build a body-format OTEL log record dict as found in runtime log groups.""" - record = { - "traceId": trace_id, - "spanId": span_id, - "timeUnixNano": time_nano, - "body": { - "input": {"messages": input_messages or []}, - "output": {"messages": output_messages or []}, - }, - "attributes": {"session.id": session_id}, - } - return record - - -def _make_user_message(text): - """Build a user input message with double-encoded content.""" - return {"role": "user", "content": {"content": json.dumps([{"text": text}])}} - - -def _make_assistant_text_message(text): - """Build an assistant output message with double-encoded text content.""" - return { - "role": "assistant", - "content": {"message": json.dumps([{"text": text}]), "finish_reason": "end_turn"}, - } - - # --- Output extraction --- @@ -179,18 +145,18 @@ class TestExtractOutput: def test_extract_output_from_agent_response(self, provider): """_extract_output returns last agent response text.""" records = [ - _make_log_record( + make_log_record( trace_id="t1", span_id="s1", - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("First response")], + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("First response")], time_nano=1000, ), - _make_log_record( + make_log_record( trace_id="t1", span_id="s2", - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("Final response")], + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("Final response")], time_nano=2000, ), ] @@ -206,11 +172,11 @@ class TestLogsInsightsPolling: def _make_record_json(self, trace_id="t1", span_id="s1"): """Return a JSON-serialized body-format log record for use in CW Logs @message fields.""" return json.dumps( - _make_log_record( + make_log_record( trace_id=trace_id, span_id=span_id, - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("Hello!")], + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("Hello!")], ) ) @@ -279,8 +245,8 @@ def test_empty_results(self, provider, mock_logs_client): def test_parses_message_field(self, provider, mock_logs_client): """Each result row's @message field is parsed as JSON into a record dict.""" - record1 = _make_log_record(trace_id="t1", span_id="s1") - record2 = _make_log_record(trace_id="t1", span_id="s2") + record1 = make_log_record(trace_id="t1", span_id="s1") + record2 = make_log_record(trace_id="t1", span_id="s2") mock_logs_client.start_query.return_value = {"queryId": "q-1"} mock_logs_client.get_query_results.return_value = { "status": "Complete", @@ -306,12 +272,12 @@ def test_start_query_failure_raises(self, provider, mock_logs_client): class TestGetEvaluationData: def test_happy_path(self, provider, mock_logs_client): records = [ - _make_log_record( + make_log_record( trace_id="t1", span_id="s1", session_id="sess-1", - input_messages=[_make_user_message("What is 6*7?")], - output_messages=[_make_assistant_text_message("The answer is 42.")], + input_messages=[make_user_message("What is 6*7?")], + output_messages=[make_assistant_text_message("The answer is 42.")], ) ] _setup_query_results(mock_logs_client, records) @@ -335,10 +301,10 @@ def test_query_failure_raises_provider_error(self, provider, mock_logs_client): def test_query_uses_session_id_filter(self, provider, mock_logs_client): """Verify the query string uses attributes.session.id filter.""" records = [ - _make_log_record( + make_log_record( session_id="sess-1", - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("Hello")], + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("Hello")], ) ] _setup_query_results(mock_logs_client, records) @@ -349,15 +315,15 @@ def test_query_uses_session_id_filter(self, provider, mock_logs_client): def test_multiple_traces(self, provider, mock_logs_client): records = [ - _make_log_record( + make_log_record( trace_id="t1", - input_messages=[_make_user_message("q1")], - output_messages=[_make_assistant_text_message("first")], + input_messages=[make_user_message("q1")], + output_messages=[make_assistant_text_message("first")], ), - _make_log_record( + make_log_record( trace_id="t2", - input_messages=[_make_user_message("q2")], - output_messages=[_make_assistant_text_message("second")], + input_messages=[make_user_message("q2")], + output_messages=[make_assistant_text_message("second")], ), ] _setup_query_results(mock_logs_client, records) @@ -367,18 +333,18 @@ def test_multiple_traces(self, provider, mock_logs_client): def test_output_from_last_agent_invocation(self, provider, mock_logs_client): records = [ - _make_log_record( + make_log_record( trace_id="t1", span_id="s1", - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("first")], + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("first")], time_nano=1000, ), - _make_log_record( + make_log_record( trace_id="t1", span_id="s2", - input_messages=[_make_user_message("Hi")], - output_messages=[_make_assistant_text_message("last")], + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("last")], time_nano=2000, ), ] diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py new file mode 100644 index 0000000..d0d69f6 --- /dev/null +++ b/tests_integ/conftest.py @@ -0,0 +1,13 @@ +"""Shared fixtures for trace provider integration tests. + +Each provider test module defines its own `provider` and `session_id` fixtures. +This conftest provides common fixtures that build on those. +""" + +import pytest + + +@pytest.fixture(scope="module") +def evaluation_data(provider, session_id): + """Fetch evaluation data for the test session.""" + return provider.get_evaluation_data(session_id) diff --git a/tests_integ/test_cloudwatch_provider.py b/tests_integ/test_cloudwatch_provider.py index 098e754..0ac2208 100644 --- a/tests_integ/test_cloudwatch_provider.py +++ b/tests_integ/test_cloudwatch_provider.py @@ -58,12 +58,6 @@ def session_id(): return sid -@pytest.fixture(scope="module") -def evaluation_data(provider, session_id): - """Fetch evaluation data for the discovered session.""" - return provider.get_evaluation_data(session_id) - - class TestGetEvaluationData: def test_returns_session_with_traces(self, evaluation_data, session_id): session = evaluation_data["trajectory"] diff --git a/tests_integ/test_langfuse_provider.py b/tests_integ/test_langfuse_provider.py index 2593660..8fddb36 100644 --- a/tests_integ/test_langfuse_provider.py +++ b/tests_integ/test_langfuse_provider.py @@ -41,7 +41,7 @@ def provider(): @pytest.fixture(scope="module") -def discovered_session_id(): +def session_id(): """Get a test session ID from environment variable.""" session_id = os.environ.get("LANGFUSE_TEST_SESSION_ID") if not session_id: @@ -49,17 +49,11 @@ def discovered_session_id(): return session_id -@pytest.fixture(scope="module") -def evaluation_data(provider, discovered_session_id): - """Fetch evaluation data for the discovered session.""" - return provider.get_evaluation_data(discovered_session_id) - - class TestGetEvaluationData: - def test_returns_session_with_traces(self, evaluation_data, discovered_session_id): + def test_returns_session_with_traces(self, evaluation_data, session_id): session = evaluation_data["trajectory"] assert isinstance(session, Session) - assert session.session_id == discovered_session_id + assert session.session_id == session_id assert len(session.traces) > 0 def test_traces_have_spans(self, evaluation_data): @@ -126,7 +120,7 @@ def test_nonexistent_session_raises(self, provider): class TestEndToEnd: """Fetch traces from Langfuse and run real evaluators on them.""" - def test_output_evaluator_on_remote_trace(self, provider, discovered_session_id): + def test_output_evaluator_on_remote_trace(self, provider, session_id): """OutputEvaluator produces a valid score from a Langfuse session.""" def task(case: Case) -> dict: @@ -135,7 +129,7 @@ def task(case: Case) -> dict: cases = [ Case( name="langfuse_session", - input=discovered_session_id, + input=session_id, expected_output="any agent response", ), ] @@ -154,7 +148,7 @@ def task(case: Case) -> dict: assert 0.0 <= report.score <= 1.0 assert len(report.case_results) == 1 - def test_coherence_evaluator_on_remote_trace(self, provider, discovered_session_id): + def test_coherence_evaluator_on_remote_trace(self, provider, session_id): """CoherenceEvaluator produces a valid score from a Langfuse session.""" def task(case: Case) -> dict: @@ -163,7 +157,7 @@ def task(case: Case) -> dict: cases = [ Case( name="langfuse_session", - input=discovered_session_id, + input=session_id, expected_output="any agent response", ), ] @@ -178,7 +172,7 @@ def task(case: Case) -> dict: assert report.score is not None assert 0.0 <= report.score <= 1.0 - def test_multiple_evaluators_on_remote_trace(self, provider, discovered_session_id): + def test_multiple_evaluators_on_remote_trace(self, provider, session_id): """Multiple evaluators can all run on the same Langfuse session data.""" def task(case: Case) -> dict: @@ -187,7 +181,7 @@ def task(case: Case) -> dict: cases = [ Case( name="langfuse_session", - input=discovered_session_id, + input=session_id, expected_output="any agent response", ), ]