diff --git a/.env.example b/.env.example index cceb209e..b82879df 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,13 @@ ANTHROPIC_API_KEY=your_anthropic_api_key_here + OPENAI_API_KEY=your_openai_api_key_here + GOOGLE_API_KEY=your_google_api_key_here + AZURE_API_KEY=your_azure_api_key_here AZURE_ENDPOINT=your_azure_endpoint_here -AZURE_API_VERSION=your_azure_api_version_here \ No newline at end of file +AZURE_API_VERSION=your_azure_api_version_here + +ENDPOINT_URL=http://0.0.0.0:8000/api/chat +ENDPOINT_START_URL=http://0.0.0.0:8000/api/start_conversation +ENDPOINT_API_KEY=your_endpoint_api_key_here \ No newline at end of file diff --git a/README.md b/README.md index 9205822d..1da8971c 100644 --- a/README.md +++ b/README.md @@ -302,6 +302,7 @@ VERA-MH simulates realistic conversations between Large Language Models (LLMs) f - **`gemini_llm.py`**: Google Gemini implementation with structured output - **`azure_llm.py`**: Azure OpenAI and Azure AI Foundry implementation with structured output - **`ollama_llm.py`**: Ollama model implementation + - **`endpoint_llm.py`**: Example for using your own API as the provider agent (currently chat-only; see [evaluating.md](docs/evaluating.md)) - **`config.py`**: Configuration management for API keys and model settings - **`utils/`**: Utility functions and helpers - **`prompt_loader.py`**: Functions for loading prompt configurations diff --git a/docs/evaluating.md b/docs/evaluating.md index 9392170b..8bf0113e 100644 --- a/docs/evaluating.md +++ b/docs/evaluating.md @@ -3,6 +3,7 @@ VERA-MH is ready to be used to evaluate any chat-based interface. [This](../llm_clients/llm_interface.py) Abstract Base Class (ABC) represents the interface to be implemented. Four concrete implementations of that class are provided for the APIs of ChatGPT, Claude, Gemini, Azure, and Llama (via Ollama). +For developers who wish to use their own API as the provider agent, [EndpointLLM](../llm_clients/endpoint_llm.py) serves as a working example (currently chat-only; no judge support). To test your service, you need to instantiate a concrete class and implement these key methods: - `start_conversation()`: Async method that returns the first conversational turn as a string. For raw LLM APIs you can call `generate_response(self.get_initial_prompt_turns())`; for service-based APIs you may call your own start endpoint (e.g. POST /start_conversation) and return the message. diff --git a/llm_clients/__init__.py b/llm_clients/__init__.py index f6a00aa8..dd6c10eb 100644 --- a/llm_clients/__init__.py +++ b/llm_clients/__init__.py @@ -6,6 +6,7 @@ - Gemini (gemini-*) - Azure (azure-*) - Ollama (ollama-*) + - Custom endpoint (endpoint, endpoint-*) """ from .config import Config diff --git a/llm_clients/config.py b/llm_clients/config.py index 5961eaca..e930bc07 100644 --- a/llm_clients/config.py +++ b/llm_clients/config.py @@ -25,11 +25,18 @@ class Config: # API Keys ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") + OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") - GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # For Gemini + + GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") + AZURE_API_KEY = os.getenv("AZURE_API_KEY") AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT") - AZURE_API_VERSION = os.getenv("AZURE_API_VERSION") # Optional + AZURE_API_VERSION = os.getenv("AZURE_API_VERSION") + + ENDPOINT_API_KEY = os.getenv("ENDPOINT_API_KEY", None) + ENDPOINT_URL = os.getenv("ENDPOINT_URL", None) + ENDPOINT_START_URL = os.getenv("ENDPOINT_START_URL", None) @classmethod def get_claude_config(cls) -> Dict[str, Any]: @@ -80,3 +87,32 @@ def get_ollama_config(cls) -> Dict[str, Any]: "model": "llama3:8b", "base_url": "http://localhost:11434", # Default Ollama URL } + + @classmethod + def get_endpoint_config(cls) -> Dict[str, Any]: + """Get custom endpoint configuration. + + Returns base_url (no /api/chat path), api_key, and default model. + Runtime parameters can override via kwargs. + Raises ValueError if ENDPOINT_API_KEY or ENDPOINT_URL + are not set in the environment. + ENDPOINT_START_URL is optional and can be set to None. + """ + missing = [] + if cls.ENDPOINT_API_KEY is None: + missing.append("ENDPOINT_API_KEY") + if cls.ENDPOINT_URL is None: + missing.append("ENDPOINT_URL") + if cls.ENDPOINT_START_URL is None: + print("ENDPOINT_START_URL is not set in the environment.") + if missing: + raise ValueError( + "Custom endpoint requires these environment variables: " + f"{', '.join(missing)}" + ) + return { + "base_url": cls.ENDPOINT_URL, + "api_key": cls.ENDPOINT_API_KEY, + "start_url": cls.ENDPOINT_START_URL, + "model": "phi4", + } diff --git a/llm_clients/endpoint_llm.py b/llm_clients/endpoint_llm.py new file mode 100644 index 00000000..4c0e519b --- /dev/null +++ b/llm_clients/endpoint_llm.py @@ -0,0 +1,194 @@ +import time +from typing import Any, Dict, List, Optional + +import aiohttp + +from utils.conversation_utils import build_langchain_messages + +from .config import Config +from .llm_interface import LLMInterface, Role + + +class EndpointLLM(LLMInterface): + """Chat-only LLM that calls a custom POST /api/chat endpoint. + + The API manages conversation history server-side via conversation_id. + This implementation does not support structured output and cannot be used + as a judge. For judge operations, use Claude, OpenAI, Gemini, or Azure. + + System prompt: This class accepts system_prompt (from LLMInterface) for + interface consistency and as an example for subclasses. By default we do + not send it to the endpoint as custom APIs typically manage system context + themselves. To apply it (e.g. prefix first user message with + \"System: ...\"), override generate_response or _build_body in a subclass. + """ + + def __init__( + self, + name: str, + role: Role, + system_prompt: Optional[str] = None, + model_name: Optional[str] = None, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs, + ): + first_message = kwargs.pop("first_message", None) + start_prompt = kwargs.pop("start_prompt", None) + super().__init__( + name, + role, + system_prompt, + first_message=first_message, + start_prompt=start_prompt, + ) + + cfg = Config.get_endpoint_config() + self._api_key = api_key or cfg["api_key"] + self._base_url = base_url or cfg["base_url"] + self._start_url = cfg.get("start_url", None) + + # NOTE: if start_url is set, we don't need to use the start_prompt + # unless the developer wants to utilize it + if self._start_url is not None: + self.start_prompt = None + + if model_name and model_name.lower().startswith("endpoint-"): + self._api_model = model_name[len("endpoint-") :].strip() or cfg["model"] + else: + self._api_model = cfg["model"] + self.model_name = model_name or "endpoint" + self.temperature = kwargs.pop("temperature", None) + self.max_tokens = kwargs.pop("max_tokens", None) + + def __getattr__(self, name): + """Delegate to self.llm when present; else return self's attribute or None. + + Only uses __dict__ lookups to avoid recursion. Attributes like + temperature and max_tokens are on self; unknown names return None. + """ + if "llm" in self.__dict__ and hasattr(self.__dict__["llm"], name): + return getattr(self.__dict__["llm"], name) + if name in self.__dict__: + return self.__dict__[name] + return None + + async def start_conversation(self) -> str: + """Produce the first conversational turn: + - static first_message if set, or + - API call to start_url if set, or + - API call to /api/chat with start_prompt if neither is set. + """ + if self.first_message is not None: + self._set_response_metadata("endpoint", static_first_message=True) + return self.first_message + elif self._start_url is not None: + start_time = time.time() + resp_data = await self._ainvoke(self._start_url, self.start_prompt) + return self._process_chat_response( + resp_data, round(time.time() - start_time, 3) + ) + else: + return await self.generate_response(self.get_initial_prompt_turns()) + + def _default_headers(self) -> Dict[str, str]: + """Default request headers (API key and content type).""" + return { + "X-API-Key": self._api_key, + "Content-Type": "application/json", + } + + def _process_chat_response( + self, resp_data: Dict[str, Any], response_time_seconds: float + ) -> str: + """Extract message text from API response and set metadata. Return content.""" + msg_data = resp_data.get("message") or {} + msg_text: str = msg_data.get("content", "") + + usage = {} + if resp_data.get("prompt_eval_count") is not None: + usage["prompt_tokens"] = resp_data.get("prompt_eval_count", 0) + if resp_data.get("eval_count") is not None: + usage["completion_tokens"] = resp_data.get("eval_count", 0) + if usage: + usage.setdefault("prompt_tokens", 0) + usage.setdefault("completion_tokens", 0) + usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"] + + self._set_response_metadata( + "endpoint", + model=resp_data.get("model", self._api_model), + response_id=msg_data.get("id"), + usage=usage, + conversation_id=resp_data.get("conversation_id"), + response_time_seconds=response_time_seconds, + total_duration=resp_data.get("total_duration"), + load_duration=resp_data.get("load_duration"), + prompt_eval_count=resp_data.get("prompt_eval_count"), + prompt_eval_duration=resp_data.get("prompt_eval_duration"), + eval_count=resp_data.get("eval_count"), + eval_duration=resp_data.get("eval_duration"), + ) + self._update_conversation_id_from_metadata() + return msg_text + + def _build_body(self, content: str) -> Dict[str, Any]: + """Body: model, messages (user content only), stream, conversation_id. + System prompt is not included; see class docstring. + """ + return { + "model": self._api_model, + "messages": [{"role": "user", "content": content}], + "stream": False, + "conversation_id": self.conversation_id, + } + + async def _ainvoke( + self, + url: str, + content: str, + *, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """POST to url with body built from content; return parsed JSON. + Body: model, messages (single user message), stream=False, conversation_id. + Default headers when headers is None. Raises RuntimeError on non-200. + """ + req_headers = headers if headers is not None else self._default_headers() + body = self._build_body(content) + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=req_headers, json=body) as resp: + if resp.status != 200: + text = await resp.text() + raise RuntimeError(f"Endpoint returned {resp.status}: {text[:500]}") + return await resp.json() + + async def generate_response( + self, + conversation_history: Optional[List[Dict[str, Any]]] = None, + ) -> str: + """Generate a response via POST /api/chat with server-side conversation_id. + + Only the latest user content is sent; self.system_prompt is not included + in the request (see class docstring for rationale). + """ + if not conversation_history or len(conversation_history) == 0: + return await self.start_conversation() + + messages = build_langchain_messages(self.role, conversation_history) + last_message = messages[-1].text # no system_prompt in payload by design + + try: + start_time = time.time() + resp_data = await self._ainvoke(self._base_url, last_message) + return self._process_chat_response( + resp_data, round(time.time() - start_time, 3) + ) + except Exception as e: + self._set_response_metadata("endpoint", error=str(e)) + self._update_conversation_id_from_metadata() + return f"Error generating response: {str(e)}" + + def set_system_prompt(self, system_prompt: str) -> None: + """Set or update the system prompt.""" + self.system_prompt = system_prompt diff --git a/llm_clients/llm_factory.py b/llm_clients/llm_factory.py index 3ce0d57e..a60b0e4c 100644 --- a/llm_clients/llm_factory.py +++ b/llm_clients/llm_factory.py @@ -67,6 +67,10 @@ def create_llm( from .gemini_llm import GeminiLLM return GeminiLLM(name, role, system_prompt, model_name, **model_params) + elif "endpoint" in model_lower: + from .endpoint_llm import EndpointLLM + + return EndpointLLM(name, role, system_prompt, model_name, **model_params) else: raise ValueError(f"Unsupported model: {model_name}") @@ -111,7 +115,7 @@ def create_judge_llm( f"generation. Judge operations require models with structured " f"output support. Supported models: Claude (claude-*), " f"OpenAI (gpt-*), Gemini (gemini-*), Azure (azure-*). " - f"Not supported: Ollama models." + f"Not supported: Ollama (ollama-*), Endpoint (endpoint-*)." ) return llm diff --git a/llm_clients/llm_interface.py b/llm_clients/llm_interface.py index ca051bbd..6fc94566 100644 --- a/llm_clients/llm_interface.py +++ b/llm_clients/llm_interface.py @@ -80,7 +80,7 @@ def _update_conversation_id_from_metadata(self) -> None: will overwrite self.conversation_id here. """ cid = (self._last_response_metadata or {}).get("conversation_id") - if cid is not None: + if cid is not None and cid != self.conversation_id and cid != "": self.conversation_id = cid def _set_response_metadata(self, provider: str, **extra: Any) -> None: diff --git a/tests/unit/llm_clients/conftest.py b/tests/unit/llm_clients/conftest.py index 63a031de..49e4fb7e 100644 --- a/tests/unit/llm_clients/conftest.py +++ b/tests/unit/llm_clients/conftest.py @@ -339,6 +339,22 @@ def mock_ollama_model(): yield mock +@pytest.fixture +def mock_endpoint_config(): + """Patch custom endpoint configuration for EndpointLLM tests.""" + from unittest.mock import patch + + with patch( + "llm_clients.endpoint_llm.Config.get_endpoint_config", + return_value={ + "base_url": "https://api.example.com/chat", + "api_key": "test-endpoint-key", + "model": "phi4", + }, + ): + yield + + # Note there is no need to mock the other LLM Client configs as Azure's is a bit complex @pytest.fixture def mock_azure_config(): diff --git a/tests/unit/llm_clients/test_endpoint_llm.py b/tests/unit/llm_clients/test_endpoint_llm.py new file mode 100644 index 00000000..7c149d2e --- /dev/null +++ b/tests/unit/llm_clients/test_endpoint_llm.py @@ -0,0 +1,259 @@ +"""Unit tests for EndpointLLM class.""" + +from contextlib import contextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from llm_clients import Role +from llm_clients.endpoint_llm import EndpointLLM +from llm_clients.llm_interface import DEFAULT_START_PROMPT + +from .test_base_llm import TestLLMBase +from .test_helpers import ( + assert_error_metadata, + assert_error_response, + assert_iso_timestamp, + assert_metadata_copy_behavior, + assert_metadata_structure, + assert_response_timing, +) + + +def _make_aiohttp_mock( + content: str = "Test response text", + conversation_id: str | None = "server-cid-1", + status: int = 200, +): + """Build mock aiohttp ClientSession/post/response for EndpointLLM.""" + resp_mock = MagicMock() + resp_mock.status = status + resp_mock.json = AsyncMock( + return_value={ + "message": {"content": content, "id": "msg-1"}, + "conversation_id": conversation_id, + "model": "phi4", + } + ) + resp_mock.text = AsyncMock(return_value="") + + post_cm = MagicMock() + post_cm.__aenter__ = AsyncMock(return_value=resp_mock) + post_cm.__aexit__ = AsyncMock(return_value=None) + + session_mock = MagicMock() + session_mock.post = MagicMock(return_value=post_cm) + + session_cm = MagicMock() + session_cm.__aenter__ = AsyncMock(return_value=session_mock) + session_cm.__aexit__ = AsyncMock(return_value=None) + + client_session_mock = MagicMock(return_value=session_cm) + return client_session_mock + + +@pytest.mark.unit +@pytest.mark.usefixtures("mock_endpoint_config") +class TestEndpointLLM(TestLLMBase): + """Unit tests for EndpointLLM. + + EndpointLLM implements LLMInterface only (no JudgeLLM); it uses aiohttp + instead of an underlying .llm, so some base tests are overridden. + """ + + def create_llm(self, role: Role, **kwargs): + if "name" not in kwargs: + kwargs["name"] = "test-endpoint" + return EndpointLLM(role=role, **kwargs) + + def get_provider_name(self) -> str: + return "endpoint" + + @contextmanager + def get_mock_patches(self): + with patch( + "llm_clients.endpoint_llm.aiohttp.ClientSession", + new_callable=lambda: _make_aiohttp_mock(), + ): + yield + + # ------------------------------------------------------------------------- + # Overrides: generate_response uses aiohttp, not llm.llm + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_generate_response_returns_llm_text( + self, mock_response_factory, mock_llm_factory, mock_system_message + ): + expected_text = "Test response text" + with self.get_mock_patches(): + with patch( + "llm_clients.endpoint_llm.aiohttp.ClientSession", + new_callable=lambda: _make_aiohttp_mock(content=expected_text), + ): + llm = self.create_llm(role=Role.PROVIDER, name="TestLLM") + response = await llm.generate_response( + conversation_history=mock_system_message + ) + assert response == expected_text + + @pytest.mark.asyncio + async def test_generate_response_updates_metadata( + self, mock_response_factory, mock_llm_factory, mock_system_message + ): + with self.get_mock_patches(): + llm = self.create_llm(role=Role.PROVIDER, name="TestLLM") + await llm.generate_response(conversation_history=mock_system_message) + metadata = assert_metadata_structure( + llm, + expected_provider=self.get_provider_name(), + expected_role=Role.PROVIDER, + ) + assert "timestamp" in metadata + assert_iso_timestamp(metadata["timestamp"]) + assert_response_timing(metadata) + + @pytest.mark.asyncio + async def test_generate_response_handles_errors( + self, mock_llm_factory, mock_system_message + ): + with self.get_mock_patches(): + with patch( + "llm_clients.endpoint_llm.aiohttp.ClientSession" + ) as mock_session_class: + session_cm = MagicMock() + session_cm.__aenter__ = AsyncMock(side_effect=Exception("API Error")) + session_cm.__aexit__ = AsyncMock(return_value=None) + mock_session_class.return_value = session_cm + + llm = self.create_llm(role=Role.PROVIDER, name="TestLLM") + response = await llm.generate_response( + conversation_history=mock_system_message + ) + + assert_error_response(response, "API Error") + assert_error_metadata( + llm, + expected_provider=self.get_provider_name(), + expected_error_substring="API Error", + ) + + # ------------------------------------------------------------------------- + # Endpoint-specific tests + # ------------------------------------------------------------------------- + + def test_init_passes_first_message_and_start_prompt_to_super(self): + with self.get_mock_patches(): + llm = EndpointLLM( + name="ep", + role=Role.PROVIDER, + first_message="Hello", + start_prompt="Custom start", + ) + assert llm.first_message == "Hello" + assert llm.start_prompt == "Custom start" + + def test_init_default_start_prompt(self): + with self.get_mock_patches(): + llm = EndpointLLM(name="ep", role=Role.PROVIDER) + assert llm.start_prompt == DEFAULT_START_PROMPT + + @pytest.mark.asyncio + async def test_start_conversation_returns_first_message_when_set(self): + with self.get_mock_patches(): + llm = EndpointLLM( + name="ep", + role=Role.PROVIDER, + first_message="Static first reply", + ) + out = await llm.start_conversation() + assert out == "Static first reply" + meta = llm.last_response_metadata + assert meta.get("static_first_message") is True + assert meta.get("provider") == "endpoint" + + @pytest.mark.asyncio + async def test_start_conversation_calls_api_when_no_first_message(self): + with self.get_mock_patches(): + with patch( + "llm_clients.endpoint_llm.aiohttp.ClientSession", + new_callable=lambda: _make_aiohttp_mock(content="First turn from API"), + ) as mock_session_class: + llm = EndpointLLM(name="ep", role=Role.PROVIDER) + out = await llm.start_conversation() + assert out == "First turn from API" + mock_session_class.return_value.__aenter__.return_value.post.assert_called_once() + + @pytest.mark.asyncio + async def test_conversation_id_overwritten_when_endpoint_returns_different( + self, mock_system_message + ): + """Endpoint response conversation_id overwrites client-generated id.""" + client_cid = "client-generated-cid" + server_cid = "server-returned-cid" + with self.get_mock_patches(): + with patch( + "llm_clients.endpoint_llm.aiohttp.ClientSession", + new_callable=lambda: _make_aiohttp_mock( + content="OK", conversation_id=server_cid + ), + ): + llm = EndpointLLM(name="ep", role=Role.PROVIDER) + llm.conversation_id = client_cid + await llm.generate_response(conversation_history=mock_system_message) + assert llm.conversation_id == server_cid + + @pytest.mark.asyncio + async def test_generate_response_with_empty_conversation_history(self): + """Verify start_conversation / default start_prompt with empty history.""" + with self.get_mock_patches(): + with patch( + "llm_clients.endpoint_llm.aiohttp.ClientSession", + new_callable=lambda: _make_aiohttp_mock(content="Delegated first turn"), + ): + llm = EndpointLLM(name="ep", role=Role.PROVIDER) + out = await llm.generate_response(conversation_history=[]) + assert out == "Delegated first turn" + + @pytest.mark.asyncio + async def test_generate_response_none_history_delegates_to_start_conversation( + self, + ): + with self.get_mock_patches(): + with patch( + "llm_clients.endpoint_llm.aiohttp.ClientSession", + new_callable=lambda: _make_aiohttp_mock(content="Delegated from None"), + ): + llm = EndpointLLM(name="ep", role=Role.PROVIDER) + out = await llm.generate_response(conversation_history=None) + assert out == "Delegated from None" + + def test_set_system_prompt(self): + with self.get_mock_patches(): + llm = self.create_llm( + role=Role.PROVIDER, name="TestLLM", system_prompt="Initial" + ) + assert llm.system_prompt == "Initial" + llm.set_system_prompt("Updated") + assert llm.system_prompt == "Updated" + + def test_getattr_returns_none_for_unknown_attribute(self): + with self.get_mock_patches(): + llm = EndpointLLM(name="ep", role=Role.PROVIDER) + assert llm.nonexistent_attr is None + + def test_temperature_and_max_tokens_accessible_from_self(self): + with self.get_mock_patches(): + llm = EndpointLLM( + name="ep", + role=Role.PROVIDER, + temperature=0.3, + max_tokens=100, + ) + assert llm.temperature == 0.3 + assert llm.max_tokens == 100 + + def test_last_response_metadata_copy_returns_copy(self): + with self.get_mock_patches(): + llm = self.create_llm(role=Role.PROVIDER, name="TestLLM") + assert_metadata_copy_behavior(llm)