diff --git a/src/strands_evals/mappers/__init__.py b/src/strands_evals/mappers/__init__.py index 8a27078..5325de4 100644 --- a/src/strands_evals/mappers/__init__.py +++ b/src/strands_evals/mappers/__init__.py @@ -1,6 +1,7 @@ """Converters for transforming telemetry data to Session format.""" +from .cloudwatch_session_mapper import CloudWatchSessionMapper from .session_mapper import SessionMapper from .strands_in_memory_session_mapper import GenAIConventionVersion, StrandsInMemorySessionMapper -__all__ = ["GenAIConventionVersion", "SessionMapper", "StrandsInMemorySessionMapper"] +__all__ = ["CloudWatchSessionMapper", "GenAIConventionVersion", "SessionMapper", "StrandsInMemorySessionMapper"] diff --git a/src/strands_evals/mappers/cloudwatch_session_mapper.py b/src/strands_evals/mappers/cloudwatch_session_mapper.py new file mode 100644 index 0000000..4cdb57b --- /dev/null +++ b/src/strands_evals/mappers/cloudwatch_session_mapper.py @@ -0,0 +1,354 @@ +"""CloudWatch session mapper — converts CW Logs Insights records to Session format.""" + +import json +import logging +from collections import defaultdict +from datetime import datetime, timezone +from typing import Any + +from ..mappers.session_mapper import SessionMapper +from ..types.trace import ( + AgentInvocationSpan, + AssistantMessage, + InferenceSpan, + Session, + SpanInfo, + TextContent, + ToolCall, + ToolCallContent, + ToolConfig, + ToolExecutionSpan, + ToolResult, + ToolResultContent, + Trace, + UserMessage, +) + +logger = logging.getLogger(__name__) + + +class CloudWatchSessionMapper(SessionMapper): + """Maps CloudWatch Logs Insights records to Session format. + + Parses body.input/output messages from OTEL log records emitted by + Strands agent runtimes and builds typed Session objects for evaluation. + """ + + def map_to_session(self, records: list[dict[str, Any]], session_id: str) -> Session: + """Group log records by traceId, convert each group to a Trace, return a Session.""" + traces_by_id: dict[str, list[dict[str, Any]]] = defaultdict(list) + for record in records: + trace_id = record.get("traceId", "") + if not trace_id: + continue + traces_by_id[trace_id].append(record) + + traces: list[Trace] = [] + for trace_id, trace_records in traces_by_id.items(): + trace = self._convert_trace(trace_id, trace_records, session_id) + if trace.spans: + traces.append(trace) + + return Session(session_id=session_id, traces=traces) + + def _convert_trace(self, trace_id: str, records: list[dict[str, Any]], session_id: str) -> Trace: + """Convert a group of log records (same traceId) into a Trace with typed spans.""" + sorted_records = sorted(records, key=lambda r: r.get("timeUnixNano", 0)) + + spans: list[InferenceSpan | ToolExecutionSpan | AgentInvocationSpan] = [] + + # Collect all tool calls and results across records + all_tool_calls: dict[str, ToolCall] = {} + all_tool_results: dict[str, ToolResult] = {} + + for record in sorted_records: + if not isinstance(record.get("body"), dict): + continue + + for tc in self._extract_tool_calls(record): + if tc.tool_call_id: + all_tool_calls[tc.tool_call_id] = tc + + for tr in self._extract_tool_results(record): + if tr.tool_call_id: + all_tool_results[tr.tool_call_id] = tr + + # Create InferenceSpans (one per record with parseable body) + for record in sorted_records: + if not isinstance(record.get("body"), dict): + continue + + try: + messages = self._record_to_messages(record) + if messages: + span_info = self._create_span_info(record, session_id) + spans.append(InferenceSpan(span_info=span_info, messages=messages, metadata={})) + except Exception as e: + logger.warning("Failed to create inference span from record %s: %s", record.get("spanId"), e) + + # Create ToolExecutionSpans by matching calls to results + seen_tool_ids: set[str] = set() + for record in sorted_records: + for tc in self._extract_tool_calls(record): + if tc.tool_call_id and tc.tool_call_id not in seen_tool_ids: + seen_tool_ids.add(tc.tool_call_id) + tr = all_tool_results.get(tc.tool_call_id, ToolResult(content="", tool_call_id=tc.tool_call_id)) + span_info = self._create_span_info(record, session_id) + spans.append(ToolExecutionSpan(span_info=span_info, tool_call=tc, tool_result=tr, metadata={})) + + # Create AgentInvocationSpan from first user prompt + last agent response + agent_span = self._create_agent_invocation_span(sorted_records, all_tool_calls, session_id) + if agent_span: + spans.append(agent_span) + + return Trace(spans=spans, trace_id=trace_id, session_id=session_id) + + def _create_agent_invocation_span( + self, records: list[dict[str, Any]], tool_calls: dict[str, ToolCall], session_id: str + ) -> AgentInvocationSpan | None: + """Create an AgentInvocationSpan from the first user prompt and last agent response.""" + user_prompt = None + for record in records: + prompt = self._extract_user_prompt(record) + if prompt: + user_prompt = prompt + break + + if not user_prompt: + return None + + agent_response = None + best_record = None + for record in reversed(records): + response = self._extract_agent_response(record) + if response: + agent_response = response + best_record = record + break + + if not agent_response or not best_record: + return None + + available_tools = [ToolConfig(name=name) for name in sorted({tc.name for tc in tool_calls.values()})] + span_info = self._create_span_info(best_record, session_id) + + return AgentInvocationSpan( + span_info=span_info, + user_prompt=user_prompt, + agent_response=agent_response, + available_tools=available_tools, + metadata={}, + ) + + # --- Span info --- + + def _create_span_info(self, record: dict[str, Any], session_id: str) -> SpanInfo: + time_nano = record.get("timeUnixNano", 0) + ts = datetime.fromtimestamp(time_nano / 1e9, tz=timezone.utc) + + return SpanInfo( + trace_id=record.get("traceId", ""), + span_id=record.get("spanId", ""), + session_id=session_id, + parent_span_id=record.get("parentSpanId") or None, + start_time=ts, + end_time=ts, + ) + + # --- Body-based content extraction --- + + def _parse_message_content(self, raw: str) -> list[dict[str, Any]] | None: + """Parse double-encoded message content into a list of content blocks.""" + if not isinstance(raw, str): + return None + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, list) else None + except (json.JSONDecodeError, TypeError): + return None + + def _extract_content_field(self, content: dict[str, Any]) -> str | None: + """Extract the raw content field from a message.""" + if not isinstance(content, dict): + return None + return content.get("content") or content.get("message") + + def _extract_text_from_content(self, content: Any) -> str | None: + """Extract text from a content field, handling double-encoded JSON strings.""" + raw = self._extract_content_field(content) + if not raw: + return None + + parsed = self._parse_message_content(raw) + if parsed: + texts = [item["text"] for item in parsed if isinstance(item, dict) and "text" in item] + return " ".join(texts) if texts else None + + return raw if isinstance(raw, str) else None + + def _extract_message_text(self, record: dict[str, Any], message_type: str, role: str) -> str | None: + """Extract text from a specific message type and role in a log record.""" + body = record.get("body", {}) + if not isinstance(body, dict): + return None + + messages = body.get(message_type, {}).get("messages", []) + for msg in messages: + if msg.get("role") == role: + text = self._extract_text_from_content(msg.get("content", {})) + if text: + return text + return None + + def _extract_user_prompt(self, record: dict[str, Any]) -> str | None: + """Extract user prompt text from a log record's body.input.messages.""" + return self._extract_message_text(record, "input", "user") + + def _extract_agent_response(self, record: dict[str, Any]) -> str | None: + """Extract assistant text response from a log record's body.output.messages.""" + return self._extract_message_text(record, "output", "assistant") + + def _extract_tool_calls(self, record: dict[str, Any]) -> list[ToolCall]: + """Extract tool calls from a log record's body.output.messages.""" + tool_calls: list[ToolCall] = [] + body = record.get("body", {}) + if not isinstance(body, dict): + return tool_calls + + for msg in body.get("output", {}).get("messages", []): + if msg.get("role") != "assistant": + continue + + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + for item in parsed: + if isinstance(item, dict) and "toolUse" in item: + tu = item["toolUse"] + tool_calls.append( + ToolCall( + name=tu.get("name", ""), + arguments=tu.get("input", {}), + tool_call_id=tu.get("toolUseId"), + ) + ) + + return tool_calls + + def _extract_tool_results(self, record: dict[str, Any]) -> list[ToolResult]: + """Extract tool results from a log record's body.input.messages.""" + tool_results: list[ToolResult] = [] + body = record.get("body", {}) + if not isinstance(body, dict): + return tool_results + + for msg in body.get("input", {}).get("messages", []): + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + for item in parsed: + if isinstance(item, dict) and "toolResult" in item: + tr_data = item["toolResult"] + result_text = self._extract_tool_result_text(tr_data.get("content")) + tool_results.append( + ToolResult( + content=result_text, + error=tr_data.get("error"), + tool_call_id=tr_data.get("toolUseId"), + ) + ) + + return tool_results + + def _extract_tool_result_text(self, content: Any) -> str: + """Extract text from tool result content.""" + if not content: + return "" + if isinstance(content, list) and content: + return content[0].get("text", "") + return str(content) + + # --- Record-to-messages conversion --- + + def _record_to_messages(self, record: dict[str, Any]) -> list[UserMessage | AssistantMessage]: + """Convert a log record's body into a list of typed messages for InferenceSpan.""" + messages: list[UserMessage | AssistantMessage] = [] + body = record.get("body", {}) + if not isinstance(body, dict): + return messages + + # Process input messages + for msg in body.get("input", {}).get("messages", []): + role = msg.get("role", "") + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + if role == "user": + user_content = self._process_user_message(parsed) + if user_content: + messages.append(UserMessage(content=user_content)) + elif role == "tool": + tool_content = self._process_tool_results(parsed) + if tool_content: + messages.append(UserMessage(content=tool_content)) + + # Process output messages + for msg in body.get("output", {}).get("messages", []): + if msg.get("role") != "assistant": + continue + + raw = self._extract_content_field(msg.get("content", {})) + parsed = self._parse_message_content(raw) if raw else None + if not parsed: + continue + + assistant_content = self._process_assistant_content(parsed) + if assistant_content: + messages.append(AssistantMessage(content=assistant_content)) + + return messages + + # --- Content parsing helpers (Bedrock Converse format) --- + + @staticmethod + def _process_user_message(content_list: list[dict[str, Any]]) -> list[TextContent | ToolResultContent]: + return [TextContent(text=item["text"]) for item in content_list if "text" in item] + + @staticmethod + def _process_assistant_content(content_list: list[dict[str, Any]]) -> list[TextContent | ToolCallContent]: + result: list[TextContent | ToolCallContent] = [] + for item in content_list: + if "text" in item: + result.append(TextContent(text=item["text"])) + elif "toolUse" in item: + tool_use = item["toolUse"] + result.append( + ToolCallContent( + name=tool_use["name"], + arguments=tool_use.get("input", {}), + tool_call_id=tool_use.get("toolUseId"), + ) + ) + return result + + def _process_tool_results(self, content_list: list[dict[str, Any]]) -> list[TextContent | ToolResultContent]: + result: list[TextContent | ToolResultContent] = [] + for item in content_list: + if "toolResult" not in item: + continue + tool_result = item["toolResult"] + result_text = self._extract_tool_result_text(tool_result.get("content")) + result.append( + ToolResultContent( + content=result_text, + error=tool_result.get("error"), + tool_call_id=tool_result.get("toolUseId"), + ) + ) + return result diff --git a/src/strands_evals/providers/README.md b/src/strands_evals/providers/README.md new file mode 100644 index 0000000..17b2b21 --- /dev/null +++ b/src/strands_evals/providers/README.md @@ -0,0 +1,108 @@ +# Trace Providers + +Trace providers fetch agent execution data from observability backends and convert it into the format the evaluation pipeline expects. This lets you run evaluators against traces from production or staging agents without re-running them. + +## Available Providers + +| Provider | Backend | Auth | +|----------|---------|------| +| `CloudWatchProvider` | AWS CloudWatch Logs (Bedrock AgentCore runtime logs) | AWS credentials (boto3) | +| `LangfuseProvider` | Langfuse | API keys | + +## Quick Start + +### CloudWatch + +```python +from strands_evals.providers import CloudWatchProvider + +# Option 1: Provide the log group directly +provider = CloudWatchProvider( + log_group="/aws/bedrock-agentcore/runtimes/my-agent-abc123-DEFAULT", + region="us-east-1", +) + +# Option 2: Discover the log group from the agent name +provider = CloudWatchProvider(agent_name="my-agent", region="us-east-1") +``` + +### Langfuse + +```python +from strands_evals.providers import LangfuseProvider + +# Reads LANGFUSE_PUBLIC_KEY / LANGFUSE_SECRET_KEY from env by default +provider = LangfuseProvider() + +# Or pass credentials explicitly +provider = LangfuseProvider( + public_key="pk-...", + secret_key="sk-...", + host="https://us.cloud.langfuse.com", +) +``` + +## Core API + +All providers implement the `TraceProvider` interface: + +```python +# Fetch traces for a session, ready for evaluation +data = provider.get_evaluation_data(session_id="my-session-id") +# data["output"] -> str (final agent response) +# data["trajectory"] -> Session (traces and spans) +``` + +## Running Evaluators on Remote Traces + +Pass the provider's data into the standard `Experiment` pipeline: + +```python +from strands_evals import Case, Experiment +from strands_evals.evaluators import CoherenceEvaluator, OutputEvaluator +from strands_evals.providers import CloudWatchProvider + +provider = CloudWatchProvider(log_group="/aws/...", region="us-east-1") + +def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + +cases = [Case(name="session_1", input="my-session-id", expected_output="any")] +evaluators = [ + OutputEvaluator(rubric="Score 1.0 if the output is coherent. Score 0.0 otherwise."), + CoherenceEvaluator(), +] + +experiment = Experiment(cases=cases, evaluators=evaluators) +reports = experiment.run_evaluations(task) + +for report in reports: + print(f"{report.overall_score:.2f} - {report.reasons}") +``` + +## Error Handling + +```python +from strands_evals.providers import SessionNotFoundError, ProviderError + +try: + data = provider.get_evaluation_data("unknown-session") +except SessionNotFoundError: + print("No traces found for that session") +except ProviderError: + print("Provider unreachable or query failed") +``` + +## Implementing a Custom Provider + +Subclass `TraceProvider` and implement `get_evaluation_data`: + +```python +from strands_evals.providers import TraceProvider + +class MyProvider(TraceProvider): + def get_evaluation_data(self, session_id: str) -> dict: + # Fetch traces from your backend, return: + # {"output": "final response text", "trajectory": Session(...)} + ... +``` diff --git a/src/strands_evals/providers/__init__.py b/src/strands_evals/providers/__init__.py index 40c9c4c..3fcae61 100644 --- a/src/strands_evals/providers/__init__.py +++ b/src/strands_evals/providers/__init__.py @@ -3,7 +3,6 @@ from .exceptions import ( ProviderError, SessionNotFoundError, - TraceNotFoundError, TraceProviderError, ) from .trace_provider import ( @@ -11,10 +10,10 @@ ) __all__ = [ + "CloudWatchProvider", "LangfuseProvider", "ProviderError", "SessionNotFoundError", - "TraceNotFoundError", "TraceProvider", "TraceProviderError", ] @@ -22,6 +21,10 @@ def __getattr__(name: str) -> Any: """Lazy-load providers that depend on optional packages.""" + if name == "CloudWatchProvider": + from .cloudwatch_provider import CloudWatchProvider + + return CloudWatchProvider if name == "LangfuseProvider": from .langfuse_provider import LangfuseProvider diff --git a/src/strands_evals/providers/cloudwatch_provider.py b/src/strands_evals/providers/cloudwatch_provider.py new file mode 100644 index 0000000..3b9d037 --- /dev/null +++ b/src/strands_evals/providers/cloudwatch_provider.py @@ -0,0 +1,186 @@ +"""CloudWatch trace provider for retrieving agent traces from AWS CloudWatch Logs.""" + +import json +import logging +import os +import time +from datetime import datetime, timedelta, timezone +from typing import Any + +import boto3 + +from ..mappers.cloudwatch_session_mapper import CloudWatchSessionMapper +from ..providers.exceptions import ProviderError, SessionNotFoundError +from ..providers.trace_provider import TraceProvider +from ..types.evaluation import TaskOutput +from ..types.trace import ( + AgentInvocationSpan, + Session, +) + +logger = logging.getLogger(__name__) + + +class CloudWatchProvider(TraceProvider): + """Retrieves agent trace data from AWS CloudWatch Logs for evaluation. + + Queries CloudWatch Logs Insights to fetch OTEL log records from an + agent-specific runtime log group, parses body.input/output messages, + and returns Session objects ready for the evaluation pipeline. + """ + + def __init__( + self, + region: str | None = None, + log_group: str | None = None, + agent_name: str | None = None, + lookback_days: int = 30, + query_timeout_seconds: float = 60.0, + ): + """Initialize the CloudWatch provider. + + The log group can be specified directly or discovered automatically + from an agent name. Region falls back to AWS_REGION, then + AWS_DEFAULT_REGION, then `us-east-1`. + + Example:: + + from strands_evals.providers import CloudWatchProvider + + # Explicit log group + provider = CloudWatchProvider( + log_group="/aws/bedrock-agentcore/runtimes/my-agent-abc123-DEFAULT", + ) + + # Discover log group from agent name + provider = CloudWatchProvider(agent_name="my-agent") + + Args: + region: AWS region. Falls back to AWS_REGION / AWS_DEFAULT_REGION env vars. + log_group: Full CloudWatch log group path. + agent_name: Agent name used to discover the runtime log group via + `describe_log_groups`. Exactly one of `log_group` or + `agent_name` must be provided. + lookback_days: How many days back to search for traces. + query_timeout_seconds: Maximum seconds to wait for a Logs Insights query. + + Raises: + ProviderError: If neither `log_group` nor `agent_name` is provided, + or if the boto3 client cannot be created. + """ + resolved_region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION", "us-east-1") + + try: + self._client = boto3.client("logs", region_name=resolved_region) + except Exception as e: + raise ProviderError(f"CloudWatch: failed to create boto3 logs client: {e}") from e + + if log_group: + self._log_group = log_group + elif agent_name: + self._log_group = self._discover_log_group(agent_name) + else: + raise ProviderError("CloudWatch: either log_group or agent_name must be provided") + + self._lookback_days = lookback_days + self._query_timeout_seconds = query_timeout_seconds + self._mapper = CloudWatchSessionMapper() + + def _discover_log_group(self, agent_name: str) -> str: + """Discover the runtime log group for an agent via describe_log_groups.""" + prefix = f"/aws/bedrock-agentcore/runtimes/{agent_name}" + response = self._client.describe_log_groups(logGroupNamePrefix=prefix) + log_groups = response.get("logGroups", []) + if not log_groups: + raise ProviderError(f"CloudWatch: no log group found for agent_name='{agent_name}' (prefix={prefix})") + return log_groups[0]["logGroupName"] + + def get_evaluation_data(self, session_id: str) -> TaskOutput: + """Fetch all traces for a session and return evaluation data.""" + query = f"fields @message | filter attributes.session.id = '{session_id}' | sort @timestamp asc | limit 10000" + + try: + span_dicts = self._run_logs_insights_query(query) + except ProviderError: + raise + except Exception as e: + raise ProviderError(f"CloudWatch: failed to query spans for session '{session_id}': {e}") from e + + if not span_dicts: + raise SessionNotFoundError(f"CloudWatch: no spans found for session_id='{session_id}'") + + session = self._mapper.map_to_session(span_dicts, session_id) + + if not session.traces: + raise SessionNotFoundError( + f"CloudWatch: spans found for session_id='{session_id}' but none contained convertible spans" + ) + + output = self._extract_output(session) + return TaskOutput(output=output, trajectory=session) + + # --- Internal: CW Logs Insights query execution --- + + def _run_logs_insights_query(self, query: str) -> list[dict[str, Any]]: + """Execute a CW Logs Insights query and return parsed span dicts from @message fields.""" + now = datetime.now(tz=timezone.utc) + end_time = now + start_time = now - timedelta(days=self._lookback_days) + + try: + response = self._client.start_query( + logGroupName=self._log_group, + startTime=int(start_time.timestamp()), + endTime=int(end_time.timestamp()), + queryString=query, + ) + except Exception as e: + raise ProviderError(f"CloudWatch: failed to start query: {e}") from e + + query_id = response["queryId"] + raw_results = self._poll_query_results(query_id) + return self._parse_query_results(raw_results) + + def _poll_query_results(self, query_id: str) -> list[list[dict[str, str]]]: + """Poll for query completion with exponential backoff. Returns raw result rows.""" + delay = 0.5 + max_delay = 8.0 + deadline = time.monotonic() + self._query_timeout_seconds + + while True: + response = self._client.get_query_results(queryId=query_id) + status = response.get("status", "") + + if status == "Complete": + return response.get("results", []) + elif status in ("Failed", "Cancelled", "Timeout"): + raise ProviderError(f"CloudWatch: query {status}") + + if time.monotonic() >= deadline: + raise ProviderError(f"CloudWatch: query timed out after {self._query_timeout_seconds}s") + + time.sleep(delay) + delay = min(delay * 2, max_delay) + + @staticmethod + def _parse_query_results(results: list[list[dict[str, str]]]) -> list[dict[str, Any]]: + """Parse @message fields from CW Logs Insights results into span dicts.""" + span_dicts = [] + for row in results: + for field in row: + if field.get("field") == "@message": + try: + span_dicts.append(json.loads(field["value"])) + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Failed to parse @message: %s", e) + return span_dicts + + # --- Internal: output extraction --- + + def _extract_output(self, session: Session) -> str: + """Extract the final agent response from the session for TaskOutput.output.""" + for trace in reversed(session.traces): + for span in reversed(trace.spans): + if isinstance(span, AgentInvocationSpan): + return span.agent_response + return "" diff --git a/src/strands_evals/providers/exceptions.py b/src/strands_evals/providers/exceptions.py index cb4f8a4..2823c1b 100644 --- a/src/strands_evals/providers/exceptions.py +++ b/src/strands_evals/providers/exceptions.py @@ -9,9 +9,5 @@ class SessionNotFoundError(TraceProviderError): """No traces found for the given session ID.""" -class TraceNotFoundError(TraceProviderError): - """Trace with the given ID not found.""" - - class ProviderError(TraceProviderError): """Provider is unreachable or returned an error.""" diff --git a/tests/strands_evals/__init__.py b/tests/strands_evals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strands_evals/cloudwatch_helpers.py b/tests/strands_evals/cloudwatch_helpers.py new file mode 100644 index 0000000..4b91abe --- /dev/null +++ b/tests/strands_evals/cloudwatch_helpers.py @@ -0,0 +1,38 @@ +"""Shared test helpers for building CloudWatch body-format OTEL log records.""" + +import json + + +def make_log_record( + trace_id="abc123", + span_id="span-1", + input_messages=None, + output_messages=None, + session_id="sess-1", + time_nano=1000000000000000000, +): + """Build a body-format OTEL log record dict as found in runtime log groups.""" + record = { + "traceId": trace_id, + "spanId": span_id, + "timeUnixNano": time_nano, + "body": { + "input": {"messages": input_messages or []}, + "output": {"messages": output_messages or []}, + }, + "attributes": {"session.id": session_id}, + } + return record + + +def make_user_message(text): + """Build a user input message with double-encoded content.""" + return {"role": "user", "content": {"content": json.dumps([{"text": text}])}} + + +def make_assistant_text_message(text): + """Build an assistant output message with double-encoded text content.""" + return { + "role": "assistant", + "content": {"message": json.dumps([{"text": text}]), "finish_reason": "end_turn"}, + } diff --git a/tests/strands_evals/mappers/test_cloudwatch_session_mapper.py b/tests/strands_evals/mappers/test_cloudwatch_session_mapper.py new file mode 100644 index 0000000..90518fc --- /dev/null +++ b/tests/strands_evals/mappers/test_cloudwatch_session_mapper.py @@ -0,0 +1,245 @@ +"""Tests for CloudWatchSessionMapper — body-format CW log record → Session conversion.""" + +import json + +from strands_evals.mappers.cloudwatch_session_mapper import CloudWatchSessionMapper +from strands_evals.types.trace import ( + AgentInvocationSpan, + InferenceSpan, + ToolExecutionSpan, +) +from tests.strands_evals.cloudwatch_helpers import make_assistant_text_message, make_log_record, make_user_message + +# --- Helpers for body-format log records --- + + +def _make_assistant_tool_use_message(tool_name, tool_input, tool_use_id): + """Build an assistant output message with a toolUse block.""" + return { + "role": "assistant", + "content": { + "message": json.dumps([{"toolUse": {"name": tool_name, "input": tool_input, "toolUseId": tool_use_id}}]), + "finish_reason": "tool_use", + }, + } + + +def _make_tool_result_message(tool_use_id, result_text): + """Build a tool result input message with double-encoded content.""" + return { + "role": "tool", + "content": { + "content": json.dumps([{"toolResult": {"content": [{"text": result_text}], "toolUseId": tool_use_id}}]) + }, + } + + +# --- Span conversion (body-based parsing) --- + + +class TestSpanConversion: + def setup_method(self): + self.mapper = CloudWatchSessionMapper() + + def test_single_record_produces_inference_span(self): + """One log record produces an InferenceSpan with input/output messages.""" + record = make_log_record( + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("Hello!")], + ) + session = self.mapper.map_to_session([record], "sess-1") + spans = session.traces[0].spans + inference_spans = [s for s in spans if isinstance(s, InferenceSpan)] + assert len(inference_spans) == 1 + assert inference_spans[0].messages[0].content[0].text == "Hi" + assert inference_spans[0].messages[1].content[0].text == "Hello!" + + def test_record_with_tool_use_and_result(self): + """toolUse in output + toolResult in next record's input → ToolExecutionSpan.""" + record1 = make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[make_user_message("Calculate 6*7")], + output_messages=[_make_assistant_tool_use_message("calculator", {"expr": "6*7"}, "tu-1")], + time_nano=1000, + ) + record2 = make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[ + make_user_message("Calculate 6*7"), + _make_tool_result_message("tu-1", "42"), + ], + output_messages=[make_assistant_text_message("The answer is 42.")], + time_nano=2000, + ) + session = self.mapper.map_to_session([record1, record2], "sess-1") + tool_spans = [s for s in session.traces[0].spans if isinstance(s, ToolExecutionSpan)] + assert len(tool_spans) == 1 + assert tool_spans[0].tool_call.name == "calculator" + assert tool_spans[0].tool_call.arguments == {"expr": "6*7"} + assert tool_spans[0].tool_result.content == "42" + + def test_agent_invocation_from_trace(self): + """User prompt from first record, response from last → AgentInvocationSpan.""" + record1 = make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[make_user_message("Tell me a joke")], + output_messages=[make_assistant_text_message("Why did the chicken cross the road?")], + time_nano=1000, + ) + session = self.mapper.map_to_session([record1], "sess-1") + agent_spans = [s for s in session.traces[0].spans if isinstance(s, AgentInvocationSpan)] + assert len(agent_spans) == 1 + assert agent_spans[0].user_prompt == "Tell me a joke" + assert agent_spans[0].agent_response == "Why did the chicken cross the road?" + + def test_agent_invocation_extracts_tools(self): + """available_tools populated from tool call names in the trace.""" + record1 = make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[make_user_message("Search for X")], + output_messages=[_make_assistant_tool_use_message("web_search", {"q": "X"}, "tu-1")], + time_nano=1000, + ) + record2 = make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[make_user_message("Search for X"), _make_tool_result_message("tu-1", "found X")], + output_messages=[make_assistant_text_message("Here's what I found about X.")], + time_nano=2000, + ) + session = self.mapper.map_to_session([record1, record2], "sess-1") + agent_spans = [s for s in session.traces[0].spans if isinstance(s, AgentInvocationSpan)] + assert len(agent_spans) == 1 + tool_names = [t.name for t in agent_spans[0].available_tools] + assert "web_search" in tool_names + + def test_double_encoded_content_parsed(self): + """Content field is a JSON string that must be parsed to get content blocks.""" + record = make_log_record( + input_messages=[make_user_message("test double encoding")], + output_messages=[make_assistant_text_message("parsed correctly")], + ) + session = self.mapper.map_to_session([record], "sess-1") + inference_spans = [s for s in session.traces[0].spans if isinstance(s, InferenceSpan)] + assert inference_spans[0].messages[0].content[0].text == "test double encoding" + assert inference_spans[0].messages[1].content[0].text == "parsed correctly" + + def test_tool_call_matched_to_result_by_id(self): + """toolUseId matching works across records in the same trace.""" + record1 = make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[make_user_message("Do two things")], + output_messages=[ + _make_assistant_tool_use_message("tool_a", {"x": 1}, "tu-a"), + ], + time_nano=1000, + ) + record2 = make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[ + make_user_message("Do two things"), + _make_tool_result_message("tu-a", "result-a"), + ], + output_messages=[ + _make_assistant_tool_use_message("tool_b", {"y": 2}, "tu-b"), + ], + time_nano=2000, + ) + record3 = make_log_record( + trace_id="t1", + span_id="s3", + input_messages=[ + make_user_message("Do two things"), + _make_tool_result_message("tu-a", "result-a"), + _make_tool_result_message("tu-b", "result-b"), + ], + output_messages=[make_assistant_text_message("Both done.")], + time_nano=3000, + ) + session = self.mapper.map_to_session([record1, record2, record3], "sess-1") + tool_spans = [s for s in session.traces[0].spans if isinstance(s, ToolExecutionSpan)] + assert len(tool_spans) == 2 + tool_span_by_name = {ts.tool_call.name: ts for ts in tool_spans} + assert tool_span_by_name["tool_a"].tool_result.content == "result-a" + assert tool_span_by_name["tool_b"].tool_result.content == "result-b" + + +# --- Session building --- + + +class TestSessionBuilding: + def setup_method(self): + self.mapper = CloudWatchSessionMapper() + + def test_multiple_records_grouped_by_trace_id(self): + """Records with different traceIds become separate Trace objects.""" + records = [ + make_log_record( + trace_id="t1", + input_messages=[make_user_message("q1")], + output_messages=[make_assistant_text_message("a1")], + ), + make_log_record( + trace_id="t2", + input_messages=[make_user_message("q2")], + output_messages=[make_assistant_text_message("a2")], + ), + ] + session = self.mapper.map_to_session(records, "sess-1") + assert session.session_id == "sess-1" + assert len(session.traces) == 2 + trace_ids = {t.trace_id for t in session.traces} + assert trace_ids == {"t1", "t2"} + + def test_multi_step_agent_loop(self): + """user→LLM→tool→LLM→response produces InferenceSpan + ToolExecutionSpan + AgentInvocationSpan.""" + record1 = make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[make_user_message("What is 6*7?")], + output_messages=[_make_assistant_tool_use_message("calculator", {"expr": "6*7"}, "tu-1")], + time_nano=1000, + ) + record2 = make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[ + make_user_message("What is 6*7?"), + _make_tool_result_message("tu-1", "42"), + ], + output_messages=[make_assistant_text_message("The answer is 42.")], + time_nano=2000, + ) + session = self.mapper.map_to_session([record1, record2], "sess-1") + assert len(session.traces) == 1 + spans = session.traces[0].spans + span_types = [type(s).__name__ for s in spans] + assert "InferenceSpan" in span_types + assert "ToolExecutionSpan" in span_types + assert "AgentInvocationSpan" in span_types + + def test_empty_records_list(self): + session = self.mapper.map_to_session([], "sess-1") + assert session.session_id == "sess-1" + assert session.traces == [] + + def test_record_with_no_body_skipped(self): + """Malformed records without body don't crash.""" + records = [ + {"traceId": "t1", "spanId": "s1", "timeUnixNano": 1000}, + make_log_record( + trace_id="t1", + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("Hello!")], + time_nano=2000, + ), + ] + session = self.mapper.map_to_session(records, "sess-1") + assert len(session.traces) == 1 + assert len(session.traces[0].spans) > 0 diff --git a/tests/strands_evals/providers/test_cloudwatch_provider.py b/tests/strands_evals/providers/test_cloudwatch_provider.py new file mode 100644 index 0000000..d9ce0e1 --- /dev/null +++ b/tests/strands_evals/providers/test_cloudwatch_provider.py @@ -0,0 +1,352 @@ +"""Tests for CloudWatchProvider — mocked boto3 CloudWatch Logs client.""" + +import json +import os +from unittest.mock import MagicMock, patch + +import pytest + +from strands_evals.providers.cloudwatch_provider import CloudWatchProvider +from strands_evals.providers.exceptions import ( + ProviderError, + SessionNotFoundError, +) +from strands_evals.types.trace import Session +from tests.strands_evals.cloudwatch_helpers import make_assistant_text_message, make_log_record, make_user_message + +# --- Fixtures --- + + +@pytest.fixture +def mock_logs_client(): + return MagicMock() + + +@pytest.fixture +def provider(mock_logs_client): + with patch("boto3.client", return_value=mock_logs_client): + return CloudWatchProvider(log_group="/test/group") + + +# --- Constructor --- + + +class TestConstructor: + def test_explicit_log_group(self, mock_logs_client): + with patch("boto3.client", return_value=mock_logs_client): + p = CloudWatchProvider(log_group="/custom/group") + assert p._log_group == "/custom/group" + assert p._lookback_days == 30 + assert p._query_timeout_seconds == 60.0 + + def test_custom_params(self, mock_logs_client): + with patch("boto3.client", return_value=mock_logs_client) as mock_boto: + p = CloudWatchProvider( + region="eu-west-1", + log_group="/custom/log-group", + lookback_days=7, + query_timeout_seconds=120.0, + ) + mock_boto.assert_called_once_with("logs", region_name="eu-west-1") + assert p._log_group == "/custom/log-group" + assert p._lookback_days == 7 + assert p._query_timeout_seconds == 120.0 + + def test_agent_name_discovery(self, mock_logs_client): + mock_logs_client.describe_log_groups.return_value = { + "logGroups": [{"logGroupName": "/aws/bedrock-agentcore/runtimes/my-agent-abc-DEFAULT"}] + } + with patch("boto3.client", return_value=mock_logs_client): + p = CloudWatchProvider(agent_name="my-agent") + assert p._log_group == "/aws/bedrock-agentcore/runtimes/my-agent-abc-DEFAULT" + mock_logs_client.describe_log_groups.assert_called_once_with( + logGroupNamePrefix="/aws/bedrock-agentcore/runtimes/my-agent" + ) + + def test_agent_name_no_match_raises(self, mock_logs_client): + mock_logs_client.describe_log_groups.return_value = {"logGroups": []} + with ( + patch("boto3.client", return_value=mock_logs_client), + pytest.raises(ProviderError, match="no log group found"), + ): + CloudWatchProvider(agent_name="nonexistent-agent") + + def test_neither_log_group_nor_agent_name_raises(self, mock_logs_client): + with ( + patch("boto3.client", return_value=mock_logs_client), + pytest.raises(ProviderError, match="log_group.*agent_name"), + ): + CloudWatchProvider() + + def test_region_from_aws_region_env(self, mock_logs_client): + env = os.environ.copy() + env.pop("AWS_DEFAULT_REGION", None) + env["AWS_REGION"] = "ap-southeast-1" + with ( + patch.dict(os.environ, env, clear=True), + patch("boto3.client", return_value=mock_logs_client) as mock_boto, + ): + CloudWatchProvider(log_group="/test/group") + mock_boto.assert_called_once_with("logs", region_name="ap-southeast-1") + + def test_region_from_aws_default_region_env(self, mock_logs_client): + env = os.environ.copy() + env.pop("AWS_REGION", None) + env["AWS_DEFAULT_REGION"] = "us-west-2" + with ( + patch.dict(os.environ, env, clear=True), + patch("boto3.client", return_value=mock_logs_client) as mock_boto, + ): + CloudWatchProvider(log_group="/test/group") + mock_boto.assert_called_once_with("logs", region_name="us-west-2") + + def test_aws_region_takes_precedence_over_default_region(self, mock_logs_client): + env = {"AWS_REGION": "eu-central-1", "AWS_DEFAULT_REGION": "us-west-2"} + with ( + patch.dict(os.environ, env), + patch("boto3.client", return_value=mock_logs_client) as mock_boto, + ): + CloudWatchProvider(log_group="/test/group") + mock_boto.assert_called_once_with("logs", region_name="eu-central-1") + + def test_explicit_region_overrides_env(self, mock_logs_client): + env = {"AWS_REGION": "eu-central-1"} + with ( + patch.dict(os.environ, env), + patch("boto3.client", return_value=mock_logs_client) as mock_boto, + ): + CloudWatchProvider(region="ca-central-1", log_group="/test/group") + mock_boto.assert_called_once_with("logs", region_name="ca-central-1") + + def test_boto3_client_creation_failure(self): + with ( + patch("boto3.client", side_effect=Exception("bad credentials")), + pytest.raises(ProviderError, match="bad credentials"), + ): + CloudWatchProvider(log_group="/test/group") + + +# --- Helpers for body-format log records --- + + +def _setup_query_results(mock_logs_client, records): + """Wire up mock to return records from a CW Logs Insights query.""" + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Complete", + "results": [[{"field": "@message", "value": json.dumps(r)}] for r in records], + } + + +# --- Output extraction --- + + +class TestExtractOutput: + def test_extract_output_from_agent_response(self, provider): + """_extract_output returns last agent response text.""" + records = [ + make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("First response")], + time_nano=1000, + ), + make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("Final response")], + time_nano=2000, + ), + ] + session = provider._mapper.map_to_session(records, "sess-1") + output = provider._extract_output(session) + assert output == "Final response" + + +# --- CW Logs Insights polling --- + + +class TestLogsInsightsPolling: + def _make_record_json(self, trace_id="t1", span_id="s1"): + """Return a JSON-serialized body-format log record for use in CW Logs @message fields.""" + return json.dumps( + make_log_record( + trace_id=trace_id, + span_id=span_id, + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("Hello!")], + ) + ) + + def test_happy_path(self, provider, mock_logs_client): + """Query starts, one poll, completes with results.""" + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Complete", + "results": [ + [{"field": "@message", "value": self._make_record_json()}], + ], + } + results = provider._run_logs_insights_query("fields @message") + assert len(results) == 1 + assert results[0]["traceId"] == "t1" + + def test_polls_through_intermediate_statuses(self, provider, mock_logs_client): + """Query goes through Scheduled → Running → Complete.""" + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.side_effect = [ + {"status": "Scheduled", "results": []}, + {"status": "Running", "results": []}, + { + "status": "Complete", + "results": [ + [{"field": "@message", "value": self._make_record_json()}], + ], + }, + ] + with patch("time.sleep"): + results = provider._run_logs_insights_query("fields @message") + assert len(results) == 1 + + def test_failed_status_raises(self, provider, mock_logs_client): + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Failed", + "results": [], + } + with pytest.raises(ProviderError, match="Failed"): + provider._run_logs_insights_query("fields @message") + + def test_timeout_raises(self, provider, mock_logs_client): + """If query doesn't complete within timeout, raises ProviderError.""" + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Running", + "results": [], + } + provider._query_timeout_seconds = 0.01 + with ( + patch("time.sleep"), + patch("time.monotonic", side_effect=[0.0, 0.0, 1.0]), + pytest.raises(ProviderError, match="timed out"), + ): + provider._run_logs_insights_query("fields @message") + + def test_empty_results(self, provider, mock_logs_client): + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Complete", + "results": [], + } + results = provider._run_logs_insights_query("fields @message") + assert results == [] + + def test_parses_message_field(self, provider, mock_logs_client): + """Each result row's @message field is parsed as JSON into a record dict.""" + record1 = make_log_record(trace_id="t1", span_id="s1") + record2 = make_log_record(trace_id="t1", span_id="s2") + mock_logs_client.start_query.return_value = {"queryId": "q-1"} + mock_logs_client.get_query_results.return_value = { + "status": "Complete", + "results": [ + [{"field": "@message", "value": json.dumps(record1)}], + [{"field": "@message", "value": json.dumps(record2)}], + ], + } + results = provider._run_logs_insights_query("fields @message") + assert len(results) == 2 + assert results[0]["spanId"] == "s1" + assert results[1]["spanId"] == "s2" + + def test_start_query_failure_raises(self, provider, mock_logs_client): + mock_logs_client.start_query.side_effect = Exception("access denied") + with pytest.raises(ProviderError, match="access denied"): + provider._run_logs_insights_query("fields @message") + + +# --- get_evaluation_data --- + + +class TestGetEvaluationData: + def test_happy_path(self, provider, mock_logs_client): + records = [ + make_log_record( + trace_id="t1", + span_id="s1", + session_id="sess-1", + input_messages=[make_user_message("What is 6*7?")], + output_messages=[make_assistant_text_message("The answer is 42.")], + ) + ] + _setup_query_results(mock_logs_client, records) + + result = provider.get_evaluation_data("sess-1") + assert isinstance(result["trajectory"], Session) + assert result["trajectory"].session_id == "sess-1" + assert len(result["trajectory"].traces) == 1 + assert result["output"] == "The answer is 42." + + def test_no_results_raises_session_not_found(self, provider, mock_logs_client): + _setup_query_results(mock_logs_client, []) + with pytest.raises(SessionNotFoundError, match="sess-missing"): + provider.get_evaluation_data("sess-missing") + + def test_query_failure_raises_provider_error(self, provider, mock_logs_client): + mock_logs_client.start_query.side_effect = Exception("throttled") + with pytest.raises(ProviderError, match="throttled"): + provider.get_evaluation_data("sess-1") + + def test_query_uses_session_id_filter(self, provider, mock_logs_client): + """Verify the query string uses attributes.session.id filter.""" + records = [ + make_log_record( + session_id="sess-1", + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("Hello")], + ) + ] + _setup_query_results(mock_logs_client, records) + provider.get_evaluation_data("sess-1") + query_string = mock_logs_client.start_query.call_args[1]["queryString"] + assert "attributes.session.id" in query_string + assert "sess-1" in query_string + + def test_multiple_traces(self, provider, mock_logs_client): + records = [ + make_log_record( + trace_id="t1", + input_messages=[make_user_message("q1")], + output_messages=[make_assistant_text_message("first")], + ), + make_log_record( + trace_id="t2", + input_messages=[make_user_message("q2")], + output_messages=[make_assistant_text_message("second")], + ), + ] + _setup_query_results(mock_logs_client, records) + result = provider.get_evaluation_data("sess-1") + assert len(result["trajectory"].traces) == 2 + assert result["output"] == "second" + + def test_output_from_last_agent_invocation(self, provider, mock_logs_client): + records = [ + make_log_record( + trace_id="t1", + span_id="s1", + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("first")], + time_nano=1000, + ), + make_log_record( + trace_id="t1", + span_id="s2", + input_messages=[make_user_message("Hi")], + output_messages=[make_assistant_text_message("last")], + time_nano=2000, + ), + ] + _setup_query_results(mock_logs_client, records) + assert provider.get_evaluation_data("sess-1")["output"] == "last" diff --git a/tests/strands_evals/providers/test_trace_provider.py b/tests/strands_evals/providers/test_trace_provider.py index f89a36f..fc5ee46 100644 --- a/tests/strands_evals/providers/test_trace_provider.py +++ b/tests/strands_evals/providers/test_trace_provider.py @@ -5,7 +5,6 @@ from strands_evals.providers.exceptions import ( ProviderError, SessionNotFoundError, - TraceNotFoundError, TraceProviderError, ) from strands_evals.providers.trace_provider import ( @@ -37,9 +36,6 @@ def test_trace_provider_error_is_exception(self): def test_session_not_found_is_trace_provider_error(self): assert issubclass(SessionNotFoundError, TraceProviderError) - def test_trace_not_found_is_trace_provider_error(self): - assert issubclass(TraceNotFoundError, TraceProviderError) - def test_provider_error_is_trace_provider_error(self): assert issubclass(ProviderError, TraceProviderError) @@ -49,7 +45,7 @@ def test_exceptions_carry_message(self): def test_catching_base_catches_all(self): """All provider exceptions can be caught with TraceProviderError.""" - for exc_class in (SessionNotFoundError, TraceNotFoundError, ProviderError): + for exc_class in (SessionNotFoundError, ProviderError): with pytest.raises(TraceProviderError): raise exc_class("test") diff --git a/tests_integ/conftest.py b/tests_integ/conftest.py new file mode 100644 index 0000000..d0d69f6 --- /dev/null +++ b/tests_integ/conftest.py @@ -0,0 +1,13 @@ +"""Shared fixtures for trace provider integration tests. + +Each provider test module defines its own `provider` and `session_id` fixtures. +This conftest provides common fixtures that build on those. +""" + +import pytest + + +@pytest.fixture(scope="module") +def evaluation_data(provider, session_id): + """Fetch evaluation data for the test session.""" + return provider.get_evaluation_data(session_id) diff --git a/tests_integ/test_cloudwatch_provider.py b/tests_integ/test_cloudwatch_provider.py new file mode 100644 index 0000000..0ac2208 --- /dev/null +++ b/tests_integ/test_cloudwatch_provider.py @@ -0,0 +1,211 @@ +"""Integration tests for CloudWatchProvider against real CloudWatch Logs data. + +Requires the following environment variables: + CLOUDWATCH_TEST_LOG_GROUP — CW Logs group containing agent traces + CLOUDWATCH_TEST_SESSION_ID — session ID with convertible spans + AWS_REGION (optional) — defaults to us-east-1 + +AWS credentials must be configured via standard mechanisms (env vars, profile, instance role). + +Run with: pytest tests_integ/test_cloudwatch_provider.py -v +""" + +import os + +import pytest + +from strands_evals import Case, Experiment +from strands_evals.evaluators import ( + CoherenceEvaluator, + HelpfulnessEvaluator, + OutputEvaluator, +) +from strands_evals.providers.cloudwatch_provider import CloudWatchProvider +from strands_evals.providers.exceptions import ( + ProviderError, + SessionNotFoundError, +) +from strands_evals.types.trace import ( + AgentInvocationSpan, + InferenceSpan, + Session, + ToolExecutionSpan, + Trace, +) + + +@pytest.fixture(scope="module") +def provider(): + """Create a CloudWatchProvider using env var configuration.""" + log_group = os.environ.get("CLOUDWATCH_TEST_LOG_GROUP") + if not log_group: + pytest.skip("CLOUDWATCH_TEST_LOG_GROUP not set") + + region = os.environ.get("AWS_REGION", "us-east-1") + + try: + return CloudWatchProvider(log_group=log_group, region=region) + except ProviderError as e: + pytest.skip(f"CloudWatch provider creation failed: {e}") + + +@pytest.fixture(scope="module") +def session_id(): + """Get a test session ID from environment variable.""" + sid = os.environ.get("CLOUDWATCH_TEST_SESSION_ID") + if not sid: + pytest.skip("CLOUDWATCH_TEST_SESSION_ID not set") + return sid + + +class TestGetEvaluationData: + def test_returns_session_with_traces(self, evaluation_data, session_id): + session = evaluation_data["trajectory"] + assert isinstance(session, Session) + assert session.session_id == session_id + assert len(session.traces) > 0 + + def test_traces_have_spans(self, evaluation_data): + session = evaluation_data["trajectory"] + for trace in session.traces: + assert isinstance(trace, Trace) + assert isinstance(trace.trace_id, str) + assert len(trace.trace_id) > 0 + assert len(trace.spans) > 0 + + def test_spans_are_typed(self, evaluation_data): + """All spans should be one of the three known types.""" + valid_types = (AgentInvocationSpan, InferenceSpan, ToolExecutionSpan) + session = evaluation_data["trajectory"] + for trace in session.traces: + for span in trace.spans: + assert isinstance(span, valid_types), f"Unexpected span type: {type(span).__name__}" + + def test_has_agent_invocation_span(self, evaluation_data): + """At least one trace should have an AgentInvocationSpan.""" + session = evaluation_data["trajectory"] + agent_spans = [ + span for trace in session.traces for span in trace.spans if isinstance(span, AgentInvocationSpan) + ] + assert len(agent_spans) > 0, "Expected at least one AgentInvocationSpan" + + def test_agent_invocation_has_prompt_and_response(self, evaluation_data): + """AgentInvocationSpan should have user prompt, agent response, and tools list.""" + session = evaluation_data["trajectory"] + for trace in session.traces: + for span in trace.spans: + if isinstance(span, AgentInvocationSpan): + assert isinstance(span.user_prompt, str) + assert len(span.user_prompt) > 0 + assert isinstance(span.agent_response, str) + assert len(span.agent_response) > 0 + assert isinstance(span.available_tools, list) + return + pytest.fail("No AgentInvocationSpan found") + + def test_output_is_nonempty_string(self, evaluation_data): + output = evaluation_data["output"] + assert isinstance(output, str) + assert len(output) > 0, "Expected non-empty output from agent response" + + def test_span_info_populated(self, evaluation_data): + session = evaluation_data["trajectory"] + for trace in session.traces: + for span in trace.spans: + si = span.span_info + assert si.trace_id is not None + assert si.span_id is not None + assert si.session_id == session.session_id + assert si.start_time is not None + assert si.end_time is not None + return + pytest.fail("No spans found to check") + + def test_nonexistent_session_raises(self, provider): + with pytest.raises(SessionNotFoundError): + provider.get_evaluation_data("nonexistent-session-id-that-does-not-exist-12345") + + +# --- End-to-end: CloudWatch → Evaluator pipeline --- + + +class TestEndToEnd: + """Fetch traces from CloudWatch and run real evaluators on them.""" + + def test_output_evaluator_on_remote_trace(self, provider, session_id): + """OutputEvaluator produces a valid score from a CloudWatch session.""" + + def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + + cases = [ + Case( + name="cloudwatch_session", + input=session_id, + expected_output="any agent response", + ), + ] + + evaluator = OutputEvaluator( + rubric="Score 1.0 if the output is a coherent response from an AI agent. " + "Score 0.0 if the output is empty or clearly broken.", + ) + + experiment = Experiment(cases=cases, evaluators=[evaluator]) + reports = experiment.run_evaluations(task) + + assert len(reports) == 1 + report = reports[0] + assert 0.0 <= report.overall_score <= 1.0 + assert len(report.scores) == 1 + + def test_coherence_evaluator_on_remote_trace(self, provider, session_id): + """CoherenceEvaluator produces a valid score from a CloudWatch session.""" + + def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + + cases = [ + Case( + name="cloudwatch_session", + input=session_id, + expected_output="any agent response", + ), + ] + + evaluator = CoherenceEvaluator() + + experiment = Experiment(cases=cases, evaluators=[evaluator]) + reports = experiment.run_evaluations(task) + + assert len(reports) == 1 + report = reports[0] + assert 0.0 <= report.overall_score <= 1.0 + + def test_multiple_evaluators_on_remote_trace(self, provider, session_id): + """Multiple evaluators can all run on the same CloudWatch session data.""" + + def task(case: Case) -> dict: + return provider.get_evaluation_data(case.input) + + cases = [ + Case( + name="cloudwatch_session", + input=session_id, + expected_output="any agent response", + ), + ] + + evaluators = [ + OutputEvaluator(rubric="Score 1.0 if the output is coherent. Score 0.0 otherwise."), + CoherenceEvaluator(), + HelpfulnessEvaluator(), + ] + + experiment = Experiment(cases=cases, evaluators=evaluators) + reports = experiment.run_evaluations(task) + + assert len(reports) == 3 + for report in reports: + assert 0.0 <= report.overall_score <= 1.0 + assert len(report.scores) == 1 diff --git a/tests_integ/test_langfuse_provider.py b/tests_integ/test_langfuse_provider.py index 2593660..8fddb36 100644 --- a/tests_integ/test_langfuse_provider.py +++ b/tests_integ/test_langfuse_provider.py @@ -41,7 +41,7 @@ def provider(): @pytest.fixture(scope="module") -def discovered_session_id(): +def session_id(): """Get a test session ID from environment variable.""" session_id = os.environ.get("LANGFUSE_TEST_SESSION_ID") if not session_id: @@ -49,17 +49,11 @@ def discovered_session_id(): return session_id -@pytest.fixture(scope="module") -def evaluation_data(provider, discovered_session_id): - """Fetch evaluation data for the discovered session.""" - return provider.get_evaluation_data(discovered_session_id) - - class TestGetEvaluationData: - def test_returns_session_with_traces(self, evaluation_data, discovered_session_id): + def test_returns_session_with_traces(self, evaluation_data, session_id): session = evaluation_data["trajectory"] assert isinstance(session, Session) - assert session.session_id == discovered_session_id + assert session.session_id == session_id assert len(session.traces) > 0 def test_traces_have_spans(self, evaluation_data): @@ -126,7 +120,7 @@ def test_nonexistent_session_raises(self, provider): class TestEndToEnd: """Fetch traces from Langfuse and run real evaluators on them.""" - def test_output_evaluator_on_remote_trace(self, provider, discovered_session_id): + def test_output_evaluator_on_remote_trace(self, provider, session_id): """OutputEvaluator produces a valid score from a Langfuse session.""" def task(case: Case) -> dict: @@ -135,7 +129,7 @@ def task(case: Case) -> dict: cases = [ Case( name="langfuse_session", - input=discovered_session_id, + input=session_id, expected_output="any agent response", ), ] @@ -154,7 +148,7 @@ def task(case: Case) -> dict: assert 0.0 <= report.score <= 1.0 assert len(report.case_results) == 1 - def test_coherence_evaluator_on_remote_trace(self, provider, discovered_session_id): + def test_coherence_evaluator_on_remote_trace(self, provider, session_id): """CoherenceEvaluator produces a valid score from a Langfuse session.""" def task(case: Case) -> dict: @@ -163,7 +157,7 @@ def task(case: Case) -> dict: cases = [ Case( name="langfuse_session", - input=discovered_session_id, + input=session_id, expected_output="any agent response", ), ] @@ -178,7 +172,7 @@ def task(case: Case) -> dict: assert report.score is not None assert 0.0 <= report.score <= 1.0 - def test_multiple_evaluators_on_remote_trace(self, provider, discovered_session_id): + def test_multiple_evaluators_on_remote_trace(self, provider, session_id): """Multiple evaluators can all run on the same Langfuse session data.""" def task(case: Case) -> dict: @@ -187,7 +181,7 @@ def task(case: Case) -> dict: cases = [ Case( name="langfuse_session", - input=discovered_session_id, + input=session_id, expected_output="any agent response", ), ]