diff --git a/router-tests/lifecycle/mcp_hot_reload_shutdown_test.go b/router-tests/lifecycle/mcp_hot_reload_shutdown_test.go new file mode 100644 index 0000000000..0ff3567c89 --- /dev/null +++ b/router-tests/lifecycle/mcp_hot_reload_shutdown_test.go @@ -0,0 +1,92 @@ +package integration + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "go.uber.org/goleak" +) + +func TestShutDownMCPGoRoutineLeaks(t *testing.T) { + + defer goleak.VerifyNone(t, + goleak.IgnoreTopFunction("github.com/hashicorp/consul/sdk/freeport.checkFreedPorts"), // Freeport, spawned by init + goleak.IgnoreAnyFunction("net/http.(*conn).serve"), // HTTPTest server I can't close if I want to keep the problematic goroutine open for the test + ) + + operationsDir := t.TempDir() + storageProviderID := "mcp_hot_reload_test_id" + + xEnv, err := testenv.CreateTestEnv(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + Storage: config.MCPStorageConfig{ + ProviderID: storageProviderID, + }, + HotReloadConfig: config.MCPOperationsHotReloadConfig{ + Enabled: true, + Interval: 1 * time.Second, + }, + }, + RouterOptions: []core.Option{ + core.WithStorageProviders(config.StorageProviders{ + FileSystem: []config.FileSystemStorageProvider{ + { + ID: storageProviderID, + Path: operationsDir, + }, + }, + }), + }, + }) + + require.NoError(t, err) + + mcpOperationFile := filepath.Join(operationsDir, "main.graphql") + // write mcp operation content + require.NoError(t, os.WriteFile(mcpOperationFile, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0o600)) + + // Verify GoRoutines are properly setup for Hot Reloading + require.EventuallyWithT(t, func(t *assert.CollectT) { + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + + require.Contains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes", + Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotes", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + }, 10*time.Second, 100*time.Millisecond) + + xEnv.Shutdown() + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + if assert.Error(t, err) { + require.ErrorIs(t, err, testenv.ErrEnvironmentClosed) + } + require.Nil(t, resp) + +} diff --git a/router-tests/mcp_hot_reload_test.go b/router-tests/mcp_hot_reload_test.go new file mode 100644 index 0000000000..92b6f5ef6f --- /dev/null +++ b/router-tests/mcp_hot_reload_test.go @@ -0,0 +1,201 @@ +package integration + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +func TestMCPOperationHotReload(t *testing.T) { + t.Parallel() + operationsDir := t.TempDir() + storageProviderID := "mcp_hot_reload_test_id" + + t.Run("List Updated User Operations On Addition and Removal", func(t *testing.T) { + + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + Storage: config.MCPStorageConfig{ + ProviderID: storageProviderID, + }, + HotReloadConfig: config.MCPOperationsHotReloadConfig{ + Enabled: true, + Interval: 1 * time.Second, + }, + }, + RouterOptions: []core.Option{ + core.WithStorageProviders(config.StorageProviders{ + FileSystem: []config.FileSystemStorageProvider{ + { + ID: storageProviderID, + Path: operationsDir, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + require.NoError(t, err) + + initialToolsCount := len(resp.Tools) + + mcpOperationFile := filepath.Join(operationsDir, "main.graphql") + + // write mcp operation content + err = os.WriteFile(mcpOperationFile, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0644) + assert.NoError(t, err) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + assert.Len(t, resp.Tools, initialToolsCount+1) + + // verity getEmployeeNotes operation is present + require.Contains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes", + Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotes", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + }, 10*time.Second, 100*time.Millisecond) + + err = os.Remove(mcpOperationFile) + assert.NoError(t, err) + + assert.EventuallyWithT(t, func(t *assert.CollectT) { + + resp, err = xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + assert.Len(t, resp.Tools, initialToolsCount) + + // verity getEmployeeNotes operation tool is properly removed + require.NotContains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes", + Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotes", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + + }, 10*time.Second, 100*time.Millisecond) + + }) + }) + + t.Run("List Updated User Operations On Content Update", func(t *testing.T) { + + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + Storage: config.MCPStorageConfig{ + ProviderID: storageProviderID, + }, + HotReloadConfig: config.MCPOperationsHotReloadConfig{ + Enabled: true, + Interval: 1 * time.Second, + }, + }, + RouterOptions: []core.Option{ + core.WithStorageProviders(config.StorageProviders{ + FileSystem: []config.FileSystemStorageProvider{ + { + ID: storageProviderID, + Path: operationsDir, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + mcpOperationFile := filepath.Join(operationsDir, "main.graphql") + + // write mcp operation content + require.NoError(t, os.WriteFile(mcpOperationFile, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0o600)) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + + // verity getEmployeeNotes operation is present + require.Contains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes", + Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotes", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + }, 10*time.Second, 100*time.Millisecond) + + // update mcp operation content + require.NoError(t, os.WriteFile(mcpOperationFile, []byte("\nquery getEmployeeNotesUpdatedTitle($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0o600)) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + + // verity getEmployeeNotesUpdatedTitle operation is present + require.Contains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes_updated_title", + Description: "Executes the GraphQL operation 'getEmployeeNotesUpdatedTitle' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotesUpdatedTitle", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + }, 10*time.Second, 100*time.Millisecond) + }) + }) +} diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 9a3d953ac2..8b8ad1dcb6 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -1434,17 +1434,19 @@ func configureRouter(listenerAddr string, testConfig *Config, routerConfig *node } if testConfig.MCP.Enabled { - // Add Storage provider - routerOpts = append(routerOpts, core.WithStorageProviders(config.StorageProviders{ - FileSystem: []config.FileSystemStorageProvider{ - { - ID: "test", - Path: "testdata/mcp_operations", + if testConfig.MCP.Storage.ProviderID == "" { + // Add Storage provider + routerOpts = append(routerOpts, core.WithStorageProviders(config.StorageProviders{ + FileSystem: []config.FileSystemStorageProvider{ + { + ID: "test", + Path: "testdata/mcp_operations", + }, }, - }, - })) + })) - testConfig.MCP.Storage.ProviderID = "test" + testConfig.MCP.Storage.ProviderID = "test" + } routerOpts = append(routerOpts, core.WithMCP(testConfig.MCP)) } diff --git a/router/core/graph_server.go b/router/core/graph_server.go index d112c75552..90a78ea49b 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1210,9 +1210,22 @@ func (s *graphServer) buildGraphMux( // We support the MCP only on the base graph. Feature flags are not supported yet. if opts.IsBaseGraph() && s.mcpServer != nil { - if mErr := s.mcpServer.Reload(executor.ClientSchema); mErr != nil { + if mErr := s.mcpServer.Reload(ctx, executor.ClientSchema); mErr != nil { return nil, fmt.Errorf("failed to reload MCP server: %w", mErr) } + go func() { + for { + select { + case <-s.mcpServer.ReloadOperationsChannel(): + s.logger.Log(zap.InfoLevel, "Reloading mcp server!") + if mErr := s.mcpServer.Reload(ctx, executor.ClientSchema); mErr != nil { + return + } + case <-ctx.Done(): + return + } + } + }() } if s.Config.cacheWarmup != nil && s.Config.cacheWarmup.Enabled { diff --git a/router/core/router.go b/router/core/router.go index f7aa41de12..ac0a7d4eaa 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -861,6 +861,7 @@ func (r *Router) bootstrap(ctx context.Context) error { mcpserver.WithExcludeMutations(r.mcp.ExcludeMutations), mcpserver.WithEnableArbitraryOperations(r.mcp.EnableArbitraryOperations), mcpserver.WithExposeSchema(r.mcp.ExposeSchema), + mcpserver.WithHotReload(r.mcp.HotReloadConfig.Enabled, r.mcp.HotReloadConfig.Interval), } // Determine the router GraphQL endpoint diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 923fcac36f..59f1ff8e6d 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -925,14 +925,15 @@ type CacheWarmupConfiguration struct { } type MCPConfiguration struct { - Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` - Server MCPServer `yaml:"server,omitempty"` - Storage MCPStorageConfig `yaml:"storage,omitempty"` - GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` - ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` - EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` - ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` - RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` + Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` + Server MCPServer `yaml:"server,omitempty"` + Storage MCPStorageConfig `yaml:"storage,omitempty"` + GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` + ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` + EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` + ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` + RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` + HotReloadConfig MCPOperationsHotReloadConfig `yaml:"hot_reload_config,omitempty" envPrefix:"MCP_OPERATIONS_HOT_RELOAD_"` } type MCPStorageConfig struct { @@ -944,6 +945,11 @@ type MCPServer struct { BaseURL string `yaml:"base_url,omitempty" env:"MCP_SERVER_BASE_URL"` } +type MCPOperationsHotReloadConfig struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` + Interval time.Duration `yaml:"interval" envDefault:"10s" env:"INTERVAL"` +} + type PluginsConfiguration struct { Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` Path string `yaml:"path" envDefault:"plugins" env:"PATH"` diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 477be33e9a..f4eb0fd20b 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1920,6 +1920,26 @@ "type": "boolean", "default": false, "description": "Expose the full GraphQL schema through MCP. When enabled, AI models can request the complete schema of your API." + }, + "hot_reload_config": { + "type": "object", + "description": "Hot reloading configuration for MCP operations.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable hot reloading for MCP Operations.", + "default": false + }, + "interval": { + "type": "string", + "description": "The interval at which the MCP Operations directory is checked for changes. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", + "default": "10s", + "duration": { + "minimum": "5s" + } + } + } } } }, diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index c8726845b8..70544b33c4 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -49,6 +49,9 @@ mcp: base_url: 'http://localhost:5025' storage: provider_id: mcp + hot_reload_config: + enabled: false + interval: '10s' watch_config: enabled: true diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index 0e11cad37c..6a378cb9ac 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -125,7 +125,11 @@ "ExcludeMutations": false, "EnableArbitraryOperations": false, "ExposeSchema": false, - "RouterURL": "" + "RouterURL": "", + "HotReloadConfig": { + "Enabled": false, + "Interval": 10000000000 + } }, "DemoMode": false, "Modules": null, diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index 5be309fc5d..ae55c8026c 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -160,7 +160,11 @@ "ExcludeMutations": false, "EnableArbitraryOperations": false, "ExposeSchema": false, - "RouterURL": "https://cosmo-router.wundergraph.com" + "RouterURL": "https://cosmo-router.wundergraph.com", + "HotReloadConfig": { + "Enabled": false, + "Interval": 10000000000 + } }, "DemoMode": true, "Modules": { diff --git a/router/pkg/mcpserver/operation_manager.go b/router/pkg/mcpserver/operation_manager.go index 0bbe2e15d6..e055329f13 100644 --- a/router/pkg/mcpserver/operation_manager.go +++ b/router/pkg/mcpserver/operation_manager.go @@ -1,7 +1,9 @@ package mcpserver import ( + "context" "fmt" + "time" "github.com/wundergraph/cosmo/router/pkg/schemaloader" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" @@ -30,10 +32,10 @@ func NewOperationsManager(schemaDoc *ast.Document, logger *zap.Logger, excludeMu } // LoadOperationsFromDirectory loads operations from a specified directory -func (om *OperationsManager) LoadOperationsFromDirectory(operationsDir string) error { +func (om *OperationsManager) LoadOperationsFromDirectory(ctx context.Context, operationsDir string, reloadOperationsChan chan bool, hotReload bool, hotReloadInterval time.Duration) error { // Load operations loader := schemaloader.NewOperationLoader(om.logger, om.schemaDoc) - operations, err := loader.LoadOperationsFromDirectory(operationsDir) + operations, err := loader.LoadOperationsFromDirectory(ctx, operationsDir, reloadOperationsChan, hotReload, hotReloadInterval) if err != nil { return fmt.Errorf("failed to load operations: %w", err) } diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index 26b636b54e..8d671b096b 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -68,6 +68,10 @@ type Options struct { EnableArbitraryOperations bool // ExposeSchema determines whether the GraphQL schema is exposed ExposeSchema bool + // Enables hot reloading for MCP operations + HotReload bool + // The interval at which the MCP Operations directory is checked for changes + HotReloadInterval time.Duration } // GraphQLSchemaServer represents an MCP server that works with GraphQL schemas and operations @@ -88,6 +92,9 @@ type GraphQLSchemaServer struct { operationsManager *OperationsManager schemaCompiler *SchemaCompiler registeredTools []string + hotReload bool + hotReloadInterval time.Duration + reloadOperationsChan chan bool } type graphqlRequest struct { @@ -213,6 +220,9 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) enableArbitraryOperations: options.EnableArbitraryOperations, exposeSchema: options.ExposeSchema, baseURL: options.BaseURL, + hotReload: options.HotReload, + hotReloadInterval: options.HotReloadInterval, + reloadOperationsChan: make(chan bool), } return gs, nil @@ -273,6 +283,14 @@ func WithExposeSchema(exposeSchema bool) func(*Options) { } } +// WithHotReload sets the hot reload options +func WithHotReload(hotReload bool, hotReloadInterval time.Duration) func(*Options) { + return func(o *Options) { + o.HotReload = hotReload + o.HotReloadInterval = hotReloadInterval + } +} + // ServeSSE starts the server with SSE transport func (s *GraphQLSchemaServer) ServeSSE() (*server.SSEServer, error) { sseServer := server.NewSSEServer(s.server, @@ -321,7 +339,7 @@ func (s *GraphQLSchemaServer) Start() error { } // Reload reloads the operations and schema -func (s *GraphQLSchemaServer) Reload(schema *ast.Document) error { +func (s *GraphQLSchemaServer) Reload(ctx context.Context, schema *ast.Document) error { if s.server == nil { return fmt.Errorf("server is not started") @@ -330,7 +348,7 @@ func (s *GraphQLSchemaServer) Reload(schema *ast.Document) error { s.schemaCompiler = NewSchemaCompiler(s.logger) s.operationsManager = NewOperationsManager(schema, s.logger, s.excludeMutations) - if err := s.operationsManager.LoadOperationsFromDirectory(s.operationsDir); err != nil { + if err := s.operationsManager.LoadOperationsFromDirectory(ctx, s.operationsDir, s.reloadOperationsChan, s.hotReload, s.hotReloadInterval); err != nil { return fmt.Errorf("failed to load operations: %w", err) } @@ -723,3 +741,7 @@ func (s *GraphQLSchemaServer) handleGetGraphQLSchema() func(ctx context.Context, return mcp.NewToolResultText(schemaStr), nil } } + +func (s *GraphQLSchemaServer) ReloadOperationsChannel() chan bool { + return s.reloadOperationsChan +} diff --git a/router/pkg/schemaloader/loader.go b/router/pkg/schemaloader/loader.go index 5a4cb928eb..a3510843d7 100644 --- a/router/pkg/schemaloader/loader.go +++ b/router/pkg/schemaloader/loader.go @@ -1,13 +1,17 @@ package schemaloader import ( + "context" "encoding/json" + "errors" "fmt" "io/fs" "os" "path/filepath" "strings" + "time" + "github.com/wundergraph/cosmo/router/pkg/watcher" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" @@ -43,12 +47,14 @@ func NewOperationLoader(logger *zap.Logger, schemaDoc *ast.Document) *OperationL } // LoadOperationsFromDirectory loads all GraphQL operations from files in the specified directory -func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operation, error) { +func (l *OperationLoader) LoadOperationsFromDirectory(ctx context.Context, dirPath string, reloadOperationsChan chan bool, hotReload bool, hotReloadInterval time.Duration) ([]Operation, error) { var operations []Operation // Create an operation validator validator := astvalidation.DefaultOperationValidator() + ctx, cancel := context.WithCancel(ctx) + // Walk through the directory and process GraphQL files err := filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { if err != nil { @@ -129,9 +135,44 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operati }) if err != nil { + cancel() return nil, fmt.Errorf("error walking mcp operations directory %s: %w", dirPath, err) } + if hotReload { + watchFunc, err := watcher.New(watcher.Options{ + Interval: hotReloadInterval, + Logger: l.Logger, + Callback: func() { + reloadOperationsChan <- true + cancel() + }, + Directory: watcher.DirOptions{ + DirPath: dirPath, + Filter: func(path string) bool { + return isGraphQLFile(path) + }, + }, + }) + + if err != nil { + cancel() + l.Logger.Error("Could not create watcher", zap.Error(err)) + } + + go func() { + if err := watchFunc(ctx); err != nil { + if !errors.Is(err, context.Canceled) { + l.Logger.Error("Error watching operations path", zap.Error(err)) + } else { + l.Logger.Debug("Watcher context cancelled, shutting down") + } + } + }() + } else { + cancel() + } + return operations, nil } diff --git a/router/pkg/watcher/watcher.go b/router/pkg/watcher/watcher.go index 1990a9b502..e2b4d94431 100644 --- a/router/pkg/watcher/watcher.go +++ b/router/pkg/watcher/watcher.go @@ -3,20 +3,54 @@ package watcher import ( "context" "errors" + "fmt" + "io/fs" "os" + "path/filepath" "time" "go.uber.org/zap" ) +type DirOptions struct { + DirPath string + Filter func(string) bool +} + type Options struct { Interval time.Duration Logger *zap.Logger Paths []string + Directory DirOptions Callback func() TickSource <-chan time.Time } +func ListDirFilePaths(diropts DirOptions) ([]string, error) { + var files []string + if diropts.DirPath != "" { + err := filepath.WalkDir(diropts.DirPath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + // Skip directories + if d.IsDir() { + return nil + } + // Accept if filter passes the file + if diropts.Filter != nil && !diropts.Filter(path) { + return nil + } + files = append(files, path) + return nil + }) + if err != nil { + return files, fmt.Errorf("error walking directory %s: %w", diropts.DirPath, err) + } + } + return files, nil +} + func New(options Options) (func(ctx context.Context) error, error) { if options.Interval <= 0 { return nil, errors.New("interval must be greater than zero") @@ -26,8 +60,12 @@ func New(options Options) (func(ctx context.Context) error, error) { return nil, errors.New("logger must be provided") } - if len(options.Paths) == 0 { - return nil, errors.New("path must be provided") + if len(options.Paths) == 0 && options.Directory.DirPath == "" { + return nil, errors.New("either paths or directory must be provided") + } + + if len(options.Paths) != 0 && options.Directory.DirPath != "" { + return nil, errors.New("can't watch both paths and directory") } if options.Callback == nil { @@ -46,6 +84,14 @@ func New(options Options) (func(ctx context.Context) error, error) { prevModTimes := make(map[string]time.Time) + var err error + if options.Directory.DirPath != "" { + options.Paths, err = ListDirFilePaths(options.Directory) + if err != nil { + ll.Error("failed to list directory files", zap.Error(err)) + } + } + for _, path := range options.Paths { stat, err := os.Stat(path) if err != nil { @@ -63,6 +109,13 @@ func New(options Options) (func(ctx context.Context) error, error) { case <-ticker: changesDetected := false + if options.Directory.DirPath != "" { + options.Paths, err = ListDirFilePaths(options.Directory) + if err != nil { + ll.Error("failed to list directory files", zap.Error(err)) + } + } + for _, path := range options.Paths { stat, err := os.Stat(path) if err != nil { @@ -76,12 +129,23 @@ func New(options Options) (func(ctx context.Context) error, error) { zap.Time("prev_mod_time", prevModTimes[path]), zap.Time("current_mod_time", stat.ModTime()), ) - if stat.ModTime().After(prevModTimes[path]) { + _, seen := prevModTimes[path] + + // Detects new files & existing file updates + if !seen || stat.ModTime().After(prevModTimes[path]) { prevModTimes[path] = stat.ModTime() changesDetected = true } } + for path := range prevModTimes { + _, err := os.Stat(path) + if os.IsNotExist(err) { + delete(prevModTimes, path) + changesDetected = true + } + } + if changesDetected { // If there are changes detected this tick // We want to wait for the next tick (without changes) @@ -92,6 +156,7 @@ func New(options Options) (func(ctx context.Context) error, error) { // but the previous tick had changes detected pendingCallback = false options.Callback() + ll.Info("Running callback!") } case <-ctx.Done(): return ctx.Err() diff --git a/router/pkg/watcher/watcher_test.go b/router/pkg/watcher/watcher_test.go index db128d3dc8..e02030e180 100644 --- a/router/pkg/watcher/watcher_test.go +++ b/router/pkg/watcher/watcher_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "path/filepath" + "strings" "testing" "time" @@ -52,30 +53,35 @@ func TestOptionsValidation(t *testing.T) { } }) - t.Run("path not provided", func(t *testing.T) { + t.Run("either paths or directory must be provided", func(t *testing.T) { t.Parallel() - t.Run("nil path slice", func(t *testing.T) { - _, err := watcher.New(watcher.Options{ - Interval: watchInterval, - Logger: zap.NewNop(), - Paths: nil, - }) - if assert.Error(t, err) { - assert.ErrorContains(t, err, "path must be provided") - } + _, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), }) - t.Run("empty path slice", func(t *testing.T) { - _, err := watcher.New(watcher.Options{ - Interval: watchInterval, - Logger: zap.NewNop(), - Paths: []string{}, - }) - if assert.Error(t, err) { - assert.ErrorContains(t, err, "path must be provided") - } + if assert.Error(t, err) { + assert.ErrorContains(t, err, "either paths or directory must be provided") + } + }) + + t.Run("can't watch both paths and directory", func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + tempFile := filepath.Join(dir, "temp_1.json") + + _, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + Paths: []string{tempFile}, + Directory: watcher.DirOptions{ + DirPath: dir, + }, }) + if assert.Error(t, err) { + assert.ErrorContains(t, err, "can't watch both paths and directory") + } }) t.Run("callback not provided", func(t *testing.T) { @@ -775,6 +781,167 @@ func TestWatch(t *testing.T) { sendTick(ticker) spy.AssertCalled(t, 1) }) + + t.Run("create a file in watcher directory", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dir := t.TempDir() + + spy := test.NewCallSpy() + + tickerChan := make(chan time.Time) + watchFunc, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + Directory: watcher.DirOptions{ + DirPath: dir, + }, + Callback: spy.Call, + TickSource: tickerChan, + }) + require.NoError(t, err) + + go func() { + _ = watchFunc(ctx) + }() + + sendTick(tickerChan) + sendTick(tickerChan) + + tempFile := filepath.Join(dir, "config.json") + require.NoError(t, os.WriteFile(tempFile, []byte("a"), 0o600)) + + sendTick(tickerChan) + spy.AssertCalled(t, 0) + sendTick(tickerChan) + spy.AssertCalled(t, 1) + }) + + t.Run("modify a file in watcher directory", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dir := t.TempDir() + tempFile := filepath.Join(dir, "config.json") + require.NoError(t, os.WriteFile(tempFile, []byte("a"), 0o600)) + + spy := test.NewCallSpy() + + tickerChan := make(chan time.Time) + watchFunc, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + Directory: watcher.DirOptions{ + DirPath: dir, + }, + Callback: spy.Call, + TickSource: tickerChan, + }) + require.NoError(t, err) + + go func() { + _ = watchFunc(ctx) + }() + + sendTick(tickerChan) + sendTick(tickerChan) + + require.NoError(t, os.WriteFile(tempFile, []byte("b"), 0o600)) + + sendTick(tickerChan) + sendTick(tickerChan) + spy.AssertCalled(t, 1) + }) + + t.Run("delete a file in watcher directory", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dir := t.TempDir() + tempFile := filepath.Join(dir, "config.json") + require.NoError(t, os.WriteFile(tempFile, []byte("a"), 0o600)) + + spy := test.NewCallSpy() + + tickerChan := make(chan time.Time) + watchFunc, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + Directory: watcher.DirOptions{ + DirPath: dir, + }, + Callback: spy.Call, + TickSource: tickerChan, + }) + require.NoError(t, err) + + go func() { + _ = watchFunc(ctx) + }() + + sendTick(tickerChan) + sendTick(tickerChan) + + require.NoError(t, os.Remove(tempFile)) + + sendTick(tickerChan) + sendTick(tickerChan) + spy.AssertCalled(t, 1) + }) + + t.Run("rename multiple files in watcher directory", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dir := t.TempDir() + tempFile1 := filepath.Join(dir, "file_1.json") + tempFile2 := filepath.Join(dir, "file_2.json") + tempFile3 := filepath.Join(dir, "file_3.json") + + require.NoError(t, os.WriteFile(tempFile1, []byte("file_1"), 0o600)) + require.NoError(t, os.WriteFile(tempFile2, []byte("file_2"), 0o600)) + require.NoError(t, os.WriteFile(tempFile3, []byte("file_3"), 0o600)) + + spy := test.NewCallSpy() + + tickerChan := make(chan time.Time) + watchFunc, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + Directory: watcher.DirOptions{ + DirPath: dir, + }, + Callback: spy.Call, + TickSource: tickerChan, + }) + require.NoError(t, err) + + go func() { + _ = watchFunc(ctx) + }() + + sendTick(tickerChan) + sendTick(tickerChan) + + newTempFile1 := filepath.Join(dir, "new_file_1.json") + newTempFile2 := filepath.Join(dir, "new_file_2.json") + newTempFile3 := filepath.Join(dir, "new_file_3.json") + + require.NoError(t, os.Rename(tempFile1, newTempFile1)) + require.NoError(t, os.Rename(tempFile2, newTempFile2)) + require.NoError(t, os.Rename(tempFile3, newTempFile3)) + + // Two ticks are needed to run the callback + sendTick(tickerChan) + sendTick(tickerChan) + spy.AssertCalled(t, 1) + }) } func TestCancel(t *testing.T) { @@ -806,6 +973,35 @@ func TestCancel(t *testing.T) { require.ErrorIs(t, eg.Wait(), context.Canceled) } +func TestListDirFilePaths(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + tempFile1 := filepath.Join(dir, "file_1.json") + tempFile2 := filepath.Join(dir, "file_2.json") + tempFile3 := filepath.Join(dir, "file_3.txt") + + require.NoError(t, os.WriteFile(tempFile1, []byte("a1"), 0o600)) + require.NoError(t, os.WriteFile(tempFile2, []byte("a2"), 0o600)) + require.NoError(t, os.WriteFile(tempFile3, []byte("a3"), 0o600)) + + diropts := watcher.DirOptions{ + DirPath: dir, + Filter: func(path string) bool { + return strings.ToLower(filepath.Ext(path)) == ".json" + }, + } + + filteredFilePaths, err := watcher.ListDirFilePaths(diropts) + + require.NoError(t, err) + require.Len(t, filteredFilePaths, 2) + + require.Contains(t, filteredFilePaths, tempFile1) + require.Contains(t, filteredFilePaths, tempFile2) + require.NotContains(t, filteredFilePaths, tempFile3) +} + // sendTick helper function which adds a sleep timeout // so users don't need to manually add sleep to every test func sendTick(channel chan time.Time) {