Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions beeai/Containerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ RUN pip3 install --no-cache-dir \
beeai-framework[mcp,duckduckgo]==0.1.31 \
openinference-instrumentation-beeai \
arize-phoenix-otel \
deepeval \
&& cd /usr/local/lib/python3.13/site-packages \
&& patch -p2 -i /tmp/beeai-gemini.patch \
&& patch -p2 -i /tmp/beeai-gemini-malformed-function-call.patch \
Expand Down
4 changes: 2 additions & 2 deletions beeai/agents/backport_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def prompt(self) -> str:
6. {{ backport_git_steps }}
"""

async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
async def run_with_schema(self, input: TInputSchema, capture_raw_response: bool = False) -> TOutputSchema:
async with mcp_tools(
os.getenv("MCP_GATEWAY_URL"),
filter=lambda t: t
Expand All @@ -153,7 +153,7 @@ async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
tools = self._tools.copy()
try:
self._tools.extend(gateway_tools)
return await self._run_with_schema(input)
return await self._run_with_schema(input, capture_raw_response=capture_raw_response)
finally:
self._tools = tools
# disassociate removed tools from requirements
Expand Down
15 changes: 12 additions & 3 deletions beeai/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


class BaseAgent(RequirementAgent, ABC):
last_raw_response: RequirementAgentRunOutput | None = None

@property
@abstractmethod
def input_schema(self) -> type[TInputSchema]: ...
Expand All @@ -32,7 +34,9 @@ def _render_prompt(self, input: TInputSchema) -> str:
)
return template.render(input)

async def _run_with_schema(self, input: TInputSchema) -> TOutputSchema:
async def _run_with_schema(
self, input: TInputSchema, capture_raw_response: bool = False
) -> TOutputSchema:
max_retries_per_step = int(os.getenv("BEEAI_MAX_RETRIES_PER_STEP", 5))
total_max_retries = int(os.getenv("BEEAI_TOTAL_MAX_RETRIES", 10))
max_iterations = int(os.getenv("BEEAI_MAX_ITERATIONS", 100))
Expand All @@ -46,10 +50,14 @@ async def _run_with_schema(self, input: TInputSchema) -> TOutputSchema:
max_iterations=max_iterations,
),
)
if capture_raw_response:
self.last_raw_response = response
return self.output_schema.model_validate_json(response.result.text)

async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
return await self._run_with_schema(input)
async def run_with_schema(
self, input: TInputSchema, capture_raw_response: bool = False
) -> TOutputSchema:
return await self._run_with_schema(input, capture_raw_response)


if os.getenv("LITELLM_DEBUG"):
Expand All @@ -58,4 +66,5 @@ async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
import beeai_framework.adapters.litellm.chat
import beeai_framework.adapters.litellm.embedding
from beeai_framework.adapters.litellm.utils import litellm_debug

litellm_debug(True)
4 changes: 2 additions & 2 deletions beeai/agents/rebase_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def prompt(self) -> str:
- Any validation issues found with rpmlint
"""

async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
async def run_with_schema(self, input: TInputSchema, capture_raw_response: bool = False) -> TOutputSchema:
async with mcp_tools(
os.getenv("MCP_GATEWAY_URL"),
filter=lambda t: t
Expand All @@ -182,7 +182,7 @@ async def run_with_schema(self, input: TInputSchema) -> TOutputSchema:
tools = self._tools.copy()
try:
self._tools.extend(gateway_tools)
return await self._run_with_schema(input)
return await self._run_with_schema(input, capture_raw_response=capture_raw_response)
finally:
self._tools = tools
# disassociate removed tools from requirements
Expand Down
166 changes: 166 additions & 0 deletions beeai/agents/tests/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright 2025 ยฉ BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

import asyncio
import os
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import TypeVar

import pytest
from deepeval import evaluate
from deepeval.dataset import EvaluationDataset, Golden
from deepeval.evaluate import DisplayConfig
from deepeval.metrics import BaseMetric
from deepeval.test_case import LLMTestCase
from deepeval.test_run.test_run import TestRunResultDisplay
from rich.console import Console, Group
from rich.panel import Panel
from rich.table import Table

from beeai_framework.agents import AnyAgent

ROOT_CACHE_DIR = f"/tmp/.cache"
Path(ROOT_CACHE_DIR).mkdir(parents=True, exist_ok=True)


T = TypeVar("T", bound=AnyAgent)


async def create_dataset(
*,
name: str,
agent_factory: Callable[[], T],
agent_run: Callable[[T, LLMTestCase], Awaitable[None]],
goldens: list[Golden],
cache: bool | None = None,
) -> EvaluationDataset:
dataset = EvaluationDataset()

cache_dir = Path(f"{ROOT_CACHE_DIR}/{name}")
if cache is None:
cache = os.getenv("EVAL_CACHE_DATASET", "").lower() == "true"

if cache and cache_dir.exists():
for file_path in cache_dir.glob("*.json"):
dataset.add_test_cases_from_json_file(
file_path=str(file_path.absolute().resolve()),
input_key_name="input",
actual_output_key_name="actual_output",
expected_output_key_name="expected_output",
context_key_name="context",
tools_called_key_name="tools_called",
expected_tools_key_name="expected_tools",
retrieval_context_key_name="retrieval_context",
)
else:

async def process_golden(golden: Golden) -> LLMTestCase:
agent = agent_factory()
case = LLMTestCase(
input=golden.input,
expected_tools=golden.expected_tools,
actual_output="",
expected_output=golden.expected_output,
comments=golden.comments,
context=golden.context,
tools_called=golden.tools_called,
retrieval_context=golden.retrieval_context,
additional_metadata=golden.additional_metadata,
)
await agent_run(agent, case)
return case

for test_case in await asyncio.gather(*[process_golden(golden) for golden in goldens], return_exceptions=False):
dataset.add_test_case(test_case)

if cache:
dataset.save_as(file_type="json", directory=str(cache_dir.absolute()), include_test_cases=True)

for case in dataset.test_cases:
case.name = f"{name} - {case.input[0:128].strip()}" # type: ignore

return dataset


def evaluate_dataset(
dataset: EvaluationDataset, metrics: list[BaseMetric], display_mode: TestRunResultDisplay | None = None
) -> None:
console = Console()
console.print("[bold green]Evaluating dataset[/bold green]")

if display_mode is None:
display_mode = TestRunResultDisplay(os.environ.get("EVAL_DISPLAY_MODE", "all"))

output = evaluate(
test_cases=dataset.test_cases, # type: ignore
metrics=metrics,
display_config=DisplayConfig(
show_indicator=False, print_results=False, verbose_mode=False, display_option=None
),
)

# Calculate pass/fail counts
total = len(output.test_results)
passed = sum(
bool(test_result.metrics_data) and all(md.success for md in (test_result.metrics_data or []))
for test_result in output.test_results
)
failed = total - passed

# Print summary table
summary_table = Table(title="Test Results Summary", show_header=True, header_style="bold cyan")
summary_table.add_column("Total", justify="right")
summary_table.add_column("Passed", justify="right", style="green")
summary_table.add_column("Failed", justify="right", style="red")
summary_table.add_row(str(total), str(passed), str(failed))
console.print(summary_table)

for test_result in output.test_results:
if display_mode != TestRunResultDisplay.ALL and (
(display_mode == TestRunResultDisplay.FAILING and test_result.success)
or (display_mode == TestRunResultDisplay.PASSING and not test_result.success)
):
continue

# Info Table
info_table = Table(show_header=False, box=None, pad_edge=False)
info_table.add_row("Input", str(test_result.input))
info_table.add_row("Expected Output", str(test_result.expected_output))
info_table.add_row("Actual Output", str(test_result.actual_output))

# Metrics Table
metrics_table = Table(title="Metrics", show_header=True, header_style="bold magenta")
metrics_table.add_column("Metric")
metrics_table.add_column("Success")
metrics_table.add_column("Score")
metrics_table.add_column("Threshold")
metrics_table.add_column("Reason")
metrics_table.add_column("Error")
# metrics_table.add_column("Verbose Log")

for metric_data in test_result.metrics_data or []:
metrics_table.add_row(
str(metric_data.name),
str(metric_data.success),
str(metric_data.score),
str(metric_data.threshold),
str(metric_data.reason),
str(metric_data.error) if metric_data.error else "",
# str(metric_data.verbose_logs),
)

# Print the panel with info and metrics table
console.print(
Panel(
Group(info_table, metrics_table),
title=f"[bold blue]{test_result.name}[/bold blue]",
border_style="blue",
)
)

# Gather failed tests
if failed:
pytest.fail(f"{failed}/{total} tests failed. See the summary table above for more details.", pytrace=False)
else:
assert 1 == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of assert 1 == 1 is unconventional for indicating a successful test completion. It works, but it's not very expressive. A more standard and clearer way to indicate that a test has passed successfully (when no exceptions were raised) is to simply have no assertions at the end, or use pass. Since pytest.fail is used for the failure case, the function will implicitly pass if it completes without hitting that line.

Suggested change
assert 1 == 1
pass

59 changes: 59 additions & 0 deletions beeai/agents/tests/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2025 ยฉ BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, TypeVar

from deepeval.key_handler import KEY_FILE_HANDLER, KeyValues
from deepeval.models import DeepEvalBaseLLM
from dotenv import load_dotenv
from pydantic import BaseModel

from beeai_framework.backend import ChatModel, ChatModelParameters
from beeai_framework.backend.constants import ProviderName
from beeai_framework.backend.message import UserMessage
from beeai_framework.middleware.trajectory import GlobalTrajectoryMiddleware
from beeai_framework.utils import ModelLike

TSchema = TypeVar("TSchema", bound=BaseModel)


load_dotenv()


class DeepEvalLLM(DeepEvalBaseLLM):
def __init__(self, model: ChatModel, *args: Any, **kwargs: Any) -> None:
self._model = model
super().__init__(model.model_id, *args, **kwargs)

def load_model(self, *args: Any, **kwargs: Any) -> None:
return None

def generate(self, prompt: str, schema: BaseModel | None = None) -> str:
raise NotImplementedError()

async def a_generate(self, prompt: str, schema: TSchema | None = None) -> str:
input_msg = UserMessage(prompt)
response = await self._model.create(
messages=[input_msg],
response_format=schema.model_json_schema(mode="serialization") if schema is not None else None,
stream=False,
temperature=0,
).middleware(
GlobalTrajectoryMiddleware(
pretty=True, exclude_none=True, enabled=os.environ.get("EVAL_LOG_LLM_CALLS", "").lower() == "true"
)
)
text = response.get_text_content()
return schema.model_validate_json(text) if schema else text # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The a_generate method in the base class DeepEvalBaseLLM is typed to return a str. This implementation, however, returns a Pydantic model object when a schema is provided, which violates the Liskov substitution principle. The # type: ignore comment is hiding this type mismatch.

To conform to the base class's method signature and ensure type safety, you should serialize the Pydantic model into a JSON string before returning it.

Suggested change
return schema.model_validate_json(text) if schema else text # type: ignore
return schema.model_validate_json(text).model_dump_json() if schema else text


def get_model_name(self) -> str:
return f"{self._model.model_id} ({self._model.provider_id})"

@staticmethod
def from_name(
name: str | ProviderName | None = None, options: ModelLike[ChatModelParameters] | None = None, **kwargs: Any
) -> "DeepEvalLLM":
name = name or KEY_FILE_HANDLER.fetch_data(KeyValues.LOCAL_MODEL_NAME)
model = ChatModel.from_name(name, options, **kwargs)
return DeepEvalLLM(model)
Loading