diff --git a/airbyte/constants.py b/airbyte/constants.py index 4b4d99b1d..50503ff9a 100644 --- a/airbyte/constants.py +++ b/airbyte/constants.py @@ -258,36 +258,68 @@ def _str_to_bool(value: str) -> bool: # MCP (Model Context Protocol) Constants -MCP_TOOL_DOMAINS: list[str] = ["cloud", "local", "registry"] -"""Valid MCP tool domains available in the server. +MCP_READONLY_MODE_ENV_VAR: str = "AIRBYTE_CLOUD_MCP_READONLY_MODE" +"""Environment variable to enable read-only mode for the MCP server. -- `cloud`: Tools for managing Airbyte Cloud resources (sources, destinations, connections) -- `local`: Tools for local operations (connector validation, caching, SQL queries) -- `registry`: Tools for querying the Airbyte connector registry +When set to "1" or "true", only tools with readOnlyHint=True will be available. """ -AIRBYTE_MCP_DOMAINS: list[str] | None = [ - d.strip().lower() for d in os.getenv("AIRBYTE_MCP_DOMAINS", "").split(",") if d.strip() -] or None -"""Enabled MCP tool domains from the `AIRBYTE_MCP_DOMAINS` environment variable. +MCP_DOMAINS_DISABLED_ENV_VAR: str = "AIRBYTE_MCP_DOMAINS_DISABLED" +"""Environment variable to disable specific MCP tool domains. -Accepts a comma-separated list of domain names (e.g., "registry,cloud"). +Accepts a comma-separated list of domain names (e.g., "local,registry"). +Tools from these domains will not be advertised by the MCP server. +""" + +MCP_DOMAINS_ENV_VAR: str = "AIRBYTE_MCP_DOMAINS" +"""Environment variable to enable specific MCP tool domains. + +Accepts a comma-separated list of domain names (e.g., "cloud,registry"). If set, only tools from these domains will be advertised by the MCP server. -If not set (None), all domains are enabled by default. +""" -Values are case-insensitive and whitespace is trimmed. +MCP_WORKSPACE_ID_HEADER: str = "X-Airbyte-Workspace-Id" +"""HTTP header key for passing workspace ID to the MCP server. + +This allows per-request workspace ID configuration when using HTTP transport. """ -AIRBYTE_MCP_DOMAINS_DISABLED: list[str] | None = [ - d.strip().lower() for d in os.getenv("AIRBYTE_MCP_DOMAINS_DISABLED", "").split(",") if d.strip() -] or None -"""Disabled MCP tool domains from the `AIRBYTE_MCP_DOMAINS_DISABLED` environment variable. +# MCP Config Arg Names (used with get_mcp_config) -Accepts a comma-separated list of domain names (e.g., "registry"). -Tools from these domains will not be advertised by the MCP server. +MCP_CONFIG_READONLY_MODE: str = "airbyte_readonly_mode" +"""Config arg name for the legacy AIRBYTE_CLOUD_MCP_READONLY_MODE setting.""" -When both `AIRBYTE_MCP_DOMAINS` and `AIRBYTE_MCP_DOMAINS_DISABLED` are set, -the disabled list takes precedence (subtracts from the enabled list). +MCP_CONFIG_EXCLUDE_MODULES: str = "airbyte_exclude_modules" +"""Config arg name for the legacy AIRBYTE_MCP_DOMAINS_DISABLED setting.""" -Values are case-insensitive and whitespace is trimmed. -""" +MCP_CONFIG_INCLUDE_MODULES: str = "airbyte_include_modules" +"""Config arg name for the legacy AIRBYTE_MCP_DOMAINS setting.""" + +MCP_CONFIG_WORKSPACE_ID: str = "workspace_id" +"""Config arg name for the workspace ID setting.""" + +MCP_CONFIG_BEARER_TOKEN: str = "bearer_token" +"""Config arg name for the bearer token setting.""" + +MCP_CONFIG_CLIENT_ID: str = "client_id" +"""Config arg name for the client ID setting.""" + +MCP_CONFIG_CLIENT_SECRET: str = "client_secret" +"""Config arg name for the client secret setting.""" + +MCP_CONFIG_API_URL: str = "api_url" +"""Config arg name for the API URL setting.""" + +# MCP HTTP Header Keys for credentials + +MCP_BEARER_TOKEN_HEADER: str = "Authorization" +"""HTTP header key for bearer token (standard Authorization header).""" + +MCP_CLIENT_ID_HEADER: str = "X-Airbyte-Cloud-Client-Id" +"""HTTP header key for client ID.""" + +MCP_CLIENT_SECRET_HEADER: str = "X-Airbyte-Cloud-Client-Secret" +"""HTTP header key for client secret.""" + +MCP_API_URL_HEADER: str = "X-Airbyte-Cloud-Api-Url" +"""HTTP header key for API URL.""" diff --git a/airbyte/mcp/__init__.py b/airbyte/mcp/__init__.py index d7dcbeb9b..db8086c26 100644 --- a/airbyte/mcp/__init__.py +++ b/airbyte/mcp/__init__.py @@ -210,13 +210,13 @@ """ # noqa: D415 -from airbyte.mcp import cloud_ops, connector_registry, local_ops, server +from airbyte.mcp import cloud, local, registry, server __all__: list[str] = [ - "cloud_ops", - "connector_registry", - "local_ops", + "cloud", + "local", + "registry", "server", ] diff --git a/airbyte/mcp/_annotations.py b/airbyte/mcp/_annotations.py deleted file mode 100644 index 78cd82f05..000000000 --- a/airbyte/mcp/_annotations.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2024 Airbyte, Inc., all rights reserved. -"""MCP tool annotation constants. - -These constants define the standard MCP annotations for tools, following the -FastMCP 2.2.7+ specification. - -For more information, see: -https://gofastmcp.com/concepts/tools#mcp-annotations -""" - -from __future__ import annotations - - -READ_ONLY_HINT = "readOnlyHint" -"""Indicates if the tool only reads data without making any changes. - -When True, the tool performs read-only operations and does not modify any state. -When False, the tool may write, create, update, or delete data. - -FastMCP default if not specified: False -""" - -DESTRUCTIVE_HINT = "destructiveHint" -"""Signals if the tool's changes are destructive (updates or deletes existing data). - -This hint is only relevant for non-read-only tools (readOnlyHint=False). -When True, the tool modifies or deletes existing data in a way that may be -difficult or impossible to reverse. -When False, the tool creates new data or performs non-destructive operations. - -FastMCP default if not specified: True -""" - -IDEMPOTENT_HINT = "idempotentHint" -"""Indicates if repeated calls with the same parameters have the same effect. - -When True, calling the tool multiple times with identical parameters produces -the same result and side effects as calling it once. -When False, each call may produce different results or side effects. - -FastMCP default if not specified: False -""" - -OPEN_WORLD_HINT = "openWorldHint" -"""Specifies if the tool interacts with external systems. - -When True, the tool communicates with external services, APIs, or systems -outside the local environment (e.g., cloud APIs, remote databases, internet). -When False, the tool only operates on local state or resources. - -FastMCP default if not specified: True -""" diff --git a/airbyte/mcp/_arg_resolvers.py b/airbyte/mcp/_arg_resolvers.py new file mode 100644 index 000000000..cb4f8df5c --- /dev/null +++ b/airbyte/mcp/_arg_resolvers.py @@ -0,0 +1,173 @@ +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +"""Argument resolver functions for MCP tools. + +This module provides functions to resolve and validate arguments passed to MCP tools, +including connector configurations and list-of-strings arguments. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, overload + +import yaml + +from airbyte.secrets.hydration import deep_update, detect_hardcoded_secrets +from airbyte.secrets.util import get_secret + + +# Hint: Null result if input is Null +@overload +def resolve_list_of_strings(value: None) -> None: ... + + +# Hint: Non-null result if input is non-null +@overload +def resolve_list_of_strings(value: str | list[str] | set[str]) -> list[str]: ... + + +def resolve_list_of_strings(value: str | list[str] | set[str] | None) -> list[str] | None: + """Resolve a string or list of strings to a list of strings. + + This method will handle three types of input: + + 1. A list of strings (e.g., ["stream1", "stream2"]) will be returned as-is. + 2. None or empty input will return None. + 3. A single CSV string (e.g., "stream1,stream2") will be split into a list. + 4. A JSON string (e.g., '["stream1", "stream2"]') will be parsed into a list. + 5. If the input is empty or None, an empty list will be returned. + + Args: + value: A string or list of strings. + """ + if value is None: + return None + + if isinstance(value, list): + return value + + if isinstance(value, set): + return list(value) + + if not isinstance(value, str): + raise TypeError( + "Expected a string, list of strings, a set of strings, or None. " + f"Got '{type(value).__name__}': {value}" + ) + + value = value.strip() + if not value: + return [] + + if value.startswith("[") and value.endswith("]"): + # Try to parse as JSON array: + try: + parsed = json.loads(value) + if isinstance(parsed, list) and all(isinstance(item, str) for item in parsed): + return parsed + except json.JSONDecodeError as ex: + raise ValueError(f"Invalid JSON array: {value}") from ex + + # Fallback to CSV split: + return [item.strip() for item in value.split(",") if item.strip()] + + +def resolve_connector_config( # noqa: PLR0912 + config: dict | str | None = None, + config_file: str | Path | None = None, + config_secret_name: str | None = None, + config_spec_jsonschema: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Resolve a configuration dictionary, JSON string, or file path to a dictionary. + + Returns: + Resolved configuration dictionary (empty if no inputs provided) + + Raises: + ValueError: If JSON parsing fails or a provided input is invalid + + We reject hardcoded secrets in a config dict if we detect them. + """ + config_dict: dict[str, Any] = {} + + if config is None and config_file is None and config_secret_name is None: + return {} + + if config_file is not None: + if isinstance(config_file, str): + config_file = Path(config_file) + + if not isinstance(config_file, Path): + raise ValueError( + f"config_file must be a string or Path object, got: {type(config_file).__name__}" + ) + + if not config_file.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_file}") + + def _raise_invalid_type(file_config: object) -> None: + raise TypeError( + f"Configuration file must contain a valid JSON/YAML object, " + f"got: {type(file_config).__name__}" + ) + + try: + file_config = yaml.safe_load(config_file.read_text()) + if not isinstance(file_config, dict): + _raise_invalid_type(file_config) + config_dict.update(file_config) + except Exception as e: + raise ValueError(f"Error reading configuration file {config_file}: {e}") from e + + if config is not None: + if isinstance(config, dict): + config_dict.update(config) + elif isinstance(config, str): + try: + parsed_config = json.loads(config) + if not isinstance(parsed_config, dict): + raise TypeError( + f"Parsed JSON config must be an object/dict, " + f"got: {type(parsed_config).__name__}" + ) + config_dict.update(parsed_config) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in config parameter: {e}") from e + else: + raise ValueError(f"Config must be a dict or JSON string, got: {type(config).__name__}") + + if config_dict and config_spec_jsonschema is not None: + hardcoded_secrets: list[list[str]] = detect_hardcoded_secrets( + config=config_dict, + spec_json_schema=config_spec_jsonschema, + ) + if hardcoded_secrets: + error_msg = "Configuration contains hardcoded secrets in fields: " + error_msg += ", ".join( + [".".join(hardcoded_secret) for hardcoded_secret in hardcoded_secrets] + ) + + error_msg += ( + "Please use environment variables instead. For example:\n" + "To set a secret via reference, set its value to " + "`secret_reference::ENV_VAR_NAME`.\n" + ) + raise ValueError(error_msg) + + if config_secret_name is not None: + # Assume this is a secret name that points to a JSON/YAML config. + secret_config = yaml.safe_load(str(get_secret(config_secret_name))) + if not isinstance(secret_config, dict): + raise ValueError( + f"Secret '{config_secret_name}' must contain a valid JSON or YAML object, " + f"but got: {type(secret_config).__name__}" + ) + + # Merge the secret config into the main config: + deep_update( + config_dict, + secret_config, + ) + + return config_dict diff --git a/airbyte/mcp/_config.py b/airbyte/mcp/_config.py new file mode 100644 index 000000000..00925a0af --- /dev/null +++ b/airbyte/mcp/_config.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025 Airbyte, Inc., all rights reserved. +"""Internal utility functions for MCP secret loading.""" + +from __future__ import annotations + +import os +from pathlib import Path + +import dotenv + +from airbyte._util.meta import is_interactive +from airbyte.secrets import ( + DotenvSecretManager, + GoogleGSMSecretManager, + SecretSourceEnum, + register_secret_manager, +) +from airbyte.secrets.config import disable_secret_source +from airbyte.secrets.util import get_secret, is_secret_available + + +AIRBYTE_MCP_DOTENV_PATH_ENVVAR = "AIRBYTE_MCP_ENV_FILE" + + +def _load_dotenv_file(dotenv_path: Path | str) -> None: + """Load environment variables from a .env file.""" + if isinstance(dotenv_path, str): + dotenv_path = Path(dotenv_path) + if not dotenv_path.exists(): + raise FileNotFoundError(f".env file not found: {dotenv_path}") + + dotenv.load_dotenv(dotenv_path=dotenv_path) + + +def load_secrets_to_env_vars() -> None: + """Load secrets from dotenv files and secret managers into environment variables. + + This function must be called before mcp_server() so that config args can resolve + from the loaded environment variables. + + Note: Later secret manager registrations have higher priority than earlier ones. + """ + # Load the .env file from the current working directory. + envrc_path = Path.cwd() / ".envrc" + if envrc_path.exists(): + envrc_secret_mgr = DotenvSecretManager(envrc_path) + _load_dotenv_file(envrc_path) + register_secret_manager( + envrc_secret_mgr, + ) + + if AIRBYTE_MCP_DOTENV_PATH_ENVVAR in os.environ: + dotenv_path = Path(os.environ[AIRBYTE_MCP_DOTENV_PATH_ENVVAR]).absolute() + custom_dotenv_secret_mgr = DotenvSecretManager(dotenv_path) + _load_dotenv_file(dotenv_path) + register_secret_manager( + custom_dotenv_secret_mgr, + ) + + if is_secret_available("GCP_GSM_CREDENTIALS") and is_secret_available("GCP_GSM_PROJECT_ID"): + # Initialize the GoogleGSMSecretManager if the credentials and project are set. + register_secret_manager( + GoogleGSMSecretManager( + project=get_secret("GCP_GSM_PROJECT_ID"), + credentials_json=get_secret("GCP_GSM_CREDENTIALS"), + ) + ) + + # Make sure we disable the prompt source in non-interactive environments. + if not is_interactive(): + disable_secret_source(SecretSourceEnum.PROMPT) diff --git a/airbyte/mcp/_tool_utils.py b/airbyte/mcp/_tool_utils.py index c7043c94d..ec2babd62 100644 --- a/airbyte/mcp/_tool_utils.py +++ b/airbyte/mcp/_tool_utils.py @@ -1,42 +1,71 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. -"""MCP tool utility functions. +"""MCP tool utility functions for safe mode and config args. -This module provides a decorator to tag tool functions with MCP annotations -for deferred registration with safe mode filtering. +This module provides: +- Safe mode functionality for MCP tools, allowing tracking of resources created + during a session to prevent accidental deletion of pre-existing resources. +- Config args and filters for backward compatibility with legacy Airbyte env vars. """ from __future__ import annotations -import inspect import os -import warnings -from collections.abc import Callable -from functools import lru_cache -from typing import Any, Literal, TypeVar +from typing import TYPE_CHECKING -from airbyte.constants import ( - AIRBYTE_MCP_DOMAINS, - AIRBYTE_MCP_DOMAINS_DISABLED, - MCP_TOOL_DOMAINS, +from fastmcp_extensions import MCPServerConfigArg, get_mcp_config +from fastmcp_extensions.tool_filters import ( + ANNOTATION_MCP_MODULE, + ANNOTATION_READ_ONLY_HINT, + get_annotation, ) -from airbyte.mcp._annotations import ( - DESTRUCTIVE_HINT, - IDEMPOTENT_HINT, - OPEN_WORLD_HINT, - READ_ONLY_HINT, + +from airbyte.constants import ( + CLOUD_API_ROOT_ENV_VAR, + CLOUD_BEARER_TOKEN_ENV_VAR, + CLOUD_CLIENT_ID_ENV_VAR, + CLOUD_CLIENT_SECRET_ENV_VAR, + CLOUD_WORKSPACE_ID_ENV_VAR, + MCP_API_URL_HEADER, + MCP_BEARER_TOKEN_HEADER, + MCP_CLIENT_ID_HEADER, + MCP_CLIENT_SECRET_HEADER, + MCP_CONFIG_API_URL, + MCP_CONFIG_BEARER_TOKEN, + MCP_CONFIG_CLIENT_ID, + MCP_CONFIG_CLIENT_SECRET, + MCP_CONFIG_EXCLUDE_MODULES, + MCP_CONFIG_INCLUDE_MODULES, + MCP_CONFIG_READONLY_MODE, + MCP_CONFIG_WORKSPACE_ID, + MCP_DOMAINS_DISABLED_ENV_VAR, + MCP_DOMAINS_ENV_VAR, + MCP_READONLY_MODE_ENV_VAR, + MCP_WORKSPACE_ID_HEADER, ) -F = TypeVar("F", bound=Callable[..., Any]) +if TYPE_CHECKING: + from fastmcp import FastMCP + from mcp.types import Tool + + +# ============================================================================= +# Safe Mode Configuration +# ============================================================================= -AIRBYTE_CLOUD_MCP_READONLY_MODE = ( - os.environ.get("AIRBYTE_CLOUD_MCP_READONLY_MODE", "").strip() == "1" -) AIRBYTE_CLOUD_MCP_SAFE_MODE = os.environ.get("AIRBYTE_CLOUD_MCP_SAFE_MODE", "1").strip() != "0" +"""Whether safe mode is enabled for cloud operations. + +When enabled (default), destructive operations are only allowed on resources +created during the current session. +""" + AIRBYTE_CLOUD_WORKSPACE_ID_IS_SET = bool(os.environ.get("AIRBYTE_CLOUD_WORKSPACE_ID", "").strip()) +"""Whether the AIRBYTE_CLOUD_WORKSPACE_ID environment variable is set. +When set, the workspace_id parameter is hidden from cloud tools. +""" -_REGISTERED_TOOLS: list[tuple[Callable[..., Any], dict[str, Any]]] = [] _GUIDS_CREATED_IN_SESSION: set[str] = set() @@ -74,191 +103,125 @@ def check_guid_created_in_session(guid: str) -> None: ) -@lru_cache(maxsize=1) -def _resolve_mcp_domain_filters() -> tuple[set[str], set[str]]: - """Resolve MCP domain filters from environment variables. +# ============================================================================= +# Backward-Compatible Config Args +# ============================================================================= +# These config args support the legacy Airbyte-specific environment variables +# while the standard fastmcp-extensions config args support the new MCP_* vars. +# Both sets of filters are applied, so either env var will work. +# ============================================================================= - This function is cached to ensure warnings are only emitted once per process. - - Returns: - Tuple of (enabled_domains, disabled_domains) as sets. - If an env var is not set, the corresponding set will be empty. - """ - known_domains = set(MCP_TOOL_DOMAINS) - enabled = set(AIRBYTE_MCP_DOMAINS or []) - disabled = set(AIRBYTE_MCP_DOMAINS_DISABLED or []) - - # Check for unknown domains and warn - unknown_enabled = enabled - known_domains - unknown_disabled = disabled - known_domains - - if unknown_enabled or unknown_disabled: - parts: list[str] = [] - if unknown_enabled: - parts.append( - f"AIRBYTE_MCP_DOMAINS contains unknown domain(s): {sorted(unknown_enabled)}" - ) - if unknown_disabled: - parts.append( - "AIRBYTE_MCP_DOMAINS_DISABLED contains unknown domain(s): " - f"{sorted(unknown_disabled)}" - ) - known_list = ", ".join(sorted(known_domains)) - warning_message = "; ".join(parts) + f". Known MCP domains are: [{known_list}]." - warnings.warn(warning_message, stacklevel=3) - - return enabled, disabled - - -def is_domain_enabled(domain: str) -> bool: - """Check if a domain is enabled based on AIRBYTE_MCP_DOMAINS and AIRBYTE_MCP_DOMAINS_DISABLED. - - The logic is: - - If neither env var is set: all domains are enabled - - If only AIRBYTE_MCP_DOMAINS is set: only those domains are enabled - - If only AIRBYTE_MCP_DOMAINS_DISABLED is set: all domains except those are enabled - - If both are set: disabled domains subtract from enabled domains - - Args: - domain: The domain to check (e.g., "cloud", "local", "registry") +AIRBYTE_READONLY_MODE_CONFIG_ARG = MCPServerConfigArg( + name=MCP_CONFIG_READONLY_MODE, + env_var=MCP_READONLY_MODE_ENV_VAR, + default="0", + required=False, +) +"""Config arg for legacy AIRBYTE_CLOUD_MCP_READONLY_MODE env var.""" - Returns: - True if the domain is enabled, False otherwise - """ - enabled, disabled = _resolve_mcp_domain_filters() - domain_lower = domain.lower() +AIRBYTE_EXCLUDE_MODULES_CONFIG_ARG = MCPServerConfigArg( + name=MCP_CONFIG_EXCLUDE_MODULES, + env_var=MCP_DOMAINS_DISABLED_ENV_VAR, + default="", + required=False, +) +"""Config arg for legacy AIRBYTE_MCP_DOMAINS_DISABLED env var.""" - # If neither env var is set, all domains are enabled - if not enabled and not disabled: - return True +AIRBYTE_INCLUDE_MODULES_CONFIG_ARG = MCPServerConfigArg( + name=MCP_CONFIG_INCLUDE_MODULES, + env_var=MCP_DOMAINS_ENV_VAR, + default="", + required=False, +) +"""Config arg for legacy AIRBYTE_MCP_DOMAINS env var.""" + +WORKSPACE_ID_CONFIG_ARG = MCPServerConfigArg( + name=MCP_CONFIG_WORKSPACE_ID, + http_header_key=MCP_WORKSPACE_ID_HEADER, + env_var=CLOUD_WORKSPACE_ID_ENV_VAR, + required=False, + sensitive=False, +) +"""Config arg for workspace ID, supporting both HTTP header and env var.""" + +BEARER_TOKEN_CONFIG_ARG = MCPServerConfigArg( + name=MCP_CONFIG_BEARER_TOKEN, + http_header_key=MCP_BEARER_TOKEN_HEADER, + env_var=CLOUD_BEARER_TOKEN_ENV_VAR, + required=False, + sensitive=True, +) +"""Config arg for bearer token, supporting Authorization header and env var.""" + +CLIENT_ID_CONFIG_ARG = MCPServerConfigArg( + name=MCP_CONFIG_CLIENT_ID, + http_header_key=MCP_CLIENT_ID_HEADER, + env_var=CLOUD_CLIENT_ID_ENV_VAR, + required=False, + sensitive=True, +) +"""Config arg for client ID, supporting HTTP header and env var.""" + +CLIENT_SECRET_CONFIG_ARG = MCPServerConfigArg( + name=MCP_CONFIG_CLIENT_SECRET, + http_header_key=MCP_CLIENT_SECRET_HEADER, + env_var=CLOUD_CLIENT_SECRET_ENV_VAR, + required=False, + sensitive=True, +) +"""Config arg for client secret, supporting HTTP header and env var.""" + +API_URL_CONFIG_ARG = MCPServerConfigArg( + name=MCP_CONFIG_API_URL, + http_header_key=MCP_API_URL_HEADER, + env_var=CLOUD_API_ROOT_ENV_VAR, + required=False, + sensitive=False, +) +"""Config arg for API URL, supporting HTTP header and env var.""" - # If only disabled list is set, enable all except disabled - if not enabled and disabled: - return domain_lower not in disabled - # If only enabled list is set, only enable those domains - if enabled and not disabled: - return domain_lower in enabled +# ============================================================================= +# Tool Filters for Backward Compatibility +# ============================================================================= - # Both are set: disabled list subtracts from enabled list - return domain_lower in enabled and domain_lower not in disabled +def _parse_csv_config(value: str) -> list[str]: + """Parse a comma-separated config value into a list of strings.""" + if not value: + return [] + return [item.strip() for item in value.split(",") if item.strip()] -def should_register_tool(annotations: dict[str, Any]) -> bool: - """Check if a tool should be registered based on mode settings. - Args: - annotations: Tool annotations dict containing domain, readOnlyHint, and destructiveHint +def airbyte_readonly_mode_filter(tool: Tool, app: FastMCP) -> bool: + """Filter tools based on legacy AIRBYTE_CLOUD_MCP_READONLY_MODE env var. - Returns: - True if the tool should be registered, False if it should be filtered out + When set to "1", only show tools with readOnlyHint=True. """ - domain = annotations.get("domain") - domain_normalized = domain.lower() if isinstance(domain, str) else None - - # Check domain filtering first - if domain_normalized and not is_domain_enabled(domain_normalized): - return False - - # Cloud-specific readonly mode check (case-insensitive) - if domain_normalized == "cloud" and AIRBYTE_CLOUD_MCP_READONLY_MODE: - is_readonly = annotations.get(READ_ONLY_HINT, False) - if not is_readonly: - return False - + config_value = (get_mcp_config(app, MCP_CONFIG_READONLY_MODE) or "").lower() + if config_value in {"1", "true"}: + return bool(get_annotation(tool, ANNOTATION_READ_ONLY_HINT, default=False)) return True -def get_registered_tools( - domain: Literal["cloud", "local", "registry"] | None = None, -) -> list[tuple[Callable[..., Any], dict[str, Any]]]: - """Get all registered tools, optionally filtered by domain. +def airbyte_module_filter(tool: Tool, app: FastMCP) -> bool: + """Filter tools based on legacy AIRBYTE_MCP_DOMAINS and AIRBYTE_MCP_DOMAINS_DISABLED. - Args: - domain: The domain to filter by (e.g., "cloud", "local", "registry"). - If None, returns all tools. - - Returns: - List of tuples containing (function, annotations) for each registered tool + When AIRBYTE_MCP_DOMAINS_DISABLED is set, hide tools from those modules. + When AIRBYTE_MCP_DOMAINS is set, only show tools from those modules. """ - if domain is None: - return _REGISTERED_TOOLS.copy() - return [(func, ann) for func, ann in _REGISTERED_TOOLS if ann.get("domain") == domain] + exclude_modules = _parse_csv_config(get_mcp_config(app, MCP_CONFIG_EXCLUDE_MODULES) or "") + include_modules = _parse_csv_config(get_mcp_config(app, MCP_CONFIG_INCLUDE_MODULES) or "") + # Get the tool's mcp_module from annotations + tool_module = get_annotation(tool, ANNOTATION_MCP_MODULE, None) -def mcp_tool( - domain: Literal["cloud", "local", "registry"], - *, - read_only: bool = False, - destructive: bool = False, - idempotent: bool = False, - open_world: bool = False, - extra_help_text: str | None = None, -) -> Callable[[F], F]: - """Decorator to tag an MCP tool function with annotations for deferred registration. + if exclude_modules: + # Hide tools from excluded modules + return not (tool_module and tool_module in exclude_modules) - This decorator stores the annotations on the function for later use during - deferred registration. It does not register the tool immediately. + if include_modules: + # Only show tools from included modules + return bool(tool_module and tool_module in include_modules) - Args: - domain: The domain this tool belongs to (e.g., "cloud", "local", "registry") - read_only: If True, tool only reads without making changes (default: False) - destructive: If True, tool modifies/deletes existing data (default: False) - idempotent: If True, repeated calls have same effect (default: False) - open_world: If True, tool interacts with external systems (default: False) - extra_help_text: Optional text to append to the function's docstring - with a newline delimiter - - Returns: - Decorator function that tags the tool with annotations - - Example: - @mcp_tool("cloud", read_only=True, idempotent=True) - def list_sources(): - ... - """ - annotations: dict[str, Any] = { - "domain": domain, - READ_ONLY_HINT: read_only, - DESTRUCTIVE_HINT: destructive, - IDEMPOTENT_HINT: idempotent, - OPEN_WORLD_HINT: open_world, - } - - def decorator(func: F) -> F: - func._mcp_annotations = annotations # type: ignore[attr-defined] # noqa: SLF001 - func._mcp_domain = domain # type: ignore[attr-defined] # noqa: SLF001 - func._mcp_extra_help_text = extra_help_text # type: ignore[attr-defined] # noqa: SLF001 - _REGISTERED_TOOLS.append((func, annotations)) - return func - - return decorator - - -def register_tools(app: Any, domain: Literal["cloud", "local", "registry"]) -> None: # noqa: ANN401 - """Register tools with the FastMCP app, filtered by domain and safe mode settings. - - Args: - app: The FastMCP app instance - domain: The domain to register tools for (e.g., "cloud", "local", "registry") - """ - for func, tool_annotations in get_registered_tools(domain): - if should_register_tool(tool_annotations): - extra_help_text = getattr(func, "_mcp_extra_help_text", None) - description: str | None = None - if extra_help_text: - description = (func.__doc__ or "").rstrip() + "\n" + extra_help_text - - # For cloud tools, conditionally hide workspace_id parameter when env var is set - exclude_args: list[str] | None = None - if domain == "cloud" and AIRBYTE_CLOUD_WORKSPACE_ID_IS_SET: - params = set(inspect.signature(func).parameters.keys()) - excluded = [name for name in ["workspace_id"] if name in params] - exclude_args = excluded or None - - app.tool( - func, - annotations=tool_annotations, - description=description, - exclude_args=exclude_args, - ) + return True diff --git a/airbyte/mcp/_util.py b/airbyte/mcp/_util.py deleted file mode 100644 index d93974d8f..000000000 --- a/airbyte/mcp/_util.py +++ /dev/null @@ -1,459 +0,0 @@ -# Copyright (c) 2025 Airbyte, Inc., all rights reserved. -"""Internal utility functions for MCP.""" - -from __future__ import annotations - -import json -import os -from pathlib import Path -from typing import Any, overload - -import dotenv -import yaml -from fastmcp.server.dependencies import get_http_headers - -from airbyte._util.meta import is_interactive -from airbyte.cloud.auth import ( - resolve_cloud_api_url, - resolve_cloud_bearer_token, - resolve_cloud_client_id, - resolve_cloud_client_secret, - resolve_cloud_workspace_id, -) -from airbyte.cloud.client_config import CloudClientConfig -from airbyte.secrets import ( - DotenvSecretManager, - GoogleGSMSecretManager, - SecretSourceEnum, - register_secret_manager, -) -from airbyte.secrets.base import SecretString -from airbyte.secrets.config import disable_secret_source -from airbyte.secrets.hydration import deep_update, detect_hardcoded_secrets -from airbyte.secrets.util import get_secret, is_secret_available - - -AIRBYTE_MCP_DOTENV_PATH_ENVVAR = "AIRBYTE_MCP_ENV_FILE" - -# HTTP header names for Airbyte Cloud authentication (X-Airbyte-Cloud-* convention) -HEADER_CLIENT_ID = "X-Airbyte-Cloud-Client-Id" -HEADER_CLIENT_SECRET = "X-Airbyte-Cloud-Client-Secret" -HEADER_WORKSPACE_ID = "X-Airbyte-Cloud-Workspace-Id" -HEADER_API_URL = "X-Airbyte-Cloud-Api-Url" - - -def _load_dotenv_file(dotenv_path: Path | str) -> None: - """Load environment variables from a .env file.""" - if isinstance(dotenv_path, str): - dotenv_path = Path(dotenv_path) - if not dotenv_path.exists(): - raise FileNotFoundError(f".env file not found: {dotenv_path}") - - dotenv.load_dotenv(dotenv_path=dotenv_path) - - -def initialize_secrets() -> None: - """Initialize dotenv to load environment variables from .env files. - - Note: Later secret manager registrations have higher priority than earlier ones. - """ - # Load the .env file from the current working directory. - envrc_path = Path.cwd() / ".envrc" - if envrc_path.exists(): - envrc_secret_mgr = DotenvSecretManager(envrc_path) - _load_dotenv_file(envrc_path) - register_secret_manager( - envrc_secret_mgr, - ) - - if AIRBYTE_MCP_DOTENV_PATH_ENVVAR in os.environ: - dotenv_path = Path(os.environ[AIRBYTE_MCP_DOTENV_PATH_ENVVAR]).absolute() - custom_dotenv_secret_mgr = DotenvSecretManager(dotenv_path) - _load_dotenv_file(dotenv_path) - register_secret_manager( - custom_dotenv_secret_mgr, - ) - - if is_secret_available("GCP_GSM_CREDENTIALS") and is_secret_available("GCP_GSM_PROJECT_ID"): - # Initialize the GoogleGSMSecretManager if the credentials and project are set. - register_secret_manager( - GoogleGSMSecretManager( - project=get_secret("GCP_GSM_PROJECT_ID"), - credentials_json=get_secret("GCP_GSM_CREDENTIALS"), - ) - ) - - # Make sure we disable the prompt source in non-interactive environments. - if not is_interactive(): - disable_secret_source(SecretSourceEnum.PROMPT) - - -# Hint: Null result if input is Null -@overload -def resolve_list_of_strings(value: None) -> None: ... - - -# Hint: Non-null result if input is non-null -@overload -def resolve_list_of_strings(value: str | list[str] | set[str]) -> list[str]: ... - - -def resolve_list_of_strings(value: str | list[str] | set[str] | None) -> list[str] | None: - """Resolve a string or list of strings to a list of strings. - - This method will handle three types of input: - - 1. A list of strings (e.g., ["stream1", "stream2"]) will be returned as-is. - 2. None or empty input will return None. - 3. A single CSV string (e.g., "stream1,stream2") will be split into a list. - 4. A JSON string (e.g., '["stream1", "stream2"]') will be parsed into a list. - 5. If the input is empty or None, an empty list will be returned. - - Args: - value: A string or list of strings. - """ - if value is None: - return None - - if isinstance(value, list): - return value - - if isinstance(value, set): - return list(value) - - if not isinstance(value, str): - raise TypeError( - "Expected a string, list of strings, a set of strings, or None. " - f"Got '{type(value).__name__}': {value}" - ) - - value = value.strip() - if not value: - return [] - - if value.startswith("[") and value.endswith("]"): - # Try to parse as JSON array: - try: - parsed = json.loads(value) - if isinstance(parsed, list) and all(isinstance(item, str) for item in parsed): - return parsed - except json.JSONDecodeError as ex: - raise ValueError(f"Invalid JSON array: {value}") from ex - - # Fallback to CSV split: - return [item.strip() for item in value.split(",") if item.strip()] - - -def resolve_config( # noqa: PLR0912 - config: dict | str | None = None, - config_file: str | Path | None = None, - config_secret_name: str | None = None, - config_spec_jsonschema: dict[str, Any] | None = None, -) -> dict[str, Any]: - """Resolve a configuration dictionary, JSON string, or file path to a dictionary. - - Returns: - Resolved configuration dictionary - - Raises: - ValueError: If no configuration provided or if JSON parsing fails - - We reject hardcoded secrets in a config dict if we detect them. - """ - config_dict: dict[str, Any] = {} - - if config is None and config_file is None and config_secret_name is None: - return {} - - if config_file is not None: - if isinstance(config_file, str): - config_file = Path(config_file) - - if not isinstance(config_file, Path): - raise ValueError( - f"config_file must be a string or Path object, got: {type(config_file).__name__}" - ) - - if not config_file.exists(): - raise FileNotFoundError(f"Configuration file not found: {config_file}") - - def _raise_invalid_type(file_config: object) -> None: - raise TypeError( - f"Configuration file must contain a valid JSON/YAML object, " - f"got: {type(file_config).__name__}" - ) - - try: - file_config = yaml.safe_load(config_file.read_text()) - if not isinstance(file_config, dict): - _raise_invalid_type(file_config) - config_dict.update(file_config) - except Exception as e: - raise ValueError(f"Error reading configuration file {config_file}: {e}") from e - - if config is not None: - if isinstance(config, dict): - config_dict.update(config) - elif isinstance(config, str): - try: - parsed_config = json.loads(config) - if not isinstance(parsed_config, dict): - raise TypeError( - f"Parsed JSON config must be an object/dict, " - f"got: {type(parsed_config).__name__}" - ) - config_dict.update(parsed_config) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON in config parameter: {e}") from e - else: - raise ValueError(f"Config must be a dict or JSON string, got: {type(config).__name__}") - - if config_dict and config_spec_jsonschema is not None: - hardcoded_secrets: list[list[str]] = detect_hardcoded_secrets( - config=config_dict, - spec_json_schema=config_spec_jsonschema, - ) - if hardcoded_secrets: - error_msg = "Configuration contains hardcoded secrets in fields: " - error_msg += ", ".join( - [".".join(hardcoded_secret) for hardcoded_secret in hardcoded_secrets] - ) - - error_msg += ( - "Please use environment variables instead. For example:\n" - "To set a secret via reference, set its value to " - "`secret_reference::ENV_VAR_NAME`.\n" - ) - raise ValueError(error_msg) - - if config_secret_name is not None: - # Assume this is a secret name that points to a JSON/YAML config. - secret_config = yaml.safe_load(str(get_secret(config_secret_name))) - if not isinstance(secret_config, dict): - raise ValueError( - f"Secret '{config_secret_name}' must contain a valid JSON or YAML object, " - f"but got: {type(secret_config).__name__}" - ) - - # Merge the secret config into the main config: - deep_update( - config_dict, - secret_config, - ) - - return config_dict - - -def _get_header_value(headers: dict[str, str], header_name: str) -> str | None: - """Get a header value from a headers dict, case-insensitively. - - Args: - headers: Dictionary of HTTP headers. - header_name: The header name to look for (case-insensitive). - - Returns: - The header value if found, None otherwise. - """ - header_name_lower = header_name.lower() - for key, value in headers.items(): - if key.lower() == header_name_lower: - return value - return None - - -def get_bearer_token_from_headers() -> SecretString | None: - """Extract bearer token from HTTP Authorization header. - - This function extracts the bearer token from the standard HTTP - `Authorization: Bearer ` header when running as an MCP HTTP server. - - Returns: - The bearer token as a SecretString, or None if not found or not in HTTP context. - """ - headers = get_http_headers() - if not headers: - return None - - auth_header = _get_header_value(headers, "Authorization") - if not auth_header: - return None - - # Parse "Bearer " format - if auth_header.lower().startswith("bearer "): - token = auth_header[7:].strip() # Remove "Bearer " prefix - if token: - return SecretString(token) - - return None - - -def get_client_id_from_headers() -> SecretString | None: - """Extract client ID from HTTP headers. - - Returns: - The client ID as a SecretString, or None if not found. - """ - headers = get_http_headers() - if not headers: - return None - - value = _get_header_value(headers, HEADER_CLIENT_ID) - if value: - return SecretString(value) - return None - - -def get_client_secret_from_headers() -> SecretString | None: - """Extract client secret from HTTP headers. - - Returns: - The client secret as a SecretString, or None if not found. - """ - headers = get_http_headers() - if not headers: - return None - - value = _get_header_value(headers, HEADER_CLIENT_SECRET) - if value: - return SecretString(value) - return None - - -def get_workspace_id_from_headers() -> str | None: - """Extract workspace ID from HTTP headers. - - Returns: - The workspace ID, or None if not found. - """ - headers = get_http_headers() - if not headers: - return None - - return _get_header_value(headers, HEADER_WORKSPACE_ID) - - -def get_api_url_from_headers() -> str | None: - """Extract API URL from HTTP headers. - - Returns: - The API URL, or None if not found. - """ - headers = get_http_headers() - if not headers: - return None - - return _get_header_value(headers, HEADER_API_URL) - - -def resolve_cloud_credentials( - *, - client_id: SecretString | str | None = None, - client_secret: SecretString | str | None = None, - bearer_token: SecretString | str | None = None, - api_root: str | None = None, -) -> CloudClientConfig: - """Resolve CloudClientConfig from multiple sources. - - This function resolves authentication credentials for Airbyte Cloud - from multiple sources in the following priority order: - - 1. Explicit parameters passed to this function - 2. HTTP headers (when running as MCP HTTP server) - 3. Environment variables - - For bearer token authentication, the resolution order is: - 1. Explicit `bearer_token` parameter - 2. HTTP `Authorization: Bearer ` header - 3. `AIRBYTE_CLOUD_BEARER_TOKEN` environment variable - - For client credentials authentication, the resolution order is: - 1. Explicit `client_id` and `client_secret` parameters - 2. HTTP `X-Airbyte-Cloud-Client-Id` and `X-Airbyte-Cloud-Client-Secret` headers - 3. `AIRBYTE_CLOUD_CLIENT_ID` and `AIRBYTE_CLOUD_CLIENT_SECRET` environment variables - - Args: - client_id: Optional explicit client ID. - client_secret: Optional explicit client secret. - bearer_token: Optional explicit bearer token. - api_root: Optional explicit API root URL. - - Returns: - A CloudClientConfig instance with resolved authentication. - - Raises: - PyAirbyteInputError: If no valid authentication can be resolved. - """ - # Resolve API root (explicit -> header -> env var -> default) - resolved_api_root = api_root or get_api_url_from_headers() or resolve_cloud_api_url() - - # Try to resolve bearer token first (explicit -> header -> env var) - resolved_bearer_token: SecretString | None = None - if bearer_token is not None: - resolved_bearer_token = SecretString(bearer_token) - else: - # Try HTTP header - resolved_bearer_token = get_bearer_token_from_headers() - if resolved_bearer_token is None: - # Try env var - resolved_bearer_token = resolve_cloud_bearer_token() - - if resolved_bearer_token: - return CloudClientConfig( - bearer_token=resolved_bearer_token, - api_root=resolved_api_root, - ) - - # Fall back to client credentials (explicit -> header -> env var) - resolved_client_id: SecretString | None = None - resolved_client_secret: SecretString | None = None - - if client_id is not None: - resolved_client_id = SecretString(client_id) - else: - resolved_client_id = get_client_id_from_headers() - if resolved_client_id is None: - resolved_client_id = resolve_cloud_client_id() - - if client_secret is not None: - resolved_client_secret = SecretString(client_secret) - else: - resolved_client_secret = get_client_secret_from_headers() - if resolved_client_secret is None: - resolved_client_secret = resolve_cloud_client_secret() - - return CloudClientConfig( - client_id=resolved_client_id, - client_secret=resolved_client_secret, - api_root=resolved_api_root, - ) - - -def resolve_workspace_id( - workspace_id: str | None = None, -) -> str: - """Resolve workspace ID from multiple sources. - - Resolution order: - 1. Explicit `workspace_id` parameter - 2. HTTP `X-Airbyte-Cloud-Workspace-Id` header - 3. `AIRBYTE_CLOUD_WORKSPACE_ID` environment variable - - Args: - workspace_id: Optional explicit workspace ID. - - Returns: - The resolved workspace ID. - - Raises: - PyAirbyteSecretNotFoundError: If no workspace ID can be resolved. - """ - if workspace_id is not None: - return workspace_id - - # Try HTTP header - header_workspace_id = get_workspace_id_from_headers() - if header_workspace_id: - return header_workspace_id - - # Fall back to env var - return resolve_cloud_workspace_id() diff --git a/airbyte/mcp/cloud_ops.py b/airbyte/mcp/cloud.py similarity index 92% rename from airbyte/mcp/cloud_ops.py rename to airbyte/mcp/cloud.py index d614ed9bb..b951f3da4 100644 --- a/airbyte/mcp/cloud_ops.py +++ b/airbyte/mcp/cloud.py @@ -4,7 +4,8 @@ from pathlib import Path from typing import Annotated, Any, Literal, cast -from fastmcp import FastMCP +from fastmcp import Context, FastMCP +from fastmcp_extensions import get_mcp_config, mcp_tool, register_mcp_tools from pydantic import BaseModel, Field from airbyte import cloud, get_destination, get_source @@ -12,19 +13,20 @@ from airbyte.cloud.connectors import CustomCloudSourceDefinition from airbyte.cloud.constants import FAILED_STATUSES from airbyte.cloud.workspaces import CloudWorkspace +from airbyte.constants import ( + MCP_CONFIG_API_URL, + MCP_CONFIG_BEARER_TOKEN, + MCP_CONFIG_CLIENT_ID, + MCP_CONFIG_CLIENT_SECRET, + MCP_CONFIG_WORKSPACE_ID, +) from airbyte.destinations.util import get_noop_destination from airbyte.exceptions import AirbyteMissingResourceError, PyAirbyteInputError +from airbyte.mcp._arg_resolvers import resolve_connector_config, resolve_list_of_strings from airbyte.mcp._tool_utils import ( + AIRBYTE_CLOUD_WORKSPACE_ID_IS_SET, check_guid_created_in_session, - mcp_tool, register_guid_created_in_session, - register_tools, -) -from airbyte.mcp._util import ( - resolve_cloud_credentials, - resolve_config, - resolve_list_of_strings, - resolve_workspace_id, ) from airbyte.secrets import SecretString @@ -210,35 +212,47 @@ class SyncJobListResult(BaseModel): """Whether jobs are ordered newest-first (True) or oldest-first (False).""" -def _get_cloud_workspace(workspace_id: str | None = None) -> CloudWorkspace: +def _get_cloud_workspace( + ctx: Context, + workspace_id: str | None = None, +) -> CloudWorkspace: """Get an authenticated CloudWorkspace. - Resolves credentials from multiple sources in order: + Resolves credentials from multiple sources via MCP config args in order: 1. HTTP headers (when running as MCP server with HTTP/SSE transport) 2. Environment variables - Args: - workspace_id: Optional workspace ID. If not provided, uses HTTP headers - or the AIRBYTE_CLOUD_WORKSPACE_ID environment variable. + The ctx parameter provides access to MCP config values that are resolved + from HTTP headers or environment variables based on the config args + defined in server.py. """ - credentials = resolve_cloud_credentials() - resolved_workspace_id = resolve_workspace_id(workspace_id) + resolved_workspace_id = workspace_id or get_mcp_config(ctx, MCP_CONFIG_WORKSPACE_ID) + if not resolved_workspace_id: + raise PyAirbyteInputError( + message="Workspace ID is required but not provided.", + guidance="Set AIRBYTE_CLOUD_WORKSPACE_ID env var or pass workspace_id parameter.", + ) + + bearer_token = get_mcp_config(ctx, MCP_CONFIG_BEARER_TOKEN) + client_id = get_mcp_config(ctx, MCP_CONFIG_CLIENT_ID) + client_secret = get_mcp_config(ctx, MCP_CONFIG_CLIENT_SECRET) + api_url = get_mcp_config(ctx, MCP_CONFIG_API_URL) or api_util.CLOUD_API_ROOT return CloudWorkspace( workspace_id=resolved_workspace_id, - client_id=credentials.client_id, - client_secret=credentials.client_secret, - bearer_token=credentials.bearer_token, - api_root=credentials.api_root, + client_id=SecretString(client_id) if client_id else None, + client_secret=SecretString(client_secret) if client_secret else None, + bearer_token=SecretString(bearer_token) if bearer_token else None, + api_root=api_url, ) @mcp_tool( - domain="cloud", open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def deploy_source_to_cloud( + ctx: Context, source_name: Annotated[ str, Field(description="The name to use when deploying the source."), @@ -282,14 +296,14 @@ def deploy_source_to_cloud( source_connector_name, no_executor=True, ) - config_dict = resolve_config( + config_dict = resolve_connector_config( config=config, config_secret_name=config_secret_name, config_spec_jsonschema=source.config_spec, ) source.set_config(config_dict, validate=True) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) deployed_source = workspace.deploy_source( name=source_name, source=source, @@ -304,11 +318,11 @@ def deploy_source_to_cloud( @mcp_tool( - domain="cloud", open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def deploy_destination_to_cloud( + ctx: Context, destination_name: Annotated[ str, Field(description="The name to use when deploying the destination."), @@ -352,14 +366,14 @@ def deploy_destination_to_cloud( destination_connector_name, no_executor=True, ) - config_dict = resolve_config( + config_dict = resolve_connector_config( config=config, config_secret_name=config_secret_name, config_spec_jsonschema=destination.config_spec, ) destination.set_config(config_dict, validate=True) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) deployed_destination = workspace.deploy_destination( name=destination_name, destination=destination, @@ -374,11 +388,11 @@ def deploy_destination_to_cloud( @mcp_tool( - domain="cloud", open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def create_connection_on_cloud( + ctx: Context, connection_name: Annotated[ str, Field(description="The name of the connection."), @@ -419,7 +433,7 @@ def create_connection_on_cloud( ) -> str: """Create a connection between a deployed source and destination on Airbyte Cloud.""" resolved_streams_list: list[str] = resolve_list_of_strings(selected_streams) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) deployed_connection = workspace.deploy_connection( connection_name=connection_name, source=source_id, @@ -437,11 +451,11 @@ def create_connection_on_cloud( @mcp_tool( - domain="cloud", open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def run_cloud_sync( + ctx: Context, connection_id: Annotated[ str, Field(description="The ID of the Airbyte Cloud connection."), @@ -474,7 +488,7 @@ def run_cloud_sync( ], ) -> str: """Run a sync job on Airbyte Cloud.""" - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connection = workspace.get_connection(connection_id=connection_id) sync_result = connection.run_sync(wait=wait, wait_timeout=wait_timeout) @@ -489,13 +503,13 @@ def run_cloud_sync( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def check_airbyte_cloud_workspace( + ctx: Context, *, workspace_id: Annotated[ str | None, @@ -509,7 +523,7 @@ def check_airbyte_cloud_workspace( Returns workspace details including workspace ID, name, and organization info. """ - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) # Get workspace details from the public API using workspace's credentials workspace_response = api_util.get_workspace( @@ -539,11 +553,11 @@ def check_airbyte_cloud_workspace( @mcp_tool( - domain="cloud", open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def deploy_noop_destination_to_cloud( + ctx: Context, name: str = "No-op Destination", *, workspace_id: Annotated[ @@ -557,7 +571,7 @@ def deploy_noop_destination_to_cloud( ) -> str: """Deploy the No-op destination to Airbyte Cloud for testing purposes.""" destination = get_noop_destination() - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) deployed_destination = workspace.deploy_destination( name=name, destination=destination, @@ -572,13 +586,13 @@ def deploy_noop_destination_to_cloud( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def get_cloud_sync_status( + ctx: Context, connection_id: Annotated[ str, Field( @@ -609,7 +623,7 @@ def get_cloud_sync_status( ], ) -> dict[str, Any]: """Get the status of a sync job from the Airbyte Cloud.""" - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connection = workspace.get_connection(connection_id=connection_id) # If a job ID is provided, get the job by ID. @@ -646,13 +660,13 @@ def get_cloud_sync_status( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def list_cloud_sync_jobs( + ctx: Context, connection_id: Annotated[ str, Field(description="The ID of the Airbyte Cloud connection."), @@ -718,7 +732,7 @@ def list_cloud_sync_jobs( elif from_tail is None: from_tail = False - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connection = workspace.get_connection(connection_id=connection_id) # Cap at 500 to avoid overloading agent context @@ -751,13 +765,13 @@ def list_cloud_sync_jobs( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def list_deployed_cloud_source_connectors( + ctx: Context, *, workspace_id: Annotated[ str | None, @@ -782,7 +796,7 @@ def list_deployed_cloud_source_connectors( ], ) -> list[CloudSourceResult]: """List all deployed source connectors in the Airbyte Cloud workspace.""" - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) sources = workspace.list_sources() # Filter by name if requested @@ -806,13 +820,13 @@ def list_deployed_cloud_source_connectors( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def list_deployed_cloud_destination_connectors( + ctx: Context, *, workspace_id: Annotated[ str | None, @@ -837,7 +851,7 @@ def list_deployed_cloud_destination_connectors( ], ) -> list[CloudDestinationResult]: """List all deployed destination connectors in the Airbyte Cloud workspace.""" - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) destinations = workspace.list_destinations() # Filter by name if requested @@ -861,13 +875,13 @@ def list_deployed_cloud_destination_connectors( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def describe_cloud_source( + ctx: Context, source_id: Annotated[ str, Field(description="The ID of the source to describe."), @@ -882,7 +896,7 @@ def describe_cloud_source( ], ) -> CloudSourceDetails: """Get detailed information about a specific deployed source connector.""" - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) source = workspace.get_source(source_id=source_id) # Access name property to ensure _connector_info is populated @@ -897,13 +911,13 @@ def describe_cloud_source( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def describe_cloud_destination( + ctx: Context, destination_id: Annotated[ str, Field(description="The ID of the destination to describe."), @@ -918,7 +932,7 @@ def describe_cloud_destination( ], ) -> CloudDestinationDetails: """Get detailed information about a specific deployed destination connector.""" - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) destination = workspace.get_destination(destination_id=destination_id) # Access name property to ensure _connector_info is populated @@ -933,13 +947,13 @@ def describe_cloud_destination( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def describe_cloud_connection( + ctx: Context, connection_id: Annotated[ str, Field(description="The ID of the connection to describe."), @@ -954,7 +968,7 @@ def describe_cloud_connection( ], ) -> CloudConnectionDetails: """Get detailed information about a specific deployed connection.""" - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connection = workspace.get_connection(connection_id=connection_id) return CloudConnectionDetails( @@ -971,13 +985,13 @@ def describe_cloud_connection( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def get_cloud_sync_logs( + ctx: Context, connection_id: Annotated[ str, Field(description="The ID of the Airbyte Cloud connection."), @@ -1043,7 +1057,7 @@ def get_cloud_sync_logs( if from_tail is None and line_offset is None: from_tail = True - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connection = workspace.get_connection(connection_id=connection_id) sync_result: cloud.SyncResult | None = connection.get_sync_result(job_id=job_id) @@ -1119,13 +1133,13 @@ def get_cloud_sync_logs( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def list_deployed_cloud_connections( + ctx: Context, *, workspace_id: Annotated[ str | None, @@ -1173,7 +1187,7 @@ def list_deployed_cloud_connections( recent completed sync job failed or was cancelled will be returned. This implicitly enables with_connection_status. """ - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connections = workspace.list_connections() # Filter by name if requested @@ -1344,13 +1358,13 @@ def _resolve_organization_id( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def list_cloud_workspaces( + ctx: Context, *, organization_id: Annotated[ str | None, @@ -1389,23 +1403,26 @@ def list_cloud_workspaces( This tool will NOT list workspaces across all organizations - you must specify which organization to list workspaces from. """ - credentials = resolve_cloud_credentials() + bearer_token = get_mcp_config(ctx, MCP_CONFIG_BEARER_TOKEN) + client_id = get_mcp_config(ctx, MCP_CONFIG_CLIENT_ID) + client_secret = get_mcp_config(ctx, MCP_CONFIG_CLIENT_SECRET) + api_url = get_mcp_config(ctx, MCP_CONFIG_API_URL) or api_util.CLOUD_API_ROOT resolved_org_id = _resolve_organization_id( organization_id=organization_id, organization_name=organization_name, - api_root=credentials.api_root, - client_id=credentials.client_id, - client_secret=credentials.client_secret, - bearer_token=credentials.bearer_token, + api_root=api_url, + client_id=SecretString(client_id) if client_id else None, + client_secret=SecretString(client_secret) if client_secret else None, + bearer_token=SecretString(bearer_token) if bearer_token else None, ) workspaces = api_util.list_workspaces_in_organization( organization_id=resolved_org_id, - api_root=credentials.api_root, - client_id=credentials.client_id, - client_secret=credentials.client_secret, - bearer_token=credentials.bearer_token, + api_root=api_url, + client_id=SecretString(client_id) if client_id else None, + client_secret=SecretString(client_secret) if client_secret else None, + bearer_token=SecretString(bearer_token) if bearer_token else None, name_contains=name_contains, max_items_limit=max_items_limit, ) @@ -1421,13 +1438,13 @@ def list_cloud_workspaces( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def describe_cloud_organization( + ctx: Context, *, organization_id: Annotated[ str | None, @@ -1451,15 +1468,18 @@ def describe_cloud_organization( Requires either organization_id OR organization_name (exact match) to be provided. This tool is useful for looking up an organization's ID from its name, or vice versa. """ - credentials = resolve_cloud_credentials() + bearer_token = get_mcp_config(ctx, MCP_CONFIG_BEARER_TOKEN) + client_id = get_mcp_config(ctx, MCP_CONFIG_CLIENT_ID) + client_secret = get_mcp_config(ctx, MCP_CONFIG_CLIENT_SECRET) + api_url = get_mcp_config(ctx, MCP_CONFIG_API_URL) or api_util.CLOUD_API_ROOT org = _resolve_organization( organization_id=organization_id, organization_name=organization_name, - api_root=credentials.api_root, - client_id=credentials.client_id, - client_secret=credentials.client_secret, - bearer_token=credentials.bearer_token, + api_root=api_url, + client_id=SecretString(client_id) if client_id else None, + client_secret=SecretString(client_secret) if client_secret else None, + bearer_token=SecretString(bearer_token) if bearer_token else None, ) return CloudOrganizationResult( @@ -1484,11 +1504,11 @@ def _get_custom_source_definition_description( @mcp_tool( - domain="cloud", open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def publish_custom_source_definition( + ctx: Context, name: Annotated[ str, Field(description="The name for the custom connector definition."), @@ -1565,14 +1585,14 @@ def publish_custom_source_definition( testing_values_dict: dict[str, Any] | None = None if testing_values is not None or testing_values_secret_name is not None: testing_values_dict = ( - resolve_config( + resolve_connector_config( config=testing_values, config_secret_name=testing_values_secret_name, ) or None ) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) custom_source = workspace.publish_custom_source_definition( name=name, manifest_yaml=processed_manifest, @@ -1591,12 +1611,12 @@ def publish_custom_source_definition( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, ) def list_custom_source_definitions( + ctx: Context, *, workspace_id: Annotated[ str | None, @@ -1611,7 +1631,7 @@ def list_custom_source_definitions( Note: Only YAML (declarative) connectors are currently supported. Docker-based custom sources are not yet available. """ - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) definitions = workspace.list_custom_source_definitions( definition_type="yaml", ) @@ -1628,12 +1648,12 @@ def list_custom_source_definitions( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, ) def get_custom_source_definition( + ctx: Context, definition_id: Annotated[ str, Field(description="The ID of the custom source definition to retrieve."), @@ -1655,7 +1675,7 @@ def get_custom_source_definition( Note: Only YAML (declarative) connectors are currently supported. Docker-based custom sources are not yet available. """ - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) definition = workspace.get_custom_source_definition( definition_id=definition_id, definition_type="yaml", @@ -1672,11 +1692,11 @@ def get_custom_source_definition( @mcp_tool( - domain="cloud", destructive=True, open_world=True, ) def update_custom_source_definition( + ctx: Context, definition_id: Annotated[ str, Field(description="The ID of the definition to update."), @@ -1741,7 +1761,7 @@ def update_custom_source_definition( """ check_guid_created_in_session(definition_id) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) if manifest_yaml is None and testing_values is None and testing_values_secret_name is None: raise PyAirbyteInputError( @@ -1763,7 +1783,7 @@ def update_custom_source_definition( testing_values_dict: dict[str, Any] | None = None if testing_values is not None or testing_values_secret_name is not None: testing_values_dict = ( - resolve_config( + resolve_connector_config( config=testing_values, config_secret_name=testing_values_secret_name, ) @@ -1794,11 +1814,11 @@ def update_custom_source_definition( @mcp_tool( - domain="cloud", destructive=True, open_world=True, ) def permanently_delete_custom_source_definition( + ctx: Context, definition_id: Annotated[ str, Field(description="The ID of the custom source definition to delete."), @@ -1832,7 +1852,7 @@ def permanently_delete_custom_source_definition( Docker-based custom sources are not yet available. """ check_guid_created_in_session(definition_id) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) definition = workspace.get_custom_source_definition( definition_id=definition_id, definition_type="yaml", @@ -1861,12 +1881,12 @@ def permanently_delete_custom_source_definition( @mcp_tool( - domain="cloud", destructive=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def permanently_delete_cloud_source( + ctx: Context, source_id: Annotated[ str, Field(description="The ID of the deployed source to delete."), @@ -1889,7 +1909,7 @@ def permanently_delete_cloud_source( This is a safety measure to ensure you are deleting the correct resource. """ check_guid_created_in_session(source_id) - workspace: CloudWorkspace = _get_cloud_workspace() + workspace: CloudWorkspace = _get_cloud_workspace(ctx) source = workspace.get_source(source_id=source_id) actual_name: str = cast(str, source.name) @@ -1917,12 +1937,12 @@ def permanently_delete_cloud_source( @mcp_tool( - domain="cloud", destructive=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def permanently_delete_cloud_destination( + ctx: Context, destination_id: Annotated[ str, Field(description="The ID of the deployed destination to delete."), @@ -1945,7 +1965,7 @@ def permanently_delete_cloud_destination( This is a safety measure to ensure you are deleting the correct resource. """ check_guid_created_in_session(destination_id) - workspace: CloudWorkspace = _get_cloud_workspace() + workspace: CloudWorkspace = _get_cloud_workspace(ctx) destination = workspace.get_destination(destination_id=destination_id) actual_name: str = cast(str, destination.name) @@ -1973,12 +1993,12 @@ def permanently_delete_cloud_destination( @mcp_tool( - domain="cloud", destructive=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def permanently_delete_cloud_connection( + ctx: Context, connection_id: Annotated[ str, Field(description="The ID of the connection to delete."), @@ -2020,7 +2040,7 @@ def permanently_delete_cloud_connection( This is a safety measure to ensure you are deleting the correct resource. """ check_guid_created_in_session(connection_id) - workspace: CloudWorkspace = _get_cloud_workspace() + workspace: CloudWorkspace = _get_cloud_workspace(ctx) connection = workspace.get_connection(connection_id=connection_id) actual_name: str = cast(str, connection.name) @@ -2050,11 +2070,11 @@ def permanently_delete_cloud_connection( @mcp_tool( - domain="cloud", open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def rename_cloud_source( + ctx: Context, source_id: Annotated[ str, Field(description="The ID of the deployed source to rename."), @@ -2073,19 +2093,19 @@ def rename_cloud_source( ], ) -> str: """Rename a deployed source connector on Airbyte Cloud.""" - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) source = workspace.get_source(source_id=source_id) source.rename(name=name) return f"Successfully renamed source '{source_id}' to '{name}'. URL: {source.connector_url}" @mcp_tool( - domain="cloud", destructive=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def update_cloud_source_config( + ctx: Context, source_id: Annotated[ str, Field(description="The ID of the deployed source to update."), @@ -2118,10 +2138,10 @@ def update_cloud_source_config( configuration is changed incorrectly. Use with caution. """ check_guid_created_in_session(source_id) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) source = workspace.get_source(source_id=source_id) - config_dict = resolve_config( + config_dict = resolve_connector_config( config=config, config_secret_name=config_secret_name, config_spec_jsonschema=None, # We don't have the spec here @@ -2132,11 +2152,11 @@ def update_cloud_source_config( @mcp_tool( - domain="cloud", open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def rename_cloud_destination( + ctx: Context, destination_id: Annotated[ str, Field(description="The ID of the deployed destination to rename."), @@ -2155,7 +2175,7 @@ def rename_cloud_destination( ], ) -> str: """Rename a deployed destination connector on Airbyte Cloud.""" - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) destination = workspace.get_destination(destination_id=destination_id) destination.rename(name=name) return ( @@ -2165,12 +2185,12 @@ def rename_cloud_destination( @mcp_tool( - domain="cloud", destructive=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def update_cloud_destination_config( + ctx: Context, destination_id: Annotated[ str, Field(description="The ID of the deployed destination to update."), @@ -2203,10 +2223,10 @@ def update_cloud_destination_config( configuration is changed incorrectly. Use with caution. """ check_guid_created_in_session(destination_id) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) destination = workspace.get_destination(destination_id=destination_id) - config_dict = resolve_config( + config_dict = resolve_connector_config( config=config, config_secret_name=config_secret_name, config_spec_jsonschema=None, # We don't have the spec here @@ -2219,11 +2239,11 @@ def update_cloud_destination_config( @mcp_tool( - domain="cloud", open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def rename_cloud_connection( + ctx: Context, connection_id: Annotated[ str, Field(description="The ID of the connection to rename."), @@ -2242,7 +2262,7 @@ def rename_cloud_connection( ], ) -> str: """Rename a connection on Airbyte Cloud.""" - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connection = workspace.get_connection(connection_id=connection_id) connection.rename(name=name) return ( @@ -2252,12 +2272,12 @@ def rename_cloud_connection( @mcp_tool( - domain="cloud", destructive=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def set_cloud_connection_table_prefix( + ctx: Context, connection_id: Annotated[ str, Field(description="The ID of the connection to update."), @@ -2281,7 +2301,7 @@ def set_cloud_connection_table_prefix( table prefix is changed incorrectly. Use with caution. """ check_guid_created_in_session(connection_id) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connection = workspace.get_connection(connection_id=connection_id) connection.set_table_prefix(prefix=prefix) return ( @@ -2291,12 +2311,12 @@ def set_cloud_connection_table_prefix( @mcp_tool( - domain="cloud", destructive=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def set_cloud_connection_selected_streams( + ctx: Context, connection_id: Annotated[ str, Field(description="The ID of the connection to update."), @@ -2325,7 +2345,7 @@ def set_cloud_connection_selected_streams( stream selection is changed incorrectly. Use with caution. """ check_guid_created_in_session(connection_id) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connection = workspace.get_connection(connection_id=connection_id) resolved_streams_list: list[str] = resolve_list_of_strings(stream_names) @@ -2338,12 +2358,12 @@ def set_cloud_connection_selected_streams( @mcp_tool( - domain="cloud", open_world=True, destructive=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def update_cloud_connection( + ctx: Context, connection_id: Annotated[ str, Field(description="The ID of the connection to update."), @@ -2421,7 +2441,7 @@ def update_cloud_connection( "for manual-only syncs." ) - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connection = workspace.get_connection(connection_id=connection_id) changes_made: list[str] = [] @@ -2448,13 +2468,13 @@ def update_cloud_connection( @mcp_tool( - domain="cloud", read_only=True, idempotent=True, open_world=True, extra_help_text=CLOUD_AUTH_TIP_TEXT, ) def get_connection_artifact( + ctx: Context, connection_id: Annotated[ str, Field(description="The ID of the Airbyte Cloud connection."), @@ -2480,7 +2500,7 @@ def get_connection_artifact( - 'catalog': Returns the configured catalog (syncCatalog) as a dict, or {"ERROR": "..."} if not found. """ - workspace: CloudWorkspace = _get_cloud_workspace(workspace_id) + workspace: CloudWorkspace = _get_cloud_workspace(ctx, workspace_id) connection = workspace.get_connection(connection_id=connection_id) if artifact_type == "state": @@ -2496,14 +2516,14 @@ def get_connection_artifact( return result -def register_cloud_ops_tools(app: FastMCP) -> None: - """@private Register tools with the FastMCP app. - - This is an internal function and should not be called directly. +def register_cloud_tools(app: FastMCP) -> None: + """Register cloud tools with the FastMCP app. - Tools are filtered based on mode settings: - - AIRBYTE_CLOUD_MCP_READONLY_MODE=1: Only read-only tools are registered - - AIRBYTE_CLOUD_MCP_SAFE_MODE=1: All tools are registered, but destructive - operations are protected by runtime session checks + Args: + app: FastMCP application instance """ - register_tools(app, domain="cloud") + register_mcp_tools( + app, + mcp_module=__name__, + exclude_args=["workspace_id"] if AIRBYTE_CLOUD_WORKSPACE_ID_IS_SET else None, + ) diff --git a/airbyte/mcp/local_ops.py b/airbyte/mcp/local.py similarity index 96% rename from airbyte/mcp/local_ops.py rename to airbyte/mcp/local.py index df018a2b4..56216ff95 100644 --- a/airbyte/mcp/local_ops.py +++ b/airbyte/mcp/local.py @@ -8,13 +8,13 @@ from typing import TYPE_CHECKING, Annotated, Any, Literal from fastmcp import FastMCP +from fastmcp_extensions import mcp_tool, register_mcp_tools from pydantic import BaseModel, Field from airbyte import get_source from airbyte._util.meta import is_docker_installed from airbyte.caches.util import get_default_cache -from airbyte.mcp._tool_utils import mcp_tool, register_tools -from airbyte.mcp._util import resolve_config, resolve_list_of_strings +from airbyte.mcp._arg_resolvers import resolve_connector_config, resolve_list_of_strings from airbyte.registry import get_connector_metadata from airbyte.secrets.config import _get_secret_sources from airbyte.secrets.env_vars import DotenvSecretManager @@ -106,7 +106,6 @@ def _get_mcp_source( @mcp_tool( - domain="local", read_only=True, idempotent=True, extra_help_text=_CONFIG_HELP, @@ -167,7 +166,7 @@ def validate_connector_config( return False, f"Failed to get connector '{connector_name}': {ex}" try: - config_dict = resolve_config( + config_dict = resolve_connector_config( config=config, config_file=config_file, config_secret_name=config_secret_name, @@ -186,7 +185,6 @@ def validate_connector_config( @mcp_tool( - domain="local", read_only=True, idempotent=True, ) @@ -216,7 +214,6 @@ def list_connector_config_secrets( @mcp_tool( - domain="local", read_only=True, idempotent=True, extra_help_text=_CONFIG_HELP, @@ -236,7 +233,6 @@ def list_dotenv_secrets() -> dict[str, list[str]]: @mcp_tool( - domain="local", read_only=True, idempotent=True, extra_help_text=_CONFIG_HELP, @@ -292,7 +288,7 @@ def list_source_streams( override_execution_mode=override_execution_mode, manifest_path=manifest_path, ) - config_dict = resolve_config( + config_dict = resolve_connector_config( config=config, config_file=config_file, config_secret_name=config_secret_name, @@ -303,7 +299,6 @@ def list_source_streams( @mcp_tool( - domain="local", read_only=True, idempotent=True, extra_help_text=_CONFIG_HELP, @@ -360,7 +355,7 @@ def get_source_stream_json_schema( override_execution_mode=override_execution_mode, manifest_path=manifest_path, ) - config_dict = resolve_config( + config_dict = resolve_connector_config( config=config, config_file=config_file, config_secret_name=config_secret_name, @@ -371,7 +366,6 @@ def get_source_stream_json_schema( @mcp_tool( - domain="local", read_only=True, extra_help_text=_CONFIG_HELP, ) @@ -436,7 +430,7 @@ def read_source_stream_records( override_execution_mode=override_execution_mode, manifest_path=manifest_path, ) - config_dict = resolve_config( + config_dict = resolve_connector_config( config=config, config_file=config_file, config_secret_name=config_secret_name, @@ -462,7 +456,6 @@ def read_source_stream_records( @mcp_tool( - domain="local", read_only=True, extra_help_text=_CONFIG_HELP, ) @@ -537,7 +530,7 @@ def get_stream_previews( manifest_path=manifest_path, ) - config_dict = resolve_config( + config_dict = resolve_connector_config( config=config, config_file=config_file, config_secret_name=config_secret_name, @@ -575,7 +568,6 @@ def get_stream_previews( @mcp_tool( - domain="local", destructive=False, extra_help_text=_CONFIG_HELP, ) @@ -634,7 +626,7 @@ def sync_source_to_cache( override_execution_mode=override_execution_mode, manifest_path=manifest_path, ) - config_dict = resolve_config( + config_dict = resolve_connector_config( config=config, config_file=config_file, config_secret_name=config_secret_name, @@ -684,7 +676,6 @@ class CachedDatasetInfo(BaseModel): @mcp_tool( - domain="local", read_only=True, idempotent=True, extra_help_text=_CONFIG_HELP, @@ -705,7 +696,6 @@ def list_cached_streams() -> list[CachedDatasetInfo]: @mcp_tool( - domain="local", read_only=True, idempotent=True, extra_help_text=_CONFIG_HELP, @@ -757,7 +747,6 @@ def _is_safe_sql(sql_query: str) -> bool: @mcp_tool( - domain="local", read_only=True, idempotent=True, extra_help_text=_CONFIG_HELP, @@ -815,9 +804,10 @@ def run_sql_query( del cache # Ensure the cache is closed properly -def register_local_ops_tools(app: FastMCP) -> None: - """@private Register tools with the FastMCP app. +def register_local_tools(app: FastMCP) -> None: + """Register local tools with the FastMCP app. - This is an internal function and should not be called directly. + Args: + app: FastMCP application instance """ - register_tools(app, domain="local") + register_mcp_tools(app, mcp_module=__name__) diff --git a/airbyte/mcp/prompts.py b/airbyte/mcp/prompts.py index 87c0f0ab8..47993aef6 100644 --- a/airbyte/mcp/prompts.py +++ b/airbyte/mcp/prompts.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Annotated +from fastmcp_extensions import mcp_prompt, register_mcp_prompts from pydantic import Field @@ -38,6 +39,10 @@ """.strip() +@mcp_prompt( + name="test-my-tools", + description="Test all available MCP tools to confirm they are working properly", +) def test_my_tools_prompt( scope: Annotated[ str | None, @@ -65,8 +70,9 @@ def test_my_tools_prompt( def register_prompts(app: FastMCP) -> None: - """Register all prompts with the FastMCP app.""" - app.prompt( - name="test-my-tools", - description="Test all available MCP tools to confirm they are working properly", - )(test_my_tools_prompt) + """Register prompts with the FastMCP app. + + Args: + app: FastMCP application instance + """ + register_mcp_prompts(app, mcp_module=__name__) diff --git a/airbyte/mcp/connector_registry.py b/airbyte/mcp/registry.py similarity index 95% rename from airbyte/mcp/connector_registry.py rename to airbyte/mcp/registry.py index 7571c199c..77b44f061 100644 --- a/airbyte/mcp/connector_registry.py +++ b/airbyte/mcp/registry.py @@ -9,12 +9,12 @@ import requests from fastmcp import FastMCP +from fastmcp_extensions import mcp_tool, register_mcp_tools from pydantic import BaseModel, Field from airbyte import exceptions as exc from airbyte._util.meta import is_docker_installed -from airbyte.mcp._tool_utils import mcp_tool, register_tools -from airbyte.mcp._util import resolve_list_of_strings +from airbyte.mcp._arg_resolvers import resolve_list_of_strings from airbyte.registry import ( _DEFAULT_MANIFEST_URL, ApiDocsUrl, @@ -33,7 +33,6 @@ @mcp_tool( - domain="registry", read_only=True, idempotent=True, ) @@ -131,7 +130,6 @@ class ConnectorInfo(BaseModel): @mcp_tool( - domain="registry", read_only=True, idempotent=True, ) @@ -176,7 +174,6 @@ def get_connector_info( @mcp_tool( - domain="registry", read_only=True, idempotent=True, ) @@ -204,7 +201,6 @@ def get_api_docs_urls( @mcp_tool( - domain="registry", read_only=True, idempotent=True, ) @@ -268,9 +264,10 @@ def get_connector_version_history( return versions -def register_connector_registry_tools(app: FastMCP) -> None: - """@private Register tools with the FastMCP app. +def register_registry_tools(app: FastMCP) -> None: + """Register registry tools with the FastMCP app. - This is an internal function and should not be called directly. + Args: + app: FastMCP application instance """ - register_tools(app, domain="registry") + register_mcp_tools(app, mcp_module=__name__) diff --git a/airbyte/mcp/server.py b/airbyte/mcp/server.py index 4637a3811..5464e060b 100644 --- a/airbyte/mcp/server.py +++ b/airbyte/mcp/server.py @@ -1,17 +1,31 @@ # Copyright (c) 2024 Airbyte, Inc., all rights reserved. """Experimental MCP (Model Context Protocol) server for PyAirbyte connector management.""" +from __future__ import annotations + import asyncio import sys -from fastmcp import FastMCP +from fastmcp_extensions import mcp_server from airbyte._util.meta import set_mcp_mode -from airbyte.mcp._util import initialize_secrets -from airbyte.mcp.cloud_ops import register_cloud_ops_tools -from airbyte.mcp.connector_registry import register_connector_registry_tools -from airbyte.mcp.local_ops import register_local_ops_tools +from airbyte.mcp._config import load_secrets_to_env_vars +from airbyte.mcp._tool_utils import ( + AIRBYTE_EXCLUDE_MODULES_CONFIG_ARG, + AIRBYTE_INCLUDE_MODULES_CONFIG_ARG, + AIRBYTE_READONLY_MODE_CONFIG_ARG, + API_URL_CONFIG_ARG, + BEARER_TOKEN_CONFIG_ARG, + CLIENT_ID_CONFIG_ARG, + CLIENT_SECRET_CONFIG_ARG, + WORKSPACE_ID_CONFIG_ARG, + airbyte_module_filter, + airbyte_readonly_mode_filter, +) +from airbyte.mcp.cloud import register_cloud_tools +from airbyte.mcp.local import register_local_tools from airbyte.mcp.prompts import register_prompts +from airbyte.mcp.registry import register_registry_tools # ============================================================================= @@ -49,14 +63,34 @@ """.strip() set_mcp_mode() -initialize_secrets() - -app: FastMCP = FastMCP("airbyte-mcp", instructions=MCP_SERVER_INSTRUCTIONS) +load_secrets_to_env_vars() + +app = mcp_server( + name="airbyte-mcp", + package_name="airbyte", + instructions=MCP_SERVER_INSTRUCTIONS, + include_standard_tool_filters=True, + server_config_args=[ + AIRBYTE_READONLY_MODE_CONFIG_ARG, + AIRBYTE_EXCLUDE_MODULES_CONFIG_ARG, + AIRBYTE_INCLUDE_MODULES_CONFIG_ARG, + WORKSPACE_ID_CONFIG_ARG, + BEARER_TOKEN_CONFIG_ARG, + CLIENT_ID_CONFIG_ARG, + CLIENT_SECRET_CONFIG_ARG, + API_URL_CONFIG_ARG, + ], + tool_filters=[ + airbyte_readonly_mode_filter, + airbyte_module_filter, + ], +) """The Airbyte MCP Server application instance.""" -register_connector_registry_tools(app) -register_local_ops_tools(app) -register_cloud_ops_tools(app) +# Register tools from each module +register_cloud_tools(app) +register_local_tools(app) +register_registry_tools(app) register_prompts(app) diff --git a/bin/test_mcp_tool.py b/bin/test_mcp_tool.py deleted file mode 100755 index d92c32f89..000000000 --- a/bin/test_mcp_tool.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2024 Airbyte, Inc., all rights reserved. -"""One-liner CLI tool for testing PyAirbyte MCP tools directly with JSON arguments. - -Usage: - poe mcp-tool-test '' - -Examples: - poe mcp-tool-test list_connectors '{}' - poe mcp-tool-test get_config_spec '{"connector_name": "source-pokeapi"}' - poe mcp-tool-test validate_config \ - '{"connector_name": "source-pokeapi", "config": {"pokemon_name": "pikachu"}}' - poe mcp-tool-test run_sync \ - '{"connector_name": "source-pokeapi", "config": {"pokemon_name": "pikachu"}}' - - poe mcp-tool-test check_airbyte_cloud_workspace '{}' - poe mcp-tool-test list_deployed_cloud_connections '{}' - poe mcp-tool-test get_cloud_sync_status \ - '{"connection_id": "0791e193-811b-4fcf-91c3-f8c5963e74a0", "include_attempts": true}' - poe mcp-tool-test get_cloud_sync_logs \ - '{"connection_id": "0791e193-811b-4fcf-91c3-f8c5963e74a0"}' -""" - -import asyncio -import json -import sys -import traceback -from typing import Any - -from fastmcp import Client - -from airbyte.mcp.server import app - - -MIN_ARGS = 3 - - -async def call_mcp_tool(tool_name: str, args: dict[str, Any]) -> object: - """Call an MCP tool using the FastMCP client.""" - async with Client(app) as client: - return await client.call_tool(tool_name, args) - - -def main() -> None: - """Main entry point for the MCP tool tester.""" - if len(sys.argv) < MIN_ARGS: - print(__doc__, file=sys.stderr) - sys.exit(1) - - tool_name = sys.argv[1] - json_args = sys.argv[2] - - try: - args: dict[str, Any] = json.loads(json_args) - except json.JSONDecodeError as e: - print(f"Error parsing JSON arguments: {e}", file=sys.stderr) - sys.exit(1) - - try: - result = asyncio.run(call_mcp_tool(tool_name, args)) - - if hasattr(result, "text"): - print(result.text) - else: - print(str(result)) - - except Exception as e: - print(f"Error executing tool '{tool_name}': {e}", file=sys.stderr) - traceback.print_exc() - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index c453607f9..64963bc85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "typing-extensions", "uuid7>=0.1.0,<1.0", "fastmcp>=2.11.3,<3.0.0", + "fastmcp-extensions>=0.2.0,<1.0.0", "uv>=0.5.0,<0.9.0", ] @@ -171,7 +172,7 @@ mcp-serve-local = { cmd = "airbyte-mcp", help = "Start the MCP server with STDIO mcp-serve-http = { cmd = "python -c \"from airbyte.mcp.server import app; app.run(transport='http', host='127.0.0.1', port=8000)\"", help = "Start the MCP server with HTTP transport" } mcp-serve-sse = { cmd = "python -c \"from airbyte.mcp.server import app; app.run(transport='sse', host='127.0.0.1', port=8000)\"", help = "Start the MCP server with SSE transport" } mcp-inspect = { cmd = "fastmcp inspect airbyte/mcp/server.py:app", help = "Inspect MCP tools and resources (supports --tools, --health, etc.)" } -mcp-tool-test = { cmd = "python bin/test_mcp_tool.py", help = "Test MCP tools directly with JSON arguments: poe mcp-tool-test ''" } +mcp-tool-test = { cmd = "python -m fastmcp_extensions.utils.test_tool --app airbyte.mcp.server:app", help = "Test MCP tools directly with JSON arguments: poe mcp-tool-test ''" } # Claude Code MCP Testing Tasks [tool.poe.tasks.test-my-tools] diff --git a/tests/unit_tests/test_mcp_connector_registry.py b/tests/unit_tests/test_mcp_connector_registry.py index 5f70ea167..63476619f 100644 --- a/tests/unit_tests/test_mcp_connector_registry.py +++ b/tests/unit_tests/test_mcp_connector_registry.py @@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch from airbyte import exceptions as exc -from airbyte.mcp.connector_registry import get_api_docs_urls +from airbyte.mcp.registry import get_api_docs_urls from airbyte.registry import ( ApiDocsUrl, _fetch_manifest_dict, @@ -132,9 +132,7 @@ class TestGetApiDocsUrls: def test_connector_not_found(self) -> None: """Test handling when connector is not found.""" - with patch( - "airbyte.mcp.connector_registry.get_connector_api_docs_urls" - ) as mock_get_docs: + with patch("airbyte.mcp.registry.get_connector_api_docs_urls") as mock_get_docs: mock_get_docs.side_effect = exc.AirbyteConnectorNotRegisteredError( connector_name="nonexistent-connector", context={}, @@ -145,9 +143,7 @@ def test_connector_not_found(self) -> None: def test_deduplication_of_urls(self) -> None: """Test that duplicate URLs are deduplicated.""" - with patch( - "airbyte.mcp.connector_registry.get_connector_api_docs_urls" - ) as mock_get_docs: + with patch("airbyte.mcp.registry.get_connector_api_docs_urls") as mock_get_docs: mock_get_docs.return_value = [ ApiDocsUrl( title="Airbyte Documentation", diff --git a/tests/unit_tests/test_mcp_tool_utils.py b/tests/unit_tests/test_mcp_tool_utils.py index d0c8792d7..e8013b377 100644 --- a/tests/unit_tests/test_mcp_tool_utils.py +++ b/tests/unit_tests/test_mcp_tool_utils.py @@ -3,151 +3,66 @@ from __future__ import annotations -import importlib -import warnings from unittest.mock import patch import pytest -import airbyte.constants as constants -import airbyte.mcp._tool_utils as tool_utils -from airbyte.mcp._annotations import READ_ONLY_HINT - -# (enabled, disabled, domain, readonly_mode, is_readonly, domain_enabled, should_register) -_DOMAIN_CASES = [ - (None, None, "cloud", False, False, True, True), - (None, None, "registry", False, False, True, True), - (None, None, "local", False, False, True, True), - (["cloud"], None, "cloud", False, False, True, True), - (["cloud"], None, "registry", False, False, False, False), - (None, ["registry"], "registry", False, False, False, False), - (None, ["registry"], "cloud", False, False, True, True), - (["registry", "cloud"], ["registry"], "cloud", False, False, True, True), - (["registry", "cloud"], ["registry"], "registry", False, False, False, False), - (["cloud"], ["registry"], "local", False, False, False, False), - (["CLOUD"], None, "cloud", False, False, True, True), - (["cloud"], None, "CLOUD", False, False, True, True), - (None, None, "cloud", True, False, True, False), - (None, None, "cloud", True, True, True, True), - (None, None, "registry", True, False, True, True), - (["cloud"], None, "cloud", True, True, True, True), - (["registry"], None, "cloud", True, True, False, False), -] - - -@pytest.mark.parametrize( - "enabled,disabled,domain,readonly_mode,is_readonly,domain_enabled,should_register", - _DOMAIN_CASES, +from airbyte.mcp._tool_utils import ( + SafeModeError, + _GUIDS_CREATED_IN_SESSION, + check_guid_created_in_session, + register_guid_created_in_session, ) -def test_domain_logic( - enabled: list[str] | None, - disabled: list[str] | None, - domain: str, - readonly_mode: bool, - is_readonly: bool, - domain_enabled: bool, - should_register: bool, -) -> None: - norm_enabled = [d.lower() for d in enabled] if enabled else None - norm_disabled = [d.lower() for d in disabled] if disabled else None - with ( - patch("airbyte.mcp._tool_utils.AIRBYTE_MCP_DOMAINS", norm_enabled), - patch("airbyte.mcp._tool_utils.AIRBYTE_MCP_DOMAINS_DISABLED", norm_disabled), - patch("airbyte.mcp._tool_utils.AIRBYTE_CLOUD_MCP_READONLY_MODE", readonly_mode), - ): - tool_utils._resolve_mcp_domain_filters.cache_clear() - assert tool_utils.is_domain_enabled(domain) == domain_enabled - assert ( - tool_utils.should_register_tool({ - "domain": domain, - READ_ONLY_HINT: is_readonly, - }) - == should_register - ) - - -# (env_var, attr, env_value, expected) -_ENV_PARSE_CASES = [ - ("AIRBYTE_MCP_DOMAINS", "AIRBYTE_MCP_DOMAINS", "", None), - ("AIRBYTE_MCP_DOMAINS", "AIRBYTE_MCP_DOMAINS", "cloud", ["cloud"]), - ( - "AIRBYTE_MCP_DOMAINS", - "AIRBYTE_MCP_DOMAINS", - "registry,cloud", - ["registry", "cloud"], - ), - ( - "AIRBYTE_MCP_DOMAINS", - "AIRBYTE_MCP_DOMAINS", - "registry, cloud", - ["registry", "cloud"], - ), - ( - "AIRBYTE_MCP_DOMAINS", - "AIRBYTE_MCP_DOMAINS", - "REGISTRY,CLOUD", - ["registry", "cloud"], - ), - ( - "AIRBYTE_MCP_DOMAINS", - "AIRBYTE_MCP_DOMAINS", - "registry,,cloud", - ["registry", "cloud"], - ), - ("AIRBYTE_MCP_DOMAINS_DISABLED", "AIRBYTE_MCP_DOMAINS_DISABLED", "", None), - ( - "AIRBYTE_MCP_DOMAINS_DISABLED", - "AIRBYTE_MCP_DOMAINS_DISABLED", - "registry", - ["registry"], - ), - ( - "AIRBYTE_MCP_DOMAINS_DISABLED", - "AIRBYTE_MCP_DOMAINS_DISABLED", - "registry,local", - ["registry", "local"], - ), -] - - -@pytest.mark.parametrize("env_var,attr,env_value,expected", _ENV_PARSE_CASES) -def test_env_parsing( - env_var: str, attr: str, env_value: str, expected: list[str] | None -) -> None: - with patch.dict("os.environ", {env_var: env_value}, clear=False): - importlib.reload(constants) - assert getattr(constants, attr) == expected - importlib.reload(constants) - - -# (env_var, env_value, warning_fragment) -_WARNING_CASES = [ - ( - "AIRBYTE_MCP_DOMAINS", - "cloud,invalid", - "AIRBYTE_MCP_DOMAINS contains unknown domain(s)", - ), - ( - "AIRBYTE_MCP_DOMAINS_DISABLED", - "registry,fake", - "AIRBYTE_MCP_DOMAINS_DISABLED contains unknown domain(s)", - ), -] - - -@pytest.mark.parametrize("env_var,env_value,fragment", _WARNING_CASES) -def test_unknown_domain_warning(env_var: str, env_value: str, fragment: str) -> None: - with ( - patch.dict("os.environ", {env_var: env_value}, clear=False), - warnings.catch_warnings(record=True) as caught, - ): - warnings.simplefilter("always") - importlib.reload(constants) - importlib.reload(tool_utils) - tool_utils._resolve_mcp_domain_filters.cache_clear() - tool_utils._resolve_mcp_domain_filters() - messages = [str(w.message) for w in caught] - assert any(fragment in m for m in messages) - assert any("Known MCP domains are:" in m for m in messages) - importlib.reload(constants) - importlib.reload(tool_utils) + + +@pytest.fixture(autouse=True) +def clear_session_guids() -> None: + """Clear the session GUIDs before each test.""" + _GUIDS_CREATED_IN_SESSION.clear() + + +def test_register_guid_created_in_session() -> None: + """Test that GUIDs can be registered as created in session.""" + assert "test-guid-123" not in _GUIDS_CREATED_IN_SESSION + register_guid_created_in_session("test-guid-123") + assert "test-guid-123" in _GUIDS_CREATED_IN_SESSION + + +def test_check_guid_created_in_session_passes_for_registered_guid() -> None: + """Test that check passes for GUIDs registered in session.""" + register_guid_created_in_session("test-guid-456") + # Should not raise + check_guid_created_in_session("test-guid-456") + + +def test_check_guid_created_in_session_raises_for_unregistered_guid() -> None: + """Test that check raises SafeModeError for unregistered GUIDs when safe mode is enabled.""" + with patch("airbyte.mcp._tool_utils.AIRBYTE_CLOUD_MCP_SAFE_MODE", True): + with pytest.raises(SafeModeError) as exc_info: + check_guid_created_in_session("unregistered-guid") + assert "unregistered-guid" in str(exc_info.value) + assert "not created in this session" in str(exc_info.value) + + +def test_check_guid_created_in_session_passes_when_safe_mode_disabled() -> None: + """Test that check passes for any GUID when safe mode is disabled.""" + with patch("airbyte.mcp._tool_utils.AIRBYTE_CLOUD_MCP_SAFE_MODE", False): + # Should not raise even for unregistered GUID + check_guid_created_in_session("any-guid-at-all") + + +def test_multiple_guids_can_be_registered() -> None: + """Test that multiple GUIDs can be registered in the same session.""" + guids = ["guid-1", "guid-2", "guid-3"] + for guid in guids: + register_guid_created_in_session(guid) + + for guid in guids: + assert guid in _GUIDS_CREATED_IN_SESSION + + +def test_duplicate_guid_registration_is_idempotent() -> None: + """Test that registering the same GUID multiple times is safe.""" + register_guid_created_in_session("duplicate-guid") + register_guid_created_in_session("duplicate-guid") + assert "duplicate-guid" in _GUIDS_CREATED_IN_SESSION diff --git a/uv.lock b/uv.lock index e0ff68e18..bf02e74e1 100644 --- a/uv.lock +++ b/uv.lock @@ -122,6 +122,7 @@ dependencies = [ { name = "duckdb" }, { name = "duckdb-engine" }, { name = "fastmcp" }, + { name = "fastmcp-extensions" }, { name = "google-auth" }, { name = "google-cloud-bigquery" }, { name = "google-cloud-bigquery-storage" }, @@ -186,6 +187,7 @@ requires-dist = [ { name = "duckdb", specifier = "==1.4.3" }, { name = "duckdb-engine", specifier = "==0.17.0" }, { name = "fastmcp", specifier = ">=2.11.3,<3.0.0" }, + { name = "fastmcp-extensions", specifier = ">=0.2.0,<1.0.0" }, { name = "google-auth", specifier = ">=2.27.0,<3.0" }, { name = "google-cloud-bigquery", specifier = ">=3.12.0,<4.0" }, { name = "google-cloud-bigquery-storage", specifier = ">=2.25.0,<3.0" }, @@ -1120,6 +1122,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fc/dc/f7dd14213bf511690dccaa5094d436947c253b418c86c86211d1c76e6e44/fastmcp-2.14.3-py3-none-any.whl", hash = "sha256:103c6b4c6e97a9acc251c81d303f110fe4f2bdba31353df515d66272bf1b9414", size = 416220 }, ] +[[package]] +name = "fastmcp-extensions" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastmcp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/38/95549e7bb6bfe6ffedb5c388488b7c81b6a70d5e649d5a08047687a0e811/fastmcp_extensions-0.2.0.tar.gz", hash = "sha256:c456d4d00a96d9fe41b630e51cc6cb4b9920796e6943185e797669d10fe7e917", size = 156381 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/bc/0d2edadb8629afaaec9f05fe05ca9184256d8f046601f71d07cbc8ff8aeb/fastmcp_extensions-0.2.0-py3-none-any.whl", hash = "sha256:b48f13ecfbceb8e5bc75569e41029f451efa2f0f69b390cf2ad23ed41d1160e0", size = 34431 }, +] + [[package]] name = "filelock" version = "3.20.3"