Skip to content
5 changes: 2 additions & 3 deletions tests/integration/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import pytest
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.agents.turn_create_params import Document
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig

from llama_stack.apis.agents.agents import (
Expand Down Expand Up @@ -242,7 +241,7 @@ def test_code_interpreter_for_attachments(llama_stack_client_with_mocked_inferen

codex_agent = Agent(llama_stack_client_with_mocked_inference, **agent_config)
session_id = codex_agent.create_session(f"test-session-{uuid4()}")
inflation_doc = AgentDocument(
inflation_doc = Document(
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
mime_type="text/csv",
)
Expand Down
17 changes: 15 additions & 2 deletions tests/integration/datasetio/test_datasetio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,25 @@
import os
from pathlib import Path

import pytest

# How to run this test:
#
# LLAMA_STACK_CONFIG="template-name" pytest -v tests/integration/datasetio


@pytest.fixture
def dataset_for_test(llama_stack_client):
dataset_id = "test_dataset"
register_dataset(llama_stack_client, dataset_id=dataset_id)
yield
# Teardown - this always runs, even if the test fails
try:
llama_stack_client.datasets.unregister(dataset_id)
except Exception as e:
print(f"Warning: Failed to unregister test_dataset: {e}")


def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
Expand Down Expand Up @@ -80,8 +94,7 @@ def test_register_unregister_dataset(llama_stack_client):
assert len(response) == 0


def test_get_rows_paginated(llama_stack_client):
register_dataset(llama_stack_client)
def test_get_rows_paginated(llama_stack_client, dataset_for_test):
response = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
Expand Down
26 changes: 18 additions & 8 deletions tests/integration/scoring/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@
from ..datasetio.test_datasetio import register_dataset


@pytest.fixture
def rag_dataset_for_test(llama_stack_client):
dataset_id = "test_dataset"
register_dataset(llama_stack_client, for_rag=True, dataset_id=dataset_id)
yield # This is where the test function will run

# Teardown - this always runs, even if the test fails
try:
llama_stack_client.datasets.unregister(dataset_id)
except Exception as e:
print(f"Warning: Failed to unregister test_dataset: {e}")


@pytest.fixture
def sample_judge_prompt_template():
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
Expand Down Expand Up @@ -79,9 +92,7 @@ def test_scoring_functions_register(
# TODO: add unregister api for scoring functions


def test_scoring_score(llama_stack_client):
register_dataset(llama_stack_client, for_rag=True)

def test_scoring_score(llama_stack_client, rag_dataset_for_test):
# scoring individual rows
rows = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
Expand Down Expand Up @@ -115,9 +126,9 @@ def test_scoring_score(llama_stack_client):
assert len(response.results[x].score_rows) == 5


def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge_prompt_template, judge_model_id):
register_dataset(llama_stack_client, for_rag=True)

def test_scoring_score_with_params_llm_as_judge(
llama_stack_client, sample_judge_prompt_template, judge_model_id, rag_dataset_for_test
):
# scoring individual rows
rows = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
Expand Down Expand Up @@ -167,9 +178,8 @@ def test_scoring_score_with_params_llm_as_judge(llama_stack_client, sample_judge
],
)
def test_scoring_score_with_aggregation_functions(
llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id
llama_stack_client, sample_judge_prompt_template, judge_model_id, provider_id, rag_dataset_for_test
):
register_dataset(llama_stack_client, for_rag=True)
rows = llama_stack_client.datasetio.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
Expand Down