diff --git a/src/llama_stack/providers/remote/inference/vllm/vllm.py b/src/llama_stack/providers/remote/inference/vllm/vllm.py index 39d0c2d030..4f915f2d8b 100644 --- a/src/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/src/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,16 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os -import ssl from collections.abc import AsyncIterator from urllib.parse import urljoin -import aiohttp import httpx from pydantic import ConfigDict from llama_stack.log import get_logger +from llama_stack.providers.utils.inference.http_client import _build_network_client_kwargs from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack_api import ( HealthResponse, @@ -60,14 +58,12 @@ async def initialize(self) -> None: "You must provide a URL in config.yaml (or via the VLLM_URL environment variable) to use vLLM." ) - # Shared SSL context for all calls to improve performance - if self.config.tls_verify is False: - self.shared_ssl_context = False - elif isinstance(self.config.tls_verify, str): - if os.path.isdir(self.config.tls_verify): - self.shared_ssl_context = ssl.create_default_context(capath=self.config.tls_verify) - else: - self.shared_ssl_context = ssl.create_default_context(cafile=self.config.tls_verify) + def _build_httpx_client_kwargs(self) -> dict: + """Build httpx.AsyncClient kwargs that honour network/TLS configuration.""" + kwargs = _build_network_client_kwargs(self.config.network) + if not kwargs: + kwargs["verify"] = self.shared_ssl_context + return kwargs async def health(self) -> HealthResponse: """ @@ -83,7 +79,7 @@ async def health(self) -> HealthResponse: base_url = self.get_base_url() health_url = urljoin(base_url, "health") - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(**self._build_httpx_client_kwargs()) as client: response = await client.get(health_url) response.raise_for_status() return HealthResponse(status=HealthStatus.OK) @@ -159,40 +155,45 @@ def format_item( if request.max_num_results is not None: payload["top_n"] = request.max_num_results + # vLLM does not support /v1/rerank -> + # "To indicate that the rerank API is not part of the standard OpenAI API, + # we have located it at `/rerank`. Please update your client accordingly. + # (Note: Conforms to JinaAI rerank API)" - vLLM 0.15.1 + endpoint = self.get_base_url().replace("/v1", "") + "/rerank" # TODO: find a better solution + + headers: dict[str, str] = {} + api_key = self.get_api_key() + if api_key and api_key != "NO KEY REQUIRED": + headers["Authorization"] = f"Bearer {api_key}" + try: - async with aiohttp.ClientSession() as session: - # vLLM does not support /v1/rerank -> - # "To indicate that the rerank API is not part of the standard OpenAI API, - # we have located it at `/rerank`. Please update your client accordingly. - # (Note: Conforms to JinaAI rerank API)" - vLLM 0.15.1 - endpoint = self.get_base_url().replace("/v1", "") + "/rerank" # TODO: find a better solution - async with session.post(endpoint, headers={}, json=payload) as response: - if response.status != 200: - response_text = await response.text() + async with httpx.AsyncClient(**self._build_httpx_client_kwargs()) as client: + response = await client.post(endpoint, headers=headers, json=payload) + if response.status_code != 200: + raise RuntimeError( + f"vLLM rerank API request failed with status {response.status_code}: {response.text}" + ) + + def convert_result_item(item: dict) -> RerankData: + if "index" not in item or "relevance_score" not in item: raise RuntimeError( - f"vLLM rerank API request failed with status {response.status}: {response_text}" + "vLLM rerank API response missing required fields 'index' or 'relevance_score'" ) - def convert_result_item(item: dict) -> RerankData: - if "index" not in item or "relevance_score" not in item: - raise RuntimeError( - "vLLM rerank API response missing required fields 'index' or 'relevance_score'" - ) - - try: - return RerankData(index=int(item["index"]), relevance_score=float(item["relevance_score"])) - except (TypeError, ValueError) as e: - raise RuntimeError(f"Invalid data types in vLLM rerank API response: {e}") from e + try: + return RerankData(index=int(item["index"]), relevance_score=float(item["relevance_score"])) + except (TypeError, ValueError) as e: + raise RuntimeError(f"Invalid data types in vLLM rerank API response: {e}") from e - result = await response.json() + result = response.json() - if "results" not in result: - raise RuntimeError("vLLM rerank API response missing 'results' field") + if "results" not in result: + raise RuntimeError("vLLM rerank API response missing 'results' field") - rerank_data = [convert_result_item(item) for item in result.get("results")] - rerank_data.sort(key=lambda entry: entry.relevance_score, reverse=True) + rerank_data = [convert_result_item(item) for item in result.get("results")] + rerank_data.sort(key=lambda entry: entry.relevance_score, reverse=True) - return RerankResponse(data=rerank_data) + return RerankResponse(data=rerank_data) - except aiohttp.ClientError as e: + except httpx.HTTPError as e: raise ConnectionError(f"Failed to connect to vLLM rerank API at {endpoint}: {e}") from e diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 4540a0fb80..f96eb85a64 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import asyncio +import ssl import time from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch @@ -311,3 +312,148 @@ async def test_vllm_chat_completion_extra_body(): assert "extra_body" in call_kwargs assert "chat_template_kwargs" in call_kwargs["extra_body"] assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True} + + +class TestHealthTLSConfig: + """Tests that health() honours TLS/network configuration.""" + + async def test_health_uses_shared_ssl_context_by_default(self): + """Without network config, health() should pass shared_ssl_context as verify.""" + config = VLLMInferenceAdapterConfig(base_url="https://vllm.example.com/v1") + adapter = VLLMInferenceAdapter(config=config) + await adapter.initialize() + + kwargs = adapter._build_httpx_client_kwargs() + assert "verify" in kwargs + assert isinstance(kwargs["verify"], ssl.SSLContext) + + async def test_health_uses_network_tls_verify_false(self): + """With network.tls.verify=False, health() should disable TLS verification.""" + config = VLLMInferenceAdapterConfig( + base_url="https://vllm.example.com/v1", + network={"tls": {"verify": False}}, + ) + adapter = VLLMInferenceAdapter(config=config) + await adapter.initialize() + + kwargs = adapter._build_httpx_client_kwargs() + assert kwargs["verify"] is False + + async def test_health_passes_kwargs_to_httpx(self): + """health() should pass _build_httpx_client_kwargs() to httpx.AsyncClient.""" + config = VLLMInferenceAdapterConfig( + base_url="https://vllm.example.com/v1", + network={"tls": {"verify": False}}, + ) + adapter = VLLMInferenceAdapter(config=config) + await adapter.initialize() + + with patch("httpx.AsyncClient") as mock_client_class: + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_client_instance = MagicMock() + mock_client_instance.get = AsyncMock(return_value=mock_response) + mock_client_class.return_value.__aenter__.return_value = mock_client_instance + + await adapter.health() + + mock_client_class.assert_called_once() + call_kwargs = mock_client_class.call_args.kwargs + assert call_kwargs.get("verify") is False + + async def test_legacy_tls_verify_migrates_to_network(self): + """Legacy tls_verify=False should be migrated to network.tls.verify.""" + config = VLLMInferenceAdapterConfig( + base_url="https://vllm.example.com/v1", + tls_verify=False, + ) + adapter = VLLMInferenceAdapter(config=config) + await adapter.initialize() + + kwargs = adapter._build_httpx_client_kwargs() + assert kwargs["verify"] is False + + +class TestRerankTLSAndAuth: + """Tests that rerank() honours TLS and auth configuration.""" + + async def test_rerank_passes_tls_config(self): + """rerank() should use _build_httpx_client_kwargs() for TLS.""" + config = VLLMInferenceAdapterConfig( + base_url="https://vllm.example.com/v1", + network={"tls": {"verify": False}}, + ) + adapter = VLLMInferenceAdapter(config=config) + await adapter.initialize() + + with patch("httpx.AsyncClient") as mock_client_class: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "results": [{"index": 0, "relevance_score": 0.9}], + } + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client_class.return_value.__aenter__.return_value = mock_client_instance + + from llama_stack_api.inference import RerankRequest + + request = RerankRequest(model="rerank-model", query="test", items=["doc1"]) + await adapter.rerank(request) + + mock_client_class.assert_called_once() + call_kwargs = mock_client_class.call_args.kwargs + assert call_kwargs.get("verify") is False + + async def test_rerank_sends_auth_header(self): + """rerank() should include Authorization header when api_token is configured.""" + config = VLLMInferenceAdapterConfig( + base_url="https://vllm.example.com/v1", + api_token="my-secret-token", + ) + adapter = VLLMInferenceAdapter(config=config) + await adapter.initialize() + + with patch("httpx.AsyncClient") as mock_client_class: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "results": [{"index": 0, "relevance_score": 0.9}], + } + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client_class.return_value.__aenter__.return_value = mock_client_instance + + from llama_stack_api.inference import RerankRequest + + request = RerankRequest(model="rerank-model", query="test", items=["doc1"]) + await adapter.rerank(request) + + call_args = mock_client_instance.post.call_args + headers = call_args.kwargs.get("headers", {}) + assert headers.get("Authorization") == "Bearer my-secret-token" + + async def test_rerank_no_auth_header_without_token(self): + """rerank() should not include Authorization header when no api_token is set.""" + config = VLLMInferenceAdapterConfig(base_url="https://vllm.example.com/v1") + adapter = VLLMInferenceAdapter(config=config) + await adapter.initialize() + + with patch("httpx.AsyncClient") as mock_client_class: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "results": [{"index": 0, "relevance_score": 0.9}], + } + mock_client_instance = MagicMock() + mock_client_instance.post = AsyncMock(return_value=mock_response) + mock_client_class.return_value.__aenter__.return_value = mock_client_instance + + from llama_stack_api.inference import RerankRequest + + request = RerankRequest(model="rerank-model", query="test", items=["doc1"]) + await adapter.rerank(request) + + call_args = mock_client_instance.post.call_args + headers = call_args.kwargs.get("headers", {}) + assert "Authorization" not in headers