From d8ce12b830228f6a7581fbc2c0e62c9772f448b1 Mon Sep 17 00:00:00 2001 From: drewburchfield Date: Fri, 3 Apr 2026 15:01:35 -0500 Subject: [PATCH] Extract tool handlers from server.py into tools.py (NAS-991) Separates transport-agnostic tool logic from MCP plumbing: - tools.py: validates args, calls engine, returns structured dicts - server.py: MCP schemas, response formatting, dispatch Enables future CLI/REST transports to reuse tool handlers without depending on MCP. No sub-packages needed at current codebase size. --- src/server.py | 426 +++++++++++++------------------------- src/tools.py | 161 ++++++++++++++ tests/conftest.py | 17 +- tests/test_error_paths.py | 6 +- 4 files changed, 319 insertions(+), 291 deletions(-) create mode 100644 src/tools.py diff --git a/src/server.py b/src/server.py index 8d1cd08..63cae53 100644 --- a/src/server.py +++ b/src/server.py @@ -1,12 +1,12 @@ """ Obsidian Graph MCP Server -Provides semantic knowledge graph navigation for Obsidian vaults. +MCP transport layer. Defines tool schemas, delegates to tools.py for +execution, and formats results as MCP-compatible text responses. """ import asyncio import os -from dataclasses import dataclass from pathlib import Path from typing import Any @@ -16,43 +16,19 @@ from mcp.types import Tool from .embedder import VoyageEmbedder -from .exceptions import EmbeddingError from .file_watcher import VaultWatcher from .graph_builder import GraphBuilder from .hub_analyzer import HubAnalyzer -from .security_utils import SecurityError, validate_note_path_parameter -from .validation import ( - ValidationError, - validate_connection_graph_args, - validate_hub_notes_args, - validate_orphaned_notes_args, - validate_search_notes_args, - validate_similar_notes_args, -) +from .security_utils import SecurityError +from .tools import TOOLS, ToolContext, ToolError +from .validation import ValidationError from .vector_store import PostgreSQLVectorStore +# Global tool context (initialized once at startup) +_tool_context: ToolContext | None = None -@dataclass -class ServerContext: - """ - Encapsulates server dependencies for dependency injection and testing. - - Benefits: - - Makes dependencies explicit - - Easier unit testing (inject mock context) - - Foundation for future full DI migration - - No breaking changes (internal refactor) - """ - - store: PostgreSQLVectorStore - embedder: VoyageEmbedder - graph_builder: GraphBuilder - hub_analyzer: HubAnalyzer - vault_watcher: VaultWatcher | None = None - - -# Global server context (initialized once at startup) -_server_context: ServerContext | None = None +# Global vault watcher (separate from tool context, MCP-specific lifecycle) +_vault_watcher: VaultWatcher | None = None # Create MCP server app = Server("obsidian-graph") @@ -60,10 +36,12 @@ class ServerContext: async def initialize_server(): """Initialize server context with all components.""" - global _server_context + global _tool_context, _vault_watcher logger.info("Initializing Obsidian Graph MCP Server...") + vault_path = os.getenv("OBSIDIAN_VAULT_PATH", "/vault") + # Initialize embedder embedder = VoyageEmbedder( model="voyage-context-3", @@ -89,42 +67,40 @@ async def initialize_server(): graph_builder = GraphBuilder(store) hub_analyzer = HubAnalyzer(store) + # Create tool context + _tool_context = ToolContext( + store=store, + embedder=embedder, + graph_builder=graph_builder, + hub_analyzer=hub_analyzer, + vault_path=vault_path, + ) + # Start file watching if enabled - vault_path = os.getenv("OBSIDIAN_VAULT_PATH", "/vault") watch_enabled = os.getenv("OBSIDIAN_WATCH_ENABLED", "true").lower() == "true" - vault_watcher = None if watch_enabled and os.path.exists(vault_path): - vault_watcher = VaultWatcher( + _vault_watcher = VaultWatcher( vault_path, store, embedder, debounce_seconds=int(os.getenv("OBSIDIAN_DEBOUNCE_SECONDS", "30")), ) - # Start file watching first (creates event_handler) loop = asyncio.get_running_loop() - vault_watcher.start(loop) - - # Run startup scan to catch files changed while offline - await vault_watcher.startup_scan() + _vault_watcher.start(loop) + await _vault_watcher.startup_scan() logger.success(f"File watching enabled: {vault_path}") else: logger.info("File watching disabled") - # Create server context - _server_context = ServerContext( - store=store, - embedder=embedder, - graph_builder=graph_builder, - hub_analyzer=hub_analyzer, - vault_watcher=vault_watcher, - ) - logger.success("Server initialized successfully") +# -- MCP Tool Schemas -- + + @app.list_tools() async def list_tools() -> list[Tool]: """List available MCP tools.""" @@ -233,7 +209,7 @@ async def list_tools() -> list[Tool]: }, "threshold": { "type": "number", - "description": ("Similarity threshold for counting connections (0.0-1.0)"), + "description": "Similarity threshold for counting connections (0.0-1.0)", "default": 0.5, "minimum": 0.0, "maximum": 1.0, @@ -262,7 +238,7 @@ async def list_tools() -> list[Tool]: }, "threshold": { "type": "number", - "description": ("Similarity threshold for counting connections (0.0-1.0)"), + "description": "Similarity threshold for counting connections (0.0-1.0)", "default": 0.5, "minimum": 0.0, "maximum": 1.0, @@ -280,242 +256,134 @@ async def list_tools() -> list[Tool]: ] -@app.call_tool() -async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any]]: - """ - Handle tool calls with comprehensive input validation. - - Security features: - - All parameters validated before processing - - Path traversal protection for note_path parameters - - Type checking and range validation - - Graceful error handling with descriptive messages - """ - # Log tool call for debugging - logger.info(f"Tool called: {name} with args: {list(arguments.keys())}") +# -- Response Formatters -- - # Get server context - ctx = _server_context - if ctx is None: - logger.error("Server context not initialized") - return [{"type": "text", "text": "Error: Server not initialized"}] - if name == "search_notes": - try: - # Validate arguments - validated = validate_search_notes_args(arguments) - query = validated["query"] - limit = validated["limit"] - threshold = validated["threshold"] - - # Generate query embedding - try: - query_embedding = await ctx.embedder.embed(query, input_type="query") - except EmbeddingError as e: - logger.error(f"Query embedding failed: {e}", exc_info=True) - return [ - { - "type": "text", - "text": f"Error: Failed to generate query embedding: {e}", - } - ] - - # Search - results = await ctx.store.search(query_embedding, limit, threshold) - - # Format results - response = f"Found {len(results)} notes:\n\n" - for i, result in enumerate(results, 1): - snippet = ( - result.content[:200] + "..." if len(result.content) > 200 else result.content - ) - response += f"{i}. **{result.title}** " f"(similarity: {result.similarity:.3f})\n" - response += f" Path: `{result.path}`\n" - response += f" {snippet}\n\n" - - return [{"type": "text", "text": response}] - - except ValidationError as e: - logger.warning(f"Validation error in search_notes: {e}") - return [{"type": "text", "text": f"Validation Error: {str(e)}"}] - except Exception as e: - logger.error(f"Error in search_notes: {e}", exc_info=True) - return [{"type": "text", "text": f"Error: {str(e)}"}] - - elif name == "get_similar_notes": - try: - # Validate arguments - validated = validate_similar_notes_args(arguments) - - # SECURITY: Validate note_path before processing - note_path = validate_note_path_parameter( - validated["note_path"], - vault_path=os.getenv("OBSIDIAN_VAULT_PATH", "/vault"), - ) - limit = validated["limit"] - threshold = validated["threshold"] - - # Get similar notes - results = await ctx.store.get_similar_notes(note_path, limit, threshold) - - # Format results - response = f"Notes similar to `{note_path}`:\n\n" - for i, result in enumerate(results, 1): - response += f"{i}. **{result.title}** " f"(similarity: {result.similarity:.3f})\n" - response += f" Path: `{result.path}`\n\n" - - return [{"type": "text", "text": response}] - - except ValidationError as e: - logger.warning(f"Validation error in get_similar_notes: {e}") - return [{"type": "text", "text": f"Validation Error: {str(e)}"}] - except SecurityError as e: - logger.warning(f"Security validation failed for get_similar_notes: {e}") - return [{"type": "text", "text": f"Security Error: {str(e)}"}] - except Exception as e: - logger.error(f"Error in get_similar_notes: {e}", exc_info=True) - return [{"type": "text", "text": f"Error: {str(e)}"}] - - elif name == "get_connection_graph": - try: - # Validate arguments - validated = validate_connection_graph_args(arguments) - - # SECURITY: Validate note_path before processing - note_path = validate_note_path_parameter( - validated["note_path"], - vault_path=os.getenv("OBSIDIAN_VAULT_PATH", "/vault"), - ) - depth = validated["depth"] - max_per_level = validated["max_per_level"] - threshold = validated["threshold"] - - # Build connection graph - graph = await ctx.graph_builder.build_connection_graph( - note_path, depth, max_per_level, threshold - ) - - # Format results - response = f"# Connection Graph: {graph['root']['title']}\n\n" - response += f"**Starting note:** `{graph['root']['path']}`\n" - response += ( - f"**Network size:** {graph['stats']['total_nodes']} nodes, " - f"{graph['stats']['total_edges']} edges\n\n" - ) - - # Group nodes by level - nodes_by_level = {} - for node in graph["nodes"]: - level = node["level"] - if level not in nodes_by_level: - nodes_by_level[level] = [] - nodes_by_level[level].append(node) - - # Display by level - for level in sorted(nodes_by_level.keys()): - response += f"\n## Level {level}\n" - for node in nodes_by_level[level]: - response += f"- **{node['title']}** (`{node['path']}`)\n" - if node["parent_path"]: - # Find edge to get similarity - edge = next( - (e for e in graph["edges"] if e["target"] == node["path"]), - None, - ) - if edge: - response += ( - f" Connected from: `{node['parent_path']}` " - f"(similarity: {edge['similarity']:.3f})\n" - ) - - return [{"type": "text", "text": response}] - - except ValidationError as e: - logger.warning(f"Validation error in get_connection_graph: {e}") - return [{"type": "text", "text": f"Validation Error: {str(e)}"}] - except SecurityError as e: - logger.warning(f"Security validation failed for get_connection_graph: {e}") - return [{"type": "text", "text": f"Security Error: {str(e)}"}] - except Exception as e: - logger.error(f"Error in get_connection_graph: {e}", exc_info=True) - return [{"type": "text", "text": f"Error: {str(e)}"}] - - elif name == "get_hub_notes": - try: - # Validate arguments - validated = validate_hub_notes_args(arguments) - min_connections = validated["min_connections"] - threshold = validated["threshold"] - limit = validated["limit"] - - # Get hub notes - hubs = await ctx.hub_analyzer.get_hub_notes(min_connections, threshold, limit) - - # Format results - if not hubs: - response = ( - f"No hub notes found with >={min_connections} connections " - f"at threshold {threshold}" - ) - else: - response = "# Hub Notes (Highly Connected)\n\n" - response += f"Found {len(hubs)} notes with " f">={min_connections} connections:\n\n" - for i, hub in enumerate(hubs, 1): - response += ( - f"{i}. **{hub['title']}** " f"({hub['connection_count']} connections)\n" - ) - response += f" Path: `{hub['path']}`\n\n" - - return [{"type": "text", "text": response}] - - except ValidationError as e: - logger.warning(f"Validation error in get_hub_notes: {e}") - return [{"type": "text", "text": f"Validation Error: {str(e)}"}] - except Exception as e: - logger.error(f"Error in get_hub_notes: {e}", exc_info=True) - return [{"type": "text", "text": f"Error: {str(e)}"}] - - elif name == "get_orphaned_notes": - try: - # Validate arguments - validated = validate_orphaned_notes_args(arguments) - max_connections = validated["max_connections"] - threshold = validated["threshold"] - limit = validated["limit"] - - # Get orphaned notes - orphans = await ctx.hub_analyzer.get_orphaned_notes(max_connections, threshold, limit) - - # Format results - if not orphans: - response = f"No orphaned notes found with " f"<={max_connections} connections" - else: - response = "# Orphaned Notes (Isolated)\n\n" - response += ( - f"Found {len(orphans)} notes with " f"<={max_connections} connections:\n\n" +def _format_search_results(data: dict) -> str: + results = data["results"] + response = f"Found {len(results)} notes:\n\n" + for i, r in enumerate(results, 1): + snippet = r["content"][:200] + "..." if len(r["content"]) > 200 else r["content"] + response += f"{i}. **{r['title']}** (similarity: {r['similarity']:.3f})\n" + response += f" Path: `{r['path']}`\n" + response += f" {snippet}\n\n" + return response + + +def _format_similar_notes(data: dict) -> str: + response = f"Notes similar to `{data['note_path']}`:\n\n" + for i, r in enumerate(data["results"], 1): + response += f"{i}. **{r['title']}** (similarity: {r['similarity']:.3f})\n" + response += f" Path: `{r['path']}`\n\n" + return response + + +def _format_connection_graph(graph: dict) -> str: + response = f"# Connection Graph: {graph['root']['title']}\n\n" + response += f"**Starting note:** `{graph['root']['path']}`\n" + response += ( + f"**Network size:** {graph['stats']['total_nodes']} nodes, " + f"{graph['stats']['total_edges']} edges\n\n" + ) + + nodes_by_level: dict[int, list] = {} + for node in graph["nodes"]: + nodes_by_level.setdefault(node["level"], []).append(node) + + for level in sorted(nodes_by_level.keys()): + response += f"\n## Level {level}\n" + for node in nodes_by_level[level]: + response += f"- **{node['title']}** (`{node['path']}`)\n" + if node["parent_path"]: + edge = next( + (e for e in graph["edges"] if e["target"] == node["path"]), + None, ) - for i, orphan in enumerate(orphans, 1): + if edge: response += ( - f"{i}. **{orphan['title']}** " - f"({orphan['connection_count']} connections)\n" + f" Connected from: `{node['parent_path']}` " + f"(similarity: {edge['similarity']:.3f})\n" ) - response += f" Path: `{orphan['path']}`\n" - if orphan["modified_at"]: - response += f" Modified: {orphan['modified_at']}\n" - response += "\n" - return [{"type": "text", "text": response}] + return response - except ValidationError as e: - logger.warning(f"Validation error in get_orphaned_notes: {e}") - return [{"type": "text", "text": f"Validation Error: {str(e)}"}] - except Exception as e: - logger.error(f"Error in get_orphaned_notes: {e}", exc_info=True) - return [{"type": "text", "text": f"Error: {str(e)}"}] - else: +def _format_hub_notes(data: dict) -> str: + hubs = data["results"] + if not hubs: + return ( + f"No hub notes found with >={data['min_connections']} connections " + f"at threshold {data['threshold']}" + ) + + response = "# Hub Notes (Highly Connected)\n\n" + response += f"Found {len(hubs)} notes with >={data['min_connections']} connections:\n\n" + for i, hub in enumerate(hubs, 1): + response += f"{i}. **{hub['title']}** ({hub['connection_count']} connections)\n" + response += f" Path: `{hub['path']}`\n\n" + return response + + +def _format_orphaned_notes(data: dict) -> str: + orphans = data["results"] + if not orphans: + return f"No orphaned notes found with <={data['max_connections']} connections" + + response = "# Orphaned Notes (Isolated)\n\n" + response += f"Found {len(orphans)} notes with <={data['max_connections']} connections:\n\n" + for i, orphan in enumerate(orphans, 1): + response += f"{i}. **{orphan['title']}** ({orphan['connection_count']} connections)\n" + response += f" Path: `{orphan['path']}`\n" + if orphan.get("modified_at"): + response += f" Modified: {orphan['modified_at']}\n" + response += "\n" + return response + + +_FORMATTERS = { + "search_notes": _format_search_results, + "get_similar_notes": _format_similar_notes, + "get_connection_graph": _format_connection_graph, + "get_hub_notes": _format_hub_notes, + "get_orphaned_notes": _format_orphaned_notes, +} + + +# -- MCP Tool Dispatch -- + + +@app.call_tool() +async def call_tool(name: str, arguments: dict[str, Any]) -> list[dict[str, Any]]: + """Dispatch MCP tool calls to handlers in tools.py.""" + logger.info(f"Tool called: {name} with args: {list(arguments.keys())}") + + ctx = _tool_context + if ctx is None: + logger.error("Server context not initialized") + return [{"type": "text", "text": "Error: Server not initialized"}] + + handler = TOOLS.get(name) + if not handler: return [{"type": "text", "text": f"Unknown tool: {name}"}] + try: + result = await handler(ctx, arguments) + formatted = _FORMATTERS[name](result) + return [{"type": "text", "text": formatted}] + + except ValidationError as e: + logger.warning(f"Validation error in {name}: {e}") + return [{"type": "text", "text": f"Validation Error: {str(e)}"}] + except SecurityError as e: + logger.warning(f"Security validation failed for {name}: {e}") + return [{"type": "text", "text": f"Security Error: {str(e)}"}] + except ToolError as e: + logger.error(f"Tool error in {name}: {e}") + return [{"type": "text", "text": f"Error: {str(e)}"}] + except Exception as e: + logger.error(f"Error in {name}: {e}", exc_info=True) + return [{"type": "text", "text": f"Error: {str(e)}"}] + async def main(): """Run the MCP server.""" diff --git a/src/tools.py b/src/tools.py new file mode 100644 index 0000000..017f672 --- /dev/null +++ b/src/tools.py @@ -0,0 +1,161 @@ +""" +Tool handlers for Obsidian Graph. + +Transport-agnostic tool implementations that validate inputs, call engine +components, and return structured results. Used by server.py (MCP) and +can be reused by future transports (CLI, REST). +""" + +from dataclasses import dataclass +from typing import Any + +from .embedder import VoyageEmbedder +from .exceptions import EmbeddingError +from .graph_builder import GraphBuilder +from .hub_analyzer import HubAnalyzer +from .security_utils import validate_note_path_parameter +from .validation import ( + validate_connection_graph_args, + validate_hub_notes_args, + validate_orphaned_notes_args, + validate_search_notes_args, + validate_similar_notes_args, +) +from .vector_store import PostgreSQLVectorStore + + +@dataclass +class ToolContext: + """Dependencies needed by tool handlers.""" + + store: PostgreSQLVectorStore + embedder: VoyageEmbedder + graph_builder: GraphBuilder + hub_analyzer: HubAnalyzer + vault_path: str = "/vault" + + +class ToolError(Exception): + """Raised when a tool handler fails. Contains a user-facing message.""" + + pass + + +async def search_notes(ctx: ToolContext, arguments: dict[str, Any]) -> dict[str, Any]: + """ + Semantic search across vault. + + Returns: + {"results": [{"path", "title", "content", "similarity"}, ...]} + """ + validated = validate_search_notes_args(arguments) + + try: + query_embedding = await ctx.embedder.embed(validated["query"], input_type="query") + except EmbeddingError as e: + raise ToolError(f"Failed to generate query embedding: {e}") from e + + results = await ctx.store.search(query_embedding, validated["limit"], validated["threshold"]) + + return { + "results": [ + { + "path": r.path, + "title": r.title, + "content": r.content, + "similarity": r.similarity, + } + for r in results + ] + } + + +async def get_similar_notes(ctx: ToolContext, arguments: dict[str, Any]) -> dict[str, Any]: + """ + Find notes similar to a given note. + + Returns: + {"note_path": str, "results": [{"path", "title", "similarity"}, ...]} + """ + validated = validate_similar_notes_args(arguments) + note_path = validate_note_path_parameter(validated["note_path"], vault_path=ctx.vault_path) + + results = await ctx.store.get_similar_notes( + note_path, validated["limit"], validated["threshold"] + ) + + return { + "note_path": note_path, + "results": [ + { + "path": r.path, + "title": r.title, + "similarity": r.similarity, + } + for r in results + ], + } + + +async def get_connection_graph(ctx: ToolContext, arguments: dict[str, Any]) -> dict[str, Any]: + """ + Build multi-hop connection graph from a starting note. + + Returns: + The graph dict from GraphBuilder (root, nodes, edges, stats). + """ + validated = validate_connection_graph_args(arguments) + note_path = validate_note_path_parameter(validated["note_path"], vault_path=ctx.vault_path) + + return await ctx.graph_builder.build_connection_graph( + note_path, validated["depth"], validated["max_per_level"], validated["threshold"] + ) + + +async def get_hub_notes(ctx: ToolContext, arguments: dict[str, Any]) -> dict[str, Any]: + """ + Identify highly connected hub notes. + + Returns: + {"min_connections": int, "threshold": float, "results": [{"path", "title", "connection_count"}, ...]} + """ + validated = validate_hub_notes_args(arguments) + + hubs = await ctx.hub_analyzer.get_hub_notes( + validated["min_connections"], validated["threshold"], validated["limit"] + ) + + return { + "min_connections": validated["min_connections"], + "threshold": validated["threshold"], + "results": hubs, + } + + +async def get_orphaned_notes(ctx: ToolContext, arguments: dict[str, Any]) -> dict[str, Any]: + """ + Find isolated notes with few connections. + + Returns: + {"max_connections": int, "results": [{"path", "title", "connection_count", "modified_at"}, ...]} + """ + validated = validate_orphaned_notes_args(arguments) + + orphans = await ctx.hub_analyzer.get_orphaned_notes( + validated["max_connections"], validated["threshold"], validated["limit"] + ) + + return { + "max_connections": validated["max_connections"], + "results": orphans, + } + + +# Tool dispatch table +TOOLS = { + "search_notes": search_notes, + "get_similar_notes": get_similar_notes, + "get_connection_graph": get_connection_graph, + "get_hub_notes": get_hub_notes, + "get_orphaned_notes": get_orphaned_notes, +} diff --git a/tests/conftest.py b/tests/conftest.py index 0be2f2d..a0c3fea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -120,40 +120,39 @@ async def tmp_vault(tmp_path): @pytest.fixture async def server_context(mock_store, mock_embedder): """ - Mock ServerContext for testing tool handlers. + Mock ToolContext for testing tool handlers. - Provides complete server context with mocked dependencies. + Provides complete tool context with mocked dependencies. - IMPORTANT: This fixture injects the context into src.server._server_context + IMPORTANT: This fixture injects the context into src.server._tool_context so that call_tool() uses the mocked dependencies. The original context is restored after the test completes. """ import src.server from src.graph_builder import GraphBuilder from src.hub_analyzer import HubAnalyzer - from src.server import ServerContext + from src.tools import ToolContext # Save original context - original_context = src.server._server_context + original_context = src.server._tool_context graph_builder = GraphBuilder(mock_store) hub_analyzer = HubAnalyzer(mock_store) - context = ServerContext( + context = ToolContext( store=mock_store, embedder=mock_embedder, graph_builder=graph_builder, hub_analyzer=hub_analyzer, - vault_watcher=None, ) # Inject into global for call_tool() - src.server._server_context = context + src.server._tool_context = context yield context # Restore original context - src.server._server_context = original_context + src.server._tool_context = original_context @pytest.fixture diff --git a/tests/test_error_paths.py b/tests/test_error_paths.py index a018b00..7a9788e 100644 --- a/tests/test_error_paths.py +++ b/tests/test_error_paths.py @@ -105,11 +105,11 @@ async def test_all_tools_handle_uninitialized_server(self): from src.server import call_tool # Save original context - original_context = src.server._server_context + original_context = src.server._tool_context try: # Temporarily set context to None (simulate initialization failure) - src.server._server_context = None + src.server._tool_context = None # Test all 5 tools tools = [ @@ -132,7 +132,7 @@ async def test_all_tools_handle_uninitialized_server(self): finally: # Restore context - src.server._server_context = original_context + src.server._tool_context = original_context class TestGraphBuilderErrors: