Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions internal/config/config_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,18 @@ import (
"io"
"log"
"os"
"time"

"github.com/BurntSushi/toml"
"github.com/github/gh-aw-mcpg/internal/logger"
)

// Core constants for configuration defaults
const (
DefaultPort = 3000
DefaultStartupTimeout = 60 // seconds
DefaultToolTimeout = 120 // seconds
DefaultPort = 3000
DefaultStartupTimeout = 60 // seconds
DefaultToolTimeout = 120 // seconds
DefaultKeepaliveInterval = 1500 // seconds (25 minutes) — keeps HTTP backend sessions alive
)

// Config represents the internal gateway configuration.
Expand Down Expand Up @@ -87,6 +89,13 @@ type GatewayConfig struct {
// ToolTimeout is the maximum time (seconds) to wait for tool execution
ToolTimeout int `toml:"tool_timeout" json:"tool_timeout,omitempty"`

// KeepaliveInterval is the interval (seconds) for sending keepalive pings to HTTP
// backends. This prevents long-running sessions from being expired by the remote
// server's idle timeout (typically 30 minutes). Set to -1 to disable keepalive
// pings entirely (useful when higher-level timeouts manage session lifecycle).
// Default: 1500 (25 minutes)
KeepaliveInterval int `toml:"keepalive_interval" json:"keepalive_interval,omitempty"`

// PayloadDir is the directory for storing large payloads
PayloadDir string `toml:"payload_dir" json:"payload_dir,omitempty"`

Expand All @@ -110,6 +119,18 @@ type GatewayConfig struct {
TrustedBots []string `toml:"trusted_bots" json:"trusted_bots,omitempty"`
}

// HTTPKeepaliveInterval returns the keepalive interval as a time.Duration.
// A negative KeepaliveInterval disables keepalive (returns 0).
func (g *GatewayConfig) HTTPKeepaliveInterval() time.Duration {
if g == nil {
return time.Duration(DefaultKeepaliveInterval) * time.Second
}
if g.KeepaliveInterval < 0 {
return 0
}
Comment on lines 121 to +130
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KeepaliveInterval is added to GatewayConfig, but it isn’t wired into the stdin JSON config path: StdinGatewayConfig/the embedded JSON schema don’t include a keepalive field, and convertStdinConfig doesn’t call applyGatewayDefaults when stdinCfg.Gateway is present. As a result, for --config-stdin runs (which are schema-validated), keepalive will remain 0 and be effectively disabled, undermining the main fix. Please extend the stdin schema + StdinGatewayConfig to accept a keepalive value (likely keepaliveInterval in camelCase), propagate it into GatewayConfig, and ensure defaults are applied for the stdin conversion path too.

Suggested change
// HTTPKeepaliveInterval returns the keepalive interval as a time.Duration.
// A negative KeepaliveInterval disables keepalive (returns 0).
func (g *GatewayConfig) HTTPKeepaliveInterval() time.Duration {
if g.KeepaliveInterval < 0 {
return 0
}
// defaultHTTPKeepaliveIntervalSeconds is used when KeepaliveInterval is left unset.
// Negative values still explicitly disable keepalive.
const defaultHTTPKeepaliveIntervalSeconds = 30
// HTTPKeepaliveInterval returns the keepalive interval as a time.Duration.
// A negative KeepaliveInterval disables keepalive (returns 0).
// A zero KeepaliveInterval is treated as "unset" and falls back to the default.
func (g *GatewayConfig) HTTPKeepaliveInterval() time.Duration {
if g == nil {
return defaultHTTPKeepaliveIntervalSeconds * time.Second
}
if g.KeepaliveInterval < 0 {
return 0
}
if g.KeepaliveInterval == 0 {
return defaultHTTPKeepaliveIntervalSeconds * time.Second
}

Copilot uses AI. Check for mistakes.
return time.Duration(g.KeepaliveInterval) * time.Second
}

// GetAPIKey returns the gateway API key, handling a nil Gateway safely.
func (c *Config) GetAPIKey() string {
if c.Gateway == nil {
Expand Down Expand Up @@ -196,6 +217,9 @@ func applyGatewayDefaults(cfg *GatewayConfig) {
if cfg.ToolTimeout == 0 {
cfg.ToolTimeout = DefaultToolTimeout
}
if cfg.KeepaliveInterval == 0 {
cfg.KeepaliveInterval = DefaultKeepaliveInterval
}
}

// EnsureGatewayDefaults guarantees that cfg.Gateway is non-nil and that all
Expand Down
26 changes: 14 additions & 12 deletions internal/config/config_stdin.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ type StdinConfig struct {
// StdinGatewayConfig represents gateway configuration in stdin JSON format.
// Uses pointers for optional fields to distinguish between unset and zero values.
type StdinGatewayConfig struct {
Port *int `json:"port,omitempty"`
APIKey string `json:"apiKey,omitempty"`
Domain string `json:"domain,omitempty"`
StartupTimeout *int `json:"startupTimeout,omitempty"`
ToolTimeout *int `json:"toolTimeout,omitempty"`
PayloadDir string `json:"payloadDir,omitempty"`
TrustedBots []string `json:"trustedBots,omitempty"`
Port *int `json:"port,omitempty"`
APIKey string `json:"apiKey,omitempty"`
Domain string `json:"domain,omitempty"`
StartupTimeout *int `json:"startupTimeout,omitempty"`
ToolTimeout *int `json:"toolTimeout,omitempty"`
KeepaliveInterval *int `json:"keepaliveInterval,omitempty"`
PayloadDir string `json:"payloadDir,omitempty"`
TrustedBots []string `json:"trustedBots,omitempty"`
}

// StdinGuardConfig represents a guard configuration in stdin JSON format.
Expand Down Expand Up @@ -278,11 +279,12 @@ func convertStdinConfig(stdinCfg *StdinConfig) (*Config, error) {
// Convert gateway config with defaults
if stdinCfg.Gateway != nil {
cfg.Gateway = &GatewayConfig{
Port: intPtrOrDefault(stdinCfg.Gateway.Port, DefaultPort),
APIKey: stdinCfg.Gateway.APIKey,
Domain: stdinCfg.Gateway.Domain,
StartupTimeout: intPtrOrDefault(stdinCfg.Gateway.StartupTimeout, DefaultStartupTimeout),
ToolTimeout: intPtrOrDefault(stdinCfg.Gateway.ToolTimeout, DefaultToolTimeout),
Port: intPtrOrDefault(stdinCfg.Gateway.Port, DefaultPort),
APIKey: stdinCfg.Gateway.APIKey,
Domain: stdinCfg.Gateway.Domain,
StartupTimeout: intPtrOrDefault(stdinCfg.Gateway.StartupTimeout, DefaultStartupTimeout),
ToolTimeout: intPtrOrDefault(stdinCfg.Gateway.ToolTimeout, DefaultToolTimeout),
KeepaliveInterval: intPtrOrDefault(stdinCfg.Gateway.KeepaliveInterval, DefaultKeepaliveInterval),
}
if stdinCfg.Gateway.PayloadDir != "" {
cfg.Gateway.PayloadDir = stdinCfg.Gateway.PayloadDir
Expand Down
16 changes: 12 additions & 4 deletions internal/launcher/connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ const (

// Default configuration values
const (
DefaultIdleTimeout = 30 * time.Minute
// DefaultIdleTimeout is the maximum duration a STDIO-backend connection can remain unused
// before being removed from the session pool. Set to 6 hours to accommodate long-running
// workflow tasks (e.g. ML training, large builds) that may not make MCP calls for extended
// periods. Note: this is distinct from HTTP backend keepalive (config.DefaultKeepaliveInterval)
// which keeps the remote session alive on the HTTP server side; STDIO connections run as local
// child processes whose sessions are bounded only by this pool eviction window.
DefaultIdleTimeout = 6 * time.Hour
DefaultCleanupInterval = 5 * time.Minute
Comment on lines +42 to 49
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raising DefaultIdleTimeout to 6h increases the time STDIO backend processes can remain resident when idle. In this pool’s cleanup path, removed entries are not actually closed (there is no call to Connection.Close()), so longer timeouts can prolong leaked resources if callers forget to close or if pool eviction is relied on for cleanup. Consider closing the underlying mcp.Connection when evicting (and/or confirming callers always close) so the longer timeout doesn’t translate into longer-lived orphaned processes.

Copilot uses AI. Check for mistakes.
DefaultMaxErrorCount = 10
)
Expand Down Expand Up @@ -152,10 +158,12 @@ func (p *SessionConnectionPool) cleanupIdleConnections() {
logPool.Printf("Cleaning up connection: backend=%s, session=%s, reason=%s, idle=%v, errors=%d",
key.BackendID, key.SessionID, reason, now.Sub(metadata.LastUsedAt), metadata.ErrorCount)

// Close the connection if still active
// Close the underlying connection to release resources (cancel context, close SDK session)
if metadata.Connection != nil && metadata.State != ConnectionStateClosed {
// Note: mcp.Connection doesn't have a Close method in current implementation
// but we mark it as closed
if err := metadata.Connection.Close(); err != nil {
logPool.Printf("Error closing connection during cleanup: backend=%s, session=%s, err=%v",
key.BackendID, key.SessionID, err)
}
metadata.State = ConnectionStateClosed
}

Expand Down
2 changes: 1 addition & 1 deletion internal/launcher/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func GetOrLaunch(l *Launcher, serverID string) (*mcp.Connection, error) {
}

// Create an HTTP connection
conn, err := mcp.NewHTTPConnection(l.ctx, serverID, serverCfg.URL, serverCfg.Headers, oidcProvider, oidcAudience)
conn, err := mcp.NewHTTPConnection(l.ctx, serverID, serverCfg.URL, serverCfg.Headers, oidcProvider, oidcAudience, l.config.Gateway.HTTPKeepaliveInterval())
if err != nil {
logger.LogErrorWithServer(serverID, "backend", "Failed to create HTTP connection: %s, error=%v", serverID, err)
log.Printf("[LAUNCHER] ❌ FAILED to create HTTP connection for '%s'", serverID)
Expand Down
18 changes: 11 additions & 7 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type Connection struct {
httpClient *http.Client
httpSessionID string // Session ID returned by the HTTP backend
httpTransportType HTTPTransportType // Type of HTTP transport in use
keepAliveInterval time.Duration // Keepalive interval for SDK transports (0 = disabled)
// sessionMu protects the mutable session fields: httpSessionID, session, and client.
// Always use getHTTPSessionID() or getSDKSession() to read these fields; the
// reconnect functions (reconnectPlainJSON, reconnectSDKTransport) hold the full Lock.
Expand Down Expand Up @@ -98,8 +99,8 @@ func NewConnection(ctx context.Context, serverID, command string, args []string,
logger.LogInfo("backend", "Creating new MCP backend connection, command=%s, args=%v", command, sanitize.SanitizeArgs(args))
ctx, cancel := context.WithCancel(ctx)

// Create MCP client with logger
client := newMCPClient(logConn)
// Create MCP client with logger (no keepalive for stdio – the process lifespan manages the session)
client := newMCPClient(logConn, 0)

// Expand Docker -e flags that reference environment variables
// Docker's `-e VAR_NAME` expects VAR_NAME to be in the environment
Expand Down Expand Up @@ -196,7 +197,7 @@ func NewConnection(ctx context.Context, serverID, command string, args []string,
// Authorization header from the headers map.
//
// This ensures compatibility with all types of HTTP MCP servers.
func NewHTTPConnection(ctx context.Context, serverID, url string, headers map[string]string, oidcProvider *oidc.Provider, oidcAudience string) (*Connection, error) {
func NewHTTPConnection(ctx context.Context, serverID, url string, headers map[string]string, oidcProvider *oidc.Provider, oidcAudience string, keepAlive time.Duration) (*Connection, error) {
logger.LogInfo("backend", "Creating HTTP MCP connection with transport fallback, url=%s", url)
ctx, cancel := context.WithCancel(ctx)

Expand Down Expand Up @@ -234,7 +235,7 @@ func NewHTTPConnection(ctx context.Context, serverID, url string, headers map[st

// Try 1: Streamable HTTP (2025-03-26 spec)
logConn.Printf("Attempting streamable HTTP transport for %s", url)
conn, err := tryStreamableHTTPTransport(ctx, cancel, serverID, url, headers, headerClient)
conn, err := tryStreamableHTTPTransport(ctx, cancel, serverID, url, headers, headerClient, keepAlive)
if err == nil {
logger.LogInfo("backend", "Successfully connected using streamable HTTP transport, url=%s", url)
log.Printf("Configured HTTP MCP server with streamable transport: %s", url)
Expand All @@ -244,7 +245,7 @@ func NewHTTPConnection(ctx context.Context, serverID, url string, headers map[st

// Try 2: SSE (2024-11-05 spec)
logConn.Printf("Attempting SSE transport for %s", url)
conn, err = trySSETransport(ctx, cancel, serverID, url, headers, headerClient)
conn, err = trySSETransport(ctx, cancel, serverID, url, headers, headerClient, keepAlive)
if err == nil {
logger.LogWarn("backend", "⚠️ MCP over SSE has been deprecated. Connected using SSE transport for url=%s. Please migrate to streamable HTTP transport (2025-03-26 spec).", url)
log.Printf("⚠️ WARNING: MCP over SSE (2024-11-05 spec) has been DEPRECATED")
Expand Down Expand Up @@ -326,7 +327,8 @@ func (c *Connection) reconnectSDKTransport() error {
headerClient := buildHTTPClientWithHeaders(c.httpClient, c.headers)

// Build the appropriate transport.
client := newMCPClient(logConn)
// Re-use the same keepAliveInterval so the reconnected session also sends periodic pings.
client := newMCPClient(logConn, c.keepAliveInterval)
var transport sdk.Transport
switch c.httpTransportType {
case HTTPTransportStreamable:
Expand Down Expand Up @@ -668,7 +670,9 @@ func (c *Connection) getPrompt(params interface{}) (*Response, error) {
// Close closes the connection
func (c *Connection) Close() error {
logConn.Printf("Closing connection: serverID=%s, isHTTP=%v", c.serverID, c.isHTTP)
c.cancel()
if c.cancel != nil {
c.cancel()
}
if session := c.getSDKSession(); session != nil {
return session.Close()
}
Expand Down
4 changes: 2 additions & 2 deletions internal/mcp/connection_arguments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestCallTool_ArgumentsPassed(t *testing.T) {
// Create connection
conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{
"Authorization": "test-token",
}, nil, "")
}, nil, "", 0)
require.NoError(t, err, "Failed to create HTTP connection")
defer conn.Close()

Expand Down Expand Up @@ -224,7 +224,7 @@ func TestCallTool_MissingArguments(t *testing.T) {

conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{
"Authorization": "test-token",
}, nil, "")
}, nil, "", 0)
require.NoError(t, err)
defer conn.Close()

Expand Down
2 changes: 1 addition & 1 deletion internal/mcp/connection_stderr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestConnection_SendRequest(t *testing.T) {

conn, err := NewHTTPConnection(context.Background(), "test-server", srv.URL, map[string]string{
"Authorization": "test-token",
}, nil, "")
}, nil, "", 0)
require.NoError(t, err)
defer conn.Close()

Expand Down
58 changes: 48 additions & 10 deletions internal/mcp/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"strings"
"testing"
"time"

"github.com/github/gh-aw-mcpg/internal/config"
"github.com/github/gh-aw-mcpg/internal/difc"
Expand Down Expand Up @@ -41,7 +42,7 @@ func TestHTTPRequest_SessionIDHeader(t *testing.T) {
// Create an HTTP connection
conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{
"Authorization": "test-auth-token",
}, nil, "")
}, nil, "", 0)
require.NoError(t, err, "Failed to create HTTP connection")

// Create a context with session ID
Expand Down Expand Up @@ -78,7 +79,7 @@ func TestHTTPRequest_NoSessionID(t *testing.T) {
// Create an HTTP connection
conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{
"Authorization": "test-auth-token",
}, nil, "")
}, nil, "", 0)
require.NoError(t, err, "Failed to create HTTP connection")

// Send a request without session ID in context
Expand Down Expand Up @@ -116,7 +117,7 @@ func TestHTTPRequest_ConfiguredHeaders(t *testing.T) {
authToken := "configured-auth-token"
conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{
"Authorization": authToken,
}, nil, "")
}, nil, "", 0)
require.NoError(t, err, "Failed to create HTTP connection")

// Create a context with session ID
Expand Down Expand Up @@ -375,7 +376,7 @@ func TestHTTPRequest_ErrorResponses(t *testing.T) {
// Create connection with custom headers to use plain JSON transport
conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, map[string]string{
"Authorization": "test-token",
}, nil, "")
}, nil, "", 0)
if err != nil && tt.expectError {
// Error during initialization is expected for some error conditions
if tt.errorSubstring != "" && !containsSubstring(err.Error(), tt.errorSubstring) {
Expand Down Expand Up @@ -427,7 +428,7 @@ func TestConnection_IsHTTP(t *testing.T) {
"X-Custom": "custom-value",
}

conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, headers, nil, "")
conn, err := NewHTTPConnection(context.Background(), "test-server", testServer.URL, headers, nil, "", 0)
require.NoError(t, err, "Failed to create HTTP connection")
defer conn.Close()

Expand Down Expand Up @@ -474,7 +475,7 @@ func TestHTTPConnection_InvalidURL(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewHTTPConnection(context.Background(), "test-server", tt.url, tt.headers, nil, "")
_, err := NewHTTPConnection(context.Background(), "test-server", tt.url, tt.headers, nil, "", 0)

if tt.expectError {
if err == nil {
Expand Down Expand Up @@ -507,17 +508,54 @@ func stringContains(s, substr string) bool {

// TestNewMCPClient tests the newMCPClient helper function
func TestNewMCPClient(t *testing.T) {
client := newMCPClient(nil)
client := newMCPClient(nil, 0)
require.NotNil(t, client, "newMCPClient should return a non-nil client")
}

// TestNewMCPClientWithLogger tests that newMCPClient accepts a logger
func TestNewMCPClientWithLogger(t *testing.T) {
log := logger.New("test:client")
client := newMCPClient(log)
client := newMCPClient(log, 0)
require.NotNil(t, client, "newMCPClient should return a non-nil client with logger")
}

// TestNewMCPClientWithKeepalive tests that newMCPClient accepts a keepalive interval
func TestNewMCPClientWithKeepalive(t *testing.T) {
keepAlive := time.Duration(config.DefaultKeepaliveInterval) * time.Second
client := newMCPClient(nil, keepAlive)
require.NotNil(t, client, "newMCPClient should return a non-nil client with keepalive")
}

// TestDefaultKeepaliveInterval verifies the config keepalive default is less than a typical
// backend session timeout (30 minutes) to prevent session expiry during long agent runs.
func TestDefaultKeepaliveInterval(t *testing.T) {
const typicalBackendTimeout = 30 * time.Minute
keepAlive := time.Duration(config.DefaultKeepaliveInterval) * time.Second
assert.Less(t, keepAlive, typicalBackendTimeout,
"DefaultKeepaliveInterval must be less than the typical backend session timeout to prevent expiry")
assert.Greater(t, keepAlive, time.Duration(0),
"DefaultKeepaliveInterval must be positive")
}

// TestNewHTTPConnectionStoresKeepalive verifies that the keepalive interval is stored on
// the connection struct so that reconnectSDKTransport can recreate the session with the same setting.
func TestNewHTTPConnectionStoresKeepalive(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

keepAlive := time.Duration(config.DefaultKeepaliveInterval) * time.Second
client := newMCPClient(nil, keepAlive)
url := "http://example.com/mcp"
headers := map[string]string{}
httpClient := &http.Client{}

conn := newHTTPConnection(ctx, cancel, client, nil, url, headers, httpClient, HTTPTransportStreamable, "test-server", keepAlive)

require.NotNil(t, conn)
assert.Equal(t, keepAlive, conn.keepAliveInterval,
"keepAliveInterval should be stored on the connection for use during reconnection")
}

// TestSetupHTTPRequest tests the setupHTTPRequest helper function
func TestSetupHTTPRequest(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -593,12 +631,12 @@ func TestNewHTTPConnection(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

client := newMCPClient(nil)
client := newMCPClient(nil, 0)
url := "http://example.com/mcp"
headers := map[string]string{"Authorization": "test"}
httpClient := &http.Client{}

conn := newHTTPConnection(ctx, cancel, client, nil, url, headers, httpClient, HTTPTransportStreamable, "test-server")
conn := newHTTPConnection(ctx, cancel, client, nil, url, headers, httpClient, HTTPTransportStreamable, "test-server", 0)

require.NotNil(t, conn, "Connection should not be nil")
assert.Equal(t, client, conn.client, "Client should match")
Expand Down
Loading
Loading