From 391b60fcb2c0399015fe17471f6e93ae07c6a901 Mon Sep 17 00:00:00 2001 From: Guangya Liu Date: Mon, 19 Jan 2026 15:39:50 -0500 Subject: [PATCH 1/6] fix: Auto-expand provider dependencies for --providers in stack CLI --- src/llama_stack/cli/stack/_list_deps.py | 44 +++++++++++++ src/llama_stack/cli/stack/run.py | 53 +++++++++++++++ src/llama_stack/cli/stack/utils.py | 64 ++++++++++++++++++ tests/unit/cli/test_stack_utils.py | 65 +++++++++++++++++++ .../distribution/test_list_deps_output.py | 16 +++++ 5 files changed, 242 insertions(+) create mode 100644 tests/unit/cli/test_stack_utils.py diff --git a/src/llama_stack/cli/stack/_list_deps.py b/src/llama_stack/cli/stack/_list_deps.py index b7116f3af2..234659af54 100644 --- a/src/llama_stack/cli/stack/_list_deps.py +++ b/src/llama_stack/cli/stack/_list_deps.py @@ -16,6 +16,8 @@ from llama_stack.core.stack import run_config_from_dynamic_config_spec from llama_stack.log import get_logger +from .utils import add_dependent_providers + TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" logger = get_logger(name=__name__, category="cli") @@ -103,6 +105,48 @@ def run_stack_list_deps_command(args: argparse.Namespace) -> None: except ValueError as e: cprint(str(e), color="red", file=sys.stderr) sys.exit(1) + provider_list: dict[str, list[Provider]] = dict() + provider_registry = get_provider_registry() + for api_provider in args.providers.split(","): + if "=" not in api_provider: + cprint( + "Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2", + color="red", + file=sys.stderr, + ) + sys.exit(1) + api, provider_type = api_provider.split("=") + providers_for_api = provider_registry.get(Api(api), None) + if providers_for_api is None: + cprint( + f"{api} is not a valid API.", + color="red", + file=sys.stderr, + ) + sys.exit(1) + if provider_type in providers_for_api: + provider = Provider( + provider_type=provider_type, + provider_id=provider_type.split("::")[1], + module=None, + ) + provider_list.setdefault(api, []).append(provider) + else: + cprint( + f"{provider_type} is not a valid provider for the {api} API.", + color="red", + file=sys.stderr, + ) + sys.exit(1) + + add_dependent_providers( + provider_list=provider_list, + provider_registry=provider_registry, + requested_provider_types=list( + {provider.provider_type for providers in provider_list.values() for provider in providers} + ), + ) + config = StackConfig(providers=provider_list, distro_name="providers-run") normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(config) normal_deps += SERVER_DEPENDENCIES diff --git a/src/llama_stack/cli/stack/run.py b/src/llama_stack/cli/stack/run.py index c5a7c0a38c..22e408677b 100644 --- a/src/llama_stack/cli/stack/run.py +++ b/src/llama_stack/cli/stack/run.py @@ -22,6 +22,8 @@ from llama_stack.core.utils.config_resolution import resolve_config_or_distro from llama_stack.log import LoggingConfig, get_logger +from .utils import add_dependent_providers + REPO_ROOT = Path(__file__).parent.parent.parent.parent logger = get_logger(name=__name__, category="cli") @@ -92,6 +94,57 @@ def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: except ValueError as e: cprint(str(e), color="red", file=sys.stderr) sys.exit(1) + provider_list: dict[str, list[Provider]] = dict() + provider_registry = get_provider_registry() + requested_provider_types = [] + for api_provider in args.providers.split(","): + if "=" not in api_provider: + cprint( + "Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2", + color="red", + file=sys.stderr, + ) + sys.exit(1) + api, provider_type = api_provider.split("=") + providers_for_api = provider_registry.get(Api(api), None) + if providers_for_api is None: + cprint( + f"{api} is not a valid API.", + color="red", + file=sys.stderr, + ) + sys.exit(1) + if provider_type in providers_for_api: + config_type = instantiate_class_type(providers_for_api[provider_type].config_class) + if config_type is not None and hasattr(config_type, "sample_run_config"): + config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run") + else: + config = {} + provider = Provider( + provider_type=provider_type, + config=config, + provider_id=provider_type.split("::")[1], + ) + provider_list.setdefault(api, []).append(provider) + requested_provider_types.append(provider_type) + else: + cprint( + f"{provider} is not a valid provider for the {api} API.", + color="red", + file=sys.stderr, + ) + sys.exit(1) + + # Expand transitive API dependencies for the requested providers. + add_dependent_providers( + provider_list=provider_list, + provider_registry=provider_registry, + requested_provider_types=requested_provider_types, + distro_dir="~/.llama/distributions/providers-run", + include_configs=True, + ) + + run_config = self._generate_run_config_from_providers(providers=provider_list) config_dict = run_config.model_dump(mode="json") config_file = distro_dir / "config.yaml" diff --git a/src/llama_stack/cli/stack/utils.py b/src/llama_stack/cli/stack/utils.py index 51e92f3df1..efedd2c9eb 100644 --- a/src/llama_stack/cli/stack/utils.py +++ b/src/llama_stack/cli/stack/utils.py @@ -7,6 +7,10 @@ from enum import Enum from pathlib import Path +from llama_stack.core.datatypes import Provider, ProviderSpec +from llama_stack.core.utils.dynamic import instantiate_class_type +from llama_stack_api import Api + TEMPLATES_PATH = Path(__file__).parent.parent.parent / "distributions" @@ -22,3 +26,63 @@ def print_subcommand_description(parser, subparsers): description = subcommand.description description_text += f" {name:<21} {description}\n" parser.epilog = description_text + + +def add_dependent_providers( + provider_list: dict[str, list[Provider]], + provider_registry: dict[Api, dict[str, ProviderSpec]], + requested_provider_types: list[str], + *, + distro_dir: str | None = None, + include_configs: bool = False, +) -> None: + def add_provider_for_api(api: Api) -> None: + api_key = api.value + if api_key in provider_list and provider_list[api_key]: + return + providers_for_api = provider_registry.get(api) + if not providers_for_api: + return + provider_spec = next( + (spec for spec in providers_for_api.values() if spec.provider_type.startswith("inline::")), + None, + ) + if provider_spec is None: + provider_spec = next(iter(providers_for_api.values()), None) + if provider_spec is None: + return + + if include_configs: + if not distro_dir: + raise ValueError("distro_dir is required when include_configs=True") + config_type = instantiate_class_type(provider_spec.config_class) + if config_type is not None and hasattr(config_type, "sample_run_config"): + config = config_type.sample_run_config(__distro_dir__=distro_dir) + else: + config = {} + provider = Provider( + provider_type=provider_spec.provider_type, + config=config, + provider_id=provider_spec.provider_type.split("::")[1], + ) + else: + provider = Provider( + provider_type=provider_spec.provider_type, + provider_id=provider_spec.provider_type.split("::")[1], + module=None, + ) + provider_list.setdefault(api_key, []).append(provider) + + def expand_dependencies(provider_spec: ProviderSpec) -> None: + for api in provider_spec.api_dependencies: + add_provider_for_api(api) + for candidate in provider_list.get(api.value, []): + candidate_spec = provider_registry[api].get(candidate.provider_type) + if candidate_spec: + expand_dependencies(candidate_spec) + + for provider_type in requested_provider_types: + for _, api_providers in provider_registry.items(): + provider_spec = api_providers.get(provider_type) + if provider_spec: + expand_dependencies(provider_spec) diff --git a/tests/unit/cli/test_stack_utils.py b/tests/unit/cli/test_stack_utils.py new file mode 100644 index 0000000000..bb2c1fd9bb --- /dev/null +++ b/tests/unit/cli/test_stack_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.cli.stack.utils import add_dependent_providers +from llama_stack.core.datatypes import Provider +from llama_stack.core.distribution import get_provider_registry + + +def test_add_dependent_providers_expands_required_apis(): + provider_registry = get_provider_registry() + provider_list = { + "agents": [ + Provider( + provider_type="inline::meta-reference", + provider_id="meta-reference", + ) + ] + } + + add_dependent_providers( + provider_list=provider_list, + provider_registry=provider_registry, + requested_provider_types=["inline::meta-reference"], + ) + + # Required API dependencies for agents should be present. + assert "inference" in provider_list + assert "vector_io" in provider_list + assert "tool_runtime" in provider_list + assert "files" in provider_list + + # Providers should be added for those APIs. + assert provider_list["inference"] + assert provider_list["vector_io"] + assert provider_list["tool_runtime"] + assert provider_list["files"] + + +def test_add_dependent_providers_include_configs(): + provider_registry = get_provider_registry() + provider_list = { + "agents": [ + Provider( + provider_type="inline::meta-reference", + provider_id="meta-reference", + ) + ] + } + + add_dependent_providers( + provider_list=provider_list, + provider_registry=provider_registry, + requested_provider_types=["inline::meta-reference"], + include_configs=True, + distro_dir="~/.llama/distributions/providers-run", + ) + + inference_provider = provider_list["inference"][0] + assert inference_provider.config, "Expected sample config for inference provider" + + files_provider = provider_list["files"][0] + assert "storage_dir" in files_provider.config diff --git a/tests/unit/distribution/test_list_deps_output.py b/tests/unit/distribution/test_list_deps_output.py index de7c6fb16c..785ec64487 100644 --- a/tests/unit/distribution/test_list_deps_output.py +++ b/tests/unit/distribution/test_list_deps_output.py @@ -57,3 +57,19 @@ def test_list_deps_formatting_quotes_only_for_uv(): uv_format = format_output_deps_only(["mcp>=1.23.0"], [], [], uv=True) assert uv_format.strip() == "uv pip install 'mcp>=1.23.0'" + + +def test_stack_list_deps_expands_provider_dependencies(): + args = argparse.Namespace( + config=None, + env_name="test-env", + providers="agents=inline::meta-reference", + format="deps-only", + ) + + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + run_stack_list_deps_command(args) + output = mock_stdout.getvalue() + + # Agents provider depends on inference and others; ensure at least inference deps show up. + assert "torch" in output From ffd6954b76fa3984d29c5a151c90ea0678c5494c Mon Sep 17 00:00:00 2001 From: Guangya Liu Date: Fri, 30 Jan 2026 10:59:34 -0500 Subject: [PATCH 2/6] remove auto-expansion of dependencies in stack run --- src/llama_stack/cli/stack/run.py | 17 +-------- tests/unit/cli/test_stack_utils.py | 8 +++-- .../distribution/test_list_deps_output.py | 36 ++++++++++++++++--- 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/src/llama_stack/cli/stack/run.py b/src/llama_stack/cli/stack/run.py index 22e408677b..d5f81c3260 100644 --- a/src/llama_stack/cli/stack/run.py +++ b/src/llama_stack/cli/stack/run.py @@ -22,8 +22,6 @@ from llama_stack.core.utils.config_resolution import resolve_config_or_distro from llama_stack.log import LoggingConfig, get_logger -from .utils import add_dependent_providers - REPO_ROOT = Path(__file__).parent.parent.parent.parent logger = get_logger(name=__name__, category="cli") @@ -95,8 +93,6 @@ def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: cprint(str(e), color="red", file=sys.stderr) sys.exit(1) provider_list: dict[str, list[Provider]] = dict() - provider_registry = get_provider_registry() - requested_provider_types = [] for api_provider in args.providers.split(","): if "=" not in api_provider: cprint( @@ -106,7 +102,7 @@ def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: ) sys.exit(1) api, provider_type = api_provider.split("=") - providers_for_api = provider_registry.get(Api(api), None) + providers_for_api = get_provider_registry().get(Api(api), None) if providers_for_api is None: cprint( f"{api} is not a valid API.", @@ -126,7 +122,6 @@ def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: provider_id=provider_type.split("::")[1], ) provider_list.setdefault(api, []).append(provider) - requested_provider_types.append(provider_type) else: cprint( f"{provider} is not a valid provider for the {api} API.", @@ -134,16 +129,6 @@ def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: file=sys.stderr, ) sys.exit(1) - - # Expand transitive API dependencies for the requested providers. - add_dependent_providers( - provider_list=provider_list, - provider_registry=provider_registry, - requested_provider_types=requested_provider_types, - distro_dir="~/.llama/distributions/providers-run", - include_configs=True, - ) - run_config = self._generate_run_config_from_providers(providers=provider_list) config_dict = run_config.model_dump(mode="json") diff --git a/tests/unit/cli/test_stack_utils.py b/tests/unit/cli/test_stack_utils.py index bb2c1fd9bb..d8aa0debd0 100644 --- a/tests/unit/cli/test_stack_utils.py +++ b/tests/unit/cli/test_stack_utils.py @@ -58,8 +58,12 @@ def test_add_dependent_providers_include_configs(): distro_dir="~/.llama/distributions/providers-run", ) - inference_provider = provider_list["inference"][0] - assert inference_provider.config, "Expected sample config for inference provider" + # Some providers like sentence-transformers don't need configuration, + # so they may have empty configs. Check providers that have actual config needs. + vector_io_provider = provider_list["vector_io"][0] + assert vector_io_provider.config, "Expected sample config for vector_io provider" + assert "persistence" in vector_io_provider.config files_provider = provider_list["files"][0] + assert files_provider.config, "Expected sample config for files provider" assert "storage_dir" in files_provider.config diff --git a/tests/unit/distribution/test_list_deps_output.py b/tests/unit/distribution/test_list_deps_output.py index 785ec64487..945cb9bb44 100644 --- a/tests/unit/distribution/test_list_deps_output.py +++ b/tests/unit/distribution/test_list_deps_output.py @@ -60,7 +60,27 @@ def test_list_deps_formatting_quotes_only_for_uv(): def test_stack_list_deps_expands_provider_dependencies(): - args = argparse.Namespace( + """Test that listing deps for a provider also includes deps from its API dependencies. + + For example, agents=inline::meta-reference depends on the inference API. + When we list deps for agents, we should also get dependencies from an inference provider. + This test verifies the expansion happens by checking that dependencies unique to + inference providers appear in the agents output. + """ + # First, get dependencies for just the inference provider + inference_args = argparse.Namespace( + config=None, + env_name="test-env", + providers="inference=inline::meta-reference", + format="deps-only", + ) + + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + run_stack_list_deps_command(inference_args) + inference_output = mock_stdout.getvalue() + + # Now get dependencies for agents, which should include inference deps + agents_args = argparse.Namespace( config=None, env_name="test-env", providers="agents=inline::meta-reference", @@ -68,8 +88,14 @@ def test_stack_list_deps_expands_provider_dependencies(): ) with patch("sys.stdout", new_callable=StringIO) as mock_stdout: - run_stack_list_deps_command(args) - output = mock_stdout.getvalue() + run_stack_list_deps_command(agents_args) + agents_output = mock_stdout.getvalue() + + # Verify that inference-specific dependencies appear in agents output + # (because agents depends on inference API and dependencies were expanded) + # Pick a few packages that are specific to inference providers + inference_specific_packages = ["torch", "transformers", "accelerate"] - # Agents provider depends on inference and others; ensure at least inference deps show up. - assert "torch" in output + for package in inference_specific_packages: + assert package in inference_output, f"{package} should be in inference deps" + assert package in agents_output, f"{package} should be in agents deps (expanded from inference dependency)" From 455c342e3a222819b6acb98e095dcef7fa64121a Mon Sep 17 00:00:00 2001 From: Guangya Liu Date: Tue, 10 Feb 2026 08:21:34 -0500 Subject: [PATCH 3/6] address comments from leseb --- src/llama_stack/cli/stack/_list_deps.py | 45 +++---------------- src/llama_stack/cli/stack/run.py | 38 ---------------- tests/unit/cli/test_stack_list_deps.py | 45 +++++++++++++++++++ .../distribution/test_list_deps_output.py | 32 ++++++++----- 4 files changed, 72 insertions(+), 88 deletions(-) diff --git a/src/llama_stack/cli/stack/_list_deps.py b/src/llama_stack/cli/stack/_list_deps.py index 234659af54..f6f7ef719f 100644 --- a/src/llama_stack/cli/stack/_list_deps.py +++ b/src/llama_stack/cli/stack/_list_deps.py @@ -13,6 +13,7 @@ from llama_stack.core.build import get_provider_dependencies from llama_stack.core.datatypes import StackConfig +from llama_stack.core.distribution import get_provider_registry from llama_stack.core.stack import run_config_from_dynamic_config_spec from llama_stack.log import get_logger @@ -105,48 +106,16 @@ def run_stack_list_deps_command(args: argparse.Namespace) -> None: except ValueError as e: cprint(str(e), color="red", file=sys.stderr) sys.exit(1) - provider_list: dict[str, list[Provider]] = dict() + # Expand dependent providers (e.g. agents depends on inference, safety, etc.) provider_registry = get_provider_registry() - for api_provider in args.providers.split(","): - if "=" not in api_provider: - cprint( - "Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2", - color="red", - file=sys.stderr, - ) - sys.exit(1) - api, provider_type = api_provider.split("=") - providers_for_api = provider_registry.get(Api(api), None) - if providers_for_api is None: - cprint( - f"{api} is not a valid API.", - color="red", - file=sys.stderr, - ) - sys.exit(1) - if provider_type in providers_for_api: - provider = Provider( - provider_type=provider_type, - provider_id=provider_type.split("::")[1], - module=None, - ) - provider_list.setdefault(api, []).append(provider) - else: - cprint( - f"{provider_type} is not a valid provider for the {api} API.", - color="red", - file=sys.stderr, - ) - sys.exit(1) - + requested_provider_types = list( + {provider.provider_type for providers in config.providers.values() for provider in providers} + ) add_dependent_providers( - provider_list=provider_list, + provider_list=config.providers, provider_registry=provider_registry, - requested_provider_types=list( - {provider.provider_type for providers in provider_list.values() for provider in providers} - ), + requested_provider_types=requested_provider_types, ) - config = StackConfig(providers=provider_list, distro_name="providers-run") normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(config) normal_deps += SERVER_DEPENDENCIES diff --git a/src/llama_stack/cli/stack/run.py b/src/llama_stack/cli/stack/run.py index d5f81c3260..c5a7c0a38c 100644 --- a/src/llama_stack/cli/stack/run.py +++ b/src/llama_stack/cli/stack/run.py @@ -92,44 +92,6 @@ def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: except ValueError as e: cprint(str(e), color="red", file=sys.stderr) sys.exit(1) - provider_list: dict[str, list[Provider]] = dict() - for api_provider in args.providers.split(","): - if "=" not in api_provider: - cprint( - "Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2", - color="red", - file=sys.stderr, - ) - sys.exit(1) - api, provider_type = api_provider.split("=") - providers_for_api = get_provider_registry().get(Api(api), None) - if providers_for_api is None: - cprint( - f"{api} is not a valid API.", - color="red", - file=sys.stderr, - ) - sys.exit(1) - if provider_type in providers_for_api: - config_type = instantiate_class_type(providers_for_api[provider_type].config_class) - if config_type is not None and hasattr(config_type, "sample_run_config"): - config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run") - else: - config = {} - provider = Provider( - provider_type=provider_type, - config=config, - provider_id=provider_type.split("::")[1], - ) - provider_list.setdefault(api, []).append(provider) - else: - cprint( - f"{provider} is not a valid provider for the {api} API.", - color="red", - file=sys.stderr, - ) - sys.exit(1) - run_config = self._generate_run_config_from_providers(providers=provider_list) config_dict = run_config.model_dump(mode="json") config_file = distro_dir / "config.yaml" diff --git a/tests/unit/cli/test_stack_list_deps.py b/tests/unit/cli/test_stack_list_deps.py index 8ff0a839e3..87d8c38a74 100644 --- a/tests/unit/cli/test_stack_list_deps.py +++ b/tests/unit/cli/test_stack_list_deps.py @@ -49,6 +49,15 @@ def test_config_and_providers_are_independent(self, stack_list_deps: StackListDe class TestDelegation: + def _mock_provider_registry(self): + """Return a mock registry that accepts test provider types.""" + from llama_stack.core.datatypes import Api + + return { + Api.inference: {"fireworks": MagicMock()}, + Api.safety: {"llama-guard": MagicMock()}, + } + def test_providers_calls_dynamic_config_spec(self, stack_list_deps: StackListDeps): mock_config = MagicMock() mock_config.external_apis_dir = None @@ -58,6 +67,13 @@ def test_providers_calls_dynamic_config_spec(self, stack_list_deps: StackListDep "llama_stack.cli.stack._list_deps.run_config_from_dynamic_config_spec", return_value=mock_config, ) as mock_fn, + patch( + "llama_stack.cli.stack._list_deps.get_provider_registry", + return_value=self._mock_provider_registry(), + ), + patch( + "llama_stack.cli.stack._list_deps.add_dependent_providers", + ), patch( "llama_stack.cli.stack._list_deps.get_provider_dependencies", return_value=([], [], []), @@ -80,6 +96,13 @@ def test_providers_passes_semicolon_spec_unchanged(self, stack_list_deps: StackL "llama_stack.cli.stack._list_deps.run_config_from_dynamic_config_spec", return_value=mock_config, ) as mock_fn, + patch( + "llama_stack.cli.stack._list_deps.get_provider_registry", + return_value=self._mock_provider_registry(), + ), + patch( + "llama_stack.cli.stack._list_deps.add_dependent_providers", + ), patch( "llama_stack.cli.stack._list_deps.get_provider_dependencies", return_value=([], [], []), @@ -122,6 +145,14 @@ def test_value_error_message_printed_to_stderr(self, stack_list_deps: StackListD class TestOutput: + def _mock_provider_registry(self): + """Return a mock registry that accepts test provider types.""" + from llama_stack.core.datatypes import Api + + return { + Api.inference: {"fireworks": MagicMock()}, + } + def test_normal_deps_printed(self, stack_list_deps: StackListDeps, capsys): mock_config = MagicMock() mock_config.external_apis_dir = None @@ -131,6 +162,13 @@ def test_normal_deps_printed(self, stack_list_deps: StackListDeps, capsys): "llama_stack.cli.stack._list_deps.run_config_from_dynamic_config_spec", return_value=mock_config, ), + patch( + "llama_stack.cli.stack._list_deps.get_provider_registry", + return_value=self._mock_provider_registry(), + ), + patch( + "llama_stack.cli.stack._list_deps.add_dependent_providers", + ), patch( "llama_stack.cli.stack._list_deps.get_provider_dependencies", return_value=(["httpx", "aiohttp"], [], []), @@ -152,6 +190,13 @@ def test_server_dependencies_always_included(self, stack_list_deps: StackListDep "llama_stack.cli.stack._list_deps.run_config_from_dynamic_config_spec", return_value=mock_config, ), + patch( + "llama_stack.cli.stack._list_deps.get_provider_registry", + return_value=self._mock_provider_registry(), + ), + patch( + "llama_stack.cli.stack._list_deps.add_dependent_providers", + ), patch( "llama_stack.cli.stack._list_deps.get_provider_dependencies", return_value=([], [], []), diff --git a/tests/unit/distribution/test_list_deps_output.py b/tests/unit/distribution/test_list_deps_output.py index 945cb9bb44..f44c7d68ea 100644 --- a/tests/unit/distribution/test_list_deps_output.py +++ b/tests/unit/distribution/test_list_deps_output.py @@ -64,14 +64,14 @@ def test_stack_list_deps_expands_provider_dependencies(): For example, agents=inline::meta-reference depends on the inference API. When we list deps for agents, we should also get dependencies from an inference provider. - This test verifies the expansion happens by checking that dependencies unique to - inference providers appear in the agents output. + This test picks a known dependency (inference), lists its deps, then verifies those + deps appear in the agents output (proving expansion happened). """ - # First, get dependencies for just the inference provider + # First, get dependencies for the inference provider (which agents depends on) inference_args = argparse.Namespace( config=None, env_name="test-env", - providers="inference=inline::meta-reference", + providers="inference=inline::sentence-transformers", format="deps-only", ) @@ -91,11 +91,19 @@ def test_stack_list_deps_expands_provider_dependencies(): run_stack_list_deps_command(agents_args) agents_output = mock_stdout.getvalue() - # Verify that inference-specific dependencies appear in agents output - # (because agents depends on inference API and dependencies were expanded) - # Pick a few packages that are specific to inference providers - inference_specific_packages = ["torch", "transformers", "accelerate"] - - for package in inference_specific_packages: - assert package in inference_output, f"{package} should be in inference deps" - assert package in agents_output, f"{package} should be in agents deps (expanded from inference dependency)" + # Verify that dependencies were expanded: agents output should include + # inference-specific dependencies. Extract package names from the inference output + # and verify at least some appear in the agents output. + inference_lines = [line.strip() for line in inference_output.split("\n") if line.strip()] + agents_lines = [line.strip() for line in agents_output.split("\n") if line.strip()] + + # The inference provider should have some dependencies + assert len(inference_lines) > 0, "Inference provider should have dependencies" + + # At least one inference dependency should appear in agents output + # (proving that dependency expansion happened) + common_deps = set(inference_lines) & set(agents_lines) + assert len(common_deps) > 0, ( + "Agents dependencies should include at least some inference dependencies, " + "proving that dependency expansion happened" + ) From 3b5803a344d2092a8cd0ac13864378fd5a33fb94 Mon Sep 17 00:00:00 2001 From: Guangya Liu Date: Wed, 18 Mar 2026 09:15:03 -0400 Subject: [PATCH 4/6] fix unit test error --- tests/unit/cli/test_stack_utils.py | 12 ++++++------ tests/unit/distribution/test_list_deps_output.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit/cli/test_stack_utils.py b/tests/unit/cli/test_stack_utils.py index d8aa0debd0..4c184f76e4 100644 --- a/tests/unit/cli/test_stack_utils.py +++ b/tests/unit/cli/test_stack_utils.py @@ -14,8 +14,8 @@ def test_add_dependent_providers_expands_required_apis(): provider_list = { "agents": [ Provider( - provider_type="inline::meta-reference", - provider_id="meta-reference", + provider_type="inline::builtin", + provider_id="builtin", ) ] } @@ -23,7 +23,7 @@ def test_add_dependent_providers_expands_required_apis(): add_dependent_providers( provider_list=provider_list, provider_registry=provider_registry, - requested_provider_types=["inline::meta-reference"], + requested_provider_types=["inline::builtin"], ) # Required API dependencies for agents should be present. @@ -44,8 +44,8 @@ def test_add_dependent_providers_include_configs(): provider_list = { "agents": [ Provider( - provider_type="inline::meta-reference", - provider_id="meta-reference", + provider_type="inline::builtin", + provider_id="builtin", ) ] } @@ -53,7 +53,7 @@ def test_add_dependent_providers_include_configs(): add_dependent_providers( provider_list=provider_list, provider_registry=provider_registry, - requested_provider_types=["inline::meta-reference"], + requested_provider_types=["inline::builtin"], include_configs=True, distro_dir="~/.llama/distributions/providers-run", ) diff --git a/tests/unit/distribution/test_list_deps_output.py b/tests/unit/distribution/test_list_deps_output.py index f44c7d68ea..d5a6979f03 100644 --- a/tests/unit/distribution/test_list_deps_output.py +++ b/tests/unit/distribution/test_list_deps_output.py @@ -83,7 +83,7 @@ def test_stack_list_deps_expands_provider_dependencies(): agents_args = argparse.Namespace( config=None, env_name="test-env", - providers="agents=inline::meta-reference", + providers="agents=inline::builtin", format="deps-only", ) From d2d44277de8ac42bdad83749761ab5c196bd25eb Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 27 Mar 2026 12:55:00 -0400 Subject: [PATCH 5/6] chore(mypy): reduce mypy errors in agents=builtin::responses --- pyproject.toml | 1 - .../builtin/responses/openai_responses.py | 6 ++- .../agents/builtin/responses/streaming.py | 54 +++++++++++-------- .../agents/builtin/responses/tool_executor.py | 2 + .../inline/agents/builtin/responses/types.py | 3 +- .../inline/agents/builtin/responses/utils.py | 6 ++- .../utils/responses/responses_store.py | 3 +- 7 files changed, 46 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 282603d070..3b699aee99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -302,7 +302,6 @@ exclude = [ "^src/llama_stack/core/store/registry\\.py$", "^src/llama_stack/core/utils/exec\\.py$", "^src/llama_stack/core/utils/prompt_for_config\\.py$", - "^src/llama_stack/providers/inline/agents/builtin/", "^src/llama_stack/providers/inline/datasetio/localfs/", "^src/llama_stack/providers/inline/eval/builtin/eval\\.py$", "^src/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", diff --git a/src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py index 5783ca6d7c..31b2e7ff9a 100644 --- a/src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/builtin/responses/openai_responses.py @@ -593,7 +593,7 @@ async def create_openai_response( presence_penalty: float | None = None, extra_body: dict | None = None, stream_options: ResponseStreamOptions | None = None, - ): + ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: stream = bool(stream) background = bool(background) text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text @@ -802,7 +802,9 @@ async def _create_background_response( created_at = int(time.time()) # Normalize input to list format for storage - input_items = [OpenAIResponseMessage(content=input, role="user")] if isinstance(input, str) else input + input_items: list[OpenAIResponseInput] = ( + [OpenAIResponseMessage(content=input, role="user")] if isinstance(input, str) else input + ) # Create initial queued response queued_response = OpenAIResponseObject( diff --git a/src/llama_stack/providers/inline/agents/builtin/responses/streaming.py b/src/llama_stack/providers/inline/agents/builtin/responses/streaming.py index 54a28b9bd5..915458c544 100644 --- a/src/llama_stack/providers/inline/agents/builtin/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/builtin/responses/streaming.py @@ -176,7 +176,7 @@ def extract_openai_error(exc: Exception) -> tuple[str, str]: raw_message = body.get("message") if raw_code and isinstance(raw_code, str): - final_code: str = _RESPONSES_API_ERROR_CODES.get(raw_code, raw_code) + final_code: str = _RESPONSES_API_ERROR_CODES[raw_code] if raw_code in _RESPONSES_API_ERROR_CODES else raw_code else: final_code = "server_error" @@ -422,10 +422,11 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: async for stream_event in self._process_tools(output_messages): yield stream_event - chat_tool_choice = None + chat_tool_choice: str | dict[str, Any] | None = None # Track allowed tools for filtering (persists across iterations) allowed_tool_names: set[str] | None = None - if self.ctx.tool_choice and len(self.ctx.chat_tools) > 0: + # check truthiness of self.ctx.chat_tools to avoid len(None) + if self.ctx.tool_choice and self.ctx.chat_tools: processed_tool_choice = await _process_tool_choice( self.ctx.chat_tools, self.ctx.tool_choice, @@ -482,7 +483,7 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: ) # Filter tools to only allowed ones if tool_choice specified an allowed list effective_tools = self.ctx.chat_tools - if allowed_tool_names is not None: + if allowed_tool_names is not None and self.ctx.chat_tools is not None: effective_tools = [ tool for tool in self.ctx.chat_tools @@ -523,7 +524,7 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: parallel_tool_calls=effective_parallel_tool_calls, reasoning_effort=self.reasoning.effort if self.reasoning else None, safety_identifier=self.safety_identifier, - service_tier=self.service_tier, + service_tier=ServiceTier(self.service_tier) if self.service_tier else None, max_completion_tokens=remaining_output_tokens, prompt_cache_key=self.prompt_cache_key, top_logprobs=self.top_logprobs, @@ -1222,7 +1223,7 @@ async def _process_streaming_chunks( message_item_id=message_item_id, tool_call_item_ids=tool_call_item_ids, content_part_emitted=content_part_emitted, - logprobs=OpenAIChoiceLogprobs(content=chat_response_logprobs) if chat_response_logprobs else None, + logprobs=chat_response_logprobs if chat_response_logprobs else None, service_tier=chunk_service_tier, ) @@ -1245,7 +1246,7 @@ def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatComp message=assistant_message, finish_reason=result.finish_reason, index=0, - logprobs=result.logprobs, + logprobs=OpenAIChoiceLogprobs(content=result.logprobs) if result.logprobs else None, ) ], created=result.created, @@ -1444,7 +1445,8 @@ async def _process_mcp_tool( ) -> AsyncIterator[OpenAIResponseObjectStream]: """Process an MCP tool configuration and emit appropriate streaming events.""" # Resolve connector_id to server_url if provided - mcp_tool = await resolve_mcp_connector_id(mcp_tool, self.connectors_api) + if self.connectors_api is not None: + mcp_tool = await resolve_mcp_connector_id(mcp_tool, self.connectors_api) # Emit mcp_list_tools.in_progress self.sequence_number += 1 @@ -1474,6 +1476,11 @@ async def _process_mcp_tool( # Get session manager from tool_executor if available (fix for #4452) session_manager = getattr(self.tool_executor, "mcp_session_manager", None) + if not mcp_tool.server_url: + raise ValueError( + f"Failed to list MCP tools for server '{mcp_tool.server_label}': server_url is not set" + ) + # TODO: follow semantic conventions for Open Telemetry tool spans # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span with tracer.start_as_current_span("list_mcp_tools", attributes=attributes): @@ -1679,7 +1686,7 @@ async def _process_tool_choice( elif isinstance(tool_choice, OpenAIResponseInputToolChoiceAllowedTools): # ensure that specified tool choices are available in the chat tools, if not, remove them from the list - final_tools = [] + final_tools: list[dict[str, Any]] = [] for tool in tool_choice.tools: match tool.get("type"): case "function": @@ -1712,19 +1719,18 @@ async def _process_tool_choice( else: # Handle specific tool choice by type # Each case validates the tool exists in chat_tools before returning - tool_name = getattr(tool_choice, "name", None) match tool_choice: case OpenAIResponseInputToolChoiceCustomTool(): - if tool_name and tool_name not in chat_tool_names: - logger.warning(f"Tool {tool_name} not found in chat tools") + if tool_choice.name not in chat_tool_names: + logger.warning(f"Tool {tool_choice.name} not found in chat tools") return None - return OpenAIChatCompletionToolChoiceCustomTool(name=tool_name) + return OpenAIChatCompletionToolChoiceCustomTool(name=tool_choice.name) case OpenAIResponseInputToolChoiceFunctionTool(): - if tool_name and tool_name not in chat_tool_names: - logger.warning(f"Tool {tool_name} not found in chat tools") + if tool_choice.name not in chat_tool_names: + logger.warning(f"Tool {tool_choice.name} not found in chat tools") return None - return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_name) + return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_choice.name) case OpenAIResponseInputToolChoiceFileSearch(): if "file_search" not in chat_tool_names: @@ -1739,21 +1745,25 @@ async def _process_tool_choice( return OpenAIChatCompletionToolChoiceFunctionTool(name="web_search") case OpenAIResponseInputToolChoiceMCPTool(): - tool_choice = convert_mcp_tool_choice( + mcp_result = convert_mcp_tool_choice( chat_tool_names, tool_choice.server_label, server_label_to_tools, - tool_name, + tool_choice.name, ) - if isinstance(tool_choice, dict): + if isinstance(mcp_result, dict): # for single tool choice, return as function tool choice - return OpenAIChatCompletionToolChoiceFunctionTool(name=tool_choice["function"]["name"]) - elif isinstance(tool_choice, list): + function_info = mcp_result["function"] + if not isinstance(function_info, dict): + return None + return OpenAIChatCompletionToolChoiceFunctionTool(name=function_info["name"]) + elif isinstance(mcp_result, list): # for multiple tool choices, return as allowed tools return OpenAIChatCompletionToolChoiceAllowedTools( - tools=tool_choice, + tools=mcp_result, mode="required", ) + return None async def resolve_mcp_connector_id( diff --git a/src/llama_stack/providers/inline/agents/builtin/responses/tool_executor.py b/src/llama_stack/providers/inline/agents/builtin/responses/tool_executor.py index 8c7038143e..6763ab53e6 100644 --- a/src/llama_stack/providers/inline/agents/builtin/responses/tool_executor.py +++ b/src/llama_stack/providers/inline/agents/builtin/responses/tool_executor.py @@ -326,6 +326,8 @@ async def _execute_tool( from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool mcp_tool = mcp_tool_to_server[function_name] + if not mcp_tool.server_url: + raise ValueError(f"Failed to invoke MCP tool {function_name}: server_url is not set") attributes = { "server_label": mcp_tool.server_label, "server_url": mcp_tool.server_url, diff --git a/src/llama_stack/providers/inline/agents/builtin/responses/types.py b/src/llama_stack/providers/inline/agents/builtin/responses/types.py index ff6367bf3b..08dc88f200 100644 --- a/src/llama_stack/providers/inline/agents/builtin/responses/types.py +++ b/src/llama_stack/providers/inline/agents/builtin/responses/types.py @@ -38,7 +38,8 @@ def _json_equal(a: str, b: str) -> bool: """Compare two JSON strings by value, falling back to string comparison.""" try: - return json.loads(a) == json.loads(b) + # json.loads() returns Any, so == on two Any values is also Any + return cast(bool, json.loads(a) == json.loads(b)) except (json.JSONDecodeError, TypeError): return a == b diff --git a/src/llama_stack/providers/inline/agents/builtin/responses/utils.py b/src/llama_stack/providers/inline/agents/builtin/responses/utils.py index 9b906be6bd..8a1d5ba232 100644 --- a/src/llama_stack/providers/inline/agents/builtin/responses/utils.py +++ b/src/llama_stack/providers/inline/agents/builtin/responses/utils.py @@ -565,7 +565,7 @@ def convert_mcp_tool_choice( server_label: str | None = None, server_label_to_tools: dict[str, list[str]] | None = None, tool_name: str | None = None, -) -> dict[str, str] | list[dict[str, str]]: +) -> dict[str, str | dict[str, str]] | list[dict[str, str | dict[str, str]]] | None: """Convert a responses tool choice of type mcp to a chat completions compatible function tool choice.""" if tool_name: @@ -580,6 +580,8 @@ def convert_mcp_tool_choice( tool_names = server_label_to_tools.get(server_label, []) if not tool_names: return None - matching_tools = [{"type": "function", "function": {"name": tool_name}} for tool_name in tool_names] + matching_tools: list[dict[str, str | dict[str, str]]] = [ + {"type": "function", "function": {"name": name}} for name in tool_names + ] return matching_tools return [] diff --git a/src/llama_stack/providers/utils/responses/responses_store.py b/src/llama_stack/providers/utils/responses/responses_store.py index 6d02fa994d..dc8df47648 100644 --- a/src/llama_stack/providers/utils/responses/responses_store.py +++ b/src/llama_stack/providers/utils/responses/responses_store.py @@ -20,6 +20,7 @@ OpenAIResponseObjectWithInput, Order, ResponseInputItemNotFoundError, + ResponseItemInclude, ResponseNotFoundError, ) from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType @@ -253,7 +254,7 @@ async def list_response_input_items( response_id: str, after: str | None = None, before: str | None = None, - include: list[str] | None = None, + include: list[ResponseItemInclude] | None = None, limit: int | None = 20, order: Order | None = Order.desc, ) -> ListOpenAIResponseInputItem: From 4250342ba2d8c14996783441210f8d632ae80b4c Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 27 Mar 2026 15:38:24 -0400 Subject: [PATCH 6/6] chore: fix up --- .../responses/builtin/responses/streaming.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/llama_stack/providers/inline/responses/builtin/responses/streaming.py b/src/llama_stack/providers/inline/responses/builtin/responses/streaming.py index 16c7ae540c..1888d4647f 100644 --- a/src/llama_stack/providers/inline/responses/builtin/responses/streaming.py +++ b/src/llama_stack/providers/inline/responses/builtin/responses/streaming.py @@ -1412,8 +1412,8 @@ def make_openai_tool(tool_name: str, tool: ToolDef) -> ChatCompletionToolParam: for input_tool in tools: if input_tool.type == "function": self.ctx.chat_tools.append( - ChatCompletionToolParam(type="function", function=input_tool.model_dump(exclude_none=True)) - ) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition + ChatCompletionToolParam(type="function", function=input_tool.model_dump(exclude_none=True)) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition + ) elif input_tool.type in WebSearchToolTypes: tool_name = "web_search" # Need to access tool_groups_api from tool_executor @@ -1471,11 +1471,6 @@ async def _process_mcp_tool( # Call list_mcp_tools tool_defs = None list_id = f"mcp_list_{uuid.uuid4()}" - attributes = { - "server_label": mcp_tool.server_label, - "server_url": mcp_tool.server_url, - "mcp_list_tools_id": list_id, - } # Get session manager from tool_executor if available (fix for #4452) session_manager = getattr(self.tool_executor, "mcp_session_manager", None) @@ -1485,6 +1480,12 @@ async def _process_mcp_tool( f"Failed to list MCP tools for server '{mcp_tool.server_label}': server_url is not set" ) + attributes = { + "server_label": mcp_tool.server_label, + "server_url": mcp_tool.server_url, + "mcp_list_tools_id": list_id, + } + # TODO: follow semantic conventions for Open Telemetry tool spans # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span with tracer.start_as_current_span("list_mcp_tools", attributes=attributes):