From 4617ce2d39279de7492faa2e8b8841ced6d8e479 Mon Sep 17 00:00:00 2001 From: melsonic Date: Fri, 13 Jun 2025 00:56:43 +0530 Subject: [PATCH 01/12] feat: Support hot reloading of MCP operation files --- router/core/graph_server.go | 10 ++++ router/pkg/mcpserver/operation_manager.go | 4 +- router/pkg/mcpserver/server.go | 4 +- router/pkg/schemaloader/loader.go | 71 ++++++++++++++++++++++- 4 files changed, 84 insertions(+), 5 deletions(-) diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 447382c1e3..a6b4f0d99c 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1082,6 +1082,16 @@ func (s *graphServer) buildGraphMux(ctx context.Context, if mErr := s.mcpServer.Reload(executor.ClientSchema); mErr != nil { return nil, fmt.Errorf("failed to reload MCP server: %w", mErr) } + go func() { + for { + if reloadOperations := <-s.mcpServer.ReloadOperationsChan; reloadOperations { + s.logger.Log(zap.InfoLevel, "Reloading mcp server!") + if mErr := s.mcpServer.Reload(executor.ClientSchema); mErr != nil { + return + } + } + } + }() } if s.Config.cacheWarmup != nil && s.Config.cacheWarmup.Enabled { diff --git a/router/pkg/mcpserver/operation_manager.go b/router/pkg/mcpserver/operation_manager.go index 0bbe2e15d6..03035b940c 100644 --- a/router/pkg/mcpserver/operation_manager.go +++ b/router/pkg/mcpserver/operation_manager.go @@ -30,10 +30,10 @@ func NewOperationsManager(schemaDoc *ast.Document, logger *zap.Logger, excludeMu } // LoadOperationsFromDirectory loads operations from a specified directory -func (om *OperationsManager) LoadOperationsFromDirectory(operationsDir string) error { +func (om *OperationsManager) LoadOperationsFromDirectory(ReloadOperationsChan chan bool, operationsDir string) error { // Load operations loader := schemaloader.NewOperationLoader(om.logger, om.schemaDoc) - operations, err := loader.LoadOperationsFromDirectory(operationsDir) + operations, err := loader.LoadOperationsFromDirectory(ReloadOperationsChan, operationsDir) if err != nil { return fmt.Errorf("failed to load operations: %w", err) } diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index 26b636b54e..07f4ad27d8 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -88,6 +88,7 @@ type GraphQLSchemaServer struct { operationsManager *OperationsManager schemaCompiler *SchemaCompiler registeredTools []string + ReloadOperationsChan chan bool } type graphqlRequest struct { @@ -213,6 +214,7 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) enableArbitraryOperations: options.EnableArbitraryOperations, exposeSchema: options.ExposeSchema, baseURL: options.BaseURL, + ReloadOperationsChan: make(chan bool), } return gs, nil @@ -330,7 +332,7 @@ func (s *GraphQLSchemaServer) Reload(schema *ast.Document) error { s.schemaCompiler = NewSchemaCompiler(s.logger) s.operationsManager = NewOperationsManager(schema, s.logger, s.excludeMutations) - if err := s.operationsManager.LoadOperationsFromDirectory(s.operationsDir); err != nil { + if err := s.operationsManager.LoadOperationsFromDirectory(s.ReloadOperationsChan, s.operationsDir); err != nil { return fmt.Errorf("failed to load operations: %w", err) } diff --git a/router/pkg/schemaloader/loader.go b/router/pkg/schemaloader/loader.go index 5a4cb928eb..440acaf3e2 100644 --- a/router/pkg/schemaloader/loader.go +++ b/router/pkg/schemaloader/loader.go @@ -1,13 +1,19 @@ package schemaloader import ( + "context" "encoding/json" + "errors" "fmt" "io/fs" "os" "path/filepath" "strings" + "syscall" + "time" + "unsafe" + "github.com/wundergraph/cosmo/router/pkg/watcher" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" @@ -43,14 +49,49 @@ func NewOperationLoader(logger *zap.Logger, schemaDoc *ast.Document) *OperationL } // LoadOperationsFromDirectory loads all GraphQL operations from files in the specified directory -func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operation, error) { +func (l *OperationLoader) LoadOperationsFromDirectory(ReloadOperationsChan chan bool, dirPath string) ([]Operation, error) { var operations []Operation // Create an operation validator validator := astvalidation.DefaultOperationValidator() + fileDescriptor, err := syscall.InotifyInit() + if err != nil { + return nil, err + } + + watchDescriptor, err := syscall.InotifyAddWatch(fileDescriptor, dirPath, syscall.IN_CREATE|syscall.IN_DELETE|syscall.IN_DELETE_SELF) + if err != nil { + return nil, err + } + + pathCtx, pathCtxCancel := context.WithCancel(context.Background()) + + startupDelay := 5 * time.Second + + go func() { + inotifyEvent := make([]byte, 4096) + for { + _, err := syscall.Read(fileDescriptor, inotifyEvent) + if err != nil { + break + } + + var event *syscall.InotifyEvent = (*syscall.InotifyEvent)(unsafe.Pointer(&inotifyEvent[0])) + + if event.Mask&syscall.IN_CREATE == syscall.IN_CREATE || event.Mask&syscall.IN_DELETE == syscall.IN_DELETE || event.Mask&syscall.IN_DELETE_SELF == syscall.IN_DELETE_SELF { + break + } + + } + syscall.InotifyRmWatch(fileDescriptor, uint32(watchDescriptor)) + ReloadOperationsChan <- true + pathCtxCancel() + }() + + l.Logger.Info("Will walk through " + dirPath) // Walk through the directory and process GraphQL files - err := filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { + err = filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } @@ -71,6 +112,32 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operati return fmt.Errorf("failed to read file %s: %w", path, err) } + watchFunc, err := watcher.New(watcher.Options{ + Interval: 10 * time.Second, + Logger: l.Logger, + Path: path, + Callback: func() { + ReloadOperationsChan <- true + pathCtxCancel() + }, + }) + + if err != nil { + l.Logger.Error("Could not create watcher", zap.Error(err)) + return err + } + + go func() { + time.Sleep(startupDelay) + if err := watchFunc(pathCtx); err != nil { + if !errors.Is(err, context.Canceled) { + l.Logger.Error("Error watching operations path", zap.Error(err)) + } else { + l.Logger.Debug("Watcher context cancelled, shutting down") + } + } + }() + // Parse the operation operationString := string(content) opDoc, err := parseOperation(path, operationString) From 8df711397c937733d11227a7603f3067b24f170a Mon Sep 17 00:00:00 2001 From: melsonic Date: Wed, 18 Jun 2025 00:58:06 +0530 Subject: [PATCH 02/12] added mcp hot reload tests --- router-tests/mcp_hot_reload_test.go | 186 ++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 router-tests/mcp_hot_reload_test.go diff --git a/router-tests/mcp_hot_reload_test.go b/router-tests/mcp_hot_reload_test.go new file mode 100644 index 0000000000..1be1483e0f --- /dev/null +++ b/router-tests/mcp_hot_reload_test.go @@ -0,0 +1,186 @@ +package integration + +import ( + "encoding/json" + "os" + "path" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +func TestMCPOperationHotReload(t *testing.T) { + // create a temp graphql file into mcp_operations + dir := "./testdata/mcp_operations" + fileName := "getEmployeeNotes.graphql" + filePath := path.Join(dir, fileName) + + t.Parallel() + + t.Run("List Updated User Operations On Addition and Removal", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + os.Remove(filePath) + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + require.NoError(t, err) + require.NotNil(t, resp) + + // verify initial tools count + assert.Len(t, resp.Tools, 3) + + // create new mcp operation file + file, err := os.Create(filePath) + assert.NoError(t, err) + defer func() { + file.Close() + os.Remove(filePath) + }() + + // write mcp operation content + _, err = file.WriteString(` + query getEmployeeNotes($id: Int!) { + employee(id: $id) { + id + notes + } + } + `) + assert.NoError(t, err) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + // List updated Tools + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + + // verify updated tools count + assert.Len(t, resp.Tools, 4) + + // verity getEmployeeNotes operation + require.Contains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes", + Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotes", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + }, 10*time.Second, 250*time.Millisecond) + + }) + }) + + t.Run("List Updated User Operations On Content Update", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + os.Remove(filePath) + + // create new mcp operation file + file, err := os.Create(filePath) + assert.NoError(t, err) + defer func() { + file.Close() + os.Remove(filePath) + }() + + // write mcp operation content + _, err = file.WriteString(` + query getEmployeeNotes($id: Int!) { + employee(id: $id) { + id + notes + } + } + `) + assert.NoError(t, err) + + time.Sleep(5 * time.Second) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + // List updated Tools + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + + // verity getEmployeeNotes operation + require.Contains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes", + Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotes", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + }, 15*time.Second, 250*time.Millisecond) + + // update mcp operation content + err = os.WriteFile(filePath, []byte(` + query getEmployeeNotesUpdatedTitle($id: Int!) { + employee(id: $id) { + id + notes + } + } + `), 0644) + assert.NoError(t, err) + + time.Sleep(5 * time.Second) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + // fetch updated mcp tools list + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + + // verity getEmployeeNotesUpdatedTitle operation + require.Contains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes_updated_title", + Description: "Executes the GraphQL operation 'getEmployeeNotesUpdatedTitle' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotesUpdatedTitle", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + }, 15*time.Second, 250*time.Millisecond) + }) + }) +} From 6b9e16131df841d7213911c12a8472295a6d3d96 Mon Sep 17 00:00:00 2001 From: melsonic Date: Sat, 21 Jun 2025 17:32:11 +0530 Subject: [PATCH 03/12] added polling based watcher, removed inotify --- router-tests/mcp_hot_reload_test.go | 25 ++-- router/core/graph_server.go | 2 +- router/core/router.go | 2 + router/pkg/config/config.go | 22 ++- router/pkg/config/config.schema.json | 21 +++ router/pkg/config/fixtures/full.yaml | 3 + .../pkg/config/testdata/config_defaults.json | 6 +- router/pkg/config/testdata/config_full.json | 6 +- router/pkg/mcpserver/operation_manager.go | 5 +- router/pkg/mcpserver/server.go | 32 +++- router/pkg/schemaloader/loader.go | 137 ++++++++++-------- 11 files changed, 177 insertions(+), 84 deletions(-) diff --git a/router-tests/mcp_hot_reload_test.go b/router-tests/mcp_hot_reload_test.go index 1be1483e0f..3bf7601012 100644 --- a/router-tests/mcp_hot_reload_test.go +++ b/router-tests/mcp_hot_reload_test.go @@ -16,9 +16,9 @@ import ( func TestMCPOperationHotReload(t *testing.T) { // create a temp graphql file into mcp_operations - dir := "./testdata/mcp_operations" + mcpOperationsDirectory := "./testdata/mcp_operations" fileName := "getEmployeeNotes.graphql" - filePath := path.Join(dir, fileName) + filePath := path.Join(mcpOperationsDirectory, fileName) t.Parallel() @@ -26,6 +26,10 @@ func TestMCPOperationHotReload(t *testing.T) { testenv.Run(t, &testenv.Config{ MCP: config.MCPConfiguration{ Enabled: true, + HotReloadConfig: config.MCPOperationsHotReloadConfig{ + Enabled: true, + Interval: 5 * time.Second, + }, }, }, func(t *testing.T, xEnv *testenv.Environment) { @@ -36,8 +40,8 @@ func TestMCPOperationHotReload(t *testing.T) { require.NoError(t, err) require.NotNil(t, resp) - // verify initial tools count - assert.Len(t, resp.Tools, 3) + // initial tools count + initialToolsCount := len(resp.Tools) // create new mcp operation file file, err := os.Create(filePath) @@ -65,7 +69,7 @@ func TestMCPOperationHotReload(t *testing.T) { assert.NoError(t, err) // verify updated tools count - assert.Len(t, resp.Tools, 4) + assert.Len(t, resp.Tools, initialToolsCount+1) // verity getEmployeeNotes operation require.Contains(t, resp.Tools, mcp.Tool{ @@ -84,7 +88,7 @@ func TestMCPOperationHotReload(t *testing.T) { OpenWorldHint: mcp.ToBoolPtr(true), }, }) - }, 10*time.Second, 250*time.Millisecond) + }, 15*time.Second, 250*time.Millisecond) }) }) @@ -93,6 +97,10 @@ func TestMCPOperationHotReload(t *testing.T) { testenv.Run(t, &testenv.Config{ MCP: config.MCPConfiguration{ Enabled: true, + HotReloadConfig: config.MCPOperationsHotReloadConfig{ + Enabled: true, + Interval: 5 * time.Second, + }, }, }, func(t *testing.T, xEnv *testenv.Environment) { @@ -117,8 +125,6 @@ func TestMCPOperationHotReload(t *testing.T) { `) assert.NoError(t, err) - time.Sleep(5 * time.Second) - require.EventuallyWithT(t, func(t *assert.CollectT) { // List updated Tools toolsRequest := mcp.ListToolsRequest{} @@ -155,8 +161,6 @@ func TestMCPOperationHotReload(t *testing.T) { `), 0644) assert.NoError(t, err) - time.Sleep(5 * time.Second) - require.EventuallyWithT(t, func(t *assert.CollectT) { // fetch updated mcp tools list toolsRequest := mcp.ListToolsRequest{} @@ -183,4 +187,5 @@ func TestMCPOperationHotReload(t *testing.T) { }, 15*time.Second, 250*time.Millisecond) }) }) + } diff --git a/router/core/graph_server.go b/router/core/graph_server.go index a6b4f0d99c..22443b2d1e 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1084,7 +1084,7 @@ func (s *graphServer) buildGraphMux(ctx context.Context, } go func() { for { - if reloadOperations := <-s.mcpServer.ReloadOperationsChan; reloadOperations { + if reloadOperations := <-s.mcpServer.ReloadOperationsChannel(); reloadOperations { s.logger.Log(zap.InfoLevel, "Reloading mcp server!") if mErr := s.mcpServer.Reload(executor.ClientSchema); mErr != nil { return diff --git a/router/core/router.go b/router/core/router.go index 6c59f3565f..91af4b3ac0 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -848,6 +848,8 @@ func (r *Router) bootstrap(ctx context.Context) error { mcpserver.WithExcludeMutations(r.mcp.ExcludeMutations), mcpserver.WithEnableArbitraryOperations(r.mcp.EnableArbitraryOperations), mcpserver.WithExposeSchema(r.mcp.ExposeSchema), + mcpserver.WithHotReload(r.mcp.HotReloadConfig.Enabled), + mcpserver.WithHotReloadInterval(r.mcp.HotReloadConfig.Interval), } // Determine the router GraphQL endpoint diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index ddf514dcff..a9517ae69d 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -898,14 +898,15 @@ type CacheWarmupConfiguration struct { } type MCPConfiguration struct { - Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` - Server MCPServer `yaml:"server,omitempty"` - Storage MCPStorageConfig `yaml:"storage,omitempty"` - GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` - ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` - EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` - ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` - RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` + Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` + Server MCPServer `yaml:"server,omitempty"` + Storage MCPStorageConfig `yaml:"storage,omitempty"` + GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` + ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` + EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` + ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` + RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` + HotReloadConfig MCPOperationsHotReloadConfig `yaml:"hot_reload_config,omitempty" envPrefix:"MCP_OPERATIONS_HOT_RELOAD_"` } type MCPStorageConfig struct { @@ -917,6 +918,11 @@ type MCPServer struct { BaseURL string `yaml:"base_url,omitempty" env:"MCP_SERVER_BASE_URL"` } +type MCPOperationsHotReloadConfig struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` + Interval time.Duration `yaml:"interval" envDefault:"10s" env:"INTERVAL"` +} + type PluginsConfiguration struct { Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` Path string `yaml:"path" envDefault:"plugins" env:"PATH"` diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 1ce6c83b59..d3eb339b18 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1903,6 +1903,27 @@ "type": "boolean", "default": false, "description": "Expose the full GraphQL schema through MCP. When enabled, AI models can request the complete schema of your API." + }, + "hot_reload_config": { + "type": "object", + "default": false, + "description": "Hot reloading configuration for MCP operations.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable hot reloading for MCP Operations.", + "default": false + }, + "interval": { + "type": "string", + "description": "The interval at which the MCP Operations directory is checked for changes. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", + "default": "10s", + "duration": { + "minimum": "5s" + } + } + } } } }, diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index 0ede894261..4c944a5af9 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -49,6 +49,9 @@ mcp: base_url: "http://localhost:5025" storage: provider_id: mcp + hot_reload_config: + enabled: false + interval: '10s' watch_config: enabled: true diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index 4474124c11..e8280f8771 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -122,7 +122,11 @@ "ExcludeMutations": false, "EnableArbitraryOperations": false, "ExposeSchema": false, - "RouterURL": "" + "RouterURL": "", + "HotReloadConfig": { + "Enabled": false, + "Interval": 10000000000 + } }, "DemoMode": false, "Modules": null, diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index e39e5f479a..f97e21ad8f 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -157,7 +157,11 @@ "ExcludeMutations": false, "EnableArbitraryOperations": false, "ExposeSchema": false, - "RouterURL": "https://cosmo-router.wundergraph.com" + "RouterURL": "https://cosmo-router.wundergraph.com", + "HotReloadConfig": { + "Enabled": false, + "Interval": 10000000000 + } }, "DemoMode": true, "Modules": { diff --git a/router/pkg/mcpserver/operation_manager.go b/router/pkg/mcpserver/operation_manager.go index 03035b940c..44cd9a5ab9 100644 --- a/router/pkg/mcpserver/operation_manager.go +++ b/router/pkg/mcpserver/operation_manager.go @@ -2,6 +2,7 @@ package mcpserver import ( "fmt" + "time" "github.com/wundergraph/cosmo/router/pkg/schemaloader" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" @@ -30,10 +31,10 @@ func NewOperationsManager(schemaDoc *ast.Document, logger *zap.Logger, excludeMu } // LoadOperationsFromDirectory loads operations from a specified directory -func (om *OperationsManager) LoadOperationsFromDirectory(ReloadOperationsChan chan bool, operationsDir string) error { +func (om *OperationsManager) LoadOperationsFromDirectory(operationsDir string, reloadOperationsChan chan bool, hotReload bool, hotReloadInterval time.Duration) error { // Load operations loader := schemaloader.NewOperationLoader(om.logger, om.schemaDoc) - operations, err := loader.LoadOperationsFromDirectory(ReloadOperationsChan, operationsDir) + operations, err := loader.LoadOperationsFromDirectory(operationsDir, reloadOperationsChan, hotReload, hotReloadInterval) if err != nil { return fmt.Errorf("failed to load operations: %w", err) } diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index 07f4ad27d8..ba836c60d0 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -68,6 +68,10 @@ type Options struct { EnableArbitraryOperations bool // ExposeSchema determines whether the GraphQL schema is exposed ExposeSchema bool + // Enables hot reloading for MCP operations + HotReload bool + // The interval at which the MCP Operations directory is checked for changes + HotReloadInterval time.Duration } // GraphQLSchemaServer represents an MCP server that works with GraphQL schemas and operations @@ -88,7 +92,9 @@ type GraphQLSchemaServer struct { operationsManager *OperationsManager schemaCompiler *SchemaCompiler registeredTools []string - ReloadOperationsChan chan bool + hotReload bool + hotReloadInterval time.Duration + reloadOperationsChan chan bool } type graphqlRequest struct { @@ -214,7 +220,9 @@ func NewGraphQLSchemaServer(routerGraphQLEndpoint string, opts ...func(*Options) enableArbitraryOperations: options.EnableArbitraryOperations, exposeSchema: options.ExposeSchema, baseURL: options.BaseURL, - ReloadOperationsChan: make(chan bool), + hotReload: options.HotReload, + hotReloadInterval: options.HotReloadInterval, + reloadOperationsChan: make(chan bool), } return gs, nil @@ -275,6 +283,20 @@ func WithExposeSchema(exposeSchema bool) func(*Options) { } } +// WithHotReload sets the hot reload option +func WithHotReload(hotReload bool) func(*Options) { + return func(o *Options) { + o.HotReload = hotReload + } +} + +// WithHotReloadInterval sets the hot reload interval +func WithHotReloadInterval(hotReloadInterval time.Duration) func(*Options) { + return func(o *Options) { + o.HotReloadInterval = hotReloadInterval + } +} + // ServeSSE starts the server with SSE transport func (s *GraphQLSchemaServer) ServeSSE() (*server.SSEServer, error) { sseServer := server.NewSSEServer(s.server, @@ -332,7 +354,7 @@ func (s *GraphQLSchemaServer) Reload(schema *ast.Document) error { s.schemaCompiler = NewSchemaCompiler(s.logger) s.operationsManager = NewOperationsManager(schema, s.logger, s.excludeMutations) - if err := s.operationsManager.LoadOperationsFromDirectory(s.ReloadOperationsChan, s.operationsDir); err != nil { + if err := s.operationsManager.LoadOperationsFromDirectory(s.operationsDir, s.reloadOperationsChan, s.hotReload, s.hotReloadInterval); err != nil { return fmt.Errorf("failed to load operations: %w", err) } @@ -725,3 +747,7 @@ func (s *GraphQLSchemaServer) handleGetGraphQLSchema() func(ctx context.Context, return mcp.NewToolResultText(schemaStr), nil } } + +func (s *GraphQLSchemaServer) ReloadOperationsChannel() chan bool { + return s.reloadOperationsChan +} diff --git a/router/pkg/schemaloader/loader.go b/router/pkg/schemaloader/loader.go index 440acaf3e2..e89427f71a 100644 --- a/router/pkg/schemaloader/loader.go +++ b/router/pkg/schemaloader/loader.go @@ -9,9 +9,7 @@ import ( "os" "path/filepath" "strings" - "syscall" "time" - "unsafe" "github.com/wundergraph/cosmo/router/pkg/watcher" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" @@ -49,49 +47,18 @@ func NewOperationLoader(logger *zap.Logger, schemaDoc *ast.Document) *OperationL } // LoadOperationsFromDirectory loads all GraphQL operations from files in the specified directory -func (l *OperationLoader) LoadOperationsFromDirectory(ReloadOperationsChan chan bool, dirPath string) ([]Operation, error) { +func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string, reloadOperationsChan chan bool, hotReload bool, hotReloadInterval time.Duration) ([]Operation, error) { var operations []Operation // Create an operation validator validator := astvalidation.DefaultOperationValidator() - fileDescriptor, err := syscall.InotifyInit() - if err != nil { - return nil, err - } - - watchDescriptor, err := syscall.InotifyAddWatch(fileDescriptor, dirPath, syscall.IN_CREATE|syscall.IN_DELETE|syscall.IN_DELETE_SELF) - if err != nil { - return nil, err - } - pathCtx, pathCtxCancel := context.WithCancel(context.Background()) - startupDelay := 5 * time.Second - - go func() { - inotifyEvent := make([]byte, 4096) - for { - _, err := syscall.Read(fileDescriptor, inotifyEvent) - if err != nil { - break - } - - var event *syscall.InotifyEvent = (*syscall.InotifyEvent)(unsafe.Pointer(&inotifyEvent[0])) - - if event.Mask&syscall.IN_CREATE == syscall.IN_CREATE || event.Mask&syscall.IN_DELETE == syscall.IN_DELETE || event.Mask&syscall.IN_DELETE_SELF == syscall.IN_DELETE_SELF { - break - } - - } - syscall.InotifyRmWatch(fileDescriptor, uint32(watchDescriptor)) - ReloadOperationsChan <- true - pathCtxCancel() - }() + filesSeen := make(map[string]struct{}) - l.Logger.Info("Will walk through " + dirPath) // Walk through the directory and process GraphQL files - err = filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { + err := filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { if err != nil { return err } @@ -112,31 +79,34 @@ func (l *OperationLoader) LoadOperationsFromDirectory(ReloadOperationsChan chan return fmt.Errorf("failed to read file %s: %w", path, err) } - watchFunc, err := watcher.New(watcher.Options{ - Interval: 10 * time.Second, - Logger: l.Logger, - Path: path, - Callback: func() { - ReloadOperationsChan <- true - pathCtxCancel() - }, - }) + if hotReload { + filesSeen[path] = struct{}{} - if err != nil { - l.Logger.Error("Could not create watcher", zap.Error(err)) - return err - } + watchFunc, err := watcher.New(watcher.Options{ + Interval: hotReloadInterval, + Logger: l.Logger, + Path: path, + Callback: func() { + reloadOperationsChan <- true + pathCtxCancel() + }, + }) - go func() { - time.Sleep(startupDelay) - if err := watchFunc(pathCtx); err != nil { - if !errors.Is(err, context.Canceled) { - l.Logger.Error("Error watching operations path", zap.Error(err)) - } else { - l.Logger.Debug("Watcher context cancelled, shutting down") - } + if err != nil { + l.Logger.Error("Could not create watcher", zap.Error(err)) + return err } - }() + + go func() { + if err := watchFunc(pathCtx); err != nil { + if !errors.Is(err, context.Canceled) { + l.Logger.Error("Error watching operations path", zap.Error(err)) + } else { + l.Logger.Debug("Watcher context cancelled, shutting down") + } + } + }() + } // Parse the operation operationString := string(content) @@ -199,6 +169,57 @@ func (l *OperationLoader) LoadOperationsFromDirectory(ReloadOperationsChan chan return nil, fmt.Errorf("error walking mcp operations directory %s: %w", dirPath, err) } + if hotReload { + go func() { + ticker := time.NewTicker(hotReloadInterval) + for { + select { + case <-ticker.C: + operationsCount := 0 + + err := filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + // Skip directories + if d.IsDir() { + return nil + } + + // Only process GraphQL files + if !isGraphQLFile(path) { + return nil + } + + _, ok := filesSeen[path] + + // new operation added + if !ok { + reloadOperationsChan <- true + pathCtxCancel() + return filepath.SkipAll + } + + operationsCount = operationsCount + 1 + + return nil + }) + + if err != nil || operationsCount != len(filesSeen) { + reloadOperationsChan <- true + pathCtxCancel() + return + } + + case <-pathCtx.Done(): + return + } + + } + }() + } + return operations, nil } From 9ed95770f81fc6f559c373a8d30ed86f3391d07a Mon Sep 17 00:00:00 2001 From: melsonic Date: Wed, 25 Jun 2025 23:42:37 +0530 Subject: [PATCH 04/12] fixed leaked goroutines --- router/core/graph_server.go | 9 ++++++--- router/core/router.go | 3 +-- router/pkg/mcpserver/operation_manager.go | 5 +++-- router/pkg/mcpserver/server.go | 14 ++++---------- router/pkg/schemaloader/loader.go | 14 +++++++------- 5 files changed, 21 insertions(+), 24 deletions(-) diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 22443b2d1e..564957a190 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1079,16 +1079,19 @@ func (s *graphServer) buildGraphMux(ctx context.Context, // We support the MCP only on the base graph. Feature flags are not supported yet. if featureFlagName == "" && s.mcpServer != nil { - if mErr := s.mcpServer.Reload(executor.ClientSchema); mErr != nil { + if mErr := s.mcpServer.Reload(ctx, executor.ClientSchema); mErr != nil { return nil, fmt.Errorf("failed to reload MCP server: %w", mErr) } go func() { for { - if reloadOperations := <-s.mcpServer.ReloadOperationsChannel(); reloadOperations { + select { + case <-s.mcpServer.ReloadOperationsChannel(): s.logger.Log(zap.InfoLevel, "Reloading mcp server!") - if mErr := s.mcpServer.Reload(executor.ClientSchema); mErr != nil { + if mErr := s.mcpServer.Reload(ctx, executor.ClientSchema); mErr != nil { return } + case <-ctx.Done(): + return } } }() diff --git a/router/core/router.go b/router/core/router.go index 91af4b3ac0..d953739f54 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -848,8 +848,7 @@ func (r *Router) bootstrap(ctx context.Context) error { mcpserver.WithExcludeMutations(r.mcp.ExcludeMutations), mcpserver.WithEnableArbitraryOperations(r.mcp.EnableArbitraryOperations), mcpserver.WithExposeSchema(r.mcp.ExposeSchema), - mcpserver.WithHotReload(r.mcp.HotReloadConfig.Enabled), - mcpserver.WithHotReloadInterval(r.mcp.HotReloadConfig.Interval), + mcpserver.WithHotReload(r.mcp.HotReloadConfig.Enabled, r.mcp.HotReloadConfig.Interval), } // Determine the router GraphQL endpoint diff --git a/router/pkg/mcpserver/operation_manager.go b/router/pkg/mcpserver/operation_manager.go index 44cd9a5ab9..e055329f13 100644 --- a/router/pkg/mcpserver/operation_manager.go +++ b/router/pkg/mcpserver/operation_manager.go @@ -1,6 +1,7 @@ package mcpserver import ( + "context" "fmt" "time" @@ -31,10 +32,10 @@ func NewOperationsManager(schemaDoc *ast.Document, logger *zap.Logger, excludeMu } // LoadOperationsFromDirectory loads operations from a specified directory -func (om *OperationsManager) LoadOperationsFromDirectory(operationsDir string, reloadOperationsChan chan bool, hotReload bool, hotReloadInterval time.Duration) error { +func (om *OperationsManager) LoadOperationsFromDirectory(ctx context.Context, operationsDir string, reloadOperationsChan chan bool, hotReload bool, hotReloadInterval time.Duration) error { // Load operations loader := schemaloader.NewOperationLoader(om.logger, om.schemaDoc) - operations, err := loader.LoadOperationsFromDirectory(operationsDir, reloadOperationsChan, hotReload, hotReloadInterval) + operations, err := loader.LoadOperationsFromDirectory(ctx, operationsDir, reloadOperationsChan, hotReload, hotReloadInterval) if err != nil { return fmt.Errorf("failed to load operations: %w", err) } diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index ba836c60d0..8d671b096b 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -283,16 +283,10 @@ func WithExposeSchema(exposeSchema bool) func(*Options) { } } -// WithHotReload sets the hot reload option -func WithHotReload(hotReload bool) func(*Options) { +// WithHotReload sets the hot reload options +func WithHotReload(hotReload bool, hotReloadInterval time.Duration) func(*Options) { return func(o *Options) { o.HotReload = hotReload - } -} - -// WithHotReloadInterval sets the hot reload interval -func WithHotReloadInterval(hotReloadInterval time.Duration) func(*Options) { - return func(o *Options) { o.HotReloadInterval = hotReloadInterval } } @@ -345,7 +339,7 @@ func (s *GraphQLSchemaServer) Start() error { } // Reload reloads the operations and schema -func (s *GraphQLSchemaServer) Reload(schema *ast.Document) error { +func (s *GraphQLSchemaServer) Reload(ctx context.Context, schema *ast.Document) error { if s.server == nil { return fmt.Errorf("server is not started") @@ -354,7 +348,7 @@ func (s *GraphQLSchemaServer) Reload(schema *ast.Document) error { s.schemaCompiler = NewSchemaCompiler(s.logger) s.operationsManager = NewOperationsManager(schema, s.logger, s.excludeMutations) - if err := s.operationsManager.LoadOperationsFromDirectory(s.operationsDir, s.reloadOperationsChan, s.hotReload, s.hotReloadInterval); err != nil { + if err := s.operationsManager.LoadOperationsFromDirectory(ctx, s.operationsDir, s.reloadOperationsChan, s.hotReload, s.hotReloadInterval); err != nil { return fmt.Errorf("failed to load operations: %w", err) } diff --git a/router/pkg/schemaloader/loader.go b/router/pkg/schemaloader/loader.go index e89427f71a..ba3f02eb29 100644 --- a/router/pkg/schemaloader/loader.go +++ b/router/pkg/schemaloader/loader.go @@ -47,13 +47,13 @@ func NewOperationLoader(logger *zap.Logger, schemaDoc *ast.Document) *OperationL } // LoadOperationsFromDirectory loads all GraphQL operations from files in the specified directory -func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string, reloadOperationsChan chan bool, hotReload bool, hotReloadInterval time.Duration) ([]Operation, error) { +func (l *OperationLoader) LoadOperationsFromDirectory(ctx context.Context, dirPath string, reloadOperationsChan chan bool, hotReload bool, hotReloadInterval time.Duration) ([]Operation, error) { var operations []Operation // Create an operation validator validator := astvalidation.DefaultOperationValidator() - pathCtx, pathCtxCancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) filesSeen := make(map[string]struct{}) @@ -88,7 +88,7 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string, reloadOper Path: path, Callback: func() { reloadOperationsChan <- true - pathCtxCancel() + cancel() }, }) @@ -98,7 +98,7 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string, reloadOper } go func() { - if err := watchFunc(pathCtx); err != nil { + if err := watchFunc(ctx); err != nil { if !errors.Is(err, context.Canceled) { l.Logger.Error("Error watching operations path", zap.Error(err)) } else { @@ -197,7 +197,7 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string, reloadOper // new operation added if !ok { reloadOperationsChan <- true - pathCtxCancel() + cancel() return filepath.SkipAll } @@ -208,11 +208,11 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string, reloadOper if err != nil || operationsCount != len(filesSeen) { reloadOperationsChan <- true - pathCtxCancel() + cancel() return } - case <-pathCtx.Done(): + case <-ctx.Done(): return } From dbcc0241f63641d935ea904b3fcb6c3ebd532ca0 Mon Sep 17 00:00:00 2001 From: melsonic Date: Wed, 25 Jun 2025 23:43:12 +0530 Subject: [PATCH 05/12] updated mcp hot reload tests + added shutdown test --- router-tests/mcp_hot_reload_test.go | 209 ++++++++++++++++++++-------- router-tests/testenv/testenv.go | 20 +-- 2 files changed, 161 insertions(+), 68 deletions(-) diff --git a/router-tests/mcp_hot_reload_test.go b/router-tests/mcp_hot_reload_test.go index 3bf7601012..f4ff26657e 100644 --- a/router-tests/mcp_hot_reload_test.go +++ b/router-tests/mcp_hot_reload_test.go @@ -3,7 +3,6 @@ package integration import ( "encoding/json" "os" - "path" "testing" "time" @@ -11,67 +10,61 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" "github.com/wundergraph/cosmo/router/pkg/config" + "go.uber.org/goleak" ) func TestMCPOperationHotReload(t *testing.T) { - // create a temp graphql file into mcp_operations - mcpOperationsDirectory := "./testdata/mcp_operations" - fileName := "getEmployeeNotes.graphql" - filePath := path.Join(mcpOperationsDirectory, fileName) - t.Parallel() t.Run("List Updated User Operations On Addition and Removal", func(t *testing.T) { + + operationsDir := t.TempDir() + storageProviderId := "mcp_hot_reload_test_id" + testenv.Run(t, &testenv.Config{ MCP: config.MCPConfiguration{ Enabled: true, + Storage: config.MCPStorageConfig{ + ProviderID: storageProviderId, + }, HotReloadConfig: config.MCPOperationsHotReloadConfig{ Enabled: true, Interval: 5 * time.Second, }, }, + RouterOptions: []core.Option{ + core.WithStorageProviders(config.StorageProviders{ + FileSystem: []config.FileSystemStorageProvider{ + { + ID: storageProviderId, + Path: operationsDir, + }, + }, + }), + }, }, func(t *testing.T, xEnv *testenv.Environment) { - os.Remove(filePath) - toolsRequest := mcp.ListToolsRequest{} resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) require.NoError(t, err) - require.NotNil(t, resp) - // initial tools count initialToolsCount := len(resp.Tools) - // create new mcp operation file - file, err := os.Create(filePath) - assert.NoError(t, err) - defer func() { - file.Close() - os.Remove(filePath) - }() + filePath := operationsDir + "/main.graphql" // write mcp operation content - _, err = file.WriteString(` - query getEmployeeNotes($id: Int!) { - employee(id: $id) { - id - notes - } - } - `) + err = os.WriteFile(filePath, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0644) assert.NoError(t, err) require.EventuallyWithT(t, func(t *assert.CollectT) { - // List updated Tools - toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) assert.NoError(t, err) - - // verify updated tools count assert.Len(t, resp.Tools, initialToolsCount+1) - // verity getEmployeeNotes operation + // verity getEmployeeNotes operation is present require.Contains(t, resp.Tools, mcp.Tool{ Name: "execute_operation_get_employee_notes", Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", @@ -90,48 +83,78 @@ func TestMCPOperationHotReload(t *testing.T) { }) }, 15*time.Second, 250*time.Millisecond) + err = os.Remove(filePath) + assert.NoError(t, err) + + assert.EventuallyWithT(t, func(t *assert.CollectT) { + + resp, err = xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + assert.Len(t, resp.Tools, initialToolsCount) + + // verity getEmployeeNotes operation tool is properly removed + require.NotContains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes", + Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotes", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + + }, 15*time.Second, 250*time.Millisecond) + }) }) t.Run("List Updated User Operations On Content Update", func(t *testing.T) { + operationsDir := t.TempDir() + storageProviderId := "mcp_hot_reload_test_id" + testenv.Run(t, &testenv.Config{ MCP: config.MCPConfiguration{ Enabled: true, + Storage: config.MCPStorageConfig{ + ProviderID: storageProviderId, + }, HotReloadConfig: config.MCPOperationsHotReloadConfig{ Enabled: true, Interval: 5 * time.Second, }, }, + RouterOptions: []core.Option{ + core.WithStorageProviders(config.StorageProviders{ + FileSystem: []config.FileSystemStorageProvider{ + { + ID: storageProviderId, + Path: operationsDir, + }, + }, + }), + }, }, func(t *testing.T, xEnv *testenv.Environment) { - os.Remove(filePath) - - // create new mcp operation file - file, err := os.Create(filePath) - assert.NoError(t, err) - defer func() { - file.Close() - os.Remove(filePath) - }() + filePath := operationsDir + "/main.graphql" // write mcp operation content - _, err = file.WriteString(` - query getEmployeeNotes($id: Int!) { - employee(id: $id) { - id - notes - } - } - `) + err := os.WriteFile(filePath, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0644) assert.NoError(t, err) require.EventuallyWithT(t, func(t *assert.CollectT) { - // List updated Tools + toolsRequest := mcp.ListToolsRequest{} resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) assert.NoError(t, err) - // verity getEmployeeNotes operation + // verity getEmployeeNotes operation is present require.Contains(t, resp.Tools, mcp.Tool{ Name: "execute_operation_get_employee_notes", Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", @@ -151,23 +174,16 @@ func TestMCPOperationHotReload(t *testing.T) { }, 15*time.Second, 250*time.Millisecond) // update mcp operation content - err = os.WriteFile(filePath, []byte(` - query getEmployeeNotesUpdatedTitle($id: Int!) { - employee(id: $id) { - id - notes - } - } - `), 0644) + err = os.WriteFile(filePath, []byte("\nquery getEmployeeNotesUpdatedTitle($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0644) assert.NoError(t, err) require.EventuallyWithT(t, func(t *assert.CollectT) { - // fetch updated mcp tools list + toolsRequest := mcp.ListToolsRequest{} resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) assert.NoError(t, err) - // verity getEmployeeNotesUpdatedTitle operation + // verity getEmployeeNotesUpdatedTitle operation is present require.Contains(t, resp.Tools, mcp.Tool{ Name: "execute_operation_get_employee_notes_updated_title", Description: "Executes the GraphQL operation 'getEmployeeNotesUpdatedTitle' of type query.", @@ -187,5 +203,80 @@ func TestMCPOperationHotReload(t *testing.T) { }, 15*time.Second, 250*time.Millisecond) }) }) +} + +func TestShutDownMCPGoRoutineLeaks(t *testing.T) { + + defer goleak.VerifyNone(t, + goleak.IgnoreTopFunction("github.com/hashicorp/consul/sdk/freeport.checkFreedPorts"), // Freeport, spawned by init + goleak.IgnoreAnyFunction("net/http.(*conn).serve"), // HTTPTest server I can't close if I want to keep the problematic goroutine open for the test + ) + + operationsDir := t.TempDir() + storageProviderId := "mcp_hot_reload_test_id" + + xEnv, err := testenv.CreateTestEnv(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + Storage: config.MCPStorageConfig{ + ProviderID: storageProviderId, + }, + HotReloadConfig: config.MCPOperationsHotReloadConfig{ + Enabled: true, + Interval: 5 * time.Second, + }, + }, + RouterOptions: []core.Option{ + core.WithStorageProviders(config.StorageProviders{ + FileSystem: []config.FileSystemStorageProvider{ + { + ID: storageProviderId, + Path: operationsDir, + }, + }, + }), + }, + }) + + require.NoError(t, err) + + filePath := operationsDir + "/main.graphql" + // write mcp operation content + err = os.WriteFile(filePath, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0644) + assert.NoError(t, err) + + // Verify GoRoutines are properly setup for Hot Reloading + require.EventuallyWithT(t, func(t *assert.CollectT) { + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + + require.Contains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes", + Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotes", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + }, 15*time.Second, 250*time.Millisecond) + + xEnv.Shutdown() + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + if assert.Error(t, err) { + require.ErrorIs(t, err, testenv.ErrEnvironmentClosed) + } + require.Nil(t, resp) } diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 616a402a7f..369279d30d 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -1341,17 +1341,19 @@ func configureRouter(listenerAddr string, testConfig *Config, routerConfig *node } if testConfig.MCP.Enabled { - // Add Storage provider - routerOpts = append(routerOpts, core.WithStorageProviders(config.StorageProviders{ - FileSystem: []config.FileSystemStorageProvider{ - { - ID: "test", - Path: "testdata/mcp_operations", + if testConfig.MCP.Storage.ProviderID == "" { + // Add Storage provider + routerOpts = append(routerOpts, core.WithStorageProviders(config.StorageProviders{ + FileSystem: []config.FileSystemStorageProvider{ + { + ID: "test", + Path: "testdata/mcp_operations", + }, }, - }, - })) + })) - testConfig.MCP.Storage.ProviderID = "test" + testConfig.MCP.Storage.ProviderID = "test" + } routerOpts = append(routerOpts, core.WithMCP(testConfig.MCP)) } From 84912f3b4d02cacf4f798d81666114c52683bd61 Mon Sep 17 00:00:00 2001 From: melsonic Date: Sat, 5 Jul 2025 01:26:28 +0530 Subject: [PATCH 06/12] added directory watching capability to cosmo watcher with filter --- router/pkg/schemaloader/loader.go | 101 ++++++++---------------------- router/pkg/watcher/watcher.go | 98 ++++++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 76 deletions(-) diff --git a/router/pkg/schemaloader/loader.go b/router/pkg/schemaloader/loader.go index ba3f02eb29..1ea837d83f 100644 --- a/router/pkg/schemaloader/loader.go +++ b/router/pkg/schemaloader/loader.go @@ -55,8 +55,6 @@ func (l *OperationLoader) LoadOperationsFromDirectory(ctx context.Context, dirPa ctx, cancel := context.WithCancel(ctx) - filesSeen := make(map[string]struct{}) - // Walk through the directory and process GraphQL files err := filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { if err != nil { @@ -79,35 +77,6 @@ func (l *OperationLoader) LoadOperationsFromDirectory(ctx context.Context, dirPa return fmt.Errorf("failed to read file %s: %w", path, err) } - if hotReload { - filesSeen[path] = struct{}{} - - watchFunc, err := watcher.New(watcher.Options{ - Interval: hotReloadInterval, - Logger: l.Logger, - Path: path, - Callback: func() { - reloadOperationsChan <- true - cancel() - }, - }) - - if err != nil { - l.Logger.Error("Could not create watcher", zap.Error(err)) - return err - } - - go func() { - if err := watchFunc(ctx); err != nil { - if !errors.Is(err, context.Canceled) { - l.Logger.Error("Error watching operations path", zap.Error(err)) - } else { - l.Logger.Debug("Watcher context cancelled, shutting down") - } - } - }() - } - // Parse the operation operationString := string(content) opDoc, err := parseOperation(path, operationString) @@ -166,58 +135,42 @@ func (l *OperationLoader) LoadOperationsFromDirectory(ctx context.Context, dirPa }) if err != nil { + cancel() return nil, fmt.Errorf("error walking mcp operations directory %s: %w", dirPath, err) } if hotReload { - go func() { - ticker := time.NewTicker(hotReloadInterval) - for { - select { - case <-ticker.C: - operationsCount := 0 - - err := filepath.WalkDir(dirPath, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - - // Skip directories - if d.IsDir() { - return nil - } - - // Only process GraphQL files - if !isGraphQLFile(path) { - return nil - } - - _, ok := filesSeen[path] - - // new operation added - if !ok { - reloadOperationsChan <- true - cancel() - return filepath.SkipAll - } - - operationsCount = operationsCount + 1 - - return nil - }) + watchFunc, err := watcher.New(watcher.Options{ + Interval: hotReloadInterval, + Logger: l.Logger, + Callback: func() { + reloadOperationsChan <- true + cancel() + }, + Directory: watcher.DirOptions{ + DirPath: dirPath, + Filter: func(path string) bool { + return !isGraphQLFile(path) + }, + }, + }) - if err != nil || operationsCount != len(filesSeen) { - reloadOperationsChan <- true - cancel() - return - } + if err != nil { + cancel() + l.Logger.Error("Could not create watcher", zap.Error(err)) + } - case <-ctx.Done(): - return + go func() { + if err := watchFunc(ctx); err != nil { + if !errors.Is(err, context.Canceled) { + l.Logger.Error("Error watching operations path", zap.Error(err)) + } else { + l.Logger.Debug("Watcher context cancelled, shutting down") } - } }() + } else { + cancel() } return operations, nil diff --git a/router/pkg/watcher/watcher.go b/router/pkg/watcher/watcher.go index 1990a9b502..3b90e76297 100644 --- a/router/pkg/watcher/watcher.go +++ b/router/pkg/watcher/watcher.go @@ -3,16 +3,25 @@ package watcher import ( "context" "errors" + "fmt" + "io/fs" "os" + "path/filepath" "time" "go.uber.org/zap" ) +type DirOptions struct { + DirPath string + Filter func(string) bool +} + type Options struct { Interval time.Duration Logger *zap.Logger Paths []string + Directory DirOptions Callback func() TickSource <-chan time.Time } @@ -26,8 +35,8 @@ func New(options Options) (func(ctx context.Context) error, error) { return nil, errors.New("logger must be provided") } - if len(options.Paths) == 0 { - return nil, errors.New("path must be provided") + if len(options.Paths) == 0 && options.Directory.DirPath == "" { + return nil, errors.New("paths or directory must be provided") } if options.Callback == nil { @@ -36,6 +45,31 @@ func New(options Options) (func(ctx context.Context) error, error) { ll := options.Logger.With(zap.String("component", "file_watcher"), zap.Strings("path", options.Paths)) + listDirFilePaths := func() ([]string, error) { + var files []string + if options.Directory.DirPath != "" { + err := filepath.WalkDir(options.Directory.DirPath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + // Skip directories + if d.IsDir() { + return nil + } + // Skip if filter rejects the file + if options.Directory.Filter != nil && options.Directory.Filter(path) { + return nil + } + files = append(files, path) + return nil + }) + if err != nil { + return nil, fmt.Errorf("error walking directory %s: %w", options.Directory.DirPath, err) + } + } + return files, nil + } + return func(ctx context.Context) error { // If a ticker source is provided, use that instead of the default ticker // The ticker source is right now used for testing @@ -45,6 +79,24 @@ func New(options Options) (func(ctx context.Context) error, error) { } prevModTimes := make(map[string]time.Time) + seenDirFilePaths := make(map[string]struct{}) + + dirFilePaths, err := listDirFilePaths() + if err != nil { + ll.Error("failed to list directory files", zap.Error(err)) + return err + } + + for _, path := range dirFilePaths { + stat, err := os.Stat(path) + if err != nil { + ll.Debug("Target file cannot be statted", zap.Error(err)) + } else { + prevModTimes[path] = stat.ModTime() + ll.Debug("Watching file for changes", zap.String("path", path), zap.Time("initial_mod_time", prevModTimes[path])) + } + seenDirFilePaths[path] = struct{}{} + } for _, path := range options.Paths { stat, err := os.Stat(path) @@ -63,6 +115,48 @@ func New(options Options) (func(ctx context.Context) error, error) { case <-ticker: changesDetected := false + dirFilePaths, err := listDirFilePaths() + if err != nil { + ll.Error("failed to list directory files", zap.Error(err)) + return err + } + + visitedDirFilePaths := make(map[string]struct{}) + + for _, path := range dirFilePaths { + stat, err := os.Stat(path) + if err != nil { + ll.Debug("Target file cannot be statted", zap.String("path", path), zap.Error(err)) + // Reset the mod time so we catch any new file at the target path + prevModTimes[path] = time.Time{} + continue + } + ll.Debug("Checking file for changes", + zap.String("path", path), + zap.Time("prev_mod_time", prevModTimes[path]), + zap.Time("current_mod_time", stat.ModTime()), + ) + _, seen := seenDirFilePaths[path] + // Detects new files & existing file updates in `options.Directory.DirPath` + if !seen || stat.ModTime().After(prevModTimes[path]) { + seenDirFilePaths[path] = struct{}{} + prevModTimes[path] = stat.ModTime() + changesDetected = true + } + visitedDirFilePaths[path] = struct{}{} + } + + // Detects deleted files + if len(seenDirFilePaths) > len(dirFilePaths) { + changesDetected = true + } + + for path := range seenDirFilePaths { + if _, ok := visitedDirFilePaths[path]; !ok { + delete(seenDirFilePaths, path) + } + } + for _, path := range options.Paths { stat, err := os.Stat(path) if err != nil { From 5a6a9a2c0cc04bdc4d6127ca52182a5663262379 Mon Sep 17 00:00:00 2001 From: melsonic Date: Sat, 5 Jul 2025 01:43:38 +0530 Subject: [PATCH 07/12] fix coderabbitai comments --- router/pkg/config/config.schema.json | 1 - router/pkg/watcher/watcher.go | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 550fbaabf1..91a9430cd2 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1902,7 +1902,6 @@ }, "hot_reload_config": { "type": "object", - "default": false, "description": "Hot reloading configuration for MCP operations.", "additionalProperties": false, "properties": { diff --git a/router/pkg/watcher/watcher.go b/router/pkg/watcher/watcher.go index 3b90e76297..0db70ee2c5 100644 --- a/router/pkg/watcher/watcher.go +++ b/router/pkg/watcher/watcher.go @@ -64,7 +64,7 @@ func New(options Options) (func(ctx context.Context) error, error) { return nil }) if err != nil { - return nil, fmt.Errorf("error walking directory %s: %w", options.Directory.DirPath, err) + return []string{}, fmt.Errorf("error walking directory %s: %w", options.Directory.DirPath, err) } } return files, nil @@ -84,7 +84,6 @@ func New(options Options) (func(ctx context.Context) error, error) { dirFilePaths, err := listDirFilePaths() if err != nil { ll.Error("failed to list directory files", zap.Error(err)) - return err } for _, path := range dirFilePaths { @@ -118,7 +117,6 @@ func New(options Options) (func(ctx context.Context) error, error) { dirFilePaths, err := listDirFilePaths() if err != nil { ll.Error("failed to list directory files", zap.Error(err)) - return err } visitedDirFilePaths := make(map[string]struct{}) From 01d0551a45ee98448f566e72ec3fa77b2a334429 Mon Sep 17 00:00:00 2001 From: melsonic Date: Wed, 16 Jul 2025 00:01:23 +0530 Subject: [PATCH 08/12] extracted list dir files func + removed redundant maps --- router/pkg/schemaloader/loader.go | 2 +- router/pkg/watcher/watcher.go | 126 +++++++++++------------------- 2 files changed, 47 insertions(+), 81 deletions(-) diff --git a/router/pkg/schemaloader/loader.go b/router/pkg/schemaloader/loader.go index 1ea837d83f..a3510843d7 100644 --- a/router/pkg/schemaloader/loader.go +++ b/router/pkg/schemaloader/loader.go @@ -150,7 +150,7 @@ func (l *OperationLoader) LoadOperationsFromDirectory(ctx context.Context, dirPa Directory: watcher.DirOptions{ DirPath: dirPath, Filter: func(path string) bool { - return !isGraphQLFile(path) + return isGraphQLFile(path) }, }, }) diff --git a/router/pkg/watcher/watcher.go b/router/pkg/watcher/watcher.go index 0db70ee2c5..d3ea1bc252 100644 --- a/router/pkg/watcher/watcher.go +++ b/router/pkg/watcher/watcher.go @@ -26,6 +26,31 @@ type Options struct { TickSource <-chan time.Time } +func ListDirFilePaths(diropts DirOptions) ([]string, error) { + var files []string + if diropts.DirPath != "" { + err := filepath.WalkDir(diropts.DirPath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + // Skip directories + if d.IsDir() { + return nil + } + // Accept if filter passes the file + if diropts.Filter != nil && !diropts.Filter(path) { + return nil + } + files = append(files, path) + return nil + }) + if err != nil { + return []string{}, fmt.Errorf("error walking directory %s: %w", diropts.DirPath, err) + } + } + return files, nil +} + func New(options Options) (func(ctx context.Context) error, error) { if options.Interval <= 0 { return nil, errors.New("interval must be greater than zero") @@ -35,8 +60,8 @@ func New(options Options) (func(ctx context.Context) error, error) { return nil, errors.New("logger must be provided") } - if len(options.Paths) == 0 && options.Directory.DirPath == "" { - return nil, errors.New("paths or directory must be provided") + if len(options.Paths) != 0 && options.Directory.DirPath != "" { + return nil, errors.New("can't watch both paths and directory") } if options.Callback == nil { @@ -45,31 +70,6 @@ func New(options Options) (func(ctx context.Context) error, error) { ll := options.Logger.With(zap.String("component", "file_watcher"), zap.Strings("path", options.Paths)) - listDirFilePaths := func() ([]string, error) { - var files []string - if options.Directory.DirPath != "" { - err := filepath.WalkDir(options.Directory.DirPath, func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - // Skip directories - if d.IsDir() { - return nil - } - // Skip if filter rejects the file - if options.Directory.Filter != nil && options.Directory.Filter(path) { - return nil - } - files = append(files, path) - return nil - }) - if err != nil { - return []string{}, fmt.Errorf("error walking directory %s: %w", options.Directory.DirPath, err) - } - } - return files, nil - } - return func(ctx context.Context) error { // If a ticker source is provided, use that instead of the default ticker // The ticker source is right now used for testing @@ -79,22 +79,13 @@ func New(options Options) (func(ctx context.Context) error, error) { } prevModTimes := make(map[string]time.Time) - seenDirFilePaths := make(map[string]struct{}) - - dirFilePaths, err := listDirFilePaths() - if err != nil { - ll.Error("failed to list directory files", zap.Error(err)) - } - for _, path := range dirFilePaths { - stat, err := os.Stat(path) + var err error + if options.Directory.DirPath != "" { + options.Paths, err = ListDirFilePaths(options.Directory) if err != nil { - ll.Debug("Target file cannot be statted", zap.Error(err)) - } else { - prevModTimes[path] = stat.ModTime() - ll.Debug("Watching file for changes", zap.String("path", path), zap.Time("initial_mod_time", prevModTimes[path])) + ll.Error("failed to list directory files", zap.Error(err)) } - seenDirFilePaths[path] = struct{}{} } for _, path := range options.Paths { @@ -114,19 +105,17 @@ func New(options Options) (func(ctx context.Context) error, error) { case <-ticker: changesDetected := false - dirFilePaths, err := listDirFilePaths() - if err != nil { - ll.Error("failed to list directory files", zap.Error(err)) - } - - visitedDirFilePaths := make(map[string]struct{}) - - for _, path := range dirFilePaths { + for _, path := range options.Paths { stat, err := os.Stat(path) if err != nil { ll.Debug("Target file cannot be statted", zap.String("path", path), zap.Error(err)) - // Reset the mod time so we catch any new file at the target path - prevModTimes[path] = time.Time{} + if os.IsNotExist(err) { + delete(prevModTimes, path) + changesDetected = true + } else { + // Reset the mod time so we catch any new file at the target path + prevModTimes[path] = time.Time{} + } continue } ll.Debug("Checking file for changes", @@ -134,43 +123,19 @@ func New(options Options) (func(ctx context.Context) error, error) { zap.Time("prev_mod_time", prevModTimes[path]), zap.Time("current_mod_time", stat.ModTime()), ) - _, seen := seenDirFilePaths[path] - // Detects new files & existing file updates in `options.Directory.DirPath` + _, seen := prevModTimes[path] + + // Detects new files & existing file updates if !seen || stat.ModTime().After(prevModTimes[path]) { - seenDirFilePaths[path] = struct{}{} prevModTimes[path] = stat.ModTime() changesDetected = true } - visitedDirFilePaths[path] = struct{}{} - } - - // Detects deleted files - if len(seenDirFilePaths) > len(dirFilePaths) { - changesDetected = true - } - - for path := range seenDirFilePaths { - if _, ok := visitedDirFilePaths[path]; !ok { - delete(seenDirFilePaths, path) - } } - for _, path := range options.Paths { - stat, err := os.Stat(path) + if options.Directory.DirPath != "" { + options.Paths, err = ListDirFilePaths(options.Directory) if err != nil { - ll.Debug("Target file cannot be statted", zap.String("path", path), zap.Error(err)) - // Reset the mod time so we catch any new file at the target path - prevModTimes[path] = time.Time{} - continue - } - ll.Debug("Checking file for changes", - zap.String("path", path), - zap.Time("prev_mod_time", prevModTimes[path]), - zap.Time("current_mod_time", stat.ModTime()), - ) - if stat.ModTime().After(prevModTimes[path]) { - prevModTimes[path] = stat.ModTime() - changesDetected = true + ll.Error("failed to list directory files", zap.Error(err)) } } @@ -184,6 +149,7 @@ func New(options Options) (func(ctx context.Context) error, error) { // but the previous tick had changes detected pendingCallback = false options.Callback() + ll.Info("Running callback!") } case <-ctx.Done(): return ctx.Err() From e84de468a0e6483482018ebfa2ef943322c1a98e Mon Sep 17 00:00:00 2001 From: melsonic Date: Wed, 16 Jul 2025 00:02:04 +0530 Subject: [PATCH 09/12] added new watcher tests --- router-tests/mcp_hot_reload_test.go | 57 ++++--- router/pkg/watcher/watcher_test.go | 227 +++++++++++++++++++++++++--- 2 files changed, 232 insertions(+), 52 deletions(-) diff --git a/router-tests/mcp_hot_reload_test.go b/router-tests/mcp_hot_reload_test.go index f4ff26657e..4d606dc95d 100644 --- a/router-tests/mcp_hot_reload_test.go +++ b/router-tests/mcp_hot_reload_test.go @@ -3,6 +3,7 @@ package integration import ( "encoding/json" "os" + "path/filepath" "testing" "time" @@ -17,28 +18,27 @@ import ( func TestMCPOperationHotReload(t *testing.T) { t.Parallel() + operationsDir := t.TempDir() + storageProviderID := "mcp_hot_reload_test_id" t.Run("List Updated User Operations On Addition and Removal", func(t *testing.T) { - operationsDir := t.TempDir() - storageProviderId := "mcp_hot_reload_test_id" - testenv.Run(t, &testenv.Config{ MCP: config.MCPConfiguration{ Enabled: true, Storage: config.MCPStorageConfig{ - ProviderID: storageProviderId, + ProviderID: storageProviderID, }, HotReloadConfig: config.MCPOperationsHotReloadConfig{ Enabled: true, - Interval: 5 * time.Second, + Interval: 1 * time.Second, }, }, RouterOptions: []core.Option{ core.WithStorageProviders(config.StorageProviders{ FileSystem: []config.FileSystemStorageProvider{ { - ID: storageProviderId, + ID: storageProviderID, Path: operationsDir, }, }, @@ -52,10 +52,10 @@ func TestMCPOperationHotReload(t *testing.T) { initialToolsCount := len(resp.Tools) - filePath := operationsDir + "/main.graphql" + mcpOperationFile := filepath.Join(operationsDir, "main.graphql") // write mcp operation content - err = os.WriteFile(filePath, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0644) + err = os.WriteFile(mcpOperationFile, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0644) assert.NoError(t, err) require.EventuallyWithT(t, func(t *assert.CollectT) { @@ -81,9 +81,9 @@ func TestMCPOperationHotReload(t *testing.T) { OpenWorldHint: mcp.ToBoolPtr(true), }, }) - }, 15*time.Second, 250*time.Millisecond) + }, 10*time.Second, 100*time.Millisecond) - err = os.Remove(filePath) + err = os.Remove(mcpOperationFile) assert.NoError(t, err) assert.EventuallyWithT(t, func(t *assert.CollectT) { @@ -110,31 +110,29 @@ func TestMCPOperationHotReload(t *testing.T) { }, }) - }, 15*time.Second, 250*time.Millisecond) + }, 10*time.Second, 100*time.Millisecond) }) }) t.Run("List Updated User Operations On Content Update", func(t *testing.T) { - operationsDir := t.TempDir() - storageProviderId := "mcp_hot_reload_test_id" testenv.Run(t, &testenv.Config{ MCP: config.MCPConfiguration{ Enabled: true, Storage: config.MCPStorageConfig{ - ProviderID: storageProviderId, + ProviderID: storageProviderID, }, HotReloadConfig: config.MCPOperationsHotReloadConfig{ Enabled: true, - Interval: 5 * time.Second, + Interval: 1 * time.Second, }, }, RouterOptions: []core.Option{ core.WithStorageProviders(config.StorageProviders{ FileSystem: []config.FileSystemStorageProvider{ { - ID: storageProviderId, + ID: storageProviderID, Path: operationsDir, }, }, @@ -142,11 +140,10 @@ func TestMCPOperationHotReload(t *testing.T) { }, }, func(t *testing.T, xEnv *testenv.Environment) { - filePath := operationsDir + "/main.graphql" + mcpOperationFile := filepath.Join(operationsDir, "main.graphql") // write mcp operation content - err := os.WriteFile(filePath, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0644) - assert.NoError(t, err) + require.NoError(t, os.WriteFile(mcpOperationFile, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0o600)) require.EventuallyWithT(t, func(t *assert.CollectT) { @@ -171,11 +168,10 @@ func TestMCPOperationHotReload(t *testing.T) { OpenWorldHint: mcp.ToBoolPtr(true), }, }) - }, 15*time.Second, 250*time.Millisecond) + }, 10*time.Second, 100*time.Millisecond) // update mcp operation content - err = os.WriteFile(filePath, []byte("\nquery getEmployeeNotesUpdatedTitle($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0644) - assert.NoError(t, err) + require.NoError(t, os.WriteFile(mcpOperationFile, []byte("\nquery getEmployeeNotesUpdatedTitle($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0o600)) require.EventuallyWithT(t, func(t *assert.CollectT) { @@ -200,7 +196,7 @@ func TestMCPOperationHotReload(t *testing.T) { OpenWorldHint: mcp.ToBoolPtr(true), }, }) - }, 15*time.Second, 250*time.Millisecond) + }, 10*time.Second, 100*time.Millisecond) }) }) } @@ -213,24 +209,24 @@ func TestShutDownMCPGoRoutineLeaks(t *testing.T) { ) operationsDir := t.TempDir() - storageProviderId := "mcp_hot_reload_test_id" + storageProviderID := "mcp_hot_reload_test_id" xEnv, err := testenv.CreateTestEnv(t, &testenv.Config{ MCP: config.MCPConfiguration{ Enabled: true, Storage: config.MCPStorageConfig{ - ProviderID: storageProviderId, + ProviderID: storageProviderID, }, HotReloadConfig: config.MCPOperationsHotReloadConfig{ Enabled: true, - Interval: 5 * time.Second, + Interval: 1 * time.Second, }, }, RouterOptions: []core.Option{ core.WithStorageProviders(config.StorageProviders{ FileSystem: []config.FileSystemStorageProvider{ { - ID: storageProviderId, + ID: storageProviderID, Path: operationsDir, }, }, @@ -240,10 +236,9 @@ func TestShutDownMCPGoRoutineLeaks(t *testing.T) { require.NoError(t, err) - filePath := operationsDir + "/main.graphql" + mcpOperationFile := filepath.Join(operationsDir, "main.graphql") // write mcp operation content - err = os.WriteFile(filePath, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0644) - assert.NoError(t, err) + require.NoError(t, os.WriteFile(mcpOperationFile, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0o600)) // Verify GoRoutines are properly setup for Hot Reloading require.EventuallyWithT(t, func(t *assert.CollectT) { @@ -268,7 +263,7 @@ func TestShutDownMCPGoRoutineLeaks(t *testing.T) { OpenWorldHint: mcp.ToBoolPtr(true), }, }) - }, 15*time.Second, 250*time.Millisecond) + }, 10*time.Second, 100*time.Millisecond) xEnv.Shutdown() diff --git a/router/pkg/watcher/watcher_test.go b/router/pkg/watcher/watcher_test.go index db128d3dc8..877e443cec 100644 --- a/router/pkg/watcher/watcher_test.go +++ b/router/pkg/watcher/watcher_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "path/filepath" + "strings" "testing" "time" @@ -52,30 +53,22 @@ func TestOptionsValidation(t *testing.T) { } }) - t.Run("path not provided", func(t *testing.T) { + t.Run("can't watch both paths and directory", func(t *testing.T) { t.Parallel() + dir := t.TempDir() + tempFile := filepath.Join(dir, "temp_1.json") - t.Run("nil path slice", func(t *testing.T) { - _, err := watcher.New(watcher.Options{ - Interval: watchInterval, - Logger: zap.NewNop(), - Paths: nil, - }) - if assert.Error(t, err) { - assert.ErrorContains(t, err, "path must be provided") - } - }) - - t.Run("empty path slice", func(t *testing.T) { - _, err := watcher.New(watcher.Options{ - Interval: watchInterval, - Logger: zap.NewNop(), - Paths: []string{}, - }) - if assert.Error(t, err) { - assert.ErrorContains(t, err, "path must be provided") - } + _, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + Paths: []string{tempFile}, + Directory: watcher.DirOptions{ + DirPath: dir, + }, }) + if assert.Error(t, err) { + assert.ErrorContains(t, err, "can't watch both paths and directory") + } }) t.Run("callback not provided", func(t *testing.T) { @@ -775,6 +768,169 @@ func TestWatch(t *testing.T) { sendTick(ticker) spy.AssertCalled(t, 1) }) + + t.Run("create a file in watcher directory", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dir := t.TempDir() + + spy := test.NewCallSpy() + + tickerChan := make(chan time.Time) + watchFunc, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + Directory: watcher.DirOptions{ + DirPath: dir, + }, + Callback: spy.Call, + TickSource: tickerChan, + }) + require.NoError(t, err) + + go func() { + _ = watchFunc(ctx) + }() + + sendTick(tickerChan) + sendTick(tickerChan) + + tempFile := filepath.Join(dir, "config.json") + require.NoError(t, os.WriteFile(tempFile, []byte("a"), 0o600)) + + sendTick(tickerChan) + sendTick(tickerChan) + spy.AssertCalled(t, 0) + sendTick(tickerChan) + spy.AssertCalled(t, 1) + }) + + t.Run("modify a file in watcher directory", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dir := t.TempDir() + tempFile := filepath.Join(dir, "config.json") + require.NoError(t, os.WriteFile(tempFile, []byte("a"), 0o600)) + + spy := test.NewCallSpy() + + tickerChan := make(chan time.Time) + watchFunc, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + Directory: watcher.DirOptions{ + DirPath: dir, + }, + Callback: spy.Call, + TickSource: tickerChan, + }) + require.NoError(t, err) + + go func() { + _ = watchFunc(ctx) + }() + + sendTick(tickerChan) + sendTick(tickerChan) + + require.NoError(t, os.WriteFile(tempFile, []byte("b"), 0o600)) + + sendTick(tickerChan) + sendTick(tickerChan) + spy.AssertCalled(t, 1) + }) + + t.Run("delete a file in watcher directory", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dir := t.TempDir() + tempFile := filepath.Join(dir, "config.json") + require.NoError(t, os.WriteFile(tempFile, []byte("a"), 0o600)) + + spy := test.NewCallSpy() + + tickerChan := make(chan time.Time) + watchFunc, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + Directory: watcher.DirOptions{ + DirPath: dir, + }, + Callback: spy.Call, + TickSource: tickerChan, + }) + require.NoError(t, err) + + go func() { + _ = watchFunc(ctx) + }() + + sendTick(tickerChan) + sendTick(tickerChan) + + require.NoError(t, os.Remove(tempFile)) + + sendTick(tickerChan) + sendTick(tickerChan) + spy.AssertCalled(t, 1) + }) + + t.Run("rename multiple files in watcher directory", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dir := t.TempDir() + tempFile1 := filepath.Join(dir, "file_1.json") + tempFile2 := filepath.Join(dir, "file_2.json") + tempFile3 := filepath.Join(dir, "file_3.json") + + require.NoError(t, os.WriteFile(tempFile1, []byte("file_1"), 0o600)) + require.NoError(t, os.WriteFile(tempFile2, []byte("file_2"), 0o600)) + require.NoError(t, os.WriteFile(tempFile3, []byte("file_3"), 0o600)) + + spy := test.NewCallSpy() + + tickerChan := make(chan time.Time) + watchFunc, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + Directory: watcher.DirOptions{ + DirPath: dir, + }, + Callback: spy.Call, + TickSource: tickerChan, + }) + require.NoError(t, err) + + go func() { + _ = watchFunc(ctx) + }() + + sendTick(tickerChan) + sendTick(tickerChan) + + newTempFile1 := filepath.Join(dir, "new_file_1.json") + newTempFile2 := filepath.Join(dir, "new_file_2.json") + newTempFile3 := filepath.Join(dir, "new_file_3.json") + + require.NoError(t, os.Rename(tempFile1, newTempFile1)) + require.NoError(t, os.Rename(tempFile2, newTempFile2)) + require.NoError(t, os.Rename(tempFile3, newTempFile3)) + + sendTick(tickerChan) + sendTick(tickerChan) + spy.AssertCalled(t, 0) + sendTick(tickerChan) + spy.AssertCalled(t, 1) + }) } func TestCancel(t *testing.T) { @@ -806,6 +962,35 @@ func TestCancel(t *testing.T) { require.ErrorIs(t, eg.Wait(), context.Canceled) } +func TestListDirFilePaths(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + tempFile1 := filepath.Join(dir, "file_1.json") + tempFile2 := filepath.Join(dir, "file_2.json") + tempFile3 := filepath.Join(dir, "file_3.txt") + + require.NoError(t, os.WriteFile(tempFile1, []byte("a1"), 0o600)) + require.NoError(t, os.WriteFile(tempFile2, []byte("a2"), 0o600)) + require.NoError(t, os.WriteFile(tempFile3, []byte("a3"), 0o600)) + + diropts := watcher.DirOptions{ + DirPath: dir, + Filter: func(path string) bool { + return strings.ToLower(filepath.Ext(path)) == ".json" + }, + } + + filteredFilePaths, err := watcher.ListDirFilePaths(diropts) + + require.NoError(t, err) + require.Len(t, filteredFilePaths, 2) + + require.Contains(t, filteredFilePaths, tempFile1) + require.Contains(t, filteredFilePaths, tempFile2) + require.NotContains(t, filteredFilePaths, tempFile3) +} + // sendTick helper function which adds a sleep timeout // so users don't need to manually add sleep to every test func sendTick(channel chan time.Time) { From 12d12dfb175e7b27074ad1520731e2b75608cc70 Mon Sep 17 00:00:00 2001 From: melsonic Date: Wed, 16 Jul 2025 00:09:32 +0530 Subject: [PATCH 10/12] minor change --- router/pkg/watcher/watcher.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/pkg/watcher/watcher.go b/router/pkg/watcher/watcher.go index d3ea1bc252..c157fdf4d6 100644 --- a/router/pkg/watcher/watcher.go +++ b/router/pkg/watcher/watcher.go @@ -45,7 +45,7 @@ func ListDirFilePaths(diropts DirOptions) ([]string, error) { return nil }) if err != nil { - return []string{}, fmt.Errorf("error walking directory %s: %w", diropts.DirPath, err) + return files, fmt.Errorf("error walking directory %s: %w", diropts.DirPath, err) } } return files, nil From 112a1532e7bf88e4e3c1fcba7293be97771ad19d Mon Sep 17 00:00:00 2001 From: melsonic Date: Thu, 17 Jul 2025 00:22:03 +0530 Subject: [PATCH 11/12] fail watcher if both paths & dir are empty + consistent change detection --- router/pkg/watcher/watcher.go | 29 ++++++++++++++++++----------- router/pkg/watcher/watcher_test.go | 17 ++++++++++++++--- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/router/pkg/watcher/watcher.go b/router/pkg/watcher/watcher.go index c157fdf4d6..e2b4d94431 100644 --- a/router/pkg/watcher/watcher.go +++ b/router/pkg/watcher/watcher.go @@ -60,6 +60,10 @@ func New(options Options) (func(ctx context.Context) error, error) { return nil, errors.New("logger must be provided") } + if len(options.Paths) == 0 && options.Directory.DirPath == "" { + return nil, errors.New("either paths or directory must be provided") + } + if len(options.Paths) != 0 && options.Directory.DirPath != "" { return nil, errors.New("can't watch both paths and directory") } @@ -105,17 +109,19 @@ func New(options Options) (func(ctx context.Context) error, error) { case <-ticker: changesDetected := false + if options.Directory.DirPath != "" { + options.Paths, err = ListDirFilePaths(options.Directory) + if err != nil { + ll.Error("failed to list directory files", zap.Error(err)) + } + } + for _, path := range options.Paths { stat, err := os.Stat(path) if err != nil { ll.Debug("Target file cannot be statted", zap.String("path", path), zap.Error(err)) - if os.IsNotExist(err) { - delete(prevModTimes, path) - changesDetected = true - } else { - // Reset the mod time so we catch any new file at the target path - prevModTimes[path] = time.Time{} - } + // Reset the mod time so we catch any new file at the target path + prevModTimes[path] = time.Time{} continue } ll.Debug("Checking file for changes", @@ -132,10 +138,11 @@ func New(options Options) (func(ctx context.Context) error, error) { } } - if options.Directory.DirPath != "" { - options.Paths, err = ListDirFilePaths(options.Directory) - if err != nil { - ll.Error("failed to list directory files", zap.Error(err)) + for path := range prevModTimes { + _, err := os.Stat(path) + if os.IsNotExist(err) { + delete(prevModTimes, path) + changesDetected = true } } diff --git a/router/pkg/watcher/watcher_test.go b/router/pkg/watcher/watcher_test.go index 877e443cec..e02030e180 100644 --- a/router/pkg/watcher/watcher_test.go +++ b/router/pkg/watcher/watcher_test.go @@ -53,6 +53,19 @@ func TestOptionsValidation(t *testing.T) { } }) + t.Run("either paths or directory must be provided", func(t *testing.T) { + t.Parallel() + + _, err := watcher.New(watcher.Options{ + Interval: watchInterval, + Logger: zap.NewNop(), + }) + + if assert.Error(t, err) { + assert.ErrorContains(t, err, "either paths or directory must be provided") + } + }) + t.Run("can't watch both paths and directory", func(t *testing.T) { t.Parallel() dir := t.TempDir() @@ -800,7 +813,6 @@ func TestWatch(t *testing.T) { tempFile := filepath.Join(dir, "config.json") require.NoError(t, os.WriteFile(tempFile, []byte("a"), 0o600)) - sendTick(tickerChan) sendTick(tickerChan) spy.AssertCalled(t, 0) sendTick(tickerChan) @@ -925,10 +937,9 @@ func TestWatch(t *testing.T) { require.NoError(t, os.Rename(tempFile2, newTempFile2)) require.NoError(t, os.Rename(tempFile3, newTempFile3)) + // Two ticks are needed to run the callback sendTick(tickerChan) sendTick(tickerChan) - spy.AssertCalled(t, 0) - sendTick(tickerChan) spy.AssertCalled(t, 1) }) } From d6ba838bcc645cecb52456da8d43aee3d38cd6a2 Mon Sep 17 00:00:00 2001 From: melsonic Date: Thu, 17 Jul 2025 23:06:51 +0530 Subject: [PATCH 12/12] extracted mcp shutdown tests to lifecycle dir --- .../lifecycle/mcp_hot_reload_shutdown_test.go | 92 +++++++++++++++++++ router-tests/mcp_hot_reload_test.go | 76 --------------- 2 files changed, 92 insertions(+), 76 deletions(-) create mode 100644 router-tests/lifecycle/mcp_hot_reload_shutdown_test.go diff --git a/router-tests/lifecycle/mcp_hot_reload_shutdown_test.go b/router-tests/lifecycle/mcp_hot_reload_shutdown_test.go new file mode 100644 index 0000000000..0ff3567c89 --- /dev/null +++ b/router-tests/lifecycle/mcp_hot_reload_shutdown_test.go @@ -0,0 +1,92 @@ +package integration + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "go.uber.org/goleak" +) + +func TestShutDownMCPGoRoutineLeaks(t *testing.T) { + + defer goleak.VerifyNone(t, + goleak.IgnoreTopFunction("github.com/hashicorp/consul/sdk/freeport.checkFreedPorts"), // Freeport, spawned by init + goleak.IgnoreAnyFunction("net/http.(*conn).serve"), // HTTPTest server I can't close if I want to keep the problematic goroutine open for the test + ) + + operationsDir := t.TempDir() + storageProviderID := "mcp_hot_reload_test_id" + + xEnv, err := testenv.CreateTestEnv(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + Storage: config.MCPStorageConfig{ + ProviderID: storageProviderID, + }, + HotReloadConfig: config.MCPOperationsHotReloadConfig{ + Enabled: true, + Interval: 1 * time.Second, + }, + }, + RouterOptions: []core.Option{ + core.WithStorageProviders(config.StorageProviders{ + FileSystem: []config.FileSystemStorageProvider{ + { + ID: storageProviderID, + Path: operationsDir, + }, + }, + }), + }, + }) + + require.NoError(t, err) + + mcpOperationFile := filepath.Join(operationsDir, "main.graphql") + // write mcp operation content + require.NoError(t, os.WriteFile(mcpOperationFile, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0o600)) + + // Verify GoRoutines are properly setup for Hot Reloading + require.EventuallyWithT(t, func(t *assert.CollectT) { + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + assert.NoError(t, err) + + require.Contains(t, resp.Tools, mcp.Tool{ + Name: "execute_operation_get_employee_notes", + Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, + Required: []string{"id"}, + }, + RawInputSchema: json.RawMessage(nil), + Annotations: mcp.ToolAnnotation{ + Title: "Execute operation getEmployeeNotes", + ReadOnlyHint: mcp.ToBoolPtr(true), + IdempotentHint: mcp.ToBoolPtr(true), + OpenWorldHint: mcp.ToBoolPtr(true), + }, + }) + }, 10*time.Second, 100*time.Millisecond) + + xEnv.Shutdown() + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + if assert.Error(t, err) { + require.ErrorIs(t, err, testenv.ErrEnvironmentClosed) + } + require.Nil(t, resp) + +} diff --git a/router-tests/mcp_hot_reload_test.go b/router-tests/mcp_hot_reload_test.go index 4d606dc95d..92b6f5ef6f 100644 --- a/router-tests/mcp_hot_reload_test.go +++ b/router-tests/mcp_hot_reload_test.go @@ -13,7 +13,6 @@ import ( "github.com/wundergraph/cosmo/router-tests/testenv" "github.com/wundergraph/cosmo/router/core" "github.com/wundergraph/cosmo/router/pkg/config" - "go.uber.org/goleak" ) func TestMCPOperationHotReload(t *testing.T) { @@ -200,78 +199,3 @@ func TestMCPOperationHotReload(t *testing.T) { }) }) } - -func TestShutDownMCPGoRoutineLeaks(t *testing.T) { - - defer goleak.VerifyNone(t, - goleak.IgnoreTopFunction("github.com/hashicorp/consul/sdk/freeport.checkFreedPorts"), // Freeport, spawned by init - goleak.IgnoreAnyFunction("net/http.(*conn).serve"), // HTTPTest server I can't close if I want to keep the problematic goroutine open for the test - ) - - operationsDir := t.TempDir() - storageProviderID := "mcp_hot_reload_test_id" - - xEnv, err := testenv.CreateTestEnv(t, &testenv.Config{ - MCP: config.MCPConfiguration{ - Enabled: true, - Storage: config.MCPStorageConfig{ - ProviderID: storageProviderID, - }, - HotReloadConfig: config.MCPOperationsHotReloadConfig{ - Enabled: true, - Interval: 1 * time.Second, - }, - }, - RouterOptions: []core.Option{ - core.WithStorageProviders(config.StorageProviders{ - FileSystem: []config.FileSystemStorageProvider{ - { - ID: storageProviderID, - Path: operationsDir, - }, - }, - }), - }, - }) - - require.NoError(t, err) - - mcpOperationFile := filepath.Join(operationsDir, "main.graphql") - // write mcp operation content - require.NoError(t, os.WriteFile(mcpOperationFile, []byte("query getEmployeeNotes($id: Int!) {\nemployee(id: $id) {\nid\nnotes\n}\n}"), 0o600)) - - // Verify GoRoutines are properly setup for Hot Reloading - require.EventuallyWithT(t, func(t *assert.CollectT) { - - toolsRequest := mcp.ListToolsRequest{} - resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) - assert.NoError(t, err) - - require.Contains(t, resp.Tools, mcp.Tool{ - Name: "execute_operation_get_employee_notes", - Description: "Executes the GraphQL operation 'getEmployeeNotes' of type query.", - InputSchema: mcp.ToolInputSchema{ - Type: "object", - Properties: map[string]interface{}{"id": map[string]interface{}{"type": "integer"}}, - Required: []string{"id"}, - }, - RawInputSchema: json.RawMessage(nil), - Annotations: mcp.ToolAnnotation{ - Title: "Execute operation getEmployeeNotes", - ReadOnlyHint: mcp.ToBoolPtr(true), - IdempotentHint: mcp.ToBoolPtr(true), - OpenWorldHint: mcp.ToBoolPtr(true), - }, - }) - }, 10*time.Second, 100*time.Millisecond) - - xEnv.Shutdown() - - toolsRequest := mcp.ListToolsRequest{} - resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) - if assert.Error(t, err) { - require.ErrorIs(t, err, testenv.ErrEnvironmentClosed) - } - require.Nil(t, resp) - -}