Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions a2a/weather_service/src/weather_service/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import re
from textwrap import dedent

import uvicorn
Expand All @@ -14,14 +15,61 @@
from a2a.server.tasks import InMemoryTaskStore, TaskUpdater
from a2a.types import AgentCapabilities, AgentCard, AgentSkill, TaskState, TextPart
from a2a.utils import new_agent_text_message, new_task
from weather_service.configuration import Configuration
from weather_service.graph import get_graph, get_mcpclient
from weather_service.observability import (
create_tracing_middleware,
get_root_span,
set_span_output,
)

logging.basicConfig(level=logging.DEBUG)

class SecretRedactionFilter(logging.Filter):
"""Redacts Bearer tokens and API keys from log messages.

Covers three layers:
1. Bearer tokens in Authorization headers (any format).
2. OpenAI-style ``sk-*`` API keys (pattern-based).
3. The literal configured ``LLM_API_KEY`` value (provider-agnostic,
catches non-standard key formats like RHOAI MaaS 32-char keys).
"""

_BEARER_RE = re.compile(r"(Bearer\s+)\S+", re.IGNORECASE)
_API_KEY_RE = re.compile(r"(sk-[a-zA-Z0-9]{3})[a-zA-Z0-9]+")

def __init__(self, name: str = ""):
super().__init__(name)
# Redact the literal configured key when it's long enough to be real
configured_key = os.environ.get("LLM_API_KEY", "").strip()
if len(configured_key) > 8:
self._literal_key_re: re.Pattern | None = re.compile(re.escape(configured_key))
else:
self._literal_key_re = None

def _redact(self, text: str) -> str:
text = self._BEARER_RE.sub(r"\1[REDACTED]", text)
text = self._API_KEY_RE.sub(r"\1...[REDACTED]", text)
if self._literal_key_re is not None:
text = self._literal_key_re.sub("[REDACTED]", text)
return text

def filter(self, record: logging.LogRecord) -> bool:
if isinstance(record.msg, str):
record.msg = self._redact(record.msg)
if record.args:
args = record.args if isinstance(record.args, tuple) else (record.args,)
new_args = []
for arg in args:
if isinstance(arg, str):
arg = self._redact(arg)
new_args.append(arg)
record.args = tuple(new_args)
return True


logging.basicConfig(level=logging.INFO)
# Apply secret redaction filter to the root logger so all loggers benefit
logging.getLogger().addFilter(SecretRedactionFilter())
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -111,6 +159,17 @@ async def execute(self, context: RequestContext, event_queue: EventQueue):
task_updater = TaskUpdater(event_queue, task.id, task.context_id)
event_emitter = A2AEvent(task_updater)

# Check API key before attempting LLM calls
config = Configuration()
if not config.has_valid_api_key:
await event_emitter.emit_event(
"Error: No LLM API key configured. Please set the LLM_API_KEY "
"environment variable (or charts.kagenti.values.secrets.openaiApiKey "
"during Kagenti installation).",
failed=True,
)
return

# Get user input for the agent
user_input = context.get_user_input()

Expand Down Expand Up @@ -216,14 +275,4 @@ def run():
# Add tracing middleware - creates root span with MLflow/GenAI attributes
app.add_middleware(BaseHTTPMiddleware, dispatch=create_tracing_middleware())

# Add logging middleware
@app.middleware("http")
async def log_authorization_header(request, call_next):
auth_header = request.headers.get("authorization", "No Authorization header")
logger.info(
f"🔐 Incoming request to {request.url.path} with Authorization: {auth_header[:80] + '...' if len(auth_header) > 80 else auth_header}"
)
response = await call_next(request)
return response

uvicorn.run(app, host="0.0.0.0", port=8000)
37 changes: 37 additions & 0 deletions a2a/weather_service/src/weather_service/configuration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,44 @@
import logging

from pydantic_settings import BaseSettings

logger = logging.getLogger(__name__)

_PLACEHOLDER_KEYS = {"dummy", "changeme", "your-api-key-here", ""}

# API bases that are known to accept placeholder/dummy keys (local LLMs)
_LOCAL_LLM_HOSTS = {"localhost", "127.0.0.1", "0.0.0.0"}


class Configuration(BaseSettings):
llm_model: str = "llama3.1"
llm_api_base: str = "http://localhost:11434/v1"
llm_api_key: str = "dummy"

@property
def is_local_llm(self) -> bool:
"""Check if the API base points to a local LLM (e.g. Ollama)."""
from urllib.parse import urlparse

parsed = urlparse(self.llm_api_base)
hostname = parsed.hostname or ""
return hostname in _LOCAL_LLM_HOSTS

@property
def has_valid_api_key(self) -> bool:
"""Check if the API key is usable.

Local LLMs (Ollama, vLLM, etc.) accept any key including placeholders,
so placeholder keys are only flagged when pointing at a remote API.
"""
if self.is_local_llm:
return True
return self.llm_api_key.strip() not in _PLACEHOLDER_KEYS

def log_warnings(self) -> None:
"""Log warnings about configuration issues at startup."""
if not self.has_valid_api_key:
logger.warning(
"No LLM API key configured (set LLM_API_KEY env var). "
"The weather agent will not be able to call the LLM."
)
1 change: 1 addition & 0 deletions a2a/weather_service/src/weather_service/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from weather_service.configuration import Configuration

config = Configuration()
config.log_warnings()


# Extend MessagesState to include a final answer
Expand Down
240 changes: 240 additions & 0 deletions tests/a2a/test_weather_secret_redaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""Tests for secret redaction and API key validation in the weather service.

Loads agent.py and configuration.py in isolation (same approach as
test_weather_service.py) to avoid pulling in heavy deps like opentelemetry.
"""

import importlib.util
import logging
import pathlib
import sys
from types import ModuleType
from unittest.mock import MagicMock

# --- Isolation setup (must happen before any weather_service imports) ---
_fake_ws = ModuleType("weather_service")
_fake_ws.__path__ = [] # type: ignore[attr-defined]
sys.modules.setdefault("weather_service", _fake_ws)
sys.modules.setdefault("weather_service.observability", MagicMock())

_BASE = pathlib.Path(__file__).parent.parent.parent / "a2a" / "weather_service" / "src" / "weather_service"


def _load_module(name: str, path: pathlib.Path):
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
sys.modules[name] = mod
spec.loader.exec_module(mod) # type: ignore[union-attr]
return mod


_config_mod = _load_module("weather_service.configuration", _BASE / "configuration.py")

# Mock modules that agent.py imports but we don't need
for mod_name in [
"uvicorn",
"langchain_core",
"langchain_core.messages",
"starlette",
"starlette.middleware",
"starlette.middleware.base",
"starlette.routing",
"a2a",
"a2a.server",
"a2a.server.agent_execution",
"a2a.server.apps",
"a2a.server.events",
"a2a.server.events.event_queue",
"a2a.server.request_handlers",
"a2a.server.tasks",
"a2a.types",
"a2a.utils",
"weather_service.graph",
]:
sys.modules.setdefault(mod_name, MagicMock())

_agent_mod = _load_module("weather_service.agent", _BASE / "agent.py")

Configuration = _config_mod.Configuration
SecretRedactionFilter = _agent_mod.SecretRedactionFilter


# --- Tests ---


class TestSecretRedactionFilter:
"""Test the logging filter that redacts Bearer tokens and API keys."""

def setup_method(self):
self.filt = SecretRedactionFilter()

def _make_record(self, msg: str, args=None) -> logging.LogRecord:
record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname="",
lineno=0,
msg=msg,
args=args,
exc_info=None,
)
return record

def test_redacts_bearer_token_in_msg(self):
record = self._make_record("Authorization: Bearer sk-abc123xyz789secret")
self.filt.filter(record)
assert "sk-abc123xyz789secret" not in record.msg
assert "[REDACTED]" in record.msg

def test_redacts_bearer_token_case_insensitive(self):
record = self._make_record("header: bearer my-secret-token-value")
self.filt.filter(record)
assert "my-secret-token-value" not in record.msg
assert "[REDACTED]" in record.msg

def test_redacts_openai_api_key_pattern(self):
record = self._make_record("Using key sk-proj1234567890abcdefghijklmnop")
self.filt.filter(record)
assert "1234567890abcdefghijklmnop" not in record.msg
assert "sk-pro...[REDACTED]" in record.msg

def test_preserves_non_secret_messages(self):
record = self._make_record("Processing weather request for New York")
self.filt.filter(record)
assert record.msg == "Processing weather request for New York"

def test_redacts_bearer_in_args(self):
record = self._make_record("Header: %s", ("Bearer sk-abc123xyz789secret",))
self.filt.filter(record)
assert "sk-abc123xyz789secret" not in record.args[0]
assert "[REDACTED]" in record.args[0]

def test_always_returns_true(self):
"""Filter should never suppress log records, only redact content."""
record = self._make_record("Bearer secret123")
assert self.filt.filter(record) is True

def test_redacts_literal_configured_key(self, monkeypatch):
"""Non-sk-* keys (e.g. RHOAI MaaS 32-char keys) are redacted via LLM_API_KEY."""
rhoai_key = "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
monkeypatch.setenv("LLM_API_KEY", rhoai_key)
filt = SecretRedactionFilter()
record = self._make_record(f"Sending request with api-key={rhoai_key}")
filt.filter(record)
assert rhoai_key not in record.msg
assert "[REDACTED]" in record.msg

def test_literal_key_redaction_in_args(self, monkeypatch):
"""Literal key redaction also applies to log record args."""
rhoai_key = "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
monkeypatch.setenv("LLM_API_KEY", rhoai_key)
filt = SecretRedactionFilter()
record = self._make_record("key=%s", (rhoai_key,))
filt.filter(record)
assert rhoai_key not in record.args[0]
assert "[REDACTED]" in record.args[0]

def test_short_key_not_literal_redacted(self, monkeypatch):
"""Short keys (<=8 chars like 'dummy') should not trigger literal redaction."""
monkeypatch.setenv("LLM_API_KEY", "dummy")
filt = SecretRedactionFilter()
record = self._make_record("Using dummy config for testing dummy values")
filt.filter(record)
# "dummy" should NOT be redacted — it's too short/common
assert "dummy" in record.msg

def test_no_literal_key_when_unset(self, monkeypatch):
"""No crash when LLM_API_KEY is not set."""
monkeypatch.delenv("LLM_API_KEY", raising=False)
filt = SecretRedactionFilter()
record = self._make_record("Normal log message")
filt.filter(record)
assert record.msg == "Normal log message"


class TestConfigurationApiKeyValidation:
"""Test API key validation logic."""

def test_dummy_key_with_remote_api_is_invalid(self, monkeypatch):
"""Dummy key pointing at OpenAI should be flagged."""
monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1")
monkeypatch.setenv("LLM_API_KEY", "dummy")
config = Configuration()
assert config.has_valid_api_key is False

def test_dummy_key_with_ollama_is_valid(self):
"""Default config (Ollama on localhost) should work with dummy key."""
config = Configuration()
assert config.is_local_llm is True
assert config.has_valid_api_key is True

def test_empty_key_with_remote_api_is_invalid(self, monkeypatch):
monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1")
monkeypatch.setenv("LLM_API_KEY", "")
config = Configuration()
assert config.has_valid_api_key is False

def test_placeholder_keys_with_remote_api_are_invalid(self, monkeypatch):
monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1")
for placeholder in ["changeme", "your-api-key-here"]:
monkeypatch.setenv("LLM_API_KEY", placeholder)
config = Configuration()
assert config.has_valid_api_key is False, f"'{placeholder}' should be invalid"

def test_placeholder_keys_with_local_llm_are_valid(self, monkeypatch):
"""Local LLMs (Ollama, vLLM) accept any key — don't block them."""
monkeypatch.setenv("LLM_API_BASE", "http://localhost:11434/v1")
for placeholder in ["dummy", "changeme", ""]:
monkeypatch.setenv("LLM_API_KEY", placeholder)
config = Configuration()
assert config.has_valid_api_key is True, f"'{placeholder}' with local LLM should be valid"

def test_real_key_is_valid(self, monkeypatch):
monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1")
monkeypatch.setenv("LLM_API_KEY", "sk-proj-realkey123")
config = Configuration()
assert config.has_valid_api_key is True

def test_rhoai_maas_key_is_valid(self, monkeypatch):
"""RHOAI MaaS uses non-sk-* 32-char alphanumeric keys — should be valid."""
monkeypatch.setenv(
"LLM_API_BASE",
"https://deepseek-r1-qwen-14b-w4a16--maas-apicast-production.apps.prod.rhoai.rh-aiservices-bu.com:443/v1",
)
monkeypatch.setenv("LLM_API_KEY", "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6")
config = Configuration()
assert config.is_local_llm is False
assert config.has_valid_api_key is True

def test_is_local_llm_with_127(self, monkeypatch):
monkeypatch.setenv("LLM_API_BASE", "http://127.0.0.1:8080/v1")
config = Configuration()
assert config.is_local_llm is True

def test_is_not_local_llm_with_remote_host(self, monkeypatch):
monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1")
config = Configuration()
assert config.is_local_llm is False

def test_log_warnings_with_dummy_key_remote(self, monkeypatch, caplog):
monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1")
monkeypatch.setenv("LLM_API_KEY", "dummy")
config = Configuration()
with caplog.at_level(logging.WARNING):
config.log_warnings()
assert "No LLM API key configured" in caplog.text

def test_no_warning_with_ollama_dummy_key(self, caplog):
"""Default Ollama config should NOT warn about the dummy key."""
config = Configuration()
with caplog.at_level(logging.WARNING):
config.log_warnings()
assert "No LLM API key configured" not in caplog.text

def test_log_warnings_with_real_key(self, monkeypatch, caplog):
monkeypatch.setenv("LLM_API_BASE", "https://api.openai.com/v1")
monkeypatch.setenv("LLM_API_KEY", "sk-proj-realkey123")
config = Configuration()
with caplog.at_level(logging.WARNING):
config.log_warnings()
assert "No LLM API key configured" not in caplog.text
Loading