Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/CONFIGURATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ JSON configuration is the primary format for containerized deployments. Pass via

### Configuration Validation

The gateway provides fail-fast validation with precise error locations (line/column for TOML parse errors), unknown key detection (catches typos like `prot` instead of `port`), and environment variable expansion validation. Check log files for warnings after startup.
The gateway provides fail-fast validation with precise error locations (line/column for TOML parse errors), unknown field rejection (typos like `prot` instead of `port` are rejected with an error per spec §4.3.1), and environment variable expansion validation.

### Usage

Expand Down
26 changes: 26 additions & 0 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package cmd
import (
"bufio"
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -297,6 +299,19 @@ func run(cmd *cobra.Command, args []string) error {

debugLog.Printf("Server mode: %s, guards mode: %s", mode, cfg.DIFCMode)

// Per spec §7.3: generate a random API key on startup if none is configured.
// The generated key is set in the config so it propagates to both the HTTP
// server authentication and the stdout configuration output (spec §5.4).
if cfg.GetAPIKey() == "" {
randomKey, err := generateRandomAPIKey()
if err != nil {
return fmt.Errorf("failed to generate random API key: %w", err)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generateRandomAPIKey() already wraps rand.Read errors with "failed to generate random API key". The caller in run() wraps the returned error with the same prefix again, which will produce duplicated text like "failed to generate random API key: failed to generate random API key: …". Consider returning the raw rand.Read error from generateRandomAPIKey (or removing the extra wrap in run()) so the error chain is not repetitive.

Suggested change
return fmt.Errorf("failed to generate random API key: %w", err)
return err

Copilot uses AI. Check for mistakes.
}
cfg.Gateway.APIKey = randomKey
log.Printf("No API key configured — generated a temporary random API key for this session")
logger.LogInfoMd("startup", "Generated temporary random API key (spec §7.3)")
}

// Create unified MCP server (backend for both modes)
unifiedServer, err := server.NewUnified(ctx, cfg)
if err != nil {
Expand Down Expand Up @@ -572,6 +587,17 @@ func loadEnvFile(path string) error {
return scanner.Err()
}

// generateRandomAPIKey generates a cryptographically random API key.
// Per spec §7.3, the gateway SHOULD generate a random API key on startup
// if none is provided. Returns a 32-byte hex-encoded string (64 chars).
func generateRandomAPIKey() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random API key: %w", err)
}
return hex.EncodeToString(bytes), nil
}

// Execute runs the root command
func Execute() {
if err := rootCmd.Execute(); err != nil {
Expand Down
15 changes: 15 additions & 0 deletions internal/cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,18 @@ func TestPostRunCleanup(t *testing.T) {
assert.NotNil(t, rootCmd.PersistentPostRun, "PersistentPostRun should be set")
})
}

// TestGenerateRandomAPIKey verifies that generateRandomAPIKey produces a
// non-empty, unique, hex-encoded string per spec §7.3.
func TestGenerateRandomAPIKey(t *testing.T) {
key, err := generateRandomAPIKey()
require.NoError(t, err, "generateRandomAPIKey() should not fail")
assert.NotEmpty(t, key, "generated key should not be empty")
// 32 bytes encoded as hex = 64 characters
assert.Len(t, key, 64, "generated key should be 64 hex characters")

// Verify keys are unique across calls
key2, err := generateRandomAPIKey()
require.NoError(t, err)
assert.NotEqual(t, key, key2, "successive calls should produce unique keys")
}
58 changes: 42 additions & 16 deletions internal/config/config_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
//
// Streaming Decoder: Uses toml.NewDecoder() for memory efficiency with large configs
// Error Reporting: Wraps ParseError with %w to preserve structured type and surface full source context
// Unknown Fields: Uses MetaData.Undecoded() for typo warnings (not hard errors)
// Unknown Fields: Uses MetaData.Undecoded() to reject configurations with unrecognized fields (spec §4.3.1)
// Validation: Multi-layer approach (parse → schema → field-level → variable expansion)
//
// # TOML 1.1 Features Used
Expand All @@ -29,10 +29,10 @@ import (
"io"
"log"
"os"
"strings"
"time"

"github.com/BurntSushi/toml"
"github.com/github/gh-aw-mcpg/internal/logger"
)

// Core constants for configuration defaults
Expand Down Expand Up @@ -235,15 +235,37 @@ func (cfg *Config) EnsureGatewayDefaults() {
applyDefaults(cfg)
}

// LoadFromFile loads configuration from a TOML file.
// isDynamicTOMLPath reports whether the TOML key path falls under a known
// map[string]interface{} field in the config struct. Such fields accept
// arbitrary nested keys by design and must be excluded from the unknown-field check.
//
// toml.Key is a []string of path components, e.g.:
//
// ["servers", "github", "guard_policies", "mypolicy", "repos"]
// [0] [1] [2] [3] [4]
//
// Dynamic sections:
// - servers[0].<name>[1].guard_policies[2].<policy>[3].<key>[4+] (len ≥ 5)
// - guards[0].<name>[1].config[2].<key>[3+] (len ≥ 4)
func isDynamicTOMLPath(key toml.Key) bool {
// servers.<name>.guard_policies.<policy>.<key> → indices [0]="servers" [2]="guard_policies", len ≥ 5
if len(key) >= 5 && key[0] == "servers" && key[2] == "guard_policies" {
return true
}
// guards.<name>.config.<key> → indices [0]="guards" [2]="config", len ≥ 4
if len(key) >= 4 && key[0] == "guards" && key[2] == "config" {
return true
}
return false
}

// This function uses the BurntSushi/toml v1.6.0+ parser with TOML 1.1 support,
// which enables modern syntax features like newlines in inline tables and
// improved duplicate key detection.
//
// Error Handling:
// - Parse errors include both line AND column numbers (v1.5.0+ feature)
// - Unknown fields generate warnings instead of hard errors (typo detection)
// - Unknown fields are rejected with an error per spec §4.3.1
// - Metadata tracks all decoded keys for validation purposes
//
// Example usage with TOML 1.1 multi-line arrays:
Expand Down Expand Up @@ -280,22 +302,26 @@ func LoadFromFile(path string) (*Config, error) {

logConfig.Printf("Parsed TOML config with %d servers", len(cfg.Servers))

// Detect and warn about unknown configuration keys (typos, deprecated options)
// Detect and reject unknown configuration keys (typos, unrecognized fields).
// This uses MetaData.Undecoded() to identify keys present in TOML but not
// in the Config struct. This provides a balance between strict validation
// (hard errors) and user-friendliness (warnings allow config to load).
// in the Config struct. Per spec §4.3.1, the gateway MUST reject configurations
// containing unrecognized fields with an informative error message.
//
// Design decision: We use warnings rather than toml.Decoder.DisallowUnknownFields()
// (which doesn't exist) or hard errors to maintain backward compatibility and
// allow gradual config migration. Common typos like "prot" → "port" are caught
// while still allowing the gateway to start.
// Note: map[string]interface{} fields (guard_policies, guards.*.config) are
// intentionally flexible and their nested keys are exempt from this check.
undecoded := md.Undecoded()
if len(undecoded) > 0 {
for _, key := range undecoded {
// Log to both debug logger and file logger for visibility
logConfig.Printf("WARNING: Unknown configuration key '%s' - check for typos or deprecated options", key)
logger.LogWarn("config", "Unknown configuration key '%s' - check for typos or deprecated options", key)
var unknownKeys []toml.Key
for _, key := range undecoded {
if !isDynamicTOMLPath(key) {
unknownKeys = append(unknownKeys, key)
}
}
if len(unknownKeys) > 0 {
keyStrs := make([]string, len(unknownKeys))
for i, k := range unknownKeys {
keyStrs[i] = k.String()
}
return nil, fmt.Errorf("configuration contains unrecognized field(s): %s — check the MCP Gateway Specification for supported fields", strings.Join(keyStrs, ", "))
}

// Validate required fields
Expand Down
11 changes: 5 additions & 6 deletions internal/config/config_core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ GITHUB_TOKEN = "mytoken"
}

// TestLoadFromFile_UnknownKeysDoNotCauseError verifies that unknown configuration
// keys produce a warning log but do not prevent the config from loading.
// keys are rejected with an error per spec §4.3.1.
func TestLoadFromFile_UnknownKeysDoNotCauseError(t *testing.T) {
path := writeTempTOML(t, `
[gateway]
Expand All @@ -182,12 +182,11 @@ prot = 3000
command = "docker"
args = ["run", "--rm", "-i", "ghcr.io/github/github-mcp-server:latest"]
`)
// Unknown key "prot" (typo for "port") should warn but not error
// Unknown key "prot" (typo for "port") must now return an error per spec §4.3.1
cfg, err := LoadFromFile(path)
require.NoError(t, err)
require.NotNil(t, cfg)
// Port should use default since "prot" was not recognized
assert.Equal(t, DefaultPort, cfg.Gateway.Port)
require.Error(t, err)
assert.Nil(t, cfg)
assert.Contains(t, err.Error(), "unrecognized field")
Comment on lines 174 to +189
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test name TestLoadFromFile_UnknownKeysDoNotCauseError no longer matches the behavior/assertions (it now expects LoadFromFile to return an error on unknown keys). Renaming the test to reflect the new contract (e.g., TestLoadFromFile_UnknownKeysCauseError) would avoid confusion and keep intent clear.

Copilot uses AI. Check for mistakes.
}

// TestLoadFromFile_TrustedBotsEmptyArray verifies that an explicitly set but
Expand Down
45 changes: 15 additions & 30 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1036,16 +1036,17 @@ func TestLoadFromFile_UnknownKeys(t *testing.T) {
[servers.test]
command = "docker"
args = ["run", "--rm", "-i", "test/container:latest"]
unknown_field = "should trigger warning"
unknown_field = "should trigger error"
`

err := os.WriteFile(tmpFile, []byte(tomlContent), 0644)
require.NoError(t, err, "Failed to write temp TOML file")

// Should still load successfully but log warning
// Must now return an error per spec §4.3.1: unknown fields MUST be rejected
cfg, err := LoadFromFile(tmpFile)
require.NoError(t, err, "LoadFromFile() should succeed with unknown keys")
require.NotNil(t, cfg, "Config should not be nil")
require.Error(t, err, "LoadFromFile() should fail with unknown keys")
assert.Nil(t, cfg, "Config should be nil on error")
assert.Contains(t, err.Error(), "unrecognized field", "Error should mention unrecognized field")
}

func TestLoadFromFile_NonExistentFile(t *testing.T) {
Expand Down Expand Up @@ -1092,7 +1093,7 @@ port 3000
"Error should mention column or line position, got: %s", errMsg)
}

// TestLoadFromFile_UnknownKeysInGateway tests detection of unknown keys in gateway section
// TestLoadFromFile_UnknownKeysInGateway tests that unknown keys in gateway section are rejected
func TestLoadFromFile_UnknownKeysInGateway(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "config.toml")
Expand All @@ -1111,21 +1112,14 @@ args = ["run", "--rm", "-i", "test/container:latest"]
err := os.WriteFile(tmpFile, []byte(tomlContent), 0644)
require.NoError(t, err, "Failed to write temp TOML file")

// Enable debug logging to capture warning about unknown key
SetDebug(true)
defer SetDebug(false)

// Should still load successfully, but warning will be logged
// Must return an error per spec §4.3.1: unknown fields MUST be rejected
cfg, err := LoadFromFile(tmpFile)
require.NoError(t, err, "LoadFromFile() should succeed even with unknown keys")
require.NotNil(t, cfg, "Config should not be nil")

// Port should be default since "prot" was not recognized
assert.Equal(t, DefaultPort, cfg.Gateway.Port, "Port should be default since 'prot' is unknown")
assert.Equal(t, "test-key", cfg.Gateway.APIKey, "API key should be set correctly")
require.Error(t, err, "LoadFromFile() should fail with unknown keys")
assert.Nil(t, cfg, "Config should be nil on error")
assert.Contains(t, err.Error(), "unrecognized field", "Error should mention unrecognized field")
}

// TestLoadFromFile_MultipleUnknownKeys tests detection of multiple typos
// TestLoadFromFile_MultipleUnknownKeys tests that multiple unknown keys are rejected
func TestLoadFromFile_MultipleUnknownKeys(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "config.toml")
Expand All @@ -1146,20 +1140,11 @@ typ = "stdio"
err := os.WriteFile(tmpFile, []byte(tomlContent), 0644)
require.NoError(t, err, "Failed to write temp TOML file")

// Enable debug logging to capture warnings
SetDebug(true)
defer SetDebug(false)

// Should still load successfully
// Must return an error per spec §4.3.1: unknown fields MUST be rejected
cfg, err := LoadFromFile(tmpFile)
require.NoError(t, err, "LoadFromFile() should succeed even with multiple unknown keys")
require.NotNil(t, cfg, "Config should not be nil")

// Correctly spelled fields should work
assert.Equal(t, 8080, cfg.Gateway.Port, "Port should be set correctly")
// Misspelled fields should use defaults
assert.Equal(t, DefaultStartupTimeout, cfg.Gateway.StartupTimeout, "StartupTimeout should be default")
assert.Equal(t, DefaultToolTimeout, cfg.Gateway.ToolTimeout, "ToolTimeout should be default")
require.Error(t, err, "LoadFromFile() should fail with multiple unknown keys")
assert.Nil(t, cfg, "Config should be nil on error")
assert.Contains(t, err.Error(), "unrecognized field", "Error should mention unrecognized field")
}

// TestLoadFromFile_StreamingLargeFile tests that streaming decoder works efficiently
Expand Down
20 changes: 20 additions & 0 deletions internal/server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ func logRuntimeError(errorType, detail string, r *http.Request, serverName *stri
timestamp, server, requestID, errorType, detail, r.URL.Path, r.Method)
}

// isMalformedAuthHeader returns true if the header value contains characters
// that are not valid in HTTP header values per RFC 7230: null bytes, control
// characters below 0x20 (except horizontal tab 0x09), or DEL (0x7F).
// Per spec 7.2 item 3, such headers must be rejected with HTTP 400.
func isMalformedAuthHeader(header string) bool {
for _, c := range header {
if c == 0x00 || (c < 0x20 && c != 0x09) || c == 0x7F {
return true
}
}
return false
}

// authMiddleware implements API key authentication per spec section 7.1
// Per spec: Authorization header MUST contain the API key directly (NOT Bearer scheme)
//
Expand All @@ -50,6 +63,13 @@ func authMiddleware(apiKey string, next http.HandlerFunc) http.HandlerFunc {
return
}

// Spec 7.2 item 3: Malformed Authorization headers (null bytes, non-printable
// control characters) must return 400 Bad Request, not 401.
if isMalformedAuthHeader(authHeader) {
rejectRequest(w, r, http.StatusBadRequest, "bad_request", "malformed Authorization header", "auth", "authentication_failed", "malformed_auth_header")
return
}

// Spec 7.1: Authorization header must contain API key directly (not Bearer scheme)
if authHeader != apiKey {
rejectRequest(w, r, http.StatusUnauthorized, "unauthorized", "invalid API key", "auth", "authentication_failed", "invalid_api_key")
Expand Down
67 changes: 67 additions & 0 deletions internal/server/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,46 @@ func TestAuthMiddleware(t *testing.T) {
expectNextCalled: true,
expectErrorMessage: "",
},
{
name: "MalformedHeaderNullByte",
configuredAPIKey: "valid-key",
authHeader: "valid-key\x00extra",
expectStatusCode: http.StatusBadRequest,
expectNextCalled: false,
expectErrorMessage: "malformed Authorization header",
},
{
name: "MalformedHeaderControlChar",
configuredAPIKey: "valid-key",
authHeader: "valid-key\x01extra",
expectStatusCode: http.StatusBadRequest,
expectNextCalled: false,
expectErrorMessage: "malformed Authorization header",
},
{
name: "MalformedHeaderDEL",
configuredAPIKey: "valid-key",
authHeader: "valid-key\x7F",
expectStatusCode: http.StatusBadRequest,
expectNextCalled: false,
expectErrorMessage: "malformed Authorization header",
},
{
name: "MalformedHeaderNewline",
configuredAPIKey: "valid-key",
authHeader: "valid-key\nextra",
expectStatusCode: http.StatusBadRequest,
expectNextCalled: false,
expectErrorMessage: "malformed Authorization header",
},
{
name: "TabAllowedInHeader",
configuredAPIKey: "valid\tkey",
authHeader: "valid\tkey",
expectStatusCode: http.StatusOK,
expectNextCalled: true,
expectErrorMessage: "",
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -313,3 +353,30 @@ func TestLogRuntimeError(t *testing.T) {
func stringPtr(s string) *string {
return &s
}

// TestIsMalformedAuthHeader tests the isMalformedAuthHeader helper.
func TestIsMalformedAuthHeader(t *testing.T) {
tests := []struct {
name string
header string
malformed bool
}{
{name: "EmptyString", header: "", malformed: false},
{name: "NormalKey", header: "my-api-key", malformed: false},
{name: "SpecialPrintableChars", header: "key!@#$%^&*()", malformed: false},
{name: "HorizontalTab", header: "key\tvalue", malformed: false},
{name: "NullByte", header: "key\x00value", malformed: true},
{name: "ControlCharSOH", header: "\x01key", malformed: true},
{name: "ControlCharLF", header: "key\nvalue", malformed: true},
{name: "ControlCharCR", header: "key\rvalue", malformed: true},
{name: "DELChar", header: "key\x7Fvalue", malformed: true},
{name: "ControlCharUS", header: "key\x1Fvalue", malformed: true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isMalformedAuthHeader(tt.header)
assert.Equal(t, tt.malformed, got, "isMalformedAuthHeader(%q) should return %v", tt.header, tt.malformed)
})
}
}
Loading
Loading