From 4b7a4e6a248738c802a34e0df08aa797a68bdaee Mon Sep 17 00:00:00 2001 From: Thiago Santos Date: Wed, 3 Dec 2025 11:17:57 -0300 Subject: [PATCH] fix: fixing log forging vulnerability Adding sanitization to log to avoid log forging vulnerability from user interaction --- cmd/openapi-mcp/main.go | 38 ++++---- pkg/config/config.go | 17 ++-- pkg/parser/parser.go | 38 ++++---- pkg/server/manager.go | 15 +-- pkg/server/server.go | 194 +++++++++++++++++++------------------- pkg/server/server_test.go | 16 ++-- pkg/utils/logging.go | 41 ++++++++ 7 files changed, 201 insertions(+), 158 deletions(-) create mode 100644 pkg/utils/logging.go diff --git a/cmd/openapi-mcp/main.go b/cmd/openapi-mcp/main.go index 4e05368..e8d9f3b 100644 --- a/cmd/openapi-mcp/main.go +++ b/cmd/openapi-mcp/main.go @@ -3,7 +3,6 @@ package main import ( "flag" "fmt" - "log" "os" "path/filepath" "strings" @@ -11,6 +10,7 @@ import ( "github.com/ckanthony/openapi-mcp/pkg/config" "github.com/ckanthony/openapi-mcp/pkg/parser" "github.com/ckanthony/openapi-mcp/pkg/server" + "github.com/ckanthony/openapi-mcp/pkg/utils" "github.com/joho/godotenv" ) @@ -56,33 +56,33 @@ func main() { // --- Load .env after parsing flags --- if *specPath != "" && !strings.HasPrefix(*specPath, "http://") && !strings.HasPrefix(*specPath, "https://") { envPath := filepath.Join(filepath.Dir(*specPath), ".env") - log.Printf("Attempting to load .env file from spec directory: %s", envPath) + utils.SafeLogPrintf("Attempting to load .env file from spec directory: %s", envPath) err := godotenv.Load(envPath) if err != nil { // It's okay if the file doesn't exist, log other errors. if !os.IsNotExist(err) { - log.Printf("Warning: Error loading .env file from %s: %v", envPath, err) + utils.SafeLogPrintf("Warning: Error loading .env file from %s: %v", envPath, err) } else { - log.Printf("Info: No .env file found at %s, proceeding without it.", envPath) + utils.SafeLogPrintf("Info: No .env file found at %s, proceeding without it.", envPath) } } else { - log.Printf("Successfully loaded .env file from %s", envPath) + utils.SafeLogPrintf("Successfully loaded .env file from %s", envPath) } } else if *specPath == "" { - log.Println("Skipping .env load because --spec is missing.") + utils.SafeLogPrintln("Skipping .env load because --spec is missing.") } else { - log.Println("Skipping .env load because spec path appears to be a URL.") + utils.SafeLogPrintln("Skipping .env load because spec path appears to be a URL.") } // --- Read REQUEST_HEADERS env var --- customHeadersEnv := os.Getenv("REQUEST_HEADERS") if customHeadersEnv != "" { - log.Printf("Found REQUEST_HEADERS environment variable: %s", customHeadersEnv) + utils.SafeLogPrintf("Found REQUEST_HEADERS environment variable: %s", customHeadersEnv) } // --- Input Validation --- if *specPath == "" { - log.Println("Error: --spec flag is required.") + utils.SafeLogPrintln("Error: --spec flag is required.") flag.Usage() os.Exit(1) } @@ -99,7 +99,8 @@ func main() { case string(config.APIKeyLocationCookie): apiKeyLocation = config.APIKeyLocationCookie default: - log.Fatalf("Error: invalid --api-key-loc value: %s. Must be 'header', 'query', 'path', or 'cookie'.", *apiKeyLocStr) + utils.SafeLogPrintln("Error: invalid --api-key-loc value:", *apiKeyLocStr, "Must be 'header', 'query', 'path', or 'cookie'.") + os.Exit(1) } } @@ -120,27 +121,28 @@ func main() { CustomHeaders: customHeadersEnv, } - log.Printf("Configuration loaded: %+v\n", cfg) - log.Println("API Key (resolved):", cfg.GetAPIKey()) + utils.SafeLogPrintln("Configuration loaded.") + utils.SafeLogPrintln("API Key (resolved):", cfg.GetAPIKey()) // --- Call Parser --- specDoc, version, err := parser.LoadSwagger(cfg.SpecPath) if err != nil { - log.Fatalf("Failed to load OpenAPI/Swagger spec: %v", err) + utils.SafeLogFatalf("Failed to load OpenAPI/Swagger spec: %v", err) } - log.Printf("Spec type %s loaded successfully from %s.\n", version, cfg.SpecPath) + utils.SafeLogPrintln("Spec type", version, "loaded successfully from", cfg.SpecPath) toolSet, err := parser.GenerateToolSet(specDoc, version, cfg) if err != nil { - log.Fatalf("Failed to generate MCP toolset: %v", err) + utils.SafeLogFatalf("Failed to generate MCP toolset: %v", err) } - log.Printf("MCP toolset generated with %d tools.\n", len(toolSet.Tools)) + utils.SafeLogPrintln("MCP toolset generated with", len(toolSet.Tools), "tools.") // --- Start Server --- addr := fmt.Sprintf(":%d", *port) - log.Printf("Starting MCP server on %s...", addr) + utils.SafeLogPrintf("Starting MCP server on %s...", addr) err = server.ServeMCP(addr, toolSet, cfg) // Pass cfg to ServeMCP if err != nil { - log.Fatalf("Failed to start server: %v", err) + utils.SafeLogPrintln("Failed to start server:", err) + os.Exit(1) } } diff --git a/pkg/config/config.go b/pkg/config/config.go index 7555687..cf6e5d0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,8 +1,9 @@ package config import ( - "log" "os" + + "github.com/ckanthony/openapi-mcp/pkg/utils" ) // APIKeyLocation specifies where the API key is located for requests. @@ -43,28 +44,28 @@ type Config struct { // GetAPIKey resolves the API key value, prioritizing the environment variable over the direct flag. func (c *Config) GetAPIKey() string { - log.Println("GetAPIKey: Attempting to resolve API key...") + utils.SafeLogPrintf("GetAPIKey: Attempting to resolve API key...") // 1. Check environment variable specified by --api-key-env if c.APIKeyFromEnvVar != "" { - log.Printf("GetAPIKey: Checking environment variable specified by --api-key-env: %s", c.APIKeyFromEnvVar) + utils.SafeLogPrintf("GetAPIKey: Checking environment variable specified by --api-key-env: %s", c.APIKeyFromEnvVar) val := os.Getenv(c.APIKeyFromEnvVar) if val != "" { - log.Printf("GetAPIKey: Found key in environment variable %s.", c.APIKeyFromEnvVar) + utils.SafeLogPrintf("GetAPIKey: Found key in environment variable %s.", c.APIKeyFromEnvVar) return val } - log.Printf("GetAPIKey: Environment variable %s not found or empty.", c.APIKeyFromEnvVar) + utils.SafeLogPrintf("GetAPIKey: Environment variable %s not found or empty.", c.APIKeyFromEnvVar) } else { - log.Println("GetAPIKey: No --api-key-env variable specified.") + utils.SafeLogPrintf("GetAPIKey: No --api-key-env variable specified.") } // 2. Check direct flag --api-key if c.APIKey != "" { - log.Println("GetAPIKey: Found key provided directly via --api-key flag.") + utils.SafeLogPrintf("GetAPIKey: Found key provided directly via --api-key flag.") return c.APIKey } // 3. No key found - log.Println("GetAPIKey: No API key found from config (env var or direct flag).") + utils.SafeLogPrintf("GetAPIKey: No API key found from config (env var or direct flag).") return "" } diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index 0ab6cca..39365b9 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "net/url" "os" @@ -15,6 +14,7 @@ import ( "github.com/ckanthony/openapi-mcp/pkg/config" "github.com/ckanthony/openapi-mcp/pkg/mcp" + "github.com/ckanthony/openapi-mcp/pkg/utils" "github.com/getkin/kin-openapi/openapi3" "github.com/go-openapi/loads" "github.com/go-openapi/spec" @@ -38,7 +38,7 @@ func LoadSwagger(location string) (interface{}, string, error) { var absPath string // Store absolute path if it's a file if !isURL { - log.Printf("Detected file path location: %s", location) + utils.SafeLogPrintf("Detected file path location: %s", location) absPath, err = filepath.Abs(location) if err != nil { return nil, "", fmt.Errorf("failed to get absolute path for '%s': %w", location, err) @@ -49,7 +49,7 @@ func LoadSwagger(location string) (interface{}, string, error) { return nil, "", fmt.Errorf("failed reading file path '%s': %w", absPath, err) } } else { - log.Printf("Detected URL location: %s", location) + utils.SafeLogPrintf("Detected URL location: %s", location) // Read data first for version detection resp, err := http.Get(location) if err != nil { @@ -81,11 +81,11 @@ func LoadSwagger(location string) (interface{}, string, error) { if !isURL { // Use LoadFromFile for local files - log.Printf("Loading V3 spec using LoadFromFile: %s", absPath) + utils.SafeLogPrintf("Loading V3 spec using LoadFromFile: %s", absPath) doc, loadErr = loader.LoadFromFile(absPath) } else { // Use LoadFromURI for URLs - log.Printf("Loading V3 spec using LoadFromURI: %s", location) + utils.SafeLogPrintf("Loading V3 spec using LoadFromURI: %s", location) doc, loadErr = loader.LoadFromURI(locationURL) } @@ -99,7 +99,7 @@ func LoadSwagger(location string) (interface{}, string, error) { return doc, VersionV3, nil } else if _, ok := detector["swagger"]; ok { // Swagger 2.0 - Still load from data as loads.Analyzed expects bytes - log.Printf("Loading V2 spec using loads.Analyzed from data (source: %s)", location) + utils.SafeLogPrintf("Loading V2 spec using loads.Analyzed from data (source: %s)", location) doc, err := loads.Analyzed(data, "2.0") if err != nil { return nil, "", fmt.Errorf("failed to load or validate Swagger v2 spec from '%s': %w", location, err) @@ -139,7 +139,7 @@ func generateToolSetV3(doc *openapi3.T, cfg *config.Config) (*mcp.ToolSet, error // Determine Base URL once baseURL, err := determineBaseURLV3(doc, cfg) if err != nil { - log.Printf("Warning: Could not determine base URL for V3 spec: %v. Operations might fail if base URL override is not set.", err) + utils.SafeLogPrintf("Warning: Could not determine base URL for V3 spec: %v. Operations might fail if base URL override is not set.", err) baseURL = "" // Allow proceeding if override is set } @@ -175,7 +175,7 @@ func generateToolSetV3(doc *openapi3.T, cfg *config.Config) (*mcp.ToolSet, error // Handle request body requestBody, err := requestBodyToMCPV3(op.RequestBody) if err != nil { - log.Printf("Warning: skipping request body for %s %s due to error: %v", method, rawPath, err) + utils.SafeLogPrintf("Warning: skipping request body for %s %s due to error: %v", method, rawPath, err) } else { // Merge request body schema into the main parameter schema if requestBody.Content != nil { @@ -189,7 +189,7 @@ func generateToolSetV3(doc *openapi3.T, cfg *config.Config) (*mcp.ToolSet, error } } else { // If body is not an object, represent as 'requestBody' - log.Printf("Warning: V3 request body for %s %s is not an object schema. Representing as 'requestBody' field.", method, rawPath) + utils.SafeLogPrintf("Warning: V3 request body for %s %s is not an object schema. Representing as 'requestBody' field.", method, rawPath) parametersSchema.Properties["requestBody"] = mediaTypeSchema } break // Only process the first content type @@ -219,7 +219,7 @@ func generateToolSetV3(doc *openapi3.T, cfg *config.Config) (*mcp.ToolSet, error // Optionally, add a note if the requestBody itself was marked as required if requestBody.Required { // Check the boolean field // How to indicate this? Maybe add to description? - log.Printf("Note: Request body for %s %s is marked as required.", method, rawPath) + utils.SafeLogPrintf("Note: Request body for %s %s is marked as required.", method, rawPath) // Or add all top-level body props to required? Needs decision. } } @@ -309,18 +309,18 @@ func parametersToMCPSchemaAndDetailsV3(params openapi3.Parameters, cfg *config.C opParams := []mcp.ParameterDetail{} for _, paramRef := range params { if paramRef.Value == nil { - log.Printf("Warning: Skipping parameter with nil value.") + utils.SafeLogPrintf("Warning: Skipping parameter with nil value.") continue } param := paramRef.Value if param.Schema == nil { - log.Printf("Warning: Skipping parameter '%s' with nil schema.", param.Name) + utils.SafeLogPrintf("Warning: Skipping parameter '%s' with nil schema.", param.Name) continue } // Skip the API key parameter if configured if cfg.APIKeyName != "" && param.Name == cfg.APIKeyName && param.In == string(cfg.APIKeyLocation) { - log.Printf("Parser V3: Skipping API key parameter '%s' ('%s') from input schema generation.", param.Name, param.In) + utils.SafeLogPrintf("Parser V3: Skipping API key parameter '%s' ('%s') from input schema generation.", param.Name, param.In) continue } @@ -442,7 +442,7 @@ func generateToolSetV2(doc *spec.Swagger, cfg *config.Config) (*mcp.ToolSet, err // Determine Base URL once baseURL, err := determineBaseURLV2(doc, cfg) if err != nil { - log.Printf("Warning: Could not determine base URL for V2 spec: %v. Operations might fail if base URL override is not set.", err) + utils.SafeLogPrintf("Warning: Could not determine base URL for V2 spec: %v. Operations might fail if base URL override is not set.", err) baseURL = "" // Allow proceeding if override is set } @@ -455,7 +455,7 @@ func generateToolSetV2(doc *spec.Swagger, cfg *config.Config) (*mcp.ToolSet, err if secDef.Type == "apiKey" { apiKeyName = secDef.Name apiKeyIn = secDef.In // "query" or "header" - log.Printf("Parser V2: Detected API key from security definition '%s': Name='%s', In='%s'", name, apiKeyName, apiKeyIn) + utils.SafeLogPrintf("Parser V2: Detected API key from security definition '%s': Name='%s', In='%s'", name, apiKeyName, apiKeyIn) break // Assume only one apiKey definition for simplicity } } @@ -519,7 +519,7 @@ func generateToolSetV2(doc *spec.Swagger, cfg *config.Config) (*mcp.ToolSet, err } } else { // If body is not an object, represent as 'requestBody' - log.Printf("Warning: V2 request body for %s %s is not an object schema. Representing as 'requestBody' field.", method, rawPath) + utils.SafeLogPrintf("Warning: V2 request body for %s %s is not an object schema. Representing as 'requestBody' field.", method, rawPath) if parametersSchema.Properties == nil { parametersSchema.Properties = make(map[string]mcp.Schema) } @@ -629,7 +629,7 @@ func parametersToMCPSchemaAndDetailsV2(params []spec.Parameter, definitions spec for _, param := range params { // Skip the API key parameter if it's configured/detected if apiKeyName != "" && param.Name == apiKeyName && (param.In == "query" || param.In == "header") { - log.Printf("Parser V2: Skipping API key parameter '%s' ('%s') from input schema generation.", param.Name, param.In) + utils.SafeLogPrintf("Parser V2: Skipping API key parameter '%s' ('%s') from input schema generation.", param.Name, param.In) continue } @@ -643,7 +643,7 @@ func parametersToMCPSchemaAndDetailsV2(params []spec.Parameter, definitions spec } if param.In != "query" && param.In != "path" && param.In != "header" && param.In != "formData" { - log.Printf("Parser V2: Skipping unsupported parameter type '%s' for parameter '%s'", param.In, param.Name) + utils.SafeLogPrintf("Parser V2: Skipping unsupported parameter type '%s' for parameter '%s'", param.In, param.Name) continue } @@ -702,7 +702,7 @@ func parametersToMCPSchemaAndDetailsV2(params []spec.Parameter, definitions spec } else { // Body param defined without a schema? Treat as simple string. - log.Printf("Warning: V2 body parameter '%s' defined without a schema. Treating as string.", bodyParam.Name) + utils.SafeLogPrintf("Warning: V2 body parameter '%s' defined without a schema. Treating as string.", bodyParam.Name) bodySchema.Type = "string" mcpSchema.Properties[bodyParam.Name] = bodySchema if bodyParam.Required { diff --git a/pkg/server/manager.go b/pkg/server/manager.go index cb18249..ffcc7b8 100644 --- a/pkg/server/manager.go +++ b/pkg/server/manager.go @@ -2,9 +2,10 @@ package server import ( "fmt" - "log" "net/http" "sync" + + "github.com/ckanthony/openapi-mcp/pkg/utils" ) // client holds information about a connected SSE client. @@ -36,7 +37,7 @@ func (m *connectionManager) addClient(r *http.Request, w http.ResponseWriter, f m.clients[r] = newClient m.mu.Unlock() - log.Printf("Client connected: %s (Total: %d)", r.RemoteAddr, m.getClientCount()) + utils.SafeLogPrintf("Client connected: %s (Total: %d)", r.RemoteAddr, m.getClientCount()) // Send initial toolset immediately go m.sendToolset(newClient) // Send in a goroutine to avoid blocking registration? @@ -48,9 +49,9 @@ func (m *connectionManager) removeClient(r *http.Request) { _, ok := m.clients[r] if ok { delete(m.clients, r) - log.Printf("Client disconnected: %s (Total: %d)", r.RemoteAddr, len(m.clients)) + utils.SafeLogPrintf("Client disconnected: %s (Total: %d)", r.RemoteAddr, len(m.clients)) } else { - log.Printf("Attempted to remove already disconnected client: %s", r.RemoteAddr) + utils.SafeLogPrintf("Attempted to remove already disconnected client: %s", r.RemoteAddr) } m.mu.Unlock() } @@ -68,15 +69,15 @@ func (m *connectionManager) sendToolset(c *client) { if c == nil { return } - log.Printf("Attempting to send toolset to client...") + utils.SafeLogPrintf("Attempting to send toolset to client...") _, err := fmt.Fprintf(c.writer, "event: tool_set\ndata: %s\n\n", string(m.toolSet)) if err != nil { // This error often happens if the client disconnected before/during the write - log.Printf("Error sending toolset data to client: %v (client likely disconnected)", err) + utils.SafeLogPrintf("Error sending toolset data to client: %v (client likely disconnected)", err) // Optionally trigger removal here if possible, though context done in handler is primary mechanism return } // Flush the data c.flusher.Flush() - log.Println("Sent tool_set event and flushed.") + utils.SafeLogPrintf("Sent tool_set event and flushed.") } diff --git a/pkg/server/server.go b/pkg/server/server.go index e4fd5f6..cc34e3a 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "net/url" "strings" @@ -18,6 +17,7 @@ import ( "github.com/ckanthony/openapi-mcp/pkg/config" "github.com/ckanthony/openapi-mcp/pkg/mcp" + "github.com/ckanthony/openapi-mcp/pkg/utils" "github.com/google/uuid" // Import UUID package ) @@ -99,7 +99,7 @@ const messageChannelBufferSize = 10 // ServeMCP starts an HTTP server handling MCP communication. func ServeMCP(addr string, toolSet *mcp.ToolSet, cfg *config.Config) error { - log.Printf("Preparing ToolSet for MCP...") + utils.SafeLogPrintf("Preparing ToolSet for MCP...") // --- Handler Functions --- mcpHandler := func(w http.ResponseWriter, r *http.Request) { @@ -110,7 +110,7 @@ func ServeMCP(addr string, toolSet *mcp.ToolSet, cfg *config.Config) error { w.Header().Set("Access-Control-Expose-Headers", "X-Connection-ID") if r.Method == http.MethodOptions { - log.Println("Responding to OPTIONS request") + utils.SafeLogPrintf("Responding to OPTIONS request") w.WriteHeader(http.StatusNoContent) // Use 204 No Content for OPTIONS return } @@ -120,7 +120,7 @@ func ServeMCP(addr string, toolSet *mcp.ToolSet, cfg *config.Config) error { } else if r.Method == http.MethodPost { httpMethodPostHandler(w, r, toolSet, cfg) // Pass the cfg object here } else { - log.Printf("Method Not Allowed: %s", r.Method) + utils.SafeLogPrintf("Method Not Allowed: %s", r.Method) http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) } } @@ -129,19 +129,19 @@ func ServeMCP(addr string, toolSet *mcp.ToolSet, cfg *config.Config) error { mux := http.NewServeMux() mux.HandleFunc("/mcp", mcpHandler) // Single endpoint for GET/POST/OPTIONS - log.Printf("MCP server listening on %s/mcp", addr) + utils.SafeLogPrintf("MCP server listening on %s/mcp", addr) return http.ListenAndServe(addr, mux) } // httpMethodGetHandler handles the initial GET request to establish the SSE connection. func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) { connectionID := uuid.New().String() - log.Printf("SSE client connecting: %s (Assigning ID: %s)", r.RemoteAddr, connectionID) + utils.SafeLogPrintf("SSE client connecting: %s (Assigning ID: %s)", r.RemoteAddr, connectionID) flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) - log.Println("Error: Client connection does not support flushing") + utils.SafeLogPrintf("Error: Client connection does not support flushing") return } @@ -157,20 +157,20 @@ func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) { // --- Send initial :ok --- (Must happen *after* headers) if _, err := fmt.Fprintf(w, ":ok\n\n"); err != nil { - log.Printf("Error sending SSE preamble to %s (ID: %s): %v", r.RemoteAddr, connectionID, err) + utils.SafeLogPrintf("Error sending SSE preamble to %s (ID: %s): %v", r.RemoteAddr, connectionID, err) return // Cannot proceed if preamble fails } flusher.Flush() - log.Printf("Sent :ok preamble to %s (ID: %s)", r.RemoteAddr, connectionID) + utils.SafeLogPrintf("Sent :ok preamble to %s (ID: %s)", r.RemoteAddr, connectionID) // --- Send initial SSE events --- (endpoint, mcp-ready) endpointURL := fmt.Sprintf("/mcp?sessionId=%s", connectionID) // Assuming /mcp is the mount path if err := writeSSEEvent(w, "endpoint", endpointURL); err != nil { - log.Printf("Error sending SSE endpoint event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err) + utils.SafeLogPrintf("Error sending SSE endpoint event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err) return } flusher.Flush() - log.Printf("Sent endpoint event to %s (ID: %s)", r.RemoteAddr, connectionID) + utils.SafeLogPrintf("Sent endpoint event to %s (ID: %s)", r.RemoteAddr, connectionID) readyMsg := jsonRPCRequest{ // Use request struct for notification format Jsonrpc: "2.0", @@ -182,18 +182,18 @@ func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) { }, } if err := writeSSEEvent(w, "message", readyMsg); err != nil { - log.Printf("Error sending SSE mcp-ready event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err) + utils.SafeLogPrintf("Error sending SSE mcp-ready event to %s (ID: %s): %v", r.RemoteAddr, connectionID, err) return } flusher.Flush() - log.Printf("Sent mcp-ready event to %s (ID: %s)", r.RemoteAddr, connectionID) + utils.SafeLogPrintf("Sent mcp-ready event to %s (ID: %s)", r.RemoteAddr, connectionID) // --- Setup message channel and store connection --- msgChan := make(chan jsonRPCResponse, messageChannelBufferSize) // Channel for responses connMutex.Lock() activeConnections[connectionID] = msgChan connMutex.Unlock() - log.Printf("Registered channel for connection %s. Active connections: %d", connectionID, len(activeConnections)) + utils.SafeLogPrintf("Registered channel for connection %s. Active connections: %d", connectionID, len(activeConnections)) // --- Cleanup function --- cleanup := func() { @@ -201,7 +201,7 @@ func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) { delete(activeConnections, connectionID) connMutex.Unlock() close(msgChan) // Close channel when connection ends - log.Printf("Removed connection %s. Active connections: %d", connectionID, len(activeConnections)) + utils.SafeLogPrintf("Removed connection %s. Active connections: %d", connectionID, len(activeConnections)) } defer cleanup() @@ -210,20 +210,20 @@ func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) { defer cancel() go func() { - log.Printf("[SSE Writer %s] Starting message writer goroutine", connectionID) - defer log.Printf("[SSE Writer %s] Exiting message writer goroutine", connectionID) + utils.SafeLogPrintf("[SSE Writer %s] Starting message writer goroutine", connectionID) + defer utils.SafeLogPrintf("[SSE Writer %s] Exiting message writer goroutine", connectionID) for { select { case <-ctx.Done(): return // Exit if main context is cancelled case resp, ok := <-msgChan: if !ok { - log.Printf("[SSE Writer %s] Message channel closed.", connectionID) + utils.SafeLogPrintf("[SSE Writer %s] Message channel closed.", connectionID) return // Exit if channel is closed } - log.Printf("[SSE Writer %s] Sending message (ID: %v) via SSE", connectionID, resp.ID) + utils.SafeLogPrintf("[SSE Writer %s] Sending message (ID: %v) via SSE", connectionID, resp.ID) if err := writeSSEEvent(w, "message", resp); err != nil { - log.Printf("[SSE Writer %s] Error writing message to SSE stream: %v. Cancelling context.", connectionID, err) + utils.SafeLogPrintf("[SSE Writer %s] Error writing message to SSE stream: %v. Cancelling context.", connectionID, err) cancel() // Signal main loop to exit on write error return } @@ -236,11 +236,11 @@ func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) { keepAliveTicker := time.NewTicker(20 * time.Second) defer keepAliveTicker.Stop() - log.Printf("[SSE %s] Entering keep-alive loop", connectionID) + utils.SafeLogPrintf("[SSE %s] Entering keep-alive loop", connectionID) for { select { case <-ctx.Done(): - log.Printf("[SSE %s] Context done. Exiting keep-alive loop.", connectionID) + utils.SafeLogPrintf("[SSE %s] Context done. Exiting keep-alive loop.", connectionID) return // Exit loop if context cancelled (client disconnect or write error) case <-keepAliveTicker.C: // Send JSON-RPC ping notification instead of SSE comment @@ -252,7 +252,7 @@ func httpMethodGetHandler(w http.ResponseWriter, r *http.Request) { }, } if err := writeSSEEvent(w, "message", pingMsg); err != nil { - log.Printf("[SSE %s] Error sending ping notification: %v. Closing connection.", connectionID, err) + utils.SafeLogPrintf("[SSE %s] Error sending ping notification: %v. Closing connection.", connectionID, err) cancel() // Signal writer goroutine and exit return } @@ -303,11 +303,11 @@ func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp. connID := r.Header.Get("X-Connection-ID") // Try header first if connID == "" { connID = r.URL.Query().Get("sessionId") // Fallback to query parameter - log.Printf("X-Connection-ID header missing, checking sessionId query param: found='%s'", connID) + utils.SafeLogPrintf("X-Connection-ID header missing, checking sessionId query param: found='%s'", connID) } if connID == "" { - log.Println("Error: POST request received without X-Connection-ID header or sessionId query parameter") + utils.SafeLogPrintf("Error: POST request received without X-Connection-ID header or sessionId query parameter") http.Error(w, "Missing X-Connection-ID header or sessionId query parameter", http.StatusBadRequest) return } @@ -318,7 +318,7 @@ func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp. connMutex.RUnlock() if !isActive { - log.Printf("Error: POST request received for inactive/unknown connection ID: %s", connID) + utils.SafeLogPrintf("Error: POST request received for inactive/unknown connection ID: %s", connID) // Still send sync error here, as we don't have a channel tryWriteHTTPError(w, http.StatusNotFound, "Invalid or expired connection ID") return @@ -326,7 +326,7 @@ func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp. bodyBytes, err := io.ReadAll(r.Body) if err != nil { - log.Printf("Error reading POST request body for %s: %v", connID, err) + utils.SafeLogPrintf("Error reading POST request body for %s: %v", connID, err) // Create error response in the ToolResultPayload format errPayload := ToolResultPayload{ IsError: true, @@ -346,12 +346,12 @@ func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp. // Attempt to send via SSE channel select { case msgChan <- errResp: - log.Printf("Queued read error response (ID: %v) for %s onto SSE channel (as Result)", errResp.ID, connID) + utils.SafeLogPrintf("Queued read error response (ID: %v) for %s onto SSE channel (as Result)", errResp.ID, connID) // Send HTTP 202 Accepted back to the POST request w.WriteHeader(http.StatusAccepted) fmt.Fprintln(w, "Request accepted (with parse error), response will be sent via SSE.") default: - log.Printf("Error: Failed to queue read error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID) + utils.SafeLogPrintf("Error: Failed to queue read error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID) // Send an error back on the POST request if channel fails tryWriteHTTPError(w, http.StatusInternalServerError, "Failed to queue error response for SSE channel") } @@ -359,7 +359,7 @@ func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp. } // No defer r.Body.Close() needed here as io.ReadAll reads to EOF - log.Printf("Received POST data for %s: %s", connID, string(bodyBytes)) + utils.SafeLogPrintf("Received POST data for %s", connID) // Attempt to unmarshal into a temporary map first to extract ID if possible var rawReq map[string]interface{} @@ -375,26 +375,26 @@ func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp. } } else { // Full unmarshal failed, log it but continue to try specific struct - log.Printf("Warning: Initial unmarshal into map failed for %s: %v. Will attempt specific struct unmarshal.", connID, err) + utils.SafeLogPrintf("Warning: Initial unmarshal into map failed for %s: %v. Will attempt specific struct unmarshal.", connID, err) reqID = nil // ID is unknown } var req jsonRPCRequest // Expect JSON-RPC request if err := json.Unmarshal(bodyBytes, &req); err != nil { - log.Printf("Error decoding JSON-RPC request for %s: %v", connID, err) + utils.SafeLogPrintf("Error decoding JSON-RPC request for %s: %v", connID, err) // Use createJSONRPCError to correctly format the error response errResp := createJSONRPCError(reqID, -32700, "Parse error decoding JSON request", err.Error()) // Attempt to send via SSE channel select { case msgChan <- errResp: - log.Printf("Queued decode error response (ID: %v) for %s onto SSE channel", errResp.ID, connID) + utils.SafeLogPrintf("Queued decode error response (ID: %v) for %s onto SSE channel", errResp.ID, connID) // Send HTTP 202 Accepted back to the POST request w.WriteHeader(http.StatusAccepted) // Use a specific message for decode errors fmt.Fprintln(w, "Request accepted (with decode error), response will be sent via SSE.") default: - log.Printf("Error: Failed to queue decode error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID) + utils.SafeLogPrintf("Error: Failed to queue decode error response (ID: %v) for %s - SSE channel likely full or closed.", errResp.ID, connID) // Send an error back on the POST request if channel fails tryWriteHTTPError(w, http.StatusInternalServerError, "Failed to queue error response for SSE channel") } @@ -413,23 +413,21 @@ func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp. // --- Validate JSON-RPC Request --- if req.Jsonrpc != "2.0" { - log.Printf("Invalid JSON-RPC version ('%s') for %s, ID: %v", req.Jsonrpc, connID, reqID) + utils.SafeLogPrintf("Invalid JSON-RPC version ('%s') for %s, ID: %v", req.Jsonrpc, connID, reqID) respToSend = createJSONRPCError(reqID, -32600, "Invalid Request: jsonrpc field must be \"2.0\"", nil) } else if req.Method == "" { - log.Printf("Missing JSON-RPC method for %s, ID: %v", connID, reqID) + utils.SafeLogPrintf("Missing JSON-RPC method for %s, ID: %v", connID, reqID) respToSend = createJSONRPCError(reqID, -32600, "Invalid Request: method field is missing or empty", nil) } else { // --- Process the valid request --- - log.Printf("Processing JSON-RPC message for %s: Method=%s, ID=%v", connID, req.Method, reqID) + utils.SafeLogPrintf("Processing JSON-RPC message for %s: Method=%s, ID=%v", connID, req.Method, reqID) switch req.Method { case "initialize": - incomingInitializeJSON, _ := json.Marshal(req) - log.Printf("DEBUG: Handling 'initialize' for %s. Incoming request: %s", connID, string(incomingInitializeJSON)) + utils.SafeLogPrintf("DEBUG: Handling 'initialize' for %s", connID) respToSend = handleInitializeJSONRPC(connID, &req) - outgoingInitializeJSON, _ := json.Marshal(respToSend) - log.Printf("DEBUG: Prepared 'initialize' response for %s. Outgoing response: %s", connID, string(outgoingInitializeJSON)) + utils.SafeLogPrintf("DEBUG: Prepared 'initialize' response for %s", connID) case "notifications/initialized": - log.Printf("Received 'notifications/initialized' notification for %s. Ignoring.", connID) + utils.SafeLogPrintf("Received 'notifications/initialized' notification for %s. Ignoring.", connID) w.WriteHeader(http.StatusAccepted) fmt.Fprintln(w, "Notification received.") return // Return early, do not send anything on SSE channel @@ -438,7 +436,7 @@ func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp. case "tools/call": respToSend = handleToolCallJSONRPC(connID, &req, toolSet, cfg) default: - log.Printf("Received unknown JSON-RPC method '%s' for %s", req.Method, connID) + utils.SafeLogPrintf("Received unknown JSON-RPC method '%s' for %s", req.Method, connID) respToSend = createJSONRPCError(reqID, -32601, fmt.Sprintf("Method not found: %s", req.Method), nil) } } @@ -446,13 +444,13 @@ func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp. // --- Send response ASYNCHRONOUSLY via SSE channel (unless handled earlier) --- select { case msgChan <- respToSend: - log.Printf("Queued response (ID: %v) for %s onto SSE channel", respToSend.ID, connID) + utils.SafeLogPrintf("Queued response (ID: %v) for %s onto SSE channel", respToSend.ID, connID) // Send HTTP 202 Accepted back to the POST request w.WriteHeader(http.StatusAccepted) // Use the standard message for successfully queued responses fmt.Fprintln(w, "Request accepted, response will be sent via SSE.") default: - log.Printf("Error: Failed to queue response (ID: %v) for %s - SSE channel likely full or closed.", respToSend.ID, connID) + utils.SafeLogPrintf("Error: Failed to queue response (ID: %v) for %s - SSE channel likely full or closed.", respToSend.ID, connID) http.Error(w, "Failed to queue response for SSE channel", http.StatusInternalServerError) } } @@ -460,7 +458,7 @@ func httpMethodPostHandler(w http.ResponseWriter, r *http.Request, toolSet *mcp. // --- JSON-RPC Message Handlers --- // Implementations returning jsonRPCResponse func handleInitializeJSONRPC(connID string, req *jsonRPCRequest) jsonRPCResponse { - log.Printf("Handling 'initialize' (JSON-RPC) for %s", connID) + utils.SafeLogPrintf("Handling 'initialize' (JSON-RPC) for %s", connID) // Construct the result payload based on gin-mcp's structure using map[string]interface{} resultPayload := map[string]interface{}{ @@ -501,7 +499,7 @@ func handleInitializeJSONRPC(connID string, req *jsonRPCRequest) jsonRPCResponse } func handleToolsListJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.ToolSet) jsonRPCResponse { - log.Printf("Handling 'tools/list' (JSON-RPC) for %s", connID) + utils.SafeLogPrintf("Handling 'tools/list' (JSON-RPC) for %s", connID) // Construct the result payload based on gin-mcp's structure resultPayload := map[string]interface{}{ @@ -525,13 +523,13 @@ func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.C toolName := params.ToolName toolInput := params.Input // This is the map[string]interface{} from the client - log.Printf("[ExecuteToolCall] Looking up details for tool: %s", toolName) + utils.SafeLogPrintf("[ExecuteToolCall] Looking up details for tool: %s", toolName) operation, ok := toolSet.Operations[toolName] if !ok { - log.Printf("[ExecuteToolCall] Error: Operation details not found for tool '%s'", toolName) + utils.SafeLogPrintf("[ExecuteToolCall] Error: Operation details not found for tool '%s'", toolName) return nil, fmt.Errorf("operation details for tool '%s' not found", toolName) } - log.Printf("[ExecuteToolCall] Found operation: Method=%s, Path=%s", operation.Method, operation.Path) + utils.SafeLogPrintf("[ExecuteToolCall] Found operation: Method=%s, Path=%s", operation.Method, operation.Path) // --- Resolve API Key (using cfg passed from main) --- resolvedKey := cfg.GetAPIKey() @@ -539,16 +537,16 @@ func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.C apiKeyLocation := cfg.APIKeyLocation hasServerKey := resolvedKey != "" && apiKeyName != "" && apiKeyLocation != "" - log.Printf("[ExecuteToolCall] API Key Details: Name='%s', In='%s', HasServerValue=%t", apiKeyName, apiKeyLocation, resolvedKey != "") + utils.SafeLogPrintf("[ExecuteToolCall] API Key Details: Name='%s', In='%s', HasServerValue=%t", apiKeyName, apiKeyLocation, resolvedKey != "") // --- Prepare Request Components --- baseURL := operation.BaseURL // Use BaseURL from the specific operation if cfg.ServerBaseURL != "" { baseURL = cfg.ServerBaseURL // Override if global base URL is set - log.Printf("[ExecuteToolCall] Overriding base URL with global config: %s", baseURL) + utils.SafeLogPrintf("[ExecuteToolCall] Overriding base URL with global config: %s", baseURL) } if baseURL == "" { - log.Printf("[ExecuteToolCall] Warning: No base URL found for operation %s and no global override set.", toolName) + utils.SafeLogPrintf("[ExecuteToolCall] Warning: No base URL found for operation %s and no global override set.", toolName) // For now, assume relative if empty. } @@ -567,13 +565,13 @@ func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.C } // --- Process Input Parameters (Separating and Handling API Key Override) --- - log.Printf("[ExecuteToolCall] Processing %d input parameters...", len(toolInput)) + utils.SafeLogPrintf("[ExecuteToolCall] Processing %d input parameters...", len(toolInput)) for key, value := range toolInput { // --- API Key Override Check --- // If this input param is the API key AND we have a valid server key config, // skip processing the client's value entirely. if hasServerKey && key == apiKeyName { - log.Printf("[ExecuteToolCall] Skipping client-provided param '%s' due to server API key override.", key) + utils.SafeLogPrintf("[ExecuteToolCall] Skipping client-provided param '%s' due to server API key override.", key) continue } // --- End API Key Override --- @@ -584,43 +582,43 @@ func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.C if strings.Contains(path, pathPlaceholder) { // Handle path parameter substitution pathParams[key] = fmt.Sprintf("%v", value) - log.Printf("[ExecuteToolCall] Found path parameter %s=%v", key, value) + utils.SafeLogPrintf("[ExecuteToolCall] Found path parameter %s=%v", key, value) } else if knownParam { // Handle parameters defined in the spec (query, header, cookie) switch paramLocation { case "query": queryParams.Add(key, fmt.Sprintf("%v", value)) - log.Printf("[ExecuteToolCall] Found query parameter %s=%v (from spec)", key, value) + utils.SafeLogPrintf("[ExecuteToolCall] Found query parameter %s=%v (from spec)", key, value) case "header": headerParams.Add(key, fmt.Sprintf("%v", value)) - log.Printf("[ExecuteToolCall] Found header parameter %s=%v (from spec)", key, value) + utils.SafeLogPrintf("[ExecuteToolCall] Found header parameter %s=%v (from spec)", key, value) case "cookie": cookieParams = append(cookieParams, &http.Cookie{Name: key, Value: fmt.Sprintf("%v", value)}) - log.Printf("[ExecuteToolCall] Found cookie parameter %s=%v (from spec)", key, value) + utils.SafeLogPrintf("[ExecuteToolCall] Found cookie parameter %s=%v (from spec)", key, value) // case "formData": // TODO: Handle form data if needed // bodyData[key] = value // Or handle differently based on content type - // log.Printf("[ExecuteToolCall] Found formData parameter %s=%v (from spec)", key, value) + // utils.SafeLogPrintf("[ExecuteToolCall] Found formData parameter %s=%v (from spec)", key, value) default: // Known parameter but location handling is missing or mismatched. if paramLocation == "path" && (operation.Method == "GET" || operation.Method == "DELETE") { // If spec says 'path' but it wasn't in the actual path, and it's a GET/DELETE, // treat it as a query parameter as a fallback. - log.Printf("[ExecuteToolCall] Warning: Parameter '%s' is 'path' in spec but not in URL path '%s'. Adding to query parameters as fallback for GET/DELETE.", key, operation.Path) + utils.SafeLogPrintf("[ExecuteToolCall] Warning: Parameter '%s' is 'path' in spec but not in URL path '%s'. Adding to query parameters as fallback for GET/DELETE.", key, operation.Path) queryParams.Add(key, fmt.Sprintf("%v", value)) } else { // Otherwise, log the warning and ignore. - log.Printf("[ExecuteToolCall] Warning: Parameter '%s' has unsupported or unhandled location '%s' in spec. Ignoring.", key, paramLocation) + utils.SafeLogPrintf("[ExecuteToolCall] Warning: Parameter '%s' has unsupported or unhandled location '%s' in spec. Ignoring.", key, paramLocation) } } } else if requestBodyRequired { // If parameter is not in path or defined in spec params, and method expects a body, // assume it belongs in the request body. bodyData[key] = value - log.Printf("[ExecuteToolCall] Added body parameter %s=%v (assumed)", key, value) + utils.SafeLogPrintf("[ExecuteToolCall] Added body parameter %s=%v (assumed)", key, value) } else { // Parameter not in path, not in spec, and not a body method. // This could be an extraneous parameter like 'explanation'. Log it. - log.Printf("[ExecuteToolCall] Ignoring parameter '%s' as it doesn't match path or known parameter location for method %s.", key, operation.Method) + utils.SafeLogPrintf("[ExecuteToolCall] Ignoring parameter '%s' as it doesn't match path or known parameter location for method %s.", key, operation.Method) } } @@ -631,43 +629,43 @@ func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.C // --- Inject Server API Key (if applicable) --- if hasServerKey { - log.Printf("[ExecuteToolCall] Injecting server API key (Name: %s, Location: %s)", apiKeyName, string(apiKeyLocation)) + utils.SafeLogPrintf("[ExecuteToolCall] Injecting server API key (Name: %s, Location: %s)", apiKeyName, string(apiKeyLocation)) switch apiKeyLocation { case config.APIKeyLocationQuery: queryParams.Set(apiKeyName, resolvedKey) // Set overrides any previous value - log.Printf("[ExecuteToolCall] Injected API key '%s' into query parameters", apiKeyName) + utils.SafeLogPrintf("[ExecuteToolCall] Injected API key '%s' into query parameters", apiKeyName) case config.APIKeyLocationHeader: headerParams.Set(apiKeyName, resolvedKey) // Set overrides any previous value - log.Printf("[ExecuteToolCall] Injected API key '%s' into headers", apiKeyName) + utils.SafeLogPrintf("[ExecuteToolCall] Injected API key '%s' into headers", apiKeyName) case config.APIKeyLocationPath: pathPlaceholder := "{" + apiKeyName + "}" if strings.Contains(path, pathPlaceholder) { path = strings.Replace(path, pathPlaceholder, resolvedKey, -1) - log.Printf("[ExecuteToolCall] Injected API key into path parameter '%s'", apiKeyName) + utils.SafeLogPrintf("[ExecuteToolCall] Injected API key into path parameter '%s'", apiKeyName) } else { - log.Printf("[ExecuteToolCall] Warning: API key location is 'path' but placeholder '%s' not found in final path '%s' for injection.", pathPlaceholder, path) + utils.SafeLogPrintf("[ExecuteToolCall] Warning: API key location is 'path' but placeholder '%s' not found in final path '%s' for injection.", pathPlaceholder, path) } case config.APIKeyLocationCookie: // Check if cookie already exists from input, replace if so foundCookie := false for i, c := range cookieParams { if c.Name == apiKeyName { - log.Printf("[ExecuteToolCall] Replacing existing cookie '%s' with injected API key.", apiKeyName) + utils.SafeLogPrintf("[ExecuteToolCall] Replacing existing cookie '%s' with injected API key.", apiKeyName) cookieParams[i] = &http.Cookie{Name: apiKeyName, Value: resolvedKey} // Replace existing foundCookie = true break } } if !foundCookie { - log.Printf("[ExecuteToolCall] Adding new cookie '%s' with injected API key.", apiKeyName) + utils.SafeLogPrintf("[ExecuteToolCall] Adding new cookie '%s' with injected API key.", apiKeyName) cookieParams = append(cookieParams, &http.Cookie{Name: apiKeyName, Value: resolvedKey}) // Append new } default: - // Use log.Printf for consistency - log.Printf("Warning: Unsupported API key location specified in config: '%s'", apiKeyLocation) + // Use utils.SafeLogPrintf for consistency + utils.SafeLogPrintf("Warning: Unsupported API key location specified in config: '%s'", apiKeyLocation) } } else { - log.Printf("[ExecuteToolCall] Skipping server API key injection (config incomplete or key unresolved).") + utils.SafeLogPrintf("[ExecuteToolCall] Skipping server API key injection (config incomplete or key unresolved).") } // --- Final URL Construction --- @@ -676,7 +674,7 @@ func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.C if len(queryParams) > 0 { targetURL += "?" + queryParams.Encode() } - log.Printf("[ExecuteToolCall] Final Target URL: %s %s", operation.Method, targetURL) + utils.SafeLogPrintf("[ExecuteToolCall] Final Target URL: %s %s", operation.Method, targetURL) // --- Prepare Request Body --- var reqBody io.Reader @@ -685,17 +683,17 @@ func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.C var err error bodyBytes, err = json.Marshal(bodyData) if err != nil { - log.Printf("[ExecuteToolCall] Error marshalling request body: %v", err) + utils.SafeLogPrintf("[ExecuteToolCall] Error marshalling request body: %v", err) return nil, fmt.Errorf("error marshalling request body: %w", err) } reqBody = bytes.NewBuffer(bodyBytes) - log.Printf("[ExecuteToolCall] Request body: %s", string(bodyBytes)) + utils.SafeLogPrintf("[ExecuteToolCall] Request body: %s", string(bodyBytes)) } // --- Create HTTP Request --- req, err := http.NewRequest(operation.Method, targetURL, reqBody) if err != nil { - log.Printf("[ExecuteToolCall] Error creating HTTP request: %v", err) + utils.SafeLogPrintf("[ExecuteToolCall] Error creating HTTP request: %v", err) return nil, fmt.Errorf("error creating request: %w", err) } @@ -725,7 +723,7 @@ func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.C headerValue := strings.TrimSpace(parts[1]) if headerName != "" { req.Header.Set(headerName, headerValue) // Set overrides potential input - log.Printf("[ExecuteToolCall] Added custom header from config: %s", headerName) + utils.SafeLogPrintf("[ExecuteToolCall] Added custom header from config: %s", headerName) } } } @@ -736,21 +734,21 @@ func executeToolCall(params *ToolCallParams, toolSet *mcp.ToolSet, cfg *config.C req.AddCookie(cookie) } - log.Printf("[ExecuteToolCall] Sending request with headers: %v", req.Header) + utils.SafeLogPrintf("[ExecuteToolCall] Sending request with headers: %v", req.Header) if len(req.Cookies()) > 0 { - log.Printf("[ExecuteToolCall] Sending request with cookies: %+v", req.Cookies()) + utils.SafeLogPrintf("[ExecuteToolCall] Sending request with cookies: %+v", req.Cookies()) } // --- Execute HTTP Request --- - log.Printf("[ExecuteToolCall] Sending request with headers: %v", req.Header) + utils.SafeLogPrintf("[ExecuteToolCall] Sending request with headers: %v", req.Header) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { - log.Printf("[ExecuteToolCall] Error executing HTTP request: %v", err) + utils.SafeLogPrintf("[ExecuteToolCall] Error executing HTTP request: %v", err) return nil, fmt.Errorf("error executing request: %w", err) } - log.Printf("[ExecuteToolCall] Request executed. Status Code: %d", resp.StatusCode) + utils.SafeLogPrintf("[ExecuteToolCall] Request executed. Status Code: %d", resp.StatusCode) // Note: Don't close resp.Body here, the caller (handleToolCallJSONRPC) needs it. return resp, nil } @@ -765,26 +763,26 @@ func handleToolCallJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.Tool var marshalErr error rawParams, marshalErr = json.Marshal(paramsMap) if marshalErr != nil { - log.Printf("Error marshalling params map for %s: %v", connID, marshalErr) + utils.SafeLogPrintf("Error marshalling params map for %s: %v", connID, marshalErr) return createJSONRPCError(req.ID, -32602, "Invalid parameters format (map marshal failed)", marshalErr.Error()) } - log.Printf("Handling 'tools/call' (JSON-RPC) for %s, Params: %s (from map)", connID, string(rawParams)) + utils.SafeLogPrintf("Handling 'tools/call' (JSON-RPC) for %s", connID) } else { - log.Printf("Invalid parameters format for tools/call (not json.RawMessage or map[string]interface{}): %T", req.Params) + utils.SafeLogPrintf("Invalid parameters format for tools/call (not json.RawMessage or map[string]interface{}): %T", req.Params) return createJSONRPCError(req.ID, -32602, "Invalid parameters format (expected JSON object)", nil) } } else { - log.Printf("Handling 'tools/call' (JSON-RPC) for %s, Params: %s (from RawMessage)", connID, string(rawParams)) + utils.SafeLogPrintf("Handling 'tools/call' (JSON-RPC) for %s", connID) } // Now, unmarshal the rawParams ([]byte) into ToolCallParams var params ToolCallParams if err := json.Unmarshal(rawParams, ¶ms); err != nil { - log.Printf("Error unmarshalling tools/call params for %s: %v", connID, err) + utils.SafeLogPrintf("Error unmarshalling tools/call params for %s: %v", connID, err) return createJSONRPCError(req.ID, -32602, "Invalid parameters structure (unmarshal)", err.Error()) } - log.Printf("Executing tool '%s' for %s with input: %+v", params.ToolName, connID, params.Input) + utils.SafeLogPrintf("Executing tool '%s' for %s with input: %+v", params.ToolName, connID, params.Input) // --- Execute the actual tool call --- httpResp, execErr := executeToolCall(¶ms, toolSet, cfg) @@ -792,7 +790,7 @@ func handleToolCallJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.Tool // --- Process Response --- var resultPayload ToolResultPayload if execErr != nil { - log.Printf("Error executing tool call '%s': %v", params.ToolName, execErr) + utils.SafeLogPrintf("Error executing tool call '%s': %v", params.ToolName, execErr) // Populate content with error message resultContent := []ToolResultContent{ { @@ -812,7 +810,7 @@ func handleToolCallJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.Tool defer httpResp.Body.Close() // Ensure body is closed bodyBytes, readErr := io.ReadAll(httpResp.Body) if readErr != nil { - log.Printf("Error reading response body for tool '%s': %v", params.ToolName, readErr) + utils.SafeLogPrintf("Error reading response body for tool '%s': %v", params.ToolName, readErr) // Populate content with error message resultContent := []ToolResultContent{ { @@ -829,7 +827,7 @@ func handleToolCallJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.Tool ToolCallID: fmt.Sprintf("%v", req.ID), } } else { - log.Printf("Received response body for tool '%s': %s", params.ToolName, string(bodyBytes)) + utils.SafeLogPrintf("Received response body for tool '%s': %s", params.ToolName, string(bodyBytes)) // Check status code for API-level errors if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { // Error case: populate content with the error response body @@ -882,11 +880,11 @@ func handleToolCallJSONRPC(connID string, req *jsonRPCRequest, toolSet *mcp.Tool func sendJSONRPCResponse(w http.ResponseWriter, resp jsonRPCResponse) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(resp); err != nil { - log.Printf("Error encoding JSON-RPC response (ID: %v) for ConnID %v: %v", resp.ID, resp.Error, err) + utils.SafeLogPrintf("Error encoding JSON-RPC response (ID: %v) for ConnID %v: %v", resp.ID, resp.Error, err) // Attempt to send a plain text error if JSON encoding fails tryWriteHTTPError(w, http.StatusInternalServerError, "Internal Server Error encoding JSON-RPC response") } - log.Printf("Sent JSON-RPC response: Method=%s, ID=%v", getMethodFromResponse(resp), resp.ID) + utils.SafeLogPrintf("Sent JSON-RPC response: Method=%s, ID=%v", getMethodFromResponse(resp), resp.ID) } // createJSONRPCError creates a JSON-RPC error response. @@ -902,7 +900,7 @@ func createJSONRPCError(id interface{}, code int, message string, data interface // sendJSONRPCError sends a JSON-RPC error response. func sendJSONRPCError(w http.ResponseWriter, connID string, id interface{}, code int, message string, data interface{}) { resp := createJSONRPCError(id, code, message, data) - log.Printf("Sending JSON-RPC Error for ConnID %s, ID %v: Code=%d, Message='%s'", connID, id, code, message) + utils.SafeLogPrintf("Sending JSON-RPC Error for ConnID %s, ID %v: Code=%d, Message='%s'", connID, id, code, message) sendJSONRPCResponse(w, resp) } @@ -930,7 +928,7 @@ func getMethodFromResponse(resp jsonRPCResponse) string { // tryWriteHTTPError attempts to write an HTTP error, ignoring failures. func tryWriteHTTPError(w http.ResponseWriter, code int, message string) { if _, err := w.Write([]byte(message)); err != nil { - log.Printf("Error writing plain HTTP error response: %v", err) + utils.SafeLogPrintf("Error writing plain HTTP error response: %v", err) } - log.Printf("Sent plain HTTP error: %s (Code: %d)", message, code) + utils.SafeLogPrintf("Sent plain HTTP error: %s (Code: %d)", message, code) } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 06f7318..ea129fb 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "net/http/httptest" "strings" @@ -16,6 +15,7 @@ import ( "github.com/ckanthony/openapi-mcp/pkg/config" "github.com/ckanthony/openapi-mcp/pkg/mcp" + "github.com/ckanthony/openapi-mcp/pkg/utils" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -378,7 +378,7 @@ func TestHttpMethodGetHandler(t *testing.T) { for id, ch := range activeConnections { close(ch) delete(activeConnections, id) - log.Printf("[DEFER Cleanup] Closed channel and removed connection %s", id) + utils.SafeLogPrintf("[DEFER Cleanup] Closed channel and removed connection %s", id) } activeConnections = originalConnections // Restore the original map connMutex.Unlock() @@ -872,7 +872,7 @@ func (m *sseMockResponseWriter) Write(p []byte) (int, error) { // Check if write count triggers failure if m.failAfterNWrites >= 0 && m.writesMade >= m.failAfterNWrites { m.forceError = fmt.Errorf("forced write error after %d writes", m.failAfterNWrites) - log.Printf("DEBUG: sseMockResponseWriter triggering error: %v", m.forceError) // Debug log + utils.SafeLogPrintf("DEBUG: sseMockResponseWriter triggering error: %v", m.forceError) // Debug log return 0, m.forceError } @@ -1002,7 +1002,7 @@ func TestHttpMethodGetHandler_GoroutineErrors(t *testing.T) { go func() { defer close(done) httpMethodGetHandler(mockWriter, req) - log.Println("DEBUG: httpMethodGetHandler goroutine exited") + utils.SafeLogPrintln("DEBUG: httpMethodGetHandler goroutine exited") }() // Wait for the connection to be established @@ -1012,7 +1012,7 @@ func TestHttpMethodGetHandler_GoroutineErrors(t *testing.T) { for id, ch := range activeConnections { connID = id msgChan = ch - log.Printf("DEBUG: Connection established: %s", connID) + utils.SafeLogPrintf("DEBUG: Connection established: %s", connID) return true } return false @@ -1023,10 +1023,10 @@ func TestHttpMethodGetHandler_GoroutineErrors(t *testing.T) { // Send a message that should trigger the write error testResp := jsonRPCResponse{Jsonrpc: "2.0", ID: "test-msg-1", Result: "test data"} - log.Printf("DEBUG: Sending test message to channel for %s", connID) + utils.SafeLogPrintf("DEBUG: Sending test message to channel for %s", connID) select { case msgChan <- testResp: - log.Printf("DEBUG: Test message sent to channel for %s", connID) + utils.SafeLogPrintf("DEBUG: Test message sent to channel for %s", connID) case <-time.After(100 * time.Millisecond): t.Fatal("Timeout sending message to channel") } @@ -1034,7 +1034,7 @@ func TestHttpMethodGetHandler_GoroutineErrors(t *testing.T) { // Wait for the handler goroutine to finish due to the write error select { case <-done: - log.Printf("DEBUG: Handler goroutine finished as expected after message write error") + utils.SafeLogPrintf("DEBUG: Handler goroutine finished as expected after message write error") // Handler finished (presumably due to write error) case <-time.After(1000 * time.Millisecond): // Increased timeout to 1 second t.Fatal("Timeout waiting for httpMethodGetHandler goroutine to exit after message write error") diff --git a/pkg/utils/logging.go b/pkg/utils/logging.go new file mode 100644 index 0000000..de111f7 --- /dev/null +++ b/pkg/utils/logging.go @@ -0,0 +1,41 @@ +package utils + +import ( + "log" + "strings" +) + +// SanitizeLogInput sanitizes user input for logging to prevent log forging +func SanitizeLogInput(s string) string { + return strings.ReplaceAll(strings.ReplaceAll(s, "\n", "\\n"), "\r", "\\r") +} + +// SafeLogPrintf logs with sanitized string arguments to prevent log forging +func SafeLogPrintf(format string, args ...interface{}) { + for i, arg := range args { + if s, ok := arg.(string); ok { + args[i] = SanitizeLogInput(s) + } + } + log.Printf(format, args...) +} + +// SafeLogPrintln logs with sanitized string arguments to prevent log forging, like log.Println +func SafeLogPrintln(args ...interface{}) { + for i, arg := range args { + if s, ok := arg.(string); ok { + args[i] = SanitizeLogInput(s) + } + } + log.Println(args...) +} + +// SafeLogFatalf logs with sanitized string arguments to prevent log forging +func SafeLogFatalf(format string, args ...interface{}) { + for i, arg := range args { + if s, ok := arg.(string); ok { + args[i] = SanitizeLogInput(s) + } + } + log.Fatalf(format, args...) +}