Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
ANTHROPIC_API_KEY=your_anthropic_api_key_here

OPENAI_API_KEY=your_openai_api_key_here

GOOGLE_API_KEY=your_google_api_key_here

AZURE_API_KEY=your_azure_api_key_here
AZURE_ENDPOINT=your_azure_endpoint_here
AZURE_API_VERSION=your_azure_api_version_here
AZURE_API_VERSION=your_azure_api_version_here

ENDPOINT_URL=http://0.0.0.0:8000/api/chat
ENDPOINT_START_URL=http://0.0.0.0:8000/api/start_conversation
ENDPOINT_API_KEY=your_endpoint_api_key_here
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ VERA-MH simulates realistic conversations between Large Language Models (LLMs) f
- **`gemini_llm.py`**: Google Gemini implementation with structured output
- **`azure_llm.py`**: Azure OpenAI and Azure AI Foundry implementation with structured output
- **`ollama_llm.py`**: Ollama model implementation
- **`endpoint_llm.py`**: Example for using your own API as the provider agent (currently chat-only; see [evaluating.md](docs/evaluating.md))
- **`config.py`**: Configuration management for API keys and model settings
- **`utils/`**: Utility functions and helpers
- **`prompt_loader.py`**: Functions for loading prompt configurations
Expand Down
1 change: 1 addition & 0 deletions docs/evaluating.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
VERA-MH is ready to be used to evaluate any chat-based interface.
[This](../llm_clients/llm_interface.py) Abstract Base Class (ABC) represents the interface to be implemented.
Four concrete implementations of that class are provided for the APIs of ChatGPT, Claude, Gemini, Azure, and Llama (via Ollama).
For developers who wish to use their own API as the provider agent, [EndpointLLM](../llm_clients/endpoint_llm.py) serves as a working example (currently chat-only; no judge support).

To test your service, you need to instantiate a concrete class and implement these key methods:
- `start_conversation()`: Async method that returns the first conversational turn as a string. For raw LLM APIs you can call `generate_response(self.get_initial_prompt_turns())`; for service-based APIs you may call your own start endpoint (e.g. POST /start_conversation) and return the message.
Expand Down
1 change: 1 addition & 0 deletions llm_clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Gemini (gemini-*)
- Azure (azure-*)
- Ollama (ollama-*)
- Custom endpoint (endpoint, endpoint-*)
"""

from .config import Config
Expand Down
40 changes: 38 additions & 2 deletions llm_clients/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@ class Config:

# API Keys
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # For Gemini

GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

AZURE_API_KEY = os.getenv("AZURE_API_KEY")
AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT")
AZURE_API_VERSION = os.getenv("AZURE_API_VERSION") # Optional
AZURE_API_VERSION = os.getenv("AZURE_API_VERSION")

ENDPOINT_API_KEY = os.getenv("ENDPOINT_API_KEY", None)
ENDPOINT_URL = os.getenv("ENDPOINT_URL", None)
ENDPOINT_START_URL = os.getenv("ENDPOINT_START_URL", None)

@classmethod
def get_claude_config(cls) -> Dict[str, Any]:
Expand Down Expand Up @@ -80,3 +87,32 @@ def get_ollama_config(cls) -> Dict[str, Any]:
"model": "llama3:8b",
"base_url": "http://localhost:11434", # Default Ollama URL
}

@classmethod
def get_endpoint_config(cls) -> Dict[str, Any]:
"""Get custom endpoint configuration.

Returns base_url (no /api/chat path), api_key, and default model.
Runtime parameters can override via kwargs.
Raises ValueError if ENDPOINT_API_KEY or ENDPOINT_URL
are not set in the environment.
ENDPOINT_START_URL is optional and can be set to None.
"""
missing = []
if cls.ENDPOINT_API_KEY is None:
missing.append("ENDPOINT_API_KEY")
if cls.ENDPOINT_URL is None:
missing.append("ENDPOINT_URL")
if cls.ENDPOINT_START_URL is None:
print("ENDPOINT_START_URL is not set in the environment.")
if missing:
raise ValueError(
"Custom endpoint requires these environment variables: "
f"{', '.join(missing)}"
)
return {
"base_url": cls.ENDPOINT_URL,
"api_key": cls.ENDPOINT_API_KEY,
"start_url": cls.ENDPOINT_START_URL,
"model": "phi4",
}
194 changes: 194 additions & 0 deletions llm_clients/endpoint_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import time
from typing import Any, Dict, List, Optional

import aiohttp

from utils.conversation_utils import build_langchain_messages

from .config import Config
from .llm_interface import LLMInterface, Role


class EndpointLLM(LLMInterface):
"""Chat-only LLM that calls a custom POST /api/chat endpoint.

The API manages conversation history server-side via conversation_id.
This implementation does not support structured output and cannot be used
as a judge. For judge operations, use Claude, OpenAI, Gemini, or Azure.

System prompt: This class accepts system_prompt (from LLMInterface) for
interface consistency and as an example for subclasses. By default we do
not send it to the endpoint as custom APIs typically manage system context
themselves. To apply it (e.g. prefix first user message with
\"System: ...\"), override generate_response or _build_body in a subclass.
"""

def __init__(
self,
name: str,
role: Role,
system_prompt: Optional[str] = None,
model_name: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs,
):
first_message = kwargs.pop("first_message", None)
start_prompt = kwargs.pop("start_prompt", None)
super().__init__(
name,
role,
system_prompt,
first_message=first_message,
start_prompt=start_prompt,
)

cfg = Config.get_endpoint_config()
self._api_key = api_key or cfg["api_key"]
self._base_url = base_url or cfg["base_url"]
self._start_url = cfg.get("start_url", None)

# NOTE: if start_url is set, we don't need to use the start_prompt
# unless the developer wants to utilize it
if self._start_url is not None:
self.start_prompt = None

if model_name and model_name.lower().startswith("endpoint-"):
self._api_model = model_name[len("endpoint-") :].strip() or cfg["model"]
else:
self._api_model = cfg["model"]
self.model_name = model_name or "endpoint"
self.temperature = kwargs.pop("temperature", None)
self.max_tokens = kwargs.pop("max_tokens", None)

def __getattr__(self, name):
"""Delegate to self.llm when present; else return self's attribute or None.

Only uses __dict__ lookups to avoid recursion. Attributes like
temperature and max_tokens are on self; unknown names return None.
"""
if "llm" in self.__dict__ and hasattr(self.__dict__["llm"], name):
return getattr(self.__dict__["llm"], name)
if name in self.__dict__:
return self.__dict__[name]
return None

async def start_conversation(self) -> str:
"""Produce the first conversational turn:
- static first_message if set, or
- API call to start_url if set, or
- API call to /api/chat with start_prompt if neither is set.
"""
if self.first_message is not None:
self._set_response_metadata("endpoint", static_first_message=True)
return self.first_message
elif self._start_url is not None:
start_time = time.time()
resp_data = await self._ainvoke(self._start_url, self.start_prompt)
return self._process_chat_response(
resp_data, round(time.time() - start_time, 3)
)
else:
return await self.generate_response(self.get_initial_prompt_turns())

def _default_headers(self) -> Dict[str, str]:
"""Default request headers (API key and content type)."""
return {
"X-API-Key": self._api_key,
"Content-Type": "application/json",
}

def _process_chat_response(
self, resp_data: Dict[str, Any], response_time_seconds: float
) -> str:
"""Extract message text from API response and set metadata. Return content."""
msg_data = resp_data.get("message") or {}
msg_text: str = msg_data.get("content", "")

usage = {}
if resp_data.get("prompt_eval_count") is not None:
usage["prompt_tokens"] = resp_data.get("prompt_eval_count", 0)
if resp_data.get("eval_count") is not None:
usage["completion_tokens"] = resp_data.get("eval_count", 0)
if usage:
usage.setdefault("prompt_tokens", 0)
usage.setdefault("completion_tokens", 0)
usage["total_tokens"] = usage["prompt_tokens"] + usage["completion_tokens"]

self._set_response_metadata(
"endpoint",
model=resp_data.get("model", self._api_model),
response_id=msg_data.get("id"),
usage=usage,
conversation_id=resp_data.get("conversation_id"),
response_time_seconds=response_time_seconds,
total_duration=resp_data.get("total_duration"),
load_duration=resp_data.get("load_duration"),
prompt_eval_count=resp_data.get("prompt_eval_count"),
prompt_eval_duration=resp_data.get("prompt_eval_duration"),
eval_count=resp_data.get("eval_count"),
eval_duration=resp_data.get("eval_duration"),
)
self._update_conversation_id_from_metadata()
return msg_text

def _build_body(self, content: str) -> Dict[str, Any]:
"""Body: model, messages (user content only), stream, conversation_id.
System prompt is not included; see class docstring.
"""
return {
"model": self._api_model,
"messages": [{"role": "user", "content": content}],
"stream": False,
"conversation_id": self.conversation_id,
}

async def _ainvoke(
self,
url: str,
content: str,
*,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
"""POST to url with body built from content; return parsed JSON.
Body: model, messages (single user message), stream=False, conversation_id.
Default headers when headers is None. Raises RuntimeError on non-200.
"""
req_headers = headers if headers is not None else self._default_headers()
body = self._build_body(content)
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=req_headers, json=body) as resp:
if resp.status != 200:
text = await resp.text()
raise RuntimeError(f"Endpoint returned {resp.status}: {text[:500]}")
return await resp.json()

async def generate_response(
self,
conversation_history: Optional[List[Dict[str, Any]]] = None,
) -> str:
"""Generate a response via POST /api/chat with server-side conversation_id.

Only the latest user content is sent; self.system_prompt is not included
in the request (see class docstring for rationale).
"""
if not conversation_history or len(conversation_history) == 0:
return await self.start_conversation()

messages = build_langchain_messages(self.role, conversation_history)
last_message = messages[-1].text # no system_prompt in payload by design

try:
start_time = time.time()
resp_data = await self._ainvoke(self._base_url, last_message)
return self._process_chat_response(
resp_data, round(time.time() - start_time, 3)
)
except Exception as e:
self._set_response_metadata("endpoint", error=str(e))
self._update_conversation_id_from_metadata()
return f"Error generating response: {str(e)}"

def set_system_prompt(self, system_prompt: str) -> None:
"""Set or update the system prompt."""
self.system_prompt = system_prompt
6 changes: 5 additions & 1 deletion llm_clients/llm_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def create_llm(
from .gemini_llm import GeminiLLM

return GeminiLLM(name, role, system_prompt, model_name, **model_params)
elif "endpoint" in model_lower:
from .endpoint_llm import EndpointLLM

return EndpointLLM(name, role, system_prompt, model_name, **model_params)
else:
raise ValueError(f"Unsupported model: {model_name}")

Expand Down Expand Up @@ -111,7 +115,7 @@ def create_judge_llm(
f"generation. Judge operations require models with structured "
f"output support. Supported models: Claude (claude-*), "
f"OpenAI (gpt-*), Gemini (gemini-*), Azure (azure-*). "
f"Not supported: Ollama models."
f"Not supported: Ollama (ollama-*), Endpoint (endpoint-*)."
)

return llm
2 changes: 1 addition & 1 deletion llm_clients/llm_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _update_conversation_id_from_metadata(self) -> None:
will overwrite self.conversation_id here.
"""
cid = (self._last_response_metadata or {}).get("conversation_id")
if cid is not None:
if cid is not None and cid != self.conversation_id and cid != "":
self.conversation_id = cid

def _set_response_metadata(self, provider: str, **extra: Any) -> None:
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/llm_clients/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,22 @@ def mock_ollama_model():
yield mock


@pytest.fixture
def mock_endpoint_config():
"""Patch custom endpoint configuration for EndpointLLM tests."""
from unittest.mock import patch

with patch(
"llm_clients.endpoint_llm.Config.get_endpoint_config",
return_value={
"base_url": "https://api.example.com/chat",
"api_key": "test-endpoint-key",
"model": "phi4",
},
):
yield


# Note there is no need to mock the other LLM Client configs as Azure's is a bit complex
@pytest.fixture
def mock_azure_config():
Expand Down
Loading