diff --git a/src/mcp_zero/plugins/github_repo_filter.py b/src/mcp_zero/plugins/github_repo_filter.py index 390e8bf..80e8722 100644 --- a/src/mcp_zero/plugins/github_repo_filter.py +++ b/src/mcp_zero/plugins/github_repo_filter.py @@ -6,9 +6,12 @@ from __future__ import annotations +import fnmatch import logging from typing import Any +from mcp_zero.context import HookContext +from mcp_zero.pipeline.errors import ShortCircuitError from mcp_zero.pipeline.hooks import LifecycleHook from mcp_zero.pipeline.registry import HookRegistry from mcp_zero.plugin import BasePlugin @@ -32,6 +35,45 @@ def __init__( self._repos = repos self._servers = servers + async def on_post_validation(self, ctx: HookContext) -> HookContext: + """Filter requests based on GitHub repository allowlist/blocklist.""" + # Server scoping: skip if this server is not in scope + if self._servers and ctx.server_name not in self._servers: + logger.debug( + "Skipping repo filter for server %r (not in scoped servers)", ctx.server_name + ) + return ctx + + # Extract owner/repo from request arguments + arguments = ctx.request_payload.get("arguments", {}) + owner = arguments.get("owner") + repo = arguments.get("repo") + + if not owner or not repo: + logger.debug("No owner/repo in request arguments; skipping repo filter") + return ctx + + # Combine and normalize for case-insensitive comparison + full_repo = f"{owner}/{repo}".lower() + + if self._mode == "allowlist": + matched = any(fnmatch.fnmatch(full_repo, pattern.lower()) for pattern in self._repos) + if not matched: + logger.debug("Repository %r not in allowlist; denying", full_repo) + raise ShortCircuitError( + f"Repository '{owner}/{repo}' is not in the allowlist", deny=True + ) + logger.debug("Repository %r matched allowlist", full_repo) + + elif self._mode == "blocklist": + matched = any(fnmatch.fnmatch(full_repo, pattern.lower()) for pattern in self._repos) + if matched: + logger.debug("Repository %r matched blocklist; denying", full_repo) + raise ShortCircuitError(f"Repository '{owner}/{repo}' is blocked", deny=True) + logger.debug("Repository %r not in blocklist; allowing", full_repo) + + return ctx + class GitHubRepoFilterPlugin(BasePlugin): """Plugin that enforces allowlist/blocklist policies on GitHub repos.