Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 40 additions & 39 deletions src/llama_stack/providers/remote/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import ssl
from collections.abc import AsyncIterator
from urllib.parse import urljoin

import aiohttp
import httpx
from pydantic import ConfigDict

from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.http_client import _build_network_client_kwargs
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import (
HealthResponse,
Expand Down Expand Up @@ -60,14 +58,12 @@ async def initialize(self) -> None:
"You must provide a URL in config.yaml (or via the VLLM_URL environment variable) to use vLLM."
)

# Shared SSL context for all calls to improve performance
if self.config.tls_verify is False:
self.shared_ssl_context = False
elif isinstance(self.config.tls_verify, str):
if os.path.isdir(self.config.tls_verify):
self.shared_ssl_context = ssl.create_default_context(capath=self.config.tls_verify)
else:
self.shared_ssl_context = ssl.create_default_context(cafile=self.config.tls_verify)
def _build_httpx_client_kwargs(self) -> dict:
"""Build httpx.AsyncClient kwargs that honour network/TLS configuration."""
kwargs = _build_network_client_kwargs(self.config.network)
if not kwargs:
kwargs["verify"] = self.shared_ssl_context
return kwargs

async def health(self) -> HealthResponse:
"""
Expand All @@ -83,7 +79,7 @@ async def health(self) -> HealthResponse:
base_url = self.get_base_url()
health_url = urljoin(base_url, "health")

async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(**self._build_httpx_client_kwargs()) as client:
response = await client.get(health_url)
response.raise_for_status()
return HealthResponse(status=HealthStatus.OK)
Expand Down Expand Up @@ -159,40 +155,45 @@ def format_item(
if request.max_num_results is not None:
payload["top_n"] = request.max_num_results

# vLLM does not support /v1/rerank ->
# "To indicate that the rerank API is not part of the standard OpenAI API,
# we have located it at `/rerank`. Please update your client accordingly.
# (Note: Conforms to JinaAI rerank API)" - vLLM 0.15.1
endpoint = self.get_base_url().replace("/v1", "") + "/rerank" # TODO: find a better solution

headers: dict[str, str] = {}
api_key = self.get_api_key()
if api_key and api_key != "NO KEY REQUIRED":
headers["Authorization"] = f"Bearer {api_key}"

try:
async with aiohttp.ClientSession() as session:
# vLLM does not support /v1/rerank ->
# "To indicate that the rerank API is not part of the standard OpenAI API,
# we have located it at `/rerank`. Please update your client accordingly.
# (Note: Conforms to JinaAI rerank API)" - vLLM 0.15.1
endpoint = self.get_base_url().replace("/v1", "") + "/rerank" # TODO: find a better solution
async with session.post(endpoint, headers={}, json=payload) as response:
if response.status != 200:
response_text = await response.text()
async with httpx.AsyncClient(**self._build_httpx_client_kwargs()) as client:
response = await client.post(endpoint, headers=headers, json=payload)
if response.status_code != 200:
raise RuntimeError(
f"vLLM rerank API request failed with status {response.status_code}: {response.text}"
)

def convert_result_item(item: dict) -> RerankData:
if "index" not in item or "relevance_score" not in item:
raise RuntimeError(
f"vLLM rerank API request failed with status {response.status}: {response_text}"
"vLLM rerank API response missing required fields 'index' or 'relevance_score'"
)

def convert_result_item(item: dict) -> RerankData:
if "index" not in item or "relevance_score" not in item:
raise RuntimeError(
"vLLM rerank API response missing required fields 'index' or 'relevance_score'"
)

try:
return RerankData(index=int(item["index"]), relevance_score=float(item["relevance_score"]))
except (TypeError, ValueError) as e:
raise RuntimeError(f"Invalid data types in vLLM rerank API response: {e}") from e
try:
return RerankData(index=int(item["index"]), relevance_score=float(item["relevance_score"]))
except (TypeError, ValueError) as e:
raise RuntimeError(f"Invalid data types in vLLM rerank API response: {e}") from e

result = await response.json()
result = response.json()

if "results" not in result:
raise RuntimeError("vLLM rerank API response missing 'results' field")
if "results" not in result:
raise RuntimeError("vLLM rerank API response missing 'results' field")

rerank_data = [convert_result_item(item) for item in result.get("results")]
rerank_data.sort(key=lambda entry: entry.relevance_score, reverse=True)
rerank_data = [convert_result_item(item) for item in result.get("results")]
rerank_data.sort(key=lambda entry: entry.relevance_score, reverse=True)

return RerankResponse(data=rerank_data)
return RerankResponse(data=rerank_data)

except aiohttp.ClientError as e:
except httpx.HTTPError as e:
raise ConnectionError(f"Failed to connect to vLLM rerank API at {endpoint}: {e}") from e
146 changes: 146 additions & 0 deletions tests/unit/providers/inference/test_remote_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# the root directory of this source tree.

import asyncio
import ssl
import time
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch

Expand Down Expand Up @@ -311,3 +312,148 @@ async def test_vllm_chat_completion_extra_body():
assert "extra_body" in call_kwargs
assert "chat_template_kwargs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}


class TestHealthTLSConfig:
"""Tests that health() honours TLS/network configuration."""

async def test_health_uses_shared_ssl_context_by_default(self):
"""Without network config, health() should pass shared_ssl_context as verify."""
config = VLLMInferenceAdapterConfig(base_url="https://vllm.example.com/v1")
adapter = VLLMInferenceAdapter(config=config)
await adapter.initialize()

kwargs = adapter._build_httpx_client_kwargs()
assert "verify" in kwargs
assert isinstance(kwargs["verify"], ssl.SSLContext)

async def test_health_uses_network_tls_verify_false(self):
"""With network.tls.verify=False, health() should disable TLS verification."""
config = VLLMInferenceAdapterConfig(
base_url="https://vllm.example.com/v1",
network={"tls": {"verify": False}},
)
adapter = VLLMInferenceAdapter(config=config)
await adapter.initialize()

kwargs = adapter._build_httpx_client_kwargs()
assert kwargs["verify"] is False

async def test_health_passes_kwargs_to_httpx(self):
"""health() should pass _build_httpx_client_kwargs() to httpx.AsyncClient."""
config = VLLMInferenceAdapterConfig(
base_url="https://vllm.example.com/v1",
network={"tls": {"verify": False}},
)
adapter = VLLMInferenceAdapter(config=config)
await adapter.initialize()

with patch("httpx.AsyncClient") as mock_client_class:
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
mock_client_instance = MagicMock()
mock_client_instance.get = AsyncMock(return_value=mock_response)
mock_client_class.return_value.__aenter__.return_value = mock_client_instance

await adapter.health()

mock_client_class.assert_called_once()
call_kwargs = mock_client_class.call_args.kwargs
assert call_kwargs.get("verify") is False

async def test_legacy_tls_verify_migrates_to_network(self):
"""Legacy tls_verify=False should be migrated to network.tls.verify."""
config = VLLMInferenceAdapterConfig(
base_url="https://vllm.example.com/v1",
tls_verify=False,
)
adapter = VLLMInferenceAdapter(config=config)
await adapter.initialize()

kwargs = adapter._build_httpx_client_kwargs()
assert kwargs["verify"] is False


class TestRerankTLSAndAuth:
"""Tests that rerank() honours TLS and auth configuration."""

async def test_rerank_passes_tls_config(self):
"""rerank() should use _build_httpx_client_kwargs() for TLS."""
config = VLLMInferenceAdapterConfig(
base_url="https://vllm.example.com/v1",
network={"tls": {"verify": False}},
)
adapter = VLLMInferenceAdapter(config=config)
await adapter.initialize()

with patch("httpx.AsyncClient") as mock_client_class:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"results": [{"index": 0, "relevance_score": 0.9}],
}
mock_client_instance = MagicMock()
mock_client_instance.post = AsyncMock(return_value=mock_response)
mock_client_class.return_value.__aenter__.return_value = mock_client_instance

from llama_stack_api.inference import RerankRequest

request = RerankRequest(model="rerank-model", query="test", items=["doc1"])
await adapter.rerank(request)

mock_client_class.assert_called_once()
call_kwargs = mock_client_class.call_args.kwargs
assert call_kwargs.get("verify") is False

async def test_rerank_sends_auth_header(self):
"""rerank() should include Authorization header when api_token is configured."""
config = VLLMInferenceAdapterConfig(
base_url="https://vllm.example.com/v1",
api_token="my-secret-token",
)
adapter = VLLMInferenceAdapter(config=config)
await adapter.initialize()

with patch("httpx.AsyncClient") as mock_client_class:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"results": [{"index": 0, "relevance_score": 0.9}],
}
mock_client_instance = MagicMock()
mock_client_instance.post = AsyncMock(return_value=mock_response)
mock_client_class.return_value.__aenter__.return_value = mock_client_instance

from llama_stack_api.inference import RerankRequest

request = RerankRequest(model="rerank-model", query="test", items=["doc1"])
await adapter.rerank(request)

call_args = mock_client_instance.post.call_args
headers = call_args.kwargs.get("headers", {})
assert headers.get("Authorization") == "Bearer my-secret-token"

async def test_rerank_no_auth_header_without_token(self):
"""rerank() should not include Authorization header when no api_token is set."""
config = VLLMInferenceAdapterConfig(base_url="https://vllm.example.com/v1")
adapter = VLLMInferenceAdapter(config=config)
await adapter.initialize()

with patch("httpx.AsyncClient") as mock_client_class:
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"results": [{"index": 0, "relevance_score": 0.9}],
}
mock_client_instance = MagicMock()
mock_client_instance.post = AsyncMock(return_value=mock_response)
mock_client_class.return_value.__aenter__.return_value = mock_client_instance

from llama_stack_api.inference import RerankRequest

request = RerankRequest(model="rerank-model", query="test", items=["doc1"])
await adapter.rerank(request)

call_args = mock_client_instance.post.call_args
headers = call_args.kwargs.get("headers", {})
assert "Authorization" not in headers
Loading