From 57744521959199db627a866710b0e2c426f18710 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 12 Feb 2026 05:16:34 +0000 Subject: [PATCH 1/6] Refactor config_generator into modular, testable components Break the 507-line monolithic validate_and_render_schema() into focused modules with pure validation functions, proper error handling, and clean I/O separation: - config_providers.py: provider constants, ConfigValidationError, unified URL parsing (replaces 3 different inline implementations) - config_validator.py: 11 pure validation functions (no I/O, no print/exit) - config_generator.py: thin 146-line I/O orchestrator, reads files once (was twice), uses logging instead of print() Also cleans up module responsibilities: - Move stream_access_logs from utils.py to docker_cli.py (Docker operation) - Deduplicate llm_providers->model_providers migration - Fix "Model alias 2 -" debug artifact in error message - Update docker-compose.dev.yaml volume mounts for new files - Rewrite tests: 53 tests calling pure functions directly (no mock_open chains), up from 10 brittle mock-dependent tests Co-Authored-By: Claude Opus 4.6 --- cli/planoai/config_generator.py | 575 ++++-------------- cli/planoai/config_providers.py | 87 +++ cli/planoai/config_validator.py | 486 +++++++++++++++ cli/planoai/docker_cli.py | 22 + cli/planoai/main.py | 2 +- cli/planoai/utils.py | 37 +- cli/test/test_config_generator.py | 951 ++++++++++++++++++------------ config/docker-compose.dev.yaml | 2 + 8 files changed, 1295 insertions(+), 867 deletions(-) create mode 100644 cli/planoai/config_providers.py create mode 100644 cli/planoai/config_validator.py diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index 277685466..e7142a676 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -1,506 +1,145 @@ -import json -import os -from planoai.utils import convert_legacy_listeners -from jinja2 import Environment, FileSystemLoader -import yaml -from jsonschema import validate -from urllib.parse import urlparse -from copy import deepcopy -from planoai.consts import DEFAULT_OTEL_TRACING_GRPC_ENDPOINT +"""Config generator: loads config files, validates, and renders Envoy template. + +This module is the I/O boundary. It reads files, calls pure validation +functions from config_validator and config_providers, then writes output. +Entry point: ``python -m planoai.config_generator`` (called by supervisord). +""" -SUPPORTED_PROVIDERS_WITH_BASE_URL = [ - "azure_openai", - "ollama", - "qwen", - "amazon_bedrock", - "arch", -] +import logging +import os -SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [ - "deepseek", - "groq", - "mistral", - "openai", - "gemini", - "anthropic", - "together_ai", - "xai", - "moonshotai", - "zhipu", -] +import yaml +from copy import deepcopy +from jinja2 import Environment, FileSystemLoader -SUPPORTED_PROVIDERS = ( - SUPPORTED_PROVIDERS_WITHOUT_BASE_URL + SUPPORTED_PROVIDERS_WITH_BASE_URL +from planoai.config_providers import ( + ConfigValidationError, + # Re-export for backward compatibility + SUPPORTED_PROVIDERS, + SUPPORTED_PROVIDERS_WITH_BASE_URL, + SUPPORTED_PROVIDERS_WITHOUT_BASE_URL, +) +from planoai.config_validator import ( + build_clusters, + build_template_data, + migrate_legacy_providers, + process_model_providers, + resolve_agent_orchestrator, + validate_agents, + validate_listeners, + validate_model_aliases, + validate_prompt_targets, + validate_schema, + validate_tracing, ) +from planoai.consts import DEFAULT_OTEL_TRACING_GRPC_ENDPOINT +from planoai.utils import convert_legacy_listeners +log = logging.getLogger(__name__) -def get_endpoint_and_port(endpoint, protocol): - endpoint_tokens = endpoint.split(":") - if len(endpoint_tokens) > 1: - endpoint = endpoint_tokens[0] - port = int(endpoint_tokens[1]) - return endpoint, port - else: - if protocol == "http": - port = 80 - else: - port = 443 - return endpoint, port + +def load_yaml_file(path): + """Read a YAML file and return the parsed dict.""" + with open(path, "r") as f: + raw = f.read() + return yaml.safe_load(raw) def validate_and_render_schema(): - ENVOY_CONFIG_TEMPLATE_FILE = os.getenv( - "ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml" - ) - ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml") - ARCH_CONFIG_FILE_RENDERED = os.getenv( + """Main orchestrator: load -> validate -> process -> render -> write. + + Reads env vars for file paths (Docker integration). + Raises ConfigValidationError on validation failure. + """ + # --- Read environment config --- + template_file = os.getenv("ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml") + config_path = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml") + rendered_config_path = os.getenv( "ARCH_CONFIG_FILE_RENDERED", "/app/arch_config_rendered.yaml" ) - ENVOY_CONFIG_FILE_RENDERED = os.getenv( + envoy_rendered_path = os.getenv( "ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml" ) - ARCH_CONFIG_SCHEMA_FILE = os.getenv( - "ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml" - ) + schema_path = os.getenv("ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml") + template_root = os.getenv("TEMPLATE_ROOT", "./") - env = Environment(loader=FileSystemLoader(os.getenv("TEMPLATE_ROOT", "./"))) - template = env.get_template(ENVOY_CONFIG_TEMPLATE_FILE) + # --- Load files (each read exactly once) --- + config = load_yaml_file(config_path) + schema = load_yaml_file(schema_path) - try: - validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE) - except Exception as e: - print(str(e)) - exit(1) # validate_prompt_config failed. Exit - - with open(ARCH_CONFIG_FILE, "r") as file: - arch_config = file.read() - - with open(ARCH_CONFIG_SCHEMA_FILE, "r") as file: - arch_config_schema = file.read() + env = Environment(loader=FileSystemLoader(template_root)) + template = env.get_template(template_file) - config_yaml = yaml.safe_load(arch_config) - _ = yaml.safe_load(arch_config_schema) - inferred_clusters = {} - - # Convert legacy llm_providers to model_providers - if "llm_providers" in config_yaml: - if "model_providers" in config_yaml: - raise Exception( - "Please provide either llm_providers or model_providers, not both. llm_providers is deprecated, please use model_providers instead" - ) - config_yaml["model_providers"] = config_yaml["llm_providers"] - del config_yaml["llm_providers"] + # --- Validate and process --- + validate_schema(config, schema) + config = migrate_legacy_providers(config) listeners, llm_gateway, prompt_gateway = convert_legacy_listeners( - config_yaml.get("listeners"), config_yaml.get("model_providers") + config.get("listeners"), config.get("model_providers") ) + config["listeners"] = listeners - config_yaml["listeners"] = listeners - - endpoints = config_yaml.get("endpoints", {}) - - # Process agents section and convert to endpoints - agents = config_yaml.get("agents", []) - filters = config_yaml.get("filters", []) - agents_combined = agents + filters - agent_id_keys = set() - - for agent in agents_combined: - agent_id = agent.get("id") - if agent_id in agent_id_keys: - raise Exception( - f"Duplicate agent id {agent_id}, please provide unique id for each agent" - ) - agent_id_keys.add(agent_id) - agent_endpoint = agent.get("url") - - if agent_id and agent_endpoint: - urlparse_result = urlparse(agent_endpoint) - if urlparse_result.scheme and urlparse_result.hostname: - protocol = urlparse_result.scheme - - port = urlparse_result.port - if port is None: - if protocol == "http": - port = 80 - else: - port = 443 - - endpoints[agent_id] = { - "endpoint": urlparse_result.hostname, - "port": port, - "protocol": protocol, - } - - # override the inferred clusters with the ones defined in the config - for name, endpoint_details in endpoints.items(): - inferred_clusters[name] = endpoint_details - # Only call get_endpoint_and_port for manually defined endpoints, not agent-derived ones - if "port" not in endpoint_details: - endpoint = inferred_clusters[name]["endpoint"] - protocol = inferred_clusters[name].get("protocol", "http") - ( - inferred_clusters[name]["endpoint"], - inferred_clusters[name]["port"], - ) = get_endpoint_and_port(endpoint, protocol) - - print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters)) - - if "prompt_targets" in config_yaml: - for prompt_target in config_yaml["prompt_targets"]: - name = prompt_target.get("endpoint", {}).get("name", None) - if not name: - continue - if name not in inferred_clusters: - raise Exception( - f"Unknown endpoint {name}, please add it in endpoints section in your arch_config.yaml file" - ) - - arch_tracing = config_yaml.get("tracing", {}) - - # Resolution order: config yaml > OTEL_TRACING_GRPC_ENDPOINT env var > hardcoded default - opentracing_grpc_endpoint = arch_tracing.get( - "opentracing_grpc_endpoint", - os.environ.get( - "OTEL_TRACING_GRPC_ENDPOINT", DEFAULT_OTEL_TRACING_GRPC_ENDPOINT - ), + agent_endpoints = validate_agents( + config.get("agents", []), config.get("filters", []) ) - # resolve env vars in opentracing_grpc_endpoint if present - if opentracing_grpc_endpoint and "$" in opentracing_grpc_endpoint: - opentracing_grpc_endpoint = os.path.expandvars(opentracing_grpc_endpoint) - print( - f"Resolved opentracing_grpc_endpoint to {opentracing_grpc_endpoint} after expanding environment variables" - ) - arch_tracing["opentracing_grpc_endpoint"] = opentracing_grpc_endpoint - # ensure that opentracing_grpc_endpoint is a valid URL if present and start with http and must not have any path - if opentracing_grpc_endpoint: - urlparse_result = urlparse(opentracing_grpc_endpoint) - if urlparse_result.scheme != "http": - raise Exception( - f"Invalid opentracing_grpc_endpoint {opentracing_grpc_endpoint}, scheme must be http" - ) - if urlparse_result.path and urlparse_result.path != "/": - raise Exception( - f"Invalid opentracing_grpc_endpoint {opentracing_grpc_endpoint}, path must be empty" - ) - - llms_with_endpoint = [] - llms_with_endpoint_cluster_names = set() - updated_model_providers = [] - model_provider_name_set = set() - llms_with_usage = [] - model_name_keys = set() - model_usage_name_keys = set() - - print("listeners: ", listeners) + clusters = build_clusters(config.get("endpoints", {}), agent_endpoints) + log.info("Defined clusters: %s", clusters) - for listener in listeners: - if ( - listener.get("model_providers") is None - or listener.get("model_providers") == [] - ): - continue - print("Processing listener with model_providers: ", listener) - name = listener.get("name", None) + validate_prompt_targets(config, clusters) - for model_provider in listener.get("model_providers", []): - if model_provider.get("usage", None): - llms_with_usage.append(model_provider["name"]) - if model_provider.get("name") in model_provider_name_set: - raise Exception( - f"Duplicate model_provider name {model_provider.get('name')}, please provide unique name for each model_provider" - ) - - model_name = model_provider.get("model") - print("Processing model_provider: ", model_provider) - - # Check if this is a wildcard model (provider/*) - is_wildcard = False - if "/" in model_name: - model_name_tokens = model_name.split("/") - if len(model_name_tokens) >= 2 and model_name_tokens[-1] == "*": - is_wildcard = True - - if model_name in model_name_keys and not is_wildcard: - raise Exception( - f"Duplicate model name {model_name}, please provide unique model name for each model_provider" - ) - - if not is_wildcard: - model_name_keys.add(model_name) - if model_provider.get("name") is None: - model_provider["name"] = model_name - - model_provider_name_set.add(model_provider.get("name")) - - model_name_tokens = model_name.split("/") - if len(model_name_tokens) < 2: - raise Exception( - f"Invalid model name {model_name}. Please provide model name in the format / or /* for wildcards." - ) - provider = model_name_tokens[0].strip() - - # Check if this is a wildcard (provider/*) - is_wildcard = model_name_tokens[-1].strip() == "*" - - # Validate wildcard constraints - if is_wildcard: - if model_provider.get("default", False): - raise Exception( - f"Model {model_name} is configured as default but uses wildcard (*). Default models cannot be wildcards." - ) - if model_provider.get("routing_preferences"): - raise Exception( - f"Model {model_name} has routing_preferences but uses wildcard (*). Models with routing preferences cannot be wildcards." - ) - - # Validate azure_openai and ollama provider requires base_url - if (provider in SUPPORTED_PROVIDERS_WITH_BASE_URL) and model_provider.get( - "base_url" - ) is None: - raise Exception( - f"Provider '{provider}' requires 'base_url' to be set for model {model_name}" - ) - - model_id = "/".join(model_name_tokens[1:]) - - # For wildcard providers, allow any provider name - if not is_wildcard and provider not in SUPPORTED_PROVIDERS: - if ( - model_provider.get("base_url", None) is None - or model_provider.get("provider_interface", None) is None - ): - raise Exception( - f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}" - ) - provider = model_provider.get("provider_interface", None) - elif is_wildcard and provider not in SUPPORTED_PROVIDERS: - # Wildcard models with unsupported providers require base_url and provider_interface - if ( - model_provider.get("base_url", None) is None - or model_provider.get("provider_interface", None) is None - ): - raise Exception( - f"Must provide base_url and provider_interface for unsupported provider {provider} for wildcard model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}" - ) - provider = model_provider.get("provider_interface", None) - elif ( - provider in SUPPORTED_PROVIDERS - and model_provider.get("provider_interface", None) is not None - ): - # For supported providers, provider_interface should not be manually set - raise Exception( - f"Please provide provider interface as part of model name {model_name} using the format /. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' " - ) - - # For wildcard models, don't add model_id to the keys since it's "*" - if not is_wildcard: - if model_id in model_name_keys: - raise Exception( - f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider" - ) - model_name_keys.add(model_id) - - for routing_preference in model_provider.get("routing_preferences", []): - if routing_preference.get("name") in model_usage_name_keys: - raise Exception( - f'Duplicate routing preference name "{routing_preference.get("name")}", please provide unique name for each routing preference' - ) - model_usage_name_keys.add(routing_preference.get("name")) - - # Warn if both passthrough_auth and access_key are configured - if model_provider.get("passthrough_auth") and model_provider.get( - "access_key" - ): - print( - f"WARNING: Model provider '{model_provider.get('name')}' has both 'passthrough_auth: true' and 'access_key' configured. " - f"The access_key will be ignored and the client's Authorization header will be forwarded instead." - ) - - model_provider["model"] = model_id - model_provider["provider_interface"] = provider - model_provider_name_set.add(model_provider.get("name")) - if model_provider.get("provider") and model_provider.get( - "provider_interface" - ): - raise Exception( - "Please provide either provider or provider_interface, not both" - ) - if model_provider.get("provider"): - provider = model_provider["provider"] - model_provider["provider_interface"] = provider - del model_provider["provider"] - updated_model_providers.append(model_provider) - - if model_provider.get("base_url", None): - base_url = model_provider["base_url"] - urlparse_result = urlparse(base_url) - base_url_path_prefix = urlparse_result.path - if base_url_path_prefix and base_url_path_prefix != "/": - # we will now support base_url_path_prefix. This means that the user can provide base_url like http://example.com/path and we will extract /path as base_url_path_prefix - model_provider["base_url_path_prefix"] = base_url_path_prefix - - if urlparse_result.scheme == "" or urlparse_result.scheme not in [ - "http", - "https", - ]: - raise Exception( - "Please provide a valid URL with scheme (http/https) in base_url" - ) - protocol = urlparse_result.scheme - port = urlparse_result.port - if port is None: - if protocol == "http": - port = 80 - else: - port = 443 - endpoint = urlparse_result.hostname - model_provider["endpoint"] = endpoint - model_provider["port"] = port - model_provider["protocol"] = protocol - cluster_name = ( - provider + "_" + endpoint - ) # make name unique by appending endpoint - model_provider["cluster_name"] = cluster_name - # Only add if cluster_name is not already present to avoid duplicates - if cluster_name not in llms_with_endpoint_cluster_names: - llms_with_endpoint.append(model_provider) - llms_with_endpoint_cluster_names.add(cluster_name) - - if len(model_usage_name_keys) > 0: - routing_model_provider = config_yaml.get("routing", {}).get( - "model_provider", None - ) - if ( - routing_model_provider - and routing_model_provider not in model_provider_name_set - ): - raise Exception( - f"Routing model_provider {routing_model_provider} is not defined in model_providers" - ) - if ( - routing_model_provider is None - and "arch-router" not in model_provider_name_set - ): - updated_model_providers.append( - { - "name": "arch-router", - "provider_interface": "arch", - "model": config_yaml.get("routing", {}).get("model", "Arch-Router"), - "internal": True, - } - ) - - # Always add arch-function model provider if not already defined - if "arch-function" not in model_provider_name_set: - updated_model_providers.append( - { - "name": "arch-function", - "provider_interface": "arch", - "model": "Arch-Function", - "internal": True, - } - ) - - if "plano-orchestrator" not in model_provider_name_set: - updated_model_providers.append( - { - "name": "plano-orchestrator", - "provider_interface": "arch", - "model": "Plano-Orchestrator", - "internal": True, - } - ) - - config_yaml["model_providers"] = deepcopy(updated_model_providers) - - listeners_with_provider = 0 - for listener in listeners: - print("Processing listener: ", listener) - model_providers = listener.get("model_providers", None) - if model_providers is not None: - listeners_with_provider += 1 - if listeners_with_provider > 1: - raise Exception( - "Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers" - ) - - # Validate model aliases if present - if "model_aliases" in config_yaml: - model_aliases = config_yaml["model_aliases"] - for alias_name, alias_config in model_aliases.items(): - target = alias_config.get("target") - if target not in model_name_keys: - raise Exception( - f"Model alias 2 - '{alias_name}' targets '{target}' which is not defined as a model. Available models: {', '.join(sorted(model_name_keys))}" - ) - - arch_config_string = yaml.dump(config_yaml) - arch_llm_config_string = yaml.dump(config_yaml) + tracing = validate_tracing( + config.get("tracing", {}), DEFAULT_OTEL_TRACING_GRPC_ENDPOINT + ) - use_agent_orchestrator = config_yaml.get("overrides", {}).get( - "use_agent_orchestrator", False + updated_providers, llms_with_endpoint, model_name_keys = process_model_providers( + listeners, config.get("routing", {}) ) + config["model_providers"] = deepcopy(updated_providers) - agent_orchestrator = None - if use_agent_orchestrator: - print("Using agent orchestrator") + validate_listeners(listeners) - if len(endpoints) == 0: - raise Exception( - "Please provide agent orchestrator in the endpoints section in your arch_config.yaml file" - ) - elif len(endpoints) > 1: - raise Exception( - "Please provide single agent orchestrator in the endpoints section in your arch_config.yaml file" - ) - else: - agent_orchestrator = list(endpoints.keys())[0] + if "model_aliases" in config: + validate_model_aliases(config["model_aliases"], model_name_keys) - print("agent_orchestrator: ", agent_orchestrator) + agent_orchestrator = resolve_agent_orchestrator( + config, config.get("endpoints", {}) + ) - data = { - "prompt_gateway_listener": prompt_gateway, - "llm_gateway_listener": llm_gateway, - "arch_config": arch_config_string, - "arch_llm_config": arch_llm_config_string, - "arch_clusters": inferred_clusters, - "arch_model_providers": updated_model_providers, - "arch_tracing": arch_tracing, - "local_llms": llms_with_endpoint, - "agent_orchestrator": agent_orchestrator, - "listeners": listeners, - } + data = build_template_data( + prompt_gateway, + llm_gateway, + config, + clusters, + updated_providers, + tracing, + llms_with_endpoint, + agent_orchestrator, + listeners, + ) + # --- Render and write --- rendered = template.render(data) - print(ENVOY_CONFIG_FILE_RENDERED) - print(rendered) - with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file: - file.write(rendered) - - with open(ARCH_CONFIG_FILE_RENDERED, "w") as file: - file.write(arch_config_string) + log.info("Writing Envoy config to %s", envoy_rendered_path) + with open(envoy_rendered_path, "w") as f: + f.write(rendered) -def validate_prompt_config(arch_config_file, arch_config_schema_file): - with open(arch_config_file, "r") as file: - arch_config = file.read() + config_string = yaml.dump(config) + with open(rendered_config_path, "w") as f: + f.write(config_string) - with open(arch_config_schema_file, "r") as file: - arch_config_schema = file.read() - - config_yaml = yaml.safe_load(arch_config) - config_schema_yaml = yaml.safe_load(arch_config_schema) +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) try: - validate(config_yaml, config_schema_yaml) + validate_and_render_schema() + except ConfigValidationError as e: + log.error(str(e)) + exit(1) except Exception as e: - print( - f"Error validating arch_config file: {arch_config_file}, schema file: {arch_config_schema_file}, error: {e}" - ) - raise e - - -if __name__ == "__main__": - validate_and_render_schema() + log.error("Unexpected error: %s", e) + exit(1) diff --git a/cli/planoai/config_providers.py b/cli/planoai/config_providers.py new file mode 100644 index 000000000..5f934c8be --- /dev/null +++ b/cli/planoai/config_providers.py @@ -0,0 +1,87 @@ +"""Model provider constants, custom exception, and URL parsing utility.""" + +import logging +from urllib.parse import urlparse + +log = logging.getLogger(__name__) + + +class ConfigValidationError(Exception): + """Raised when config validation fails.""" + + pass + + +# --- Provider Constants --- + +SUPPORTED_PROVIDERS_WITH_BASE_URL = [ + "azure_openai", + "ollama", + "qwen", + "amazon_bedrock", + "arch", +] + +SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [ + "deepseek", + "groq", + "mistral", + "openai", + "gemini", + "anthropic", + "together_ai", + "xai", + "moonshotai", + "zhipu", +] + +SUPPORTED_PROVIDERS = ( + SUPPORTED_PROVIDERS_WITHOUT_BASE_URL + SUPPORTED_PROVIDERS_WITH_BASE_URL +) + +INTERNAL_PROVIDERS = { + "arch-function": { + "name": "arch-function", + "provider_interface": "arch", + "model": "Arch-Function", + "internal": True, + }, + "plano-orchestrator": { + "name": "plano-orchestrator", + "provider_interface": "arch", + "model": "Plano-Orchestrator", + "internal": True, + }, +} + + +def parse_url_endpoint(url): + """Parse a URL into endpoint, port, protocol, and optional path_prefix. + + Replaces the old get_endpoint_and_port() and inline urlparse logic. + Raises ConfigValidationError for invalid URLs. + + Returns dict with keys: endpoint, port, protocol, path_prefix (optional) + """ + result = urlparse(url) + if not result.scheme or result.scheme not in ("http", "https"): + raise ConfigValidationError( + f"Invalid URL '{url}': scheme must be http or https" + ) + if not result.hostname: + raise ConfigValidationError(f"Invalid URL '{url}': hostname is required") + + port = result.port + if port is None: + port = 80 if result.scheme == "http" else 443 + + parsed = { + "endpoint": result.hostname, + "port": port, + "protocol": result.scheme, + } + + if result.path and result.path != "/": + parsed["path_prefix"] = result.path + + return parsed diff --git a/cli/planoai/config_validator.py b/cli/planoai/config_validator.py new file mode 100644 index 000000000..a5a0b0e80 --- /dev/null +++ b/cli/planoai/config_validator.py @@ -0,0 +1,486 @@ +"""Pure validation and transformation functions for Plano config. + +Every function in this module takes data in, returns data out, and raises +ConfigValidationError on failure. No file I/O, no print(), no exit(). +""" + +import logging +import os +from copy import deepcopy +from urllib.parse import urlparse + +import yaml +from jsonschema import validate as jsonschema_validate + +from planoai.config_providers import ( + INTERNAL_PROVIDERS, + SUPPORTED_PROVIDERS, + SUPPORTED_PROVIDERS_WITH_BASE_URL, + ConfigValidationError, + parse_url_endpoint, +) + +log = logging.getLogger(__name__) + + +def validate_schema(config, schema): + """Validate config dict against JSON schema dict. + + Raises ConfigValidationError with a clear message on failure. + """ + try: + jsonschema_validate(config, schema) + except Exception as e: + raise ConfigValidationError(f"Schema validation failed: {e}") from e + + +def migrate_legacy_providers(config): + """Migrate llm_providers -> model_providers if needed. + + Returns a new config dict (does not mutate input). + Raises ConfigValidationError if both are present. + """ + config = deepcopy(config) + + if "llm_providers" in config: + if "model_providers" in config: + raise ConfigValidationError( + "Please provide either llm_providers or model_providers, not both. " + "llm_providers is deprecated, please use model_providers instead" + ) + config["model_providers"] = config.pop("llm_providers") + + return config + + +def validate_agents(agents, filters): + """Validate agent/filter entries and infer endpoint clusters from URLs. + + Returns dict of inferred endpoint clusters keyed by agent_id. + Raises ConfigValidationError on duplicate IDs. + """ + combined = agents + filters + seen_ids = set() + inferred_endpoints = {} + + for agent in combined: + agent_id = agent.get("id") + if agent_id in seen_ids: + raise ConfigValidationError( + f"Duplicate agent id {agent_id}, please provide unique id for each agent" + ) + seen_ids.add(agent_id) + + agent_url = agent.get("url") + if agent_id and agent_url: + result = urlparse(agent_url) + if result.scheme and result.hostname: + port = result.port + if port is None: + port = 80 if result.scheme == "http" else 443 + + inferred_endpoints[agent_id] = { + "endpoint": result.hostname, + "port": port, + "protocol": result.scheme, + } + + return inferred_endpoints + + +def build_clusters(endpoints, agent_inferred): + """Merge explicit endpoints with agent-inferred clusters. + + Returns the final cluster dict. + """ + clusters = dict(agent_inferred) + + for name, endpoint_details in endpoints.items(): + clusters[name] = dict(endpoint_details) + # Resolve port for manually defined endpoints that lack one + if "port" not in clusters[name]: + endpoint = clusters[name]["endpoint"] + protocol = clusters[name].get("protocol", "http") + if ":" in endpoint: + parts = endpoint.split(":") + clusters[name]["endpoint"] = parts[0] + clusters[name]["port"] = int(parts[1]) + else: + clusters[name]["port"] = 80 if protocol == "http" else 443 + + return clusters + + +def validate_prompt_targets(config, clusters): + """Validate that prompt_targets reference valid endpoints.""" + for prompt_target in config.get("prompt_targets", []): + name = prompt_target.get("endpoint", {}).get("name", None) + if not name: + continue + if name not in clusters: + raise ConfigValidationError( + f"Unknown endpoint {name}, please add it in endpoints section " + "in your arch_config.yaml file" + ) + + +def validate_tracing(tracing_config, default_endpoint): + """Validate and resolve the tracing configuration. + + Handles env var resolution for opentracing_grpc_endpoint. + Returns the resolved tracing dict. + Raises ConfigValidationError for invalid endpoints. + """ + tracing = deepcopy(tracing_config) + + # Resolution order: config yaml > OTEL_TRACING_GRPC_ENDPOINT env var > default + endpoint = tracing.get( + "opentracing_grpc_endpoint", + os.environ.get("OTEL_TRACING_GRPC_ENDPOINT", default_endpoint), + ) + + # Resolve env var references like $VAR or ${VAR} + if endpoint and "$" in endpoint: + endpoint = os.path.expandvars(endpoint) + log.info("Resolved opentracing_grpc_endpoint to %s", endpoint) + + tracing["opentracing_grpc_endpoint"] = endpoint + + if endpoint: + result = urlparse(endpoint) + if result.scheme != "http": + raise ConfigValidationError( + f"Invalid opentracing_grpc_endpoint {endpoint}, scheme must be http" + ) + if result.path and result.path != "/": + raise ConfigValidationError( + f"Invalid opentracing_grpc_endpoint {endpoint}, path must be empty" + ) + + return tracing + + +def process_model_providers(listeners, routing_config): + """Process all model providers from listeners. + + Validates names, models, provider interfaces, base_urls, wildcards, + routing preferences, and injects internal providers. + + Args: + listeners: List of listener dicts from config. + routing_config: The 'routing' section from config (may be empty dict). + + Returns: + Tuple of (updated_model_providers, llms_with_endpoint, model_name_keys). + + Raises: + ConfigValidationError on any validation failure. + """ + llms_with_endpoint = [] + llms_with_endpoint_cluster_names = set() + updated_model_providers = [] + model_provider_name_set = set() + model_name_keys = set() + model_usage_name_keys = set() + + for listener in listeners: + if not listener.get("model_providers"): + continue + + for model_provider in listener.get("model_providers", []): + _validate_and_process_single_provider( + model_provider, + model_name_keys, + model_provider_name_set, + model_usage_name_keys, + updated_model_providers, + llms_with_endpoint, + llms_with_endpoint_cluster_names, + ) + + # Inject internal providers + _inject_internal_providers( + updated_model_providers, + model_provider_name_set, + model_usage_name_keys, + routing_config, + ) + + return updated_model_providers, llms_with_endpoint, model_name_keys + + +def _validate_and_process_single_provider( + model_provider, + model_name_keys, + model_provider_name_set, + model_usage_name_keys, + updated_model_providers, + llms_with_endpoint, + llms_with_endpoint_cluster_names, +): + """Validate and normalize a single model_provider entry.""" + # Check duplicate provider name + if model_provider.get("name") in model_provider_name_set: + raise ConfigValidationError( + f"Duplicate model_provider name {model_provider.get('name')}, " + "please provide unique name for each model_provider" + ) + + model_name = model_provider.get("model") + + # Parse model name into provider/model_id + model_name_tokens = model_name.split("/") + if len(model_name_tokens) < 2: + raise ConfigValidationError( + f"Invalid model name {model_name}. Please provide model name in the " + "format / or /* for wildcards." + ) + + provider = model_name_tokens[0].strip() + model_id = "/".join(model_name_tokens[1:]) + is_wildcard = model_name_tokens[-1].strip() == "*" + + # Check duplicate model name (non-wildcard only) + if model_name in model_name_keys and not is_wildcard: + raise ConfigValidationError( + f"Duplicate model name {model_name}, please provide unique model " + "name for each model_provider" + ) + + if not is_wildcard: + model_name_keys.add(model_name) + + # Auto-name if not provided + if model_provider.get("name") is None: + model_provider["name"] = model_name + + model_provider_name_set.add(model_provider.get("name")) + + # Validate wildcard constraints + if is_wildcard: + if model_provider.get("default", False): + raise ConfigValidationError( + f"Model {model_name} is configured as default but uses wildcard (*). " + "Default models cannot be wildcards." + ) + if model_provider.get("routing_preferences"): + raise ConfigValidationError( + f"Model {model_name} has routing_preferences but uses wildcard (*). " + "Models with routing preferences cannot be wildcards." + ) + + # Validate provider requires base_url + if provider in SUPPORTED_PROVIDERS_WITH_BASE_URL and not model_provider.get( + "base_url" + ): + raise ConfigValidationError( + f"Provider '{provider}' requires 'base_url' to be set for model {model_name}" + ) + + # Resolve provider interface + if provider not in SUPPORTED_PROVIDERS: + if not model_provider.get("base_url") or not model_provider.get( + "provider_interface" + ): + raise ConfigValidationError( + f"Must provide base_url and provider_interface for unsupported " + f"provider {provider} for {'wildcard ' if is_wildcard else ''}model " + f"{model_name}. Supported providers are: " + f"{', '.join(SUPPORTED_PROVIDERS)}" + ) + provider = model_provider.get("provider_interface") + elif model_provider.get("provider_interface") is not None: + raise ConfigValidationError( + f"Please provide provider interface as part of model name {model_name} " + "using the format /. For example, use " + "'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' " + ) + + # Check duplicate model_id (non-wildcard only) + if not is_wildcard: + if model_id in model_name_keys: + raise ConfigValidationError( + f"Duplicate model_id {model_id}, please provide unique model_id " + "for each model_provider" + ) + model_name_keys.add(model_id) + + # Validate routing preferences + for routing_preference in model_provider.get("routing_preferences", []): + pref_name = routing_preference.get("name") + if pref_name in model_usage_name_keys: + raise ConfigValidationError( + f'Duplicate routing preference name "{pref_name}", please provide ' + "unique name for each routing preference" + ) + model_usage_name_keys.add(pref_name) + + # Warn if both passthrough_auth and access_key are configured + if model_provider.get("passthrough_auth") and model_provider.get("access_key"): + log.warning( + "Model provider '%s' has both 'passthrough_auth: true' and 'access_key' " + "configured. The access_key will be ignored and the client's Authorization " + "header will be forwarded instead.", + model_provider.get("name"), + ) + + # Normalize provider fields + model_provider["model"] = model_id + model_provider["provider_interface"] = provider + model_provider_name_set.add(model_provider.get("name")) + + if model_provider.get("provider") and model_provider.get("provider_interface"): + raise ConfigValidationError( + "Please provide either provider or provider_interface, not both" + ) + if model_provider.get("provider"): + provider = model_provider["provider"] + model_provider["provider_interface"] = provider + del model_provider["provider"] + + updated_model_providers.append(model_provider) + + # Process base_url into cluster endpoint info + if model_provider.get("base_url"): + _process_base_url( + model_provider, + provider, + llms_with_endpoint, + llms_with_endpoint_cluster_names, + ) + + +def _process_base_url( + model_provider, provider, llms_with_endpoint, llms_with_endpoint_cluster_names +): + """Parse base_url and add cluster endpoint info to the model provider.""" + base_url = model_provider["base_url"] + parsed = parse_url_endpoint(base_url) + + if parsed.get("path_prefix"): + model_provider["base_url_path_prefix"] = parsed["path_prefix"] + + model_provider["endpoint"] = parsed["endpoint"] + model_provider["port"] = parsed["port"] + model_provider["protocol"] = parsed["protocol"] + + cluster_name = provider + "_" + parsed["endpoint"] + model_provider["cluster_name"] = cluster_name + + if cluster_name not in llms_with_endpoint_cluster_names: + llms_with_endpoint.append(model_provider) + llms_with_endpoint_cluster_names.add(cluster_name) + + +def _inject_internal_providers( + updated_model_providers, + model_provider_name_set, + model_usage_name_keys, + routing_config, +): + """Add arch-router, arch-function, plano-orchestrator if not already defined.""" + # Add arch-router if routing preferences exist and no router is configured + if len(model_usage_name_keys) > 0: + routing_model_provider = routing_config.get("model_provider", None) + if ( + routing_model_provider + and routing_model_provider not in model_provider_name_set + ): + raise ConfigValidationError( + f"Routing model_provider {routing_model_provider} is not defined " + "in model_providers" + ) + if ( + routing_model_provider is None + and "arch-router" not in model_provider_name_set + ): + updated_model_providers.append( + { + "name": "arch-router", + "provider_interface": "arch", + "model": routing_config.get("model", "Arch-Router"), + "internal": True, + } + ) + + for name, provider_def in INTERNAL_PROVIDERS.items(): + if name not in model_provider_name_set: + updated_model_providers.append(dict(provider_def)) + + +def validate_listeners(listeners): + """Validate that at most one listener has model_providers.""" + count = sum(1 for l in listeners if l.get("model_providers") is not None) + if count > 1: + raise ConfigValidationError( + "Please provide model_providers either under listeners or at root level, " + "not both. Currently we don't support multiple listeners with model_providers" + ) + + +def validate_model_aliases(aliases, model_name_keys): + """Validate that model aliases reference existing models.""" + for alias_name, alias_config in aliases.items(): + target = alias_config.get("target") + if target not in model_name_keys: + raise ConfigValidationError( + f"Model alias '{alias_name}' targets '{target}' which is not " + f"defined as a model. Available models: " + f"{', '.join(sorted(model_name_keys))}" + ) + + +def resolve_agent_orchestrator(config, endpoints): + """Resolve agent orchestrator from config overrides. + + Returns the orchestrator endpoint name, or None if not configured. + """ + use_orchestrator = config.get("overrides", {}).get( + "use_agent_orchestrator", False + ) + if not use_orchestrator: + return None + + if len(endpoints) == 0: + raise ConfigValidationError( + "Please provide agent orchestrator in the endpoints section " + "in your arch_config.yaml file" + ) + if len(endpoints) > 1: + raise ConfigValidationError( + "Please provide single agent orchestrator in the endpoints section " + "in your arch_config.yaml file" + ) + + return list(endpoints.keys())[0] + + +def build_template_data( + prompt_gateway, + llm_gateway, + config_yaml, + clusters, + model_providers, + tracing, + llms_with_endpoint, + agent_orchestrator, + listeners, +): + """Assemble the Jinja2 template rendering context. + + Note: arch_config and arch_llm_config are intentionally the same value. + Both are kept for backward compatibility with the Envoy template. + """ + config_string = yaml.dump(config_yaml) + return { + "prompt_gateway_listener": prompt_gateway, + "llm_gateway_listener": llm_gateway, + "arch_config": config_string, + "arch_llm_config": config_string, + "arch_clusters": clusters, + "arch_model_providers": model_providers, + "arch_tracing": tracing, + "local_llms": llms_with_endpoint, + "agent_orchestrator": agent_orchestrator, + "listeners": listeners, + } diff --git a/cli/planoai/docker_cli.py b/cli/planoai/docker_cli.py index 0e0bc2d77..0233909df 100644 --- a/cli/planoai/docker_cli.py +++ b/cli/planoai/docker_cli.py @@ -115,6 +115,28 @@ def stream_gateway_logs(follow, service="plano"): log.info(f"Failed to stream logs: {str(e)}") +def stream_access_logs(follow): + """Stream access logs from the running Plano container.""" + + follow_arg = "-f" if follow else "" + + stream_command = [ + "docker", + "exec", + PLANO_DOCKER_NAME, + "sh", + "-c", + f"tail {follow_arg} /var/log/access_*.log", + ] + + subprocess.run( + stream_command, + check=True, + stdout=sys.stdout, + stderr=sys.stderr, + ) + + def docker_validate_plano_schema(arch_config_file): import os diff --git a/cli/planoai/main.py b/cli/planoai/main.py index ac0fb0191..7e2301026 100644 --- a/cli/planoai/main.py +++ b/cli/planoai/main.py @@ -10,6 +10,7 @@ from planoai.docker_cli import ( docker_validate_plano_schema, stream_gateway_logs, + stream_access_logs, docker_container_status, ) from planoai.utils import ( @@ -17,7 +18,6 @@ get_llm_provider_access_keys, load_env_file_to_dict, set_log_level, - stream_access_logs, find_config_file, find_repo_root, ) diff --git a/cli/planoai/utils.py b/cli/planoai/utils.py index d55774f40..3917b9f88 100644 --- a/cli/planoai/utils.py +++ b/cli/planoai/utils.py @@ -1,10 +1,6 @@ -import glob import os -import subprocess -import sys import yaml import logging -from planoai.consts import PLANO_DOCKER_NAME # Standard env var for log level across all Plano components @@ -162,20 +158,15 @@ def convert_legacy_listeners( def get_llm_provider_access_keys(arch_config_file): + from planoai.config_validator import migrate_legacy_providers + with open(arch_config_file, "r") as file: arch_config = file.read() arch_config_yaml = yaml.safe_load(arch_config) access_key_list = [] - # Convert legacy llm_providers to model_providers - if "llm_providers" in arch_config_yaml: - if "model_providers" in arch_config_yaml: - raise Exception( - "Please provide either llm_providers or model_providers, not both. llm_providers is deprecated, please use model_providers instead" - ) - arch_config_yaml["model_providers"] = arch_config_yaml["llm_providers"] - del arch_config_yaml["llm_providers"] + arch_config_yaml = migrate_legacy_providers(arch_config_yaml) listeners, _, _ = convert_legacy_listeners( arch_config_yaml.get("listeners"), arch_config_yaml.get("model_providers") @@ -258,25 +249,3 @@ def find_config_file(path=".", file=None): return arch_config_file -def stream_access_logs(follow): - """ - Get the archgw access logs - """ - - follow_arg = "-f" if follow else "" - - stream_command = [ - "docker", - "exec", - PLANO_DOCKER_NAME, - "sh", - "-c", - f"tail {follow_arg} /var/log/access_*.log", - ] - - subprocess.run( - stream_command, - check=True, - stdout=sys.stdout, - stderr=sys.stderr, - ) diff --git a/cli/test/test_config_generator.py b/cli/test/test_config_generator.py index 214ea06c4..c373f990d 100644 --- a/cli/test/test_config_generator.py +++ b/cli/test/test_config_generator.py @@ -1,25 +1,473 @@ +"""Tests for config validation, processing, and generation. + +Tests are organized in layers: +1. Unit tests for config_providers (parse_url_endpoint, constants) +2. Unit tests for config_validator (pure validation functions, no I/O) +3. Integration tests for validate_and_render_schema (file I/O with tmp_path) +4. Legacy listener conversion tests (unchanged) +""" + import json +import os import pytest +import yaml from unittest import mock + +from planoai.config_providers import ( + ConfigValidationError, + SUPPORTED_PROVIDERS, + parse_url_endpoint, +) +from planoai.config_validator import ( + build_clusters, + migrate_legacy_providers, + process_model_providers, + validate_agents, + validate_listeners, + validate_model_aliases, + validate_prompt_targets, + validate_schema, + validate_tracing, + resolve_agent_orchestrator, +) from planoai.config_generator import validate_and_render_schema +from planoai.utils import convert_legacy_listeners -@pytest.fixture(autouse=True) -def cleanup_env(monkeypatch): - # Clean up environment variables and mocks after each test - yield - monkeypatch.undo() +# --------------------------------------------------------------------------- +# Layer 1: config_providers unit tests +# --------------------------------------------------------------------------- -def test_validate_and_render_happy_path(monkeypatch): - monkeypatch.setenv("ARCH_CONFIG_FILE", "fake_arch_config.yaml") - monkeypatch.setenv("ARCH_CONFIG_SCHEMA_FILE", "fake_arch_config_schema.yaml") - monkeypatch.setenv("ENVOY_CONFIG_TEMPLATE_FILE", "./envoy.template.yaml") - monkeypatch.setenv("ARCH_CONFIG_FILE_RENDERED", "fake_arch_config_rendered.yaml") - monkeypatch.setenv("ENVOY_CONFIG_FILE_RENDERED", "fake_envoy.yaml") - monkeypatch.setenv("TEMPLATE_ROOT", "../") +class TestParseUrlEndpoint: + def test_https_with_port(self): + result = parse_url_endpoint("https://example.com:8443") + assert result == { + "endpoint": "example.com", + "port": 8443, + "protocol": "https", + } + + def test_http_default_port(self): + result = parse_url_endpoint("http://example.com") + assert result == { + "endpoint": "example.com", + "port": 80, + "protocol": "http", + } + + def test_https_default_port(self): + result = parse_url_endpoint("https://example.com") + assert result == { + "endpoint": "example.com", + "port": 443, + "protocol": "https", + } - arch_config = """ + def test_with_path_prefix(self): + result = parse_url_endpoint("http://example.com/api/v2") + assert result["path_prefix"] == "/api/v2" + assert result["endpoint"] == "example.com" + assert result["port"] == 80 + + def test_invalid_scheme(self): + with pytest.raises(ConfigValidationError, match="scheme must be http or https"): + parse_url_endpoint("ftp://example.com") + + def test_no_scheme(self): + with pytest.raises(ConfigValidationError, match="scheme must be http or https"): + parse_url_endpoint("example.com") + + def test_no_hostname(self): + with pytest.raises(ConfigValidationError): + parse_url_endpoint("http://") + + +# --------------------------------------------------------------------------- +# Layer 2: config_validator unit tests +# --------------------------------------------------------------------------- + + +class TestMigrateLegacyProviders: + def test_no_migration_needed(self): + config = {"model_providers": [{"model": "openai/gpt-4o"}]} + result = migrate_legacy_providers(config) + assert "model_providers" in result + assert "llm_providers" not in result + + def test_migration_from_llm_providers(self): + config = {"llm_providers": [{"model": "openai/gpt-4o"}]} + result = migrate_legacy_providers(config) + assert result["model_providers"] == [{"model": "openai/gpt-4o"}] + assert "llm_providers" not in result + + def test_both_present_raises(self): + config = {"llm_providers": [], "model_providers": []} + with pytest.raises(ConfigValidationError, match="not both"): + migrate_legacy_providers(config) + + def test_does_not_mutate_input(self): + config = {"llm_providers": [{"model": "openai/gpt-4o"}]} + migrate_legacy_providers(config) + assert "llm_providers" in config # original unchanged + + +class TestValidateAgents: + def test_duplicate_agent_id(self): + agents = [ + {"id": "a1", "url": "http://localhost:8000"}, + {"id": "a1", "url": "http://localhost:8001"}, + ] + with pytest.raises(ConfigValidationError, match="Duplicate agent id"): + validate_agents(agents, []) + + def test_infers_clusters_from_urls(self): + agents = [{"id": "a1", "url": "http://localhost:8000"}] + clusters = validate_agents(agents, []) + assert "a1" in clusters + assert clusters["a1"]["port"] == 8000 + assert clusters["a1"]["endpoint"] == "localhost" + assert clusters["a1"]["protocol"] == "http" + + def test_agents_and_filters_combined(self): + agents = [{"id": "a1", "url": "http://localhost:8000"}] + filters = [{"id": "f1", "url": "http://localhost:9000"}] + clusters = validate_agents(agents, filters) + assert "a1" in clusters + assert "f1" in clusters + + def test_duplicate_across_agents_and_filters(self): + agents = [{"id": "shared", "url": "http://localhost:8000"}] + filters = [{"id": "shared", "url": "http://localhost:9000"}] + with pytest.raises(ConfigValidationError, match="Duplicate agent id"): + validate_agents(agents, filters) + + def test_agent_without_url(self): + agents = [{"id": "a1"}] + clusters = validate_agents(agents, []) + assert clusters == {} + + +class TestBuildClusters: + def test_merge_agent_and_explicit_endpoints(self): + agent_inferred = { + "agent1": {"endpoint": "localhost", "port": 8000, "protocol": "http"} + } + endpoints = { + "explicit1": {"endpoint": "api.example.com", "port": 443, "protocol": "https"} + } + result = build_clusters(endpoints, agent_inferred) + assert "agent1" in result + assert "explicit1" in result + + def test_infer_port_from_host_colon_port(self): + clusters = build_clusters( + {"svc": {"endpoint": "localhost:9090"}}, {} + ) + assert clusters["svc"]["endpoint"] == "localhost" + assert clusters["svc"]["port"] == 9090 + + def test_default_port_http(self): + clusters = build_clusters( + {"svc": {"endpoint": "localhost", "protocol": "http"}}, {} + ) + assert clusters["svc"]["port"] == 80 + + +class TestValidatePromptTargets: + def test_valid_targets(self): + config = { + "prompt_targets": [ + {"endpoint": {"name": "my_endpoint"}}, + ] + } + clusters = {"my_endpoint": {"endpoint": "localhost", "port": 80}} + validate_prompt_targets(config, clusters) # should not raise + + def test_unknown_endpoint(self): + config = { + "prompt_targets": [ + {"endpoint": {"name": "nonexistent"}}, + ] + } + with pytest.raises(ConfigValidationError, match="Unknown endpoint"): + validate_prompt_targets(config, {}) + + def test_target_without_name(self): + config = {"prompt_targets": [{"endpoint": {}}]} + validate_prompt_targets(config, {}) # should not raise + + +class TestValidateTracing: + def test_valid_http_endpoint(self): + result = validate_tracing( + {"random_sampling": 100}, + "http://host.docker.internal:4317", + ) + assert result["opentracing_grpc_endpoint"] == "http://host.docker.internal:4317" + assert result["random_sampling"] == 100 + + def test_invalid_scheme(self): + with pytest.raises(ConfigValidationError, match="scheme must be http"): + validate_tracing( + {"opentracing_grpc_endpoint": "https://example.com:4317"}, + "http://default:4317", + ) + + def test_invalid_path(self): + with pytest.raises(ConfigValidationError, match="path must be empty"): + validate_tracing( + {"opentracing_grpc_endpoint": "http://example.com:4317/some/path"}, + "http://default:4317", + ) + + def test_empty_tracing_uses_default(self): + result = validate_tracing({}, "http://default:4317") + assert result["opentracing_grpc_endpoint"] == "http://default:4317" + + +class TestProcessModelProviders: + def _make_listeners(self, model_providers): + return [{"model_providers": model_providers}] + + def test_happy_path(self): + listeners = self._make_listeners([ + {"model": "openai/gpt-4o", "access_key": "$KEY", "default": True}, + ]) + providers, llms, keys = process_model_providers(listeners, {}) + # Should have the user provider + internal providers + names = [p["name"] for p in providers] + assert "openai/gpt-4o" in names + assert "arch-function" in names + assert "plano-orchestrator" in names + + def test_duplicate_provider_name(self): + listeners = self._make_listeners([ + {"name": "test1", "model": "openai/gpt-4o", "access_key": "$KEY"}, + {"name": "test1", "model": "openai/gpt-4o-mini", "access_key": "$KEY"}, + ]) + with pytest.raises(ConfigValidationError, match="Duplicate model_provider name"): + process_model_providers(listeners, {}) + + def test_provider_interface_with_supported_provider(self): + listeners = self._make_listeners([ + { + "model": "openai/gpt-4o", + "access_key": "$KEY", + "provider_interface": "openai", + }, + ]) + with pytest.raises( + ConfigValidationError, + match="provide provider interface as part of model name", + ): + process_model_providers(listeners, {}) + + def test_duplicate_model_id(self): + listeners = self._make_listeners([ + {"model": "openai/gpt-4o", "access_key": "$KEY"}, + {"model": "mistral/gpt-4o"}, + ]) + with pytest.raises(ConfigValidationError, match="Duplicate model_id"): + process_model_providers(listeners, {}) + + def test_custom_provider_requires_base_url(self): + listeners = self._make_listeners([ + {"model": "custom/gpt-4o"}, + ]) + with pytest.raises( + ConfigValidationError, match="Must provide base_url and provider_interface" + ): + process_model_providers(listeners, {}) + + def test_base_url_with_path_prefix(self): + listeners = self._make_listeners([ + { + "model": "custom/gpt-4o", + "base_url": "http://custom.com/api/v2", + "provider_interface": "openai", + }, + ]) + providers, llms, keys = process_model_providers(listeners, {}) + # Find the custom provider + custom = next(p for p in providers if p.get("cluster_name")) + assert custom["base_url_path_prefix"] == "/api/v2" + assert custom["endpoint"] == "custom.com" + assert custom["port"] == 80 + + def test_duplicate_routing_preference_name(self): + listeners = self._make_listeners([ + {"model": "openai/gpt-4o-mini", "access_key": "$KEY", "default": True}, + { + "model": "openai/gpt-4o", + "access_key": "$KEY", + "routing_preferences": [ + {"name": "code understanding", "description": "explains code"}, + ], + }, + { + "model": "openai/gpt-4.1", + "access_key": "$KEY", + "routing_preferences": [ + {"name": "code understanding", "description": "generates code"}, + ], + }, + ]) + with pytest.raises( + ConfigValidationError, match="Duplicate routing preference name" + ): + process_model_providers(listeners, {}) + + def test_wildcard_cannot_be_default(self): + listeners = self._make_listeners([ + {"model": "openai/*", "access_key": "$KEY", "default": True}, + ]) + with pytest.raises(ConfigValidationError, match="Default models cannot be wildcards"): + process_model_providers(listeners, {}) + + def test_invalid_model_name_format(self): + listeners = self._make_listeners([ + {"model": "gpt-4o", "access_key": "$KEY"}, + ]) + with pytest.raises(ConfigValidationError, match="Invalid model name"): + process_model_providers(listeners, {}) + + def test_internal_providers_always_added(self): + listeners = self._make_listeners([ + {"model": "openai/gpt-4o", "access_key": "$KEY"}, + ]) + providers, _, _ = process_model_providers(listeners, {}) + names = [p["name"] for p in providers] + assert "arch-function" in names + assert "plano-orchestrator" in names + + def test_arch_router_added_when_routing_preferences_exist(self): + listeners = self._make_listeners([ + { + "model": "openai/gpt-4o", + "access_key": "$KEY", + "routing_preferences": [ + {"name": "coding", "description": "code tasks"}, + ], + }, + ]) + providers, _, _ = process_model_providers(listeners, {}) + names = [p["name"] for p in providers] + assert "arch-router" in names + + def test_skips_listeners_without_model_providers(self): + listeners = [ + {"name": "agent_listener", "type": "agent"}, + {"model_providers": [{"model": "openai/gpt-4o", "access_key": "$KEY"}]}, + ] + providers, _, _ = process_model_providers(listeners, {}) + names = [p["name"] for p in providers] + assert "openai/gpt-4o" in names + + +class TestValidateListeners: + def test_single_listener_with_providers(self): + listeners = [ + {"model_providers": [{"model": "openai/gpt-4o"}]}, + {"name": "agent_listener"}, + ] + validate_listeners(listeners) # should not raise + + def test_multiple_listeners_with_providers(self): + listeners = [ + {"model_providers": [{"model": "openai/gpt-4o"}]}, + {"model_providers": [{"model": "anthropic/claude-3"}]}, + ] + with pytest.raises(ConfigValidationError, match="not both"): + validate_listeners(listeners) + + +class TestValidateModelAliases: + def test_valid_alias(self): + validate_model_aliases( + {"fast": {"target": "gpt-4o"}}, + {"gpt-4o", "gpt-4o-mini"}, + ) + + def test_invalid_target(self): + with pytest.raises(ConfigValidationError, match="not defined as a model"): + validate_model_aliases( + {"fast": {"target": "nonexistent"}}, + {"gpt-4o"}, + ) + + def test_no_debug_artifact_in_error(self): + """Regression test: old code had 'Model alias 2 -' debug text.""" + with pytest.raises(ConfigValidationError) as exc_info: + validate_model_aliases({"fast": {"target": "bad"}}, {"gpt-4o"}) + assert "2 -" not in str(exc_info.value) + + +class TestResolveAgentOrchestrator: + def test_not_enabled(self): + result = resolve_agent_orchestrator({}, {}) + assert result is None + + def test_enabled_with_single_endpoint(self): + config = {"overrides": {"use_agent_orchestrator": True}} + endpoints = {"my_agent": {"endpoint": "localhost", "port": 8000}} + result = resolve_agent_orchestrator(config, endpoints) + assert result == "my_agent" + + def test_enabled_with_no_endpoints(self): + config = {"overrides": {"use_agent_orchestrator": True}} + with pytest.raises(ConfigValidationError, match="provide agent orchestrator"): + resolve_agent_orchestrator(config, {}) + + def test_enabled_with_multiple_endpoints(self): + config = {"overrides": {"use_agent_orchestrator": True}} + endpoints = {"a": {}, "b": {}} + with pytest.raises(ConfigValidationError, match="single agent orchestrator"): + resolve_agent_orchestrator(config, endpoints) + + +class TestValidateSchema: + @pytest.fixture + def schema(self): + schema_path = os.path.join( + os.path.dirname(__file__), "..", "..", "config", "arch_config_schema.yaml" + ) + with open(schema_path) as f: + return yaml.safe_load(f.read()) + + def test_valid_config(self, schema): + config = { + "version": "v0.1.0", + "listeners": { + "egress_traffic": {"port": 12000}, + }, + } + validate_schema(config, schema) # should not raise + + def test_invalid_config(self, schema): + config = {"invalid_key": "bad"} + with pytest.raises(ConfigValidationError, match="Schema validation failed"): + validate_schema(config, schema) + + +# --------------------------------------------------------------------------- +# Layer 3: Integration tests +# --------------------------------------------------------------------------- + + +class TestValidateAndRenderSchema: + """Integration tests that exercise the full pipeline.""" + + @pytest.fixture + def schema_content(self): + schema_path = os.path.join( + os.path.dirname(__file__), "..", "..", "config", "arch_config_schema.yaml" + ) + with open(schema_path) as f: + return f.read() + + def test_happy_path_legacy_format(self, tmp_path, schema_content, monkeypatch): + config_content = """\ version: v0.1.0 listeners: @@ -30,7 +478,6 @@ def test_validate_and_render_happy_path(monkeypatch): timeout: 30s llm_providers: - - model: openai/gpt-4o-mini access_key: $OPENAI_API_KEY default: true @@ -39,50 +486,39 @@ def test_validate_and_render_happy_path(monkeypatch): access_key: $OPENAI_API_KEY routing_preferences: - name: code understanding - description: understand and explain existing code snippets, functions, or libraries + description: understand and explain code - model: openai/gpt-4.1 access_key: $OPENAI_API_KEY routing_preferences: - name: code generation - description: generating new code snippets, functions, or boilerplate based on user prompts or requirements + description: generate new code tracing: random_sampling: 100 """ - arch_config_schema = "" - with open("../config/arch_config_schema.yaml", "r") as file: - arch_config_schema = file.read() - - m_open = mock.mock_open() - # Provide enough file handles for all open() calls in validate_and_render_schema - m_open.side_effect = [ - # Removed empty read - was causing validation failures - mock.mock_open(read_data=arch_config).return_value, # ARCH_CONFIG_FILE - mock.mock_open( - read_data=arch_config_schema - ).return_value, # ARCH_CONFIG_SCHEMA_FILE - mock.mock_open(read_data=arch_config).return_value, # ARCH_CONFIG_FILE - mock.mock_open( - read_data=arch_config_schema - ).return_value, # ARCH_CONFIG_SCHEMA_FILE - mock.mock_open().return_value, # ENVOY_CONFIG_FILE_RENDERED (write) - mock.mock_open().return_value, # ARCH_CONFIG_FILE_RENDERED (write) - ] - with mock.patch("builtins.open", m_open): - with mock.patch("planoai.config_generator.Environment"): + config_file = tmp_path / "config.yaml" + config_file.write_text(config_content) + schema_file = tmp_path / "schema.yaml" + schema_file.write_text(schema_content) + envoy_out = tmp_path / "envoy.yaml" + config_out = tmp_path / "config_rendered.yaml" + + monkeypatch.setenv("ARCH_CONFIG_FILE", str(config_file)) + monkeypatch.setenv("ARCH_CONFIG_SCHEMA_FILE", str(schema_file)) + monkeypatch.setenv("ENVOY_CONFIG_FILE_RENDERED", str(envoy_out)) + monkeypatch.setenv("ARCH_CONFIG_FILE_RENDERED", str(config_out)) + + mock_env = mock.patch("planoai.config_generator.Environment") + with mock_env as MockEnv: + mock_template = MockEnv.return_value.get_template.return_value + mock_template.render.return_value = "# rendered envoy config" validate_and_render_schema() + assert config_out.exists() -def test_validate_and_render_happy_path_agent_config(monkeypatch): - monkeypatch.setenv("ARCH_CONFIG_FILE", "fake_arch_config.yaml") - monkeypatch.setenv("ARCH_CONFIG_SCHEMA_FILE", "fake_arch_config_schema.yaml") - monkeypatch.setenv("ENVOY_CONFIG_TEMPLATE_FILE", "./envoy.template.yaml") - monkeypatch.setenv("ARCH_CONFIG_FILE_RENDERED", "fake_arch_config_rendered.yaml") - monkeypatch.setenv("ENVOY_CONFIG_FILE_RENDERED", "fake_envoy.yaml") - monkeypatch.setenv("TEMPLATE_ROOT", "../") - - arch_config = """ + def test_happy_path_agent_config(self, tmp_path, schema_content, monkeypatch): + config_content = """\ version: v0.3.0 agents: @@ -123,345 +559,132 @@ def test_validate_and_render_happy_path_agent_config(monkeypatch): - access_key: ${OPENAI_API_KEY} model: openai/gpt-4o """ - arch_config_schema = "" - with open("../config/arch_config_schema.yaml", "r") as file: - arch_config_schema = file.read() - - m_open = mock.mock_open() - # Provide enough file handles for all open() calls in validate_and_render_schema - m_open.side_effect = [ - # Removed empty read - was causing validation failures - mock.mock_open(read_data=arch_config).return_value, # ARCH_CONFIG_FILE - mock.mock_open( - read_data=arch_config_schema - ).return_value, # ARCH_CONFIG_SCHEMA_FILE - mock.mock_open(read_data=arch_config).return_value, # ARCH_CONFIG_FILE - mock.mock_open( - read_data=arch_config_schema - ).return_value, # ARCH_CONFIG_SCHEMA_FILE - mock.mock_open().return_value, # ENVOY_CONFIG_FILE_RENDERED (write) - mock.mock_open().return_value, # ARCH_CONFIG_FILE_RENDERED (write) - ] - with mock.patch("builtins.open", m_open): - with mock.patch("planoai.config_generator.Environment"): + config_file = tmp_path / "config.yaml" + config_file.write_text(config_content) + schema_file = tmp_path / "schema.yaml" + schema_file.write_text(schema_content) + envoy_out = tmp_path / "envoy.yaml" + config_out = tmp_path / "config_rendered.yaml" + + monkeypatch.setenv("ARCH_CONFIG_FILE", str(config_file)) + monkeypatch.setenv("ARCH_CONFIG_SCHEMA_FILE", str(schema_file)) + monkeypatch.setenv("ENVOY_CONFIG_FILE_RENDERED", str(envoy_out)) + monkeypatch.setenv("ARCH_CONFIG_FILE_RENDERED", str(config_out)) + + mock_env = mock.patch("planoai.config_generator.Environment") + with mock_env as MockEnv: + mock_template = MockEnv.return_value.get_template.return_value + mock_template.render.return_value = "# rendered envoy config" validate_and_render_schema() + assert config_out.exists() -arch_config_test_cases = [ - { - "id": "duplicate_provider_name", - "expected_error": "Duplicate model_provider name", - "arch_config": """ -version: v0.1.0 -listeners: - egress_traffic: - address: 0.0.0.0 - port: 12000 - message_format: openai - timeout: 30s - -llm_providers: - - - name: test1 - model: openai/gpt-4o - access_key: $OPENAI_API_KEY - - - name: test1 - model: openai/gpt-4o - access_key: $OPENAI_API_KEY - -""", - }, - { - "id": "provider_interface_with_model_id", - "expected_error": "Please provide provider interface as part of model name", - "arch_config": """ -version: v0.1.0 - -listeners: - egress_traffic: - address: 0.0.0.0 - port: 12000 - message_format: openai - timeout: 30s - -llm_providers: - - - model: openai/gpt-4o - access_key: $OPENAI_API_KEY - provider_interface: openai - -""", - }, - { - "id": "duplicate_model_id", - "expected_error": "Duplicate model_id", - "arch_config": """ -version: v0.1.0 +# --------------------------------------------------------------------------- +# Layer 4: Legacy listener conversion tests (unchanged behavior) +# --------------------------------------------------------------------------- -listeners: - egress_traffic: - address: 0.0.0.0 - port: 12000 - message_format: openai - timeout: 30s - -llm_providers: - - - model: openai/gpt-4o - access_key: $OPENAI_API_KEY - - - model: mistral/gpt-4o - -""", - }, - { - "id": "custom_provider_base_url", - "expected_error": "Must provide base_url and provider_interface", - "arch_config": """ -version: v0.1.0 - -listeners: - egress_traffic: - address: 0.0.0.0 - port: 12000 - message_format: openai - timeout: 30s - -llm_providers: - - model: custom/gpt-4o - -""", - }, - { - "id": "base_url_with_path_prefix", - "expected_error": None, - "arch_config": """ -version: v0.1.0 - -listeners: - egress_traffic: - address: 0.0.0.0 - port: 12000 - message_format: openai - timeout: 30s - -llm_providers: - - - model: custom/gpt-4o - base_url: "http://custom.com/api/v2" - provider_interface: openai - -""", - }, - { - "id": "duplicate_routeing_preference_name", - "expected_error": "Duplicate routing preference name", - "arch_config": """ -version: v0.1.0 - -listeners: - egress_traffic: - address: 0.0.0.0 - port: 12000 - message_format: openai - timeout: 30s - -llm_providers: - - - model: openai/gpt-4o-mini - access_key: $OPENAI_API_KEY - default: true - - - model: openai/gpt-4o - access_key: $OPENAI_API_KEY - routing_preferences: - - name: code understanding - description: understand and explain existing code snippets, functions, or libraries - - - model: openai/gpt-4.1 - access_key: $OPENAI_API_KEY - routing_preferences: - - name: code understanding - description: generating new code snippets, functions, or boilerplate based on user prompts or requirements - -tracing: - random_sampling: 100 - -""", - }, -] +class TestConvertLegacyListeners: + def test_dict_format_with_both_listeners(self): + listeners = { + "ingress_traffic": { + "address": "0.0.0.0", + "port": 10000, + "timeout": "30s", + }, + "egress_traffic": { + "address": "0.0.0.0", + "port": 12000, + "timeout": "30s", + }, + } + llm_providers = [ + {"model": "openai/gpt-4o", "access_key": "test_key"}, + ] + + updated, llm_gateway, prompt_gateway = convert_legacy_listeners( + listeners, llm_providers + ) + assert isinstance(updated, list) + assert llm_gateway is not None + assert prompt_gateway is not None + assert updated == [ + { + "name": "egress_traffic", + "type": "model_listener", + "port": 12000, + "address": "0.0.0.0", + "timeout": "30s", + "model_providers": [ + {"model": "openai/gpt-4o", "access_key": "test_key"} + ], + }, + { + "name": "ingress_traffic", + "type": "prompt_listener", + "port": 10000, + "address": "0.0.0.0", + "timeout": "30s", + }, + ] -@pytest.mark.parametrize( - "arch_config_test_case", - arch_config_test_cases, - ids=[case["id"] for case in arch_config_test_cases], -) -def test_validate_and_render_schema_tests(monkeypatch, arch_config_test_case): - monkeypatch.setenv("ARCH_CONFIG_FILE", "fake_arch_config.yaml") - monkeypatch.setenv("ARCH_CONFIG_SCHEMA_FILE", "fake_arch_config_schema.yaml") - monkeypatch.setenv("ENVOY_CONFIG_TEMPLATE_FILE", "./envoy.template.yaml") - monkeypatch.setenv("ARCH_CONFIG_FILE_RENDERED", "fake_arch_config_rendered.yaml") - monkeypatch.setenv("ENVOY_CONFIG_FILE_RENDERED", "fake_envoy.yaml") - monkeypatch.setenv("TEMPLATE_ROOT", "../") - - arch_config = arch_config_test_case["arch_config"] - expected_error = arch_config_test_case.get("expected_error") - - arch_config_schema = "" - with open("../config/arch_config_schema.yaml", "r") as file: - arch_config_schema = file.read() - - m_open = mock.mock_open() - # Provide enough file handles for all open() calls in validate_and_render_schema - m_open.side_effect = [ - mock.mock_open( - read_data=arch_config - ).return_value, # validate_prompt_config: ARCH_CONFIG_FILE - mock.mock_open( - read_data=arch_config_schema - ).return_value, # validate_prompt_config: ARCH_CONFIG_SCHEMA_FILE - mock.mock_open( - read_data=arch_config - ).return_value, # validate_and_render_schema: ARCH_CONFIG_FILE - mock.mock_open( - read_data=arch_config_schema - ).return_value, # validate_and_render_schema: ARCH_CONFIG_SCHEMA_FILE - mock.mock_open().return_value, # ENVOY_CONFIG_FILE_RENDERED (write) - mock.mock_open().return_value, # ARCH_CONFIG_FILE_RENDERED (write) - ] - with mock.patch("builtins.open", m_open): - with mock.patch("planoai.config_generator.Environment"): - if expected_error: - # Test expects an error - with pytest.raises(Exception) as excinfo: - validate_and_render_schema() - assert expected_error in str(excinfo.value) - else: - # Test expects success - no exception should be raised - validate_and_render_schema() - - -def test_convert_legacy_llm_providers(): - from planoai.utils import convert_legacy_listeners - - listeners = { - "ingress_traffic": { - "address": "0.0.0.0", - "port": 10000, - "timeout": "30s", - }, - "egress_traffic": { + assert llm_gateway == { "address": "0.0.0.0", - "port": 12000, - "timeout": "30s", - }, - } - llm_providers = [ - { - "model": "openai/gpt-4o", - "access_key": "test_key", - } - ] - - updated_providers, llm_gateway, prompt_gateway = convert_legacy_listeners( - listeners, llm_providers - ) - assert isinstance(updated_providers, list) - assert llm_gateway is not None - assert prompt_gateway is not None - print(json.dumps(updated_providers)) - assert updated_providers == [ - { + "model_providers": [ + {"access_key": "test_key", "model": "openai/gpt-4o"}, + ], "name": "egress_traffic", "type": "model_listener", "port": 12000, - "address": "0.0.0.0", - "timeout": "30s", - "model_providers": [{"model": "openai/gpt-4o", "access_key": "test_key"}], - }, - { - "name": "ingress_traffic", - "type": "prompt_listener", - "port": 10000, - "address": "0.0.0.0", "timeout": "30s", - }, - ] + } - assert llm_gateway == { - "address": "0.0.0.0", - "model_providers": [ - { - "access_key": "test_key", - "model": "openai/gpt-4o", - }, - ], - "name": "egress_traffic", - "type": "model_listener", - "port": 12000, - "timeout": "30s", - } - - assert prompt_gateway == { - "address": "0.0.0.0", - "name": "ingress_traffic", - "port": 10000, - "timeout": "30s", - "type": "prompt_listener", - } - - -def test_convert_legacy_llm_providers_no_prompt_gateway(): - from planoai.utils import convert_legacy_listeners - - listeners = { - "egress_traffic": { + assert prompt_gateway == { "address": "0.0.0.0", - "port": 12000, + "name": "ingress_traffic", + "port": 10000, "timeout": "30s", + "type": "prompt_listener", } - } - llm_providers = [ - { - "model": "openai/gpt-4o", - "access_key": "test_key", + + def test_dict_format_no_prompt_gateway(self): + listeners = { + "egress_traffic": { + "address": "0.0.0.0", + "port": 12000, + "timeout": "30s", + } } - ] - - updated_providers, llm_gateway, prompt_gateway = convert_legacy_listeners( - listeners, llm_providers - ) - assert isinstance(updated_providers, list) - assert llm_gateway is not None - assert prompt_gateway is not None - assert updated_providers == [ - { + llm_providers = [ + {"model": "openai/gpt-4o", "access_key": "test_key"}, + ] + + updated, llm_gateway, prompt_gateway = convert_legacy_listeners( + listeners, llm_providers + ) + assert isinstance(updated, list) + assert llm_gateway is not None + assert prompt_gateway is not None + assert updated == [ + { + "address": "0.0.0.0", + "model_providers": [ + {"access_key": "test_key", "model": "openai/gpt-4o"}, + ], + "name": "egress_traffic", + "port": 12000, + "timeout": "30s", + "type": "model_listener", + } + ] + assert llm_gateway == { "address": "0.0.0.0", "model_providers": [ - { - "access_key": "test_key", - "model": "openai/gpt-4o", - }, + {"access_key": "test_key", "model": "openai/gpt-4o"}, ], "name": "egress_traffic", + "type": "model_listener", "port": 12000, "timeout": "30s", - "type": "model_listener", } - ] - assert llm_gateway == { - "address": "0.0.0.0", - "model_providers": [ - { - "access_key": "test_key", - "model": "openai/gpt-4o", - }, - ], - "name": "egress_traffic", - "type": "model_listener", - "port": 12000, - "timeout": "30s", - } diff --git a/config/docker-compose.dev.yaml b/config/docker-compose.dev.yaml index 2e061939f..ec0fe6036 100644 --- a/config/docker-compose.dev.yaml +++ b/config/docker-compose.dev.yaml @@ -13,6 +13,8 @@ services: - ./envoy.template.yaml:/app/envoy.template.yaml - ./arch_config_schema.yaml:/app/arch_config_schema.yaml - ../cli/planoai/config_generator.py:/app/planoai/config_generator.py + - ../cli/planoai/config_validator.py:/app/planoai/config_validator.py + - ../cli/planoai/config_providers.py:/app/planoai/config_providers.py - ../crates/target/wasm32-wasip1/release/llm_gateway.wasm:/etc/envoy/proxy-wasm-plugins/llm_gateway.wasm - ../crates/target/wasm32-wasip1/release/prompt_gateway.wasm:/etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm - ~/archgw_logs:/var/log/ From 0a84239bbf42af001c897ec2fda95779aa82949c Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 12 Feb 2026 05:34:17 +0000 Subject: [PATCH 2/6] Update uv.lock after dependency sync Co-Authored-By: Claude Opus 4.6 --- cli/uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/uv.lock b/cli/uv.lock index d7e6b3a0c..f8f72721e 100644 --- a/cli/uv.lock +++ b/cli/uv.lock @@ -337,7 +337,7 @@ wheels = [ [[package]] name = "planoai" -version = "0.4.4" +version = "0.4.6" source = { editable = "." } dependencies = [ { name = "click" }, From df2f1c7f29a28aaf7fbefbaa3a144c7682974cff Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 12 Feb 2026 05:44:36 +0000 Subject: [PATCH 3/6] Fix trailing newlines in utils.py Co-Authored-By: Claude Opus 4.6 --- cli/planoai/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/cli/planoai/utils.py b/cli/planoai/utils.py index 3917b9f88..d673634a6 100644 --- a/cli/planoai/utils.py +++ b/cli/planoai/utils.py @@ -247,5 +247,3 @@ def find_config_file(path=".", file=None): if not os.path.exists(arch_config_file): arch_config_file = os.path.abspath(os.path.join(path, "arch_config.yaml")) return arch_config_file - - From 9858ce4d31cc16a0e299d1b9a03a7886410136e9 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 12 Feb 2026 06:01:30 +0000 Subject: [PATCH 4/6] Upgrade black 23.1.0 -> 25.1.0 for Python 3.14 compatibility black 23.1.0 uses ast.Str which was removed in Python 3.14, causing pre-commit CI to crash on any PR that touches Python files. Co-Authored-By: Claude Opus 4.6 --- .pre-commit-config.yaml | 2 +- cli/planoai/config_generator.py | 4 +- cli/planoai/config_validator.py | 4 +- cli/planoai/main.py | 6 +- cli/planoai/targets.py | 18 ++-- cli/test/test_config_generator.py | 166 +++++++++++++++++------------- 6 files changed, 112 insertions(+), 88 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 42b439433..6373e73b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: pass_filenames: false - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 25.1.0 hooks: - id: black language_version: python3 diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index e7142a676..fd1bae5a7 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -102,9 +102,7 @@ def validate_and_render_schema(): if "model_aliases" in config: validate_model_aliases(config["model_aliases"], model_name_keys) - agent_orchestrator = resolve_agent_orchestrator( - config, config.get("endpoints", {}) - ) + agent_orchestrator = resolve_agent_orchestrator(config, config.get("endpoints", {})) data = build_template_data( prompt_gateway, diff --git a/cli/planoai/config_validator.py b/cli/planoai/config_validator.py index a5a0b0e80..8256d495c 100644 --- a/cli/planoai/config_validator.py +++ b/cli/planoai/config_validator.py @@ -435,9 +435,7 @@ def resolve_agent_orchestrator(config, endpoints): Returns the orchestrator endpoint name, or None if not configured. """ - use_orchestrator = config.get("overrides", {}).get( - "use_agent_orchestrator", False - ) + use_orchestrator = config.get("overrides", {}).get("use_agent_orchestrator", False) if not use_orchestrator: return None diff --git a/cli/planoai/main.py b/cli/planoai/main.py index 7e2301026..1b4aa3e27 100644 --- a/cli/planoai/main.py +++ b/cli/planoai/main.py @@ -296,9 +296,9 @@ def up(file, path, foreground, with_tracing, tracing_port): sys.exit(1) # Update the OTEL endpoint so the gateway sends traces to the right port - env_stage[ - "OTEL_TRACING_GRPC_ENDPOINT" - ] = f"http://host.docker.internal:{tracing_port}" + env_stage["OTEL_TRACING_GRPC_ENDPOINT"] = ( + f"http://host.docker.internal:{tracing_port}" + ) env.update(env_stage) try: diff --git a/cli/planoai/targets.py b/cli/planoai/targets.py index 73b14b36e..7c56f2b72 100644 --- a/cli/planoai/targets.py +++ b/cli/planoai/targets.py @@ -189,26 +189,26 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list: if isinstance( arg.annotation.value, ast.Name ) and arg.annotation.value.id in ["list", "tuple", "set", "dict"]: - param_info[ - "type" - ] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc. + param_info["type"] = ( + f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc. + ) else: param_info["type"] = "[UNKNOWN - PLEASE FIX]" # Default for unknown types else: - param_info[ - "type" - ] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type + param_info["type"] = ( + "[UNKNOWN - PLEASE FIX]" # If unable to detect type + ) # Handle default values if default is not None: if isinstance(default, ast.Constant) or isinstance( default, ast.NameConstant ): - param_info[ - "default" - ] = default.value # Use the default value directly + param_info["default"] = ( + default.value + ) # Use the default value directly else: param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type param_info["required"] = False # Optional since it has a default value diff --git a/cli/test/test_config_generator.py b/cli/test/test_config_generator.py index c373f990d..bd7007f96 100644 --- a/cli/test/test_config_generator.py +++ b/cli/test/test_config_generator.py @@ -154,16 +154,18 @@ def test_merge_agent_and_explicit_endpoints(self): "agent1": {"endpoint": "localhost", "port": 8000, "protocol": "http"} } endpoints = { - "explicit1": {"endpoint": "api.example.com", "port": 443, "protocol": "https"} + "explicit1": { + "endpoint": "api.example.com", + "port": 443, + "protocol": "https", + } } result = build_clusters(endpoints, agent_inferred) assert "agent1" in result assert "explicit1" in result def test_infer_port_from_host_colon_port(self): - clusters = build_clusters( - {"svc": {"endpoint": "localhost:9090"}}, {} - ) + clusters = build_clusters({"svc": {"endpoint": "localhost:9090"}}, {}) assert clusters["svc"]["endpoint"] == "localhost" assert clusters["svc"]["port"] == 9090 @@ -231,9 +233,11 @@ def _make_listeners(self, model_providers): return [{"model_providers": model_providers}] def test_happy_path(self): - listeners = self._make_listeners([ - {"model": "openai/gpt-4o", "access_key": "$KEY", "default": True}, - ]) + listeners = self._make_listeners( + [ + {"model": "openai/gpt-4o", "access_key": "$KEY", "default": True}, + ] + ) providers, llms, keys = process_model_providers(listeners, {}) # Should have the user provider + internal providers names = [p["name"] for p in providers] @@ -242,21 +246,27 @@ def test_happy_path(self): assert "plano-orchestrator" in names def test_duplicate_provider_name(self): - listeners = self._make_listeners([ - {"name": "test1", "model": "openai/gpt-4o", "access_key": "$KEY"}, - {"name": "test1", "model": "openai/gpt-4o-mini", "access_key": "$KEY"}, - ]) - with pytest.raises(ConfigValidationError, match="Duplicate model_provider name"): + listeners = self._make_listeners( + [ + {"name": "test1", "model": "openai/gpt-4o", "access_key": "$KEY"}, + {"name": "test1", "model": "openai/gpt-4o-mini", "access_key": "$KEY"}, + ] + ) + with pytest.raises( + ConfigValidationError, match="Duplicate model_provider name" + ): process_model_providers(listeners, {}) def test_provider_interface_with_supported_provider(self): - listeners = self._make_listeners([ - { - "model": "openai/gpt-4o", - "access_key": "$KEY", - "provider_interface": "openai", - }, - ]) + listeners = self._make_listeners( + [ + { + "model": "openai/gpt-4o", + "access_key": "$KEY", + "provider_interface": "openai", + }, + ] + ) with pytest.raises( ConfigValidationError, match="provide provider interface as part of model name", @@ -264,30 +274,36 @@ def test_provider_interface_with_supported_provider(self): process_model_providers(listeners, {}) def test_duplicate_model_id(self): - listeners = self._make_listeners([ - {"model": "openai/gpt-4o", "access_key": "$KEY"}, - {"model": "mistral/gpt-4o"}, - ]) + listeners = self._make_listeners( + [ + {"model": "openai/gpt-4o", "access_key": "$KEY"}, + {"model": "mistral/gpt-4o"}, + ] + ) with pytest.raises(ConfigValidationError, match="Duplicate model_id"): process_model_providers(listeners, {}) def test_custom_provider_requires_base_url(self): - listeners = self._make_listeners([ - {"model": "custom/gpt-4o"}, - ]) + listeners = self._make_listeners( + [ + {"model": "custom/gpt-4o"}, + ] + ) with pytest.raises( ConfigValidationError, match="Must provide base_url and provider_interface" ): process_model_providers(listeners, {}) def test_base_url_with_path_prefix(self): - listeners = self._make_listeners([ - { - "model": "custom/gpt-4o", - "base_url": "http://custom.com/api/v2", - "provider_interface": "openai", - }, - ]) + listeners = self._make_listeners( + [ + { + "model": "custom/gpt-4o", + "base_url": "http://custom.com/api/v2", + "provider_interface": "openai", + }, + ] + ) providers, llms, keys = process_model_providers(listeners, {}) # Find the custom provider custom = next(p for p in providers if p.get("cluster_name")) @@ -296,61 +312,73 @@ def test_base_url_with_path_prefix(self): assert custom["port"] == 80 def test_duplicate_routing_preference_name(self): - listeners = self._make_listeners([ - {"model": "openai/gpt-4o-mini", "access_key": "$KEY", "default": True}, - { - "model": "openai/gpt-4o", - "access_key": "$KEY", - "routing_preferences": [ - {"name": "code understanding", "description": "explains code"}, - ], - }, - { - "model": "openai/gpt-4.1", - "access_key": "$KEY", - "routing_preferences": [ - {"name": "code understanding", "description": "generates code"}, - ], - }, - ]) + listeners = self._make_listeners( + [ + {"model": "openai/gpt-4o-mini", "access_key": "$KEY", "default": True}, + { + "model": "openai/gpt-4o", + "access_key": "$KEY", + "routing_preferences": [ + {"name": "code understanding", "description": "explains code"}, + ], + }, + { + "model": "openai/gpt-4.1", + "access_key": "$KEY", + "routing_preferences": [ + {"name": "code understanding", "description": "generates code"}, + ], + }, + ] + ) with pytest.raises( ConfigValidationError, match="Duplicate routing preference name" ): process_model_providers(listeners, {}) def test_wildcard_cannot_be_default(self): - listeners = self._make_listeners([ - {"model": "openai/*", "access_key": "$KEY", "default": True}, - ]) - with pytest.raises(ConfigValidationError, match="Default models cannot be wildcards"): + listeners = self._make_listeners( + [ + {"model": "openai/*", "access_key": "$KEY", "default": True}, + ] + ) + with pytest.raises( + ConfigValidationError, match="Default models cannot be wildcards" + ): process_model_providers(listeners, {}) def test_invalid_model_name_format(self): - listeners = self._make_listeners([ - {"model": "gpt-4o", "access_key": "$KEY"}, - ]) + listeners = self._make_listeners( + [ + {"model": "gpt-4o", "access_key": "$KEY"}, + ] + ) with pytest.raises(ConfigValidationError, match="Invalid model name"): process_model_providers(listeners, {}) def test_internal_providers_always_added(self): - listeners = self._make_listeners([ - {"model": "openai/gpt-4o", "access_key": "$KEY"}, - ]) + listeners = self._make_listeners( + [ + {"model": "openai/gpt-4o", "access_key": "$KEY"}, + ] + ) providers, _, _ = process_model_providers(listeners, {}) names = [p["name"] for p in providers] assert "arch-function" in names assert "plano-orchestrator" in names def test_arch_router_added_when_routing_preferences_exist(self): - listeners = self._make_listeners([ - { - "model": "openai/gpt-4o", - "access_key": "$KEY", - "routing_preferences": [ - {"name": "coding", "description": "code tasks"}, - ], - }, - ]) + listeners = self._make_listeners( + [ + { + "model": "openai/gpt-4o", + "access_key": "$KEY", + "routing_preferences": [ + {"name": "coding", "description": "code tasks"}, + ], + }, + ] + ) providers, _, _ = process_model_providers(listeners, {}) names = [p["name"] for p in providers] assert "arch-router" in names From c422f06d00211da5c9e8920a155684a6ee1f4d27 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 12 Feb 2026 06:09:23 +0000 Subject: [PATCH 5/6] Reformat remaining Python files for black 25.1.0 Co-Authored-By: Claude Opus 4.6 --- .../crewai/flight_agent.py | 16 ++++++++++------ .../src/travel_agents/flight_agent.py | 16 ++++++++++------ docs/source/_ext/provider_models.py | 1 + docs/source/resources/includes/agents/weather.py | 12 ++++++------ 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/demos/use_cases/multi_agent_with_crewai_langchain/crewai/flight_agent.py b/demos/use_cases/multi_agent_with_crewai_langchain/crewai/flight_agent.py index bfff06de9..1988f860c 100644 --- a/demos/use_cases/multi_agent_with_crewai_langchain/crewai/flight_agent.py +++ b/demos/use_cases/multi_agent_with_crewai_langchain/crewai/flight_agent.py @@ -287,12 +287,16 @@ async def fetch_flights( "flight_number": flight.get("ident_iata") or flight.get("ident"), "departure_time": flight.get("scheduled_out"), "arrival_time": flight.get("scheduled_in"), - "origin": flight["origin"].get("code_iata") - if isinstance(flight.get("origin"), dict) - else None, - "destination": flight["destination"].get("code_iata") - if isinstance(flight.get("destination"), dict) - else None, + "origin": ( + flight["origin"].get("code_iata") + if isinstance(flight.get("origin"), dict) + else None + ), + "destination": ( + flight["destination"].get("code_iata") + if isinstance(flight.get("destination"), dict) + else None + ), "aircraft_type": flight.get("aircraft_type"), "status": flight.get("status"), "terminal_origin": flight.get("terminal_origin"), diff --git a/demos/use_cases/travel_agents/src/travel_agents/flight_agent.py b/demos/use_cases/travel_agents/src/travel_agents/flight_agent.py index f1e22266d..21e9bf2aa 100644 --- a/demos/use_cases/travel_agents/src/travel_agents/flight_agent.py +++ b/demos/use_cases/travel_agents/src/travel_agents/flight_agent.py @@ -197,12 +197,16 @@ async def fetch_flights( "flight_number": flight.get("ident_iata") or flight.get("ident"), "departure_time": flight.get("scheduled_out"), "arrival_time": flight.get("scheduled_in"), - "origin": flight["origin"].get("code_iata") - if isinstance(flight.get("origin"), dict) - else None, - "destination": flight["destination"].get("code_iata") - if isinstance(flight.get("destination"), dict) - else None, + "origin": ( + flight["origin"].get("code_iata") + if isinstance(flight.get("origin"), dict) + else None + ), + "destination": ( + flight["destination"].get("code_iata") + if isinstance(flight.get("destination"), dict) + else None + ), "aircraft_type": flight.get("aircraft_type"), "status": flight.get("status"), "terminal_origin": flight.get("terminal_origin"), diff --git a/docs/source/_ext/provider_models.py b/docs/source/_ext/provider_models.py index 9b7451c53..cad967ac2 100644 --- a/docs/source/_ext/provider_models.py +++ b/docs/source/_ext/provider_models.py @@ -1,4 +1,5 @@ """Sphinx extension to copy provider_models.yaml to build output.""" + from __future__ import annotations from pathlib import Path diff --git a/docs/source/resources/includes/agents/weather.py b/docs/source/resources/includes/agents/weather.py index 023512347..ad51d2dc8 100644 --- a/docs/source/resources/includes/agents/weather.py +++ b/docs/source/resources/includes/agents/weather.py @@ -230,12 +230,12 @@ async def get_weather_data(request: Request, messages: list, days: int = 1): "day_name": date_obj.strftime("%A"), "temperature_c": round(temp_c, 1) if temp_c is not None else None, "temperature_f": celsius_to_fahrenheit(temp_c), - "temperature_max_c": round(temp_max, 1) - if temp_max is not None - else None, - "temperature_min_c": round(temp_min, 1) - if temp_min is not None - else None, + "temperature_max_c": ( + round(temp_max, 1) if temp_max is not None else None + ), + "temperature_min_c": ( + round(temp_min, 1) if temp_min is not None else None + ), "weather_code": weather_code, "sunrise": sunrise.split("T")[1] if sunrise else None, "sunset": sunset.split("T")[1] if sunset else None, From f427c6b40836423610b690c8d6943209170488c8 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 12 Feb 2026 06:33:30 +0000 Subject: [PATCH 6/6] Fix rendered config to include merged endpoints and resolved tracing The old code mutated config_yaml via dict references (endpoints and tracing dicts were modified in-place), so yaml.dump captured the merged state. The refactored code used separate variables without writing them back to config, causing the rendered YAML to miss agent-inferred endpoints and resolved tracing config. Co-Authored-By: Claude Opus 4.6 --- cli/planoai/config_generator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index fd1bae5a7..df45b529c 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -84,6 +84,7 @@ def validate_and_render_schema(): config.get("agents", []), config.get("filters", []) ) clusters = build_clusters(config.get("endpoints", {}), agent_endpoints) + config["endpoints"] = clusters log.info("Defined clusters: %s", clusters) validate_prompt_targets(config, clusters) @@ -91,6 +92,7 @@ def validate_and_render_schema(): tracing = validate_tracing( config.get("tracing", {}), DEFAULT_OTEL_TRACING_GRPC_ENDPOINT ) + config["tracing"] = tracing updated_providers, llms_with_endpoint, model_name_keys = process_model_providers( listeners, config.get("routing", {})