From 004650f4f54b9fd65f134cf890d661c50761611b Mon Sep 17 00:00:00 2001 From: skamenan7 Date: Wed, 25 Mar 2026 14:11:34 -0400 Subject: [PATCH 1/4] refactor(tests): split test_openai_mixin.py by concern into 3 files Split 1438-line test_openai_mixin.py into: - openai_mixin_helpers.py: shared impl classes - conftest.py: shared fixtures (auto-discovered by pytest) - test_openai_mixin_models.py: model lifecycle tests (580L) - test_openai_mixin_provider_data.py: provider data and custom list tests (200L) - test_openai_mixin_inference.py: inference enforcement and stream options tests (520L) Removes test_openai_mixin.py from GRANDFATHERED_FILES in check_file_size.py. Signed-off-by: skamenan7 --- scripts/check_file_size.py | 1 - .../providers/utils/inference/conftest.py | 16 + .../utils/inference/openai_mixin_helpers.py | 180 +++ .../utils/inference/test_openai_mixin.py | 1438 ----------------- .../inference/test_openai_mixin_inference.py | 524 ++++++ .../inference/test_openai_mixin_models.py | 583 +++++++ .../test_openai_mixin_provider_data.py | 199 +++ 7 files changed, 1502 insertions(+), 1439 deletions(-) create mode 100644 tests/unit/providers/utils/inference/conftest.py create mode 100644 tests/unit/providers/utils/inference/openai_mixin_helpers.py delete mode 100644 tests/unit/providers/utils/inference/test_openai_mixin.py create mode 100644 tests/unit/providers/utils/inference/test_openai_mixin_inference.py create mode 100644 tests/unit/providers/utils/inference/test_openai_mixin_models.py create mode 100644 tests/unit/providers/utils/inference/test_openai_mixin_provider_data.py diff --git a/scripts/check_file_size.py b/scripts/check_file_size.py index e86b433f98..ec708ccc6e 100755 --- a/scripts/check_file_size.py +++ b/scripts/check_file_size.py @@ -38,7 +38,6 @@ "tests/integration/responses/test_openai_responses.py", "tests/integration/responses/test_tool_responses.py", "tests/unit/providers/responses/builtin/test_openai_responses.py", - "tests/unit/providers/utils/inference/test_openai_mixin.py", } diff --git a/tests/unit/providers/utils/inference/conftest.py b/tests/unit/providers/utils/inference/conftest.py new file mode 100644 index 0000000000..89bf62adfe --- /dev/null +++ b/tests/unit/providers/utils/inference/conftest.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from tests.unit.providers.utils.inference.openai_mixin_helpers import ( + mixin, # noqa: F401 + mixin_with_custom_model_construction, # noqa: F401 + mixin_with_embeddings, # noqa: F401 + mock_client_context, # noqa: F401 + mock_client_with_empty_models, # noqa: F401 + mock_client_with_exception, # noqa: F401 + mock_client_with_models, # noqa: F401 + mock_models, # noqa: F401 +) diff --git a/tests/unit/providers/utils/inference/openai_mixin_helpers.py b/tests/unit/providers/utils/inference/openai_mixin_helpers.py new file mode 100644 index 0000000000..9ab3e8285c --- /dev/null +++ b/tests/unit/providers/utils/inference/openai_mixin_helpers.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from collections.abc import Iterable +from typing import Any +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch + +import pytest +from pydantic import BaseModel, Field, SecretStr + +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin +from llama_stack_api import Model, ModelType + + +class OpenAIMixinImpl(OpenAIMixin): + __provider_id__: str = "test-provider" + + def get_api_key(self) -> str: + return "test-api-key" + + def get_base_url(self) -> str: + return "http://test-base-url" + + +class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl): + """Test implementation with embedding model metadata""" + + embedding_model_metadata: dict[str, dict[str, int]] = { + "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, + "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, + } + + +class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl): + """Test implementation that uses construct_model_from_identifier to add rerank models""" + + embedding_model_metadata: dict[str, dict[str, int]] = { + "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, + "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, + } + + # Adds rerank models via construct_model_from_identifier + rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"} + + def construct_model_from_identifier(self, identifier: str) -> Model: + if identifier in self.rerank_model_ids: + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.rerank, + ) + return super().construct_model_from_identifier(identifier) + + +class ProviderDataValidator(BaseModel): + """Validator for provider data in tests""" + + test_api_key: SecretStr | None = Field(default=None) + + +class OpenAIMixinWithProviderData(OpenAIMixinImpl): + """Test implementation that supports provider data API key field""" + + provider_data_api_key_field: str = "test_api_key" + + def get_api_key(self) -> str: + return "default-api-key" + + def get_base_url(self): + return "default-base-url" + + +class CustomListProviderModelIdsImplementation(OpenAIMixinImpl): + """Test implementation with custom list_provider_model_ids override""" + + custom_model_ids: Any + + async def list_provider_model_ids(self) -> Iterable[str]: + """Return custom model IDs list""" + return self.custom_model_ids + + +@pytest.fixture +def mixin(): + """Create a test instance of OpenAIMixin with mocked model_store""" + config = RemoteInferenceProviderConfig() + mixin_instance = OpenAIMixinImpl(config=config) + + # Mock model_store with async methods + mock_model_store = AsyncMock() + mock_model = MagicMock() + mock_model.provider_resource_id = "test-provider-resource-id" + mock_model_store.get_model = AsyncMock(return_value=mock_model) + mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override + mixin_instance.model_store = mock_model_store + + return mixin_instance + + +@pytest.fixture +def mixin_with_embeddings(): + """Create a test instance of OpenAIMixin with embedding model metadata""" + config = RemoteInferenceProviderConfig() + return OpenAIMixinWithEmbeddingsImpl(config=config) + + +@pytest.fixture +def mixin_with_custom_model_construction(): + """Create a test instance using custom construct_model_from_identifier""" + config = RemoteInferenceProviderConfig() + return OpenAIMixinWithCustomModelConstruction(config=config) + + +@pytest.fixture +def mock_models(): + """Create multiple mock OpenAI model objects""" + models = [MagicMock(id=id) for id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]] + return models + + +@pytest.fixture +def mock_client_with_models(mock_models): + """Create a mock client with models.list() set up to return mock_models""" + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + return mock_client + + +@pytest.fixture +def mock_client_with_empty_models(): + """Create a mock client with models.list() set up to return empty list""" + mock_client = MagicMock() + + async def mock_empty_models_list(): + return + yield # Make it an async generator but don't yield anything + + mock_client.models.list.return_value = mock_empty_models_list() + return mock_client + + +@pytest.fixture +def mock_client_with_exception(): + """Create a mock client with models.list() set up to raise an exception""" + mock_client = MagicMock() + mock_client.models.list.side_effect = Exception("API Error") + return mock_client + + +@pytest.fixture +def mock_client_context(): + """Fixture that provides a context manager for mocking the OpenAI client""" + + def _mock_client_context(mixin, mock_client): + return patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client) + + return _mock_client_context + + +def _assert_models_match_expected(actual_models, expected_models): + """Verify the models match expected attributes. + + Args: + actual_models: List of models to verify + expected_models: Mapping of model identifier to expected attribute values + """ + for identifier, expected_attrs in expected_models.items(): + model = next(m for m in actual_models if m.identifier == identifier) + for attr_name, expected_value in expected_attrs.items(): + assert getattr(model, attr_name) == expected_value diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py deleted file mode 100644 index 4d34b1de15..0000000000 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ /dev/null @@ -1,1438 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -from collections.abc import Iterable -from typing import Any -from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch - -import pytest -from pydantic import BaseModel, Field, SecretStr - -from llama_stack.core.request_headers import request_provider_data_context -from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig -from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin -from llama_stack_api import ( - Model, - ModelType, - OpenAIChatCompletionRequestWithExtraBody, - OpenAICompletionRequestWithExtraBody, - OpenAIEmbeddingsRequestWithExtraBody, - OpenAIUserMessageParam, -) - - -class OpenAIMixinImpl(OpenAIMixin): - __provider_id__: str = "test-provider" - - def get_api_key(self) -> str: - return "test-api-key" - - def get_base_url(self) -> str: - return "http://test-base-url" - - -class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl): - """Test implementation with embedding model metadata""" - - embedding_model_metadata: dict[str, dict[str, int]] = { - "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, - "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, - } - - -class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl): - """Test implementation that uses construct_model_from_identifier to add rerank models""" - - embedding_model_metadata: dict[str, dict[str, int]] = { - "text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192}, - "text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192}, - } - - # Adds rerank models via construct_model_from_identifier - rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"} - - def construct_model_from_identifier(self, identifier: str) -> Model: - if identifier in self.rerank_model_ids: - return Model( - provider_id=self.__provider_id__, # type: ignore[attr-defined] - provider_resource_id=identifier, - identifier=identifier, - model_type=ModelType.rerank, - ) - return super().construct_model_from_identifier(identifier) - - -@pytest.fixture -def mixin(): - """Create a test instance of OpenAIMixin with mocked model_store""" - config = RemoteInferenceProviderConfig() - mixin_instance = OpenAIMixinImpl(config=config) - - # Mock model_store with async methods - mock_model_store = AsyncMock() - mock_model = MagicMock() - mock_model.provider_resource_id = "test-provider-resource-id" - mock_model_store.get_model = AsyncMock(return_value=mock_model) - mock_model_store.has_model = AsyncMock(return_value=False) # Default to False, tests can override - mixin_instance.model_store = mock_model_store - - return mixin_instance - - -@pytest.fixture -def mixin_with_embeddings(): - """Create a test instance of OpenAIMixin with embedding model metadata""" - config = RemoteInferenceProviderConfig() - return OpenAIMixinWithEmbeddingsImpl(config=config) - - -@pytest.fixture -def mixin_with_custom_model_construction(): - """Create a test instance using custom construct_model_from_identifier""" - config = RemoteInferenceProviderConfig() - return OpenAIMixinWithCustomModelConstruction(config=config) - - -@pytest.fixture -def mock_models(): - """Create multiple mock OpenAI model objects""" - models = [MagicMock(id=id) for id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]] - return models - - -@pytest.fixture -def mock_client_with_models(mock_models): - """Create a mock client with models.list() set up to return mock_models""" - mock_client = MagicMock() - - async def mock_models_list(): - for model in mock_models: - yield model - - mock_client.models.list.return_value = mock_models_list() - return mock_client - - -@pytest.fixture -def mock_client_with_empty_models(): - """Create a mock client with models.list() set up to return empty list""" - mock_client = MagicMock() - - async def mock_empty_models_list(): - return - yield # Make it an async generator but don't yield anything - - mock_client.models.list.return_value = mock_empty_models_list() - return mock_client - - -@pytest.fixture -def mock_client_with_exception(): - """Create a mock client with models.list() set up to raise an exception""" - mock_client = MagicMock() - mock_client.models.list.side_effect = Exception("API Error") - return mock_client - - -@pytest.fixture -def mock_client_context(): - """Fixture that provides a context manager for mocking the OpenAI client""" - - def _mock_client_context(mixin, mock_client): - return patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client) - - return _mock_client_context - - -def _assert_models_match_expected(actual_models, expected_models): - """Verify the models match expected attributes. - - Args: - actual_models: List of models to verify - expected_models: Mapping of model identifier to expected attribute values - """ - for identifier, expected_attrs in expected_models.items(): - model = next(m for m in actual_models if m.identifier == identifier) - for attr_name, expected_value in expected_attrs.items(): - assert getattr(model, attr_name) == expected_value - - -class TestOpenAIMixinListModels: - """Test cases for the list_models method""" - - async def test_list_models_success(self, mixin, mock_client_with_models, mock_client_context): - """Test successful model listing""" - assert len(mixin._model_cache) == 0 - - with mock_client_context(mixin, mock_client_with_models): - result = await mixin.list_models() - - assert result is not None - assert len(result) == 3 - - model_ids = [model.identifier for model in result] - assert "some-mock-model-id" in model_ids - assert "another-mock-model-id" in model_ids - assert "final-mock-model-id" in model_ids - - for model in result: - assert model.provider_id == "test-provider" - assert model.model_type == ModelType.llm - assert model.provider_resource_id == model.identifier - - assert len(mixin._model_cache) == 3 - for model_id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]: - assert model_id in mixin._model_cache - cached_model = mixin._model_cache[model_id] - assert cached_model.identifier == model_id - assert cached_model.provider_resource_id == model_id - - async def test_list_models_empty_response(self, mixin, mock_client_with_empty_models, mock_client_context): - """Test handling of empty model list""" - with mock_client_context(mixin, mock_client_with_empty_models): - result = await mixin.list_models() - - assert result is not None - assert len(result) == 0 - assert len(mixin._model_cache) == 0 - - -class TestOpenAIMixinCheckModelAvailability: - """Test cases for the check_model_availability method""" - - async def test_check_model_availability_with_cache(self, mixin, mock_client_with_models, mock_client_context): - """Test model availability check when cache is populated""" - with mock_client_context(mixin, mock_client_with_models): - mock_client_with_models.models.list.assert_not_called() - await mixin.list_models() - mock_client_with_models.models.list.assert_called_once() - - assert await mixin.check_model_availability("some-mock-model-id") - assert await mixin.check_model_availability("another-mock-model-id") - assert await mixin.check_model_availability("final-mock-model-id") - assert not await mixin.check_model_availability("non-existent-model") - mock_client_with_models.models.list.assert_called_once() - - async def test_check_model_availability_without_cache(self, mixin, mock_client_with_models, mock_client_context): - """Test model availability check when cache is empty (calls list_models)""" - assert len(mixin._model_cache) == 0 - - with mock_client_context(mixin, mock_client_with_models): - mock_client_with_models.models.list.assert_not_called() - assert await mixin.check_model_availability("some-mock-model-id") - mock_client_with_models.models.list.assert_called_once() - - assert len(mixin._model_cache) == 3 - assert "some-mock-model-id" in mixin._model_cache - - async def test_check_model_availability_model_not_found(self, mixin, mock_client_with_models, mock_client_context): - """Test model availability check for non-existent model""" - with mock_client_context(mixin, mock_client_with_models): - mock_client_with_models.models.list.assert_not_called() - assert not await mixin.check_model_availability("non-existent-model") - mock_client_with_models.models.list.assert_called_once() - - assert len(mixin._model_cache) == 3 - - async def test_check_model_availability_with_pre_registered_model( - self, mixin, mock_client_with_models, mock_client_context - ): - """Test that check_model_availability returns True for pre-registered models in model_store""" - # Mock model_store.has_model to return True for a specific model - mock_model_store = AsyncMock() - mock_model_store.has_model = AsyncMock(return_value=True) - mixin.model_store = mock_model_store - - # Test that pre-registered model is found without calling the provider's API - with mock_client_context(mixin, mock_client_with_models): - mock_client_with_models.models.list.assert_not_called() - assert await mixin.check_model_availability("pre-registered-model") - # Should not call the provider's list_models since model was found in store - mock_client_with_models.models.list.assert_not_called() - mock_model_store.has_model.assert_called_once_with("test-provider/pre-registered-model") - - async def test_check_model_availability_fallback_to_provider_when_not_in_store( - self, mixin, mock_client_with_models, mock_client_context - ): - """Test that check_model_availability falls back to provider when model not in store""" - # Mock model_store.has_model to return False - mock_model_store = AsyncMock() - mock_model_store.has_model = AsyncMock(return_value=False) - mixin.model_store = mock_model_store - - # Test that it falls back to provider's model cache - with mock_client_context(mixin, mock_client_with_models): - mock_client_with_models.models.list.assert_not_called() - assert await mixin.check_model_availability("some-mock-model-id") - # Should call the provider's list_models since model was not found in store - mock_client_with_models.models.list.assert_called_once() - mock_model_store.has_model.assert_called_once_with("test-provider/some-mock-model-id") - - -class TestOpenAIMixinCacheBehavior: - """Test cases for cache behavior and edge cases""" - - async def test_cache_overwrites_on_list_models_call(self, mixin, mock_client_with_models, mock_client_context): - """Test that calling list_models overwrites existing cache""" - initial_model = Model( - provider_id="test-provider", - provider_resource_id="old-model", - identifier="old-model", - model_type=ModelType.llm, - ) - mixin._model_cache = {"old-model": initial_model} - - with mock_client_context(mixin, mock_client_with_models): - await mixin.list_models() - - assert len(mixin._model_cache) == 3 - assert "old-model" not in mixin._model_cache - assert "some-mock-model-id" in mixin._model_cache - assert "another-mock-model-id" in mixin._model_cache - assert "final-mock-model-id" in mixin._model_cache - - -class TestOpenAIMixinImagePreprocessing: - """Test cases for image preprocessing functionality""" - - async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin): - """Test that image URLs are converted to base64 when download_images is True""" - mixin.download_images = True - - message = OpenAIUserMessageParam( - role="user", - content=[ - {"type": "text", "text": "What's in this image?"}, - {"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}}, - ], - ) - - mock_client = MagicMock() - mock_response = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_response) - - with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client): - with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize: - mock_localize.return_value = (b"fake_image_data", "jpeg") - - params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message]) - await mixin.openai_chat_completion(params) - - mock_localize.assert_called_once_with("http://example.com/image.jpg") - - mock_client.chat.completions.create.assert_called_once() - call_args = mock_client.chat.completions.create.call_args - processed_messages = call_args[1]["messages"] - assert len(processed_messages) == 1 - content = processed_messages[0]["content"] - assert len(content) == 2 - assert content[0]["type"] == "text" - assert content[1]["type"] == "image_url" - assert content[1]["image_url"]["url"] == "data:image/jpeg;base64,ZmFrZV9pbWFnZV9kYXRh" - - async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin): - """Test that image URLs are not modified when download_images is False""" - mixin.download_images = False # explicitly set to False - - message = OpenAIUserMessageParam( - role="user", - content=[ - {"type": "text", "text": "What's in this image?"}, - {"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}}, - ], - ) - - mock_client = MagicMock() - mock_response = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=mock_response) - - with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client): - with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize: - params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message]) - await mixin.openai_chat_completion(params) - - mock_localize.assert_not_called() - - mock_client.chat.completions.create.assert_called_once() - call_args = mock_client.chat.completions.create.call_args - processed_messages = call_args[1]["messages"] - assert len(processed_messages) == 1 - content = processed_messages[0]["content"] - assert len(content) == 2 - assert content[1]["image_url"]["url"] == "http://example.com/image.jpg" - - -class TestOpenAIMixinEmbeddingModelMetadata: - """Test cases for embedding_model_metadata attribute functionality""" - - async def test_embedding_model_identified_and_augmented(self, mixin_with_embeddings, mock_client_context): - """Test that models in embedding_model_metadata are correctly identified as embeddings with metadata""" - # Create mock models: 1 embedding model and 1 LLM, while there are 2 known embedding models - mock_embedding_model = MagicMock(id="text-embedding-3-small") - mock_llm_model = MagicMock(id="gpt-4") - mock_models = [mock_embedding_model, mock_llm_model] - - mock_client = MagicMock() - - async def mock_models_list(): - for model in mock_models: - yield model - - mock_client.models.list.return_value = mock_models_list() - - with mock_client_context(mixin_with_embeddings, mock_client): - result = await mixin_with_embeddings.list_models() - - assert result is not None - assert len(result) == 2 - - expected_models = { - "text-embedding-3-small": { - "model_type": ModelType.embedding, - "metadata": {"embedding_dimension": 1536, "context_length": 8192}, - "provider_id": "test-provider", - "provider_resource_id": "text-embedding-3-small", - }, - "gpt-4": { - "model_type": ModelType.llm, - "metadata": {}, - "provider_id": "test-provider", - "provider_resource_id": "gpt-4", - }, - } - - _assert_models_match_expected(result, expected_models) - - -class TestOpenAIMixinCustomModelConstruction: - """Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier""" - - async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context): - """Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata""" - # Create mock models: 1 embedding, 1 rerank, 1 LLM - mock_embedding_model = MagicMock(id="text-embedding-3-small") - mock_rerank_model = MagicMock(id="rerank-model-1") - mock_llm_model = MagicMock(id="gpt-4") - mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model] - - mock_client = MagicMock() - - async def mock_models_list(): - for model in mock_models: - yield model - - mock_client.models.list.return_value = mock_models_list() - - with mock_client_context(mixin_with_custom_model_construction, mock_client): - result = await mixin_with_custom_model_construction.list_models() - - assert result is not None - assert len(result) == 3 - - expected_models = { - "text-embedding-3-small": { - "model_type": ModelType.embedding, - "metadata": {"embedding_dimension": 1536, "context_length": 8192}, - "provider_id": "test-provider", - "provider_resource_id": "text-embedding-3-small", - }, - "rerank-model-1": { - "model_type": ModelType.rerank, - "metadata": {}, - "provider_id": "test-provider", - "provider_resource_id": "rerank-model-1", - }, - "gpt-4": { - "model_type": ModelType.llm, - "metadata": {}, - "provider_id": "test-provider", - "provider_resource_id": "gpt-4", - }, - } - - _assert_models_match_expected(result, expected_models) - - -class TestOpenAIMixinAllowedModels: - """Test cases for allowed_models filtering functionality""" - - async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context): - """Test that list_models filters models based on allowed_models""" - mixin.config.allowed_models = ["some-mock-model-id", "another-mock-model-id"] - - with mock_client_context(mixin, mock_client_with_models): - result = await mixin.list_models() - - assert result is not None - assert len(result) == 2 - - model_ids = [model.identifier for model in result] - assert "some-mock-model-id" in model_ids - assert "another-mock-model-id" in model_ids - assert "final-mock-model-id" not in model_ids - - async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context): - """Test that empty allowed_models allows no models""" - mixin.config.allowed_models = [] - - with mock_client_context(mixin, mock_client_with_models): - result = await mixin.list_models() - - assert result is not None - assert len(result) == 0 # No models should be included - - async def test_list_models_with_omitted_allowed_models(self, mixin, mock_client_with_models, mock_client_context): - """Test that omitted allowed_models allows all models""" - assert mixin.config.allowed_models is None - - with mock_client_context(mixin, mock_client_with_models): - result = await mixin.list_models() - - assert result is not None - assert len(result) == 3 # All models should be included - - model_ids = [model.identifier for model in result] - assert "some-mock-model-id" in model_ids - assert "another-mock-model-id" in model_ids - assert "final-mock-model-id" in model_ids - - async def test_check_model_availability_with_allowed_models( - self, mixin, mock_client_with_models, mock_client_context - ): - """Test that check_model_availability respects allowed_models""" - mixin.config.allowed_models = ["final-mock-model-id"] - - with mock_client_context(mixin, mock_client_with_models): - assert await mixin.check_model_availability("final-mock-model-id") - assert not await mixin.check_model_availability("some-mock-model-id") - assert not await mixin.check_model_availability("another-mock-model-id") - - -class TestOpenAIMixinModelRegistration: - """Test cases for model registration functionality""" - - async def test_register_model_success(self, mock_client_with_models, mock_client_context): - """Test successful model registration when model is available""" - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinImpl(config=config) - - # Enable validation for this model - model = Model( - provider_id="test-provider", - provider_resource_id="some-mock-model-id", - identifier="test-model", - model_type=ModelType.llm, - model_validation=True, - ) - - with mock_client_context(mixin, mock_client_with_models): - result = await mixin.register_model(model) - - assert result == model - assert result.provider_id == "test-provider" - assert result.provider_resource_id == "some-mock-model-id" - assert result.identifier == "test-model" - assert result.model_type == ModelType.llm - mock_client_with_models.models.list.assert_called_once() - - async def test_register_model_not_available(self, mock_client_with_models, mock_client_context): - """Test model registration failure when model is not available from provider""" - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinImpl(config=config) - - # Enable validation for this model - model = Model( - provider_id="test-provider", - provider_resource_id="non-existent-model", - identifier="test-model", - model_type=ModelType.llm, - model_validation=True, - ) - - with mock_client_context(mixin, mock_client_with_models): - with pytest.raises( - ValueError, match="Model non-existent-model is not available from provider test-provider" - ): - await mixin.register_model(model) - mock_client_with_models.models.list.assert_called_once() - - async def test_register_model_with_allowed_models_filter(self, mock_client_with_models, mock_client_context): - """Test model registration with allowed_models filtering""" - config = RemoteInferenceProviderConfig(allowed_models=["some-mock-model-id"]) - mixin = OpenAIMixinImpl(config=config) - - # Test with allowed model (with validation enabled) - allowed_model = Model( - provider_id="test-provider", - provider_resource_id="some-mock-model-id", - identifier="allowed-model", - model_type=ModelType.llm, - model_validation=True, - ) - - # Test with disallowed model (with validation enabled) - disallowed_model = Model( - provider_id="test-provider", - provider_resource_id="final-mock-model-id", - identifier="disallowed-model", - model_type=ModelType.llm, - model_validation=True, - ) - - with mock_client_context(mixin, mock_client_with_models): - result = await mixin.register_model(allowed_model) - assert result == allowed_model - with pytest.raises( - ValueError, match="Model final-mock-model-id is not available from provider test-provider" - ): - await mixin.register_model(disallowed_model) - mock_client_with_models.models.list.assert_called_once() - - async def test_register_embedding_model(self, mixin_with_embeddings, mock_client_context): - """Test registration of embedding models with metadata""" - mock_embedding_model = MagicMock(id="text-embedding-3-small") - mock_models = [mock_embedding_model] - - mock_client = MagicMock() - - async def mock_models_list(): - for model in mock_models: - yield model - - mock_client.models.list.return_value = mock_models_list() - - embedding_model = Model( - provider_id="test-provider", - provider_resource_id="text-embedding-3-small", - identifier="embedding-test", - model_type=ModelType.embedding, - ) - - with mock_client_context(mixin_with_embeddings, mock_client): - result = await mixin_with_embeddings.register_model(embedding_model) - assert result == embedding_model - assert result.model_type == ModelType.embedding - - async def test_unregister_model(self, mixin): - """Test model unregistration (should be no-op)""" - # unregister_model should not raise any exceptions and return None - result = await mixin.unregister_model("any-model-id") - assert result is None - - async def test_should_refresh_models(self, mixin): - """Test should_refresh_models method returns config value""" - # Default config has refresh_models=False - result = await mixin.should_refresh_models() - assert result is False - - # With refresh_models=True, should return True - config_with_refresh = RemoteInferenceProviderConfig(refresh_models=True) - mixin_with_refresh = OpenAIMixinImpl(config=config_with_refresh) - result_with_refresh = await mixin_with_refresh.should_refresh_models() - assert result_with_refresh is True - - async def test_register_model_error_propagation(self, mock_client_with_exception, mock_client_context): - """Test that errors from provider API are properly propagated during registration""" - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinImpl(config=config) - - # Enable validation for this model - model = Model( - provider_id="test-provider", - provider_resource_id="some-model", - identifier="test-model", - model_type=ModelType.llm, - model_validation=True, - ) - - with mock_client_context(mixin, mock_client_with_exception): - # The exception from the API should be propagated - with pytest.raises(Exception, match="API Error"): - await mixin.register_model(model) - - async def test_register_model_default_behavior_no_validation(self, mock_client_with_models, mock_client_context): - """Test model registration with default behavior (no validation)""" - # Default behavior - no validation - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinImpl(config=config) - - model = Model( - provider_id="test-provider", - provider_resource_id="non-existent-model", - identifier="test-model", - model_type=ModelType.llm, - ) - - with mock_client_context(mixin, mock_client_with_models): - # Should succeed without checking model availability (default behavior) - result = await mixin.register_model(model) - - assert result == model - # Verify that models.list() was NOT called - mock_client_with_models.models.list.assert_not_called() - - async def test_register_model_with_validation_enabled(self, mock_client_with_models, mock_client_context): - """Test that model-level model_validation=True enables validation""" - # Default config (no provider-level validation setting) - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinImpl(config=config) - - # Model explicitly enables validation - model = Model( - provider_id="test-provider", - provider_resource_id="non-existent-model", - identifier="test-model", - model_type=ModelType.llm, - model_validation=True, - ) - - with mock_client_context(mixin, mock_client_with_models): - # Should fail because model-level validation is enabled - with pytest.raises(ValueError, match="Model non-existent-model is not available"): - await mixin.register_model(model) - # Verify that models.list() WAS called (validation happened) - mock_client_with_models.models.list.assert_called_once() - - async def test_register_model_with_validation_explicitly_disabled( - self, mock_client_with_models, mock_client_context - ): - """Test that model-level model_validation=False explicitly disables validation""" - # Default config - config = RemoteInferenceProviderConfig() - mixin = OpenAIMixinImpl(config=config) - - # Model explicitly disables validation (though this is the default anyway) - model = Model( - provider_id="test-provider", - provider_resource_id="non-existent-model", - identifier="test-model", - model_type=ModelType.llm, - model_validation=False, - ) - - with mock_client_context(mixin, mock_client_with_models): - # Should succeed because validation is disabled - result = await mixin.register_model(model) - - assert result == model - # Verify that models.list() was NOT called - mock_client_with_models.models.list.assert_not_called() - - -class ProviderDataValidator(BaseModel): - """Validator for provider data in tests""" - - test_api_key: SecretStr | None = Field(default=None) - - -class OpenAIMixinWithProviderData(OpenAIMixinImpl): - """Test implementation that supports provider data API key field""" - - provider_data_api_key_field: str = "test_api_key" - - def get_api_key(self) -> str: - return "default-api-key" - - def get_base_url(self): - return "default-base-url" - - -class CustomListProviderModelIdsImplementation(OpenAIMixinImpl): - """Test implementation with custom list_provider_model_ids override""" - - custom_model_ids: Any - - async def list_provider_model_ids(self) -> Iterable[str]: - """Return custom model IDs list""" - return self.custom_model_ids - - -class TestOpenAIMixinCustomListProviderModelIds: - """Test cases for custom list_provider_model_ids() implementation functionality""" - - @pytest.fixture - def custom_model_ids_list(self): - """Create a list of custom model ID strings""" - return ["custom-model-1", "custom-model-2", "custom-embedding"] - - @pytest.fixture - def config(self): - """Create RemoteInferenceProviderConfig instance""" - return RemoteInferenceProviderConfig() - - @pytest.fixture - def adapter(self, custom_model_ids_list, config): - """Create mixin instance with custom list_provider_model_ids implementation""" - mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=custom_model_ids_list) - mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}} - return mixin - - async def test_is_used(self, adapter, custom_model_ids_list): - """Test that custom list_provider_model_ids() implementation is used instead of client.models.list()""" - result = await adapter.list_models() - - assert result is not None - assert len(result) == 3 - - assert set(custom_model_ids_list) == {m.identifier for m in result} - - async def test_populates_cache(self, adapter, custom_model_ids_list): - """Test that custom list_provider_model_ids() results are cached""" - assert len(adapter._model_cache) == 0 - - await adapter.list_models() - - assert set(custom_model_ids_list) == set(adapter._model_cache.keys()) - - async def test_respects_allowed_models(self, config): - """Test that custom list_provider_model_ids() respects allowed_models filtering""" - mixin = CustomListProviderModelIdsImplementation( - config=config, custom_model_ids=["model-1", "model-2", "model-3"] - ) - mixin.config.allowed_models = ["model-1"] - - result = await mixin.list_models() - - assert result is not None - assert len(result) == 1 - assert result[0].identifier == "model-1" - - async def test_with_empty_list(self, config): - """Test that custom list_provider_model_ids() handles empty list correctly""" - mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[]) - - result = await mixin.list_models() - - assert result is not None - assert len(result) == 0 - assert len(mixin._model_cache) == 0 - - async def test_wrong_type_raises_error(self, config): - """Test that list_provider_model_ids() returning unhashable items results in an error""" - mixin = CustomListProviderModelIdsImplementation( - config=config, custom_model_ids=["valid-model", ["nested", "list"]] - ) - with pytest.raises(Exception, match="is not a string"): - await mixin.list_models() - - mixin = CustomListProviderModelIdsImplementation( - config=config, custom_model_ids=[{"key": "value"}, "valid-model"] - ) - with pytest.raises(Exception, match="is not a string"): - await mixin.list_models() - - mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=["valid-model", 42.0]) - with pytest.raises(Exception, match="is not a string"): - await mixin.list_models() - - mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[None]) - with pytest.raises(Exception, match="is not a string"): - await mixin.list_models() - - async def test_non_iterable_raises_error(self, config): - """Test that list_provider_model_ids() returning non-iterable type raises error""" - mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=42) - - with pytest.raises( - TypeError, - match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int", - ): - await mixin.list_models() - - async def test_accepts_various_iterables(self, config): - """Test that list_provider_model_ids() accepts tuples, sets, generators, etc.""" - - tuples = CustomListProviderModelIdsImplementation( - config=config, custom_model_ids=("model-1", "model-2", "model-3") - ) - result = await tuples.list_models() - assert result is not None - assert len(result) == 3 - - class GeneratorAdapter(OpenAIMixinImpl): - async def list_provider_model_ids(self) -> Iterable[str]: - def gen(): - yield "gen-model-1" - yield "gen-model-2" - - return gen() - - mixin = GeneratorAdapter(config=config) - result = await mixin.list_models() - assert result is not None - assert len(result) == 2 - - sets = CustomListProviderModelIdsImplementation(config=config, custom_model_ids={"set-model-1", "set-model-2"}) - result = await sets.list_models() - assert result is not None - assert len(result) == 2 - - -class TestOpenAIMixinProviderDataApiKey: - """Test cases for provider_data_api_key_field functionality""" - - @pytest.fixture - def mixin_with_provider_data_field(self): - """Mixin instance with provider_data_api_key_field set""" - config = RemoteInferenceProviderConfig() - mixin_instance = OpenAIMixinWithProviderData(config=config) - - # Mock provider_spec for provider data validation - mock_provider_spec = MagicMock() - mock_provider_spec.provider_type = "test-provider-with-data" - mock_provider_spec.provider_data_validator = ( - "tests.unit.providers.utils.inference.test_openai_mixin.ProviderDataValidator" - ) - mixin_instance.__provider_spec__ = mock_provider_spec - - return mixin_instance - - @pytest.fixture - def mixin_with_provider_data_field_and_none_api_key(self, mixin_with_provider_data_field): - mixin_with_provider_data_field.get_api_key = Mock(return_value=None) - return mixin_with_provider_data_field - - def test_no_provider_data(self, mixin_with_provider_data_field): - """Test that client uses config API key when no provider data is available""" - assert mixin_with_provider_data_field.client.api_key == "default-api-key" - - def test_with_provider_data(self, mixin_with_provider_data_field): - """Test that provider data API key overrides config API key""" - with request_provider_data_context( - {"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})} - ): - assert mixin_with_provider_data_field.client.api_key == "provider-data-key" - - def test_with_wrong_key(self, mixin_with_provider_data_field): - """Test fallback to config when provider data doesn't have the required key""" - with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}): - assert mixin_with_provider_data_field.client.api_key == "default-api-key" - - def test_error_when_no_config_and_provider_data_has_wrong_key( - self, mixin_with_provider_data_field_and_none_api_key - ): - """Test that ValueError is raised when provider data exists but doesn't have required key""" - with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}): - with pytest.raises(ValueError, match="API key not provided"): - _ = mixin_with_provider_data_field_and_none_api_key.client - - def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key): - """Test that error message includes correct field name and header information""" - with pytest.raises(ValueError) as exc_info: - _ = mixin_with_provider_data_field_and_none_api_key.client - - error_message = str(exc_info.value) - assert "test_api_key" in error_message - assert "x-llamastack-provider-data" in error_message - - -class TestOpenAIMixinAllowedModelsInference: - """Test cases for allowed_models enforcement during inference requests""" - - async def test_inference_with_allowed_models(self, mixin, mock_client_context): - """Test that all inference methods succeed with allowed models""" - mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"] - - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - mock_client.completions.create = AsyncMock(return_value=MagicMock()) - mock_embedding_response = MagicMock() - mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] - mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5) - mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response) - - with mock_client_context(mixin, mock_client): - # Test chat completion - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] - ) - ) - mock_client.chat.completions.create.assert_called_once() - - # Test completion - await mixin.openai_completion( - OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello") - ) - mock_client.completions.create.assert_called_once() - - # Test embeddings - await mixin.openai_embeddings( - OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text") - ) - mock_client.embeddings.create.assert_called_once() - - async def test_inference_with_disallowed_models(self, mixin, mock_client_context): - """Test that all inference methods fail with disallowed models""" - mixin.config.allowed_models = ["gpt-4"] - - mock_client = MagicMock() - - with mock_client_context(mixin, mock_client): - # Test chat completion with disallowed model - with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")] - ) - ) - - # Test completion with disallowed model - with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"): - await mixin.openai_completion( - OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello") - ) - - # Test embeddings with disallowed model - with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"): - await mixin.openai_embeddings( - OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text") - ) - - mock_client.chat.completions.create.assert_not_called() - mock_client.completions.create.assert_not_called() - mock_client.embeddings.create.assert_not_called() - - async def test_inference_with_no_restrictions(self, mixin, mock_client_context): - """Test that inference succeeds when allowed_models is None or empty list blocks all""" - # Test with None (no restrictions) - assert mixin.config.allowed_models is None - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - with mock_client_context(mixin, mock_client): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")] - ) - ) - mock_client.chat.completions.create.assert_called_once() - - # Test with empty list (blocks all models) - mixin.config.allowed_models = [] - with mock_client_context(mixin, mock_client): - with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] - ) - ) - - -class TestOpenAIMixinStreamOptionsInjection: - """Test cases for automatic stream_options injection when telemetry is active""" - - async def test_chat_completion_injects_stream_options_when_telemetry_active(self, mixin, mock_client_context): - """Test that stream_options is injected for streaming chat completion when telemetry is active""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - # Mock OpenTelemetry span as recording - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True - ) - ) - - mock_client.chat.completions.create.assert_called_once() - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert call_kwargs["stream_options"] == {"include_usage": True} - - async def test_chat_completion_preserves_existing_stream_options(self, mixin, mock_client_context): - """Test that existing stream_options are preserved with include_usage added""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - stream=True, - stream_options={"other_option": True}, - ) - ) - - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True} - - async def test_chat_completion_no_injection_when_telemetry_inactive(self, mixin, mock_client_context): - """Test that stream_options is NOT injected when telemetry is inactive""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - # Mock OpenTelemetry span as not recording - mock_span = MagicMock() - mock_span.is_recording.return_value = False - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True - ) - ) - - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None - - async def test_chat_completion_no_injection_when_not_streaming(self, mixin, mock_client_context): - """Test that stream_options is NOT injected for non-streaming requests""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=False - ) - ) - - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None - - async def test_completion_injects_stream_options_when_telemetry_active(self, mixin, mock_client_context): - """Test that stream_options is injected for streaming completion when telemetry is active""" - mock_client = MagicMock() - mock_client.completions.create = AsyncMock(return_value=MagicMock()) - - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_completion( - OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True) - ) - - mock_client.completions.create.assert_called_once() - call_kwargs = mock_client.completions.create.call_args[1] - assert call_kwargs["stream_options"] == {"include_usage": True} - - async def test_completion_no_injection_when_telemetry_inactive(self, mixin, mock_client_context): - """Test that stream_options is NOT injected for completion when telemetry is inactive""" - mock_client = MagicMock() - mock_client.completions.create = AsyncMock(return_value=MagicMock()) - - mock_span = MagicMock() - mock_span.is_recording.return_value = False - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_completion( - OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True) - ) - - call_kwargs = mock_client.completions.create.call_args[1] - assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None - - async def test_params_not_mutated(self, mixin, mock_client_context): - """Test that original params object is not mutated when stream_options is injected""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - original_params = OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True - ) - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_chat_completion(original_params) - - # Original params should not be modified - assert original_params.stream_options is None - - async def test_chat_completion_overrides_include_usage_false(self, mixin, mock_client_context): - """Test that include_usage=False is overridden when telemetry is active""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - stream=True, - stream_options={"include_usage": False}, - ) - ) - - call_kwargs = mock_client.chat.completions.create.call_args[1] - # Telemetry must override False to ensure complete metrics - assert call_kwargs["stream_options"]["include_usage"] is True - - async def test_no_injection_when_provider_doesnt_support_stream_options(self, mixin, mock_client_context): - """Test that stream_options is NOT injected when provider doesn't support it""" - # Set supports_stream_options to False (like Ollama/vLLM) - mixin.supports_stream_options = False - - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - # Mock OpenTelemetry span as recording (telemetry is active) - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True - ) - ) - - call_kwargs = mock_client.chat.completions.create.call_args[1] - # Should NOT inject stream_options even though telemetry is active - assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None - - async def test_completion_no_injection_when_provider_doesnt_support_stream_options( - self, mixin, mock_client_context - ): - """Test that stream_options is NOT injected for completion when provider doesn't support it""" - # Set supports_stream_options to False (like Ollama/vLLM) - mixin.supports_stream_options = False - - mock_client = MagicMock() - mock_client.completions.create = AsyncMock(return_value=MagicMock()) - - # Mock OpenTelemetry span as recording (telemetry is active) - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_completion( - OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True) - ) - - call_kwargs = mock_client.completions.create.call_args[1] - # Should NOT inject stream_options even though telemetry is active - assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None - - -class TestOpenAIMixinSafetyIdentifierPassing: - """Test cases for safety_identifier parameter passing to OpenAI API""" - - async def test_chat_completion_passes_safety_identifier(self, mixin, mock_client_context): - """Test that safety_identifier is passed to OpenAI chat completions API""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - with mock_client_context(mixin, mock_client): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - safety_identifier="user-123-hashed", - ) - ) - - mock_client.chat.completions.create.assert_called_once() - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert call_kwargs["safety_identifier"] == "user-123-hashed" - - async def test_chat_completion_with_top_p(self, mixin, mock_client_context): - """Test that top_p is properly passed to the OpenAI client""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - top_p_value = 0.9 - - with mock_client_context(mixin, mock_client): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - top_p=top_p_value, - ) - ) - - mock_client.chat.completions.create.assert_called_once() - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert call_kwargs["top_p"] == top_p_value - - -class TestOpenAIMixinPromptCacheKey: - """Test cases for prompt_cache_key parameter propagation""" - - async def test_chat_completion_with_prompt_cache_key(self, mixin, mock_client_context): - """Test that prompt_cache_key is properly passed to the OpenAI client""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - cache_key = "test-cache-key-123" - - with mock_client_context(mixin, mock_client): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - prompt_cache_key=cache_key, - ) - ) - - mock_client.chat.completions.create.assert_called_once() - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert call_kwargs["prompt_cache_key"] == cache_key - - -class TestOpenAIMixinServiceTier: - """Test cases for service_tier parameter in OpenAIMixin""" - - async def test_chat_completion_passes_service_tier_to_openai(self, mixin, mock_client_context): - """Test that service_tier parameter is passed to OpenAI client for chat completion""" - from llama_stack_api.inference import ServiceTier - - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - with mock_client_context(mixin, mock_client): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - service_tier=ServiceTier.priority, - ) - ) - - mock_client.chat.completions.create.assert_called_once() - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert call_kwargs["service_tier"] == ServiceTier.priority - - -class TestOpenAIMixinTopLogprobs: - """Test cases for top_logprobs parameter in chat completion requests""" - - async def test_chat_completion_with_top_logprobs_value_5(self, mixin, mock_client_context): - """Test that top_logprobs=5 is properly passed to the OpenAI client""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - with mock_client_context(mixin, mock_client): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - top_logprobs=5, - ) - ) - - mock_client.chat.completions.create.assert_called_once() - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert call_kwargs["top_logprobs"] == 5 - - async def test_chat_completion_with_top_logprobs_boundary_min(self, mixin, mock_client_context): - """Test that top_logprobs=0 (minimum) is properly passed to the OpenAI client""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - with mock_client_context(mixin, mock_client): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - top_logprobs=0, - ) - ) - - mock_client.chat.completions.create.assert_called_once() - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert call_kwargs["top_logprobs"] == 0 - - async def test_chat_completion_with_top_logprobs_boundary_max(self, mixin, mock_client_context): - """Test that top_logprobs=20 (maximum) is properly passed to the OpenAI client""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - with mock_client_context(mixin, mock_client): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - top_logprobs=20, - ) - ) - - mock_client.chat.completions.create.assert_called_once() - call_kwargs = mock_client.chat.completions.create.call_args[1] - assert call_kwargs["top_logprobs"] == 20 - - -class TestOpenAIMixinUserProvidedStreamOptions: - """Test cases for user-provided stream_options parameter handling""" - - async def test_user_stream_options_passed_through_when_telemetry_inactive(self, mixin, mock_client_context): - """Test that user-provided stream_options are passed through unchanged when telemetry is inactive""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - mock_span = MagicMock() - mock_span.is_recording.return_value = False - - # OpenAI stream_options supports include_usage (bool) and include_obfuscation (bool) - # Using dict[str, Any] allows for future extensions and provider-specific options - user_stream_options = {"include_obfuscation": True, "custom_field": 123} - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - stream=True, - stream_options=user_stream_options, - ) - ) - - call_kwargs = mock_client.chat.completions.create.call_args[1] - # User's stream_options should be passed through unchanged - assert call_kwargs["stream_options"] == user_stream_options - - async def test_user_stream_options_include_usage_false_overridden_by_telemetry(self, mixin, mock_client_context): - """Test that include_usage=False is overridden to True when telemetry is active""" - mock_client = MagicMock() - mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) - - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - with mock_client_context(mixin, mock_client): - with patch("opentelemetry.trace.get_current_span", return_value=mock_span): - await mixin.openai_chat_completion( - OpenAIChatCompletionRequestWithExtraBody( - model="gpt-4", - messages=[OpenAIUserMessageParam(role="user", content="Hello")], - stream=True, - stream_options={"include_usage": False, "other_option": True}, - ) - ) - - call_kwargs = mock_client.chat.completions.create.call_args[1] - # Telemetry must override include_usage to True - assert call_kwargs["stream_options"]["include_usage"] is True - # Other options should be preserved - assert call_kwargs["stream_options"]["other_option"] is True diff --git a/tests/unit/providers/utils/inference/test_openai_mixin_inference.py b/tests/unit/providers/utils/inference/test_openai_mixin_inference.py new file mode 100644 index 0000000000..5e91310931 --- /dev/null +++ b/tests/unit/providers/utils/inference/test_openai_mixin_inference.py @@ -0,0 +1,524 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from llama_stack_api import ( + OpenAIChatCompletionRequestWithExtraBody, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIUserMessageParam, +) + + +class TestOpenAIMixinAllowedModelsInference: + """Test cases for allowed_models enforcement during inference requests""" + + async def test_inference_with_allowed_models(self, mixin, mock_client_context): + """Test that all inference methods succeed with allowed models""" + mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"] + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + mock_embedding_response = MagicMock() + mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])] + mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5) + mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response) + + with mock_client_context(mixin, mock_client): + # Test chat completion + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + mock_client.chat.completions.create.assert_called_once() + + # Test completion + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello") + ) + mock_client.completions.create.assert_called_once() + + # Test embeddings + await mixin.openai_embeddings( + OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text") + ) + mock_client.embeddings.create.assert_called_once() + + async def test_inference_with_disallowed_models(self, mixin, mock_client_context): + """Test that all inference methods fail with disallowed models""" + mixin.config.allowed_models = ["gpt-4"] + + mock_client = MagicMock() + + with mock_client_context(mixin, mock_client): + # Test chat completion with disallowed model + with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + + # Test completion with disallowed model + with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"): + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello") + ) + + # Test embeddings with disallowed model + with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"): + await mixin.openai_embeddings( + OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text") + ) + + mock_client.chat.completions.create.assert_not_called() + mock_client.completions.create.assert_not_called() + mock_client.embeddings.create.assert_not_called() + + async def test_inference_with_no_restrictions(self, mixin, mock_client_context): + """Test that inference succeeds when allowed_models is None or empty list blocks all""" + # Test with None (no restrictions) + assert mixin.config.allowed_models is None + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + mock_client.chat.completions.create.assert_called_once() + + # Test with empty list (blocks all models) + mixin.config.allowed_models = [] + with mock_client_context(mixin, mock_client): + with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")] + ) + ) + + +class TestOpenAIMixinStreamOptionsInjection: + """Test cases for automatic stream_options injection when telemetry is active""" + + async def test_chat_completion_injects_stream_options_when_telemetry_active(self, mixin, mock_client_context): + """Test that stream_options is injected for streaming chat completion when telemetry is active""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + # Mock OpenTelemetry span as recording + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + ) + + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["stream_options"] == {"include_usage": True} + + async def test_chat_completion_preserves_existing_stream_options(self, mixin, mock_client_context): + """Test that existing stream_options are preserved with include_usage added""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + stream=True, + stream_options={"other_option": True}, + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["stream_options"] == {"other_option": True, "include_usage": True} + + async def test_chat_completion_no_injection_when_telemetry_inactive(self, mixin, mock_client_context): + """Test that stream_options is NOT injected when telemetry is inactive""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + # Mock OpenTelemetry span as not recording + mock_span = MagicMock() + mock_span.is_recording.return_value = False + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_chat_completion_no_injection_when_not_streaming(self, mixin, mock_client_context): + """Test that stream_options is NOT injected for non-streaming requests""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=False + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_completion_injects_stream_options_when_telemetry_active(self, mixin, mock_client_context): + """Test that stream_options is injected for streaming completion when telemetry is active""" + mock_client = MagicMock() + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True) + ) + + mock_client.completions.create.assert_called_once() + call_kwargs = mock_client.completions.create.call_args[1] + assert call_kwargs["stream_options"] == {"include_usage": True} + + async def test_completion_no_injection_when_telemetry_inactive(self, mixin, mock_client_context): + """Test that stream_options is NOT injected for completion when telemetry is inactive""" + mock_client = MagicMock() + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = False + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True) + ) + + call_kwargs = mock_client.completions.create.call_args[1] + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_params_not_mutated(self, mixin, mock_client_context): + """Test that original params object is not mutated when stream_options is injected""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + original_params = OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion(original_params) + + # Original params should not be modified + assert original_params.stream_options is None + + async def test_chat_completion_overrides_include_usage_false(self, mixin, mock_client_context): + """Test that include_usage=False is overridden when telemetry is active""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + stream=True, + stream_options={"include_usage": False}, + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + # Telemetry must override False to ensure complete metrics + assert call_kwargs["stream_options"]["include_usage"] is True + + async def test_no_injection_when_provider_doesnt_support_stream_options(self, mixin, mock_client_context): + """Test that stream_options is NOT injected when provider doesn't support it""" + # Set supports_stream_options to False (like Ollama/vLLM) + mixin.supports_stream_options = False + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + # Mock OpenTelemetry span as recording (telemetry is active) + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")], stream=True + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + # Should NOT inject stream_options even though telemetry is active + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + async def test_completion_no_injection_when_provider_doesnt_support_stream_options( + self, mixin, mock_client_context + ): + """Test that stream_options is NOT injected for completion when provider doesn't support it""" + # Set supports_stream_options to False (like Ollama/vLLM) + mixin.supports_stream_options = False + + mock_client = MagicMock() + mock_client.completions.create = AsyncMock(return_value=MagicMock()) + + # Mock OpenTelemetry span as recording (telemetry is active) + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_completion( + OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello", stream=True) + ) + + call_kwargs = mock_client.completions.create.call_args[1] + # Should NOT inject stream_options even though telemetry is active + assert "stream_options" not in call_kwargs or call_kwargs["stream_options"] is None + + +class TestOpenAIMixinSafetyIdentifierPassing: + """Test cases for safety_identifier parameter passing to OpenAI API""" + + async def test_chat_completion_passes_safety_identifier(self, mixin, mock_client_context): + """Test that safety_identifier is passed to OpenAI chat completions API""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + safety_identifier="user-123-hashed", + ) + ) + + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["safety_identifier"] == "user-123-hashed" + + async def test_chat_completion_with_top_p(self, mixin, mock_client_context): + """Test that top_p is properly passed to the OpenAI client""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + top_p_value = 0.9 + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + top_p=top_p_value, + ) + ) + + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["top_p"] == top_p_value + + +class TestOpenAIMixinPromptCacheKey: + """Test cases for prompt_cache_key parameter propagation""" + + async def test_chat_completion_with_prompt_cache_key(self, mixin, mock_client_context): + """Test that prompt_cache_key is properly passed to the OpenAI client""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + cache_key = "test-cache-key-123" + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + prompt_cache_key=cache_key, + ) + ) + + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["prompt_cache_key"] == cache_key + + +class TestOpenAIMixinServiceTier: + """Test cases for service_tier parameter in OpenAIMixin""" + + async def test_chat_completion_passes_service_tier_to_openai(self, mixin, mock_client_context): + """Test that service_tier parameter is passed to OpenAI client for chat completion""" + from llama_stack_api.inference import ServiceTier + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + service_tier=ServiceTier.priority, + ) + ) + + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["service_tier"] == ServiceTier.priority + + +class TestOpenAIMixinTopLogprobs: + """Test cases for top_logprobs parameter in chat completion requests""" + + async def test_chat_completion_with_top_logprobs_value_5(self, mixin, mock_client_context): + """Test that top_logprobs=5 is properly passed to the OpenAI client""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + top_logprobs=5, + ) + ) + + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["top_logprobs"] == 5 + + async def test_chat_completion_with_top_logprobs_boundary_min(self, mixin, mock_client_context): + """Test that top_logprobs=0 (minimum) is properly passed to the OpenAI client""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + top_logprobs=0, + ) + ) + + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["top_logprobs"] == 0 + + async def test_chat_completion_with_top_logprobs_boundary_max(self, mixin, mock_client_context): + """Test that top_logprobs=20 (maximum) is properly passed to the OpenAI client""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + with mock_client_context(mixin, mock_client): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + top_logprobs=20, + ) + ) + + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args[1] + assert call_kwargs["top_logprobs"] == 20 + + +class TestOpenAIMixinUserProvidedStreamOptions: + """Test cases for user-provided stream_options parameter handling""" + + async def test_user_stream_options_passed_through_when_telemetry_inactive(self, mixin, mock_client_context): + """Test that user-provided stream_options are passed through unchanged when telemetry is inactive""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = False + + # OpenAI stream_options supports include_usage (bool) and include_obfuscation (bool) + # Using dict[str, Any] allows for future extensions and provider-specific options + user_stream_options = {"include_obfuscation": True, "custom_field": 123} + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + stream=True, + stream_options=user_stream_options, + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + # User's stream_options should be passed through unchanged + assert call_kwargs["stream_options"] == user_stream_options + + async def test_user_stream_options_include_usage_false_overridden_by_telemetry(self, mixin, mock_client_context): + """Test that include_usage=False is overridden to True when telemetry is active""" + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=MagicMock()) + + mock_span = MagicMock() + mock_span.is_recording.return_value = True + + with mock_client_context(mixin, mock_client): + with patch("opentelemetry.trace.get_current_span", return_value=mock_span): + await mixin.openai_chat_completion( + OpenAIChatCompletionRequestWithExtraBody( + model="gpt-4", + messages=[OpenAIUserMessageParam(role="user", content="Hello")], + stream=True, + stream_options={"include_usage": False, "other_option": True}, + ) + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + # Telemetry must override include_usage to True + assert call_kwargs["stream_options"]["include_usage"] is True + # Other options should be preserved + assert call_kwargs["stream_options"]["other_option"] is True diff --git a/tests/unit/providers/utils/inference/test_openai_mixin_models.py b/tests/unit/providers/utils/inference/test_openai_mixin_models.py new file mode 100644 index 0000000000..50dc0d0484 --- /dev/null +++ b/tests/unit/providers/utils/inference/test_openai_mixin_models.py @@ -0,0 +1,583 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch + +import pytest + +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig +from llama_stack_api import ( + Model, + ModelType, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIUserMessageParam, +) +from tests.unit.providers.utils.inference.openai_mixin_helpers import ( + OpenAIMixinImpl, + _assert_models_match_expected, +) + + +class TestOpenAIMixinListModels: + """Test cases for the list_models method""" + + async def test_list_models_success(self, mixin, mock_client_with_models, mock_client_context): + """Test successful model listing""" + assert len(mixin._model_cache) == 0 + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 3 + + model_ids = [model.identifier for model in result] + assert "some-mock-model-id" in model_ids + assert "another-mock-model-id" in model_ids + assert "final-mock-model-id" in model_ids + + for model in result: + assert model.provider_id == "test-provider" + assert model.model_type == ModelType.llm + assert model.provider_resource_id == model.identifier + + assert len(mixin._model_cache) == 3 + for model_id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]: + assert model_id in mixin._model_cache + cached_model = mixin._model_cache[model_id] + assert cached_model.identifier == model_id + assert cached_model.provider_resource_id == model_id + + async def test_list_models_empty_response(self, mixin, mock_client_with_empty_models, mock_client_context): + """Test handling of empty model list""" + with mock_client_context(mixin, mock_client_with_empty_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 0 + assert len(mixin._model_cache) == 0 + + +class TestOpenAIMixinCheckModelAvailability: + """Test cases for the check_model_availability method""" + + async def test_check_model_availability_with_cache(self, mixin, mock_client_with_models, mock_client_context): + """Test model availability check when cache is populated""" + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + await mixin.list_models() + mock_client_with_models.models.list.assert_called_once() + + assert await mixin.check_model_availability("some-mock-model-id") + assert await mixin.check_model_availability("another-mock-model-id") + assert await mixin.check_model_availability("final-mock-model-id") + assert not await mixin.check_model_availability("non-existent-model") + mock_client_with_models.models.list.assert_called_once() + + async def test_check_model_availability_without_cache(self, mixin, mock_client_with_models, mock_client_context): + """Test model availability check when cache is empty (calls list_models)""" + assert len(mixin._model_cache) == 0 + + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert await mixin.check_model_availability("some-mock-model-id") + mock_client_with_models.models.list.assert_called_once() + + assert len(mixin._model_cache) == 3 + assert "some-mock-model-id" in mixin._model_cache + + async def test_check_model_availability_model_not_found(self, mixin, mock_client_with_models, mock_client_context): + """Test model availability check for non-existent model""" + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert not await mixin.check_model_availability("non-existent-model") + mock_client_with_models.models.list.assert_called_once() + + assert len(mixin._model_cache) == 3 + + async def test_check_model_availability_with_pre_registered_model( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that check_model_availability returns True for pre-registered models in model_store""" + # Mock model_store.has_model to return True for a specific model + mock_model_store = AsyncMock() + mock_model_store.has_model = AsyncMock(return_value=True) + mixin.model_store = mock_model_store + + # Test that pre-registered model is found without calling the provider's API + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert await mixin.check_model_availability("pre-registered-model") + # Should not call the provider's list_models since model was found in store + mock_client_with_models.models.list.assert_not_called() + mock_model_store.has_model.assert_called_once_with("test-provider/pre-registered-model") + + async def test_check_model_availability_fallback_to_provider_when_not_in_store( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that check_model_availability falls back to provider when model not in store""" + # Mock model_store.has_model to return False + mock_model_store = AsyncMock() + mock_model_store.has_model = AsyncMock(return_value=False) + mixin.model_store = mock_model_store + + # Test that it falls back to provider's model cache + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert await mixin.check_model_availability("some-mock-model-id") + # Should call the provider's list_models since model was not found in store + mock_client_with_models.models.list.assert_called_once() + mock_model_store.has_model.assert_called_once_with("test-provider/some-mock-model-id") + + +class TestOpenAIMixinCacheBehavior: + """Test cases for cache behavior and edge cases""" + + async def test_cache_overwrites_on_list_models_call(self, mixin, mock_client_with_models, mock_client_context): + """Test that calling list_models overwrites existing cache""" + initial_model = Model( + provider_id="test-provider", + provider_resource_id="old-model", + identifier="old-model", + model_type=ModelType.llm, + ) + mixin._model_cache = {"old-model": initial_model} + + with mock_client_context(mixin, mock_client_with_models): + await mixin.list_models() + + assert len(mixin._model_cache) == 3 + assert "old-model" not in mixin._model_cache + assert "some-mock-model-id" in mixin._model_cache + assert "another-mock-model-id" in mixin._model_cache + assert "final-mock-model-id" in mixin._model_cache + + +class TestOpenAIMixinImagePreprocessing: + """Test cases for image preprocessing functionality""" + + async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin): + """Test that image URLs are converted to base64 when download_images is True""" + mixin.download_images = True + + message = OpenAIUserMessageParam( + role="user", + content=[ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}}, + ], + ) + + mock_client = MagicMock() + mock_response = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=mock_response) + + with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client): + with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize: + mock_localize.return_value = (b"fake_image_data", "jpeg") + + params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message]) + await mixin.openai_chat_completion(params) + + mock_localize.assert_called_once_with("http://example.com/image.jpg") + + mock_client.chat.completions.create.assert_called_once() + call_args = mock_client.chat.completions.create.call_args + processed_messages = call_args[1]["messages"] + assert len(processed_messages) == 1 + content = processed_messages[0]["content"] + assert len(content) == 2 + assert content[0]["type"] == "text" + assert content[1]["type"] == "image_url" + assert content[1]["image_url"]["url"] == "data:image/jpeg;base64,ZmFrZV9pbWFnZV9kYXRh" + + async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin): + """Test that image URLs are not modified when download_images is False""" + mixin.download_images = False # explicitly set to False + + message = OpenAIUserMessageParam( + role="user", + content=[ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}}, + ], + ) + + mock_client = MagicMock() + mock_response = MagicMock() + mock_client.chat.completions.create = AsyncMock(return_value=mock_response) + + with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client): + with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize: + params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message]) + await mixin.openai_chat_completion(params) + + mock_localize.assert_not_called() + + mock_client.chat.completions.create.assert_called_once() + call_args = mock_client.chat.completions.create.call_args + processed_messages = call_args[1]["messages"] + assert len(processed_messages) == 1 + content = processed_messages[0]["content"] + assert len(content) == 2 + assert content[1]["image_url"]["url"] == "http://example.com/image.jpg" + + +class TestOpenAIMixinEmbeddingModelMetadata: + """Test cases for embedding_model_metadata attribute functionality""" + + async def test_embedding_model_identified_and_augmented(self, mixin_with_embeddings, mock_client_context): + """Test that models in embedding_model_metadata are correctly identified as embeddings with metadata""" + # Create mock models: 1 embedding model and 1 LLM, while there are 2 known embedding models + mock_embedding_model = MagicMock(id="text-embedding-3-small") + mock_llm_model = MagicMock(id="gpt-4") + mock_models = [mock_embedding_model, mock_llm_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + with mock_client_context(mixin_with_embeddings, mock_client): + result = await mixin_with_embeddings.list_models() + + assert result is not None + assert len(result) == 2 + + expected_models = { + "text-embedding-3-small": { + "model_type": ModelType.embedding, + "metadata": {"embedding_dimension": 1536, "context_length": 8192}, + "provider_id": "test-provider", + "provider_resource_id": "text-embedding-3-small", + }, + "gpt-4": { + "model_type": ModelType.llm, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "gpt-4", + }, + } + + _assert_models_match_expected(result, expected_models) + + +class TestOpenAIMixinCustomModelConstruction: + """Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier""" + + async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context): + """Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata""" + # Create mock models: 1 embedding, 1 rerank, 1 LLM + mock_embedding_model = MagicMock(id="text-embedding-3-small") + mock_rerank_model = MagicMock(id="rerank-model-1") + mock_llm_model = MagicMock(id="gpt-4") + mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + with mock_client_context(mixin_with_custom_model_construction, mock_client): + result = await mixin_with_custom_model_construction.list_models() + + assert result is not None + assert len(result) == 3 + + expected_models = { + "text-embedding-3-small": { + "model_type": ModelType.embedding, + "metadata": {"embedding_dimension": 1536, "context_length": 8192}, + "provider_id": "test-provider", + "provider_resource_id": "text-embedding-3-small", + }, + "rerank-model-1": { + "model_type": ModelType.rerank, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "rerank-model-1", + }, + "gpt-4": { + "model_type": ModelType.llm, + "metadata": {}, + "provider_id": "test-provider", + "provider_resource_id": "gpt-4", + }, + } + + _assert_models_match_expected(result, expected_models) + + +class TestOpenAIMixinAllowedModels: + """Test cases for allowed_models filtering functionality""" + + async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context): + """Test that list_models filters models based on allowed_models""" + mixin.config.allowed_models = ["some-mock-model-id", "another-mock-model-id"] + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 2 + + model_ids = [model.identifier for model in result] + assert "some-mock-model-id" in model_ids + assert "another-mock-model-id" in model_ids + assert "final-mock-model-id" not in model_ids + + async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context): + """Test that empty allowed_models allows no models""" + mixin.config.allowed_models = [] + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 0 # No models should be included + + async def test_list_models_with_omitted_allowed_models(self, mixin, mock_client_with_models, mock_client_context): + """Test that omitted allowed_models allows all models""" + assert mixin.config.allowed_models is None + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 3 # All models should be included + + model_ids = [model.identifier for model in result] + assert "some-mock-model-id" in model_ids + assert "another-mock-model-id" in model_ids + assert "final-mock-model-id" in model_ids + + async def test_check_model_availability_with_allowed_models( + self, mixin, mock_client_with_models, mock_client_context + ): + """Test that check_model_availability respects allowed_models""" + mixin.config.allowed_models = ["final-mock-model-id"] + + with mock_client_context(mixin, mock_client_with_models): + assert await mixin.check_model_availability("final-mock-model-id") + assert not await mixin.check_model_availability("some-mock-model-id") + assert not await mixin.check_model_availability("another-mock-model-id") + + +class TestOpenAIMixinModelRegistration: + """Test cases for model registration functionality""" + + async def test_register_model_success(self, mock_client_with_models, mock_client_context): + """Test successful model registration when model is available""" + config = RemoteInferenceProviderConfig() + mixin = OpenAIMixinImpl(config=config) + + # Enable validation for this model + model = Model( + provider_id="test-provider", + provider_resource_id="some-mock-model-id", + identifier="test-model", + model_type=ModelType.llm, + model_validation=True, + ) + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.register_model(model) + + assert result == model + assert result.provider_id == "test-provider" + assert result.provider_resource_id == "some-mock-model-id" + assert result.identifier == "test-model" + assert result.model_type == ModelType.llm + mock_client_with_models.models.list.assert_called_once() + + async def test_register_model_not_available(self, mock_client_with_models, mock_client_context): + """Test model registration failure when model is not available from provider""" + config = RemoteInferenceProviderConfig() + mixin = OpenAIMixinImpl(config=config) + + # Enable validation for this model + model = Model( + provider_id="test-provider", + provider_resource_id="non-existent-model", + identifier="test-model", + model_type=ModelType.llm, + model_validation=True, + ) + + with mock_client_context(mixin, mock_client_with_models): + with pytest.raises( + ValueError, match="Model non-existent-model is not available from provider test-provider" + ): + await mixin.register_model(model) + mock_client_with_models.models.list.assert_called_once() + + async def test_register_model_with_allowed_models_filter(self, mock_client_with_models, mock_client_context): + """Test model registration with allowed_models filtering""" + config = RemoteInferenceProviderConfig(allowed_models=["some-mock-model-id"]) + mixin = OpenAIMixinImpl(config=config) + + # Test with allowed model (with validation enabled) + allowed_model = Model( + provider_id="test-provider", + provider_resource_id="some-mock-model-id", + identifier="allowed-model", + model_type=ModelType.llm, + model_validation=True, + ) + + # Test with disallowed model (with validation enabled) + disallowed_model = Model( + provider_id="test-provider", + provider_resource_id="final-mock-model-id", + identifier="disallowed-model", + model_type=ModelType.llm, + model_validation=True, + ) + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.register_model(allowed_model) + assert result == allowed_model + with pytest.raises( + ValueError, match="Model final-mock-model-id is not available from provider test-provider" + ): + await mixin.register_model(disallowed_model) + mock_client_with_models.models.list.assert_called_once() + + async def test_register_embedding_model(self, mixin_with_embeddings, mock_client_context): + """Test registration of embedding models with metadata""" + mock_embedding_model = MagicMock(id="text-embedding-3-small") + mock_models = [mock_embedding_model] + + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + + embedding_model = Model( + provider_id="test-provider", + provider_resource_id="text-embedding-3-small", + identifier="embedding-test", + model_type=ModelType.embedding, + ) + + with mock_client_context(mixin_with_embeddings, mock_client): + result = await mixin_with_embeddings.register_model(embedding_model) + assert result == embedding_model + assert result.model_type == ModelType.embedding + + async def test_unregister_model(self, mixin): + """Test model unregistration (should be no-op)""" + # unregister_model should not raise any exceptions and return None + result = await mixin.unregister_model("any-model-id") + assert result is None + + async def test_should_refresh_models(self, mixin): + """Test should_refresh_models method returns config value""" + # Default config has refresh_models=False + result = await mixin.should_refresh_models() + assert result is False + + # With refresh_models=True, should return True + config_with_refresh = RemoteInferenceProviderConfig(refresh_models=True) + mixin_with_refresh = OpenAIMixinImpl(config=config_with_refresh) + result_with_refresh = await mixin_with_refresh.should_refresh_models() + assert result_with_refresh is True + + async def test_register_model_error_propagation(self, mock_client_with_exception, mock_client_context): + """Test that errors from provider API are properly propagated during registration""" + config = RemoteInferenceProviderConfig() + mixin = OpenAIMixinImpl(config=config) + + # Enable validation for this model + model = Model( + provider_id="test-provider", + provider_resource_id="some-model", + identifier="test-model", + model_type=ModelType.llm, + model_validation=True, + ) + + with mock_client_context(mixin, mock_client_with_exception): + # The exception from the API should be propagated + with pytest.raises(Exception, match="API Error"): + await mixin.register_model(model) + + async def test_register_model_default_behavior_no_validation(self, mock_client_with_models, mock_client_context): + """Test model registration with default behavior (no validation)""" + # Default behavior - no validation + config = RemoteInferenceProviderConfig() + mixin = OpenAIMixinImpl(config=config) + + model = Model( + provider_id="test-provider", + provider_resource_id="non-existent-model", + identifier="test-model", + model_type=ModelType.llm, + ) + + with mock_client_context(mixin, mock_client_with_models): + # Should succeed without checking model availability (default behavior) + result = await mixin.register_model(model) + + assert result == model + # Verify that models.list() was NOT called + mock_client_with_models.models.list.assert_not_called() + + async def test_register_model_with_validation_enabled(self, mock_client_with_models, mock_client_context): + """Test that model-level model_validation=True enables validation""" + # Default config (no provider-level validation setting) + config = RemoteInferenceProviderConfig() + mixin = OpenAIMixinImpl(config=config) + + # Model explicitly enables validation + model = Model( + provider_id="test-provider", + provider_resource_id="non-existent-model", + identifier="test-model", + model_type=ModelType.llm, + model_validation=True, + ) + + with mock_client_context(mixin, mock_client_with_models): + # Should fail because model-level validation is enabled + with pytest.raises(ValueError, match="Model non-existent-model is not available"): + await mixin.register_model(model) + # Verify that models.list() WAS called (validation happened) + mock_client_with_models.models.list.assert_called_once() + + async def test_register_model_with_validation_explicitly_disabled( + self, mock_client_with_models, mock_client_context + ): + """Test that model-level model_validation=False explicitly disables validation""" + # Default config + config = RemoteInferenceProviderConfig() + mixin = OpenAIMixinImpl(config=config) + + # Model explicitly disables validation (though this is the default anyway) + model = Model( + provider_id="test-provider", + provider_resource_id="non-existent-model", + identifier="test-model", + model_type=ModelType.llm, + model_validation=False, + ) + + with mock_client_context(mixin, mock_client_with_models): + # Should succeed because validation is disabled + result = await mixin.register_model(model) + + assert result == model + # Verify that models.list() was NOT called + mock_client_with_models.models.list.assert_not_called() diff --git a/tests/unit/providers/utils/inference/test_openai_mixin_provider_data.py b/tests/unit/providers/utils/inference/test_openai_mixin_provider_data.py new file mode 100644 index 0000000000..7f25f7f98d --- /dev/null +++ b/tests/unit/providers/utils/inference/test_openai_mixin_provider_data.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from collections.abc import Iterable +from unittest.mock import MagicMock, Mock + +import pytest + +from llama_stack.core.request_headers import request_provider_data_context +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig +from tests.unit.providers.utils.inference.openai_mixin_helpers import ( + CustomListProviderModelIdsImplementation, + OpenAIMixinImpl, + OpenAIMixinWithProviderData, + ProviderDataValidator, # noqa: F401 — referenced by provider_data_validator string path +) + + +class TestOpenAIMixinCustomListProviderModelIds: + """Test cases for custom list_provider_model_ids() implementation functionality""" + + @pytest.fixture + def custom_model_ids_list(self): + """Create a list of custom model ID strings""" + return ["custom-model-1", "custom-model-2", "custom-embedding"] + + @pytest.fixture + def config(self): + """Create RemoteInferenceProviderConfig instance""" + return RemoteInferenceProviderConfig() + + @pytest.fixture + def adapter(self, custom_model_ids_list, config): + """Create mixin instance with custom list_provider_model_ids implementation""" + mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=custom_model_ids_list) + mixin.embedding_model_metadata = {"custom-embedding": {"embedding_dimension": 768, "context_length": 512}} + return mixin + + async def test_is_used(self, adapter, custom_model_ids_list): + """Test that custom list_provider_model_ids() implementation is used instead of client.models.list()""" + result = await adapter.list_models() + + assert result is not None + assert len(result) == 3 + + assert set(custom_model_ids_list) == {m.identifier for m in result} + + async def test_populates_cache(self, adapter, custom_model_ids_list): + """Test that custom list_provider_model_ids() results are cached""" + assert len(adapter._model_cache) == 0 + + await adapter.list_models() + + assert set(custom_model_ids_list) == set(adapter._model_cache.keys()) + + async def test_respects_allowed_models(self, config): + """Test that custom list_provider_model_ids() respects allowed_models filtering""" + mixin = CustomListProviderModelIdsImplementation( + config=config, custom_model_ids=["model-1", "model-2", "model-3"] + ) + mixin.config.allowed_models = ["model-1"] + + result = await mixin.list_models() + + assert result is not None + assert len(result) == 1 + assert result[0].identifier == "model-1" + + async def test_with_empty_list(self, config): + """Test that custom list_provider_model_ids() handles empty list correctly""" + mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[]) + + result = await mixin.list_models() + + assert result is not None + assert len(result) == 0 + assert len(mixin._model_cache) == 0 + + async def test_wrong_type_raises_error(self, config): + """Test that list_provider_model_ids() returning unhashable items results in an error""" + mixin = CustomListProviderModelIdsImplementation( + config=config, custom_model_ids=["valid-model", ["nested", "list"]] + ) + with pytest.raises(Exception, match="is not a string"): + await mixin.list_models() + + mixin = CustomListProviderModelIdsImplementation( + config=config, custom_model_ids=[{"key": "value"}, "valid-model"] + ) + with pytest.raises(Exception, match="is not a string"): + await mixin.list_models() + + mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=["valid-model", 42.0]) + with pytest.raises(Exception, match="is not a string"): + await mixin.list_models() + + mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=[None]) + with pytest.raises(Exception, match="is not a string"): + await mixin.list_models() + + async def test_non_iterable_raises_error(self, config): + """Test that list_provider_model_ids() returning non-iterable type raises error""" + mixin = CustomListProviderModelIdsImplementation(config=config, custom_model_ids=42) + + with pytest.raises( + TypeError, + match=r"Failed to list models: CustomListProviderModelIdsImplementation\.list_provider_model_ids\(\) must return an iterable.*but returned int", + ): + await mixin.list_models() + + async def test_accepts_various_iterables(self, config): + """Test that list_provider_model_ids() accepts tuples, sets, generators, etc.""" + + tuples = CustomListProviderModelIdsImplementation( + config=config, custom_model_ids=("model-1", "model-2", "model-3") + ) + result = await tuples.list_models() + assert result is not None + assert len(result) == 3 + + class GeneratorAdapter(OpenAIMixinImpl): + async def list_provider_model_ids(self) -> Iterable[str]: + def gen(): + yield "gen-model-1" + yield "gen-model-2" + + return gen() + + mixin = GeneratorAdapter(config=config) + result = await mixin.list_models() + assert result is not None + assert len(result) == 2 + + sets = CustomListProviderModelIdsImplementation(config=config, custom_model_ids={"set-model-1", "set-model-2"}) + result = await sets.list_models() + assert result is not None + assert len(result) == 2 + + +class TestOpenAIMixinProviderDataApiKey: + """Test cases for provider_data_api_key_field functionality""" + + @pytest.fixture + def mixin_with_provider_data_field(self): + """Mixin instance with provider_data_api_key_field set""" + config = RemoteInferenceProviderConfig() + mixin_instance = OpenAIMixinWithProviderData(config=config) + + # Mock provider_spec for provider data validation + mock_provider_spec = MagicMock() + mock_provider_spec.provider_type = "test-provider-with-data" + mock_provider_spec.provider_data_validator = ( + "tests.unit.providers.utils.inference.openai_mixin_helpers.ProviderDataValidator" + ) + mixin_instance.__provider_spec__ = mock_provider_spec + + return mixin_instance + + @pytest.fixture + def mixin_with_provider_data_field_and_none_api_key(self, mixin_with_provider_data_field): + mixin_with_provider_data_field.get_api_key = Mock(return_value=None) + return mixin_with_provider_data_field + + def test_no_provider_data(self, mixin_with_provider_data_field): + """Test that client uses config API key when no provider data is available""" + assert mixin_with_provider_data_field.client.api_key == "default-api-key" + + def test_with_provider_data(self, mixin_with_provider_data_field): + """Test that provider data API key overrides config API key""" + with request_provider_data_context( + {"x-llamastack-provider-data": json.dumps({"test_api_key": "provider-data-key"})} + ): + assert mixin_with_provider_data_field.client.api_key == "provider-data-key" + + def test_with_wrong_key(self, mixin_with_provider_data_field): + """Test fallback to config when provider data doesn't have the required key""" + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}): + assert mixin_with_provider_data_field.client.api_key == "default-api-key" + + def test_error_when_no_config_and_provider_data_has_wrong_key( + self, mixin_with_provider_data_field_and_none_api_key + ): + """Test that ValueError is raised when provider data exists but doesn't have required key""" + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}): + with pytest.raises(ValueError, match="API key not provided"): + _ = mixin_with_provider_data_field_and_none_api_key.client + + def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key): + """Test that error message includes correct field name and header information""" + with pytest.raises(ValueError) as exc_info: + _ = mixin_with_provider_data_field_and_none_api_key.client + + error_message = str(exc_info.value) + assert "test_api_key" in error_message + assert "x-llamastack-provider-data" in error_message From 879f7bf7a81a64c90995c116b8db7535f2233f90 Mon Sep 17 00:00:00 2001 From: skamenan7 Date: Wed, 25 Mar 2026 14:17:28 -0400 Subject: [PATCH 2/4] refactor(tests): split test_openai_responses.py unit tests by feature group Split 3226-line test_openai_responses.py into: - conftest.py: shared fixtures (116L, auto-discovered) - test_openai_responses_helpers.py: fake_stream generator (47L) - test_openai_responses_core.py: core creation, instructions, store logic (857L) - test_openai_responses_agent.py: failed streams, agent loop incomplete (500L) - test_openai_responses_tools.py: tool calls, MCP, file search (670L) - test_openai_responses_prompts.py: prompt template variants (579L) - test_openai_responses_params.py: param passthrough, service_tier, stream_options (856L) Removes test_openai_responses.py from GRANDFATHERED_FILES in check_file_size.py. 103 tests pass across all new files. Signed-off-by: skamenan7 --- scripts/check_file_size.py | 1 - .../providers/responses/builtin/conftest.py | 107 + .../builtin/test_openai_responses.py | 3226 ----------------- .../builtin/test_openai_responses_agent.py | 439 +++ .../builtin/test_openai_responses_core.py | 810 +++++ .../builtin/test_openai_responses_helpers.py | 46 + .../builtin/test_openai_responses_params.py | 811 +++++ .../builtin/test_openai_responses_prompts.py | 518 +++ .../builtin/test_openai_responses_tools.py | 621 ++++ 9 files changed, 3352 insertions(+), 3227 deletions(-) create mode 100644 tests/unit/providers/responses/builtin/conftest.py delete mode 100644 tests/unit/providers/responses/builtin/test_openai_responses.py create mode 100644 tests/unit/providers/responses/builtin/test_openai_responses_agent.py create mode 100644 tests/unit/providers/responses/builtin/test_openai_responses_core.py create mode 100644 tests/unit/providers/responses/builtin/test_openai_responses_helpers.py create mode 100644 tests/unit/providers/responses/builtin/test_openai_responses_params.py create mode 100644 tests/unit/providers/responses/builtin/test_openai_responses_prompts.py create mode 100644 tests/unit/providers/responses/builtin/test_openai_responses_tools.py diff --git a/scripts/check_file_size.py b/scripts/check_file_size.py index ec708ccc6e..0a179edc07 100755 --- a/scripts/check_file_size.py +++ b/scripts/check_file_size.py @@ -37,7 +37,6 @@ "tests/integration/vector_io/test_openai_vector_stores.py", "tests/integration/responses/test_openai_responses.py", "tests/integration/responses/test_tool_responses.py", - "tests/unit/providers/responses/builtin/test_openai_responses.py", } diff --git a/tests/unit/providers/responses/builtin/conftest.py b/tests/unit/providers/responses/builtin/conftest.py new file mode 100644 index 0000000000..8b215799b7 --- /dev/null +++ b/tests/unit/providers/responses/builtin/conftest.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.providers.inline.responses.builtin.responses.openai_responses import ( + OpenAIResponsesImpl, +) +from llama_stack.providers.utils.responses.responses_store import ( + ResponsesStore, +) +from llama_stack_api import Connectors +from llama_stack_api.tools import ToolGroups, ToolRuntime + + +@pytest.fixture +def mock_inference_api(): + inference_api = AsyncMock() + return inference_api + + +@pytest.fixture +def mock_tool_groups_api(): + tool_groups_api = AsyncMock(spec=ToolGroups) + return tool_groups_api + + +@pytest.fixture +def mock_tool_runtime_api(): + tool_runtime_api = AsyncMock(spec=ToolRuntime) + return tool_runtime_api + + +@pytest.fixture +def mock_responses_store(): + responses_store = AsyncMock(spec=ResponsesStore) + return responses_store + + +@pytest.fixture +def mock_vector_io_api(): + vector_io_api = AsyncMock() + return vector_io_api + + +@pytest.fixture +def mock_conversations_api(): + """Mock conversations API for testing.""" + mock_api = AsyncMock() + return mock_api + + +@pytest.fixture +def mock_safety_api(): + safety_api = AsyncMock() + return safety_api + + +@pytest.fixture +def mock_prompts_api(): + prompts_api = AsyncMock() + return prompts_api + + +@pytest.fixture +def mock_files_api(): + """Mock files API for testing.""" + files_api = AsyncMock() + return files_api + + +@pytest.fixture +def mock_connectors_api(): + connectors_api = AsyncMock(spec=Connectors) + return connectors_api + + +@pytest.fixture +def openai_responses_impl( + mock_inference_api, + mock_tool_groups_api, + mock_tool_runtime_api, + mock_responses_store, + mock_vector_io_api, + mock_safety_api, + mock_conversations_api, + mock_prompts_api, + mock_files_api, + mock_connectors_api, +): + return OpenAIResponsesImpl( + inference_api=mock_inference_api, + tool_groups_api=mock_tool_groups_api, + tool_runtime_api=mock_tool_runtime_api, + responses_store=mock_responses_store, + vector_io_api=mock_vector_io_api, + safety_api=mock_safety_api, + conversations_api=mock_conversations_api, + prompts_api=mock_prompts_api, + files_api=mock_files_api, + connectors_api=mock_connectors_api, + ) diff --git a/tests/unit/providers/responses/builtin/test_openai_responses.py b/tests/unit/providers/responses/builtin/test_openai_responses.py deleted file mode 100644 index 979ae114e1..0000000000 --- a/tests/unit/providers/responses/builtin/test_openai_responses.py +++ /dev/null @@ -1,3226 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from openai.types.chat.chat_completion_chunk import ( - ChatCompletionChunk, - Choice, - ChoiceDelta, - ChoiceDeltaToolCall, - ChoiceDeltaToolCallFunction, -) - -from llama_stack.core.access_control.access_control import default_policy -from llama_stack.core.datatypes import VectorStoresConfig -from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig -from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends -from llama_stack.providers.inline.responses.builtin.responses.openai_responses import ( - OpenAIResponsesImpl, -) -from llama_stack.providers.inline.responses.builtin.responses.tool_executor import ToolExecutor -from llama_stack.providers.remote.inference.openai.config import OpenAIConfig -from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter -from llama_stack.providers.utils.responses.responses_store import ( - ResponsesStore, - _OpenAIResponseObjectWithInputAndMessages, -) -from llama_stack_api import ( - Connectors, - GetConnectorRequest, - GetPromptRequest, - InternalServerError, - InvalidParameterError, - OpenAIChatCompletionContentPartImageParam, - OpenAIFile, - OpenAIFileObject, - OpenAISystemMessageParam, - Order, - Prompt, - ResponseStreamOptions, - ResponseTruncation, -) -from llama_stack_api.inference import ( - OpenAIAssistantMessageParam, - OpenAIChatCompletionContentPartTextParam, - OpenAIChatCompletionRequestWithExtraBody, - OpenAIDeveloperMessageParam, - OpenAIJSONSchema, - OpenAIResponseFormatJSONObject, - OpenAIResponseFormatJSONSchema, - OpenAIUserMessageParam, - ServiceTier, -) -from llama_stack_api.openai_responses import ( - ListOpenAIResponseInputItem, - OpenAIResponseError, - OpenAIResponseInputMessageContentFile, - OpenAIResponseInputMessageContentImage, - OpenAIResponseInputMessageContentText, - OpenAIResponseInputToolFileSearch, - OpenAIResponseInputToolFunction, - OpenAIResponseInputToolMCP, - OpenAIResponseInputToolWebSearch, - OpenAIResponseMessage, - OpenAIResponseObject, - OpenAIResponseObjectStreamResponseFailed, - OpenAIResponseOutputMessageContentOutputText, - OpenAIResponseOutputMessageFunctionToolCall, - OpenAIResponseOutputMessageMCPCall, - OpenAIResponseOutputMessageWebSearchToolCall, - OpenAIResponsePrompt, - OpenAIResponseText, - OpenAIResponseTextFormat, - WebSearchToolTypes, -) -from llama_stack_api.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime -from llama_stack_api.vector_io import ( - VectorStoreContent, - VectorStoreSearchResponse, - VectorStoreSearchResponsePage, -) -from tests.unit.providers.responses.builtin.fixtures import load_chat_completion_fixture - - -@pytest.fixture -def mock_inference_api(): - inference_api = AsyncMock() - return inference_api - - -@pytest.fixture -def mock_tool_groups_api(): - tool_groups_api = AsyncMock(spec=ToolGroups) - return tool_groups_api - - -@pytest.fixture -def mock_tool_runtime_api(): - tool_runtime_api = AsyncMock(spec=ToolRuntime) - return tool_runtime_api - - -@pytest.fixture -def mock_responses_store(): - responses_store = AsyncMock(spec=ResponsesStore) - return responses_store - - -@pytest.fixture -def mock_vector_io_api(): - vector_io_api = AsyncMock() - return vector_io_api - - -@pytest.fixture -def mock_conversations_api(): - """Mock conversations API for testing.""" - mock_api = AsyncMock() - return mock_api - - -@pytest.fixture -def mock_safety_api(): - safety_api = AsyncMock() - return safety_api - - -@pytest.fixture -def mock_prompts_api(): - prompts_api = AsyncMock() - return prompts_api - - -@pytest.fixture -def mock_files_api(): - """Mock files API for testing.""" - files_api = AsyncMock() - return files_api - - -@pytest.fixture -def mock_connectors_api(): - connectors_api = AsyncMock(spec=Connectors) - return connectors_api - - -@pytest.fixture -def openai_responses_impl( - mock_inference_api, - mock_tool_groups_api, - mock_tool_runtime_api, - mock_responses_store, - mock_vector_io_api, - mock_safety_api, - mock_conversations_api, - mock_prompts_api, - mock_files_api, - mock_connectors_api, -): - return OpenAIResponsesImpl( - inference_api=mock_inference_api, - tool_groups_api=mock_tool_groups_api, - tool_runtime_api=mock_tool_runtime_api, - responses_store=mock_responses_store, - vector_io_api=mock_vector_io_api, - safety_api=mock_safety_api, - conversations_api=mock_conversations_api, - prompts_api=mock_prompts_api, - files_api=mock_files_api, - connectors_api=mock_connectors_api, - ) - - -async def fake_stream(fixture: str = "simple_chat_completion.yaml"): - value = load_chat_completion_fixture(fixture) - yield ChatCompletionChunk( - id=value.id, - choices=[ - Choice( - index=0, - delta=ChoiceDelta( - content=c.message.content, - role=c.message.role, - tool_calls=[ - ChoiceDeltaToolCall( - index=0, - id=t.id, - function=ChoiceDeltaToolCallFunction( - name=t.function.name, - arguments=t.function.arguments, - ), - ) - for t in (c.message.tool_calls or []) - ], - ), - ) - for c in value.choices - ], - created=1, - model=value.model, - object="chat.completion.chunk", - ) - - -async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api): - """Test creating an OpenAI response with a simple string input.""" - # Setup - input_text = "What is the capital of Ireland?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - # Load the chat completion fixture - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - temperature=0.1, - stream=True, # Enable streaming to test content part events - ) - - # For streaming response, collect all chunks - chunks = [chunk async for chunk in result] - - mock_inference_api.openai_chat_completion.assert_called_once_with( - OpenAIChatCompletionRequestWithExtraBody( - model=model, - messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)], - response_format=None, - tools=None, - stream=True, - temperature=0.1, - stream_options={ - "include_usage": True, - }, - ) - ) - - # Should have content part events for text streaming - # Expected: response.created, response.in_progress, content_part.added, output_text.delta, content_part.done, response.completed - assert len(chunks) >= 5 - assert chunks[0].type == "response.created" - assert any(chunk.type == "response.in_progress" for chunk in chunks) - - # Check for content part events - content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"] - content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"] - text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"] - - assert len(content_part_added_events) >= 1, "Should have content_part.added event for text" - assert len(content_part_done_events) >= 1, "Should have content_part.done event for text" - assert len(text_delta_events) >= 1, "Should have text delta events" - - added_event = content_part_added_events[0] - done_event = content_part_done_events[0] - assert added_event.content_index == 0 - assert done_event.content_index == 0 - assert added_event.output_index == done_event.output_index == 0 - assert added_event.item_id == done_event.item_id - assert added_event.response_id == done_event.response_id - - # Verify final event is completion - assert chunks[-1].type == "response.completed" - - # When streaming, the final response is in the last chunk - final_response = chunks[-1].response - assert final_response.model == model - assert len(final_response.output) == 1 - assert isinstance(final_response.output[0], OpenAIResponseMessage) - - -async def test_failed_stream_persists_non_system_messages(openai_responses_impl, mock_responses_store): - input_text = "Hello" - model = "meta-llama/Llama-3.1-8B-Instruct" - - failed_response = OpenAIResponseObject( - created_at=1, - id="resp_failed", - model=model, - output=[], - status="failed", - error=OpenAIResponseError(code="server_error", message="boom"), - store=True, - ) - - class FakeOrchestrator: - def __init__(self, *, ctx, **_kwargs): - self.ctx = ctx - self.final_messages = None - - async def create_response(self): - yield OpenAIResponseObjectStreamResponseFailed(response=failed_response, sequence_number=0) - - with patch( - "llama_stack.providers.inline.responses.builtin.responses.openai_responses.StreamingResponseOrchestrator", - FakeOrchestrator, - ): - stream = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - instructions="system instructions", - stream=True, - store=True, - ) - chunks = [chunk async for chunk in stream] - - assert chunks[-1].type == "response.failed" - mock_responses_store.upsert_response_object.assert_awaited() - - # Find the call that corresponds to the failed response - call_args_list = mock_responses_store.upsert_response_object.call_args_list - failed_call = None - for call in call_args_list: - _, kwargs = call - if kwargs.get("response_object") and kwargs["response_object"].status == "failed": - failed_call = call - break - - assert failed_call is not None, "Expected upsert_response_object to be called with failed response" - _, kwargs = failed_call - messages = kwargs["messages"] - assert messages, "Expected non-system messages to be persisted on failure" - assert all(not isinstance(m, OpenAISystemMessageParam) for m in messages) - assert any(getattr(m, "role", None) == "user" for m in messages) - - -async def test_failed_stream_raises_internal_server_error_in_non_streaming_mode(openai_responses_impl): - """Test that a response.failed event in non-streaming mode raises InternalServerError. - - When stream=False, the caller expects a fully resolved response object, not a stream. - If the underlying stream emits a response.failed event, the implementation must raise - InternalServerError so the caller gets a typed, predictable error rather than a raw - RuntimeError or ValueError. - - Unlike other InternalServerError sites in this file (which guard against internal bugs), - response.failed carries a structured, curated message from the inference backend that - may be directly actionable by the caller (e.g. context window exceeded, invalid prompt). - The message is surfaced to maintain consistency with streaming mode, where the same - response.failed event is returned directly to the caller with the error message visible. - """ - model = "meta-llama/Llama-3.1-8B-Instruct" - provider_error_message = "This model's maximum context length is 4096 tokens" - - failed_response = OpenAIResponseObject( - created_at=1, - id="resp_failed_nonstream", - model=model, - output=[], - status="failed", - error=OpenAIResponseError(code="server_error", message=provider_error_message), - store=False, - ) - - class FakeOrchestrator: - def __init__(self, *, ctx, **_kwargs): - self.ctx = ctx - self.final_messages = None - - async def create_response(self): - yield OpenAIResponseObjectStreamResponseFailed(response=failed_response, sequence_number=0) - - with patch( - "llama_stack.providers.inline.responses.builtin.responses.openai_responses.StreamingResponseOrchestrator", - FakeOrchestrator, - ): - with pytest.raises(InternalServerError) as exc_info: - await openai_responses_impl.create_openai_response( - input="Hello", - model=model, - stream=False, - store=False, - ) - - # The provider message is surfaced to the caller: response.failed errors are - # structured and may be actionable (e.g. context window, invalid prompt). - # This is consistent with streaming mode where the same message is visible. - assert provider_error_message in str(exc_info.value) - - -async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api): - """Test creating an OpenAI response with a simple string input and tools.""" - # Setup - input_text = "What is the capital of Ireland?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - openai_responses_impl.tool_groups_api.get_tool.return_value = ToolDef( - name="web_search", - toolgroup_id="web_search", - description="Search the web for information", - input_schema={ - "type": "object", - "properties": {"query": {"type": "string", "description": "The query to search for"}}, - "required": ["query"], - }, - ) - - openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult( - status="completed", - content="Dublin", - ) - - # Execute - for tool_name in WebSearchToolTypes: - # Reset mock states as we loop through each tool type - mock_inference_api.openai_chat_completion.side_effect = [ - fake_stream("tool_call_completion.yaml"), - fake_stream(), - ] - openai_responses_impl.tool_groups_api.get_tool.reset_mock() - openai_responses_impl.tool_runtime_api.invoke_tool.reset_mock() - openai_responses_impl.responses_store.upsert_response_object.reset_mock() - - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - temperature=0.1, - tools=[ - OpenAIResponseInputToolWebSearch( - name=tool_name, - ) - ], - ) - - # Verify - first_call = mock_inference_api.openai_chat_completion.call_args_list[0] - first_params = first_call.args[0] - assert first_params.messages[0].content == "What is the capital of Ireland?" - assert first_params.tools is not None - assert first_params.temperature == 0.1 - - second_call = mock_inference_api.openai_chat_completion.call_args_list[1] - second_params = second_call.args[0] - assert second_params.messages[-1].content == "Dublin" - assert second_params.temperature == 0.1 - - openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search") - openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with( - tool_name="web_search", - kwargs={"query": "What is the capital of Ireland?"}, - ) - - openai_responses_impl.responses_store.upsert_response_object.assert_called() - - # Check that we got the content from our mocked tool execution result - assert len(result.output) >= 1 - assert isinstance(result.output[1], OpenAIResponseMessage) - assert result.output[1].content[0].text == "Dublin" - assert result.output[1].content[0].annotations == [] - - -async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api): - """Test creating an OpenAI response with a tool call response that has a type of None.""" - # Setup - input_text = "How hot it is in San Francisco today?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - async def fake_stream_toolcall(): - yield ChatCompletionChunk( - id="123", - choices=[ - Choice( - index=0, - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall( - index=0, - id="tc_123", - function=ChoiceDeltaToolCallFunction(name="get_weather", arguments="{}"), - type=None, - ) - ] - ), - ), - ], - created=1, - model=model, - object="chat.completion.chunk", - ) - - mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() - - # Execute - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream=True, - temperature=0.1, - tools=[ - OpenAIResponseInputToolFunction( - name="get_weather", - description="Get current temperature for a given location.", - parameters={ - "location": "string", - }, - ) - ], - ) - - # Check that we got the content from our mocked tool execution result - chunks = [chunk async for chunk in result] - - # Verify event types - # Should have: response.created, response.in_progress, output_item.added, - # function_call_arguments.delta, function_call_arguments.done, output_item.done, response.completed - assert len(chunks) == 7 - - event_types = [chunk.type for chunk in chunks] - assert event_types == [ - "response.created", - "response.in_progress", - "response.output_item.added", - "response.function_call_arguments.delta", - "response.function_call_arguments.done", - "response.output_item.done", - "response.completed", - ] - - # Verify inference API was called correctly (after iterating over result) - first_call = mock_inference_api.openai_chat_completion.call_args_list[0] - first_params = first_call.args[0] - assert first_params.messages[0].content == input_text - assert first_params.tools is not None - assert first_params.temperature == 0.1 - - # Check response.created event (should have empty output) - assert len(chunks[0].response.output) == 0 - - # Check response.completed event (should have the tool call) - completed_chunk = chunks[-1] - assert completed_chunk.type == "response.completed" - assert len(completed_chunk.response.output) == 1 - assert completed_chunk.response.output[0].type == "function_call" - assert completed_chunk.response.output[0].name == "get_weather" - - -async def test_create_openai_response_with_tool_call_function_arguments_none(openai_responses_impl, mock_inference_api): - """Test creating an OpenAI response with tool calls that omit arguments.""" - - input_text = "What is the time right now?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - async def fake_stream_toolcall(): - yield ChatCompletionChunk( - id="123", - choices=[ - Choice( - index=0, - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall( - index=0, - id="tc_123", - function=ChoiceDeltaToolCallFunction(name="get_current_time", arguments=None), - type=None, - ) - ] - ), - ), - ], - created=1, - model=model, - object="chat.completion.chunk", - ) - - def assert_common_expectations(chunks) -> None: - first_call = mock_inference_api.openai_chat_completion.call_args_list[0] - first_params = first_call.args[0] - assert first_params.messages[0].content == input_text - assert first_params.tools is not None - assert first_params.temperature == 0.1 - assert len(chunks[0].response.output) == 0 - completed_chunk = chunks[-1] - assert completed_chunk.type == "response.completed" - assert len(completed_chunk.response.output) == 1 - assert completed_chunk.response.output[0].type == "function_call" - assert completed_chunk.response.output[0].name == "get_current_time" - assert completed_chunk.response.output[0].arguments == "{}" - - # Function does not accept arguments - mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream=True, - temperature=0.1, - tools=[ - OpenAIResponseInputToolFunction( - name="get_current_time", description="Get current time for system's timezone", parameters={} - ) - ], - ) - chunks = [chunk async for chunk in result] - assert [chunk.type for chunk in chunks] == [ - "response.created", - "response.in_progress", - "response.output_item.added", - "response.function_call_arguments.done", - "response.output_item.done", - "response.completed", - ] - assert_common_expectations(chunks) - - # Function accepts optional arguments - mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream=True, - temperature=0.1, - tools=[ - OpenAIResponseInputToolFunction( - name="get_current_time", - description="Get current time for system's timezone", - parameters={"timezone": "string"}, - ) - ], - ) - chunks = [chunk async for chunk in result] - assert [chunk.type for chunk in chunks] == [ - "response.created", - "response.in_progress", - "response.output_item.added", - "response.function_call_arguments.done", - "response.output_item.done", - "response.completed", - ] - assert_common_expectations(chunks) - - # Function accepts optional arguments with additional optional fields - mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream=True, - temperature=0.1, - tools=[ - OpenAIResponseInputToolFunction( - name="get_current_time", - description="Get current time for system's timezone", - parameters={"timezone": "string", "location": "string"}, - ) - ], - ) - chunks = [chunk async for chunk in result] - assert [chunk.type for chunk in chunks] == [ - "response.created", - "response.in_progress", - "response.output_item.added", - "response.function_call_arguments.done", - "response.output_item.done", - "response.completed", - ] - assert_common_expectations(chunks) - mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() - - -async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api, mock_files_api): - """Test creating an OpenAI response with multiple messages.""" - # Setup - input_messages = [ - OpenAIResponseMessage(role="developer", content="You are a helpful assistant", name=None), - OpenAIResponseMessage(role="user", content="Name some towns in Ireland", name=None), - OpenAIResponseMessage( - role="assistant", - content=[ - OpenAIResponseInputMessageContentText(text="Galway, Longford, Sligo"), - OpenAIResponseInputMessageContentText(text="Dublin"), - ], - name=None, - ), - OpenAIResponseMessage(role="user", content="Which is the largest town in Ireland?", name=None), - ] - model = "meta-llama/Llama-3.1-8B-Instruct" - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - await openai_responses_impl.create_openai_response( - input=input_messages, - model=model, - temperature=0.1, - ) - - # Verify the the correct messages were sent to the inference API i.e. - # All of the responses message were convered to the chat completion message objects - call_args = mock_inference_api.openai_chat_completion.call_args_list[0] - params = call_args.args[0] - inference_messages = params.messages - for i, m in enumerate(input_messages): - if isinstance(m.content, str): - assert inference_messages[i].content == m.content - else: - assert inference_messages[i].content[0].text == m.content[0].text - assert isinstance(inference_messages[i].content[0], OpenAIChatCompletionContentPartTextParam) - assert inference_messages[i].role == m.role - if m.role == "user": - assert isinstance(inference_messages[i], OpenAIUserMessageParam) - elif m.role == "assistant": - assert isinstance(inference_messages[i], OpenAIAssistantMessageParam) - else: - assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam) - - -async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store): - """Test prepending a basic previous response to a new response.""" - - input_item_message = OpenAIResponseMessage( - id="123", - content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], - role="user", - ) - response_output_message = OpenAIResponseMessage( - id="123", - content=[OpenAIResponseOutputMessageContentOutputText(text="fake_response")], - status="completed", - role="assistant", - ) - previous_response = _OpenAIResponseObjectWithInputAndMessages( - created_at=1, - id="resp_123", - model="fake_model", - output=[response_output_message], - status="completed", - text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), - input=[input_item_message], - messages=[OpenAIUserMessageParam(content="fake_previous_input")], - store=True, - ) - mock_responses_store.get_response_object.return_value = previous_response - - input = await openai_responses_impl._prepend_previous_response("fake_input", previous_response) - - assert len(input) == 3 - # Check for previous input - assert isinstance(input[0], OpenAIResponseMessage) - assert input[0].content[0].text == "fake_previous_input" - # Check for previous output - assert isinstance(input[1], OpenAIResponseMessage) - assert input[1].content[0].text == "fake_response" - # Check for new input - assert isinstance(input[2], OpenAIResponseMessage) - assert input[2].content == "fake_input" - - -async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store): - """Test prepending a web search previous response to a new response.""" - input_item_message = OpenAIResponseMessage( - id="123", - content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], - role="user", - ) - output_web_search = OpenAIResponseOutputMessageWebSearchToolCall( - id="ws_123", - status="completed", - ) - output_message = OpenAIResponseMessage( - id="123", - content=[OpenAIResponseOutputMessageContentOutputText(text="fake_web_search_response")], - status="completed", - role="assistant", - ) - response = _OpenAIResponseObjectWithInputAndMessages( - created_at=1, - id="resp_123", - model="fake_model", - output=[output_web_search, output_message], - status="completed", - text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), - input=[input_item_message], - messages=[OpenAIUserMessageParam(content="test input")], - store=True, - ) - mock_responses_store.get_response_object.return_value = response - - input_messages = [OpenAIResponseMessage(content="fake_input", role="user")] - input = await openai_responses_impl._prepend_previous_response(input_messages, response) - - assert len(input) == 4 - # Check for previous input - assert isinstance(input[0], OpenAIResponseMessage) - assert input[0].content[0].text == "fake_previous_input" - # Check for previous output web search tool call - assert isinstance(input[1], OpenAIResponseOutputMessageWebSearchToolCall) - # Check for previous output web search response - assert isinstance(input[2], OpenAIResponseMessage) - assert input[2].content[0].text == "fake_web_search_response" - # Check for new input - assert isinstance(input[3], OpenAIResponseMessage) - assert input[3].content == "fake_input" - - -async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mock_responses_store): - """Test prepending a previous response which included an mcp tool call to a new response.""" - input_item_message = OpenAIResponseMessage( - id="123", - content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], - role="user", - ) - output_tool_call = OpenAIResponseOutputMessageMCPCall( - id="ws_123", - name="fake-tool", - arguments="fake-arguments", - server_label="fake-label", - ) - output_message = OpenAIResponseMessage( - id="123", - content=[OpenAIResponseOutputMessageContentOutputText(text="fake_tool_call_response")], - status="completed", - role="assistant", - ) - response = _OpenAIResponseObjectWithInputAndMessages( - created_at=1, - id="resp_123", - model="fake_model", - output=[output_tool_call, output_message], - status="completed", - text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), - input=[input_item_message], - messages=[OpenAIUserMessageParam(content="test input")], - store=True, - ) - mock_responses_store.get_response_object.return_value = response - - input_messages = [OpenAIResponseMessage(content="fake_input", role="user")] - input = await openai_responses_impl._prepend_previous_response(input_messages, response) - - assert len(input) == 4 - # Check for previous input - assert isinstance(input[0], OpenAIResponseMessage) - assert input[0].content[0].text == "fake_previous_input" - # Check for previous output MCP tool call - assert isinstance(input[1], OpenAIResponseOutputMessageMCPCall) - # Check for previous output web search response - assert isinstance(input[2], OpenAIResponseMessage) - assert input[2].content[0].text == "fake_tool_call_response" - # Check for new input - assert isinstance(input[3], OpenAIResponseMessage) - assert input[3].content == "fake_input" - - -async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api): - # Setup - input_text = "What is the capital of Ireland?" - model = "meta-llama/Llama-3.1-8B-Instruct" - instructions = "You are a geography expert. Provide concise answers." - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - instructions=instructions, - ) - - # Verify - mock_inference_api.openai_chat_completion.assert_called_once() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - sent_messages = params.messages - - # Check that instructions were prepended as a system message - assert len(sent_messages) == 2 - assert sent_messages[0].role == "system" - assert sent_messages[0].content == instructions - assert sent_messages[1].role == "user" - assert sent_messages[1].content == input_text - - -async def test_create_openai_response_with_instructions_and_multiple_messages( - openai_responses_impl, mock_inference_api, mock_files_api -): - # Setup - input_messages = [ - OpenAIResponseMessage(role="user", content="Name some towns in Ireland", name=None), - OpenAIResponseMessage( - role="assistant", - content="Galway, Longford, Sligo", - name=None, - ), - OpenAIResponseMessage(role="user", content="Which is the largest?", name=None), - ] - model = "meta-llama/Llama-3.1-8B-Instruct" - instructions = "You are a geography expert. Provide concise answers." - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - await openai_responses_impl.create_openai_response( - input=input_messages, - model=model, - instructions=instructions, - ) - - # Verify - mock_inference_api.openai_chat_completion.assert_called_once() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - sent_messages = params.messages - - # Check that instructions were prepended as a system message - assert len(sent_messages) == 4 # 1 system + 3 input messages - assert sent_messages[0].role == "system" - assert sent_messages[0].content == instructions - - # Check the rest of the messages were converted correctly - assert sent_messages[1].role == "user" - assert sent_messages[1].content == "Name some towns in Ireland" - assert sent_messages[2].role == "assistant" - assert sent_messages[2].content == "Galway, Longford, Sligo" - assert sent_messages[3].role == "user" - assert sent_messages[3].content == "Which is the largest?" - - -async def test_create_openai_response_with_instructions_and_previous_response( - openai_responses_impl, mock_responses_store, mock_inference_api -): - """Test prepending both instructions and previous response.""" - - input_item_message = OpenAIResponseMessage( - id="123", - content="Name some towns in Ireland", - role="user", - ) - response_output_message = OpenAIResponseMessage( - id="123", - content="Galway, Longford, Sligo", - status="completed", - role="assistant", - ) - response = _OpenAIResponseObjectWithInputAndMessages( - created_at=1, - id="resp_123", - model="fake_model", - output=[response_output_message], - status="completed", - text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), - input=[input_item_message], - messages=[ - OpenAIUserMessageParam(content="Name some towns in Ireland"), - OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"), - ], - store=True, - ) - mock_responses_store.get_response_object.return_value = response - - model = "meta-llama/Llama-3.1-8B-Instruct" - instructions = "You are a geography expert. Provide concise answers." - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - await openai_responses_impl.create_openai_response( - input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123" - ) - - # Verify - mock_inference_api.openai_chat_completion.assert_called_once() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - sent_messages = params.messages - - # Check that instructions were prepended as a system message - assert len(sent_messages) == 4, sent_messages - assert sent_messages[0].role == "system" - assert sent_messages[0].content == instructions - - # Check the rest of the messages were converted correctly - assert sent_messages[1].role == "user" - assert sent_messages[1].content == "Name some towns in Ireland" - assert sent_messages[2].role == "assistant" - assert sent_messages[2].content == "Galway, Longford, Sligo" - assert sent_messages[3].role == "user" - assert sent_messages[3].content == "Which is the largest?" - - -async def test_create_openai_response_with_previous_response_instructions( - openai_responses_impl, mock_responses_store, mock_inference_api -): - """Test prepending instructions and previous response with instructions.""" - - input_item_message = OpenAIResponseMessage( - id="123", - content="Name some towns in Ireland", - role="user", - ) - response_output_message = OpenAIResponseMessage( - id="123", - content="Galway, Longford, Sligo", - status="completed", - role="assistant", - ) - response = _OpenAIResponseObjectWithInputAndMessages( - created_at=1, - id="resp_123", - model="fake_model", - output=[response_output_message], - status="completed", - text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), - input=[input_item_message], - messages=[ - OpenAIUserMessageParam(content="Name some towns in Ireland"), - OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"), - ], - instructions="You are a helpful assistant.", - store=True, - ) - mock_responses_store.get_response_object.return_value = response - - model = "meta-llama/Llama-3.1-8B-Instruct" - instructions = "You are a geography expert. Provide concise answers." - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - await openai_responses_impl.create_openai_response( - input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123" - ) - - # Verify - mock_inference_api.openai_chat_completion.assert_called_once() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - sent_messages = params.messages - - # Check that instructions were prepended as a system message - # and that the previous response instructions were not carried over - assert len(sent_messages) == 4, sent_messages - assert sent_messages[0].role == "system" - assert sent_messages[0].content == instructions - - # Check the rest of the messages were converted correctly - assert sent_messages[1].role == "user" - assert sent_messages[1].content == "Name some towns in Ireland" - assert sent_messages[2].role == "assistant" - assert sent_messages[2].content == "Galway, Longford, Sligo" - assert sent_messages[3].role == "user" - assert sent_messages[3].content == "Which is the largest?" - - -async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store): - """Test that list_openai_response_input_items properly delegates to responses_store with correct parameters.""" - # Setup - response_id = "resp_123" - after = "msg_after" - before = "msg_before" - include = ["metadata"] - limit = 5 - order = Order.asc - - input_message = OpenAIResponseMessage( - id="msg_123", - content="Test message", - role="user", - ) - - expected_result = ListOpenAIResponseInputItem(data=[input_message]) - mock_responses_store.list_response_input_items.return_value = expected_result - - # Execute with all parameters to test delegation - result = await openai_responses_impl.list_openai_response_input_items( - response_id, after=after, before=before, include=include, limit=limit, order=order - ) - - # Verify all parameters are passed through correctly to the store - mock_responses_store.list_response_input_items.assert_called_once_with( - response_id, after, before, include, limit, order - ) - - # Verify the result is returned as-is from the store - assert result.object == "list" - assert len(result.data) == 1 - assert result.data[0].id == "msg_123" - - -async def test_responses_store_list_input_items_logic(): - """Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting.""" - - # Create mock store and response store - mock_sql_store = AsyncMock() - backend_name = "sql_responses_test" - register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")}) - responses_store = ResponsesStore( - ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy() - ) - responses_store.sql_store = mock_sql_store - - # Setup test data - multiple input items - input_items = [ - OpenAIResponseMessage(id="msg_1", content="First message", role="user"), - OpenAIResponseMessage(id="msg_2", content="Second message", role="user"), - OpenAIResponseMessage(id="msg_3", content="Third message", role="user"), - OpenAIResponseMessage(id="msg_4", content="Fourth message", role="user"), - ] - - response_with_input = _OpenAIResponseObjectWithInputAndMessages( - id="resp_123", - model="test_model", - created_at=1234567890, - object="response", - status="completed", - output=[], - text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))), - input=input_items, - messages=[OpenAIUserMessageParam(content="First message")], - store=True, - ) - - # Mock the get_response_object method to return our test data - mock_sql_store.fetch_one.return_value = {"response_object": response_with_input.model_dump()} - - # Test 1: Default behavior (no limit, desc order) - result = await responses_store.list_response_input_items("resp_123") - assert result.object == "list" - assert len(result.data) == 4 - # Should be reversed for desc order - assert result.data[0].id == "msg_4" - assert result.data[1].id == "msg_3" - assert result.data[2].id == "msg_2" - assert result.data[3].id == "msg_1" - - # Test 2: With limit=2, desc order - result = await responses_store.list_response_input_items("resp_123", limit=2, order=Order.desc) - assert result.object == "list" - assert len(result.data) == 2 - # Should be first 2 items in desc order - assert result.data[0].id == "msg_4" - assert result.data[1].id == "msg_3" - - # Test 3: With limit=2, asc order - result = await responses_store.list_response_input_items("resp_123", limit=2, order=Order.asc) - assert result.object == "list" - assert len(result.data) == 2 - # Should be first 2 items in original order (asc) - assert result.data[0].id == "msg_1" - assert result.data[1].id == "msg_2" - - # Test 4: Asc order without limit - result = await responses_store.list_response_input_items("resp_123", order=Order.asc) - assert result.object == "list" - assert len(result.data) == 4 - # Should be in original order (asc) - assert result.data[0].id == "msg_1" - assert result.data[1].id == "msg_2" - assert result.data[2].id == "msg_3" - assert result.data[3].id == "msg_4" - - # Test 5: Large limit (larger than available items) - result = await responses_store.list_response_input_items("resp_123", limit=10, order=Order.desc) - assert result.object == "list" - assert len(result.data) == 4 # Should return all available items - assert result.data[0].id == "msg_4" - - # Test 6: Zero limit edge case - result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc) - assert result.object == "list" - assert len(result.data) == 0 # Should return no items - - -async def test_store_response_uses_rehydrated_input_with_previous_response( - openai_responses_impl, mock_responses_store, mock_inference_api -): - """Test that _store_response uses the full re-hydrated input (including previous responses) - rather than just the original input when previous_response_id is provided.""" - - # Setup - Create a previous response that should be included in the stored input - previous_response = _OpenAIResponseObjectWithInputAndMessages( - id="resp-previous-123", - object="response", - created_at=1234567890, - model="meta-llama/Llama-3.1-8B-Instruct", - status="completed", - text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), - input=[ - OpenAIResponseMessage( - id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")] - ) - ], - output=[ - OpenAIResponseMessage( - id="msg-prev-assistant", - role="assistant", - content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")], - ) - ], - messages=[ - OpenAIUserMessageParam(content="What is 2+2?"), - OpenAIAssistantMessageParam(content="2+2 equals 4."), - ], - store=True, - ) - - mock_responses_store.get_response_object.return_value = previous_response - - current_input = "Now what is 3+3?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - Create response with previous_response_id - result = await openai_responses_impl.create_openai_response( - input=current_input, - model=model, - previous_response_id="resp-previous-123", - store=True, - ) - - store_call_args = mock_responses_store.upsert_response_object.call_args - stored_input = store_call_args.kwargs["input"] - - # Verify that the stored input contains the full re-hydrated conversation: - # 1. Previous user message - # 2. Previous assistant response - # 3. Current user message - assert len(stored_input) == 3 - - assert stored_input[0].role == "user" - assert stored_input[0].content[0].text == "What is 2+2?" - - assert stored_input[1].role == "assistant" - assert stored_input[1].content[0].text == "2+2 equals 4." - - assert stored_input[2].role == "user" - assert stored_input[2].content == "Now what is 3+3?" - - # Verify the response itself is correct - assert result.model == model - assert result.status == "completed" - - -@patch("llama_stack.providers.inline.responses.builtin.responses.streaming.list_mcp_tools") -async def test_reuse_mcp_tool_list( - mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api -): - """Test that mcp_list_tools can be reused where appropriate.""" - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - mock_list_mcp_tools.return_value = ListToolDefsResponse( - data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})] - ) - - res1 = await openai_responses_impl.create_openai_response( - input="What is 2+2?", - model="meta-llama/Llama-3.1-8B-Instruct", - store=True, - tools=[ - OpenAIResponseInputToolFunction(name="fake", parameters=None), - OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"), - ], - ) - args = mock_responses_store.upsert_response_object.call_args - data = args.kwargs["response_object"].model_dump() - data["input"] = [input_item.model_dump() for input_item in args.kwargs["input"]] - data["messages"] = [msg.model_dump() for msg in args.kwargs["messages"]] - stored = _OpenAIResponseObjectWithInputAndMessages(**data) - mock_responses_store.get_response_object.return_value = stored - - res2 = await openai_responses_impl.create_openai_response( - previous_response_id=res1.id, - input="Now what is 3+3?", - model="meta-llama/Llama-3.1-8B-Instruct", - store=True, - tools=[ - OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"), - ], - ) - assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2 - second_call = mock_inference_api.openai_chat_completion.call_args_list[1] - second_params = second_call.args[0] - tools_seen = second_params.tools - assert len(tools_seen) == 1 - assert tools_seen[0]["function"]["name"] == "test_tool" - assert tools_seen[0]["function"]["description"] == "a test tool" - - assert mock_list_mcp_tools.call_count == 1 - listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"] - assert len(listings) == 1 - assert listings[0].server_label == "alabel" - assert len(listings[0].tools) == 1 - assert listings[0].tools[0].name == "test_tool" - - -@pytest.mark.parametrize( - "text_format, response_format", - [ - (OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), None), - ( - OpenAIResponseText(format=OpenAIResponseTextFormat(name="Test", schema={"foo": "bar"}, type="json_schema")), - OpenAIResponseFormatJSONSchema(json_schema=OpenAIJSONSchema(name="Test", schema={"foo": "bar"})), - ), - (OpenAIResponseText(format=OpenAIResponseTextFormat(type="json_object")), OpenAIResponseFormatJSONObject()), - # ensure text param with no format specified defaults to None - (OpenAIResponseText(format=None), None), - # ensure text param of None defaults to None - (None, None), - ], -) -async def test_create_openai_response_with_text_format( - openai_responses_impl, mock_inference_api, text_format, response_format -): - """Test creating Responses with text formats.""" - # Setup - input_text = "How hot it is in San Francisco today?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - _result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - text=text_format, - ) - - # Verify - first_call = mock_inference_api.openai_chat_completion.call_args_list[0] - first_params = first_call.args[0] - assert first_params.messages[0].content == input_text - assert first_params.response_format == response_format - - -async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api): - """Test creating an OpenAI response with an invalid text format.""" - # Setup - input_text = "How hot it is in San Francisco today?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - # Execute - with pytest.raises(ValueError): - _result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - text=OpenAIResponseText(format={"type": "invalid"}), - ) - - -async def test_create_openai_response_with_output_types_as_input( - openai_responses_impl, mock_inference_api, mock_responses_store -): - """Test that response outputs can be used as inputs in multi-turn conversations. - - Before adding OpenAIResponseOutput types to OpenAIResponseInput, - creating a _OpenAIResponseObjectWithInputAndMessages with some output types - in the input field would fail with a Pydantic ValidationError. - - This test simulates storing a response where the input contains output message - types (MCP calls, function calls), which happens in multi-turn conversations. - """ - model = "meta-llama/Llama-3.1-8B-Instruct" - - # Mock the inference response - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Create a response with store=True to trigger the storage path - result = await openai_responses_impl.create_openai_response( - input="What's the weather?", - model=model, - stream=True, - temperature=0.1, - store=True, - ) - - # Consume the stream - _ = [chunk async for chunk in result] - - # Verify store was called - assert mock_responses_store.upsert_response_object.called - - # Get the stored data - store_call_args = mock_responses_store.upsert_response_object.call_args - stored_response = store_call_args.kwargs["response_object"] - - # Now simulate a multi-turn conversation where outputs become inputs - input_with_output_types = [ - OpenAIResponseMessage(role="user", content="What's the weather?", name=None), - # These output types need to be valid OpenAIResponseInput - OpenAIResponseOutputMessageFunctionToolCall( - call_id="call_123", - name="get_weather", - arguments='{"city": "Tokyo"}', - type="function_call", - ), - OpenAIResponseOutputMessageMCPCall( - id="mcp_456", - type="mcp_call", - server_label="weather_server", - name="get_temperature", - arguments='{"location": "Tokyo"}', - output="25°C", - ), - ] - - # This simulates storing a response in a multi-turn conversation - # where previous outputs are included in the input. - stored_with_outputs = _OpenAIResponseObjectWithInputAndMessages( - id=stored_response.id, - created_at=stored_response.created_at, - model=stored_response.model, - status=stored_response.status, - output=stored_response.output, - input=input_with_output_types, # This will trigger Pydantic validation - messages=None, - store=True, - ) - - assert stored_with_outputs.input == input_with_output_types - assert len(stored_with_outputs.input) == 3 - - -async def test_create_openai_response_with_prompt(openai_responses_impl, mock_inference_api, mock_prompts_api): - """Test creating an OpenAI response with a prompt.""" - input_text = "What is the capital of Ireland?" - model = "meta-llama/Llama-3.1-8B-Instruct" - prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" - prompt = Prompt( - prompt="You are a helpful {{ area_name }} assistant at {{ company_name }}. Always provide accurate information.", - prompt_id=prompt_id, - version=1, - variables=["area_name", "company_name"], - is_default=True, - ) - - openai_response_prompt = OpenAIResponsePrompt( - id=prompt_id, - version="1", - variables={ - "area_name": OpenAIResponseInputMessageContentText(text="geography"), - "company_name": OpenAIResponseInputMessageContentText(text="Dummy Company"), - }, - ) - - mock_prompts_api.get_prompt.return_value = prompt - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - prompt=openai_response_prompt, - ) - - mock_prompts_api.get_prompt.assert_called_with(GetPromptRequest(prompt_id=prompt_id, version=1)) - mock_inference_api.openai_chat_completion.assert_called() - call_args = mock_inference_api.openai_chat_completion.call_args - sent_messages = call_args.args[0].messages - assert len(sent_messages) == 2 - - system_messages = [msg for msg in sent_messages if msg.role == "system"] - assert len(system_messages) == 1 - assert ( - system_messages[0].content - == "You are a helpful geography assistant at Dummy Company. Always provide accurate information." - ) - - user_messages = [msg for msg in sent_messages if msg.role == "user"] - assert len(user_messages) == 1 - assert user_messages[0].content == input_text - - assert result.model == model - assert result.status == "completed" - assert isinstance(result.prompt, OpenAIResponsePrompt) - assert result.prompt.id == prompt_id - assert result.prompt.variables == openai_response_prompt.variables - assert result.prompt.version == "1" - - -async def test_prepend_prompt_successful_without_variables(openai_responses_impl, mock_prompts_api, mock_inference_api): - """Test prepend_prompt function without variables.""" - input_text = "What is the capital of Ireland?" - model = "meta-llama/Llama-3.1-8B-Instruct" - prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" - prompt = Prompt( - prompt="You are a helpful assistant. Always provide accurate information.", - prompt_id=prompt_id, - version=1, - variables=[], - is_default=True, - ) - - openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1") - - mock_prompts_api.get_prompt.return_value = prompt - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - prompt=openai_response_prompt, - ) - - mock_prompts_api.get_prompt.assert_called_with(GetPromptRequest(prompt_id=prompt_id, version=1)) - mock_inference_api.openai_chat_completion.assert_called() - call_args = mock_inference_api.openai_chat_completion.call_args - sent_messages = call_args.args[0].messages - assert len(sent_messages) == 2 - system_messages = [msg for msg in sent_messages if msg.role == "system"] - assert system_messages[0].content == "You are a helpful assistant. Always provide accurate information." - - -async def test_prepend_prompt_invalid_variable(openai_responses_impl, mock_prompts_api): - """Test error handling in prepend_prompt function when prompt parameters contain invalid variables.""" - prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" - prompt = Prompt( - prompt="You are a {{ role }} assistant.", - prompt_id=prompt_id, - version=1, - variables=["role"], # Only "role" is valid - is_default=True, - ) - - openai_response_prompt = OpenAIResponsePrompt( - id=prompt_id, - version="1", - variables={ - "role": OpenAIResponseInputMessageContentText(text="helpful"), - "company": OpenAIResponseInputMessageContentText( - text="Dummy Company" - ), # company is not in prompt.variables - }, - ) - - mock_prompts_api.get_prompt.return_value = prompt - - # Initial messages - messages = [OpenAIUserMessageParam(content="Test prompt")] - - # Execute - should raise InvalidParameterError for invalid variable - with pytest.raises(InvalidParameterError) as exc_info: - await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) - assert "Invalid value for 'prompt.variables': company" in str(exc_info.value) - assert f"Variable not defined in prompt '{prompt_id}'" in str(exc_info.value) - - # Verify - mock_prompts_api.get_prompt.assert_called_once_with(GetPromptRequest(prompt_id=prompt_id, version=1)) - - -async def test_prepend_prompt_not_found(openai_responses_impl, mock_prompts_api): - """Test prepend_prompt function when prompt is not found.""" - prompt_id = "pmpt_nonexistent" - openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1") - - mock_prompts_api.get_prompt.return_value = None # Prompt not found - - # Initial messages - messages = [OpenAIUserMessageParam(content="Test prompt")] - initial_length = len(messages) - - # Execute - result = await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) - - # Verify - mock_prompts_api.get_prompt.assert_called_once_with(GetPromptRequest(prompt_id=prompt_id, version=1)) - - # Should return None when prompt not found - assert result is None - - # Messages should not be modified - assert len(messages) == initial_length - assert messages[0].content == "Test prompt" - - -async def test_prepend_prompt_variable_substitution(openai_responses_impl, mock_prompts_api): - """Test complex variable substitution with multiple occurrences and special characters in prepend_prompt function.""" - prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" - - # Support all whitespace variations: {{name}}, {{ name }}, {{ name}}, {{name }}, etc. - prompt = Prompt( - prompt="Hello {{name}}! You are working at {{ company}}. Your role is {{role}} at {{company}}. Remember, {{ name }}, to be {{ tone }}.", - prompt_id=prompt_id, - version=1, - variables=["name", "company", "role", "tone"], - is_default=True, - ) - - openai_response_prompt = OpenAIResponsePrompt( - id=prompt_id, - version="1", - variables={ - "name": OpenAIResponseInputMessageContentText(text="Alice"), - "company": OpenAIResponseInputMessageContentText(text="Dummy Company"), - "role": OpenAIResponseInputMessageContentText(text="AI Assistant"), - "tone": OpenAIResponseInputMessageContentText(text="professional"), - }, - ) - - mock_prompts_api.get_prompt.return_value = prompt - - # Initial messages - messages = [OpenAIUserMessageParam(content="Test")] - - # Execute - await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) - - # Verify - assert len(messages) == 2 - assert isinstance(messages[0], OpenAISystemMessageParam) - expected_content = "Hello Alice! You are working at Dummy Company. Your role is AI Assistant at Dummy Company. Remember, Alice, to be professional." - assert messages[0].content == expected_content - - -async def test_prepend_prompt_with_image_variable(openai_responses_impl, mock_prompts_api, mock_files_api): - """Test prepend_prompt with image variable - should create placeholder in system message and append image as separate user message.""" - prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" - prompt = Prompt( - prompt="Analyze this {{product_image}} and describe what you see.", - prompt_id=prompt_id, - version=1, - variables=["product_image"], - is_default=True, - ) - - # Mock file content and file metadata - mock_file_content = b"fake_image_data" - mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})() - mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject( - object="file", - id="file-abc123", - bytes=len(mock_file_content), - created_at=1234567890, - expires_at=1234567890, - filename="product.jpg", - purpose="assistants", - ) - - openai_response_prompt = OpenAIResponsePrompt( - id=prompt_id, - version="1", - variables={ - "product_image": OpenAIResponseInputMessageContentImage( - file_id="file-abc123", - detail="high", - ) - }, - ) - - mock_prompts_api.get_prompt.return_value = prompt - - # Initial messages - messages = [OpenAIUserMessageParam(content="What do you think?")] - - # Execute - await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) - - assert len(messages) == 3 - - # Check system message has placeholder - assert isinstance(messages[0], OpenAISystemMessageParam) - assert messages[0].content == "Analyze this [Image: product_image] and describe what you see." - - # Check original user message is still there - assert isinstance(messages[1], OpenAIUserMessageParam) - assert messages[1].content == "What do you think?" - - # Check new user message with image is appended - assert isinstance(messages[2], OpenAIUserMessageParam) - assert isinstance(messages[2].content, list) - assert len(messages[2].content) == 1 - - # Should be image with data URL - assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam) - assert messages[2].content[0].image_url.url.startswith("data:image/") - assert messages[2].content[0].image_url.detail == "high" - - -async def test_prepend_prompt_with_file_variable(openai_responses_impl, mock_prompts_api, mock_files_api): - """Test prepend_prompt with file variable - should create placeholder in system message and append file as separate user message.""" - prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" - prompt = Prompt( - prompt="Review the document {{contract_file}} and summarize key points.", - prompt_id=prompt_id, - version=1, - variables=["contract_file"], - is_default=True, - ) - - # Mock file retrieval - mock_file_content = b"fake_pdf_content" - mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})() - mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject( - object="file", - id="file-contract-789", - bytes=len(mock_file_content), - created_at=1234567890, - expires_at=1234567890, - filename="contract.pdf", - purpose="assistants", - ) - - openai_response_prompt = OpenAIResponsePrompt( - id=prompt_id, - version="1", - variables={ - "contract_file": OpenAIResponseInputMessageContentFile( - file_id="file-contract-789", - filename="contract.pdf", - ) - }, - ) - - mock_prompts_api.get_prompt.return_value = prompt - - # Initial messages - messages = [OpenAIUserMessageParam(content="Please review this.")] - - # Execute - await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) - - assert len(messages) == 3 - - # Check system message has placeholder - assert isinstance(messages[0], OpenAISystemMessageParam) - assert messages[0].content == "Review the document [File: contract_file] and summarize key points." - - # Check original user message is still there - assert isinstance(messages[1], OpenAIUserMessageParam) - assert messages[1].content == "Please review this." - - # Check new user message with file is appended - assert isinstance(messages[2], OpenAIUserMessageParam) - assert isinstance(messages[2].content, list) - assert len(messages[2].content) == 1 - - # First part should be file with data URL - assert isinstance(messages[2].content[0], OpenAIFile) - assert messages[2].content[0].file.file_data.startswith("data:application/pdf;base64,") - assert messages[2].content[0].file.filename == "contract.pdf" - assert messages[2].content[0].file.file_id is None - - -async def test_prepend_prompt_with_mixed_variables(openai_responses_impl, mock_prompts_api, mock_files_api): - """Test prepend_prompt with text, image, and file variables mixed together.""" - prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" - prompt = Prompt( - prompt="Hello {{name}}! Analyze {{photo}} and review {{document}}. Provide insights for {{company}}.", - prompt_id=prompt_id, - version=1, - variables=["name", "photo", "document", "company"], - is_default=True, - ) - - # Mock file retrieval for image and file - mock_image_content = b"fake_image_data" - mock_file_content = b"fake_doc_content" - - async def mock_retrieve_file_content(request): - file_id = request.file_id - if file_id == "file-photo-123": - return type("obj", (object,), {"body": mock_image_content})() - elif file_id == "file-doc-456": - return type("obj", (object,), {"body": mock_file_content})() - - mock_files_api.openai_retrieve_file_content.side_effect = mock_retrieve_file_content - - def mock_retrieve_file(request): - file_id = request.file_id - if file_id == "file-photo-123": - return OpenAIFileObject( - object="file", - id="file-photo-123", - bytes=len(mock_image_content), - created_at=1234567890, - expires_at=1234567890, - filename="photo.jpg", - purpose="assistants", - ) - elif file_id == "file-doc-456": - return OpenAIFileObject( - object="file", - id="file-doc-456", - bytes=len(mock_file_content), - created_at=1234567890, - expires_at=1234567890, - filename="doc.pdf", - purpose="assistants", - ) - - mock_files_api.openai_retrieve_file.side_effect = mock_retrieve_file - - openai_response_prompt = OpenAIResponsePrompt( - id=prompt_id, - version="1", - variables={ - "name": OpenAIResponseInputMessageContentText(text="Alice"), - "photo": OpenAIResponseInputMessageContentImage(file_id="file-photo-123", detail="auto"), - "document": OpenAIResponseInputMessageContentFile(file_id="file-doc-456", filename="doc.pdf"), - "company": OpenAIResponseInputMessageContentText(text="Acme Corp"), - }, - ) - - mock_prompts_api.get_prompt.return_value = prompt - - # Initial messages - messages = [OpenAIUserMessageParam(content="Here's my question.")] - - # Execute - await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) - - assert len(messages) == 3 - - # Check system message has text and placeholders - assert isinstance(messages[0], OpenAISystemMessageParam) - expected_system = "Hello Alice! Analyze [Image: photo] and review [File: document]. Provide insights for Acme Corp." - assert messages[0].content == expected_system - - # Check original user message is still there - assert isinstance(messages[1], OpenAIUserMessageParam) - assert messages[1].content == "Here's my question." - - # Check new user message with media is appended (2 media items) - assert isinstance(messages[2], OpenAIUserMessageParam) - assert isinstance(messages[2].content, list) - assert len(messages[2].content) == 2 - - # First part should be image with data URL - assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam) - assert messages[2].content[0].image_url.url.startswith("data:image/") - - # Second part should be file with data URL - assert isinstance(messages[2].content[1], OpenAIFile) - assert messages[2].content[1].file.file_data.startswith("data:application/pdf;base64,") - assert messages[2].content[1].file.filename == "doc.pdf" - assert messages[2].content[1].file.file_id is None - - -async def test_prepend_prompt_with_image_using_image_url(openai_responses_impl, mock_prompts_api): - """Test prepend_prompt with image variable using image_url instead of file_id.""" - prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" - prompt = Prompt( - prompt="Describe {{screenshot}}.", - prompt_id=prompt_id, - version=1, - variables=["screenshot"], - is_default=True, - ) - - openai_response_prompt = OpenAIResponsePrompt( - id=prompt_id, - version="1", - variables={ - "screenshot": OpenAIResponseInputMessageContentImage( - image_url="https://example.com/screenshot.png", - detail="low", - ) - }, - ) - - mock_prompts_api.get_prompt.return_value = prompt - - # Initial messages - messages = [OpenAIUserMessageParam(content="What is this?")] - - # Execute - await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) - - assert len(messages) == 3 - - # Check system message has placeholder - assert isinstance(messages[0], OpenAISystemMessageParam) - assert messages[0].content == "Describe [Image: screenshot]." - - # Check original user message is still there - assert isinstance(messages[1], OpenAIUserMessageParam) - assert messages[1].content == "What is this?" - - # Check new user message with image is appended - assert isinstance(messages[2], OpenAIUserMessageParam) - assert isinstance(messages[2].content, list) - - # Image should use the provided URL - assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam) - assert messages[2].content[0].image_url.url == "https://example.com/screenshot.png" - assert messages[2].content[0].image_url.detail == "low" - - -async def test_prepend_prompt_image_variable_missing_required_fields(openai_responses_impl, mock_prompts_api): - """Test prepend_prompt with image variable that has neither file_id nor image_url - should raise error.""" - prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" - prompt = Prompt( - prompt="Analyze {{bad_image}}.", - prompt_id=prompt_id, - version=1, - variables=["bad_image"], - is_default=True, - ) - - # Create image content with neither file_id nor image_url - openai_response_prompt = OpenAIResponsePrompt( - id=prompt_id, - version="1", - variables={"bad_image": OpenAIResponseInputMessageContentImage()}, # No file_id or image_url - ) - - mock_prompts_api.get_prompt.return_value = prompt - messages = [OpenAIUserMessageParam(content="Test")] - - # Execute - should raise ValueError - with pytest.raises(ValueError, match="Image content must have either 'image_url' or 'file_id'"): - await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) - - -@patch("llama_stack.providers.inline.responses.builtin.responses.streaming.list_mcp_tools") -async def test_mcp_tool_connector_id_resolved_to_server_url( - mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api, mock_connectors_api -): - """Test that connector_id is resolved to server_url when using MCP tools.""" - from llama_stack_api import Connector, ConnectorType - - # Setup mock connector that will be returned when resolving connector_id - mock_connector = Connector( - connector_id="my-mcp-connector", - connector_type=ConnectorType.MCP, - url="http://resolved-mcp-server:8080/mcp", - server_label="Resolved MCP Server", - ) - mock_connectors_api.get_connector.return_value = mock_connector - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - mock_list_mcp_tools.return_value = ListToolDefsResponse( - data=[ToolDef(name="resolved_tool", description="a resolved tool", input_schema={}, output_schema={})] - ) - - # Create a response using connector_id instead of server_url - result = await openai_responses_impl.create_openai_response( - input="Test connector resolution", - model="meta-llama/Llama-3.1-8B-Instruct", - store=True, - tools=[ - OpenAIResponseInputToolMCP(server_label="my-label", connector_id="my-mcp-connector"), - ], - ) - - # Verify the connector_id was resolved via the connectors API - mock_connectors_api.get_connector.assert_called_once_with(GetConnectorRequest(connector_id="my-mcp-connector")) - - # Verify list_mcp_tools was called with the resolved URL - mock_list_mcp_tools.assert_called_once() - call_kwargs = mock_list_mcp_tools.call_args.kwargs - assert call_kwargs["endpoint"] == "http://resolved-mcp-server:8080/mcp" - - # Verify the response contains the resolved tools - listings = [obj for obj in result.output if obj.type == "mcp_list_tools"] - assert len(listings) == 1 - assert listings[0].server_label == "my-label" - assert len(listings[0].tools) == 1 - assert listings[0].tools[0].name == "resolved_tool" - - -async def test_file_search_results_include_chunk_metadata_attributes(mock_vector_io_api): - """Test that file_search tool executor preserves chunk metadata attributes.""" - query = "What is machine learning?" - vector_store_id = "test_vector_store" - - # Mock vector_io to return search results with custom attributes - mock_vector_io_api.openai_search_vector_store.return_value = VectorStoreSearchResponsePage( - search_query=[query], - data=[ - VectorStoreSearchResponse( - file_id="doc-123", - filename="ml-intro.md", - content=[VectorStoreContent(type="text", text="Machine learning is a subset of AI")], - score=0.95, - attributes={ - "document_id": "ml-intro", - "source_url": "https://example.com/ml-guide", - "title": "Introduction to ML", - "author": "John Doe", - "year": "2024", - }, - ), - VectorStoreSearchResponse( - file_id="doc-456", - filename="dl-basics.md", - content=[VectorStoreContent(type="text", text="Deep learning uses neural networks")], - score=0.85, - attributes={ - "document_id": "dl-basics", - "source_url": "https://example.com/dl-guide", - "title": "Deep Learning Basics", - "category": "tutorial", - }, - ), - ], - ) - - # Create tool executor with mock vector_io - tool_executor = ToolExecutor( - tool_groups_api=None, # type: ignore - tool_runtime_api=None, # type: ignore - vector_io_api=mock_vector_io_api, - vector_stores_config=VectorStoresConfig(), - mcp_session_manager=None, - ) - - # Execute the file search - file_search_tool = OpenAIResponseInputToolFileSearch(vector_store_ids=[vector_store_id]) - result = await tool_executor._execute_file_search_via_vector_store( - query=query, - response_file_search_tool=file_search_tool, - ) - - mock_vector_io_api.openai_search_vector_store.assert_called_once() - - # Verify the result metadata includes chunk attributes - assert result.metadata is not None - assert "attributes" in result.metadata - attributes = result.metadata["attributes"] - assert len(attributes) == 2 - - # Verify first result has all expected attributes - attrs1 = attributes[0] - assert attrs1["document_id"] == "ml-intro" - assert attrs1["source_url"] == "https://example.com/ml-guide" - assert attrs1["title"] == "Introduction to ML" - assert attrs1["author"] == "John Doe" - assert attrs1["year"] == "2024" - - # Verify second result has its attributes - attrs2 = attributes[1] - assert attrs2["document_id"] == "dl-basics" - assert attrs2["source_url"] == "https://example.com/dl-guide" - assert attrs2["title"] == "Deep Learning Basics" - assert attrs2["category"] == "tutorial" - - # Verify scores and document_ids are also present - assert result.metadata["scores"] == [0.95, 0.85] - assert result.metadata["document_ids"] == ["doc-123", "doc-456"] - assert result.metadata["chunks"] == [ - "Machine learning is a subset of AI", - "Deep learning uses neural networks", - ] - - -async def test_create_openai_response_with_max_output_tokens_non_streaming( - openai_responses_impl, mock_inference_api, mock_responses_store -): - """Test that max_output_tokens is properly handled in non-streaming responses.""" - input_text = "Write a long story about AI." - model = "meta-llama/Llama-3.1-8B-Instruct" - max_tokens = 100 - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - max_output_tokens=max_tokens, - stream=False, - store=True, - ) - - # Verify response includes the max_output_tokens - assert result.max_output_tokens == max_tokens - assert result.model == model - assert result.status == "completed" - - # Verify the max_output_tokens was passed to inference API - mock_inference_api.openai_chat_completion.assert_called() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - assert params.max_completion_tokens == max_tokens - - # Verify the max_output_tokens was stored - mock_responses_store.upsert_response_object.assert_called() - store_call_args = mock_responses_store.upsert_response_object.call_args - stored_response = store_call_args.kwargs["response_object"] - assert stored_response.max_output_tokens == max_tokens - - -async def test_create_openai_response_with_max_output_tokens_streaming( - openai_responses_impl, mock_inference_api, mock_responses_store -): - """Test that max_output_tokens is properly handled in streaming responses.""" - input_text = "Explain machine learning in detail." - model = "meta-llama/Llama-3.1-8B-Instruct" - max_tokens = 200 - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - max_output_tokens=max_tokens, - stream=True, - store=True, - ) - - # Collect all chunks - chunks = [chunk async for chunk in result] - - # Verify max_output_tokens is in the created event - created_event = chunks[0] - assert created_event.type == "response.created" - assert created_event.response.max_output_tokens == max_tokens - - # Verify max_output_tokens is in the completed event - completed_event = chunks[-1] - assert completed_event.type == "response.completed" - assert completed_event.response.max_output_tokens == max_tokens - - # Verify the max_output_tokens was passed to inference API - mock_inference_api.openai_chat_completion.assert_called() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - assert params.max_completion_tokens == max_tokens - - # Verify the max_output_tokens was stored - mock_responses_store.upsert_response_object.assert_called() - store_call_args = mock_responses_store.upsert_response_object.call_args - stored_response = store_call_args.kwargs["response_object"] - assert stored_response.max_output_tokens == max_tokens - - -async def test_create_openai_response_with_max_output_tokens_boundary_value(openai_responses_impl, mock_inference_api): - """Test that max_output_tokens accepts the minimum valid value of 16.""" - input_text = "Hi" - model = "meta-llama/Llama-3.1-8B-Instruct" - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute with minimum valid value - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - max_output_tokens=16, - stream=False, - ) - - # Verify it accepts 16 - assert result.max_output_tokens == 16 - assert result.status == "completed" - - # Verify the inference API was called with max_completion_tokens=16 - mock_inference_api.openai_chat_completion.assert_called() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - assert params.max_completion_tokens == 16 - - -async def test_create_openai_response_with_max_output_tokens_and_tools(openai_responses_impl, mock_inference_api): - """Test that max_output_tokens works correctly with tool calls.""" - input_text = "What's the weather in San Francisco?" - model = "meta-llama/Llama-3.1-8B-Instruct" - max_tokens = 150 - - openai_responses_impl.tool_groups_api.get_tool.return_value = ToolDef( - name="get_weather", - toolgroup_id="weather", - description="Get weather information", - input_schema={ - "type": "object", - "properties": {"location": {"type": "string"}}, - "required": ["location"], - }, - ) - - openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult( - status="completed", - content="Sunny, 72°F", - ) - - # Mock two inference calls: one for tool call, one for final response - mock_inference_api.openai_chat_completion.side_effect = [ - fake_stream("tool_call_completion.yaml"), - fake_stream(), - ] - - # Execute - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - max_output_tokens=max_tokens, - stream=False, - tools=[ - OpenAIResponseInputToolFunction( - name="get_weather", - description="Get weather information", - parameters={"location": "string"}, - ) - ], - ) - - # Verify max_output_tokens is preserved - assert result.max_output_tokens == max_tokens - assert result.status == "completed" - - # Verify both inference calls received max_completion_tokens - assert mock_inference_api.openai_chat_completion.call_count == 2 - for call in mock_inference_api.openai_chat_completion.call_args_list: - params = call.args[0] - # The first call gets the full max_tokens, subsequent calls get remaining tokens - assert params.max_completion_tokens is not None - assert params.max_completion_tokens <= max_tokens - - -@pytest.mark.parametrize("store", [False, True]) -@pytest.mark.parametrize("stream", [False, True]) -@pytest.mark.parametrize( - "param_name,param_value,backend_param_name,backend_expected_value,response_expected_value,stored_expected_value", - [ - ("temperature", 1.5, "temperature", 1.5, 1.5, 1.5), - ("safety_identifier", "user-123", "safety_identifier", "user-123", "user-123", "user-123"), - ("max_output_tokens", 500, "max_completion_tokens", 500, 500, 500), - ( - "prompt_cache_key", - "geography-cache-001", - "prompt_cache_key", - "geography-cache-001", - "geography-cache-001", - "geography-cache-001", - ), - ("service_tier", ServiceTier.flex, "service_tier", "flex", "flex", ServiceTier.default.value), - ("top_p", 0.9, "top_p", 0.9, 0.9, 0.9), - ("frequency_penalty", 0.5, "frequency_penalty", 0.5, 0.5, 0.5), - ("presence_penalty", 0.3, "presence_penalty", 0.3, 0.3, 0.3), - ("top_logprobs", 5, "top_logprobs", 5, 5, 5), - ( - "extra_body", - {"chat_template_kwargs": {"thinking": True}}, - "extra_body", - {"chat_template_kwargs": {"thinking": True}}, - None, - None, - ), - ], -) -async def test_params_passed_through_full_chain_to_backend_service( - param_name, - param_value, - backend_param_name, - backend_expected_value, - response_expected_value, - stored_expected_value, - stream, - store, - mock_responses_store, -): - """Test that parameters which pass through to the backend service are correctly propagated. - - Only parameters that are forwarded as kwargs to the underlying chat completions API belong - here. Parameters handled internally by the responses layer (e.g. truncation) should be - tested separately since they don't produce a backend kwarg assertion. - - This test should not act differently based on the param_name/param_value/etc. Needing changes - in behavior based on those params suggests a bug in the implementation. - - This test may act differently based on : - - stream: whether the response is streamed or not - - store: whether the response is persisted via the responses store - """ - config = OpenAIConfig(api_key="test-key") - openai_adapter = OpenAIInferenceAdapter(config=config) - openai_adapter.provider_data_api_key_field = None - - mock_model_store = AsyncMock() - mock_model_store.has_model = AsyncMock(return_value=False) - openai_adapter.model_store = mock_model_store - - openai_responses_impl = OpenAIResponsesImpl( - inference_api=openai_adapter, - tool_groups_api=AsyncMock(), - tool_runtime_api=AsyncMock(), - responses_store=mock_responses_store, - vector_io_api=AsyncMock(), - safety_api=AsyncMock(), - conversations_api=AsyncMock(), - prompts_api=AsyncMock(), - files_api=AsyncMock(), - connectors_api=AsyncMock(), - ) - - with patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class: - mock_client = MagicMock() - mock_chat_completions = AsyncMock() - mock_client.chat.completions.create = mock_chat_completions - mock_openai_class.return_value = mock_client - - if stream: - mock_chat_completions.return_value = fake_stream() - else: - mock_response = MagicMock() - mock_response.id = "chatcmpl-123" - mock_response.choices = [ - MagicMock( - index=0, - message=MagicMock(content="Test response", role="assistant", tool_calls=None), - finish_reason="stop", - ) - ] - mock_response.model = "fake-model" - mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30) - mock_chat_completions.return_value = mock_response - - result = await openai_responses_impl.create_openai_response( - **{ - "input": "Test message", - "model": "fake-model", - "stream": stream, - "store": store, - param_name: param_value, - } - ) - if stream: - chunks = [chunk async for chunk in result] - created_event = chunks[0] - assert created_event.type == "response.created" - assert getattr(created_event.response, param_name, None) == response_expected_value, ( - f"Expected created {param_name}={response_expected_value}, got {getattr(created_event.response, param_name, None)}" - ) - completed_event = chunks[-1] - assert completed_event.type == "response.completed" - assert getattr(completed_event.response, param_name, None) == stored_expected_value, ( - f"Expected completed {param_name}={stored_expected_value}, got {getattr(completed_event.response, param_name, None)}" - ) - - mock_chat_completions.assert_called_once() - call_kwargs = mock_chat_completions.call_args[1] - - assert backend_param_name in call_kwargs, f"{backend_param_name} not found in backend call" - assert call_kwargs[backend_param_name] == backend_expected_value, ( - f"Expected {backend_param_name}={backend_expected_value}, got {call_kwargs[backend_param_name]}" - ) - - if store: - mock_responses_store.upsert_response_object.assert_called() - stored_response = mock_responses_store.upsert_response_object.call_args.kwargs["response_object"] - assert getattr(stored_response, param_name, None) == stored_expected_value, ( - f"Expected stored {param_name}={stored_expected_value}, got {getattr(stored_response, param_name, None)}" - ) - else: - mock_responses_store.upsert_response_object.assert_not_called() - - -async def test_function_tool_strict_field_excluded_when_none(openai_responses_impl, mock_inference_api): - """Test that function tool 'strict' field is excluded when None (fix for #4617).""" - input_text = "What is the weather?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - # Mock inference response - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute with function tool that has strict=None (default) - await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream=False, - tools=[ - OpenAIResponseInputToolFunction( - type="function", - name="get_weather", - description="Get weather information", - parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, - # strict is None by default - ) - ], - ) - - # Verify the call was made - assert mock_inference_api.openai_chat_completion.call_count == 1 - params = mock_inference_api.openai_chat_completion.call_args[0][0] - - # Verify tools were passed - assert params.tools is not None - assert len(params.tools) == 1 - - # Critical: verify 'strict' field is NOT present when it's None - # This prevents "strict: null" from being sent to OpenAI API - tool_function = params.tools[0]["function"] - assert "strict" not in tool_function, ( - "strict field should be excluded when None to avoid OpenAI API validation error" - ) - - # Verify other fields are present - assert tool_function["name"] == "get_weather" - assert tool_function["description"] == "Get weather information" - assert tool_function["parameters"] is not None - - -async def test_function_tool_strict_field_included_when_set(openai_responses_impl, mock_inference_api): - """Test that function tool 'strict' field is included when explicitly set.""" - input_text = "What is the weather?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - # Mock inference response - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute with function tool that has strict=True - await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream=False, - tools=[ - OpenAIResponseInputToolFunction( - type="function", - name="get_weather", - description="Get weather information", - parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, - strict=True, # Explicitly set to True - ) - ], - ) - - # Verify the call was made - assert mock_inference_api.openai_chat_completion.call_count == 1 - params = mock_inference_api.openai_chat_completion.call_args[0][0] - - # Verify tools were passed - assert params.tools is not None - assert len(params.tools) == 1 - - # Verify 'strict' field IS present when explicitly set - tool_function = params.tools[0]["function"] - assert "strict" in tool_function, "strict field should be included when explicitly set" - assert tool_function["strict"] is True, "strict field should have the correct value" - - # Verify other fields are present - assert tool_function["name"] == "get_weather" - assert tool_function["description"] == "Get weather information" - - -async def test_function_tool_strict_false_included(openai_responses_impl, mock_inference_api): - """Test that function tool 'strict' field is included when set to False.""" - input_text = "What is the weather?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - # Mock inference response - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute with function tool that has strict=False - await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream=False, - tools=[ - OpenAIResponseInputToolFunction( - type="function", - name="get_weather", - description="Get weather information", - parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, - strict=False, # Explicitly set to False - ) - ], - ) - - # Verify the call was made - assert mock_inference_api.openai_chat_completion.call_count == 1 - params = mock_inference_api.openai_chat_completion.call_args[0][0] - - # Verify 'strict' field IS present and set to False - tool_function = params.tools[0]["function"] - assert "strict" in tool_function, "strict field should be included when explicitly set to False" - assert tool_function["strict"] is False, "strict field should be False" - - -async def test_create_openai_response_with_truncation_disabled_streaming( - openai_responses_impl, mock_inference_api, mock_responses_store -): - """Test that truncation='disabled' is properly handled in streaming responses.""" - input_text = "Explain machine learning comprehensively." - model = "meta-llama/Llama-3.1-8B-Instruct" - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - truncation=ResponseTruncation.disabled, - stream=True, - store=True, - ) - - # Collect all chunks - chunks = [chunk async for chunk in result] - - # Verify truncation is in the created event - created_event = chunks[0] - assert created_event.type == "response.created" - assert created_event.response.truncation == ResponseTruncation.disabled - - # Verify truncation is in the completed event - completed_event = chunks[-1] - assert completed_event.type == "response.completed" - assert completed_event.response.truncation == ResponseTruncation.disabled - - mock_inference_api.openai_chat_completion.assert_called() - - # Verify the truncation was stored - mock_responses_store.upsert_response_object.assert_called() - store_call_args = mock_responses_store.upsert_response_object.call_args - stored_response = store_call_args.kwargs["response_object"] - assert stored_response.truncation == ResponseTruncation.disabled - - -async def test_create_openai_response_with_truncation_auto_streaming( - openai_responses_impl, mock_inference_api, mock_responses_store -): - """Test that truncation='auto' raises an error since it's not yet supported.""" - input_text = "Tell me about quantum computing." - model = "meta-llama/Llama-3.1-8B-Instruct" - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - truncation=ResponseTruncation.auto, - stream=True, - store=True, - ) - - # Collect all chunks - chunks = [chunk async for chunk in result] - - # Verify truncation is in the created event - created_event = chunks[0] - assert created_event.type == "response.created" - assert created_event.response.truncation == ResponseTruncation.auto - - # Verify the response failed due to unsupported truncation mode - failed_event = chunks[-1] - assert failed_event.type == "response.failed" - assert failed_event.response.truncation == ResponseTruncation.auto - assert failed_event.response.error is not None - assert failed_event.response.error.code == "server_error" - assert "Truncation mode 'auto' is not supported" in failed_event.response.error.message - - # Inference API should not be called since error occurs before inference - mock_inference_api.openai_chat_completion.assert_not_called() - - # Verify the failed response was stored - mock_responses_store.upsert_response_object.assert_called() - store_call_args = mock_responses_store.upsert_response_object.call_args - stored_response = store_call_args.kwargs["response_object"] - assert stored_response.truncation == ResponseTruncation.auto - assert stored_response.status == "failed" - - -async def test_create_openai_response_with_prompt_cache_key_and_previous_response( - openai_responses_impl, mock_responses_store, mock_inference_api -): - """Test that prompt_cache_key works correctly with previous_response_id.""" - # Setup previous response - previous_response = _OpenAIResponseObjectWithInputAndMessages( - id="resp-prev-123", - object="response", - created_at=1234567890, - model="meta-llama/Llama-3.1-8B-Instruct", - status="completed", - text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), - input=[OpenAIResponseMessage(id="msg-1", role="user", content="First question")], - output=[OpenAIResponseMessage(id="msg-2", role="assistant", content="First answer")], - messages=[ - OpenAIUserMessageParam(content="First question"), - OpenAIAssistantMessageParam(content="First answer"), - ], - prompt_cache_key="conversation-cache-001", - store=True, - ) - - mock_responses_store.get_response_object.return_value = previous_response - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Create a new response with the same cache key - result = await openai_responses_impl.create_openai_response( - input="Second question", - model="meta-llama/Llama-3.1-8B-Instruct", - previous_response_id="resp-prev-123", - prompt_cache_key="conversation-cache-001", - store=True, - ) - - # Verify cache key is preserved - assert result.prompt_cache_key == "conversation-cache-001" - assert result.status == "completed" - - # Verify the cache key was stored - mock_responses_store.upsert_response_object.assert_called() - store_call_args = mock_responses_store.upsert_response_object.call_args - stored_response = store_call_args.kwargs["response_object"] - assert stored_response.prompt_cache_key == "conversation-cache-001" - - -async def test_create_openai_response_with_service_tier(openai_responses_impl, mock_inference_api): - """Test creating an OpenAI response with service_tier parameter.""" - # Setup - input_text = "What is the capital of France?" - model = "meta-llama/Llama-3.1-8B-Instruct" - service_tier = ServiceTier.flex - - # Load the chat completion fixture - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - non-streaming to get final response directly - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - service_tier=service_tier, - stream=False, - ) - - # Verify service_tier is preserved in the response (as string) - assert result.service_tier == ServiceTier.default.value - assert result.status == "completed" - - # Verify inference call received service_tier - mock_inference_api.openai_chat_completion.assert_called_once() - params = mock_inference_api.openai_chat_completion.call_args.args[0] - assert params.service_tier == service_tier - - -async def test_create_openai_response_service_tier_auto_transformation(openai_responses_impl, mock_inference_api): - """Test that service_tier 'auto' is transformed to actual tier from provider response.""" - # Setup - input_text = "Hello" - model = "meta-llama/Llama-3.1-8B-Instruct" - - # Mock a response that returns actual service tier when "auto" was requested - async def fake_stream_with_service_tier(): - yield ChatCompletionChunk( - id="chatcmpl-123", - choices=[ - Choice( - index=0, - delta=ChoiceDelta(content="Hi there!", role="assistant"), - finish_reason="stop", - ) - ], - created=1234567890, - model=model, - object="chat.completion.chunk", - service_tier="default", # Provider returns actual tier used - ) - - mock_inference_api.openai_chat_completion.return_value = fake_stream_with_service_tier() - - # Execute with "auto" service tier - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - service_tier=ServiceTier.auto, - stream=False, - ) - - # Verify the response has the actual tier from provider, not "auto" - assert result.service_tier == "default", "service_tier should be transformed from 'auto' to actual tier" - assert result.service_tier != ServiceTier.auto.value, "service_tier should not remain as 'auto'" - assert result.status == "completed" - - # Verify inference was called with "auto" - mock_inference_api.openai_chat_completion.assert_called_once() - params = mock_inference_api.openai_chat_completion.call_args.args[0] - assert params.service_tier == "auto" - - -async def test_create_openai_response_service_tier_propagation_streaming(openai_responses_impl, mock_inference_api): - """Test that service_tier from chat completion is propagated to response object in streaming mode.""" - # Setup - input_text = "Tell me about AI" - model = "meta-llama/Llama-3.1-8B-Instruct" - - # Mock streaming response with service_tier - async def fake_stream_with_service_tier(): - yield ChatCompletionChunk( - id="chatcmpl-456", - choices=[ - Choice( - index=0, - delta=ChoiceDelta(content="AI is", role="assistant"), - finish_reason=None, - ) - ], - created=1234567890, - model=model, - object="chat.completion.chunk", - service_tier="priority", # First chunk with service_tier - ) - yield ChatCompletionChunk( - id="chatcmpl-456", - choices=[ - Choice( - index=0, - delta=ChoiceDelta(content=" amazing!"), - finish_reason="stop", - ) - ], - created=1234567890, - model=model, - object="chat.completion.chunk", - ) - - mock_inference_api.openai_chat_completion.return_value = fake_stream_with_service_tier() - - # Execute with "auto" but provider returns "priority" - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - service_tier=ServiceTier.auto, - stream=True, - ) - - # Collect all chunks - chunks = [chunk async for chunk in result] - # Verify service_tier is propagated to all events - created_event = chunks[0] - assert created_event.type == "response.created" - # Initially should have "auto" value - assert created_event.response.service_tier == "auto" - - # Check final response has the actual tier from provider - completed_event = chunks[-1] - assert completed_event.type == "response.completed" - assert completed_event.response.service_tier == "priority", "Final response should have actual tier from provider" - - -def test_response_object_incomplete_details_null_when_completed(): - """Test that completed response has incomplete_details as null.""" - from llama_stack_api.openai_responses import OpenAIResponseObject - - response = OpenAIResponseObject( - created_at=1234567890, - id="resp_123", - model="gpt-4o", - object="response", - output=[], - status="completed", - store=False, - ) - - assert response.incomplete_details is None - - # Verify JSON serialization - json_data = response.model_dump(mode="json") - assert json_data["incomplete_details"] is None - - -def test_response_object_incomplete_details_with_max_output_tokens_reason(): - """Test that incomplete response has incomplete_details with max_output_tokens reason.""" - from llama_stack_api.openai_responses import OpenAIResponseIncompleteDetails, OpenAIResponseObject - - response = OpenAIResponseObject( - created_at=1234567890, - id="resp_456", - model="gpt-4o", - object="response", - output=[], - status="incomplete", - store=False, - incomplete_details=OpenAIResponseIncompleteDetails(reason="max_output_tokens"), - ) - - assert response.incomplete_details is not None - assert response.incomplete_details.reason == "max_output_tokens" - - # Verify JSON serialization - json_data = response.model_dump(mode="json") - assert json_data["incomplete_details"] == {"reason": "max_output_tokens"} - - -def test_response_object_incomplete_details_with_length_reason(): - """Test that incomplete response has incomplete_details with length reason.""" - from llama_stack_api.openai_responses import OpenAIResponseIncompleteDetails, OpenAIResponseObject - - response = OpenAIResponseObject( - created_at=1234567890, - id="resp_length", - model="gpt-4o", - object="response", - output=[], - status="incomplete", - store=False, - incomplete_details=OpenAIResponseIncompleteDetails(reason="length"), - ) - - assert response.incomplete_details is not None - assert response.incomplete_details.reason == "length" - - # Verify JSON serialization - json_data = response.model_dump(mode="json") - assert json_data["incomplete_details"] == {"reason": "length"} - - -def test_response_object_incomplete_details_with_max_iterations_exceeded_reason(): - """Test that incomplete response has incomplete_details with max_iterations_exceeded reason.""" - from llama_stack_api.openai_responses import OpenAIResponseIncompleteDetails, OpenAIResponseObject - - response = OpenAIResponseObject( - created_at=1234567890, - id="resp_iters", - model="gpt-4o", - object="response", - output=[], - status="incomplete", - store=False, - incomplete_details=OpenAIResponseIncompleteDetails(reason="max_iterations_exceeded"), - ) - - assert response.incomplete_details is not None - assert response.incomplete_details.reason == "max_iterations_exceeded" - - # Verify JSON serialization - json_data = response.model_dump(mode="json") - assert json_data["incomplete_details"] == {"reason": "max_iterations_exceeded"} - - -async def test_agent_loop_incomplete_due_to_max_output_tokens( - openai_responses_impl, mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api -): - """Test that agent loop marks response incomplete when max_output_tokens is reached.""" - from openai.types.completion_usage import CompletionUsage - - model = "gpt-4o" - max_output_tokens = 25 # Set low enough to be exceeded by first tool call - - # Setup tool mocks - the tool call will trigger a second iteration - mock_tool_groups_api.get_tool.return_value = ToolDef( - name="web_search", description="Search the web", input_schema={} - ) - mock_tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(content="Search results") - - # First stream: returns a tool call after consuming 30 tokens (exceeds limit of 25) - async def first_stream_with_tool_call(): - yield ChatCompletionChunk( - id="test_123", - choices=[ - Choice( - index=0, - delta=ChoiceDelta( - role="assistant", - tool_calls=[ - ChoiceDeltaToolCall( - index=0, - id="call_abc", - function=ChoiceDeltaToolCallFunction(name="web_search", arguments='{"query":"test"}'), - ) - ], - ), - finish_reason="tool_calls", - ) - ], - created=1234567890, - model=model, - object="chat.completion.chunk", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=30, total_tokens=40), - ) - - # Second stream would consume more tokens, but should be skipped because we already have 30 tokens - # and the next iteration check will see we're at/above the limit - async def second_stream(): - yield ChatCompletionChunk( - id="test_456", - choices=[ - Choice( - index=0, - delta=ChoiceDelta(content="More content", role="assistant"), - finish_reason="stop", - ) - ], - created=1234567890, - model=model, - object="chat.completion.chunk", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=25, total_tokens=35), - ) - - # Mock returns first stream with tool call, then would return second stream but should be blocked - mock_inference_api.openai_chat_completion.side_effect = [first_stream_with_tool_call(), second_stream()] - - # Execute with max_output_tokens that will be exceeded after first tool call - result = await openai_responses_impl.create_openai_response( - input="Test input", - model=model, - max_output_tokens=max_output_tokens, - tools=[OpenAIResponseInputToolWebSearch(type="web_search")], - stream=True, - ) - - # Collect all events - events = [event async for event in result] - - # Find the final event (should be response.incomplete) - final_event = events[-1] - assert final_event.type == "response.incomplete" - assert final_event.response.status == "incomplete" - assert final_event.response.incomplete_details is not None - assert final_event.response.incomplete_details.reason == "max_output_tokens" - - -async def test_agent_loop_incomplete_due_to_max_iterations( - openai_responses_impl, mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api -): - """Test that agent loop marks response incomplete when max iterations is exceeded via tool calls.""" - from openai.types.completion_usage import CompletionUsage - - model = "gpt-4o" - - # Setup tool mocks - mock_tool_groups_api.get_tool.return_value = ToolDef( - name="web_search", description="Search the web", input_schema={} - ) - mock_tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(content="Search results") - - # Create a stream generator factory that returns a tool call (to trigger another iteration) - call_counter = {"count": 0} - - def fake_stream_factory(): - async def fake_stream_with_tool_call(): - call_id = f"call_abc{call_counter['count']}" - call_counter["count"] += 1 - # First chunk with tool call - yield ChatCompletionChunk( - id="test_123", - choices=[ - Choice( - index=0, - delta=ChoiceDelta( - role="assistant", - tool_calls=[ - ChoiceDeltaToolCall( - index=0, - id=call_id, - function=ChoiceDeltaToolCallFunction( - name="web_search", arguments='{"query":"test"}' - ), - ) - ], - ), - finish_reason="tool_calls", - ) - ], - created=1234567890, - model=model, - object="chat.completion.chunk", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), - ) - - return fake_stream_with_tool_call() - - # Mock the inference to repeatedly return tool calls to exceed max iterations - # Default max_infer_iters is 5, so we need to keep returning tool calls - mock_inference_api.openai_chat_completion.side_effect = lambda *args, **kwargs: fake_stream_factory() - - # Execute with tool configuration - result = await openai_responses_impl.create_openai_response( - input="Test input", - model=model, - tools=[OpenAIResponseInputToolWebSearch(type="web_search")], - stream=True, - ) - - # Collect all events - events = [event async for event in result] - - # Find the final event (should be response.incomplete) - final_event = events[-1] - assert final_event.type == "response.incomplete" - assert final_event.response.status == "incomplete" - assert final_event.response.incomplete_details is not None - assert final_event.response.incomplete_details.reason == "max_iterations_exceeded" - - -async def test_agent_loop_incomplete_due_to_length_finish_reason(openai_responses_impl, mock_inference_api): - """Test that agent loop marks response incomplete when model returns finish_reason='length'.""" - from openai.types.completion_usage import CompletionUsage - - model = "gpt-4o" - - # Create a stream that returns finish_reason="length" - async def fake_stream_with_length_finish(): - yield ChatCompletionChunk( - id="test_123", - choices=[ - Choice( - index=0, - delta=ChoiceDelta(content="This is a response that was cut off due to", role="assistant"), - finish_reason=None, - ) - ], - created=1234567890, - model=model, - object="chat.completion.chunk", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=20, total_tokens=30), - ) - # Final chunk with finish_reason="length" - yield ChatCompletionChunk( - id="test_123", - choices=[ - Choice( - index=0, - delta=ChoiceDelta(content=" length"), - finish_reason="length", # This indicates the response was truncated - ) - ], - created=1234567890, - model=model, - object="chat.completion.chunk", - usage=CompletionUsage(prompt_tokens=10, completion_tokens=25, total_tokens=35), - ) - - mock_inference_api.openai_chat_completion.return_value = fake_stream_with_length_finish() - - # Execute - result = await openai_responses_impl.create_openai_response( - input="Test input", - model=model, - stream=True, - ) - - # Collect all events - events = [event async for event in result] - - # Find the final event (should be response.incomplete) - final_event = events[-1] - assert final_event.type == "response.incomplete" - assert final_event.response.status == "incomplete" - assert final_event.response.incomplete_details is not None - assert final_event.response.incomplete_details.reason == "length" - - -async def test_create_openai_response_with_top_logprobs_boundary_values( - openai_responses_impl, mock_inference_api, mock_responses_store -): - """Test that top_logprobs works with boundary values (0 and 20).""" - input_text = "Test message" - model = "meta-llama/Llama-3.1-8B-Instruct" - - # Test with minimum value (0) - mock_inference_api.openai_chat_completion.return_value = fake_stream() - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - top_logprobs=0, - stream=False, - store=True, - ) - assert result.top_logprobs == 0 - - # Test with maximum value (20) - mock_inference_api.openai_chat_completion.return_value = fake_stream() - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - top_logprobs=20, - stream=False, - store=True, - ) - assert result.top_logprobs == 20 - - -async def test_create_openai_response_with_frequency_penalty_default(openai_responses_impl, mock_inference_api): - """Test that frequency_penalty defaults to 0.0 when not provided.""" - input_text = "Hello" - model = "meta-llama/Llama-3.1-8B-Instruct" - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute without frequency_penalty - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream=False, - ) - - # Verify response has 0.0 for frequency_penalty (non-null default for OpenResponses conformance) - assert result.frequency_penalty == 0.0 - - # Verify inference API was called with None - mock_inference_api.openai_chat_completion.assert_called() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - assert params.frequency_penalty is None - - -async def test_create_openai_response_with_presence_penalty_default(openai_responses_impl, mock_inference_api): - """Test that presence_penalty defaults to 0.0 when not provided.""" - input_text = "Hi" - model = "meta-llama/Llama-3.1-8B-Instruct" - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute without presence_penalty - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream=False, - ) - - # Verify presence_penalty is 0.0 (non-null default for OpenResponses conformance) - assert result.presence_penalty == 0.0 - assert result.status == "completed" - - # Verify the inference API was called with presence_penalty=None - mock_inference_api.openai_chat_completion.assert_called() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - assert params.presence_penalty is None - - -async def test_hallucinated_tool_call_does_not_cause_500(openai_responses_impl, mock_inference_api): - """Regression test: a hallucinated tool name should not produce a 500 (InternalServerError). - - When the LLM calls a tool name that is not in the registered tools list the server - was raising ValueError from _coordinate_tool_execution which then propagated as an - InternalServerError (HTTP 500). The correct behaviour is to surface the unknown call - as a regular function-tool-call output so the client can respond, exactly as OpenAI - does for any function tool call. - """ - input_text = "What is the capital of Ireland?" - model = "meta-llama/Llama-3.1-8B-Instruct" - - async def fake_stream_hallucinated_tool(): - # The LLM calls "lookup_capital_city" which is NOT in the registered tools list. - yield ChatCompletionChunk( - id="hallucinated-123", - choices=[ - Choice( - index=0, - delta=ChoiceDelta( - tool_calls=[ - ChoiceDeltaToolCall( - index=0, - id="tc_hall_123", - function=ChoiceDeltaToolCallFunction( - name="lookup_capital_city", - arguments='{"country": "Ireland"}', - ), - type="function", - ) - ] - ), - ), - ], - created=1, - model=model, - object="chat.completion.chunk", - ) - - mock_inference_api.openai_chat_completion.return_value = fake_stream_hallucinated_tool() - - # The only registered tool is "get_weather". The LLM hallucinated "lookup_capital_city". - # The response should complete without raising InternalServerError, and the hallucinated - # call should appear in the output as a function_call item so the client can handle it. - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - tools=[ - OpenAIResponseInputToolFunction( - name="get_weather", - description="Get current temperature for a given location.", - parameters={ - "type": "object", - "properties": {"location": {"type": "string"}}, - "required": ["location"], - }, - ) - ], - ) - - assert result is not None - assert result.status == "completed" - assert len(result.output) == 1 - assert result.output[0].type == "function_call" - assert result.output[0].name == "lookup_capital_city" - - -async def test_create_openai_response_with_stream_options_merges_with_default( - openai_responses_impl, mock_inference_api -): - """Test that stream_options merges with default include_usage.""" - input_text = "Test stream options" - model = "meta-llama/Llama-3.1-8B-Instruct" - stream_options = ResponseStreamOptions(include_obfuscation=False) - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream_options=stream_options, - stream=True, - ) - - # Collect chunks (consume the async iterator) - _ = [chunk async for chunk in result] - - # Verify the stream_options was merged properly - mock_inference_api.openai_chat_completion.assert_called() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - assert params.stream_options is not None - # Should have both default include_usage and user's option - assert params.stream_options["include_usage"] is True - assert params.stream_options["include_obfuscation"] is False - - -async def test_create_openai_response_with_empty_stream_options(openai_responses_impl, mock_inference_api): - """Test that default stream_options still merges with default include_usage.""" - input_text = "Test empty options" - model = "meta-llama/Llama-3.1-8B-Instruct" - stream_options = ResponseStreamOptions() # Uses default include_obfuscation=True - - mock_inference_api.openai_chat_completion.return_value = fake_stream() - - # Execute - result = await openai_responses_impl.create_openai_response( - input=input_text, - model=model, - stream_options=stream_options, - stream=True, - ) - - # Collect chunks (consume the async iterator) - _ = [chunk async for chunk in result] - - # Verify the stream_options has both defaults - mock_inference_api.openai_chat_completion.assert_called() - call_args = mock_inference_api.openai_chat_completion.call_args - params = call_args.args[0] - assert params.stream_options is not None - assert params.stream_options["include_usage"] is True - assert params.stream_options["include_obfuscation"] is True diff --git a/tests/unit/providers/responses/builtin/test_openai_responses_agent.py b/tests/unit/providers/responses/builtin/test_openai_responses_agent.py new file mode 100644 index 0000000000..82d10ebc8d --- /dev/null +++ b/tests/unit/providers/responses/builtin/test_openai_responses_agent.py @@ -0,0 +1,439 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import patch + +import pytest +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) + +from llama_stack_api import ( + InternalServerError, + OpenAISystemMessageParam, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseError, + OpenAIResponseInputToolWebSearch, + OpenAIResponseObject, + OpenAIResponseObjectStreamResponseFailed, +) +from llama_stack_api.tools import ToolDef, ToolInvocationResult + + +async def test_failed_stream_persists_non_system_messages(openai_responses_impl, mock_responses_store): + input_text = "Hello" + model = "meta-llama/Llama-3.1-8B-Instruct" + + failed_response = OpenAIResponseObject( + created_at=1, + id="resp_failed", + model=model, + output=[], + status="failed", + error=OpenAIResponseError(code="server_error", message="boom"), + store=True, + ) + + class FakeOrchestrator: + def __init__(self, *, ctx, **_kwargs): + self.ctx = ctx + self.final_messages = None + + async def create_response(self): + yield OpenAIResponseObjectStreamResponseFailed(response=failed_response, sequence_number=0) + + with patch( + "llama_stack.providers.inline.responses.builtin.responses.openai_responses.StreamingResponseOrchestrator", + FakeOrchestrator, + ): + stream = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + instructions="system instructions", + stream=True, + store=True, + ) + chunks = [chunk async for chunk in stream] + + assert chunks[-1].type == "response.failed" + mock_responses_store.upsert_response_object.assert_awaited() + + # Find the call that corresponds to the failed response + call_args_list = mock_responses_store.upsert_response_object.call_args_list + failed_call = None + for call in call_args_list: + _, kwargs = call + if kwargs.get("response_object") and kwargs["response_object"].status == "failed": + failed_call = call + break + + assert failed_call is not None, "Expected upsert_response_object to be called with failed response" + _, kwargs = failed_call + messages = kwargs["messages"] + assert messages, "Expected non-system messages to be persisted on failure" + assert all(not isinstance(m, OpenAISystemMessageParam) for m in messages) + assert any(getattr(m, "role", None) == "user" for m in messages) + + +async def test_failed_stream_raises_internal_server_error_in_non_streaming_mode(openai_responses_impl): + """Test that a response.failed event in non-streaming mode raises InternalServerError. + + When stream=False, the caller expects a fully resolved response object, not a stream. + If the underlying stream emits a response.failed event, the implementation must raise + InternalServerError so the caller gets a typed, predictable error rather than a raw + RuntimeError or ValueError. + + Unlike other InternalServerError sites in this file (which guard against internal bugs), + response.failed carries a structured, curated message from the inference backend that + may be directly actionable by the caller (e.g. context window exceeded, invalid prompt). + The message is surfaced to maintain consistency with streaming mode, where the same + response.failed event is returned directly to the caller with the error message visible. + """ + model = "meta-llama/Llama-3.1-8B-Instruct" + provider_error_message = "This model's maximum context length is 4096 tokens" + + failed_response = OpenAIResponseObject( + created_at=1, + id="resp_failed_nonstream", + model=model, + output=[], + status="failed", + error=OpenAIResponseError(code="server_error", message=provider_error_message), + store=False, + ) + + class FakeOrchestrator: + def __init__(self, *, ctx, **_kwargs): + self.ctx = ctx + self.final_messages = None + + async def create_response(self): + yield OpenAIResponseObjectStreamResponseFailed(response=failed_response, sequence_number=0) + + with patch( + "llama_stack.providers.inline.responses.builtin.responses.openai_responses.StreamingResponseOrchestrator", + FakeOrchestrator, + ): + with pytest.raises(InternalServerError) as exc_info: + await openai_responses_impl.create_openai_response( + input="Hello", + model=model, + stream=False, + store=False, + ) + + # The provider message is surfaced to the caller: response.failed errors are + # structured and may be actionable (e.g. context window, invalid prompt). + # This is consistent with streaming mode where the same message is visible. + assert provider_error_message in str(exc_info.value) + + +def test_response_object_incomplete_details_null_when_completed(): + """Test that completed response has incomplete_details as null.""" + from llama_stack_api.openai_responses import OpenAIResponseObject + + response = OpenAIResponseObject( + created_at=1234567890, + id="resp_123", + model="gpt-4o", + object="response", + output=[], + status="completed", + store=False, + ) + + assert response.incomplete_details is None + + # Verify JSON serialization + json_data = response.model_dump(mode="json") + assert json_data["incomplete_details"] is None + + +def test_response_object_incomplete_details_with_max_output_tokens_reason(): + """Test that incomplete response has incomplete_details with max_output_tokens reason.""" + from llama_stack_api.openai_responses import OpenAIResponseIncompleteDetails, OpenAIResponseObject + + response = OpenAIResponseObject( + created_at=1234567890, + id="resp_456", + model="gpt-4o", + object="response", + output=[], + status="incomplete", + store=False, + incomplete_details=OpenAIResponseIncompleteDetails(reason="max_output_tokens"), + ) + + assert response.incomplete_details is not None + assert response.incomplete_details.reason == "max_output_tokens" + + # Verify JSON serialization + json_data = response.model_dump(mode="json") + assert json_data["incomplete_details"] == {"reason": "max_output_tokens"} + + +def test_response_object_incomplete_details_with_length_reason(): + """Test that incomplete response has incomplete_details with length reason.""" + from llama_stack_api.openai_responses import OpenAIResponseIncompleteDetails, OpenAIResponseObject + + response = OpenAIResponseObject( + created_at=1234567890, + id="resp_length", + model="gpt-4o", + object="response", + output=[], + status="incomplete", + store=False, + incomplete_details=OpenAIResponseIncompleteDetails(reason="length"), + ) + + assert response.incomplete_details is not None + assert response.incomplete_details.reason == "length" + + # Verify JSON serialization + json_data = response.model_dump(mode="json") + assert json_data["incomplete_details"] == {"reason": "length"} + + +def test_response_object_incomplete_details_with_max_iterations_exceeded_reason(): + """Test that incomplete response has incomplete_details with max_iterations_exceeded reason.""" + from llama_stack_api.openai_responses import OpenAIResponseIncompleteDetails, OpenAIResponseObject + + response = OpenAIResponseObject( + created_at=1234567890, + id="resp_iters", + model="gpt-4o", + object="response", + output=[], + status="incomplete", + store=False, + incomplete_details=OpenAIResponseIncompleteDetails(reason="max_iterations_exceeded"), + ) + + assert response.incomplete_details is not None + assert response.incomplete_details.reason == "max_iterations_exceeded" + + # Verify JSON serialization + json_data = response.model_dump(mode="json") + assert json_data["incomplete_details"] == {"reason": "max_iterations_exceeded"} + + +async def test_agent_loop_incomplete_due_to_max_output_tokens( + openai_responses_impl, mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api +): + """Test that agent loop marks response incomplete when max_output_tokens is reached.""" + from openai.types.completion_usage import CompletionUsage + + model = "gpt-4o" + max_output_tokens = 25 # Set low enough to be exceeded by first tool call + + # Setup tool mocks - the tool call will trigger a second iteration + mock_tool_groups_api.get_tool.return_value = ToolDef( + name="web_search", description="Search the web", input_schema={} + ) + mock_tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(content="Search results") + + # First stream: returns a tool call after consuming 30 tokens (exceeds limit of 25) + async def first_stream_with_tool_call(): + yield ChatCompletionChunk( + id="test_123", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_abc", + function=ChoiceDeltaToolCallFunction(name="web_search", arguments='{"query":"test"}'), + ) + ], + ), + finish_reason="tool_calls", + ) + ], + created=1234567890, + model=model, + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=30, total_tokens=40), + ) + + # Second stream would consume more tokens, but should be skipped because we already have 30 tokens + # and the next iteration check will see we're at/above the limit + async def second_stream(): + yield ChatCompletionChunk( + id="test_456", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(content="More content", role="assistant"), + finish_reason="stop", + ) + ], + created=1234567890, + model=model, + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=25, total_tokens=35), + ) + + # Mock returns first stream with tool call, then would return second stream but should be blocked + mock_inference_api.openai_chat_completion.side_effect = [first_stream_with_tool_call(), second_stream()] + + # Execute with max_output_tokens that will be exceeded after first tool call + result = await openai_responses_impl.create_openai_response( + input="Test input", + model=model, + max_output_tokens=max_output_tokens, + tools=[OpenAIResponseInputToolWebSearch(type="web_search")], + stream=True, + ) + + # Collect all events + events = [event async for event in result] + + # Find the final event (should be response.incomplete) + final_event = events[-1] + assert final_event.type == "response.incomplete" + assert final_event.response.status == "incomplete" + assert final_event.response.incomplete_details is not None + assert final_event.response.incomplete_details.reason == "max_output_tokens" + + +async def test_agent_loop_incomplete_due_to_max_iterations( + openai_responses_impl, mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api +): + """Test that agent loop marks response incomplete when max iterations is exceeded via tool calls.""" + from openai.types.completion_usage import CompletionUsage + + model = "gpt-4o" + + # Setup tool mocks + mock_tool_groups_api.get_tool.return_value = ToolDef( + name="web_search", description="Search the web", input_schema={} + ) + mock_tool_runtime_api.invoke_tool.return_value = ToolInvocationResult(content="Search results") + + # Create a stream generator factory that returns a tool call (to trigger another iteration) + call_counter = {"count": 0} + + def fake_stream_factory(): + async def fake_stream_with_tool_call(): + call_id = f"call_abc{call_counter['count']}" + call_counter["count"] += 1 + # First chunk with tool call + yield ChatCompletionChunk( + id="test_123", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id=call_id, + function=ChoiceDeltaToolCallFunction( + name="web_search", arguments='{"query":"test"}' + ), + ) + ], + ), + finish_reason="tool_calls", + ) + ], + created=1234567890, + model=model, + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + return fake_stream_with_tool_call() + + # Mock the inference to repeatedly return tool calls to exceed max iterations + # Default max_infer_iters is 5, so we need to keep returning tool calls + mock_inference_api.openai_chat_completion.side_effect = lambda *args, **kwargs: fake_stream_factory() + + # Execute with tool configuration + result = await openai_responses_impl.create_openai_response( + input="Test input", + model=model, + tools=[OpenAIResponseInputToolWebSearch(type="web_search")], + stream=True, + ) + + # Collect all events + events = [event async for event in result] + + # Find the final event (should be response.incomplete) + final_event = events[-1] + assert final_event.type == "response.incomplete" + assert final_event.response.status == "incomplete" + assert final_event.response.incomplete_details is not None + assert final_event.response.incomplete_details.reason == "max_iterations_exceeded" + + +async def test_agent_loop_incomplete_due_to_length_finish_reason(openai_responses_impl, mock_inference_api): + """Test that agent loop marks response incomplete when model returns finish_reason='length'.""" + from openai.types.completion_usage import CompletionUsage + + model = "gpt-4o" + + # Create a stream that returns finish_reason="length" + async def fake_stream_with_length_finish(): + yield ChatCompletionChunk( + id="test_123", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(content="This is a response that was cut off due to", role="assistant"), + finish_reason=None, + ) + ], + created=1234567890, + model=model, + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + # Final chunk with finish_reason="length" + yield ChatCompletionChunk( + id="test_123", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(content=" length"), + finish_reason="length", # This indicates the response was truncated + ) + ], + created=1234567890, + model=model, + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=25, total_tokens=35), + ) + + mock_inference_api.openai_chat_completion.return_value = fake_stream_with_length_finish() + + # Execute + result = await openai_responses_impl.create_openai_response( + input="Test input", + model=model, + stream=True, + ) + + # Collect all events + events = [event async for event in result] + + # Find the final event (should be response.incomplete) + final_event = events[-1] + assert final_event.type == "response.incomplete" + assert final_event.response.status == "incomplete" + assert final_event.response.incomplete_details is not None + assert final_event.response.incomplete_details.reason == "length" diff --git a/tests/unit/providers/responses/builtin/test_openai_responses_core.py b/tests/unit/providers/responses/builtin/test_openai_responses_core.py new file mode 100644 index 0000000000..38942b3175 --- /dev/null +++ b/tests/unit/providers/responses/builtin/test_openai_responses_core.py @@ -0,0 +1,810 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.core.access_control.access_control import default_policy +from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig +from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends +from llama_stack.providers.utils.responses.responses_store import ( + ResponsesStore, + _OpenAIResponseObjectWithInputAndMessages, +) +from llama_stack_api import ( + Order, +) +from llama_stack_api.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIDeveloperMessageParam, + OpenAIJSONSchema, + OpenAIResponseFormatJSONObject, + OpenAIResponseFormatJSONSchema, + OpenAIUserMessageParam, +) +from llama_stack_api.openai_responses import ( + ListOpenAIResponseInputItem, + OpenAIResponseInputMessageContentText, + OpenAIResponseMessage, + OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPCall, + OpenAIResponseOutputMessageWebSearchToolCall, + OpenAIResponseText, + OpenAIResponseTextFormat, +) +from tests.unit.providers.responses.builtin.test_openai_responses_helpers import fake_stream + + +async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with a simple string input.""" + # Setup + input_text = "What is the capital of Ireland?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Load the chat completion fixture + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + temperature=0.1, + stream=True, # Enable streaming to test content part events + ) + + # For streaming response, collect all chunks + chunks = [chunk async for chunk in result] + + mock_inference_api.openai_chat_completion.assert_called_once_with( + OpenAIChatCompletionRequestWithExtraBody( + model=model, + messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)], + response_format=None, + tools=None, + stream=True, + temperature=0.1, + stream_options={ + "include_usage": True, + }, + ) + ) + + # Should have content part events for text streaming + # Expected: response.created, response.in_progress, content_part.added, output_text.delta, content_part.done, response.completed + assert len(chunks) >= 5 + assert chunks[0].type == "response.created" + assert any(chunk.type == "response.in_progress" for chunk in chunks) + + # Check for content part events + content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"] + content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"] + text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"] + + assert len(content_part_added_events) >= 1, "Should have content_part.added event for text" + assert len(content_part_done_events) >= 1, "Should have content_part.done event for text" + assert len(text_delta_events) >= 1, "Should have text delta events" + + added_event = content_part_added_events[0] + done_event = content_part_done_events[0] + assert added_event.content_index == 0 + assert done_event.content_index == 0 + assert added_event.output_index == done_event.output_index == 0 + assert added_event.item_id == done_event.item_id + assert added_event.response_id == done_event.response_id + + # Verify final event is completion + assert chunks[-1].type == "response.completed" + + # When streaming, the final response is in the last chunk + final_response = chunks[-1].response + assert final_response.model == model + assert len(final_response.output) == 1 + assert isinstance(final_response.output[0], OpenAIResponseMessage) + + +async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api, mock_files_api): + """Test creating an OpenAI response with multiple messages.""" + # Setup + input_messages = [ + OpenAIResponseMessage(role="developer", content="You are a helpful assistant", name=None), + OpenAIResponseMessage(role="user", content="Name some towns in Ireland", name=None), + OpenAIResponseMessage( + role="assistant", + content=[ + OpenAIResponseInputMessageContentText(text="Galway, Longford, Sligo"), + OpenAIResponseInputMessageContentText(text="Dublin"), + ], + name=None, + ), + OpenAIResponseMessage(role="user", content="Which is the largest town in Ireland?", name=None), + ] + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + await openai_responses_impl.create_openai_response( + input=input_messages, + model=model, + temperature=0.1, + ) + + # Verify the the correct messages were sent to the inference API i.e. + # All of the responses message were convered to the chat completion message objects + call_args = mock_inference_api.openai_chat_completion.call_args_list[0] + params = call_args.args[0] + inference_messages = params.messages + for i, m in enumerate(input_messages): + if isinstance(m.content, str): + assert inference_messages[i].content == m.content + else: + assert inference_messages[i].content[0].text == m.content[0].text + assert isinstance(inference_messages[i].content[0], OpenAIChatCompletionContentPartTextParam) + assert inference_messages[i].role == m.role + if m.role == "user": + assert isinstance(inference_messages[i], OpenAIUserMessageParam) + elif m.role == "assistant": + assert isinstance(inference_messages[i], OpenAIAssistantMessageParam) + else: + assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam) + + +async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store): + """Test prepending a basic previous response to a new response.""" + + input_item_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], + role="user", + ) + response_output_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseOutputMessageContentOutputText(text="fake_response")], + status="completed", + role="assistant", + ) + previous_response = _OpenAIResponseObjectWithInputAndMessages( + created_at=1, + id="resp_123", + model="fake_model", + output=[response_output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + messages=[OpenAIUserMessageParam(content="fake_previous_input")], + store=True, + ) + mock_responses_store.get_response_object.return_value = previous_response + + input = await openai_responses_impl._prepend_previous_response("fake_input", previous_response) + + assert len(input) == 3 + # Check for previous input + assert isinstance(input[0], OpenAIResponseMessage) + assert input[0].content[0].text == "fake_previous_input" + # Check for previous output + assert isinstance(input[1], OpenAIResponseMessage) + assert input[1].content[0].text == "fake_response" + # Check for new input + assert isinstance(input[2], OpenAIResponseMessage) + assert input[2].content == "fake_input" + + +async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store): + """Test prepending a web search previous response to a new response.""" + input_item_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], + role="user", + ) + output_web_search = OpenAIResponseOutputMessageWebSearchToolCall( + id="ws_123", + status="completed", + ) + output_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseOutputMessageContentOutputText(text="fake_web_search_response")], + status="completed", + role="assistant", + ) + response = _OpenAIResponseObjectWithInputAndMessages( + created_at=1, + id="resp_123", + model="fake_model", + output=[output_web_search, output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + messages=[OpenAIUserMessageParam(content="test input")], + store=True, + ) + mock_responses_store.get_response_object.return_value = response + + input_messages = [OpenAIResponseMessage(content="fake_input", role="user")] + input = await openai_responses_impl._prepend_previous_response(input_messages, response) + + assert len(input) == 4 + # Check for previous input + assert isinstance(input[0], OpenAIResponseMessage) + assert input[0].content[0].text == "fake_previous_input" + # Check for previous output web search tool call + assert isinstance(input[1], OpenAIResponseOutputMessageWebSearchToolCall) + # Check for previous output web search response + assert isinstance(input[2], OpenAIResponseMessage) + assert input[2].content[0].text == "fake_web_search_response" + # Check for new input + assert isinstance(input[3], OpenAIResponseMessage) + assert input[3].content == "fake_input" + + +async def test_prepend_previous_response_mcp_tool_call(openai_responses_impl, mock_responses_store): + """Test prepending a previous response which included an mcp tool call to a new response.""" + input_item_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseInputMessageContentText(text="fake_previous_input")], + role="user", + ) + output_tool_call = OpenAIResponseOutputMessageMCPCall( + id="ws_123", + name="fake-tool", + arguments="fake-arguments", + server_label="fake-label", + ) + output_message = OpenAIResponseMessage( + id="123", + content=[OpenAIResponseOutputMessageContentOutputText(text="fake_tool_call_response")], + status="completed", + role="assistant", + ) + response = _OpenAIResponseObjectWithInputAndMessages( + created_at=1, + id="resp_123", + model="fake_model", + output=[output_tool_call, output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + messages=[OpenAIUserMessageParam(content="test input")], + store=True, + ) + mock_responses_store.get_response_object.return_value = response + + input_messages = [OpenAIResponseMessage(content="fake_input", role="user")] + input = await openai_responses_impl._prepend_previous_response(input_messages, response) + + assert len(input) == 4 + # Check for previous input + assert isinstance(input[0], OpenAIResponseMessage) + assert input[0].content[0].text == "fake_previous_input" + # Check for previous output MCP tool call + assert isinstance(input[1], OpenAIResponseOutputMessageMCPCall) + # Check for previous output web search response + assert isinstance(input[2], OpenAIResponseMessage) + assert input[2].content[0].text == "fake_tool_call_response" + # Check for new input + assert isinstance(input[3], OpenAIResponseMessage) + assert input[3].content == "fake_input" + + +async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api): + # Setup + input_text = "What is the capital of Ireland?" + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + instructions=instructions, + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + sent_messages = params.messages + + # Check that instructions were prepended as a system message + assert len(sent_messages) == 2 + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + assert sent_messages[1].role == "user" + assert sent_messages[1].content == input_text + + +async def test_create_openai_response_with_instructions_and_multiple_messages( + openai_responses_impl, mock_inference_api, mock_files_api +): + # Setup + input_messages = [ + OpenAIResponseMessage(role="user", content="Name some towns in Ireland", name=None), + OpenAIResponseMessage( + role="assistant", + content="Galway, Longford, Sligo", + name=None, + ), + OpenAIResponseMessage(role="user", content="Which is the largest?", name=None), + ] + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + await openai_responses_impl.create_openai_response( + input=input_messages, + model=model, + instructions=instructions, + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + sent_messages = params.messages + + # Check that instructions were prepended as a system message + assert len(sent_messages) == 4 # 1 system + 3 input messages + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + + # Check the rest of the messages were converted correctly + assert sent_messages[1].role == "user" + assert sent_messages[1].content == "Name some towns in Ireland" + assert sent_messages[2].role == "assistant" + assert sent_messages[2].content == "Galway, Longford, Sligo" + assert sent_messages[3].role == "user" + assert sent_messages[3].content == "Which is the largest?" + + +async def test_create_openai_response_with_instructions_and_previous_response( + openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test prepending both instructions and previous response.""" + + input_item_message = OpenAIResponseMessage( + id="123", + content="Name some towns in Ireland", + role="user", + ) + response_output_message = OpenAIResponseMessage( + id="123", + content="Galway, Longford, Sligo", + status="completed", + role="assistant", + ) + response = _OpenAIResponseObjectWithInputAndMessages( + created_at=1, + id="resp_123", + model="fake_model", + output=[response_output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + messages=[ + OpenAIUserMessageParam(content="Name some towns in Ireland"), + OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"), + ], + store=True, + ) + mock_responses_store.get_response_object.return_value = response + + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + await openai_responses_impl.create_openai_response( + input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123" + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + sent_messages = params.messages + + # Check that instructions were prepended as a system message + assert len(sent_messages) == 4, sent_messages + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + + # Check the rest of the messages were converted correctly + assert sent_messages[1].role == "user" + assert sent_messages[1].content == "Name some towns in Ireland" + assert sent_messages[2].role == "assistant" + assert sent_messages[2].content == "Galway, Longford, Sligo" + assert sent_messages[3].role == "user" + assert sent_messages[3].content == "Which is the largest?" + + +async def test_create_openai_response_with_previous_response_instructions( + openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test prepending instructions and previous response with instructions.""" + + input_item_message = OpenAIResponseMessage( + id="123", + content="Name some towns in Ireland", + role="user", + ) + response_output_message = OpenAIResponseMessage( + id="123", + content="Galway, Longford, Sligo", + status="completed", + role="assistant", + ) + response = _OpenAIResponseObjectWithInputAndMessages( + created_at=1, + id="resp_123", + model="fake_model", + output=[response_output_message], + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[input_item_message], + messages=[ + OpenAIUserMessageParam(content="Name some towns in Ireland"), + OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"), + ], + instructions="You are a helpful assistant.", + store=True, + ) + mock_responses_store.get_response_object.return_value = response + + model = "meta-llama/Llama-3.1-8B-Instruct" + instructions = "You are a geography expert. Provide concise answers." + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + await openai_responses_impl.create_openai_response( + input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123" + ) + + # Verify + mock_inference_api.openai_chat_completion.assert_called_once() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + sent_messages = params.messages + + # Check that instructions were prepended as a system message + # and that the previous response instructions were not carried over + assert len(sent_messages) == 4, sent_messages + assert sent_messages[0].role == "system" + assert sent_messages[0].content == instructions + + # Check the rest of the messages were converted correctly + assert sent_messages[1].role == "user" + assert sent_messages[1].content == "Name some towns in Ireland" + assert sent_messages[2].role == "assistant" + assert sent_messages[2].content == "Galway, Longford, Sligo" + assert sent_messages[3].role == "user" + assert sent_messages[3].content == "Which is the largest?" + + +async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store): + """Test that list_openai_response_input_items properly delegates to responses_store with correct parameters.""" + # Setup + response_id = "resp_123" + after = "msg_after" + before = "msg_before" + include = ["metadata"] + limit = 5 + order = Order.asc + + input_message = OpenAIResponseMessage( + id="msg_123", + content="Test message", + role="user", + ) + + expected_result = ListOpenAIResponseInputItem(data=[input_message]) + mock_responses_store.list_response_input_items.return_value = expected_result + + # Execute with all parameters to test delegation + result = await openai_responses_impl.list_openai_response_input_items( + response_id, after=after, before=before, include=include, limit=limit, order=order + ) + + # Verify all parameters are passed through correctly to the store + mock_responses_store.list_response_input_items.assert_called_once_with( + response_id, after, before, include, limit, order + ) + + # Verify the result is returned as-is from the store + assert result.object == "list" + assert len(result.data) == 1 + assert result.data[0].id == "msg_123" + + +async def test_responses_store_list_input_items_logic(): + """Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting.""" + + # Create mock store and response store + mock_sql_store = AsyncMock() + backend_name = "sql_responses_test" + register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")}) + responses_store = ResponsesStore( + ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy() + ) + responses_store.sql_store = mock_sql_store + + # Setup test data - multiple input items + input_items = [ + OpenAIResponseMessage(id="msg_1", content="First message", role="user"), + OpenAIResponseMessage(id="msg_2", content="Second message", role="user"), + OpenAIResponseMessage(id="msg_3", content="Third message", role="user"), + OpenAIResponseMessage(id="msg_4", content="Fourth message", role="user"), + ] + + response_with_input = _OpenAIResponseObjectWithInputAndMessages( + id="resp_123", + model="test_model", + created_at=1234567890, + object="response", + status="completed", + output=[], + text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))), + input=input_items, + messages=[OpenAIUserMessageParam(content="First message")], + store=True, + ) + + # Mock the get_response_object method to return our test data + mock_sql_store.fetch_one.return_value = {"response_object": response_with_input.model_dump()} + + # Test 1: Default behavior (no limit, desc order) + result = await responses_store.list_response_input_items("resp_123") + assert result.object == "list" + assert len(result.data) == 4 + # Should be reversed for desc order + assert result.data[0].id == "msg_4" + assert result.data[1].id == "msg_3" + assert result.data[2].id == "msg_2" + assert result.data[3].id == "msg_1" + + # Test 2: With limit=2, desc order + result = await responses_store.list_response_input_items("resp_123", limit=2, order=Order.desc) + assert result.object == "list" + assert len(result.data) == 2 + # Should be first 2 items in desc order + assert result.data[0].id == "msg_4" + assert result.data[1].id == "msg_3" + + # Test 3: With limit=2, asc order + result = await responses_store.list_response_input_items("resp_123", limit=2, order=Order.asc) + assert result.object == "list" + assert len(result.data) == 2 + # Should be first 2 items in original order (asc) + assert result.data[0].id == "msg_1" + assert result.data[1].id == "msg_2" + + # Test 4: Asc order without limit + result = await responses_store.list_response_input_items("resp_123", order=Order.asc) + assert result.object == "list" + assert len(result.data) == 4 + # Should be in original order (asc) + assert result.data[0].id == "msg_1" + assert result.data[1].id == "msg_2" + assert result.data[2].id == "msg_3" + assert result.data[3].id == "msg_4" + + # Test 5: Large limit (larger than available items) + result = await responses_store.list_response_input_items("resp_123", limit=10, order=Order.desc) + assert result.object == "list" + assert len(result.data) == 4 # Should return all available items + assert result.data[0].id == "msg_4" + + # Test 6: Zero limit edge case + result = await responses_store.list_response_input_items("resp_123", limit=0, order=Order.asc) + assert result.object == "list" + assert len(result.data) == 0 # Should return no items + + +async def test_store_response_uses_rehydrated_input_with_previous_response( + openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test that _store_response uses the full re-hydrated input (including previous responses) + rather than just the original input when previous_response_id is provided.""" + + # Setup - Create a previous response that should be included in the stored input + previous_response = _OpenAIResponseObjectWithInputAndMessages( + id="resp-previous-123", + object="response", + created_at=1234567890, + model="meta-llama/Llama-3.1-8B-Instruct", + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[ + OpenAIResponseMessage( + id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")] + ) + ], + output=[ + OpenAIResponseMessage( + id="msg-prev-assistant", + role="assistant", + content=[OpenAIResponseOutputMessageContentOutputText(text="2+2 equals 4.")], + ) + ], + messages=[ + OpenAIUserMessageParam(content="What is 2+2?"), + OpenAIAssistantMessageParam(content="2+2 equals 4."), + ], + store=True, + ) + + mock_responses_store.get_response_object.return_value = previous_response + + current_input = "Now what is 3+3?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute - Create response with previous_response_id + result = await openai_responses_impl.create_openai_response( + input=current_input, + model=model, + previous_response_id="resp-previous-123", + store=True, + ) + + store_call_args = mock_responses_store.upsert_response_object.call_args + stored_input = store_call_args.kwargs["input"] + + # Verify that the stored input contains the full re-hydrated conversation: + # 1. Previous user message + # 2. Previous assistant response + # 3. Current user message + assert len(stored_input) == 3 + + assert stored_input[0].role == "user" + assert stored_input[0].content[0].text == "What is 2+2?" + + assert stored_input[1].role == "assistant" + assert stored_input[1].content[0].text == "2+2 equals 4." + + assert stored_input[2].role == "user" + assert stored_input[2].content == "Now what is 3+3?" + + # Verify the response itself is correct + assert result.model == model + assert result.status == "completed" + + +@pytest.mark.parametrize( + "text_format, response_format", + [ + (OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), None), + ( + OpenAIResponseText(format=OpenAIResponseTextFormat(name="Test", schema={"foo": "bar"}, type="json_schema")), + OpenAIResponseFormatJSONSchema(json_schema=OpenAIJSONSchema(name="Test", schema={"foo": "bar"})), + ), + (OpenAIResponseText(format=OpenAIResponseTextFormat(type="json_object")), OpenAIResponseFormatJSONObject()), + # ensure text param with no format specified defaults to None + (OpenAIResponseText(format=None), None), + # ensure text param of None defaults to None + (None, None), + ], +) +async def test_create_openai_response_with_text_format( + openai_responses_impl, mock_inference_api, text_format, response_format +): + """Test creating Responses with text formats.""" + # Setup + input_text = "How hot it is in San Francisco today?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + _result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + text=text_format, + ) + + # Verify + first_call = mock_inference_api.openai_chat_completion.call_args_list[0] + first_params = first_call.args[0] + assert first_params.messages[0].content == input_text + assert first_params.response_format == response_format + + +async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with an invalid text format.""" + # Setup + input_text = "How hot it is in San Francisco today?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Execute + with pytest.raises(ValueError): + _result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + text=OpenAIResponseText(format={"type": "invalid"}), + ) + + +async def test_create_openai_response_with_output_types_as_input( + openai_responses_impl, mock_inference_api, mock_responses_store +): + """Test that response outputs can be used as inputs in multi-turn conversations. + + Before adding OpenAIResponseOutput types to OpenAIResponseInput, + creating a _OpenAIResponseObjectWithInputAndMessages with some output types + in the input field would fail with a Pydantic ValidationError. + + This test simulates storing a response where the input contains output message + types (MCP calls, function calls), which happens in multi-turn conversations. + """ + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Mock the inference response + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Create a response with store=True to trigger the storage path + result = await openai_responses_impl.create_openai_response( + input="What's the weather?", + model=model, + stream=True, + temperature=0.1, + store=True, + ) + + # Consume the stream + _ = [chunk async for chunk in result] + + # Verify store was called + assert mock_responses_store.upsert_response_object.called + + # Get the stored data + store_call_args = mock_responses_store.upsert_response_object.call_args + stored_response = store_call_args.kwargs["response_object"] + + # Now simulate a multi-turn conversation where outputs become inputs + input_with_output_types = [ + OpenAIResponseMessage(role="user", content="What's the weather?", name=None), + # These output types need to be valid OpenAIResponseInput + OpenAIResponseOutputMessageFunctionToolCall( + call_id="call_123", + name="get_weather", + arguments='{"city": "Tokyo"}', + type="function_call", + ), + OpenAIResponseOutputMessageMCPCall( + id="mcp_456", + type="mcp_call", + server_label="weather_server", + name="get_temperature", + arguments='{"location": "Tokyo"}', + output="25°C", + ), + ] + + # This simulates storing a response in a multi-turn conversation + # where previous outputs are included in the input. + stored_with_outputs = _OpenAIResponseObjectWithInputAndMessages( + id=stored_response.id, + created_at=stored_response.created_at, + model=stored_response.model, + status=stored_response.status, + output=stored_response.output, + input=input_with_output_types, # This will trigger Pydantic validation + messages=None, + store=True, + ) + + assert stored_with_outputs.input == input_with_output_types + assert len(stored_with_outputs.input) == 3 diff --git a/tests/unit/providers/responses/builtin/test_openai_responses_helpers.py b/tests/unit/providers/responses/builtin/test_openai_responses_helpers.py new file mode 100644 index 0000000000..bd8fdcc2aa --- /dev/null +++ b/tests/unit/providers/responses/builtin/test_openai_responses_helpers.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) + +from tests.unit.providers.responses.builtin.fixtures import load_chat_completion_fixture + + +async def fake_stream(fixture: str = "simple_chat_completion.yaml"): + value = load_chat_completion_fixture(fixture) + yield ChatCompletionChunk( + id=value.id, + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + content=c.message.content, + role=c.message.role, + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id=t.id, + function=ChoiceDeltaToolCallFunction( + name=t.function.name, + arguments=t.function.arguments, + ), + ) + for t in (c.message.tool_calls or []) + ], + ), + ) + for c in value.choices + ], + created=1, + model=value.model, + object="chat.completion.chunk", + ) diff --git a/tests/unit/providers/responses/builtin/test_openai_responses_params.py b/tests/unit/providers/responses/builtin/test_openai_responses_params.py new file mode 100644 index 0000000000..2d870781d6 --- /dev/null +++ b/tests/unit/providers/responses/builtin/test_openai_responses_params.py @@ -0,0 +1,811 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) + +from llama_stack.providers.inline.responses.builtin.responses.openai_responses import ( + OpenAIResponsesImpl, +) +from llama_stack.providers.remote.inference.openai.config import OpenAIConfig +from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter +from llama_stack.providers.utils.responses.responses_store import ( + _OpenAIResponseObjectWithInputAndMessages, +) +from llama_stack_api import ( + ResponseStreamOptions, + ResponseTruncation, +) +from llama_stack_api.inference import ( + OpenAIAssistantMessageParam, + OpenAIUserMessageParam, + ServiceTier, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolFunction, + OpenAIResponseMessage, + OpenAIResponseText, + OpenAIResponseTextFormat, +) +from llama_stack_api.tools import ToolDef, ToolInvocationResult +from tests.unit.providers.responses.builtin.test_openai_responses_helpers import fake_stream + + +async def test_create_openai_response_with_max_output_tokens_non_streaming( + openai_responses_impl, mock_inference_api, mock_responses_store +): + """Test that max_output_tokens is properly handled in non-streaming responses.""" + input_text = "Write a long story about AI." + model = "meta-llama/Llama-3.1-8B-Instruct" + max_tokens = 100 + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + max_output_tokens=max_tokens, + stream=False, + store=True, + ) + + # Verify response includes the max_output_tokens + assert result.max_output_tokens == max_tokens + assert result.model == model + assert result.status == "completed" + + # Verify the max_output_tokens was passed to inference API + mock_inference_api.openai_chat_completion.assert_called() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + assert params.max_completion_tokens == max_tokens + + # Verify the max_output_tokens was stored + mock_responses_store.upsert_response_object.assert_called() + store_call_args = mock_responses_store.upsert_response_object.call_args + stored_response = store_call_args.kwargs["response_object"] + assert stored_response.max_output_tokens == max_tokens + + +async def test_create_openai_response_with_max_output_tokens_streaming( + openai_responses_impl, mock_inference_api, mock_responses_store +): + """Test that max_output_tokens is properly handled in streaming responses.""" + input_text = "Explain machine learning in detail." + model = "meta-llama/Llama-3.1-8B-Instruct" + max_tokens = 200 + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + max_output_tokens=max_tokens, + stream=True, + store=True, + ) + + # Collect all chunks + chunks = [chunk async for chunk in result] + + # Verify max_output_tokens is in the created event + created_event = chunks[0] + assert created_event.type == "response.created" + assert created_event.response.max_output_tokens == max_tokens + + # Verify max_output_tokens is in the completed event + completed_event = chunks[-1] + assert completed_event.type == "response.completed" + assert completed_event.response.max_output_tokens == max_tokens + + # Verify the max_output_tokens was passed to inference API + mock_inference_api.openai_chat_completion.assert_called() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + assert params.max_completion_tokens == max_tokens + + # Verify the max_output_tokens was stored + mock_responses_store.upsert_response_object.assert_called() + store_call_args = mock_responses_store.upsert_response_object.call_args + stored_response = store_call_args.kwargs["response_object"] + assert stored_response.max_output_tokens == max_tokens + + +async def test_create_openai_response_with_max_output_tokens_boundary_value(openai_responses_impl, mock_inference_api): + """Test that max_output_tokens accepts the minimum valid value of 16.""" + input_text = "Hi" + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute with minimum valid value + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + max_output_tokens=16, + stream=False, + ) + + # Verify it accepts 16 + assert result.max_output_tokens == 16 + assert result.status == "completed" + + # Verify the inference API was called with max_completion_tokens=16 + mock_inference_api.openai_chat_completion.assert_called() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + assert params.max_completion_tokens == 16 + + +async def test_create_openai_response_with_max_output_tokens_and_tools(openai_responses_impl, mock_inference_api): + """Test that max_output_tokens works correctly with tool calls.""" + input_text = "What's the weather in San Francisco?" + model = "meta-llama/Llama-3.1-8B-Instruct" + max_tokens = 150 + + openai_responses_impl.tool_groups_api.get_tool.return_value = ToolDef( + name="get_weather", + toolgroup_id="weather", + description="Get weather information", + input_schema={ + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + ) + + openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult( + status="completed", + content="Sunny, 72°F", + ) + + # Mock two inference calls: one for tool call, one for final response + mock_inference_api.openai_chat_completion.side_effect = [ + fake_stream("tool_call_completion.yaml"), + fake_stream(), + ] + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + max_output_tokens=max_tokens, + stream=False, + tools=[ + OpenAIResponseInputToolFunction( + name="get_weather", + description="Get weather information", + parameters={"location": "string"}, + ) + ], + ) + + # Verify max_output_tokens is preserved + assert result.max_output_tokens == max_tokens + assert result.status == "completed" + + # Verify both inference calls received max_completion_tokens + assert mock_inference_api.openai_chat_completion.call_count == 2 + for call in mock_inference_api.openai_chat_completion.call_args_list: + params = call.args[0] + # The first call gets the full max_tokens, subsequent calls get remaining tokens + assert params.max_completion_tokens is not None + assert params.max_completion_tokens <= max_tokens + + +@pytest.mark.parametrize("store", [False, True]) +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize( + "param_name,param_value,backend_param_name,backend_expected_value,response_expected_value,stored_expected_value", + [ + ("temperature", 1.5, "temperature", 1.5, 1.5, 1.5), + ("safety_identifier", "user-123", "safety_identifier", "user-123", "user-123", "user-123"), + ("max_output_tokens", 500, "max_completion_tokens", 500, 500, 500), + ( + "prompt_cache_key", + "geography-cache-001", + "prompt_cache_key", + "geography-cache-001", + "geography-cache-001", + "geography-cache-001", + ), + ("service_tier", ServiceTier.flex, "service_tier", "flex", "flex", ServiceTier.default.value), + ("top_p", 0.9, "top_p", 0.9, 0.9, 0.9), + ("frequency_penalty", 0.5, "frequency_penalty", 0.5, 0.5, 0.5), + ("presence_penalty", 0.3, "presence_penalty", 0.3, 0.3, 0.3), + ("top_logprobs", 5, "top_logprobs", 5, 5, 5), + ( + "extra_body", + {"chat_template_kwargs": {"thinking": True}}, + "extra_body", + {"chat_template_kwargs": {"thinking": True}}, + None, + None, + ), + ], +) +async def test_params_passed_through_full_chain_to_backend_service( + param_name, + param_value, + backend_param_name, + backend_expected_value, + response_expected_value, + stored_expected_value, + stream, + store, + mock_responses_store, +): + """Test that parameters which pass through to the backend service are correctly propagated. + + Only parameters that are forwarded as kwargs to the underlying chat completions API belong + here. Parameters handled internally by the responses layer (e.g. truncation) should be + tested separately since they don't produce a backend kwarg assertion. + + This test should not act differently based on the param_name/param_value/etc. Needing changes + in behavior based on those params suggests a bug in the implementation. + + This test may act differently based on : + - stream: whether the response is streamed or not + - store: whether the response is persisted via the responses store + """ + config = OpenAIConfig(api_key="test-key") + openai_adapter = OpenAIInferenceAdapter(config=config) + openai_adapter.provider_data_api_key_field = None + + mock_model_store = AsyncMock() + mock_model_store.has_model = AsyncMock(return_value=False) + openai_adapter.model_store = mock_model_store + + openai_responses_impl = OpenAIResponsesImpl( + inference_api=openai_adapter, + tool_groups_api=AsyncMock(), + tool_runtime_api=AsyncMock(), + responses_store=mock_responses_store, + vector_io_api=AsyncMock(), + safety_api=AsyncMock(), + conversations_api=AsyncMock(), + prompts_api=AsyncMock(), + files_api=AsyncMock(), + connectors_api=AsyncMock(), + ) + + with patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") as mock_openai_class: + mock_client = MagicMock() + mock_chat_completions = AsyncMock() + mock_client.chat.completions.create = mock_chat_completions + mock_openai_class.return_value = mock_client + + if stream: + mock_chat_completions.return_value = fake_stream() + else: + mock_response = MagicMock() + mock_response.id = "chatcmpl-123" + mock_response.choices = [ + MagicMock( + index=0, + message=MagicMock(content="Test response", role="assistant", tool_calls=None), + finish_reason="stop", + ) + ] + mock_response.model = "fake-model" + mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + mock_chat_completions.return_value = mock_response + + result = await openai_responses_impl.create_openai_response( + **{ + "input": "Test message", + "model": "fake-model", + "stream": stream, + "store": store, + param_name: param_value, + } + ) + if stream: + chunks = [chunk async for chunk in result] + created_event = chunks[0] + assert created_event.type == "response.created" + assert getattr(created_event.response, param_name, None) == response_expected_value, ( + f"Expected created {param_name}={response_expected_value}, got {getattr(created_event.response, param_name, None)}" + ) + completed_event = chunks[-1] + assert completed_event.type == "response.completed" + assert getattr(completed_event.response, param_name, None) == stored_expected_value, ( + f"Expected completed {param_name}={stored_expected_value}, got {getattr(completed_event.response, param_name, None)}" + ) + + mock_chat_completions.assert_called_once() + call_kwargs = mock_chat_completions.call_args[1] + + assert backend_param_name in call_kwargs, f"{backend_param_name} not found in backend call" + assert call_kwargs[backend_param_name] == backend_expected_value, ( + f"Expected {backend_param_name}={backend_expected_value}, got {call_kwargs[backend_param_name]}" + ) + + if store: + mock_responses_store.upsert_response_object.assert_called() + stored_response = mock_responses_store.upsert_response_object.call_args.kwargs["response_object"] + assert getattr(stored_response, param_name, None) == stored_expected_value, ( + f"Expected stored {param_name}={stored_expected_value}, got {getattr(stored_response, param_name, None)}" + ) + else: + mock_responses_store.upsert_response_object.assert_not_called() + + +async def test_create_openai_response_with_truncation_disabled_streaming( + openai_responses_impl, mock_inference_api, mock_responses_store +): + """Test that truncation='disabled' is properly handled in streaming responses.""" + input_text = "Explain machine learning comprehensively." + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + truncation=ResponseTruncation.disabled, + stream=True, + store=True, + ) + + # Collect all chunks + chunks = [chunk async for chunk in result] + + # Verify truncation is in the created event + created_event = chunks[0] + assert created_event.type == "response.created" + assert created_event.response.truncation == ResponseTruncation.disabled + + # Verify truncation is in the completed event + completed_event = chunks[-1] + assert completed_event.type == "response.completed" + assert completed_event.response.truncation == ResponseTruncation.disabled + + mock_inference_api.openai_chat_completion.assert_called() + + # Verify the truncation was stored + mock_responses_store.upsert_response_object.assert_called() + store_call_args = mock_responses_store.upsert_response_object.call_args + stored_response = store_call_args.kwargs["response_object"] + assert stored_response.truncation == ResponseTruncation.disabled + + +async def test_create_openai_response_with_truncation_auto_streaming( + openai_responses_impl, mock_inference_api, mock_responses_store +): + """Test that truncation='auto' raises an error since it's not yet supported.""" + input_text = "Tell me about quantum computing." + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + truncation=ResponseTruncation.auto, + stream=True, + store=True, + ) + + # Collect all chunks + chunks = [chunk async for chunk in result] + + # Verify truncation is in the created event + created_event = chunks[0] + assert created_event.type == "response.created" + assert created_event.response.truncation == ResponseTruncation.auto + + # Verify the response failed due to unsupported truncation mode + failed_event = chunks[-1] + assert failed_event.type == "response.failed" + assert failed_event.response.truncation == ResponseTruncation.auto + assert failed_event.response.error is not None + assert failed_event.response.error.code == "server_error" + assert "Truncation mode 'auto' is not supported" in failed_event.response.error.message + + # Inference API should not be called since error occurs before inference + mock_inference_api.openai_chat_completion.assert_not_called() + + # Verify the failed response was stored + mock_responses_store.upsert_response_object.assert_called() + store_call_args = mock_responses_store.upsert_response_object.call_args + stored_response = store_call_args.kwargs["response_object"] + assert stored_response.truncation == ResponseTruncation.auto + assert stored_response.status == "failed" + + +async def test_create_openai_response_with_prompt_cache_key_and_previous_response( + openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test that prompt_cache_key works correctly with previous_response_id.""" + # Setup previous response + previous_response = _OpenAIResponseObjectWithInputAndMessages( + id="resp-prev-123", + object="response", + created_at=1234567890, + model="meta-llama/Llama-3.1-8B-Instruct", + status="completed", + text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), + input=[OpenAIResponseMessage(id="msg-1", role="user", content="First question")], + output=[OpenAIResponseMessage(id="msg-2", role="assistant", content="First answer")], + messages=[ + OpenAIUserMessageParam(content="First question"), + OpenAIAssistantMessageParam(content="First answer"), + ], + prompt_cache_key="conversation-cache-001", + store=True, + ) + + mock_responses_store.get_response_object.return_value = previous_response + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Create a new response with the same cache key + result = await openai_responses_impl.create_openai_response( + input="Second question", + model="meta-llama/Llama-3.1-8B-Instruct", + previous_response_id="resp-prev-123", + prompt_cache_key="conversation-cache-001", + store=True, + ) + + # Verify cache key is preserved + assert result.prompt_cache_key == "conversation-cache-001" + assert result.status == "completed" + + # Verify the cache key was stored + mock_responses_store.upsert_response_object.assert_called() + store_call_args = mock_responses_store.upsert_response_object.call_args + stored_response = store_call_args.kwargs["response_object"] + assert stored_response.prompt_cache_key == "conversation-cache-001" + + +async def test_create_openai_response_with_service_tier(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with service_tier parameter.""" + # Setup + input_text = "What is the capital of France?" + model = "meta-llama/Llama-3.1-8B-Instruct" + service_tier = ServiceTier.flex + + # Load the chat completion fixture + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute - non-streaming to get final response directly + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + service_tier=service_tier, + stream=False, + ) + + # Verify service_tier is preserved in the response (as string) + assert result.service_tier == ServiceTier.default.value + assert result.status == "completed" + + # Verify inference call received service_tier + mock_inference_api.openai_chat_completion.assert_called_once() + params = mock_inference_api.openai_chat_completion.call_args.args[0] + assert params.service_tier == service_tier + + +async def test_create_openai_response_service_tier_auto_transformation(openai_responses_impl, mock_inference_api): + """Test that service_tier 'auto' is transformed to actual tier from provider response.""" + # Setup + input_text = "Hello" + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Mock a response that returns actual service tier when "auto" was requested + async def fake_stream_with_service_tier(): + yield ChatCompletionChunk( + id="chatcmpl-123", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(content="Hi there!", role="assistant"), + finish_reason="stop", + ) + ], + created=1234567890, + model=model, + object="chat.completion.chunk", + service_tier="default", # Provider returns actual tier used + ) + + mock_inference_api.openai_chat_completion.return_value = fake_stream_with_service_tier() + + # Execute with "auto" service tier + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + service_tier=ServiceTier.auto, + stream=False, + ) + + # Verify the response has the actual tier from provider, not "auto" + assert result.service_tier == "default", "service_tier should be transformed from 'auto' to actual tier" + assert result.service_tier != ServiceTier.auto.value, "service_tier should not remain as 'auto'" + assert result.status == "completed" + + # Verify inference was called with "auto" + mock_inference_api.openai_chat_completion.assert_called_once() + params = mock_inference_api.openai_chat_completion.call_args.args[0] + assert params.service_tier == "auto" + + +async def test_create_openai_response_service_tier_propagation_streaming(openai_responses_impl, mock_inference_api): + """Test that service_tier from chat completion is propagated to response object in streaming mode.""" + # Setup + input_text = "Tell me about AI" + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Mock streaming response with service_tier + async def fake_stream_with_service_tier(): + yield ChatCompletionChunk( + id="chatcmpl-456", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(content="AI is", role="assistant"), + finish_reason=None, + ) + ], + created=1234567890, + model=model, + object="chat.completion.chunk", + service_tier="priority", # First chunk with service_tier + ) + yield ChatCompletionChunk( + id="chatcmpl-456", + choices=[ + Choice( + index=0, + delta=ChoiceDelta(content=" amazing!"), + finish_reason="stop", + ) + ], + created=1234567890, + model=model, + object="chat.completion.chunk", + ) + + mock_inference_api.openai_chat_completion.return_value = fake_stream_with_service_tier() + + # Execute with "auto" but provider returns "priority" + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + service_tier=ServiceTier.auto, + stream=True, + ) + + # Collect all chunks + chunks = [chunk async for chunk in result] + # Verify service_tier is propagated to all events + created_event = chunks[0] + assert created_event.type == "response.created" + # Initially should have "auto" value + assert created_event.response.service_tier == "auto" + + # Check final response has the actual tier from provider + completed_event = chunks[-1] + assert completed_event.type == "response.completed" + assert completed_event.response.service_tier == "priority", "Final response should have actual tier from provider" + + +async def test_create_openai_response_with_top_logprobs_boundary_values( + openai_responses_impl, mock_inference_api, mock_responses_store +): + """Test that top_logprobs works with boundary values (0 and 20).""" + input_text = "Test message" + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Test with minimum value (0) + mock_inference_api.openai_chat_completion.return_value = fake_stream() + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + top_logprobs=0, + stream=False, + store=True, + ) + assert result.top_logprobs == 0 + + # Test with maximum value (20) + mock_inference_api.openai_chat_completion.return_value = fake_stream() + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + top_logprobs=20, + stream=False, + store=True, + ) + assert result.top_logprobs == 20 + + +async def test_create_openai_response_with_frequency_penalty_default(openai_responses_impl, mock_inference_api): + """Test that frequency_penalty defaults to 0.0 when not provided.""" + input_text = "Hello" + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute without frequency_penalty + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=False, + ) + + # Verify response has 0.0 for frequency_penalty (non-null default for OpenResponses conformance) + assert result.frequency_penalty == 0.0 + + # Verify inference API was called with None + mock_inference_api.openai_chat_completion.assert_called() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + assert params.frequency_penalty is None + + +async def test_create_openai_response_with_presence_penalty_default(openai_responses_impl, mock_inference_api): + """Test that presence_penalty defaults to 0.0 when not provided.""" + input_text = "Hi" + model = "meta-llama/Llama-3.1-8B-Instruct" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute without presence_penalty + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=False, + ) + + # Verify presence_penalty is 0.0 (non-null default for OpenResponses conformance) + assert result.presence_penalty == 0.0 + assert result.status == "completed" + + # Verify the inference API was called with presence_penalty=None + mock_inference_api.openai_chat_completion.assert_called() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + assert params.presence_penalty is None + + +async def test_hallucinated_tool_call_does_not_cause_500(openai_responses_impl, mock_inference_api): + """Regression test: a hallucinated tool name should not produce a 500 (InternalServerError). + + When the LLM calls a tool name that is not in the registered tools list the server + was raising ValueError from _coordinate_tool_execution which then propagated as an + InternalServerError (HTTP 500). The correct behaviour is to surface the unknown call + as a regular function-tool-call output so the client can respond, exactly as OpenAI + does for any function tool call. + """ + input_text = "What is the capital of Ireland?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + async def fake_stream_hallucinated_tool(): + # The LLM calls "lookup_capital_city" which is NOT in the registered tools list. + yield ChatCompletionChunk( + id="hallucinated-123", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="tc_hall_123", + function=ChoiceDeltaToolCallFunction( + name="lookup_capital_city", + arguments='{"country": "Ireland"}', + ), + type="function", + ) + ] + ), + ), + ], + created=1, + model=model, + object="chat.completion.chunk", + ) + + mock_inference_api.openai_chat_completion.return_value = fake_stream_hallucinated_tool() + + # The only registered tool is "get_weather". The LLM hallucinated "lookup_capital_city". + # The response should complete without raising InternalServerError, and the hallucinated + # call should appear in the output as a function_call item so the client can handle it. + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + tools=[ + OpenAIResponseInputToolFunction( + name="get_weather", + description="Get current temperature for a given location.", + parameters={ + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + ) + ], + ) + + assert result is not None + assert result.status == "completed" + assert len(result.output) == 1 + assert result.output[0].type == "function_call" + assert result.output[0].name == "lookup_capital_city" + + +async def test_create_openai_response_with_stream_options_merges_with_default( + openai_responses_impl, mock_inference_api +): + """Test that stream_options merges with default include_usage.""" + input_text = "Test stream options" + model = "meta-llama/Llama-3.1-8B-Instruct" + stream_options = ResponseStreamOptions(include_obfuscation=False) + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream_options=stream_options, + stream=True, + ) + + # Collect chunks (consume the async iterator) + _ = [chunk async for chunk in result] + + # Verify the stream_options was merged properly + mock_inference_api.openai_chat_completion.assert_called() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + assert params.stream_options is not None + # Should have both default include_usage and user's option + assert params.stream_options["include_usage"] is True + assert params.stream_options["include_obfuscation"] is False + + +async def test_create_openai_response_with_empty_stream_options(openai_responses_impl, mock_inference_api): + """Test that default stream_options still merges with default include_usage.""" + input_text = "Test empty options" + model = "meta-llama/Llama-3.1-8B-Instruct" + stream_options = ResponseStreamOptions() # Uses default include_obfuscation=True + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream_options=stream_options, + stream=True, + ) + + # Collect chunks (consume the async iterator) + _ = [chunk async for chunk in result] + + # Verify the stream_options has both defaults + mock_inference_api.openai_chat_completion.assert_called() + call_args = mock_inference_api.openai_chat_completion.call_args + params = call_args.args[0] + assert params.stream_options is not None + assert params.stream_options["include_usage"] is True + assert params.stream_options["include_obfuscation"] is True diff --git a/tests/unit/providers/responses/builtin/test_openai_responses_prompts.py b/tests/unit/providers/responses/builtin/test_openai_responses_prompts.py new file mode 100644 index 0000000000..d8b453fb5c --- /dev/null +++ b/tests/unit/providers/responses/builtin/test_openai_responses_prompts.py @@ -0,0 +1,518 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import pytest + +from llama_stack_api import ( + GetPromptRequest, + InvalidParameterError, + OpenAIChatCompletionContentPartImageParam, + OpenAIFile, + OpenAIFileObject, + OpenAISystemMessageParam, + Prompt, +) +from llama_stack_api.inference import ( + OpenAIUserMessageParam, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputMessageContentFile, + OpenAIResponseInputMessageContentImage, + OpenAIResponseInputMessageContentText, + OpenAIResponsePrompt, +) +from tests.unit.providers.responses.builtin.test_openai_responses_helpers import fake_stream + + +async def test_create_openai_response_with_prompt(openai_responses_impl, mock_inference_api, mock_prompts_api): + """Test creating an OpenAI response with a prompt.""" + input_text = "What is the capital of Ireland?" + model = "meta-llama/Llama-3.1-8B-Instruct" + prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" + prompt = Prompt( + prompt="You are a helpful {{ area_name }} assistant at {{ company_name }}. Always provide accurate information.", + prompt_id=prompt_id, + version=1, + variables=["area_name", "company_name"], + is_default=True, + ) + + openai_response_prompt = OpenAIResponsePrompt( + id=prompt_id, + version="1", + variables={ + "area_name": OpenAIResponseInputMessageContentText(text="geography"), + "company_name": OpenAIResponseInputMessageContentText(text="Dummy Company"), + }, + ) + + mock_prompts_api.get_prompt.return_value = prompt + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + prompt=openai_response_prompt, + ) + + mock_prompts_api.get_prompt.assert_called_with(GetPromptRequest(prompt_id=prompt_id, version=1)) + mock_inference_api.openai_chat_completion.assert_called() + call_args = mock_inference_api.openai_chat_completion.call_args + sent_messages = call_args.args[0].messages + assert len(sent_messages) == 2 + + system_messages = [msg for msg in sent_messages if msg.role == "system"] + assert len(system_messages) == 1 + assert ( + system_messages[0].content + == "You are a helpful geography assistant at Dummy Company. Always provide accurate information." + ) + + user_messages = [msg for msg in sent_messages if msg.role == "user"] + assert len(user_messages) == 1 + assert user_messages[0].content == input_text + + assert result.model == model + assert result.status == "completed" + assert isinstance(result.prompt, OpenAIResponsePrompt) + assert result.prompt.id == prompt_id + assert result.prompt.variables == openai_response_prompt.variables + assert result.prompt.version == "1" + + +async def test_prepend_prompt_successful_without_variables(openai_responses_impl, mock_prompts_api, mock_inference_api): + """Test prepend_prompt function without variables.""" + input_text = "What is the capital of Ireland?" + model = "meta-llama/Llama-3.1-8B-Instruct" + prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" + prompt = Prompt( + prompt="You are a helpful assistant. Always provide accurate information.", + prompt_id=prompt_id, + version=1, + variables=[], + is_default=True, + ) + + openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1") + + mock_prompts_api.get_prompt.return_value = prompt + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + prompt=openai_response_prompt, + ) + + mock_prompts_api.get_prompt.assert_called_with(GetPromptRequest(prompt_id=prompt_id, version=1)) + mock_inference_api.openai_chat_completion.assert_called() + call_args = mock_inference_api.openai_chat_completion.call_args + sent_messages = call_args.args[0].messages + assert len(sent_messages) == 2 + system_messages = [msg for msg in sent_messages if msg.role == "system"] + assert system_messages[0].content == "You are a helpful assistant. Always provide accurate information." + + +async def test_prepend_prompt_invalid_variable(openai_responses_impl, mock_prompts_api): + """Test error handling in prepend_prompt function when prompt parameters contain invalid variables.""" + prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" + prompt = Prompt( + prompt="You are a {{ role }} assistant.", + prompt_id=prompt_id, + version=1, + variables=["role"], # Only "role" is valid + is_default=True, + ) + + openai_response_prompt = OpenAIResponsePrompt( + id=prompt_id, + version="1", + variables={ + "role": OpenAIResponseInputMessageContentText(text="helpful"), + "company": OpenAIResponseInputMessageContentText( + text="Dummy Company" + ), # company is not in prompt.variables + }, + ) + + mock_prompts_api.get_prompt.return_value = prompt + + # Initial messages + messages = [OpenAIUserMessageParam(content="Test prompt")] + + # Execute - should raise InvalidParameterError for invalid variable + with pytest.raises(InvalidParameterError) as exc_info: + await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) + assert "Invalid value for 'prompt.variables': company" in str(exc_info.value) + assert f"Variable not defined in prompt '{prompt_id}'" in str(exc_info.value) + + # Verify + mock_prompts_api.get_prompt.assert_called_once_with(GetPromptRequest(prompt_id=prompt_id, version=1)) + + +async def test_prepend_prompt_not_found(openai_responses_impl, mock_prompts_api): + """Test prepend_prompt function when prompt is not found.""" + prompt_id = "pmpt_nonexistent" + openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1") + + mock_prompts_api.get_prompt.return_value = None # Prompt not found + + # Initial messages + messages = [OpenAIUserMessageParam(content="Test prompt")] + initial_length = len(messages) + + # Execute + result = await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) + + # Verify + mock_prompts_api.get_prompt.assert_called_once_with(GetPromptRequest(prompt_id=prompt_id, version=1)) + + # Should return None when prompt not found + assert result is None + + # Messages should not be modified + assert len(messages) == initial_length + assert messages[0].content == "Test prompt" + + +async def test_prepend_prompt_variable_substitution(openai_responses_impl, mock_prompts_api): + """Test complex variable substitution with multiple occurrences and special characters in prepend_prompt function.""" + prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" + + # Support all whitespace variations: {{name}}, {{ name }}, {{ name}}, {{name }}, etc. + prompt = Prompt( + prompt="Hello {{name}}! You are working at {{ company}}. Your role is {{role}} at {{company}}. Remember, {{ name }}, to be {{ tone }}.", + prompt_id=prompt_id, + version=1, + variables=["name", "company", "role", "tone"], + is_default=True, + ) + + openai_response_prompt = OpenAIResponsePrompt( + id=prompt_id, + version="1", + variables={ + "name": OpenAIResponseInputMessageContentText(text="Alice"), + "company": OpenAIResponseInputMessageContentText(text="Dummy Company"), + "role": OpenAIResponseInputMessageContentText(text="AI Assistant"), + "tone": OpenAIResponseInputMessageContentText(text="professional"), + }, + ) + + mock_prompts_api.get_prompt.return_value = prompt + + # Initial messages + messages = [OpenAIUserMessageParam(content="Test")] + + # Execute + await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) + + # Verify + assert len(messages) == 2 + assert isinstance(messages[0], OpenAISystemMessageParam) + expected_content = "Hello Alice! You are working at Dummy Company. Your role is AI Assistant at Dummy Company. Remember, Alice, to be professional." + assert messages[0].content == expected_content + + +async def test_prepend_prompt_with_image_variable(openai_responses_impl, mock_prompts_api, mock_files_api): + """Test prepend_prompt with image variable - should create placeholder in system message and append image as separate user message.""" + prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" + prompt = Prompt( + prompt="Analyze this {{product_image}} and describe what you see.", + prompt_id=prompt_id, + version=1, + variables=["product_image"], + is_default=True, + ) + + # Mock file content and file metadata + mock_file_content = b"fake_image_data" + mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})() + mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject( + object="file", + id="file-abc123", + bytes=len(mock_file_content), + created_at=1234567890, + expires_at=1234567890, + filename="product.jpg", + purpose="assistants", + ) + + openai_response_prompt = OpenAIResponsePrompt( + id=prompt_id, + version="1", + variables={ + "product_image": OpenAIResponseInputMessageContentImage( + file_id="file-abc123", + detail="high", + ) + }, + ) + + mock_prompts_api.get_prompt.return_value = prompt + + # Initial messages + messages = [OpenAIUserMessageParam(content="What do you think?")] + + # Execute + await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) + + assert len(messages) == 3 + + # Check system message has placeholder + assert isinstance(messages[0], OpenAISystemMessageParam) + assert messages[0].content == "Analyze this [Image: product_image] and describe what you see." + + # Check original user message is still there + assert isinstance(messages[1], OpenAIUserMessageParam) + assert messages[1].content == "What do you think?" + + # Check new user message with image is appended + assert isinstance(messages[2], OpenAIUserMessageParam) + assert isinstance(messages[2].content, list) + assert len(messages[2].content) == 1 + + # Should be image with data URL + assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam) + assert messages[2].content[0].image_url.url.startswith("data:image/") + assert messages[2].content[0].image_url.detail == "high" + + +async def test_prepend_prompt_with_file_variable(openai_responses_impl, mock_prompts_api, mock_files_api): + """Test prepend_prompt with file variable - should create placeholder in system message and append file as separate user message.""" + prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" + prompt = Prompt( + prompt="Review the document {{contract_file}} and summarize key points.", + prompt_id=prompt_id, + version=1, + variables=["contract_file"], + is_default=True, + ) + + # Mock file retrieval + mock_file_content = b"fake_pdf_content" + mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})() + mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject( + object="file", + id="file-contract-789", + bytes=len(mock_file_content), + created_at=1234567890, + expires_at=1234567890, + filename="contract.pdf", + purpose="assistants", + ) + + openai_response_prompt = OpenAIResponsePrompt( + id=prompt_id, + version="1", + variables={ + "contract_file": OpenAIResponseInputMessageContentFile( + file_id="file-contract-789", + filename="contract.pdf", + ) + }, + ) + + mock_prompts_api.get_prompt.return_value = prompt + + # Initial messages + messages = [OpenAIUserMessageParam(content="Please review this.")] + + # Execute + await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) + + assert len(messages) == 3 + + # Check system message has placeholder + assert isinstance(messages[0], OpenAISystemMessageParam) + assert messages[0].content == "Review the document [File: contract_file] and summarize key points." + + # Check original user message is still there + assert isinstance(messages[1], OpenAIUserMessageParam) + assert messages[1].content == "Please review this." + + # Check new user message with file is appended + assert isinstance(messages[2], OpenAIUserMessageParam) + assert isinstance(messages[2].content, list) + assert len(messages[2].content) == 1 + + # First part should be file with data URL + assert isinstance(messages[2].content[0], OpenAIFile) + assert messages[2].content[0].file.file_data.startswith("data:application/pdf;base64,") + assert messages[2].content[0].file.filename == "contract.pdf" + assert messages[2].content[0].file.file_id is None + + +async def test_prepend_prompt_with_mixed_variables(openai_responses_impl, mock_prompts_api, mock_files_api): + """Test prepend_prompt with text, image, and file variables mixed together.""" + prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" + prompt = Prompt( + prompt="Hello {{name}}! Analyze {{photo}} and review {{document}}. Provide insights for {{company}}.", + prompt_id=prompt_id, + version=1, + variables=["name", "photo", "document", "company"], + is_default=True, + ) + + # Mock file retrieval for image and file + mock_image_content = b"fake_image_data" + mock_file_content = b"fake_doc_content" + + async def mock_retrieve_file_content(request): + file_id = request.file_id + if file_id == "file-photo-123": + return type("obj", (object,), {"body": mock_image_content})() + elif file_id == "file-doc-456": + return type("obj", (object,), {"body": mock_file_content})() + + mock_files_api.openai_retrieve_file_content.side_effect = mock_retrieve_file_content + + def mock_retrieve_file(request): + file_id = request.file_id + if file_id == "file-photo-123": + return OpenAIFileObject( + object="file", + id="file-photo-123", + bytes=len(mock_image_content), + created_at=1234567890, + expires_at=1234567890, + filename="photo.jpg", + purpose="assistants", + ) + elif file_id == "file-doc-456": + return OpenAIFileObject( + object="file", + id="file-doc-456", + bytes=len(mock_file_content), + created_at=1234567890, + expires_at=1234567890, + filename="doc.pdf", + purpose="assistants", + ) + + mock_files_api.openai_retrieve_file.side_effect = mock_retrieve_file + + openai_response_prompt = OpenAIResponsePrompt( + id=prompt_id, + version="1", + variables={ + "name": OpenAIResponseInputMessageContentText(text="Alice"), + "photo": OpenAIResponseInputMessageContentImage(file_id="file-photo-123", detail="auto"), + "document": OpenAIResponseInputMessageContentFile(file_id="file-doc-456", filename="doc.pdf"), + "company": OpenAIResponseInputMessageContentText(text="Acme Corp"), + }, + ) + + mock_prompts_api.get_prompt.return_value = prompt + + # Initial messages + messages = [OpenAIUserMessageParam(content="Here's my question.")] + + # Execute + await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) + + assert len(messages) == 3 + + # Check system message has text and placeholders + assert isinstance(messages[0], OpenAISystemMessageParam) + expected_system = "Hello Alice! Analyze [Image: photo] and review [File: document]. Provide insights for Acme Corp." + assert messages[0].content == expected_system + + # Check original user message is still there + assert isinstance(messages[1], OpenAIUserMessageParam) + assert messages[1].content == "Here's my question." + + # Check new user message with media is appended (2 media items) + assert isinstance(messages[2], OpenAIUserMessageParam) + assert isinstance(messages[2].content, list) + assert len(messages[2].content) == 2 + + # First part should be image with data URL + assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam) + assert messages[2].content[0].image_url.url.startswith("data:image/") + + # Second part should be file with data URL + assert isinstance(messages[2].content[1], OpenAIFile) + assert messages[2].content[1].file.file_data.startswith("data:application/pdf;base64,") + assert messages[2].content[1].file.filename == "doc.pdf" + assert messages[2].content[1].file.file_id is None + + +async def test_prepend_prompt_with_image_using_image_url(openai_responses_impl, mock_prompts_api): + """Test prepend_prompt with image variable using image_url instead of file_id.""" + prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" + prompt = Prompt( + prompt="Describe {{screenshot}}.", + prompt_id=prompt_id, + version=1, + variables=["screenshot"], + is_default=True, + ) + + openai_response_prompt = OpenAIResponsePrompt( + id=prompt_id, + version="1", + variables={ + "screenshot": OpenAIResponseInputMessageContentImage( + image_url="https://example.com/screenshot.png", + detail="low", + ) + }, + ) + + mock_prompts_api.get_prompt.return_value = prompt + + # Initial messages + messages = [OpenAIUserMessageParam(content="What is this?")] + + # Execute + await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) + + assert len(messages) == 3 + + # Check system message has placeholder + assert isinstance(messages[0], OpenAISystemMessageParam) + assert messages[0].content == "Describe [Image: screenshot]." + + # Check original user message is still there + assert isinstance(messages[1], OpenAIUserMessageParam) + assert messages[1].content == "What is this?" + + # Check new user message with image is appended + assert isinstance(messages[2], OpenAIUserMessageParam) + assert isinstance(messages[2].content, list) + + # Image should use the provided URL + assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam) + assert messages[2].content[0].image_url.url == "https://example.com/screenshot.png" + assert messages[2].content[0].image_url.detail == "low" + + +async def test_prepend_prompt_image_variable_missing_required_fields(openai_responses_impl, mock_prompts_api): + """Test prepend_prompt with image variable that has neither file_id nor image_url - should raise error.""" + prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef" + prompt = Prompt( + prompt="Analyze {{bad_image}}.", + prompt_id=prompt_id, + version=1, + variables=["bad_image"], + is_default=True, + ) + + # Create image content with neither file_id nor image_url + openai_response_prompt = OpenAIResponsePrompt( + id=prompt_id, + version="1", + variables={"bad_image": OpenAIResponseInputMessageContentImage()}, # No file_id or image_url + ) + + mock_prompts_api.get_prompt.return_value = prompt + messages = [OpenAIUserMessageParam(content="Test")] + + # Execute - should raise ValueError + with pytest.raises(ValueError, match="Image content must have either 'image_url' or 'file_id'"): + await openai_responses_impl._prepend_prompt(messages, openai_response_prompt) diff --git a/tests/unit/providers/responses/builtin/test_openai_responses_tools.py b/tests/unit/providers/responses/builtin/test_openai_responses_tools.py new file mode 100644 index 0000000000..b5a2a43533 --- /dev/null +++ b/tests/unit/providers/responses/builtin/test_openai_responses_tools.py @@ -0,0 +1,621 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import patch + +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) + +from llama_stack.core.datatypes import VectorStoresConfig +from llama_stack.providers.inline.responses.builtin.responses.tool_executor import ToolExecutor +from llama_stack.providers.utils.responses.responses_store import ( + _OpenAIResponseObjectWithInputAndMessages, +) +from llama_stack_api import ( + GetConnectorRequest, +) +from llama_stack_api.openai_responses import ( + OpenAIResponseInputToolFileSearch, + OpenAIResponseInputToolFunction, + OpenAIResponseInputToolMCP, + OpenAIResponseInputToolWebSearch, + OpenAIResponseMessage, + WebSearchToolTypes, +) +from llama_stack_api.tools import ListToolDefsResponse, ToolDef, ToolInvocationResult +from llama_stack_api.vector_io import ( + VectorStoreContent, + VectorStoreSearchResponse, + VectorStoreSearchResponsePage, +) +from tests.unit.providers.responses.builtin.test_openai_responses_helpers import fake_stream + + +async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with a simple string input and tools.""" + # Setup + input_text = "What is the capital of Ireland?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + openai_responses_impl.tool_groups_api.get_tool.return_value = ToolDef( + name="web_search", + toolgroup_id="web_search", + description="Search the web for information", + input_schema={ + "type": "object", + "properties": {"query": {"type": "string", "description": "The query to search for"}}, + "required": ["query"], + }, + ) + + openai_responses_impl.tool_runtime_api.invoke_tool.return_value = ToolInvocationResult( + status="completed", + content="Dublin", + ) + + # Execute + for tool_name in WebSearchToolTypes: + # Reset mock states as we loop through each tool type + mock_inference_api.openai_chat_completion.side_effect = [ + fake_stream("tool_call_completion.yaml"), + fake_stream(), + ] + openai_responses_impl.tool_groups_api.get_tool.reset_mock() + openai_responses_impl.tool_runtime_api.invoke_tool.reset_mock() + openai_responses_impl.responses_store.upsert_response_object.reset_mock() + + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + temperature=0.1, + tools=[ + OpenAIResponseInputToolWebSearch( + name=tool_name, + ) + ], + ) + + # Verify + first_call = mock_inference_api.openai_chat_completion.call_args_list[0] + first_params = first_call.args[0] + assert first_params.messages[0].content == "What is the capital of Ireland?" + assert first_params.tools is not None + assert first_params.temperature == 0.1 + + second_call = mock_inference_api.openai_chat_completion.call_args_list[1] + second_params = second_call.args[0] + assert second_params.messages[-1].content == "Dublin" + assert second_params.temperature == 0.1 + + openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search") + openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with( + tool_name="web_search", + kwargs={"query": "What is the capital of Ireland?"}, + ) + + openai_responses_impl.responses_store.upsert_response_object.assert_called() + + # Check that we got the content from our mocked tool execution result + assert len(result.output) >= 1 + assert isinstance(result.output[1], OpenAIResponseMessage) + assert result.output[1].content[0].text == "Dublin" + assert result.output[1].content[0].annotations == [] + + +async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with a tool call response that has a type of None.""" + # Setup + input_text = "How hot it is in San Francisco today?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + async def fake_stream_toolcall(): + yield ChatCompletionChunk( + id="123", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="tc_123", + function=ChoiceDeltaToolCallFunction(name="get_weather", arguments="{}"), + type=None, + ) + ] + ), + ), + ], + created=1, + model=model, + object="chat.completion.chunk", + ) + + mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() + + # Execute + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=True, + temperature=0.1, + tools=[ + OpenAIResponseInputToolFunction( + name="get_weather", + description="Get current temperature for a given location.", + parameters={ + "location": "string", + }, + ) + ], + ) + + # Check that we got the content from our mocked tool execution result + chunks = [chunk async for chunk in result] + + # Verify event types + # Should have: response.created, response.in_progress, output_item.added, + # function_call_arguments.delta, function_call_arguments.done, output_item.done, response.completed + assert len(chunks) == 7 + + event_types = [chunk.type for chunk in chunks] + assert event_types == [ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.function_call_arguments.delta", + "response.function_call_arguments.done", + "response.output_item.done", + "response.completed", + ] + + # Verify inference API was called correctly (after iterating over result) + first_call = mock_inference_api.openai_chat_completion.call_args_list[0] + first_params = first_call.args[0] + assert first_params.messages[0].content == input_text + assert first_params.tools is not None + assert first_params.temperature == 0.1 + + # Check response.created event (should have empty output) + assert len(chunks[0].response.output) == 0 + + # Check response.completed event (should have the tool call) + completed_chunk = chunks[-1] + assert completed_chunk.type == "response.completed" + assert len(completed_chunk.response.output) == 1 + assert completed_chunk.response.output[0].type == "function_call" + assert completed_chunk.response.output[0].name == "get_weather" + + +async def test_create_openai_response_with_tool_call_function_arguments_none(openai_responses_impl, mock_inference_api): + """Test creating an OpenAI response with tool calls that omit arguments.""" + + input_text = "What is the time right now?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + async def fake_stream_toolcall(): + yield ChatCompletionChunk( + id="123", + choices=[ + Choice( + index=0, + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="tc_123", + function=ChoiceDeltaToolCallFunction(name="get_current_time", arguments=None), + type=None, + ) + ] + ), + ), + ], + created=1, + model=model, + object="chat.completion.chunk", + ) + + def assert_common_expectations(chunks) -> None: + first_call = mock_inference_api.openai_chat_completion.call_args_list[0] + first_params = first_call.args[0] + assert first_params.messages[0].content == input_text + assert first_params.tools is not None + assert first_params.temperature == 0.1 + assert len(chunks[0].response.output) == 0 + completed_chunk = chunks[-1] + assert completed_chunk.type == "response.completed" + assert len(completed_chunk.response.output) == 1 + assert completed_chunk.response.output[0].type == "function_call" + assert completed_chunk.response.output[0].name == "get_current_time" + assert completed_chunk.response.output[0].arguments == "{}" + + # Function does not accept arguments + mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=True, + temperature=0.1, + tools=[ + OpenAIResponseInputToolFunction( + name="get_current_time", description="Get current time for system's timezone", parameters={} + ) + ], + ) + chunks = [chunk async for chunk in result] + assert [chunk.type for chunk in chunks] == [ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.function_call_arguments.done", + "response.output_item.done", + "response.completed", + ] + assert_common_expectations(chunks) + + # Function accepts optional arguments + mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=True, + temperature=0.1, + tools=[ + OpenAIResponseInputToolFunction( + name="get_current_time", + description="Get current time for system's timezone", + parameters={"timezone": "string"}, + ) + ], + ) + chunks = [chunk async for chunk in result] + assert [chunk.type for chunk in chunks] == [ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.function_call_arguments.done", + "response.output_item.done", + "response.completed", + ] + assert_common_expectations(chunks) + + # Function accepts optional arguments with additional optional fields + mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() + result = await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=True, + temperature=0.1, + tools=[ + OpenAIResponseInputToolFunction( + name="get_current_time", + description="Get current time for system's timezone", + parameters={"timezone": "string", "location": "string"}, + ) + ], + ) + chunks = [chunk async for chunk in result] + assert [chunk.type for chunk in chunks] == [ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.function_call_arguments.done", + "response.output_item.done", + "response.completed", + ] + assert_common_expectations(chunks) + mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall() + + +@patch("llama_stack.providers.inline.responses.builtin.responses.streaming.list_mcp_tools") +async def test_reuse_mcp_tool_list( + mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api +): + """Test that mcp_list_tools can be reused where appropriate.""" + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + mock_list_mcp_tools.return_value = ListToolDefsResponse( + data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})] + ) + + res1 = await openai_responses_impl.create_openai_response( + input="What is 2+2?", + model="meta-llama/Llama-3.1-8B-Instruct", + store=True, + tools=[ + OpenAIResponseInputToolFunction(name="fake", parameters=None), + OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"), + ], + ) + args = mock_responses_store.upsert_response_object.call_args + data = args.kwargs["response_object"].model_dump() + data["input"] = [input_item.model_dump() for input_item in args.kwargs["input"]] + data["messages"] = [msg.model_dump() for msg in args.kwargs["messages"]] + stored = _OpenAIResponseObjectWithInputAndMessages(**data) + mock_responses_store.get_response_object.return_value = stored + + res2 = await openai_responses_impl.create_openai_response( + previous_response_id=res1.id, + input="Now what is 3+3?", + model="meta-llama/Llama-3.1-8B-Instruct", + store=True, + tools=[ + OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"), + ], + ) + assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2 + second_call = mock_inference_api.openai_chat_completion.call_args_list[1] + second_params = second_call.args[0] + tools_seen = second_params.tools + assert len(tools_seen) == 1 + assert tools_seen[0]["function"]["name"] == "test_tool" + assert tools_seen[0]["function"]["description"] == "a test tool" + + assert mock_list_mcp_tools.call_count == 1 + listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"] + assert len(listings) == 1 + assert listings[0].server_label == "alabel" + assert len(listings[0].tools) == 1 + assert listings[0].tools[0].name == "test_tool" + + +@patch("llama_stack.providers.inline.responses.builtin.responses.streaming.list_mcp_tools") +async def test_mcp_tool_connector_id_resolved_to_server_url( + mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api, mock_connectors_api +): + """Test that connector_id is resolved to server_url when using MCP tools.""" + from llama_stack_api import Connector, ConnectorType + + # Setup mock connector that will be returned when resolving connector_id + mock_connector = Connector( + connector_id="my-mcp-connector", + connector_type=ConnectorType.MCP, + url="http://resolved-mcp-server:8080/mcp", + server_label="Resolved MCP Server", + ) + mock_connectors_api.get_connector.return_value = mock_connector + + mock_inference_api.openai_chat_completion.return_value = fake_stream() + mock_list_mcp_tools.return_value = ListToolDefsResponse( + data=[ToolDef(name="resolved_tool", description="a resolved tool", input_schema={}, output_schema={})] + ) + + # Create a response using connector_id instead of server_url + result = await openai_responses_impl.create_openai_response( + input="Test connector resolution", + model="meta-llama/Llama-3.1-8B-Instruct", + store=True, + tools=[ + OpenAIResponseInputToolMCP(server_label="my-label", connector_id="my-mcp-connector"), + ], + ) + + # Verify the connector_id was resolved via the connectors API + mock_connectors_api.get_connector.assert_called_once_with(GetConnectorRequest(connector_id="my-mcp-connector")) + + # Verify list_mcp_tools was called with the resolved URL + mock_list_mcp_tools.assert_called_once() + call_kwargs = mock_list_mcp_tools.call_args.kwargs + assert call_kwargs["endpoint"] == "http://resolved-mcp-server:8080/mcp" + + # Verify the response contains the resolved tools + listings = [obj for obj in result.output if obj.type == "mcp_list_tools"] + assert len(listings) == 1 + assert listings[0].server_label == "my-label" + assert len(listings[0].tools) == 1 + assert listings[0].tools[0].name == "resolved_tool" + + +async def test_file_search_results_include_chunk_metadata_attributes(mock_vector_io_api): + """Test that file_search tool executor preserves chunk metadata attributes.""" + query = "What is machine learning?" + vector_store_id = "test_vector_store" + + # Mock vector_io to return search results with custom attributes + mock_vector_io_api.openai_search_vector_store.return_value = VectorStoreSearchResponsePage( + search_query=[query], + data=[ + VectorStoreSearchResponse( + file_id="doc-123", + filename="ml-intro.md", + content=[VectorStoreContent(type="text", text="Machine learning is a subset of AI")], + score=0.95, + attributes={ + "document_id": "ml-intro", + "source_url": "https://example.com/ml-guide", + "title": "Introduction to ML", + "author": "John Doe", + "year": "2024", + }, + ), + VectorStoreSearchResponse( + file_id="doc-456", + filename="dl-basics.md", + content=[VectorStoreContent(type="text", text="Deep learning uses neural networks")], + score=0.85, + attributes={ + "document_id": "dl-basics", + "source_url": "https://example.com/dl-guide", + "title": "Deep Learning Basics", + "category": "tutorial", + }, + ), + ], + ) + + # Create tool executor with mock vector_io + tool_executor = ToolExecutor( + tool_groups_api=None, # type: ignore + tool_runtime_api=None, # type: ignore + vector_io_api=mock_vector_io_api, + vector_stores_config=VectorStoresConfig(), + mcp_session_manager=None, + ) + + # Execute the file search + file_search_tool = OpenAIResponseInputToolFileSearch(vector_store_ids=[vector_store_id]) + result = await tool_executor._execute_file_search_via_vector_store( + query=query, + response_file_search_tool=file_search_tool, + ) + + mock_vector_io_api.openai_search_vector_store.assert_called_once() + + # Verify the result metadata includes chunk attributes + assert result.metadata is not None + assert "attributes" in result.metadata + attributes = result.metadata["attributes"] + assert len(attributes) == 2 + + # Verify first result has all expected attributes + attrs1 = attributes[0] + assert attrs1["document_id"] == "ml-intro" + assert attrs1["source_url"] == "https://example.com/ml-guide" + assert attrs1["title"] == "Introduction to ML" + assert attrs1["author"] == "John Doe" + assert attrs1["year"] == "2024" + + # Verify second result has its attributes + attrs2 = attributes[1] + assert attrs2["document_id"] == "dl-basics" + assert attrs2["source_url"] == "https://example.com/dl-guide" + assert attrs2["title"] == "Deep Learning Basics" + assert attrs2["category"] == "tutorial" + + # Verify scores and document_ids are also present + assert result.metadata["scores"] == [0.95, 0.85] + assert result.metadata["document_ids"] == ["doc-123", "doc-456"] + assert result.metadata["chunks"] == [ + "Machine learning is a subset of AI", + "Deep learning uses neural networks", + ] + + +async def test_function_tool_strict_field_excluded_when_none(openai_responses_impl, mock_inference_api): + """Test that function tool 'strict' field is excluded when None (fix for #4617).""" + input_text = "What is the weather?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Mock inference response + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute with function tool that has strict=None (default) + await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=False, + tools=[ + OpenAIResponseInputToolFunction( + type="function", + name="get_weather", + description="Get weather information", + parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, + # strict is None by default + ) + ], + ) + + # Verify the call was made + assert mock_inference_api.openai_chat_completion.call_count == 1 + params = mock_inference_api.openai_chat_completion.call_args[0][0] + + # Verify tools were passed + assert params.tools is not None + assert len(params.tools) == 1 + + # Critical: verify 'strict' field is NOT present when it's None + # This prevents "strict: null" from being sent to OpenAI API + tool_function = params.tools[0]["function"] + assert "strict" not in tool_function, ( + "strict field should be excluded when None to avoid OpenAI API validation error" + ) + + # Verify other fields are present + assert tool_function["name"] == "get_weather" + assert tool_function["description"] == "Get weather information" + assert tool_function["parameters"] is not None + + +async def test_function_tool_strict_field_included_when_set(openai_responses_impl, mock_inference_api): + """Test that function tool 'strict' field is included when explicitly set.""" + input_text = "What is the weather?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Mock inference response + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute with function tool that has strict=True + await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=False, + tools=[ + OpenAIResponseInputToolFunction( + type="function", + name="get_weather", + description="Get weather information", + parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, + strict=True, # Explicitly set to True + ) + ], + ) + + # Verify the call was made + assert mock_inference_api.openai_chat_completion.call_count == 1 + params = mock_inference_api.openai_chat_completion.call_args[0][0] + + # Verify tools were passed + assert params.tools is not None + assert len(params.tools) == 1 + + # Verify 'strict' field IS present when explicitly set + tool_function = params.tools[0]["function"] + assert "strict" in tool_function, "strict field should be included when explicitly set" + assert tool_function["strict"] is True, "strict field should have the correct value" + + # Verify other fields are present + assert tool_function["name"] == "get_weather" + assert tool_function["description"] == "Get weather information" + + +async def test_function_tool_strict_false_included(openai_responses_impl, mock_inference_api): + """Test that function tool 'strict' field is included when set to False.""" + input_text = "What is the weather?" + model = "meta-llama/Llama-3.1-8B-Instruct" + + # Mock inference response + mock_inference_api.openai_chat_completion.return_value = fake_stream() + + # Execute with function tool that has strict=False + await openai_responses_impl.create_openai_response( + input=input_text, + model=model, + stream=False, + tools=[ + OpenAIResponseInputToolFunction( + type="function", + name="get_weather", + description="Get weather information", + parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, + strict=False, # Explicitly set to False + ) + ], + ) + + # Verify the call was made + assert mock_inference_api.openai_chat_completion.call_count == 1 + params = mock_inference_api.openai_chat_completion.call_args[0][0] + + # Verify 'strict' field IS present and set to False + tool_function = params.tools[0]["function"] + assert "strict" in tool_function, "strict field should be included when explicitly set to False" + assert tool_function["strict"] is False, "strict field should be False" From ea9c9459d3744d9ec1e740b083ac7bf31e3bd2a0 Mon Sep 17 00:00:00 2001 From: skamenan7 Date: Wed, 25 Mar 2026 14:23:47 -0400 Subject: [PATCH 3/4] refactor(scoring): split ifeval_utils.py and extract WORD_LIST data Split 3319-line ifeval_utils.py into: - ifeval_word_list.py: WORD_LIST data literal only (1552L, grandfathered as pure data) - ifeval_support.py: LANGUAGE_CODES, regex constants, utility functions, arg constants (296L) - ifeval_checkers_core.py: Instruction base + 15 checker classes (828L) - ifeval_checkers_format.py: 15 checker classes (703L) - ifeval_utils.py: thin entry point with INSTRUCTION_DICT/LIST registry (96L) Updates pyproject.toml per-file-ignores to point RUF001 suppression at the new files containing Unicode characters instead of the now-thin ifeval_utils.py. External import path preserved: ifeval_scoring_fn.py unchanged. Also removes stale plugin import of deleted test_openai_responses from root conftest.py. Signed-off-by: skamenan7 --- conftest.py | 2 - pyproject.toml | 11 +- scripts/check_file_size.py | 2 +- .../basic/utils/ifeval_checkers_core.py | 819 ++++ .../basic/utils/ifeval_checkers_format.py | 698 ++++ .../scoring/basic/utils/ifeval_support.py | 289 ++ .../scoring/basic/utils/ifeval_utils.py | 3291 +---------------- .../scoring/basic/utils/ifeval_word_list.py | 1538 ++++++++ 8 files changed, 3382 insertions(+), 3268 deletions(-) create mode 100644 src/llama_stack/providers/inline/scoring/basic/utils/ifeval_checkers_core.py create mode 100644 src/llama_stack/providers/inline/scoring/basic/utils/ifeval_checkers_format.py create mode 100644 src/llama_stack/providers/inline/scoring/basic/utils/ifeval_support.py create mode 100644 src/llama_stack/providers/inline/scoring/basic/utils/ifeval_word_list.py diff --git a/conftest.py b/conftest.py index 8ab6b2116e..77817b3952 100644 --- a/conftest.py +++ b/conftest.py @@ -67,8 +67,6 @@ def _looks_like_existing_test_path(arg: object) -> bool: # Import plugins dynamically if running_unit: config.pluginmanager.import_plugin("tests.unit.fixtures") - # Load shared fixtures from openai_responses test file (used by conversations tests) - config.pluginmanager.import_plugin("tests.unit.providers.responses.builtin.test_openai_responses") if running_integration: config.pluginmanager.import_plugin("tests.integration.fixtures.common") diff --git a/pyproject.toml b/pyproject.toml index d42b894ac4..4e1295aef7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -276,7 +276,16 @@ convention = "google" "benchmarking/**/*.py" = ["D101"] # Ignore docstring rules for benchmarking scripts "client-sdks/**/*.py" = ["D101"] # Ignore docstring rules for client SDKs "scripts/**/*.py" = ["D101"] # Ignore docstring rules for scripts -"src/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py" = [ +"src/llama_stack/providers/inline/scoring/basic/utils/ifeval_support.py" = [ + "RUF001", +] +"src/llama_stack/providers/inline/scoring/basic/utils/ifeval_checkers_core.py" = [ + "RUF001", +] +"src/llama_stack/providers/inline/scoring/basic/utils/ifeval_checkers_format.py" = [ + "RUF001", +] +"src/llama_stack/providers/inline/scoring/basic/utils/ifeval_word_list.py" = [ "RUF001", ] "src/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py" = [ diff --git a/scripts/check_file_size.py b/scripts/check_file_size.py index 0a179edc07..6510ce1915 100755 --- a/scripts/check_file_size.py +++ b/scripts/check_file_size.py @@ -27,7 +27,7 @@ GRANDFATHERED_FILES = { "src/llama_stack/providers/inline/responses/builtin/responses/openai_responses.py", "src/llama_stack/providers/inline/responses/builtin/responses/streaming.py", - "src/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py", + "src/llama_stack/providers/inline/scoring/basic/utils/ifeval_word_list.py", # pure data file "src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py", "src/llama_stack/providers/registry/vector_io.py", "src/llama_stack/testing/api_recorder.py", diff --git a/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_checkers_core.py b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_checkers_core.py new file mode 100644 index 0000000000..5374fe0cfc --- /dev/null +++ b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_checkers_core.py @@ -0,0 +1,819 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import random +import re + +import langdetect +from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai +from pythainlp.tokenize import word_tokenize as word_tokenize_thai + +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="scoring") + +from llama_stack.providers.inline.scoring.basic.utils.ifeval_support import ( + _COMPARISON_RELATION, + _CONSTRAINED_RESPONSE_OPTIONS, + _KEYWORD_FREQUENCY, + _LANGUAGES, + _MAX_NUM_SENTENCES, + _NUM_BULLETS, + _NUM_HIGHLIGHTED_SECTIONS, + _NUM_KEYWORDS, + _NUM_PARAGRAPHS, + _NUM_PLACEHOLDERS, + _NUM_SECTIONS, + _NUM_WORDS_LOWER_LIMIT, + _NUM_WORDS_UPPER_LIMIT, + _POSTSCRIPT_MARKER, + _SECTION_SPLITER, + _STARTER_OPTIONS, + count_sentences, + count_words, + count_words_cjk, + generate_keywords, + get_langid, + split_chinese_japanese_hindi, +) + + +class Instruction: + """An instruction template.""" + + def __init__(self, instruction_id): + self.id = instruction_id + + def build_description(self, **kwargs): + raise NotImplementedError("`build_description` not implemented.") + + def get_instruction_args(self): + raise NotImplementedError("`get_instruction_args` not implemented.") + + def get_instruction_args_keys(self): + raise NotImplementedError("`get_instruction_args_keys` not implemented.") + + def check_following(self, value): + raise NotImplementedError("`check_following` not implemented.") + + +class ResponseLanguageChecker(Instruction): + """Check the language of the entire response.""" + + def build_description(self, *, language=None): + """Build the instruction description. + + Args: + language: A string representing the expected language of the response. The + language has to comply to the 97 types defined in + `langid.py` (https://pypi.org/project/langid/1.1.5/), which follows + ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes); + for example, `en` for English, `zh` for Chinese, `fr` for French. + + Returns: + A string representing the instruction description. + """ + self._language = language + if self._language is None: + self._language = random.choice(list(_LANGUAGES.keys())) + + self._description_pattern = ( + "Your ENTIRE response should be in {language} language, no other " + "language is allowed." + ) + return self._description_pattern.format(language=_LANGUAGES[self._language]) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"language": self._language} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["language"] + + def check_following(self, value): + """Check if the language of the entire response follows the instruction. + + Args: + value: A string representing the response. + + Returns: + True if the language of `value` follows instruction; otherwise False. + """ + assert isinstance(value, str) + + try: + return langdetect.detect(value) == self._language + except langdetect.LangDetectException as e: + # Count as instruction is followed. + logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 + return True + + +class NumberOfSentences(Instruction): + """Check the number of sentences.""" + + def build_description(self, *, num_sentences=None, relation=None): + """Build the instruction description. + + Args: + num_sentences: An integer specifying the number of sentences as a + threshold. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of sentences < the threshold; + if 'at least', the actual number of sentences >= the threshold. + + Returns: + A string representing the instruction description. + """ + # The number of sentences as a threshold for comparison. + self._num_sentences_threshold = num_sentences + if self._num_sentences_threshold is None or self._num_sentences_threshold < 0: + self._num_sentences_threshold = random.randint(1, _MAX_NUM_SENTENCES) + + if relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif relation not in _COMPARISON_RELATION: + raise ValueError( + f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." + ) + else: + self._comparison_relation = relation + + self._description_pattern = "Your response should contain {relation} {num_sentences} sentences." + return self._description_pattern.format( + relation=self._comparison_relation, + num_sentences=self._num_sentences_threshold, + ) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return { + "num_sentences": self._num_sentences_threshold, + "relation": self._comparison_relation, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_sentences", "relation"] + + def check_following(self, value): + """Check if the number of sentences follows the instruction. + + Args: + value: A string representing the response. + + Returns: + True if the response follows the instruction. + + Raise: + ValueError if the string in `instruction_args` is not in + [`less_than`, `at_least`]. + """ + lang = get_langid(value) + if lang == "th": + # Counting Newline also as a new sentence: + num_sentences = sum([len(sent_tokenize_thai(line)) for line in value.splitlines()]) + elif lang in ["zh", "zh-cn", "zh-tw", "ja", "hi"]: + num_sentences = len(list(split_chinese_japanese_hindi(value))) + else: + num_sentences = count_sentences(value) + if self._comparison_relation == _COMPARISON_RELATION[0]: + return num_sentences < self._num_sentences_threshold + elif self._comparison_relation == _COMPARISON_RELATION[1]: + return num_sentences >= self._num_sentences_threshold + + +class PlaceholderChecker(Instruction): + """Check the placeholders in template writing.""" + + def build_description(self, *, num_placeholders=None): + """Build the instruction description. + + Args: + num_placeholders: An integer denoting the minimum number of + placeholders required in the response. + + Returns: + A string representing the instruction description. + """ + self._num_placeholders = num_placeholders + if self._num_placeholders is None or self._num_placeholders < 0: + self._num_placeholders = random.randint(1, _NUM_PLACEHOLDERS) + self._description_pattern = ( + "The response must contain at least {num_placeholders} placeholders " + + "represented by square brackets, such as [address]." + ) + return self._description_pattern.format(num_placeholders=self._num_placeholders) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_placeholders": self._num_placeholders} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_placeholders"] + + def check_following(self, value): + """Check if the number of placeholders follows the instruction. + + Args: + value: A string representing the response. + + Returns: + True if the actual number of placeholders in the response is greater than + or equal to `num_placeholders`; otherwise, False. + """ + placeholders = re.findall(r"\[.*?\]", value) + num_placeholders = len(placeholders) + return num_placeholders >= self._num_placeholders + + +class BulletListChecker(Instruction): + """Checks the bullet list in the prompt.""" + + def build_description(self, *, num_bullets=None): + """Build the instruction description. + + Args: + num_bullets: An integer specifying the exact number of bullet lists + that is required to appear in the response. + + Returns: + A string representing the instruction description. + """ + self._num_bullets = num_bullets + if self._num_bullets is None or self._num_bullets < 0: + self._num_bullets = random.randint(1, _NUM_BULLETS) + self._description_pattern = ( + "Your answer must contain exactly {num_bullets} bullet points. " + + "Use the markdown bullet points such as:\n" + + "* This is point 1. \n" + + "* This is point 2" + ) + return self._description_pattern.format(num_bullets=self._num_bullets) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_bullets": self._num_bullets} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_bullets"] + + def check_following(self, value): + r"""Check if the number of bullet lists meets the requirement. + + Args: + value: A string representing the response. The response is expected to + contain some bullet lists that start with `\*`. + + Returns: + True if the actual number of bullet lists in the response meets the + requirement. + """ + bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE) + bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE) + num_bullet_lists = len(bullet_lists) + len(bullet_lists_2) + return num_bullet_lists == self._num_bullets + + +class ConstrainedResponseChecker(Instruction): + """Checks the constrained response.""" + + def build_description(self): + """Build the instruction description.""" + # A sequence of string(s) representing the options of the expected response. + self._constrained_responses = _CONSTRAINED_RESPONSE_OPTIONS + self._description_pattern = "Answer with one of the following options: {response_options}" + return self._description_pattern.format(response_options=self._constrained_responses) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response matches the constrained options. + + Args: + value: A string representing the response. + + Returns: + True if the actual response contains one of the options in the constrained + responses; otherwise False. + """ + value = value.strip() + for constrained_response in self._constrained_responses: + if constrained_response in value: + return True + return False + + +class ConstrainedStartChecker(Instruction): + """Checks the response start.""" + + def build_description(self, *, starter=None): + """Build the instruction description. + + Args: + starter: A string representing the keyward that the response should start + with. + + Returns: + A string representing the instruction description. + """ + self._starter = starter.strip() if isinstance(starter, str) else starter + if self._starter is None: + self._starter = random.choice(_STARTER_OPTIONS) + self._description_pattern = ( + "During the conversation, when it is your turn, " + "please always start with {starter}" + ) + return self._description_pattern.format(starter=self._starter) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"starter": self._starter} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["starter"] + + def check_following(self, value): + """Checks if the response starts with the constrained keyword or phrase. + + Args: + value: A string representing the response. + + Returns: + True if the response starts with the given phrase or keyword that is + contained in `instruction_args`; otherwise, False. + """ + response_pattern = r"^\s*" + self._starter + r".*$" + response_with_constrained_start = re.search(response_pattern, value, flags=re.MULTILINE) + return True if response_with_constrained_start else False + + +class HighlightSectionChecker(Instruction): + """Checks the highlighted section.""" + + def build_description(self, *, num_highlights=None): + """Build the instruction description. + + Args: + num_highlights: An integer specifying the minimum number of highlighted + sections. + + Returns: + A string representing the instruction description. + """ + self._num_highlights = num_highlights + if self._num_highlights is None or self._num_highlights < 0: + self._num_highlights = random.randint(1, _NUM_HIGHLIGHTED_SECTIONS) + + self._description_pattern = ( + "Highlight at least {num_highlights} sections in your answer with " + + "markdown, i.e. *highlighted section*." + ) + + return self._description_pattern.format(num_highlights=self._num_highlights) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_highlights": self._num_highlights} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_highlights"] + + def check_following(self, value): + """Checks if the number of highlighted sections meets the requirement. + + Args: + value: a string repesenting the response. The response is expected to + contain highlighted sections in the format of *highlighted*. + + Returns: + True if the actual number of highlighted sections in the format of + *highlighed sections* meets the minimum requirement; otherwise False. + """ + num_highlights = 0 + highlights = re.findall(r"\*[^\n\*]*\*", value) + double_highlights = re.findall(r"\*\*[^\n\*]*\*\*", value) + for highlight in highlights: + if highlight.strip("*").strip(): + num_highlights += 1 + for highlight in double_highlights: + if highlight.removeprefix("**").removesuffix("**").strip(): + num_highlights += 1 + + return num_highlights >= self._num_highlights + + +class SectionChecker(Instruction): + """Checks the sections.""" + + def build_description(self, *, section_spliter=None, num_sections=None): + """Build the instruction description. + + Args: + section_spliter: A string represents the section spliter keyword that + marks a new section, i.e., `Section` or `SECTION`. + num_sections: An integer specifying the number of sections. + + Returns: + A string representing the instruction description. + """ + self._section_spliter = section_spliter.strip() if isinstance(section_spliter, str) else section_spliter + if self._section_spliter is None: + self._section_spliter = random.choice(_SECTION_SPLITER) + + self._num_sections = num_sections + if self._num_sections is None or self._num_sections < 0: + self._num_sections = random.randint(1, _NUM_SECTIONS) + + self._description_pattern = ( + "Your response must have {num_sections} sections. Mark the beginning " + + "of each section with {section_spliter} X, such as:\n" + + "{section_spliter} 1\n" + + "[content of section 1]\n" + + "{section_spliter} 2\n" + + "[content of section 2]" + ) + + return self._description_pattern.format(num_sections=self._num_sections, section_spliter=self._section_spliter) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return { + "section_spliter": self._section_spliter, + "num_sections": self._num_sections, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["section_spliter", "num_sections"] + + def check_following(self, value): + """Checks the response contains multiple sections. + + Args: + value: A string representing the response. The response is expected + to contain multiple sections (number of sections is greater than 1). + A new section starts with `Section 1`, where the number denotes the + section index. + + Returns: + True if the number of sections in the response is greater than or equal to + the minimum number of sections; otherwise, False. + """ + section_splitter_patten = r"\s?" + self._section_spliter + r"\s?\d+\s?" + sections = re.split(section_splitter_patten, value) + num_sections = len(sections) - 1 + return num_sections >= self._num_sections + + +class ParagraphChecker(Instruction): + """Checks the paragraphs.""" + + def build_description(self, *, num_paragraphs=None): + """Build the instruction description. + + Args: + num_paragraphs: An integer specifying the number of paragraphs. + + Returns: + A string representing the instruction description. + """ + self._num_paragraphs = num_paragraphs + if self._num_paragraphs is None or self._num_paragraphs < 0: + self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) + + self._description_pattern = ( + "There should be {num_paragraphs} paragraphs. " + "Paragraphs are separated with the markdown divider: ***" + ) + + return self._description_pattern.format(num_paragraphs=self._num_paragraphs) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_paragraphs": self._num_paragraphs} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_paragraphs"] + + def check_following(self, value): + """Checks the response contains required number of paragraphs. + + Args: + value: A string representing the response. The response may contain + paragraphs that are separated by the markdown divider: `***`. + + Returns: + True if the actual number of paragraphs is the same as required; + otherwise, False. + """ + paragraphs = re.split(r"\s?\*\*\*\s?", value) + num_paragraphs = len(paragraphs) + + for index, paragraph in enumerate(paragraphs): + if not paragraph.strip(): + if index == 0 or index == len(paragraphs) - 1: + num_paragraphs -= 1 + else: + return False + + return num_paragraphs == self._num_paragraphs + + +class PostscriptChecker(Instruction): + """Checks the postscript.""" + + def build_description(self, *, postscript_marker=None): + """Build the instruction description. + + Args: + postscript_marker: A string containing the keyword that marks the start + of the postscript section. + + Returns: + A string representing the instruction description. + """ + self._postscript_marker = postscript_marker.strip() if isinstance(postscript_marker, str) else postscript_marker + if self._postscript_marker is None: + self._postscript_marker = random.choice(_POSTSCRIPT_MARKER) + + self._description_pattern = ( + "At the end of your response, please explicitly add a postscript " + "starting with {postscript}" + ) + + return self._description_pattern.format(postscript=self._postscript_marker) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"postscript_marker": self._postscript_marker} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["postscript_marker"] + + def check_following(self, value): + """Checks if the response follows the postscript format. + + Args: + value: a string representing the response. The response is expected to + contain a postscript section. + + Returns: + True if the response contains a postscript section starting with + the keyword containing in the `instruction_args`; otherwise False. + """ + value = value.lower() + if self._postscript_marker == "P.P.S": + postscript_pattern = r"\s*p\.\s?p\.\s?s.*$" + elif self._postscript_marker == "P.S.": + postscript_pattern = r"\s*p\.\s?s\..*$" + else: + postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$" + postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE) + return True if postscript else False + + +class RephraseChecker(Instruction): + """Checks the repharse.""" + + def build_description(self, *, original_message): + """Build the instruction description. + + Args: + original_message: A string representing the original message. The + rephrased response should only change its words/sentences in between + its two asterisks, for example, *change me*. Both original and rephrased + messages should contain the changes in the form of *change me*. + + Returns: + A string representing the instruction description. + """ + if not self.is_change(original_message): + raise ValueError(f"Message {original_message} does not contain changes in the form of *change me*.") + + self._reference_without_change = original_message + self._description = ( + "Rephrasing: Your rephrased response should only" + + "change the words/sentences in between two asterisks" + + "such as *change me*." + ) + return self._description + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"original_message": self._reference_without_change} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["original_message"] + + def check_following(self, value): + r"""Checks if the rephrasing follows the instruction. + + Args: + value: A string representing the response, which is expected to rephras + the string of `instruction_args`. + + Returns: + True if `value` and `instruction_args` only differ by the words/sentences + in between two asterisks such as *change me*; otherwise, False. + """ + + if not self.is_change(value): + raise ValueError(f"value {value} does not contain changes in the form of *change me*.") + + response_without_changes = self.strip_changes(value) + reference_without_changes = self.strip_changes(self._reference_without_change) + + return response_without_changes == reference_without_changes + + def is_change(self, response): + """Check if there is change in the response in the form of *change me*.""" + return re.search(r"\*.*\*", response) + + def strip_changes(self, response): + """Strips off the changes.""" + return re.sub(r"\*.*\*", "", response) + + +class KeywordChecker(Instruction): + """Check the exisitence of certain keywords.""" + + def build_description(self, *, keywords=None): + """Build the instruction description. + + Args: + keywords: A sequence of strings representing the keywords that are + expected in the response. + + Returns: + A string representing the instruction description. + """ + + if not keywords: + self._keywords = generate_keywords(num_keywords=_NUM_KEYWORDS) + else: + self._keywords = keywords + self._keywords = sorted(self._keywords) + + self._description_pattern = "Include keywords {keywords} in the response." + + return self._description_pattern.format(keywords=self._keywords) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"keywords": self._keywords} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keywords"] + + def check_following(self, value): + """Check if the response contain the expected keywords.""" + for keyword in self._keywords: + if not re.search(keyword, value, flags=re.IGNORECASE): + return False + return True + + +class KeywordFrequencyChecker(Instruction): + """Check the keyword frequency.""" + + def build_description(self, *, keyword=None, frequency=None, relation=None): + """Build the instruction description. + + Args: + keyword: A string representing a keyword that is expected in the response. + frequency: An integer specifying the number of times `keyword` is expected + to appear in the response. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of occurrences < frequency; + if 'at least', the actual number of occurrences >= frequency. + + Returns: + A string representing the instruction description. + """ + if not keyword: + self._keyword = generate_keywords(num_keywords=1)[0] + else: + self._keyword = keyword.strip() + + self._frequency = frequency + if self._frequency is None or self._frequency < 0: + self._frequency = random.randint(1, _KEYWORD_FREQUENCY) + + if relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif relation not in _COMPARISON_RELATION: + raise ValueError( + f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." + ) + else: + self._comparison_relation = relation + + self._description_pattern = ( + "In your response, the word {keyword} should appear {relation} " + "{frequency} times." + ) + + return self._description_pattern.format( + keyword=self._keyword, + relation=self._comparison_relation, + frequency=self._frequency, + ) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return { + "keyword": self._keyword, + "frequency": self._frequency, + "relation": self._comparison_relation, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["keyword", "frequency", "relation"] + + def check_following(self, value): + """Checks if the response contain the keyword with required frequency.""" + actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return actual_occurrences < self._frequency + elif self._comparison_relation == _COMPARISON_RELATION[1]: + return actual_occurrences >= self._frequency + + +class NumberOfWords(Instruction): + """Checks the number of words.""" + + def build_description(self, *, num_words=None, relation=None): + """Build the instruction description. + + Args: + num_words: An integer specifying the number of words contained in the + response. + relation: A string in (`less than`, `at least`), defining the relational + operator for comparison. + Two relational comparisons are supported for now: + if 'less than', the actual number of words < num_words; + if 'at least', the actual number of words >= num_words. + + Returns: + A string representing the instruction description. + """ + + self._num_words = num_words + if self._num_words is None or self._num_words < 0: + self._num_words = random.randint(_NUM_WORDS_LOWER_LIMIT, _NUM_WORDS_UPPER_LIMIT) + + if relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif relation not in _COMPARISON_RELATION: + raise ValueError( + f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." + ) + else: + self._comparison_relation = relation + + self._description_pattern = "Answer with {relation} {num_words} words." + + return self._description_pattern.format(relation=self._comparison_relation, num_words=self._num_words) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"num_words": self._num_words, "relation": self._comparison_relation} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_words", "relation"] + + def check_following(self, value): + """Checks if the response contains the expected number of words.""" + lang = get_langid(value) + if lang == "th": + num_words = len(word_tokenize_thai(value)) + elif lang in ["zh", "zh-cn", "zh-tw", "ja", "ko"]: + num_words = count_words_cjk(value) + else: + num_words = count_words(value) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return num_words < self._num_words + elif self._comparison_relation == _COMPARISON_RELATION[1]: + return num_words >= self._num_words diff --git a/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_checkers_format.py b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_checkers_format.py new file mode 100644 index 0000000000..e874562a15 --- /dev/null +++ b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_checkers_format.py @@ -0,0 +1,698 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import collections +import json +import random +import re +import string + +import langdetect +import nltk + +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="scoring") + +from llama_stack.providers.inline.scoring.basic.utils.ifeval_checkers_core import Instruction +from llama_stack.providers.inline.scoring.basic.utils.ifeval_support import ( + _ALL_CAPITAL_WORD_FREQUENCY, + _COMPARISON_RELATION, + _ENDING_OPTIONS, + _LETTER_FREQUENCY, + _NUM_KEYWORDS, + _NUM_PARAGRAPHS, + generate_keywords, + get_langid, + split_into_sentences, +) + + +class JsonFormat(Instruction): + """Check the Json format.""" + + def build_description(self): + self._description_pattern = ( + "Entire output should be wrapped in JSON format. You can use markdown ticks such as ```." + ) + return self._description_pattern + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + value = ( + value.strip() + .removeprefix("```json") + .removeprefix("```Json") + .removeprefix("```JSON") + .removeprefix("```") + .removesuffix("```") + .strip() + ) + try: + json.loads(value) + except ValueError as _: + return False + return True + + +class ParagraphFirstWordCheck(Instruction): + """Check the paragraph and the first word of the nth paragraph.""" + + def build_description(self, num_paragraphs=None, nth_paragraph=None, first_word=None): + r"""Build the instruction description. + + Args: + num_paragraphs: An integer indicating the number of paragraphs expected + in the response. A paragraph is a subset of the string that is + expected to be separated by '\n\n'. + nth_paragraph: An integer indicating the paragraph number that we look at. + Note that n starts from 1. + first_word: A string that represent the first word of the bth paragraph. + + Returns: + A string representing the instruction description. + """ + self._num_paragraphs = num_paragraphs + if self._num_paragraphs is None or self._num_paragraphs < 0: + self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) + + self._nth_paragraph = nth_paragraph + if self._nth_paragraph is None or self._nth_paragraph <= 0 or self._nth_paragraph > self._num_paragraphs: + self._nth_paragraph = random.randint(1, self._num_paragraphs + 1) + + self._first_word = first_word + if self._first_word is None: + self._first_word = generate_keywords(num_keywords=1)[0] + self._first_word = self._first_word.lower() + + self._description_pattern = ( + "There should be {num_paragraphs} paragraphs. " + + "Paragraphs and only paragraphs are separated with each other by two " + + "new lines as if it was '\\n\\n' in python. " + + "Paragraph {nth_paragraph} must start with word {first_word}." + ) + + return self._description_pattern.format( + num_paragraphs=self._num_paragraphs, + nth_paragraph=self._nth_paragraph, + first_word=self._first_word, + ) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return { + "num_paragraphs": self._num_paragraphs, + "nth_paragraph": self._nth_paragraph, + "first_word": self._first_word, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_paragraphs", "nth_paragraph", "first_word"] + + def check_following(self, value): + """Checks for required number of paragraphs and correct first word. + + Args: + value: a string representing the response. The response may contain + paragraphs that are separated by two new lines and the first word of + the nth paragraph will have to match a specified word. + + Returns: + True if the number of paragraphs is the same as required and the first + word of the specified paragraph is the same as required. Otherwise, false. + """ + + paragraphs = re.split(r"\n\n", value) + num_paragraphs = len(paragraphs) + + for paragraph in paragraphs: + if not paragraph.strip(): + num_paragraphs -= 1 + + # check that index doesn't go out of bounds + if self._nth_paragraph <= num_paragraphs: + paragraph = paragraphs[self._nth_paragraph - 1].strip() + if not paragraph: + return False + else: + return False + + first_word = "" + punctuation = {".", ",", "?", "!", "'", '"'} + + # get first word and remove punctuation + word = paragraph.split()[0].strip() + word = word.lstrip("'") + word = word.lstrip('"') + + for letter in word: + if letter in punctuation: + break + first_word += letter.lower() + + return num_paragraphs == self._num_paragraphs and first_word == self._first_word + + +class KeySentenceChecker(Instruction): + """Check the existence of certain key sentences.""" + + def build_description(self, key_sentences=None, num_sentences=None): + """Build the instruction description. + + Args: + key_sentences: A sequences of strings representing the key sentences that + are expected in the response. + num_sentences: The number of key sentences that are expected to be seen in + the response. + + Returns: + A string representing the instruction description. + """ + + if not key_sentences: + self._key_sentences = {["For now, this is fine."]} + else: + self._key_sentences = key_sentences + + if not num_sentences: + self._num_sentences = random.randint(1, len(self._key_sentences)) + else: + self._num_sentences = num_sentences + + self._description_pattern = "Include {num_sentences} of the following sentences {key_sentences}" + + return self._description_pattern.format(num_sentences=self._num_sentences, key_sentences=self._key_sentences) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return { + "num_sentences": self._num_sentences, + "key_sentences": list(self._key_sentences), + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["num_sentences", "key_sentences"] + + def check_following(self, value): + """Checks if the response contains the expected key sentences.""" + count = 0 + sentences = split_into_sentences(value) + for sentence in self._key_sentences: + if sentence in sentences: + count += 1 + + return count == self._num_sentences + + +class ForbiddenWords(Instruction): + """Checks that specified words are not used in response.""" + + def build_description(self, forbidden_words=None): + """Build the instruction description. + + Args: + forbidden_words: A sequences of strings respresenting words that are not + allowed in the response. + + Returns: + A string representing the instruction description. + """ + + if not forbidden_words: + self._forbidden_words = generate_keywords(num_keywords=_NUM_KEYWORDS) + else: + self._forbidden_words = list(set(forbidden_words)) + self._forbidden_words = sorted(self._forbidden_words) + self._description_pattern = "Do not include keywords {forbidden_words} in the response." + + return self._description_pattern.format(forbidden_words=self._forbidden_words) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return {"forbidden_words": self._forbidden_words} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["forbidden_words"] + + def check_following(self, value): + """Check if the response does not contain the expected keywords.""" + for word in self._forbidden_words: + if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE): + return False + return True + + +class RephraseParagraph(Instruction): + """Checks that the paragraph is rephrased.""" + + def build_description(self, *, original_paragraph, low, high): + """Builds the instruction description. + + Args: + original_paragraph: A string presenting the original paragraph. The + rephrases response should have betweeb low-high words in common. + low: An integer presenting the lower bound of similar words. + high: An integer representing the upper bound of similar words. + + Returns: + A string representing the instruction description. + """ + self._original_paragraph = original_paragraph + self._low = low + self._high = high + + self._description = ( + "Rephrase the following paragraph: " + + "{original_paragraph}\nYour response should have " + + "between {low} and {high} of the same words. " + + "Words are the same if and only if all of the " + + "letters, ignoring cases, are the same. For " + + "example, 'run' is the same as 'Run' but different " + + "to 'ran'." + ) + + return self._description.format(original_paragraph=original_paragraph, low=self._low, high=self._high) + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return { + "original_paragraph": self._original_paragraph, + "low": self._low, + "high": self._high, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["original_paragraph", "low", "high"] + + def check_following(self, value): + val_words = re.findall(r"\w+", value.lower()) + original_words = re.findall(r"\w+", self._original_paragraph.lower()) + similar_words = 0 + + dict_val = collections.Counter(val_words) + dict_original = collections.Counter(original_words) + + for word in dict_original: + similar_words += min(dict_original[word], dict_val[word]) + + return similar_words >= self._low and similar_words <= self._high + + +class TwoResponsesChecker(Instruction): + """Check that two responses were given.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Give two different responses. Responses and only responses should" + " be separated by 6 asterisk symbols: ******." + ) + return self._description_pattern + + def get_instruction_args(self): + """Returns the keyward args of `build_description`.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response has two different answers. + + Args: + value: A string representing the response. + + Returns: + True if two responses are detected and false otherwise. + """ + valid_responses = list() + responses = value.split("******") + for index, response in enumerate(responses): + if not response.strip(): + if index != 0 and index != len(responses) - 1: + return False + else: + valid_responses.append(response) + return len(valid_responses) == 2 and valid_responses[0].strip() != valid_responses[1].strip() + + +class RepeatPromptThenAnswer(Instruction): + """Checks that Prompt is first repeated then answered.""" + + def build_description(self, *, prompt_to_repeat=None): + """Build the instruction description. + + Args: + prompt_to_repeat: The prompt that is meant to be repeated. + + Returns: + A string representing the instruction description. + """ + if not prompt_to_repeat: + raise ValueError("prompt_to_repeat must be set.") + else: + self._prompt_to_repeat = prompt_to_repeat + self._description_pattern = ( + "First repeat the request word for word without change," + " then give your answer (1. do not say any words or characters" + " before repeating the request; 2. the request you need to repeat" + " does not include this sentence)" + ) + return self._description_pattern + + def get_instruction_args(self): + return {"prompt_to_repeat": self._prompt_to_repeat} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["prompt_to_repeat"] + + def check_following(self, value): + if value.strip().lower().startswith(self._prompt_to_repeat.strip().lower()): + return True + return False + + +class EndChecker(Instruction): + """Checks that the prompt ends with a given phrase.""" + + def build_description(self, *, end_phrase=None): + """Build the instruction description. + + Args: + end_phrase: A string representing the phrase the response should end with. + + Returns: + A string representing the instruction description. + """ + self._end_phrase = end_phrase.strip() if isinstance(end_phrase, str) else end_phrase + if self._end_phrase is None: + self._end_phrase = random.choice(_ENDING_OPTIONS) + self._description_pattern = ( + "Finish your response with this exact phrase {ender}. No other words should follow this phrase." + ) + return self._description_pattern.format(ender=self._end_phrase) + + def get_instruction_args(self): + return {"end_phrase": self._end_phrase} + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["end_phrase"] + + def check_following(self, value): + """Checks if the response ends with the expected phrase.""" + value = value.strip().strip('"').lower() + self._end_phrase = self._end_phrase.strip().lower() + return value.endswith(self._end_phrase) + + +class TitleChecker(Instruction): + """Checks the response for a title.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Your answer must contain a title, wrapped in double angular brackets, such as <>." + ) + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response contains a title.""" + pattern = r"<<[^\n]+>>" + re_pattern = re.compile(pattern) + titles = re.findall(re_pattern, value) + + for title in titles: + if title.lstrip("<").rstrip(">").strip(): + return True + return False + + +class LetterFrequencyChecker(Instruction): + """Checks letter frequency.""" + + def build_description(self, *, letter=None, let_frequency=None, let_relation=None): + """Build the instruction description. + + Args: + letter: A string representing a letter that is expected in the response. + let_frequency: An integer specifying the number of times `keyword` is + expected to appear in the response. + let_relation: A string in (`less than`, `at least`), defining the + relational operator for comparison. Two relational comparisons are + supported for now; if 'less than', the actual number of + occurrences < frequency; if 'at least', the actual number of + occurrences >= frequency. + + Returns: + A string representing the instruction description. + """ + if not letter or len(letter) > 1 or ord(letter.lower()) < 97 or ord(letter.lower()) > 122: + self._letter = random.choice(list(string.ascii_letters)) + else: + self._letter = letter.strip() + self._letter = self._letter.lower() + + self._frequency = let_frequency + if self._frequency is None or self._frequency < 0: + self._frequency = random.randint(1, _LETTER_FREQUENCY) + + if let_relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif let_relation not in _COMPARISON_RELATION: + raise ValueError( + f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {let_relation} is given." + ) + else: + self._comparison_relation = let_relation + + self._description_pattern = ( + "In your response, the letter {letter} should appear {let_relation} {let_frequency} times." + ) + + return self._description_pattern.format( + letter=self._letter, + let_frequency=self._frequency, + let_relation=self._comparison_relation, + ) + + def get_instruction_args(self): + """Returns the keyword args of build description.""" + return { + "letter": self._letter, + "let_frequency": self._frequency, + "let_relation": self._comparison_relation, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["letter", "let_frequency", "let_relation"] + + def check_following(self, value): + """Checks that the response contains the letter at the right frequency.""" + value = value.lower() + letters = collections.Counter(value) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return letters[self._letter] < self._frequency + else: + return letters[self._letter] >= self._frequency + + +class CapitalLettersEnglishChecker(Instruction): + """Checks that the response is in english and is in all capital letters.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "Your entire response should be in English, and in all capital letters." + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response is in English and in all capital letters.""" + assert isinstance(value, str) + + try: + return value.isupper() and langdetect.detect(value) == "en" + except langdetect.LangDetectException as e: + # Count as instruction is followed. + logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 + return True + + +class LowercaseLettersEnglishChecker(Instruction): + """Checks that the response is in english and is in all lowercase letters.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = ( + "Your entire response should be in English, and in all lowercase letters. No capital letters are allowed." + ) + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response is in English and in all lowercase letters.""" + assert isinstance(value, str) + + try: + return value.islower() and langdetect.detect(value) == "en" + except langdetect.LangDetectException as e: + # Count as instruction is followed. + logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 + return True + + +class CommaChecker(Instruction): + """Checks the response for no commas.""" + + def build_description(self, **kwargs): + """Build the instruction description.""" + self._description_pattern = "In your entire response, refrain from the use of any commas." + return self._description_pattern + + def get_instruction_args(self): + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks that the response does not contain commas.""" + return not re.search(r"\,", value) + + +class CapitalWordFrequencyChecker(Instruction): + """Checks frequency of words with all capital letters.""" + + def build_description( + self, + capital_frequency=None, + capital_relation=None, + ): + """Build the instruction description. + + Args: + capital_frequency: An integer that represents the number of words that + should be in all capital letters. + capital_relation: A string that is 'at least' or 'at most' that refers to + the frequency. + + Returns: + A string representing the instruction description. + """ + self._frequency = capital_frequency + if self._frequency is None: + self._frequency = random.randint(1, _ALL_CAPITAL_WORD_FREQUENCY) + + self._comparison_relation = capital_relation + if capital_relation is None: + self._comparison_relation = random.choice(_COMPARISON_RELATION) + elif capital_relation not in _COMPARISON_RELATION: + raise ValueError( + "The supported relation for comparison must be in " + f"{_COMPARISON_RELATION}, but {capital_relation} is given." + ) + + self._description_pattern = ( + "In your response, words with all capital letters should appear {relation} {frequency} times." + ) + + return self._description_pattern.format(frequency=self._frequency, relation=self._comparison_relation) + + def get_instruction_args(self): + """Returns the keyword args of build description.""" + return { + "capital_frequency": self._frequency, + "capital_relation": self._comparison_relation, + } + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return ["capital_frequency", "capital_relation"] + + def check_following(self, value): + """Checks the frequency of words with all capital letters.""" + # Hyphenated words will count as one word + nltk.download("punkt_tab") + words = nltk.word_tokenize(value) + capital_words = [word for word in words if word.isupper()] + + capital_words = len(capital_words) + + if self._comparison_relation == _COMPARISON_RELATION[0]: + return capital_words < self._frequency + else: + return capital_words >= self._frequency + + +class QuotationChecker(Instruction): + """Checks response is wrapped with double quotation marks.""" + + def build_description(self): + """Build the instruction description.""" + self._description_pattern = "Wrap your entire response with double quotation marks." + return self._description_pattern + + def get_instruction_args(self): + """Returns the keyword args of build description.""" + return None + + def get_instruction_args_keys(self): + """Returns the args keys of `build_description`.""" + return [] + + def check_following(self, value): + """Checks if the response is wrapped with double quotation marks.""" + quotations_map = { + "ja": "「」", + "ru": "«»", + "th": "“”", + "zh": "“”", + "zh-cn": "“”", + "zh-tw": "“”", + } + value = value.strip() + lang = get_langid(value) + quotes = quotations_map.get(lang, '""') + # TODO: We may wanna revisit this logic in new generations to only check of the response language's quotes. + return len(value) > 1 and value[0] in [quotes[0], '"'] and value[-1] in [quotes[1], '"'] + + +# Define instruction dicts diff --git a/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_support.py b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_support.py new file mode 100644 index 0000000000..ff347175e1 --- /dev/null +++ b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_support.py @@ -0,0 +1,289 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import collections +import functools +import random +import re +from collections.abc import Iterable, Sequence +from types import MappingProxyType + +import emoji +import langdetect +import nltk + +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="scoring") + +from llama_stack.providers.inline.scoring.basic.utils.ifeval_word_list import WORD_LIST + +# ISO 639-1 codes to language names. +LANGUAGE_CODES = MappingProxyType( + { + "en": "English", + "es": "Spanish", + "pt": "Portuguese", + "ar": "Arabic", + "hi": "Hindi", + "fr": "French", + "ru": "Russian", + "de": "German", + "ja": "Japanese", + "it": "Italian", + "bn": "Bengali", + "uk": "Ukrainian", + "th": "Thai", + "ur": "Urdu", + "ta": "Tamil", + "te": "Telugu", + "bg": "Bulgarian", + "ko": "Korean", + "pl": "Polish", + "he": "Hebrew", + "fa": "Persian", + "vi": "Vietnamese", + "ne": "Nepali", + "sw": "Swahili", + "kn": "Kannada", + "mr": "Marathi", + "gu": "Gujarati", + "pa": "Punjabi", + "ml": "Malayalam", + "fi": "Finnish", + } +) + +# Chinese characters +_CHINESE_CHARS_PATTERN = r"[\u4E00-\u9FFF\u3400-\u4DBF]" +# Japanese Hiragana & Katakana +_JAPANESE_CHARS_PATTERN = r"[\u3040-\u309f\u30a0-\u30ff]" +# Korean (Hangul Syllables) +_KOREAN_CHARS_PATTERN = r"[\uAC00-\uD7AF]" +_ALPHABETS = "([A-Za-z])" +_PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]" +_SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)" +_STARTERS = ( + r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" +) +_ACRONYMS = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" +_WEBSITES = "[.](com|net|org|io|gov|edu|me)" +_DIGITS = "([0-9])" +_MULTIPLE_DOTS = r"\.{2,}" + + +# Util functions +def split_into_sentences(text): + """Split the text into sentences. + + Args: + text: A string that consists of more than or equal to one sentences. + + Returns: + A list of strings where each string is a sentence. + """ + text = " " + text + " " + text = text.replace("\n", " ") + text = re.sub(_PREFIXES, "\\1", text) + text = re.sub(_WEBSITES, "\\1", text) + text = re.sub(_DIGITS + "[.]" + _DIGITS, "\\1\\2", text) + text = re.sub( + _MULTIPLE_DOTS, + lambda match: "" * len(match.group(0)) + "", + text, + ) + if "Ph.D" in text: + text = text.replace("Ph.D.", "PhD") + text = re.sub(r"\s" + _ALPHABETS + "[.] ", " \\1 ", text) + text = re.sub(_ACRONYMS + " " + _STARTERS, "\\1 \\2", text) + text = re.sub( + _ALPHABETS + "[.]" + _ALPHABETS + "[.]" + _ALPHABETS + "[.]", + "\\1\\2\\3", + text, + ) + text = re.sub(_ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1\\2", text) + text = re.sub(" " + _SUFFIXES + "[.] " + _STARTERS, " \\1 \\2", text) + text = re.sub(" " + _SUFFIXES + "[.]", " \\1", text) + text = re.sub(" " + _ALPHABETS + "[.]", " \\1", text) + if "”" in text: + text = text.replace(".”", "”.") + if '"' in text: + text = text.replace('."', '".') + if "!" in text: + text = text.replace('!"', '"!') + if "?" in text: + text = text.replace('?"', '"?') + text = text.replace(".", ".") + text = text.replace("?", "?") + text = text.replace("!", "!") + text = text.replace("", ".") + sentences = text.split("") + sentences = [s.strip() for s in sentences] + if sentences and not sentences[-1]: + sentences = sentences[:-1] + return sentences + + +def count_words(text): + """Counts the number of words.""" + tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+") + tokens = tokenizer.tokenize(text) + num_words = len(tokens) + return num_words + + +def split_chinese_japanese_hindi(lines: str) -> Iterable[str]: + """ + Split Chinese and Japanese text into sentences. + From https://stackoverflow.com/questions/27441191/splitting-chinese-document-into-sentences + Special question/exclamation marks were added upon inspection of our raw data, + Also supports multiple lines. + The separator for hindi is '।' + """ + for line in lines.splitlines(): + yield from re.findall( + r"[^!?。\.\!\?\!\?\.\n।]+[!?。\.\!\?\!\?\.\n।]?", + line.strip(), + flags=re.U, + ) + + +def count_words_cjk(text: str) -> int: + """Counts the number of words for Chinese and Japanese and Korean. + Can be extended to additional languages. + Source: https://stackoverflow.com/questions/49164507/how-to-count-the-number-of-chinese-korean-and-english-words withadditional modifications + Example: + >In: count_words_cjk('こんにちは、ジェイソンさん、Jason? Nice to meet you☺ ❤') + >Out: 19 + """ + # Non alpha numeric patterns in latin and asian languages. + non_alphanumeric_patterns = ( + r"[\\.\!\?\.\/_,\{\}<>:;$%^&*(+\"\'+——!,。?、`~@#¥……():;《)《》“”()\[\]«»〔〕\-「」]+" + ) + text = re.sub(non_alphanumeric_patterns, "", text) + + emoji_cnt = emoji.emoji_count(text) # count emojis + text = emoji.replace_emoji(text, "") # remove emojis + + foreign_chars_patterns = "|".join([_CHINESE_CHARS_PATTERN, _JAPANESE_CHARS_PATTERN, _KOREAN_CHARS_PATTERN]) + asian_chars = re.findall(foreign_chars_patterns, text) + asian_chars_cnt = len(asian_chars) + non_asian_chars = re.sub(foreign_chars_patterns, " ", text) + non_asian_words_cnt = len(non_asian_chars.split()) + + return non_asian_words_cnt + asian_chars_cnt + emoji_cnt + + +@functools.cache +def _get_sentence_tokenizer(): + return nltk.data.load("nltk:tokenizers/punkt/english.pickle") + + +def count_sentences(text): + """Count the number of sentences.""" + tokenizer = _get_sentence_tokenizer() + tokenized_sentences = tokenizer.tokenize(text) + return len(tokenized_sentences) + + +def get_langid(text: str, lid_path: str | None = None) -> str: + line_langs: list[str] = [] + lines = [line.strip() for line in text.split("\n") if len(line.strip()) >= 4] + + for line in lines: + try: + line_langs.append(langdetect.detect(line)) + except langdetect.LangDetectException as e: + logger.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037 + + if len(line_langs) == 0: + return "en" + # select the text language to be the most commonly predicted language of the lines. + return collections.Counter(line_langs).most_common(1)[0][0] + + +def generate_keywords(num_keywords): + """Randomly generates a few keywords.""" + return random.sample(WORD_LIST, k=num_keywords) + + +"""Library of instructions""" +_InstructionArgsDtype = dict[str, int | str | Sequence[str]] | None + +_LANGUAGES = LANGUAGE_CODES + +# The relational operation for comparison. +_COMPARISON_RELATION = ("less than", "at least") + +# The maximum number of sentences. +_MAX_NUM_SENTENCES = 20 + +# The number of placeholders. +_NUM_PLACEHOLDERS = 4 + +# The number of bullet lists. +_NUM_BULLETS = 5 + +# The options of constrained response. +_CONSTRAINED_RESPONSE_OPTIONS = ( + "My answer is yes.", + "My answer is no.", + "My answer is maybe.", +) + +# The options of starter keywords. +_STARTER_OPTIONS = ( + "I would say", + "My answer is", + "I believe", + "In my opinion", + "I think", + "I reckon", + "I feel", + "From my perspective", + "As I see it", + "According to me", + "As far as I'm concerned", + "To my understanding", + "In my view", + "My take on it is", + "As per my perception", +) + +# The options of ending keywords. +# TODO(jeffreyzhou) add more ending options +_ENDING_OPTIONS = ("Any other questions?", "Is there anything else I can help with?") + +# The number of highlighted sections. +_NUM_HIGHLIGHTED_SECTIONS = 4 + +# The section spliter. +_SECTION_SPLITER = ("Section", "SECTION") + +# The number of sections. +_NUM_SECTIONS = 5 + +# The number of paragraphs. +_NUM_PARAGRAPHS = 5 + +# The postscript marker. +_POSTSCRIPT_MARKER = ("P.S.", "P.P.S") + +# The number of keywords. +_NUM_KEYWORDS = 2 + +# The occurrences of a single keyword. +_KEYWORD_FREQUENCY = 3 + +# The occurrences of a single letter. +_LETTER_FREQUENCY = 10 + +# The occurrences of words with all capital letters. +_ALL_CAPITAL_WORD_FREQUENCY = 20 + +# The number of words in the response. +_NUM_WORDS_LOWER_LIMIT = 100 +_NUM_WORDS_UPPER_LIMIT = 500 diff --git a/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py index 886488ec18..acce13ea22 100644 --- a/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py +++ b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py @@ -4,3273 +4,36 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import collections -import functools -import json -import random -import re -import string -from collections.abc import Iterable, Sequence -from types import MappingProxyType - -import emoji -import langdetect -import nltk -from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai -from pythainlp.tokenize import word_tokenize as word_tokenize_thai - -from llama_stack.log import get_logger - -logger = get_logger(name=__name__, category="scoring") - -WORD_LIST = [ - "western", - "sentence", - "signal", - "dump", - "spot", - "opposite", - "bottom", - "potato", - "administration", - "working", - "welcome", - "morning", - "good", - "agency", - "primary", - "wish", - "responsibility", - "press", - "problem", - "president", - "steal", - "brush", - "read", - "type", - "beat", - "trainer", - "growth", - "lock", - "bone", - "case", - "equal", - "comfortable", - "region", - "replacement", - "performance", - "mate", - "walk", - "medicine", - "film", - "thing", - "rock", - "tap", - "total", - "competition", - "ease", - "south", - "establishment", - "gather", - "parking", - "world", - "plenty", - "breath", - "claim", - "alcohol", - "trade", - "dear", - "highlight", - "street", - "matter", - "decision", - "mess", - "agreement", - "studio", - "coach", - "assist", - "brain", - "wing", - "style", - "private", - "top", - "brown", - "leg", - "buy", - "procedure", - "method", - "speed", - "high", - "company", - "valuable", - "pie", - "analyst", - "session", - "pattern", - "district", - "pleasure", - "dinner", - "swimming", - "joke", - "order", - "plate", - "department", - "motor", - "cell", - "spend", - "cabinet", - "difference", - "power", - "examination", - "engine", - "horse", - "dimension", - "pay", - "toe", - "curve", - "literature", - "bother", - "fire", - "possibility", - "debate", - "activity", - "passage", - "hello", - "cycle", - "background", - "quiet", - "author", - "effect", - "actor", - "page", - "bicycle", - "error", - "throat", - "attack", - "character", - "phone", - "tea", - "increase", - "outcome", - "file", - "specific", - "inspector", - "internal", - "potential", - "staff", - "building", - "employer", - "shoe", - "hand", - "direction", - "garden", - "purchase", - "interview", - "study", - "recognition", - "member", - "spiritual", - "oven", - "sandwich", - "weird", - "passenger", - "particular", - "response", - "reaction", - "size", - "variation", - "a", - "cancel", - "candy", - "exit", - "guest", - "condition", - "fly", - "price", - "weakness", - "convert", - "hotel", - "great", - "mouth", - "mind", - "song", - "sugar", - "suspect", - "telephone", - "ear", - "roof", - "paint", - "refrigerator", - "organization", - "jury", - "reward", - "engineering", - "day", - "possession", - "crew", - "bar", - "road", - "description", - "celebration", - "score", - "mark", - "letter", - "shower", - "suggestion", - "sir", - "luck", - "national", - "progress", - "hall", - "stroke", - "theory", - "offer", - "story", - "tax", - "definition", - "history", - "ride", - "medium", - "opening", - "glass", - "elevator", - "stomach", - "question", - "ability", - "leading", - "village", - "computer", - "city", - "grand", - "confidence", - "candle", - "priest", - "recommendation", - "point", - "necessary", - "body", - "desk", - "secret", - "horror", - "noise", - "culture", - "warning", - "water", - "round", - "diet", - "flower", - "bus", - "tough", - "permission", - "week", - "prompt", - "connection", - "abuse", - "height", - "save", - "corner", - "border", - "stress", - "drive", - "stop", - "rip", - "meal", - "listen", - "confusion", - "girlfriend", - "living", - "relation", - "significance", - "plan", - "creative", - "atmosphere", - "blame", - "invite", - "housing", - "paper", - "drink", - "roll", - "silver", - "drunk", - "age", - "damage", - "smoke", - "environment", - "pack", - "savings", - "influence", - "tourist", - "rain", - "post", - "sign", - "grandmother", - "run", - "profit", - "push", - "clerk", - "final", - "wine", - "swim", - "pause", - "stuff", - "singer", - "funeral", - "average", - "source", - "scene", - "tradition", - "personal", - "snow", - "nobody", - "distance", - "sort", - "sensitive", - "animal", - "major", - "negotiation", - "click", - "mood", - "period", - "arrival", - "expression", - "holiday", - "repeat", - "dust", - "closet", - "gold", - "bad", - "sail", - "combination", - "clothes", - "emphasis", - "duty", - "black", - "step", - "school", - "jump", - "document", - "professional", - "lip", - "chemical", - "front", - "wake", - "while", - "inside", - "watch", - "row", - "subject", - "penalty", - "balance", - "possible", - "adult", - "aside", - "sample", - "appeal", - "wedding", - "depth", - "king", - "award", - "wife", - "blow", - "site", - "camp", - "music", - "safe", - "gift", - "fault", - "guess", - "act", - "shame", - "drama", - "capital", - "exam", - "stupid", - "record", - "sound", - "swing", - "novel", - "minimum", - "ratio", - "machine", - "shape", - "lead", - "operation", - "salary", - "cloud", - "affair", - "hit", - "chapter", - "stage", - "quantity", - "access", - "army", - "chain", - "traffic", - "kick", - "analysis", - "airport", - "time", - "vacation", - "philosophy", - "ball", - "chest", - "thanks", - "place", - "mountain", - "advertising", - "red", - "past", - "rent", - "return", - "tour", - "house", - "construction", - "net", - "native", - "war", - "figure", - "fee", - "spray", - "user", - "dirt", - "shot", - "task", - "stick", - "friend", - "software", - "promotion", - "interaction", - "surround", - "block", - "purpose", - "practice", - "conflict", - "routine", - "requirement", - "bonus", - "hole", - "state", - "junior", - "sweet", - "catch", - "tear", - "fold", - "wall", - "editor", - "life", - "position", - "pound", - "respect", - "bathroom", - "coat", - "script", - "job", - "teach", - "birth", - "view", - "resolve", - "theme", - "employee", - "doubt", - "market", - "education", - "serve", - "recover", - "tone", - "harm", - "miss", - "union", - "understanding", - "cow", - "river", - "association", - "concept", - "training", - "recipe", - "relationship", - "reserve", - "depression", - "proof", - "hair", - "revenue", - "independent", - "lift", - "assignment", - "temporary", - "amount", - "loss", - "edge", - "track", - "check", - "rope", - "estimate", - "pollution", - "stable", - "message", - "delivery", - "perspective", - "mirror", - "assistant", - "representative", - "witness", - "nature", - "judge", - "fruit", - "tip", - "devil", - "town", - "emergency", - "upper", - "drop", - "stay", - "human", - "neck", - "speaker", - "network", - "sing", - "resist", - "league", - "trip", - "signature", - "lawyer", - "importance", - "gas", - "choice", - "engineer", - "success", - "part", - "external", - "worker", - "simple", - "quarter", - "student", - "heart", - "pass", - "spite", - "shift", - "rough", - "lady", - "grass", - "community", - "garage", - "youth", - "standard", - "skirt", - "promise", - "blind", - "television", - "disease", - "commission", - "positive", - "energy", - "calm", - "presence", - "tune", - "basis", - "preference", - "head", - "common", - "cut", - "somewhere", - "presentation", - "current", - "thought", - "revolution", - "effort", - "master", - "implement", - "republic", - "floor", - "principle", - "stranger", - "shoulder", - "grade", - "button", - "tennis", - "police", - "collection", - "account", - "register", - "glove", - "divide", - "professor", - "chair", - "priority", - "combine", - "peace", - "extension", - "maybe", - "evening", - "frame", - "sister", - "wave", - "code", - "application", - "mouse", - "match", - "counter", - "bottle", - "half", - "cheek", - "resolution", - "back", - "knowledge", - "make", - "discussion", - "screw", - "length", - "accident", - "battle", - "dress", - "knee", - "log", - "package", - "it", - "turn", - "hearing", - "newspaper", - "layer", - "wealth", - "profile", - "imagination", - "answer", - "weekend", - "teacher", - "appearance", - "meet", - "bike", - "rise", - "belt", - "crash", - "bowl", - "equivalent", - "support", - "image", - "poem", - "risk", - "excitement", - "remote", - "secretary", - "public", - "produce", - "plane", - "display", - "money", - "sand", - "situation", - "punch", - "customer", - "title", - "shake", - "mortgage", - "option", - "number", - "pop", - "window", - "extent", - "nothing", - "experience", - "opinion", - "departure", - "dance", - "indication", - "boy", - "material", - "band", - "leader", - "sun", - "beautiful", - "muscle", - "farmer", - "variety", - "fat", - "handle", - "director", - "opportunity", - "calendar", - "outside", - "pace", - "bath", - "fish", - "consequence", - "put", - "owner", - "go", - "doctor", - "information", - "share", - "hurt", - "protection", - "career", - "finance", - "force", - "golf", - "garbage", - "aspect", - "kid", - "food", - "boot", - "milk", - "respond", - "objective", - "reality", - "raw", - "ring", - "mall", - "one", - "impact", - "area", - "news", - "international", - "series", - "impress", - "mother", - "shelter", - "strike", - "loan", - "month", - "seat", - "anything", - "entertainment", - "familiar", - "clue", - "year", - "glad", - "supermarket", - "natural", - "god", - "cost", - "conversation", - "tie", - "ruin", - "comfort", - "earth", - "storm", - "percentage", - "assistance", - "budget", - "strength", - "beginning", - "sleep", - "other", - "young", - "unit", - "fill", - "store", - "desire", - "hide", - "value", - "cup", - "maintenance", - "nurse", - "function", - "tower", - "role", - "class", - "camera", - "database", - "panic", - "nation", - "basket", - "ice", - "art", - "spirit", - "chart", - "exchange", - "feedback", - "statement", - "reputation", - "search", - "hunt", - "exercise", - "nasty", - "notice", - "male", - "yard", - "annual", - "collar", - "date", - "platform", - "plant", - "fortune", - "passion", - "friendship", - "spread", - "cancer", - "ticket", - "attitude", - "island", - "active", - "object", - "service", - "buyer", - "bite", - "card", - "face", - "steak", - "proposal", - "patient", - "heat", - "rule", - "resident", - "broad", - "politics", - "west", - "knife", - "expert", - "girl", - "design", - "salt", - "baseball", - "grab", - "inspection", - "cousin", - "couple", - "magazine", - "cook", - "dependent", - "security", - "chicken", - "version", - "currency", - "ladder", - "scheme", - "kitchen", - "employment", - "local", - "attention", - "manager", - "fact", - "cover", - "sad", - "guard", - "relative", - "county", - "rate", - "lunch", - "program", - "initiative", - "gear", - "bridge", - "breast", - "talk", - "dish", - "guarantee", - "beer", - "vehicle", - "reception", - "woman", - "substance", - "copy", - "lecture", - "advantage", - "park", - "cold", - "death", - "mix", - "hold", - "scale", - "tomorrow", - "blood", - "request", - "green", - "cookie", - "church", - "strip", - "forever", - "beyond", - "debt", - "tackle", - "wash", - "following", - "feel", - "maximum", - "sector", - "sea", - "property", - "economics", - "menu", - "bench", - "try", - "language", - "start", - "call", - "solid", - "address", - "income", - "foot", - "senior", - "honey", - "few", - "mixture", - "cash", - "grocery", - "link", - "map", - "form", - "factor", - "pot", - "model", - "writer", - "farm", - "winter", - "skill", - "anywhere", - "birthday", - "policy", - "release", - "husband", - "lab", - "hurry", - "mail", - "equipment", - "sink", - "pair", - "driver", - "consideration", - "leather", - "skin", - "blue", - "boat", - "sale", - "brick", - "two", - "feed", - "square", - "dot", - "rush", - "dream", - "location", - "afternoon", - "manufacturer", - "control", - "occasion", - "trouble", - "introduction", - "advice", - "bet", - "eat", - "kill", - "category", - "manner", - "office", - "estate", - "pride", - "awareness", - "slip", - "crack", - "client", - "nail", - "shoot", - "membership", - "soft", - "anybody", - "web", - "official", - "individual", - "pizza", - "interest", - "bag", - "spell", - "profession", - "queen", - "deal", - "resource", - "ship", - "guy", - "chocolate", - "joint", - "formal", - "upstairs", - "car", - "resort", - "abroad", - "dealer", - "associate", - "finger", - "surgery", - "comment", - "team", - "detail", - "crazy", - "path", - "tale", - "initial", - "arm", - "radio", - "demand", - "single", - "draw", - "yellow", - "contest", - "piece", - "quote", - "pull", - "commercial", - "shirt", - "contribution", - "cream", - "channel", - "suit", - "discipline", - "instruction", - "concert", - "speech", - "low", - "effective", - "hang", - "scratch", - "industry", - "breakfast", - "lay", - "join", - "metal", - "bedroom", - "minute", - "product", - "rest", - "temperature", - "many", - "give", - "argument", - "print", - "purple", - "laugh", - "health", - "credit", - "investment", - "sell", - "setting", - "lesson", - "egg", - "middle", - "marriage", - "level", - "evidence", - "phrase", - "love", - "self", - "benefit", - "guidance", - "affect", - "you", - "dad", - "anxiety", - "special", - "boyfriend", - "test", - "blank", - "payment", - "soup", - "obligation", - "reply", - "smile", - "deep", - "complaint", - "addition", - "review", - "box", - "towel", - "minor", - "fun", - "soil", - "issue", - "cigarette", - "internet", - "gain", - "tell", - "entry", - "spare", - "incident", - "family", - "refuse", - "branch", - "can", - "pen", - "grandfather", - "constant", - "tank", - "uncle", - "climate", - "ground", - "volume", - "communication", - "kind", - "poet", - "child", - "screen", - "mine", - "quit", - "gene", - "lack", - "charity", - "memory", - "tooth", - "fear", - "mention", - "marketing", - "reveal", - "reason", - "court", - "season", - "freedom", - "land", - "sport", - "audience", - "classroom", - "law", - "hook", - "win", - "carry", - "eye", - "smell", - "distribution", - "research", - "country", - "dare", - "hope", - "whereas", - "stretch", - "library", - "if", - "delay", - "college", - "plastic", - "book", - "present", - "use", - "worry", - "champion", - "goal", - "economy", - "march", - "election", - "reflection", - "midnight", - "slide", - "inflation", - "action", - "challenge", - "guitar", - "coast", - "apple", - "campaign", - "field", - "jacket", - "sense", - "way", - "visual", - "remove", - "weather", - "trash", - "cable", - "regret", - "buddy", - "beach", - "historian", - "courage", - "sympathy", - "truck", - "tension", - "permit", - "nose", - "bed", - "son", - "person", - "base", - "meat", - "usual", - "air", - "meeting", - "worth", - "game", - "independence", - "physical", - "brief", - "play", - "raise", - "board", - "she", - "key", - "writing", - "pick", - "command", - "party", - "yesterday", - "spring", - "candidate", - "physics", - "university", - "concern", - "development", - "change", - "string", - "target", - "instance", - "room", - "bitter", - "bird", - "football", - "normal", - "split", - "impression", - "wood", - "long", - "meaning", - "stock", - "cap", - "leadership", - "media", - "ambition", - "fishing", - "essay", - "salad", - "repair", - "today", - "designer", - "night", - "bank", - "drawing", - "inevitable", - "phase", - "vast", - "chip", - "anger", - "switch", - "cry", - "twist", - "personality", - "attempt", - "storage", - "being", - "preparation", - "bat", - "selection", - "white", - "technology", - "contract", - "side", - "section", - "station", - "till", - "structure", - "tongue", - "taste", - "truth", - "difficulty", - "group", - "limit", - "main", - "move", - "feeling", - "light", - "example", - "mission", - "might", - "wait", - "wheel", - "shop", - "host", - "classic", - "alternative", - "cause", - "agent", - "consist", - "table", - "airline", - "text", - "pool", - "craft", - "range", - "fuel", - "tool", - "partner", - "load", - "entrance", - "deposit", - "hate", - "article", - "video", - "summer", - "feature", - "extreme", - "mobile", - "hospital", - "flight", - "fall", - "pension", - "piano", - "fail", - "result", - "rub", - "gap", - "system", - "report", - "suck", - "ordinary", - "wind", - "nerve", - "ask", - "shine", - "note", - "line", - "mom", - "perception", - "brother", - "reference", - "bend", - "charge", - "treat", - "trick", - "term", - "homework", - "bake", - "bid", - "status", - "project", - "strategy", - "orange", - "let", - "enthusiasm", - "parent", - "concentrate", - "device", - "travel", - "poetry", - "business", - "society", - "kiss", - "end", - "vegetable", - "employ", - "schedule", - "hour", - "brave", - "focus", - "process", - "movie", - "illegal", - "general", - "coffee", - "ad", - "highway", - "chemistry", - "psychology", - "hire", - "bell", - "conference", - "relief", - "show", - "neat", - "funny", - "weight", - "quality", - "club", - "daughter", - "zone", - "touch", - "tonight", - "shock", - "burn", - "excuse", - "name", - "survey", - "landscape", - "advance", - "satisfaction", - "bread", - "disaster", - "item", - "hat", - "prior", - "shopping", - "visit", - "east", - "photo", - "home", - "idea", - "father", - "comparison", - "cat", - "pipe", - "winner", - "count", - "lake", - "fight", - "prize", - "foundation", - "dog", - "keep", - "ideal", - "fan", - "struggle", - "peak", - "safety", - "solution", - "hell", - "conclusion", - "population", - "strain", - "alarm", - "measurement", - "second", - "train", - "race", - "due", - "insurance", - "boss", - "tree", - "monitor", - "sick", - "course", - "drag", - "appointment", - "slice", - "still", - "care", - "patience", - "rich", - "escape", - "emotion", - "royal", - "female", - "childhood", - "government", - "picture", - "will", - "sock", - "big", - "gate", - "oil", - "cross", - "pin", - "improvement", - "championship", - "silly", - "help", - "sky", - "pitch", - "man", - "diamond", - "most", - "transition", - "work", - "science", - "committee", - "moment", - "fix", - "teaching", - "dig", - "specialist", - "complex", - "guide", - "people", - "dead", - "voice", - "original", - "break", - "topic", - "data", - "degree", - "reading", - "recording", - "bunch", - "reach", - "judgment", - "lie", - "regular", - "set", - "painting", - "mode", - "list", - "player", - "bear", - "north", - "wonder", - "carpet", - "heavy", - "officer", - "negative", - "clock", - "unique", - "baby", - "pain", - "assumption", - "disk", - "iron", - "bill", - "drawer", - "look", - "double", - "mistake", - "finish", - "future", - "brilliant", - "contact", - "math", - "rice", - "leave", - "restaurant", - "discount", - "sex", - "virus", - "bit", - "trust", - "event", - "wear", - "juice", - "failure", - "bug", - "context", - "mud", - "whole", - "wrap", - "intention", - "draft", - "pressure", - "cake", - "dark", - "explanation", - "space", - "angle", - "word", - "efficiency", - "management", - "habit", - "star", - "chance", - "finding", - "transportation", - "stand", - "criticism", - "flow", - "door", - "injury", - "insect", - "surprise", - "apartment", -] # pylint: disable=line-too-long - -# ISO 639-1 codes to language names. -LANGUAGE_CODES = MappingProxyType( - { - "en": "English", - "es": "Spanish", - "pt": "Portuguese", - "ar": "Arabic", - "hi": "Hindi", - "fr": "French", - "ru": "Russian", - "de": "German", - "ja": "Japanese", - "it": "Italian", - "bn": "Bengali", - "uk": "Ukrainian", - "th": "Thai", - "ur": "Urdu", - "ta": "Tamil", - "te": "Telugu", - "bg": "Bulgarian", - "ko": "Korean", - "pl": "Polish", - "he": "Hebrew", - "fa": "Persian", - "vi": "Vietnamese", - "ne": "Nepali", - "sw": "Swahili", - "kn": "Kannada", - "mr": "Marathi", - "gu": "Gujarati", - "pa": "Punjabi", - "ml": "Malayalam", - "fi": "Finnish", - } +from llama_stack.providers.inline.scoring.basic.utils.ifeval_checkers_core import ( + BulletListChecker, + ConstrainedResponseChecker, + HighlightSectionChecker, + KeywordChecker, + KeywordFrequencyChecker, + NumberOfSentences, + NumberOfWords, + ParagraphChecker, + PlaceholderChecker, + PostscriptChecker, + ResponseLanguageChecker, + SectionChecker, ) - -# Chinese characters -_CHINESE_CHARS_PATTERN = r"[\u4E00-\u9FFF\u3400-\u4DBF]" -# Japanese Hiragana & Katakana -_JAPANESE_CHARS_PATTERN = r"[\u3040-\u309f\u30a0-\u30ff]" -# Korean (Hangul Syllables) -_KOREAN_CHARS_PATTERN = r"[\uAC00-\uD7AF]" -_ALPHABETS = "([A-Za-z])" -_PREFIXES = "(Mr|St|Mrs|Ms|Dr)[.]" -_SUFFIXES = "(Inc|Ltd|Jr|Sr|Co)" -_STARTERS = ( - r"(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" +from llama_stack.providers.inline.scoring.basic.utils.ifeval_checkers_format import ( + CapitalLettersEnglishChecker, + CapitalWordFrequencyChecker, + CommaChecker, + EndChecker, + ForbiddenWords, + JsonFormat, + LetterFrequencyChecker, + LowercaseLettersEnglishChecker, + ParagraphFirstWordCheck, + QuotationChecker, + RepeatPromptThenAnswer, + TitleChecker, + TwoResponsesChecker, ) -_ACRONYMS = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" -_WEBSITES = "[.](com|net|org|io|gov|edu|me)" -_DIGITS = "([0-9])" -_MULTIPLE_DOTS = r"\.{2,}" - - -# Util functions -def split_into_sentences(text): - """Split the text into sentences. - - Args: - text: A string that consists of more than or equal to one sentences. - - Returns: - A list of strings where each string is a sentence. - """ - text = " " + text + " " - text = text.replace("\n", " ") - text = re.sub(_PREFIXES, "\\1", text) - text = re.sub(_WEBSITES, "\\1", text) - text = re.sub(_DIGITS + "[.]" + _DIGITS, "\\1\\2", text) - text = re.sub( - _MULTIPLE_DOTS, - lambda match: "" * len(match.group(0)) + "", - text, - ) - if "Ph.D" in text: - text = text.replace("Ph.D.", "PhD") - text = re.sub(r"\s" + _ALPHABETS + "[.] ", " \\1 ", text) - text = re.sub(_ACRONYMS + " " + _STARTERS, "\\1 \\2", text) - text = re.sub( - _ALPHABETS + "[.]" + _ALPHABETS + "[.]" + _ALPHABETS + "[.]", - "\\1\\2\\3", - text, - ) - text = re.sub(_ALPHABETS + "[.]" + _ALPHABETS + "[.]", "\\1\\2", text) - text = re.sub(" " + _SUFFIXES + "[.] " + _STARTERS, " \\1 \\2", text) - text = re.sub(" " + _SUFFIXES + "[.]", " \\1", text) - text = re.sub(" " + _ALPHABETS + "[.]", " \\1", text) - if "”" in text: - text = text.replace(".”", "”.") - if '"' in text: - text = text.replace('."', '".') - if "!" in text: - text = text.replace('!"', '"!') - if "?" in text: - text = text.replace('?"', '"?') - text = text.replace(".", ".") - text = text.replace("?", "?") - text = text.replace("!", "!") - text = text.replace("", ".") - sentences = text.split("") - sentences = [s.strip() for s in sentences] - if sentences and not sentences[-1]: - sentences = sentences[:-1] - return sentences - - -def count_words(text): - """Counts the number of words.""" - tokenizer = nltk.tokenize.RegexpTokenizer(r"\w+") - tokens = tokenizer.tokenize(text) - num_words = len(tokens) - return num_words - - -def split_chinese_japanese_hindi(lines: str) -> Iterable[str]: - """ - Split Chinese and Japanese text into sentences. - From https://stackoverflow.com/questions/27441191/splitting-chinese-document-into-sentences - Special question/exclamation marks were added upon inspection of our raw data, - Also supports multiple lines. - The separator for hindi is '।' - """ - for line in lines.splitlines(): - yield from re.findall( - r"[^!?。\.\!\?\!\?\.\n।]+[!?。\.\!\?\!\?\.\n।]?", - line.strip(), - flags=re.U, - ) - - -def count_words_cjk(text: str) -> int: - """Counts the number of words for Chinese and Japanese and Korean. - Can be extended to additional languages. - Source: https://stackoverflow.com/questions/49164507/how-to-count-the-number-of-chinese-korean-and-english-words withadditional modifications - Example: - >In: count_words_cjk('こんにちは、ジェイソンさん、Jason? Nice to meet you☺ ❤') - >Out: 19 - """ - # Non alpha numeric patterns in latin and asian languages. - non_alphanumeric_patterns = ( - r"[\\.\!\?\.\/_,\{\}<>:;$%^&*(+\"\'+——!,。?、`~@#¥……():;《)《》“”()\[\]«»〔〕\-「」]+" - ) - text = re.sub(non_alphanumeric_patterns, "", text) - - emoji_cnt = emoji.emoji_count(text) # count emojis - text = emoji.replace_emoji(text, "") # remove emojis - - foreign_chars_patterns = "|".join([_CHINESE_CHARS_PATTERN, _JAPANESE_CHARS_PATTERN, _KOREAN_CHARS_PATTERN]) - asian_chars = re.findall(foreign_chars_patterns, text) - asian_chars_cnt = len(asian_chars) - non_asian_chars = re.sub(foreign_chars_patterns, " ", text) - non_asian_words_cnt = len(non_asian_chars.split()) - - return non_asian_words_cnt + asian_chars_cnt + emoji_cnt - - -@functools.cache -def _get_sentence_tokenizer(): - return nltk.data.load("nltk:tokenizers/punkt/english.pickle") - - -def count_sentences(text): - """Count the number of sentences.""" - tokenizer = _get_sentence_tokenizer() - tokenized_sentences = tokenizer.tokenize(text) - return len(tokenized_sentences) - - -def get_langid(text: str, lid_path: str | None = None) -> str: - """Detect the primary language of a text using per-line language detection. - - Args: - text: input text to analyze - lid_path: unused, kept for interface compatibility - - Returns: - ISO 639-1 language code, defaulting to "en" if detection fails - """ - line_langs: list[str] = [] - lines = [line.strip() for line in text.split("\n") if len(line.strip()) >= 4] - - for line in lines: - try: - line_langs.append(langdetect.detect(line)) - except langdetect.LangDetectException as e: - logger.info("Unable to detect language for text %s due to %s", line, e) # refex: disable=pytotw.037 - - if len(line_langs) == 0: - return "en" - # select the text language to be the most commonly predicted language of the lines. - return collections.Counter(line_langs).most_common(1)[0][0] - - -def generate_keywords(num_keywords): - """Randomly generates a few keywords.""" - return random.sample(WORD_LIST, k=num_keywords) - - -"""Library of instructions""" -_InstructionArgsDtype = dict[str, int | str | Sequence[str]] | None - -_LANGUAGES = LANGUAGE_CODES - -# The relational operation for comparison. -_COMPARISON_RELATION = ("less than", "at least") - -# The maximum number of sentences. -_MAX_NUM_SENTENCES = 20 - -# The number of placeholders. -_NUM_PLACEHOLDERS = 4 - -# The number of bullet lists. -_NUM_BULLETS = 5 - -# The options of constrained response. -_CONSTRAINED_RESPONSE_OPTIONS = ( - "My answer is yes.", - "My answer is no.", - "My answer is maybe.", -) - -# The options of starter keywords. -_STARTER_OPTIONS = ( - "I would say", - "My answer is", - "I believe", - "In my opinion", - "I think", - "I reckon", - "I feel", - "From my perspective", - "As I see it", - "According to me", - "As far as I'm concerned", - "To my understanding", - "In my view", - "My take on it is", - "As per my perception", -) - -# The options of ending keywords. -# TODO(jeffreyzhou) add more ending options -_ENDING_OPTIONS = ("Any other questions?", "Is there anything else I can help with?") - -# The number of highlighted sections. -_NUM_HIGHLIGHTED_SECTIONS = 4 - -# The section spliter. -_SECTION_SPLITER = ("Section", "SECTION") - -# The number of sections. -_NUM_SECTIONS = 5 - -# The number of paragraphs. -_NUM_PARAGRAPHS = 5 - -# The postscript marker. -_POSTSCRIPT_MARKER = ("P.S.", "P.P.S") - -# The number of keywords. -_NUM_KEYWORDS = 2 - -# The occurrences of a single keyword. -_KEYWORD_FREQUENCY = 3 - -# The occurrences of a single letter. -_LETTER_FREQUENCY = 10 - -# The occurrences of words with all capital letters. -_ALL_CAPITAL_WORD_FREQUENCY = 20 - -# The number of words in the response. -_NUM_WORDS_LOWER_LIMIT = 100 -_NUM_WORDS_UPPER_LIMIT = 500 - - -class Instruction: - """An instruction template.""" - - def __init__(self, instruction_id): - self.id = instruction_id - - def build_description(self, **kwargs): - raise NotImplementedError("`build_description` not implemented.") - - def get_instruction_args(self): - raise NotImplementedError("`get_instruction_args` not implemented.") - - def get_instruction_args_keys(self): - raise NotImplementedError("`get_instruction_args_keys` not implemented.") - - def check_following(self, value): - raise NotImplementedError("`check_following` not implemented.") - - -class ResponseLanguageChecker(Instruction): - """Check the language of the entire response.""" - - def build_description(self, *, language=None): - """Build the instruction description. - - Args: - language: A string representing the expected language of the response. The - language has to comply to the 97 types defined in - `langid.py` (https://pypi.org/project/langid/1.1.5/), which follows - ISO 639-1 codes (https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes); - for example, `en` for English, `zh` for Chinese, `fr` for French. - - Returns: - A string representing the instruction description. - """ - self._language = language - if self._language is None: - self._language = random.choice(list(_LANGUAGES.keys())) - - self._description_pattern = ( - "Your ENTIRE response should be in {language} language, no other " + "language is allowed." - ) - return self._description_pattern.format(language=_LANGUAGES[self._language]) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"language": self._language} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["language"] - - def check_following(self, value): - """Check if the language of the entire response follows the instruction. - - Args: - value: A string representing the response. - - Returns: - True if the language of `value` follows instruction; otherwise False. - """ - assert isinstance(value, str) - - try: - return langdetect.detect(value) == self._language - except langdetect.LangDetectException as e: - # Count as instruction is followed. - logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 - return True - - -class NumberOfSentences(Instruction): - """Check the number of sentences.""" - - def build_description(self, *, num_sentences=None, relation=None): - """Build the instruction description. - - Args: - num_sentences: An integer specifying the number of sentences as a - threshold. - relation: A string in (`less than`, `at least`), defining the relational - operator for comparison. - Two relational comparisons are supported for now: - if 'less than', the actual number of sentences < the threshold; - if 'at least', the actual number of sentences >= the threshold. - - Returns: - A string representing the instruction description. - """ - # The number of sentences as a threshold for comparison. - self._num_sentences_threshold = num_sentences - if self._num_sentences_threshold is None or self._num_sentences_threshold < 0: - self._num_sentences_threshold = random.randint(1, _MAX_NUM_SENTENCES) - - if relation is None: - self._comparison_relation = random.choice(_COMPARISON_RELATION) - elif relation not in _COMPARISON_RELATION: - raise ValueError( - f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." - ) - else: - self._comparison_relation = relation - - self._description_pattern = "Your response should contain {relation} {num_sentences} sentences." - return self._description_pattern.format( - relation=self._comparison_relation, - num_sentences=self._num_sentences_threshold, - ) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "num_sentences": self._num_sentences_threshold, - "relation": self._comparison_relation, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_sentences", "relation"] - - def check_following(self, value): - """Check if the number of sentences follows the instruction. - - Args: - value: A string representing the response. - - Returns: - True if the response follows the instruction. - - Raise: - ValueError if the string in `instruction_args` is not in - [`less_than`, `at_least`]. - """ - lang = get_langid(value) - if lang == "th": - # Counting Newline also as a new sentence: - num_sentences = sum([len(sent_tokenize_thai(line)) for line in value.splitlines()]) - elif lang in ["zh", "zh-cn", "zh-tw", "ja", "hi"]: - num_sentences = len(list(split_chinese_japanese_hindi(value))) - else: - num_sentences = count_sentences(value) - if self._comparison_relation == _COMPARISON_RELATION[0]: - return num_sentences < self._num_sentences_threshold - elif self._comparison_relation == _COMPARISON_RELATION[1]: - return num_sentences >= self._num_sentences_threshold - - -class PlaceholderChecker(Instruction): - """Check the placeholders in template writing.""" - - def build_description(self, *, num_placeholders=None): - """Build the instruction description. - - Args: - num_placeholders: An integer denoting the minimum number of - placeholders required in the response. - - Returns: - A string representing the instruction description. - """ - self._num_placeholders = num_placeholders - if self._num_placeholders is None or self._num_placeholders < 0: - self._num_placeholders = random.randint(1, _NUM_PLACEHOLDERS) - self._description_pattern = ( - "The response must contain at least {num_placeholders} placeholders " - + "represented by square brackets, such as [address]." - ) - return self._description_pattern.format(num_placeholders=self._num_placeholders) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"num_placeholders": self._num_placeholders} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_placeholders"] - - def check_following(self, value): - """Check if the number of placeholders follows the instruction. - - Args: - value: A string representing the response. - - Returns: - True if the actual number of placeholders in the response is greater than - or equal to `num_placeholders`; otherwise, False. - """ - placeholders = re.findall(r"\[.*?\]", value) - num_placeholders = len(placeholders) - return num_placeholders >= self._num_placeholders - - -class BulletListChecker(Instruction): - """Checks the bullet list in the prompt.""" - - def build_description(self, *, num_bullets=None): - """Build the instruction description. - - Args: - num_bullets: An integer specifying the exact number of bullet lists - that is required to appear in the response. - - Returns: - A string representing the instruction description. - """ - self._num_bullets = num_bullets - if self._num_bullets is None or self._num_bullets < 0: - self._num_bullets = random.randint(1, _NUM_BULLETS) - self._description_pattern = ( - "Your answer must contain exactly {num_bullets} bullet points. " - + "Use the markdown bullet points such as:\n" - + "* This is point 1. \n" - + "* This is point 2" - ) - return self._description_pattern.format(num_bullets=self._num_bullets) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"num_bullets": self._num_bullets} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_bullets"] - - def check_following(self, value): - r"""Check if the number of bullet lists meets the requirement. - - Args: - value: A string representing the response. The response is expected to - contain some bullet lists that start with `\*`. - - Returns: - True if the actual number of bullet lists in the response meets the - requirement. - """ - bullet_lists = re.findall(r"^\s*\*[^\*].*$", value, flags=re.MULTILINE) - bullet_lists_2 = re.findall(r"^\s*-.*$", value, flags=re.MULTILINE) - num_bullet_lists = len(bullet_lists) + len(bullet_lists_2) - return num_bullet_lists == self._num_bullets - - -class ConstrainedResponseChecker(Instruction): - """Checks the constrained response.""" - - def build_description(self): - """Build the instruction description.""" - # A sequence of string(s) representing the options of the expected response. - self._constrained_responses = _CONSTRAINED_RESPONSE_OPTIONS - self._description_pattern = "Answer with one of the following options: {response_options}" - return self._description_pattern.format(response_options=self._constrained_responses) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks if the response matches the constrained options. - - Args: - value: A string representing the response. - - Returns: - True if the actual response contains one of the options in the constrained - responses; otherwise False. - """ - value = value.strip() - for constrained_response in self._constrained_responses: - if constrained_response in value: - return True - return False - - -class ConstrainedStartChecker(Instruction): - """Checks the response start.""" - - def build_description(self, *, starter=None): - """Build the instruction description. - - Args: - starter: A string representing the keyward that the response should start - with. - - Returns: - A string representing the instruction description. - """ - self._starter = starter.strip() if isinstance(starter, str) else starter - if self._starter is None: - self._starter = random.choice(_STARTER_OPTIONS) - self._description_pattern = ( - "During the conversation, when it is your turn, " + "please always start with {starter}" - ) - return self._description_pattern.format(starter=self._starter) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"starter": self._starter} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["starter"] - - def check_following(self, value): - """Checks if the response starts with the constrained keyword or phrase. - - Args: - value: A string representing the response. - - Returns: - True if the response starts with the given phrase or keyword that is - contained in `instruction_args`; otherwise, False. - """ - response_pattern = r"^\s*" + self._starter + r".*$" - response_with_constrained_start = re.search(response_pattern, value, flags=re.MULTILINE) - return True if response_with_constrained_start else False - - -class HighlightSectionChecker(Instruction): - """Checks the highlighted section.""" - - def build_description(self, *, num_highlights=None): - """Build the instruction description. - - Args: - num_highlights: An integer specifying the minimum number of highlighted - sections. - - Returns: - A string representing the instruction description. - """ - self._num_highlights = num_highlights - if self._num_highlights is None or self._num_highlights < 0: - self._num_highlights = random.randint(1, _NUM_HIGHLIGHTED_SECTIONS) - - self._description_pattern = ( - "Highlight at least {num_highlights} sections in your answer with " - + "markdown, i.e. *highlighted section*." - ) - - return self._description_pattern.format(num_highlights=self._num_highlights) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"num_highlights": self._num_highlights} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_highlights"] - - def check_following(self, value): - """Checks if the number of highlighted sections meets the requirement. - - Args: - value: a string repesenting the response. The response is expected to - contain highlighted sections in the format of *highlighted*. - - Returns: - True if the actual number of highlighted sections in the format of - *highlighed sections* meets the minimum requirement; otherwise False. - """ - num_highlights = 0 - highlights = re.findall(r"\*[^\n\*]*\*", value) - double_highlights = re.findall(r"\*\*[^\n\*]*\*\*", value) - for highlight in highlights: - if highlight.strip("*").strip(): - num_highlights += 1 - for highlight in double_highlights: - if highlight.removeprefix("**").removesuffix("**").strip(): - num_highlights += 1 - - return num_highlights >= self._num_highlights - - -class SectionChecker(Instruction): - """Checks the sections.""" - - def build_description(self, *, section_spliter=None, num_sections=None): - """Build the instruction description. - - Args: - section_spliter: A string represents the section spliter keyword that - marks a new section, i.e., `Section` or `SECTION`. - num_sections: An integer specifying the number of sections. - - Returns: - A string representing the instruction description. - """ - self._section_spliter = section_spliter.strip() if isinstance(section_spliter, str) else section_spliter - if self._section_spliter is None: - self._section_spliter = random.choice(_SECTION_SPLITER) - - self._num_sections = num_sections - if self._num_sections is None or self._num_sections < 0: - self._num_sections = random.randint(1, _NUM_SECTIONS) - - self._description_pattern = ( - "Your response must have {num_sections} sections. Mark the beginning " - + "of each section with {section_spliter} X, such as:\n" - + "{section_spliter} 1\n" - + "[content of section 1]\n" - + "{section_spliter} 2\n" - + "[content of section 2]" - ) - - return self._description_pattern.format(num_sections=self._num_sections, section_spliter=self._section_spliter) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "section_spliter": self._section_spliter, - "num_sections": self._num_sections, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["section_spliter", "num_sections"] - - def check_following(self, value): - """Checks the response contains multiple sections. - - Args: - value: A string representing the response. The response is expected - to contain multiple sections (number of sections is greater than 1). - A new section starts with `Section 1`, where the number denotes the - section index. - - Returns: - True if the number of sections in the response is greater than or equal to - the minimum number of sections; otherwise, False. - """ - section_splitter_patten = r"\s?" + self._section_spliter + r"\s?\d+\s?" - sections = re.split(section_splitter_patten, value) - num_sections = len(sections) - 1 - return num_sections >= self._num_sections - - -class ParagraphChecker(Instruction): - """Checks the paragraphs.""" - - def build_description(self, *, num_paragraphs=None): - """Build the instruction description. - - Args: - num_paragraphs: An integer specifying the number of paragraphs. - - Returns: - A string representing the instruction description. - """ - self._num_paragraphs = num_paragraphs - if self._num_paragraphs is None or self._num_paragraphs < 0: - self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) - - self._description_pattern = ( - "There should be {num_paragraphs} paragraphs. " + "Paragraphs are separated with the markdown divider: ***" - ) - - return self._description_pattern.format(num_paragraphs=self._num_paragraphs) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"num_paragraphs": self._num_paragraphs} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_paragraphs"] - - def check_following(self, value): - """Checks the response contains required number of paragraphs. - - Args: - value: A string representing the response. The response may contain - paragraphs that are separated by the markdown divider: `***`. - - Returns: - True if the actual number of paragraphs is the same as required; - otherwise, False. - """ - paragraphs = re.split(r"\s?\*\*\*\s?", value) - num_paragraphs = len(paragraphs) - - for index, paragraph in enumerate(paragraphs): - if not paragraph.strip(): - if index == 0 or index == len(paragraphs) - 1: - num_paragraphs -= 1 - else: - return False - - return num_paragraphs == self._num_paragraphs - - -class PostscriptChecker(Instruction): - """Checks the postscript.""" - - def build_description(self, *, postscript_marker=None): - """Build the instruction description. - - Args: - postscript_marker: A string containing the keyword that marks the start - of the postscript section. - - Returns: - A string representing the instruction description. - """ - self._postscript_marker = postscript_marker.strip() if isinstance(postscript_marker, str) else postscript_marker - if self._postscript_marker is None: - self._postscript_marker = random.choice(_POSTSCRIPT_MARKER) - - self._description_pattern = ( - "At the end of your response, please explicitly add a postscript " + "starting with {postscript}" - ) - - return self._description_pattern.format(postscript=self._postscript_marker) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"postscript_marker": self._postscript_marker} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["postscript_marker"] - - def check_following(self, value): - """Checks if the response follows the postscript format. - - Args: - value: a string representing the response. The response is expected to - contain a postscript section. - - Returns: - True if the response contains a postscript section starting with - the keyword containing in the `instruction_args`; otherwise False. - """ - value = value.lower() - if self._postscript_marker == "P.P.S": - postscript_pattern = r"\s*p\.\s?p\.\s?s.*$" - elif self._postscript_marker == "P.S.": - postscript_pattern = r"\s*p\.\s?s\..*$" - else: - postscript_pattern = r"\s*" + self._postscript_marker.lower() + r".*$" - postscript = re.findall(postscript_pattern, value, flags=re.MULTILINE) - return True if postscript else False - - -class RephraseChecker(Instruction): - """Checks the repharse.""" - - def build_description(self, *, original_message): - """Build the instruction description. - - Args: - original_message: A string representing the original message. The - rephrased response should only change its words/sentences in between - its two asterisks, for example, *change me*. Both original and rephrased - messages should contain the changes in the form of *change me*. - - Returns: - A string representing the instruction description. - """ - if not self.is_change(original_message): - raise ValueError(f"Message {original_message} does not contain changes in the form of *change me*.") - - self._reference_without_change = original_message - self._description = ( - "Rephrasing: Your rephrased response should only" - + "change the words/sentences in between two asterisks" - + "such as *change me*." - ) - return self._description - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"original_message": self._reference_without_change} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["original_message"] - - def check_following(self, value): - r"""Checks if the rephrasing follows the instruction. - - Args: - value: A string representing the response, which is expected to rephras - the string of `instruction_args`. - - Returns: - True if `value` and `instruction_args` only differ by the words/sentences - in between two asterisks such as *change me*; otherwise, False. - """ - - if not self.is_change(value): - raise ValueError(f"value {value} does not contain changes in the form of *change me*.") - - response_without_changes = self.strip_changes(value) - reference_without_changes = self.strip_changes(self._reference_without_change) - - return response_without_changes == reference_without_changes - - def is_change(self, response): - """Check if there is change in the response in the form of *change me*.""" - return re.search(r"\*.*\*", response) - - def strip_changes(self, response): - """Strips off the changes.""" - return re.sub(r"\*.*\*", "", response) - - -class KeywordChecker(Instruction): - """Check the exisitence of certain keywords.""" - - def build_description(self, *, keywords=None): - """Build the instruction description. - - Args: - keywords: A sequence of strings representing the keywords that are - expected in the response. - - Returns: - A string representing the instruction description. - """ - - if not keywords: - self._keywords = generate_keywords(num_keywords=_NUM_KEYWORDS) - else: - self._keywords = keywords - self._keywords = sorted(self._keywords) - - self._description_pattern = "Include keywords {keywords} in the response." - - return self._description_pattern.format(keywords=self._keywords) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"keywords": self._keywords} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["keywords"] - - def check_following(self, value): - """Check if the response contain the expected keywords.""" - for keyword in self._keywords: - if not re.search(keyword, value, flags=re.IGNORECASE): - return False - return True - - -class KeywordFrequencyChecker(Instruction): - """Check the keyword frequency.""" - - def build_description(self, *, keyword=None, frequency=None, relation=None): - """Build the instruction description. - - Args: - keyword: A string representing a keyword that is expected in the response. - frequency: An integer specifying the number of times `keyword` is expected - to appear in the response. - relation: A string in (`less than`, `at least`), defining the relational - operator for comparison. - Two relational comparisons are supported for now: - if 'less than', the actual number of occurrences < frequency; - if 'at least', the actual number of occurrences >= frequency. - - Returns: - A string representing the instruction description. - """ - if not keyword: - self._keyword = generate_keywords(num_keywords=1)[0] - else: - self._keyword = keyword.strip() - - self._frequency = frequency - if self._frequency is None or self._frequency < 0: - self._frequency = random.randint(1, _KEYWORD_FREQUENCY) - - if relation is None: - self._comparison_relation = random.choice(_COMPARISON_RELATION) - elif relation not in _COMPARISON_RELATION: - raise ValueError( - f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." - ) - else: - self._comparison_relation = relation - - self._description_pattern = ( - "In your response, the word {keyword} should appear {relation} " + "{frequency} times." - ) - - return self._description_pattern.format( - keyword=self._keyword, - relation=self._comparison_relation, - frequency=self._frequency, - ) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "keyword": self._keyword, - "frequency": self._frequency, - "relation": self._comparison_relation, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["keyword", "frequency", "relation"] - - def check_following(self, value): - """Checks if the response contain the keyword with required frequency.""" - actual_occurrences = len(re.findall(self._keyword, value, flags=re.IGNORECASE)) - - if self._comparison_relation == _COMPARISON_RELATION[0]: - return actual_occurrences < self._frequency - elif self._comparison_relation == _COMPARISON_RELATION[1]: - return actual_occurrences >= self._frequency - - -class NumberOfWords(Instruction): - """Checks the number of words.""" - - def build_description(self, *, num_words=None, relation=None): - """Build the instruction description. - - Args: - num_words: An integer specifying the number of words contained in the - response. - relation: A string in (`less than`, `at least`), defining the relational - operator for comparison. - Two relational comparisons are supported for now: - if 'less than', the actual number of words < num_words; - if 'at least', the actual number of words >= num_words. - - Returns: - A string representing the instruction description. - """ - - self._num_words = num_words - if self._num_words is None or self._num_words < 0: - self._num_words = random.randint(_NUM_WORDS_LOWER_LIMIT, _NUM_WORDS_UPPER_LIMIT) - - if relation is None: - self._comparison_relation = random.choice(_COMPARISON_RELATION) - elif relation not in _COMPARISON_RELATION: - raise ValueError( - f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {relation} is given." - ) - else: - self._comparison_relation = relation - - self._description_pattern = "Answer with {relation} {num_words} words." - - return self._description_pattern.format(relation=self._comparison_relation, num_words=self._num_words) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"num_words": self._num_words, "relation": self._comparison_relation} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_words", "relation"] - - def check_following(self, value): - """Checks if the response contains the expected number of words.""" - lang = get_langid(value) - if lang == "th": - num_words = len(word_tokenize_thai(value)) - elif lang in ["zh", "zh-cn", "zh-tw", "ja", "ko"]: - num_words = count_words_cjk(value) - else: - num_words = count_words(value) - - if self._comparison_relation == _COMPARISON_RELATION[0]: - return num_words < self._num_words - elif self._comparison_relation == _COMPARISON_RELATION[1]: - return num_words >= self._num_words - - -class JsonFormat(Instruction): - """Check the Json format.""" - - def build_description(self): - self._description_pattern = ( - "Entire output should be wrapped in JSON format. You can use markdown ticks such as ```." - ) - return self._description_pattern - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - value = ( - value.strip() - .removeprefix("```json") - .removeprefix("```Json") - .removeprefix("```JSON") - .removeprefix("```") - .removesuffix("```") - .strip() - ) - try: - json.loads(value) - except ValueError as _: - return False - return True - - -class ParagraphFirstWordCheck(Instruction): - """Check the paragraph and the first word of the nth paragraph.""" - - def build_description(self, num_paragraphs=None, nth_paragraph=None, first_word=None): - r"""Build the instruction description. - - Args: - num_paragraphs: An integer indicating the number of paragraphs expected - in the response. A paragraph is a subset of the string that is - expected to be separated by '\n\n'. - nth_paragraph: An integer indicating the paragraph number that we look at. - Note that n starts from 1. - first_word: A string that represent the first word of the bth paragraph. - - Returns: - A string representing the instruction description. - """ - self._num_paragraphs = num_paragraphs - if self._num_paragraphs is None or self._num_paragraphs < 0: - self._num_paragraphs = random.randint(1, _NUM_PARAGRAPHS) - - self._nth_paragraph = nth_paragraph - if self._nth_paragraph is None or self._nth_paragraph <= 0 or self._nth_paragraph > self._num_paragraphs: - self._nth_paragraph = random.randint(1, self._num_paragraphs + 1) - - self._first_word = first_word - if self._first_word is None: - self._first_word = generate_keywords(num_keywords=1)[0] - self._first_word = self._first_word.lower() - - self._description_pattern = ( - "There should be {num_paragraphs} paragraphs. " - + "Paragraphs and only paragraphs are separated with each other by two " - + "new lines as if it was '\\n\\n' in python. " - + "Paragraph {nth_paragraph} must start with word {first_word}." - ) - - return self._description_pattern.format( - num_paragraphs=self._num_paragraphs, - nth_paragraph=self._nth_paragraph, - first_word=self._first_word, - ) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "num_paragraphs": self._num_paragraphs, - "nth_paragraph": self._nth_paragraph, - "first_word": self._first_word, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_paragraphs", "nth_paragraph", "first_word"] - - def check_following(self, value): - """Checks for required number of paragraphs and correct first word. - - Args: - value: a string representing the response. The response may contain - paragraphs that are separated by two new lines and the first word of - the nth paragraph will have to match a specified word. - - Returns: - True if the number of paragraphs is the same as required and the first - word of the specified paragraph is the same as required. Otherwise, false. - """ - - paragraphs = re.split(r"\n\n", value) - num_paragraphs = len(paragraphs) - - for paragraph in paragraphs: - if not paragraph.strip(): - num_paragraphs -= 1 - - # check that index doesn't go out of bounds - if self._nth_paragraph <= num_paragraphs: - paragraph = paragraphs[self._nth_paragraph - 1].strip() - if not paragraph: - return False - else: - return False - - first_word = "" - punctuation = {".", ",", "?", "!", "'", '"'} - - # get first word and remove punctuation - word = paragraph.split()[0].strip() - word = word.lstrip("'") - word = word.lstrip('"') - - for letter in word: - if letter in punctuation: - break - first_word += letter.lower() - - return num_paragraphs == self._num_paragraphs and first_word == self._first_word - - -class KeySentenceChecker(Instruction): - """Check the existence of certain key sentences.""" - - def build_description(self, key_sentences=None, num_sentences=None): - """Build the instruction description. - - Args: - key_sentences: A sequences of strings representing the key sentences that - are expected in the response. - num_sentences: The number of key sentences that are expected to be seen in - the response. - - Returns: - A string representing the instruction description. - """ - - if not key_sentences: - self._key_sentences = {["For now, this is fine."]} - else: - self._key_sentences = key_sentences - - if not num_sentences: - self._num_sentences = random.randint(1, len(self._key_sentences)) - else: - self._num_sentences = num_sentences - - self._description_pattern = "Include {num_sentences} of the following sentences {key_sentences}" - - return self._description_pattern.format(num_sentences=self._num_sentences, key_sentences=self._key_sentences) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "num_sentences": self._num_sentences, - "key_sentences": list(self._key_sentences), - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["num_sentences", "key_sentences"] - - def check_following(self, value): - """Checks if the response contains the expected key sentences.""" - count = 0 - sentences = split_into_sentences(value) - for sentence in self._key_sentences: - if sentence in sentences: - count += 1 - - return count == self._num_sentences - - -class ForbiddenWords(Instruction): - """Checks that specified words are not used in response.""" - - def build_description(self, forbidden_words=None): - """Build the instruction description. - - Args: - forbidden_words: A sequences of strings respresenting words that are not - allowed in the response. - - Returns: - A string representing the instruction description. - """ - - if not forbidden_words: - self._forbidden_words = generate_keywords(num_keywords=_NUM_KEYWORDS) - else: - self._forbidden_words = list(set(forbidden_words)) - self._forbidden_words = sorted(self._forbidden_words) - self._description_pattern = "Do not include keywords {forbidden_words} in the response." - - return self._description_pattern.format(forbidden_words=self._forbidden_words) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return {"forbidden_words": self._forbidden_words} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["forbidden_words"] - - def check_following(self, value): - """Check if the response does not contain the expected keywords.""" - for word in self._forbidden_words: - if re.search(r"\b" + word + r"\b", value, flags=re.IGNORECASE): - return False - return True - - -class RephraseParagraph(Instruction): - """Checks that the paragraph is rephrased.""" - - def build_description(self, *, original_paragraph, low, high): - """Builds the instruction description. - - Args: - original_paragraph: A string presenting the original paragraph. The - rephrases response should have betweeb low-high words in common. - low: An integer presenting the lower bound of similar words. - high: An integer representing the upper bound of similar words. - - Returns: - A string representing the instruction description. - """ - self._original_paragraph = original_paragraph - self._low = low - self._high = high - - self._description = ( - "Rephrase the following paragraph: " - + "{original_paragraph}\nYour response should have " - + "between {low} and {high} of the same words. " - + "Words are the same if and only if all of the " - + "letters, ignoring cases, are the same. For " - + "example, 'run' is the same as 'Run' but different " - + "to 'ran'." - ) - - return self._description.format(original_paragraph=original_paragraph, low=self._low, high=self._high) - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return { - "original_paragraph": self._original_paragraph, - "low": self._low, - "high": self._high, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["original_paragraph", "low", "high"] - - def check_following(self, value): - val_words = re.findall(r"\w+", value.lower()) - original_words = re.findall(r"\w+", self._original_paragraph.lower()) - similar_words = 0 - - dict_val = collections.Counter(val_words) - dict_original = collections.Counter(original_words) - - for word in dict_original: - similar_words += min(dict_original[word], dict_val[word]) - - return similar_words >= self._low and similar_words <= self._high - - -class TwoResponsesChecker(Instruction): - """Check that two responses were given.""" - - def build_description(self): - """Build the instruction description.""" - self._description_pattern = ( - "Give two different responses. Responses and only responses should" - " be separated by 6 asterisk symbols: ******." - ) - return self._description_pattern - - def get_instruction_args(self): - """Returns the keyward args of `build_description`.""" - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks if the response has two different answers. - - Args: - value: A string representing the response. - - Returns: - True if two responses are detected and false otherwise. - """ - valid_responses = list() - responses = value.split("******") - for index, response in enumerate(responses): - if not response.strip(): - if index != 0 and index != len(responses) - 1: - return False - else: - valid_responses.append(response) - return len(valid_responses) == 2 and valid_responses[0].strip() != valid_responses[1].strip() - - -class RepeatPromptThenAnswer(Instruction): - """Checks that Prompt is first repeated then answered.""" - - def build_description(self, *, prompt_to_repeat=None): - """Build the instruction description. - - Args: - prompt_to_repeat: The prompt that is meant to be repeated. - - Returns: - A string representing the instruction description. - """ - if not prompt_to_repeat: - raise ValueError("prompt_to_repeat must be set.") - else: - self._prompt_to_repeat = prompt_to_repeat - self._description_pattern = ( - "First repeat the request word for word without change," - " then give your answer (1. do not say any words or characters" - " before repeating the request; 2. the request you need to repeat" - " does not include this sentence)" - ) - return self._description_pattern - - def get_instruction_args(self): - return {"prompt_to_repeat": self._prompt_to_repeat} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["prompt_to_repeat"] - - def check_following(self, value): - if value.strip().lower().startswith(self._prompt_to_repeat.strip().lower()): - return True - return False - - -class EndChecker(Instruction): - """Checks that the prompt ends with a given phrase.""" - - def build_description(self, *, end_phrase=None): - """Build the instruction description. - - Args: - end_phrase: A string representing the phrase the response should end with. - - Returns: - A string representing the instruction description. - """ - self._end_phrase = end_phrase.strip() if isinstance(end_phrase, str) else end_phrase - if self._end_phrase is None: - self._end_phrase = random.choice(_ENDING_OPTIONS) - self._description_pattern = ( - "Finish your response with this exact phrase {ender}. No other words should follow this phrase." - ) - return self._description_pattern.format(ender=self._end_phrase) - - def get_instruction_args(self): - return {"end_phrase": self._end_phrase} - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["end_phrase"] - - def check_following(self, value): - """Checks if the response ends with the expected phrase.""" - value = value.strip().strip('"').lower() - self._end_phrase = self._end_phrase.strip().lower() - return value.endswith(self._end_phrase) - - -class TitleChecker(Instruction): - """Checks the response for a title.""" - - def build_description(self): - """Build the instruction description.""" - self._description_pattern = ( - "Your answer must contain a title, wrapped in double angular brackets, such as <>." - ) - return self._description_pattern - - def get_instruction_args(self): - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks if the response contains a title.""" - pattern = r"<<[^\n]+>>" - re_pattern = re.compile(pattern) - titles = re.findall(re_pattern, value) - - for title in titles: - if title.lstrip("<").rstrip(">").strip(): - return True - return False - - -class LetterFrequencyChecker(Instruction): - """Checks letter frequency.""" - - def build_description(self, *, letter=None, let_frequency=None, let_relation=None): - """Build the instruction description. - - Args: - letter: A string representing a letter that is expected in the response. - let_frequency: An integer specifying the number of times `keyword` is - expected to appear in the response. - let_relation: A string in (`less than`, `at least`), defining the - relational operator for comparison. Two relational comparisons are - supported for now; if 'less than', the actual number of - occurrences < frequency; if 'at least', the actual number of - occurrences >= frequency. - - Returns: - A string representing the instruction description. - """ - if not letter or len(letter) > 1 or ord(letter.lower()) < 97 or ord(letter.lower()) > 122: - self._letter = random.choice(list(string.ascii_letters)) - else: - self._letter = letter.strip() - self._letter = self._letter.lower() - - self._frequency = let_frequency - if self._frequency is None or self._frequency < 0: - self._frequency = random.randint(1, _LETTER_FREQUENCY) - - if let_relation is None: - self._comparison_relation = random.choice(_COMPARISON_RELATION) - elif let_relation not in _COMPARISON_RELATION: - raise ValueError( - f"The supported relation for comparison must be in {_COMPARISON_RELATION}, but {let_relation} is given." - ) - else: - self._comparison_relation = let_relation - - self._description_pattern = ( - "In your response, the letter {letter} should appear {let_relation} {let_frequency} times." - ) - - return self._description_pattern.format( - letter=self._letter, - let_frequency=self._frequency, - let_relation=self._comparison_relation, - ) - - def get_instruction_args(self): - """Returns the keyword args of build description.""" - return { - "letter": self._letter, - "let_frequency": self._frequency, - "let_relation": self._comparison_relation, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["letter", "let_frequency", "let_relation"] - - def check_following(self, value): - """Checks that the response contains the letter at the right frequency.""" - value = value.lower() - letters = collections.Counter(value) - - if self._comparison_relation == _COMPARISON_RELATION[0]: - return letters[self._letter] < self._frequency - else: - return letters[self._letter] >= self._frequency - - -class CapitalLettersEnglishChecker(Instruction): - """Checks that the response is in english and is in all capital letters.""" - - def build_description(self): - """Build the instruction description.""" - self._description_pattern = "Your entire response should be in English, and in all capital letters." - return self._description_pattern - - def get_instruction_args(self): - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks that the response is in English and in all capital letters.""" - assert isinstance(value, str) - - try: - return value.isupper() and langdetect.detect(value) == "en" - except langdetect.LangDetectException as e: - # Count as instruction is followed. - logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 - return True - - -class LowercaseLettersEnglishChecker(Instruction): - """Checks that the response is in english and is in all lowercase letters.""" - - def build_description(self): - """Build the instruction description.""" - self._description_pattern = ( - "Your entire response should be in English, and in all lowercase letters. No capital letters are allowed." - ) - return self._description_pattern - - def get_instruction_args(self): - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks that the response is in English and in all lowercase letters.""" - assert isinstance(value, str) - - try: - return value.islower() and langdetect.detect(value) == "en" - except langdetect.LangDetectException as e: - # Count as instruction is followed. - logger.info("Unable to detect language for text %s due to %s", value, e) # refex: disable=pytotw.037 - return True - - -class CommaChecker(Instruction): - """Checks the response for no commas.""" - - def build_description(self, **kwargs): - """Build the instruction description.""" - self._description_pattern = "In your entire response, refrain from the use of any commas." - return self._description_pattern - - def get_instruction_args(self): - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks that the response does not contain commas.""" - return not re.search(r"\,", value) - - -class CapitalWordFrequencyChecker(Instruction): - """Checks frequency of words with all capital letters.""" - - def build_description( - self, - capital_frequency=None, - capital_relation=None, - ): - """Build the instruction description. - - Args: - capital_frequency: An integer that represents the number of words that - should be in all capital letters. - capital_relation: A string that is 'at least' or 'at most' that refers to - the frequency. - - Returns: - A string representing the instruction description. - """ - self._frequency = capital_frequency - if self._frequency is None: - self._frequency = random.randint(1, _ALL_CAPITAL_WORD_FREQUENCY) - - self._comparison_relation = capital_relation - if capital_relation is None: - self._comparison_relation = random.choice(_COMPARISON_RELATION) - elif capital_relation not in _COMPARISON_RELATION: - raise ValueError( - "The supported relation for comparison must be in " - f"{_COMPARISON_RELATION}, but {capital_relation} is given." - ) - - self._description_pattern = ( - "In your response, words with all capital letters should appear {relation} {frequency} times." - ) - - return self._description_pattern.format(frequency=self._frequency, relation=self._comparison_relation) - - def get_instruction_args(self): - """Returns the keyword args of build description.""" - return { - "capital_frequency": self._frequency, - "capital_relation": self._comparison_relation, - } - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return ["capital_frequency", "capital_relation"] - - def check_following(self, value): - """Checks the frequency of words with all capital letters.""" - # Hyphenated words will count as one word - nltk.download("punkt_tab") - words = nltk.word_tokenize(value) - capital_words = [word for word in words if word.isupper()] - - capital_words = len(capital_words) - - if self._comparison_relation == _COMPARISON_RELATION[0]: - return capital_words < self._frequency - else: - return capital_words >= self._frequency - - -class QuotationChecker(Instruction): - """Checks response is wrapped with double quotation marks.""" - - def build_description(self): - """Build the instruction description.""" - self._description_pattern = "Wrap your entire response with double quotation marks." - return self._description_pattern - - def get_instruction_args(self): - """Returns the keyword args of build description.""" - return None - - def get_instruction_args_keys(self): - """Returns the args keys of `build_description`.""" - return [] - - def check_following(self, value): - """Checks if the response is wrapped with double quotation marks.""" - quotations_map = { - "ja": "「」", - "ru": "«»", - "th": "“”", - "zh": "“”", - "zh-cn": "“”", - "zh-tw": "“”", - } - value = value.strip() - lang = get_langid(value) - quotes = quotations_map.get(lang, '""') - # TODO: We may wanna revisit this logic in new generations to only check of the response language's quotes. - return len(value) > 1 and value[0] in [quotes[0], '"'] and value[-1] in [quotes[1], '"'] - -# Define instruction dicts _KEYWORD = "keywords:" _LANGUAGE = "language:" _LENGTH = "length_constraints:" diff --git a/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_word_list.py b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_word_list.py new file mode 100644 index 0000000000..679f483326 --- /dev/null +++ b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_word_list.py @@ -0,0 +1,1538 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="scoring") + +WORD_LIST = [ + "western", + "sentence", + "signal", + "dump", + "spot", + "opposite", + "bottom", + "potato", + "administration", + "working", + "welcome", + "morning", + "good", + "agency", + "primary", + "wish", + "responsibility", + "press", + "problem", + "president", + "steal", + "brush", + "read", + "type", + "beat", + "trainer", + "growth", + "lock", + "bone", + "case", + "equal", + "comfortable", + "region", + "replacement", + "performance", + "mate", + "walk", + "medicine", + "film", + "thing", + "rock", + "tap", + "total", + "competition", + "ease", + "south", + "establishment", + "gather", + "parking", + "world", + "plenty", + "breath", + "claim", + "alcohol", + "trade", + "dear", + "highlight", + "street", + "matter", + "decision", + "mess", + "agreement", + "studio", + "coach", + "assist", + "brain", + "wing", + "style", + "private", + "top", + "brown", + "leg", + "buy", + "procedure", + "method", + "speed", + "high", + "company", + "valuable", + "pie", + "analyst", + "session", + "pattern", + "district", + "pleasure", + "dinner", + "swimming", + "joke", + "order", + "plate", + "department", + "motor", + "cell", + "spend", + "cabinet", + "difference", + "power", + "examination", + "engine", + "horse", + "dimension", + "pay", + "toe", + "curve", + "literature", + "bother", + "fire", + "possibility", + "debate", + "activity", + "passage", + "hello", + "cycle", + "background", + "quiet", + "author", + "effect", + "actor", + "page", + "bicycle", + "error", + "throat", + "attack", + "character", + "phone", + "tea", + "increase", + "outcome", + "file", + "specific", + "inspector", + "internal", + "potential", + "staff", + "building", + "employer", + "shoe", + "hand", + "direction", + "garden", + "purchase", + "interview", + "study", + "recognition", + "member", + "spiritual", + "oven", + "sandwich", + "weird", + "passenger", + "particular", + "response", + "reaction", + "size", + "variation", + "a", + "cancel", + "candy", + "exit", + "guest", + "condition", + "fly", + "price", + "weakness", + "convert", + "hotel", + "great", + "mouth", + "mind", + "song", + "sugar", + "suspect", + "telephone", + "ear", + "roof", + "paint", + "refrigerator", + "organization", + "jury", + "reward", + "engineering", + "day", + "possession", + "crew", + "bar", + "road", + "description", + "celebration", + "score", + "mark", + "letter", + "shower", + "suggestion", + "sir", + "luck", + "national", + "progress", + "hall", + "stroke", + "theory", + "offer", + "story", + "tax", + "definition", + "history", + "ride", + "medium", + "opening", + "glass", + "elevator", + "stomach", + "question", + "ability", + "leading", + "village", + "computer", + "city", + "grand", + "confidence", + "candle", + "priest", + "recommendation", + "point", + "necessary", + "body", + "desk", + "secret", + "horror", + "noise", + "culture", + "warning", + "water", + "round", + "diet", + "flower", + "bus", + "tough", + "permission", + "week", + "prompt", + "connection", + "abuse", + "height", + "save", + "corner", + "border", + "stress", + "drive", + "stop", + "rip", + "meal", + "listen", + "confusion", + "girlfriend", + "living", + "relation", + "significance", + "plan", + "creative", + "atmosphere", + "blame", + "invite", + "housing", + "paper", + "drink", + "roll", + "silver", + "drunk", + "age", + "damage", + "smoke", + "environment", + "pack", + "savings", + "influence", + "tourist", + "rain", + "post", + "sign", + "grandmother", + "run", + "profit", + "push", + "clerk", + "final", + "wine", + "swim", + "pause", + "stuff", + "singer", + "funeral", + "average", + "source", + "scene", + "tradition", + "personal", + "snow", + "nobody", + "distance", + "sort", + "sensitive", + "animal", + "major", + "negotiation", + "click", + "mood", + "period", + "arrival", + "expression", + "holiday", + "repeat", + "dust", + "closet", + "gold", + "bad", + "sail", + "combination", + "clothes", + "emphasis", + "duty", + "black", + "step", + "school", + "jump", + "document", + "professional", + "lip", + "chemical", + "front", + "wake", + "while", + "inside", + "watch", + "row", + "subject", + "penalty", + "balance", + "possible", + "adult", + "aside", + "sample", + "appeal", + "wedding", + "depth", + "king", + "award", + "wife", + "blow", + "site", + "camp", + "music", + "safe", + "gift", + "fault", + "guess", + "act", + "shame", + "drama", + "capital", + "exam", + "stupid", + "record", + "sound", + "swing", + "novel", + "minimum", + "ratio", + "machine", + "shape", + "lead", + "operation", + "salary", + "cloud", + "affair", + "hit", + "chapter", + "stage", + "quantity", + "access", + "army", + "chain", + "traffic", + "kick", + "analysis", + "airport", + "time", + "vacation", + "philosophy", + "ball", + "chest", + "thanks", + "place", + "mountain", + "advertising", + "red", + "past", + "rent", + "return", + "tour", + "house", + "construction", + "net", + "native", + "war", + "figure", + "fee", + "spray", + "user", + "dirt", + "shot", + "task", + "stick", + "friend", + "software", + "promotion", + "interaction", + "surround", + "block", + "purpose", + "practice", + "conflict", + "routine", + "requirement", + "bonus", + "hole", + "state", + "junior", + "sweet", + "catch", + "tear", + "fold", + "wall", + "editor", + "life", + "position", + "pound", + "respect", + "bathroom", + "coat", + "script", + "job", + "teach", + "birth", + "view", + "resolve", + "theme", + "employee", + "doubt", + "market", + "education", + "serve", + "recover", + "tone", + "harm", + "miss", + "union", + "understanding", + "cow", + "river", + "association", + "concept", + "training", + "recipe", + "relationship", + "reserve", + "depression", + "proof", + "hair", + "revenue", + "independent", + "lift", + "assignment", + "temporary", + "amount", + "loss", + "edge", + "track", + "check", + "rope", + "estimate", + "pollution", + "stable", + "message", + "delivery", + "perspective", + "mirror", + "assistant", + "representative", + "witness", + "nature", + "judge", + "fruit", + "tip", + "devil", + "town", + "emergency", + "upper", + "drop", + "stay", + "human", + "neck", + "speaker", + "network", + "sing", + "resist", + "league", + "trip", + "signature", + "lawyer", + "importance", + "gas", + "choice", + "engineer", + "success", + "part", + "external", + "worker", + "simple", + "quarter", + "student", + "heart", + "pass", + "spite", + "shift", + "rough", + "lady", + "grass", + "community", + "garage", + "youth", + "standard", + "skirt", + "promise", + "blind", + "television", + "disease", + "commission", + "positive", + "energy", + "calm", + "presence", + "tune", + "basis", + "preference", + "head", + "common", + "cut", + "somewhere", + "presentation", + "current", + "thought", + "revolution", + "effort", + "master", + "implement", + "republic", + "floor", + "principle", + "stranger", + "shoulder", + "grade", + "button", + "tennis", + "police", + "collection", + "account", + "register", + "glove", + "divide", + "professor", + "chair", + "priority", + "combine", + "peace", + "extension", + "maybe", + "evening", + "frame", + "sister", + "wave", + "code", + "application", + "mouse", + "match", + "counter", + "bottle", + "half", + "cheek", + "resolution", + "back", + "knowledge", + "make", + "discussion", + "screw", + "length", + "accident", + "battle", + "dress", + "knee", + "log", + "package", + "it", + "turn", + "hearing", + "newspaper", + "layer", + "wealth", + "profile", + "imagination", + "answer", + "weekend", + "teacher", + "appearance", + "meet", + "bike", + "rise", + "belt", + "crash", + "bowl", + "equivalent", + "support", + "image", + "poem", + "risk", + "excitement", + "remote", + "secretary", + "public", + "produce", + "plane", + "display", + "money", + "sand", + "situation", + "punch", + "customer", + "title", + "shake", + "mortgage", + "option", + "number", + "pop", + "window", + "extent", + "nothing", + "experience", + "opinion", + "departure", + "dance", + "indication", + "boy", + "material", + "band", + "leader", + "sun", + "beautiful", + "muscle", + "farmer", + "variety", + "fat", + "handle", + "director", + "opportunity", + "calendar", + "outside", + "pace", + "bath", + "fish", + "consequence", + "put", + "owner", + "go", + "doctor", + "information", + "share", + "hurt", + "protection", + "career", + "finance", + "force", + "golf", + "garbage", + "aspect", + "kid", + "food", + "boot", + "milk", + "respond", + "objective", + "reality", + "raw", + "ring", + "mall", + "one", + "impact", + "area", + "news", + "international", + "series", + "impress", + "mother", + "shelter", + "strike", + "loan", + "month", + "seat", + "anything", + "entertainment", + "familiar", + "clue", + "year", + "glad", + "supermarket", + "natural", + "god", + "cost", + "conversation", + "tie", + "ruin", + "comfort", + "earth", + "storm", + "percentage", + "assistance", + "budget", + "strength", + "beginning", + "sleep", + "other", + "young", + "unit", + "fill", + "store", + "desire", + "hide", + "value", + "cup", + "maintenance", + "nurse", + "function", + "tower", + "role", + "class", + "camera", + "database", + "panic", + "nation", + "basket", + "ice", + "art", + "spirit", + "chart", + "exchange", + "feedback", + "statement", + "reputation", + "search", + "hunt", + "exercise", + "nasty", + "notice", + "male", + "yard", + "annual", + "collar", + "date", + "platform", + "plant", + "fortune", + "passion", + "friendship", + "spread", + "cancer", + "ticket", + "attitude", + "island", + "active", + "object", + "service", + "buyer", + "bite", + "card", + "face", + "steak", + "proposal", + "patient", + "heat", + "rule", + "resident", + "broad", + "politics", + "west", + "knife", + "expert", + "girl", + "design", + "salt", + "baseball", + "grab", + "inspection", + "cousin", + "couple", + "magazine", + "cook", + "dependent", + "security", + "chicken", + "version", + "currency", + "ladder", + "scheme", + "kitchen", + "employment", + "local", + "attention", + "manager", + "fact", + "cover", + "sad", + "guard", + "relative", + "county", + "rate", + "lunch", + "program", + "initiative", + "gear", + "bridge", + "breast", + "talk", + "dish", + "guarantee", + "beer", + "vehicle", + "reception", + "woman", + "substance", + "copy", + "lecture", + "advantage", + "park", + "cold", + "death", + "mix", + "hold", + "scale", + "tomorrow", + "blood", + "request", + "green", + "cookie", + "church", + "strip", + "forever", + "beyond", + "debt", + "tackle", + "wash", + "following", + "feel", + "maximum", + "sector", + "sea", + "property", + "economics", + "menu", + "bench", + "try", + "language", + "start", + "call", + "solid", + "address", + "income", + "foot", + "senior", + "honey", + "few", + "mixture", + "cash", + "grocery", + "link", + "map", + "form", + "factor", + "pot", + "model", + "writer", + "farm", + "winter", + "skill", + "anywhere", + "birthday", + "policy", + "release", + "husband", + "lab", + "hurry", + "mail", + "equipment", + "sink", + "pair", + "driver", + "consideration", + "leather", + "skin", + "blue", + "boat", + "sale", + "brick", + "two", + "feed", + "square", + "dot", + "rush", + "dream", + "location", + "afternoon", + "manufacturer", + "control", + "occasion", + "trouble", + "introduction", + "advice", + "bet", + "eat", + "kill", + "category", + "manner", + "office", + "estate", + "pride", + "awareness", + "slip", + "crack", + "client", + "nail", + "shoot", + "membership", + "soft", + "anybody", + "web", + "official", + "individual", + "pizza", + "interest", + "bag", + "spell", + "profession", + "queen", + "deal", + "resource", + "ship", + "guy", + "chocolate", + "joint", + "formal", + "upstairs", + "car", + "resort", + "abroad", + "dealer", + "associate", + "finger", + "surgery", + "comment", + "team", + "detail", + "crazy", + "path", + "tale", + "initial", + "arm", + "radio", + "demand", + "single", + "draw", + "yellow", + "contest", + "piece", + "quote", + "pull", + "commercial", + "shirt", + "contribution", + "cream", + "channel", + "suit", + "discipline", + "instruction", + "concert", + "speech", + "low", + "effective", + "hang", + "scratch", + "industry", + "breakfast", + "lay", + "join", + "metal", + "bedroom", + "minute", + "product", + "rest", + "temperature", + "many", + "give", + "argument", + "print", + "purple", + "laugh", + "health", + "credit", + "investment", + "sell", + "setting", + "lesson", + "egg", + "middle", + "marriage", + "level", + "evidence", + "phrase", + "love", + "self", + "benefit", + "guidance", + "affect", + "you", + "dad", + "anxiety", + "special", + "boyfriend", + "test", + "blank", + "payment", + "soup", + "obligation", + "reply", + "smile", + "deep", + "complaint", + "addition", + "review", + "box", + "towel", + "minor", + "fun", + "soil", + "issue", + "cigarette", + "internet", + "gain", + "tell", + "entry", + "spare", + "incident", + "family", + "refuse", + "branch", + "can", + "pen", + "grandfather", + "constant", + "tank", + "uncle", + "climate", + "ground", + "volume", + "communication", + "kind", + "poet", + "child", + "screen", + "mine", + "quit", + "gene", + "lack", + "charity", + "memory", + "tooth", + "fear", + "mention", + "marketing", + "reveal", + "reason", + "court", + "season", + "freedom", + "land", + "sport", + "audience", + "classroom", + "law", + "hook", + "win", + "carry", + "eye", + "smell", + "distribution", + "research", + "country", + "dare", + "hope", + "whereas", + "stretch", + "library", + "if", + "delay", + "college", + "plastic", + "book", + "present", + "use", + "worry", + "champion", + "goal", + "economy", + "march", + "election", + "reflection", + "midnight", + "slide", + "inflation", + "action", + "challenge", + "guitar", + "coast", + "apple", + "campaign", + "field", + "jacket", + "sense", + "way", + "visual", + "remove", + "weather", + "trash", + "cable", + "regret", + "buddy", + "beach", + "historian", + "courage", + "sympathy", + "truck", + "tension", + "permit", + "nose", + "bed", + "son", + "person", + "base", + "meat", + "usual", + "air", + "meeting", + "worth", + "game", + "independence", + "physical", + "brief", + "play", + "raise", + "board", + "she", + "key", + "writing", + "pick", + "command", + "party", + "yesterday", + "spring", + "candidate", + "physics", + "university", + "concern", + "development", + "change", + "string", + "target", + "instance", + "room", + "bitter", + "bird", + "football", + "normal", + "split", + "impression", + "wood", + "long", + "meaning", + "stock", + "cap", + "leadership", + "media", + "ambition", + "fishing", + "essay", + "salad", + "repair", + "today", + "designer", + "night", + "bank", + "drawing", + "inevitable", + "phase", + "vast", + "chip", + "anger", + "switch", + "cry", + "twist", + "personality", + "attempt", + "storage", + "being", + "preparation", + "bat", + "selection", + "white", + "technology", + "contract", + "side", + "section", + "station", + "till", + "structure", + "tongue", + "taste", + "truth", + "difficulty", + "group", + "limit", + "main", + "move", + "feeling", + "light", + "example", + "mission", + "might", + "wait", + "wheel", + "shop", + "host", + "classic", + "alternative", + "cause", + "agent", + "consist", + "table", + "airline", + "text", + "pool", + "craft", + "range", + "fuel", + "tool", + "partner", + "load", + "entrance", + "deposit", + "hate", + "article", + "video", + "summer", + "feature", + "extreme", + "mobile", + "hospital", + "flight", + "fall", + "pension", + "piano", + "fail", + "result", + "rub", + "gap", + "system", + "report", + "suck", + "ordinary", + "wind", + "nerve", + "ask", + "shine", + "note", + "line", + "mom", + "perception", + "brother", + "reference", + "bend", + "charge", + "treat", + "trick", + "term", + "homework", + "bake", + "bid", + "status", + "project", + "strategy", + "orange", + "let", + "enthusiasm", + "parent", + "concentrate", + "device", + "travel", + "poetry", + "business", + "society", + "kiss", + "end", + "vegetable", + "employ", + "schedule", + "hour", + "brave", + "focus", + "process", + "movie", + "illegal", + "general", + "coffee", + "ad", + "highway", + "chemistry", + "psychology", + "hire", + "bell", + "conference", + "relief", + "show", + "neat", + "funny", + "weight", + "quality", + "club", + "daughter", + "zone", + "touch", + "tonight", + "shock", + "burn", + "excuse", + "name", + "survey", + "landscape", + "advance", + "satisfaction", + "bread", + "disaster", + "item", + "hat", + "prior", + "shopping", + "visit", + "east", + "photo", + "home", + "idea", + "father", + "comparison", + "cat", + "pipe", + "winner", + "count", + "lake", + "fight", + "prize", + "foundation", + "dog", + "keep", + "ideal", + "fan", + "struggle", + "peak", + "safety", + "solution", + "hell", + "conclusion", + "population", + "strain", + "alarm", + "measurement", + "second", + "train", + "race", + "due", + "insurance", + "boss", + "tree", + "monitor", + "sick", + "course", + "drag", + "appointment", + "slice", + "still", + "care", + "patience", + "rich", + "escape", + "emotion", + "royal", + "female", + "childhood", + "government", + "picture", + "will", + "sock", + "big", + "gate", + "oil", + "cross", + "pin", + "improvement", + "championship", + "silly", + "help", + "sky", + "pitch", + "man", + "diamond", + "most", + "transition", + "work", + "science", + "committee", + "moment", + "fix", + "teaching", + "dig", + "specialist", + "complex", + "guide", + "people", + "dead", + "voice", + "original", + "break", + "topic", + "data", + "degree", + "reading", + "recording", + "bunch", + "reach", + "judgment", + "lie", + "regular", + "set", + "painting", + "mode", + "list", + "player", + "bear", + "north", + "wonder", + "carpet", + "heavy", + "officer", + "negative", + "clock", + "unique", + "baby", + "pain", + "assumption", + "disk", + "iron", + "bill", + "drawer", + "look", + "double", + "mistake", + "finish", + "future", + "brilliant", + "contact", + "math", + "rice", + "leave", + "restaurant", + "discount", + "sex", + "virus", + "bit", + "trust", + "event", + "wear", + "juice", + "failure", + "bug", + "context", + "mud", + "whole", + "wrap", + "intention", + "draft", + "pressure", + "cake", + "dark", + "explanation", + "space", + "angle", + "word", + "efficiency", + "management", + "habit", + "star", + "chance", + "finding", + "transportation", + "stand", + "criticism", + "flow", + "door", + "injury", + "insect", + "surprise", + "apartment", +] # pylint: disable=line-too-long From da84966997598939c61da8c47db39d225667e7c0 Mon Sep 17 00:00:00 2001 From: skamenan7 Date: Thu, 26 Mar 2026 15:47:01 -0400 Subject: [PATCH 4/4] refactor(scoring): apply upstream get_langid docstring to ifeval_support.py Upstream PR #5267 added a docstring to get_langid() in ifeval_utils.py. Since that function now lives in ifeval_support.py after the split, carry the docstring over to the correct location. Signed-off-by: skamenan7 --- .../inline/scoring/basic/utils/ifeval_support.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_support.py b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_support.py index ff347175e1..39f70db130 100644 --- a/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_support.py +++ b/src/llama_stack/providers/inline/scoring/basic/utils/ifeval_support.py @@ -190,6 +190,15 @@ def count_sentences(text): def get_langid(text: str, lid_path: str | None = None) -> str: + """Detect the primary language of a text using per-line language detection. + + Args: + text: input text to analyze + lid_path: unused, kept for interface compatibility + + Returns: + ISO 639-1 language code, defaulting to "en" if detection fails + """ line_langs: list[str] = [] lines = [line.strip() for line in text.split("\n") if len(line.strip()) >= 4]