diff --git a/a2a/weather_service/src/weather_service/agent.py b/a2a/weather_service/src/weather_service/agent.py index 0d289728..745c7677 100644 --- a/a2a/weather_service/src/weather_service/agent.py +++ b/a2a/weather_service/src/weather_service/agent.py @@ -1,5 +1,6 @@ import logging import os +import re from textwrap import dedent import uvicorn @@ -14,6 +15,7 @@ from a2a.server.tasks import InMemoryTaskStore, TaskUpdater from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart from a2a.utils import new_agent_text_message, new_task +from weather_service.configuration import Configuration from weather_service.graph import get_graph, get_mcpclient from weather_service.observability import ( create_tracing_middleware, @@ -21,7 +23,53 @@ set_span_output, ) -logging.basicConfig(level=logging.DEBUG) + +class SecretRedactionFilter(logging.Filter): + """Redacts Bearer tokens and API keys from log messages. + + Covers three layers: + 1. Bearer tokens in Authorization headers (any format). + 2. OpenAI-style ``sk-*`` API keys (pattern-based). + 3. The literal configured ``LLM_API_KEY`` value (provider-agnostic, + catches non-standard key formats like RHOAI MaaS 32-char keys). + """ + + _BEARER_RE = re.compile(r"(Bearer\s+)\S+", re.IGNORECASE) + _API_KEY_RE = re.compile(r"(sk-[a-zA-Z0-9]{3})[a-zA-Z0-9]+") + + def __init__(self, name: str = ""): + super().__init__(name) + # Redact the literal configured key when it's long enough to be real + configured_key = os.environ.get("LLM_API_KEY", "").strip() + if len(configured_key) > 8: + self._literal_key_re: re.Pattern | None = re.compile(re.escape(configured_key)) + else: + self._literal_key_re = None + + def _redact(self, text: str) -> str: + text = self._BEARER_RE.sub(r"\1[REDACTED]", text) + text = self._API_KEY_RE.sub(r"\1...[REDACTED]", text) + if self._literal_key_re is not None: + text = self._literal_key_re.sub("[REDACTED]", text) + return text + + def filter(self, record: logging.LogRecord) -> bool: + if isinstance(record.msg, str): + record.msg = self._redact(record.msg) + if record.args: + args = record.args if isinstance(record.args, tuple) else (record.args,) + new_args = [] + for arg in args: + if isinstance(arg, str): + arg = self._redact(arg) + new_args.append(arg) + record.args = tuple(new_args) + return True + + +logging.basicConfig(level=logging.INFO) +# Apply secret redaction filter to the root logger so all loggers benefit +logging.getLogger().addFilter(SecretRedactionFilter()) logger = logging.getLogger(__name__) @@ -111,6 +159,17 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): task_updater = TaskUpdater(event_queue, task.id, task.context_id) event_emitter = A2AEvent(task_updater) + # Check API key before attempting LLM calls + config = Configuration() + if not config.has_valid_api_key: + await event_emitter.emit_event( + "Error: No LLM API key configured. Please set the LLM_API_KEY " + "environment variable (or charts.kagenti.values.secrets.openaiApiKey " + "during Kagenti installation).", + failed=True, + ) + return + # Get user input for the agent user_input = context.get_user_input() @@ -216,14 +275,4 @@ def run(): # Add tracing middleware - creates root span with MLflow/GenAI attributes app.add_middleware(BaseHTTPMiddleware, dispatch=create_tracing_middleware()) - # Add logging middleware - @app.middleware("http") - async def log_authorization_header(request, call_next): - auth_header = request.headers.get("authorization", "No Authorization header") - logger.info( - f"🔐 Incoming request to {request.url.path} with Authorization: {auth_header[:80] + '...' if len(auth_header) > 80 else auth_header}" - ) - response = await call_next(request) - return response - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/a2a/weather_service/src/weather_service/configuration.py b/a2a/weather_service/src/weather_service/configuration.py index dd1c9f4c..7a2e9ef9 100644 --- a/a2a/weather_service/src/weather_service/configuration.py +++ b/a2a/weather_service/src/weather_service/configuration.py @@ -1,7 +1,44 @@ +import logging + from pydantic_settings import BaseSettings +logger = logging.getLogger(__name__) + +_PLACEHOLDER_KEYS = {"dummy", "changeme", "your-api-key-here", ""} + +# API bases that are known to accept placeholder/dummy keys (local LLMs) +_LOCAL_LLM_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0"} + class Configuration(BaseSettings): llm_model: str = "llama3.1" llm_api_base: str = "http://localhost:11434/v1" llm_api_key: str = "dummy" + + @property + def is_local_llm(self) -> bool: + """Check if the API base points to a local LLM (e.g. Ollama).""" + from urllib.parse import urlparse + + parsed = urlparse(self.llm_api_base) + hostname = parsed.hostname or "" + return hostname in _LOCAL_LLM_HOSTS + + @property + def has_valid_api_key(self) -> bool: + """Check if the API key is usable. + + Local LLMs (Ollama, vLLM, etc.) accept any key including placeholders, + so placeholder keys are only flagged when pointing at a remote API. + """ + if self.is_local_llm: + return True + return self.llm_api_key.strip() not in _PLACEHOLDER_KEYS + + def log_warnings(self) -> None: + """Log warnings about configuration issues at startup.""" + if not self.has_valid_api_key: + logger.warning( + "No LLM API key configured (set LLM_API_KEY env var). " + "The weather agent will not be able to call the LLM." + ) diff --git a/a2a/weather_service/src/weather_service/graph.py b/a2a/weather_service/src/weather_service/graph.py index cd9e8cf2..b9d34c59 100644 --- a/a2a/weather_service/src/weather_service/graph.py +++ b/a2a/weather_service/src/weather_service/graph.py @@ -9,6 +9,7 @@ from weather_service.configuration import Configuration config = Configuration() +config.log_warnings() # Extend MessagesState to include a final answer diff --git a/tests/a2a/test_weather_secret_redaction.py b/tests/a2a/test_weather_secret_redaction.py new file mode 100644 index 00000000..1c9ef348 --- /dev/null +++ b/tests/a2a/test_weather_secret_redaction.py @@ -0,0 +1,240 @@ +"""Tests for secret redaction and API key validation in the weather service. + +Loads agent.py and configuration.py in isolation (same approach as +test_weather_service.py) to avoid pulling in heavy deps like opentelemetry. +""" + +import importlib.util +import logging +import pathlib +import sys +from types import ModuleType +from unittest.mock import MagicMock + +# --- Isolation setup (must happen before any weather_service imports) --- +_fake_ws = ModuleType("weather_service") +_fake_ws.__path__ = [] # type: ignore[attr-defined] +sys.modules.setdefault("weather_service", _fake_ws) +sys.modules.setdefault("weather_service.observability", MagicMock()) + +_BASE = pathlib.Path(__file__).parent.parent.parent / "a2a" / "weather_service" / "src" / "weather_service" + + +def _load_module(name: str, path: pathlib.Path): + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) # type: ignore[arg-type] + sys.modules[name] = mod + spec.loader.exec_module(mod) # type: ignore[union-attr] + return mod + + +_config_mod = _load_module("weather_service.configuration", _BASE / "configuration.py") + +# Mock modules that agent.py imports but we don't need +for mod_name in [ + "uvicorn", + "langchain_core", + "langchain_core.messages", + "starlette", + "starlette.middleware", + "starlette.middleware.base", + "starlette.routing", + "a2a", + "a2a.server", + "a2a.server.agent_execution", + "a2a.server.apps", + "a2a.server.events", + "a2a.server.events.event_queue", + "a2a.server.request_handlers", + "a2a.server.tasks", + "a2a.types", + "a2a.utils", + "weather_service.graph", +]: + sys.modules.setdefault(mod_name, MagicMock()) + +_agent_mod = _load_module("weather_service.agent", _BASE / "agent.py") + +Configuration = _config_mod.Configuration +SecretRedactionFilter = _agent_mod.SecretRedactionFilter + + +# --- Tests --- + + +class TestSecretRedactionFilter: + """Test the logging filter that redacts Bearer tokens and API keys.""" + + def setup_method(self): + self.filt = SecretRedactionFilter() + + def _make_record(self, msg: str, args=None) -> logging.LogRecord: + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg=msg, + args=args, + exc_info=None, + ) + return record + + def test_redacts_bearer_token_in_msg(self): + record = self._make_record("Authorization: Bearer sk-abc123xyz789secret") + self.filt.filter(record) + assert "sk-abc123xyz789secret" not in record.msg + assert "[REDACTED]" in record.msg + + def test_redacts_bearer_token_case_insensitive(self): + record = self._make_record("header: bearer my-secret-token-value") + self.filt.filter(record) + assert "my-secret-token-value" not in record.msg + assert "[REDACTED]" in record.msg + + def test_redacts_openai_api_key_pattern(self): + record = self._make_record("Using key sk-proj1234567890abcdefghijklmnop") + self.filt.filter(record) + assert "1234567890abcdefghijklmnop" not in record.msg + assert "sk-pro...[REDACTED]" in record.msg + + def test_preserves_non_secret_messages(self): + record = self._make_record("Processing weather request for New York") + self.filt.filter(record) + assert record.msg == "Processing weather request for New York" + + def test_redacts_bearer_in_args(self): + record = self._make_record("Header: %s", ("Bearer sk-abc123xyz789secret",)) + self.filt.filter(record) + assert "sk-abc123xyz789secret" not in record.args[0] + assert "[REDACTED]" in record.args[0] + + def test_always_returns_true(self): + """Filter should never suppress log records, only redact content.""" + record = self._make_record("Bearer secret123") + assert self.filt.filter(record) is True + + def test_redacts_literal_configured_key(self, monkeypatch): + """Non-sk-* keys (e.g. RHOAI MaaS 32-char keys) are redacted via LLM_API_KEY.""" + rhoai_key = "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6" + monkeypatch.setenv("LLM_API_KEY", rhoai_key) + filt = SecretRedactionFilter() + record = self._make_record(f"Sending request with api-key={rhoai_key}") + filt.filter(record) + assert rhoai_key not in record.msg + assert "[REDACTED]" in record.msg + + def test_literal_key_redaction_in_args(self, monkeypatch): + """Literal key redaction also applies to log record args.""" + rhoai_key = "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6" + monkeypatch.setenv("LLM_API_KEY", rhoai_key) + filt = SecretRedactionFilter() + record = self._make_record("key=%s", (rhoai_key,)) + filt.filter(record) + assert rhoai_key not in record.args[0] + assert "[REDACTED]" in record.args[0] + + def test_short_key_not_literal_redacted(self, monkeypatch): + """Short keys (<=8 chars like 'dummy') should not trigger literal redaction.""" + monkeypatch.setenv("LLM_API_KEY", "dummy") + filt = SecretRedactionFilter() + record = self._make_record("Using dummy config for testing dummy values") + filt.filter(record) + # "dummy" should NOT be redacted — it's too short/common + assert "dummy" in record.msg + + def test_no_literal_key_when_unset(self, monkeypatch): + """No crash when LLM_API_KEY is not set.""" + monkeypatch.delenv("LLM_API_KEY", raising=False) + filt = SecretRedactionFilter() + record = self._make_record("Normal log message") + filt.filter(record) + assert record.msg == "Normal log message" + + +class TestConfigurationApiKeyValidation: + """Test API key validation logic.""" + + def test_dummy_key_with_remote_api_is_invalid(self, monkeypatch): + """Dummy key pointing at OpenAI should be flagged.""" + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + monkeypatch.setenv("LLM_API_KEY", "dummy") + config = Configuration() + assert config.has_valid_api_key is False + + def test_dummy_key_with_ollama_is_valid(self): + """Default config (Ollama on localhost) should work with dummy key.""" + config = Configuration() + assert config.is_local_llm is True + assert config.has_valid_api_key is True + + def test_empty_key_with_remote_api_is_invalid(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + monkeypatch.setenv("LLM_API_KEY", "") + config = Configuration() + assert config.has_valid_api_key is False + + def test_placeholder_keys_with_remote_api_are_invalid(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + for placeholder in ["changeme", "your-api-key-here"]: + monkeypatch.setenv("LLM_API_KEY", placeholder) + config = Configuration() + assert config.has_valid_api_key is False, f"'{placeholder}' should be invalid" + + def test_placeholder_keys_with_local_llm_are_valid(self, monkeypatch): + """Local LLMs (Ollama, vLLM) accept any key — don't block them.""" + monkeypatch.setenv("LLM_API_BASE", "http://localhost:11434/v1") + for placeholder in ["dummy", "changeme", ""]: + monkeypatch.setenv("LLM_API_KEY", placeholder) + config = Configuration() + assert config.has_valid_api_key is True, f"'{placeholder}' with local LLM should be valid" + + def test_real_key_is_valid(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + monkeypatch.setenv("LLM_API_KEY", "sk-proj-realkey123") + config = Configuration() + assert config.has_valid_api_key is True + + def test_rhoai_maas_key_is_valid(self, monkeypatch): + """RHOAI MaaS uses non-sk-* 32-char alphanumeric keys — should be valid.""" + monkeypatch.setenv( + "LLM_API_BASE", + "https://deepseek-r1-qwen-14b-w4a16--maas-apicast-production.apps.prod.rhoai.rh-aiservices-bu.com:443/v1", + ) + monkeypatch.setenv("LLM_API_KEY", "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6") + config = Configuration() + assert config.is_local_llm is False + assert config.has_valid_api_key is True + + def test_is_local_llm_with_127(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "http://127.0.0.1:8080/v1") + config = Configuration() + assert config.is_local_llm is True + + def test_is_not_local_llm_with_remote_host(self, monkeypatch): + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + config = Configuration() + assert config.is_local_llm is False + + def test_log_warnings_with_dummy_key_remote(self, monkeypatch, caplog): + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + monkeypatch.setenv("LLM_API_KEY", "dummy") + config = Configuration() + with caplog.at_level(logging.WARNING): + config.log_warnings() + assert "No LLM API key configured" in caplog.text + + def test_no_warning_with_ollama_dummy_key(self, caplog): + """Default Ollama config should NOT warn about the dummy key.""" + config = Configuration() + with caplog.at_level(logging.WARNING): + config.log_warnings() + assert "No LLM API key configured" not in caplog.text + + def test_log_warnings_with_real_key(self, monkeypatch, caplog): + monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1") + monkeypatch.setenv("LLM_API_KEY", "sk-proj-realkey123") + config = Configuration() + with caplog.at_level(logging.WARNING): + config.log_warnings() + assert "No LLM API key configured" not in caplog.text