Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ jobs:
WATSONX_API_KEY: replay-mode-dummy-key
WATSONX_BASE_URL: https://us-south.ml.cloud.ibm.com
WATSONX_PROJECT_ID: replay-mode-dummy-project
VERTEX_AI_PROJECT: ${{ matrix.config.setup == 'vertexai' && 'replay-mode-dummy-project' || '' }}
VERTEX_AI_LOCATION: ${{ matrix.config.setup == 'vertexai' && 'global' || '' }}
AWS_BEARER_TOKEN_BEDROCK: replay-mode-dummy-key
AWS_DEFAULT_REGION: us-west-2
TAVILY_SEARCH_API_KEY: ${{ secrets.TAVILY_SEARCH_API_KEY || 'replay-mode-dummy-key' }}
Expand Down
2 changes: 2 additions & 0 deletions scripts/integration-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
[ -n "${SAFETY_MODEL:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e SAFETY_MODEL=$SAFETY_MODEL"
[ -n "${AWS_BEARER_TOKEN_BEDROCK:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e AWS_BEARER_TOKEN_BEDROCK=$AWS_BEARER_TOKEN_BEDROCK"
[ -n "${AWS_DEFAULT_REGION:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e AWS_DEFAULT_REGION=$AWS_DEFAULT_REGION"
[ -n "${VERTEX_AI_PROJECT:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e VERTEX_AI_PROJECT=$VERTEX_AI_PROJECT"
[ -n "${VERTEX_AI_LOCATION:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e VERTEX_AI_LOCATION=$VERTEX_AI_LOCATION"

if [[ "$TEST_SETUP" == "vllm" ]]; then
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e VLLM_URL=http://localhost:8000/v1"
Expand Down
12 changes: 11 additions & 1 deletion src/llama_stack/distributions/ci-tests/ci_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def get_distribution_template() -> DistributionTemplate:
model_type=ModelType.llm,
)

# Vertex AI model must be pre-registered because the recording system cannot
# replay model-list discovery calls against the Vertex AI endpoint in CI.
vertexai_model = ModelInput(
model_id="vertexai/publishers/google/models/gemini-2.0-flash",
provider_id="${env.VERTEX_AI_PROJECT:+vertexai}",
provider_model_id="publishers/google/models/gemini-2.0-flash",
model_type=ModelType.llm,
)

# Add conditional authentication config (disabled by default for CI tests)
# This tests the conditional auth provider feature and provides a template for users
# To enable: export AUTH_PROVIDER=enabled and configure the auth env vars
Expand Down Expand Up @@ -101,8 +110,9 @@ def get_distribution_template() -> DistributionTemplate:
run_config.default_models.append(azure_model)
run_config.default_models.append(watsonx_model)
run_config.default_models.append(bedrock_model)
run_config.default_models.append(vertexai_model)

# Add WatsonX inference provider
# Add WatsonX inference provider (vertexai is already in starter distribution)
run_config.provider_overrides["inference"].append(watsonx_provider)

# Replace sentence-transformers provider with one that has trust_remote_code=True
Expand Down
5 changes: 5 additions & 0 deletions src/llama_stack/distributions/ci-tests/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ registered_resources:
provider_id: ${env.AWS_BEARER_TOKEN_BEDROCK:+bedrock}
provider_model_id: openai.gpt-oss-20b
model_type: llm
- metadata: {}
model_id: vertexai/publishers/google/models/gemini-2.0-flash
provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_model_id: publishers/google/models/gemini-2.0-flash
model_type: llm
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,11 @@ registered_resources:
provider_id: ${env.AWS_BEARER_TOKEN_BEDROCK:+bedrock}
provider_model_id: openai.gpt-oss-20b
model_type: llm
- metadata: {}
model_id: vertexai/publishers/google/models/gemini-2.0-flash
provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
provider_model_id: publishers/google/models/gemini-2.0-flash
model_type: llm
shields:
- shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard}
Expand Down
221 changes: 216 additions & 5 deletions src/llama_stack/testing/api_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# Test context uses ContextVar since it changes per-test and needs async isolation
from openai.types.completion_choice import CompletionChoice

from llama_stack.core.testing_context import get_test_context, is_debug_mode
from llama_stack.core.testing_context import get_test_context, is_debug_mode, set_test_context

# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
CompletionChoice.model_fields["finish_reason"].annotation = cast(
Expand Down Expand Up @@ -428,7 +428,7 @@ def _get_test_dir(self) -> Path:
For test at "tests/integration/inference/test_foo.py::test_bar",
returns "tests/integration/inference/recordings/".
"""
test_id = get_test_context()
test_id = _get_test_context_with_fallback()
if test_id:
# Extract the directory path from the test nodeid
# e.g., "tests/integration/inference/test_basic.py::test_foo[params]"
Expand Down Expand Up @@ -503,20 +503,21 @@ def store_recording(self, request_hash: str, request: dict[str, Any], response:
logger.info("[RECORDING DEBUG] Storing recording:")
logger.info(f" Request hash: {request_hash}")
logger.info(f" File: {response_path}")
logger.info(f" Test ID: {get_test_context()}")
logger.info(f" Test ID: {_get_test_context_with_fallback()}")
logger.info(f" Endpoint: {endpoint}")

# Save response to JSON file with metadata
with open(response_path, "w") as f:
json.dump(
{
"test_id": get_test_context(),
"test_id": _get_test_context_with_fallback(),
"request": request,
"response": serialized_response,
"id_normalization_mapping": {},
},
f,
indent=2,
default=str,
)
f.write("\n")
f.flush()
Expand Down Expand Up @@ -1050,8 +1051,173 @@ async def replay_recorded_stream():
raise AssertionError(f"Invalid mode: {mode}")


_last_test_id: str | None = None


def _get_test_context_with_fallback() -> str | None:
"""Get test context, falling back to provider data header or last known test ID.

In server mode, ContextVars may not propagate through all async boundaries
(e.g., google-genai SDK may create internal async tasks). We fall back to:
1. The provider data header (set by middleware)
2. The last known test ID (set when any successful context lookup happens)
"""
global _last_test_id

ctx = get_test_context()
if ctx:
_last_test_id = ctx
return ctx

try:
from llama_stack.core.request_headers import PROVIDER_DATA_VAR

provider_data = PROVIDER_DATA_VAR.get()
if provider_data and "__test_id" in provider_data:
_last_test_id = provider_data["__test_id"]
return _last_test_id
except (LookupError, ImportError):
pass

return _last_test_id


async def _patched_genai_method(original_method, self, endpoint, *args, **kwargs):
"""Patched version of google-genai async methods for recording/replay."""
global _current_mode, _current_storage

mode = _current_mode
storage = _current_storage

if is_debug_mode():
logger.info("[RECORDING DEBUG] Entering genai method:")
logger.info(f" Mode: {mode}")
logger.info(f" Endpoint: {endpoint}")
logger.info(f" Test context: {_get_test_context_with_fallback()}")

if mode == APIRecordingMode.LIVE or storage is None:
return await original_method(self, *args, **kwargs)

# Ensure test context is set for recording path resolution
test_id = _get_test_context_with_fallback()
if test_id and not get_test_context():
set_test_context(test_id)

from google.genai import types as genai_types

# Serialize request parameters
model = kwargs.get("model", "")
body = {}
for k, v in kwargs.items():
if hasattr(v, "model_dump"):
body[k] = v.model_dump(mode="json", exclude_none=True)
elif isinstance(v, list):
serialized = []
for item in v:
if hasattr(item, "model_dump"):
serialized.append(item.model_dump(mode="json", exclude_none=True))
else:
serialized.append(str(item) if not isinstance(item, dict | str | int | float | bool) else item)
body[k] = serialized
elif isinstance(v, dict | str | int | float | bool | type(None)):
body[k] = v
else:
body[k] = str(v)

url = f"vertexai://{endpoint}"
method = "POST"
request_hash = normalize_inference_request(method, url, {}, body)

# Try replay
if mode in (APIRecordingMode.REPLAY, APIRecordingMode.RECORD_IF_MISSING):
recording = storage.find_recording(request_hash)
if recording:
response_data = recording["response"]
if response_data.get("is_exception", False):
exc_data = response_data.get("exception_data")
if exc_data:
raise deserialize_exception(exc_data)
raise Exception(response_data.get("exception_message", "Unknown error"))

response_body = response_data["body"]
if response_data.get("is_streaming", False):
response_type = genai_types.GenerateContentResponse

async def replay_genai_stream():
for chunk_data in response_body:
yield response_type.model_validate(chunk_data)

return replay_genai_stream()
else:
if endpoint == "/embed_content":
return genai_types.EmbedContentResponse.model_validate(response_body)
return genai_types.GenerateContentResponse.model_validate(response_body)
elif mode == APIRecordingMode.REPLAY:
raise RuntimeError(
f"Recording not found for genai request hash: {request_hash}\n"
f"Model: {model} | Endpoint: {endpoint}\n"
f"\n"
f"Run './scripts/integration-tests.sh --inference-mode record-if-missing' with required API keys."
)

# Record
if mode in (APIRecordingMode.RECORD, APIRecordingMode.RECORD_IF_MISSING):
request_data = {
"method": method,
"url": url,
"headers": {},
"body": body,
"endpoint": endpoint,
"model": model,
"provider_metadata": {"provider": "vertexai"},
}

is_streaming = endpoint == "/generate_content_stream"

try:
response = await original_method(self, *args, **kwargs)
except Exception as exc:
response_data = {
"body": None,
"is_streaming": is_streaming,
"is_exception": True,
"exception_data": serialize_exception(exc),
"exception_message": str(exc),
}
storage.store_recording(request_hash, request_data, response_data)
raise

if is_streaming:
original_response = response

async def recording_genai_stream():
chunks = []
try:
async for chunk in original_response:
chunks.append(chunk)
yield chunk
finally:
if chunks:
rec_data = {
"body": [c.model_dump(mode="json", exclude_none=True) for c in chunks],
"is_streaming": True,
}
storage.store_recording(request_hash, request_data, rec_data)

return recording_genai_stream()
else:
response_data = {
"body": response.model_dump(mode="json", exclude_none=True),
"is_streaming": False,
}
storage.store_recording(request_hash, request_data, response_data)
return response

raise AssertionError(f"Invalid mode: {mode}")


def patch_inference_clients():
"""Install monkey patches for OpenAI client methods, Ollama AsyncClient methods, tool runtime methods, and aiohttp for rerank."""
"""Install monkey patches for OpenAI client methods, Ollama AsyncClient methods, google-genai methods, tool runtime methods, and aiohttp for rerank."""
global _original_methods

import aiohttp
Expand Down Expand Up @@ -1081,6 +1247,16 @@ def patch_inference_clients():
"aiohttp_post": aiohttp.ClientSession.post,
}

# Google genai patching (optional - only if google-genai is installed)
try:
from google.genai import models as genai_models

_original_methods["genai_generate_content"] = genai_models.AsyncModels.generate_content
_original_methods["genai_generate_content_stream"] = genai_models.AsyncModels.generate_content_stream
_original_methods["genai_embed_content"] = genai_models.AsyncModels.embed_content
except ImportError:
pass

# Create patched methods for OpenAI client
async def patched_chat_completions_create(self, *args, **kwargs):
return await _patched_inference_method(
Expand Down Expand Up @@ -1175,6 +1351,33 @@ def patched_aiohttp_session_post(self, url, **kwargs):
# Apply aiohttp patch
aiohttp.ClientSession.post = patched_aiohttp_session_post

# Apply google-genai patches (if available)
if "genai_generate_content" in _original_methods:
from google.genai import models as genai_models

async def patched_genai_generate_content(self, *args, **kwargs):
return await _patched_genai_method(
_original_methods["genai_generate_content"], self, "/generate_content", *args, **kwargs
)

async def patched_genai_generate_content_stream(self, *args, **kwargs):
return await _patched_genai_method(
_original_methods["genai_generate_content_stream"],
self,
"/generate_content_stream",
*args,
**kwargs,
)

async def patched_genai_embed_content(self, *args, **kwargs):
return await _patched_genai_method(
_original_methods["genai_embed_content"], self, "/embed_content", *args, **kwargs
)

genai_models.AsyncModels.generate_content = patched_genai_generate_content
genai_models.AsyncModels.generate_content_stream = patched_genai_generate_content_stream
genai_models.AsyncModels.embed_content = patched_genai_embed_content


def unpatch_inference_clients():
"""Remove monkey patches and restore original OpenAI, Ollama client, tool runtime, and aiohttp methods."""
Expand Down Expand Up @@ -1215,6 +1418,14 @@ def unpatch_inference_clients():
# Restore aiohttp method
aiohttp.ClientSession.post = _original_methods["aiohttp_post"]

# Restore google-genai methods (if they were patched)
if "genai_generate_content" in _original_methods:
from google.genai import models as genai_models

genai_models.AsyncModels.generate_content = _original_methods["genai_generate_content"]
genai_models.AsyncModels.generate_content_stream = _original_methods["genai_generate_content_stream"]
genai_models.AsyncModels.embed_content = _original_methods["genai_embed_content"]

_original_methods.clear()


Expand Down
1 change: 1 addition & 0 deletions tests/integration/ci_matrix.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
{"suite": "responses", "setup": "azure"},
{"suite": "gpt-reasoning", "setup": "gpt-reasoning"},
{"suite": "responses", "setup": "watsonx"},
{"suite": "responses", "setup": "vertexai"},
{"suite": "bedrock-responses", "setup": "bedrock"},
{"suite": "base-vllm-subset", "setup": "vllm"},
{"suite": "vllm-reasoning", "setup": "vllm"},
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/responses/test_basic_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def skip_if_provider_isnt_vllm(client_with_models, text_model_id):

def skip_if_chat_completions_logprobs_not_supported(client_with_models, text_model_id):
provider_type = provider_from_model(client_with_models, text_model_id).provider_type
if provider_type in ("remote::ollama", "remote::watsonx"):
if provider_type in ("remote::ollama", "remote::watsonx", "remote::vertexai"):
pytest.skip(f"Model {text_model_id} hosted by {provider_type} doesn't support /v1/chat/completions logprobs.")


Expand Down Expand Up @@ -207,6 +207,8 @@ def test_response_non_streaming_multi_turn(responses_client, text_model_id, case

@pytest.mark.parametrize("case", image_test_cases)
def test_response_non_streaming_image(responses_client, vision_model_id, case):
if vision_model_id and vision_model_id.startswith("vertexai/"):
pytest.skip("Vertex AI image handling differs from OpenAI format")
response = responses_client.responses.create(
model=vision_model_id,
input=case.input,
Expand All @@ -217,6 +219,8 @@ def test_response_non_streaming_image(responses_client, vision_model_id, case):

@pytest.mark.parametrize("case", multi_turn_image_test_cases)
def test_response_non_streaming_multi_turn_image(responses_client, vision_model_id, case):
if vision_model_id and vision_model_id.startswith("vertexai/"):
pytest.skip("Vertex AI image handling differs from OpenAI format")
previous_response_id = None
for turn_input, turn_expected in case.turns:
response = responses_client.responses.create(
Expand Down
Loading
Loading