diff --git a/router-tests/mcp_test.go b/router-tests/mcp_test.go index e89a2c0388..0506760655 100644 --- a/router-tests/mcp_test.go +++ b/router-tests/mcp_test.go @@ -42,9 +42,16 @@ func TestMCP(t *testing.T) { Name: "get_operation_info", Description: "Provides instructions on how to execute the GraphQL operation via HTTP and how to integrate it into your application.", InputSchema: mcp.ToolInputSchema{ - Type: "object", - Properties: map[string]interface{}{"operationName": map[string]interface{}{"description": "The exact name of the GraphQL operation to retrieve information for.", "enum": []interface{}{"UpdateMood", "MyEmployees"}, "type": "string"}}, - Required: []string{"operationName"}}, + Type: "object", + Properties: map[string]interface{}{ + "operationName": map[string]interface{}{ + "description": "The exact name of the GraphQL operation to retrieve information for.", + "enum": []interface{}{"CustomNamedQuery", "UpdateMood", "MyEmployees"}, + "type": "string", + }, + }, + Required: []string{"operationName"}, + }, RawInputSchema: json.RawMessage(nil), Annotations: mcp.ToolAnnotation{ Title: "Get GraphQL Operation Info", @@ -187,6 +194,36 @@ func TestMCP(t *testing.T) { }) }) + t.Run("List user Operations / Custom tool name via @mcpTool directive", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + MCP: config.MCPConfiguration{ + Enabled: true, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + toolsRequest := mcp.ListToolsRequest{} + resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest) + require.NoError(t, err) + require.NotNil(t, resp) + + var customTool *mcp.Tool + for i, tool := range resp.Tools { + if tool.Name == "get_employee_by_id" { + customTool = &resp.Tools[i] + break + } + } + + require.NotNil(t, customTool, "Tool get_employee_by_id should be found") + assert.Equal(t, "A query with a custom MCP tool name.", customTool.Description) + + for _, tool := range resp.Tools { + assert.NotEqual(t, "execute_operation_custom_named_query", tool.Name, + "Tool should not be registered with default generated name") + } + }) + }) + t.Run("Execute Operation Info", func(t *testing.T) { testenv.Run(t, &testenv.Config{ MCP: config.MCPConfiguration{ diff --git a/router-tests/testdata/mcp_operations/CustomToolName.graphql b/router-tests/testdata/mcp_operations/CustomToolName.graphql new file mode 100644 index 0000000000..c2d77d9838 --- /dev/null +++ b/router-tests/testdata/mcp_operations/CustomToolName.graphql @@ -0,0 +1,11 @@ +""" +A query with a custom MCP tool name. +""" +query CustomNamedQuery($id: Int!) @mcpTool(name: "get_employee_by_id") { + employee(id: $id) { + id + details { + forename + } + } +} diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index 173e2bf9d2..b970d503f3 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -536,8 +536,14 @@ func (s *GraphQLSchemaServer) registerTools() error { compiledSchema: compiledSchema, } - // Convert the operation name to snake_case for consistent tool naming - operationToolName := strcase.ToSnake(op.Name) + // Use custom tool name if provided via @mcpTool directive, otherwise generate default + var toolName string + if op.ToolName != "" { + toolName = op.ToolName + } else { + operationToolName := strcase.ToSnake(op.Name) + toolName = fmt.Sprintf("execute_operation_%s", operationToolName) + } // Use the operation description directly if provided, otherwise generate a default description var toolDescription string @@ -546,8 +552,6 @@ func (s *GraphQLSchemaServer) registerTools() error { } else { toolDescription = fmt.Sprintf("Executes the GraphQL operation '%s' of type %s.", op.Name, op.OperationType) } - - toolName := fmt.Sprintf("execute_operation_%s", operationToolName) tool := mcp.NewToolWithRawSchema( toolName, toolDescription, diff --git a/router/pkg/schemaloader/loader.go b/router/pkg/schemaloader/loader.go index cd3f53dad8..ead08b4b6c 100644 --- a/router/pkg/schemaloader/loader.go +++ b/router/pkg/schemaloader/loader.go @@ -18,6 +18,7 @@ import ( // Operation represents a GraphQL operation with its AST document and schema information type Operation struct { Name string + ToolName string FilePath string Document ast.Document OperationString string @@ -46,7 +47,6 @@ func NewOperationLoader(logger *zap.Logger, schemaDoc *ast.Document) *OperationL func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operation, error) { var operations []Operation - // Create an operation validator validator := astvalidation.DefaultOperationValidator() // Walk through the directory and process GraphQL files @@ -92,7 +92,10 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operati return nil } - // Validate operation against schema + opDescription := extractOperationDescription(&opDoc) + toolName := extractMCPToolName(&opDoc) + stripMCPDirective(&opDoc) + validationReport := operationreport.Report{} validationState := validator.Validate(&opDoc, l.SchemaDocument, &validationReport) if validationState == astvalidation.Invalid { @@ -116,12 +119,10 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operati } } - // Extract description from operation definition - opDescription := extractOperationDescription(&opDoc) - // Add to our list of operations operations = append(operations, Operation{ Name: opName, + ToolName: toolName, FilePath: path, Document: opDoc, OperationString: operationString, @@ -200,3 +201,50 @@ func extractOperationDescription(doc *ast.Document) string { } return "" } + +var mcpDirectiveName = []byte("mcpTool") +var mcpNameArgument = []byte("name") + +// extractMCPToolName extracts the custom tool name from the @mcpTool directive on an operation +func extractMCPToolName(doc *ast.Document) string { + for _, ref := range doc.RootNodes { + if ref.Kind == ast.NodeKindOperationDefinition { + opDef := doc.OperationDefinitions[ref.Ref] + if !opDef.HasDirectives { + return "" + } + + directiveRef, exists := doc.DirectiveWithNameBytes(opDef.Directives.Refs, mcpDirectiveName) + if !exists { + return "" + } + + value, argExists := doc.DirectiveArgumentValueByName(directiveRef, mcpNameArgument) + if !argExists { + return "" + } + + if value.Kind == ast.ValueKindString { + return doc.StringValueContentString(value.Ref) + } + + return "" + } + } + return "" +} + +// stripMCPDirective removes the @mcpTool directive from an operation before validation +func stripMCPDirective(doc *ast.Document) { + for _, ref := range doc.RootNodes { + if ref.Kind == ast.NodeKindOperationDefinition { + opDef := &doc.OperationDefinitions[ref.Ref] + if !opDef.HasDirectives { + return + } + opDef.Directives.RemoveDirectiveByName(doc, string(mcpDirectiveName)) + opDef.HasDirectives = len(opDef.Directives.Refs) > 0 + return + } + } +} diff --git a/router/pkg/schemaloader/loader_test.go b/router/pkg/schemaloader/loader_test.go index b4573d89a5..729fc0167c 100644 --- a/router/pkg/schemaloader/loader_test.go +++ b/router/pkg/schemaloader/loader_test.go @@ -165,3 +165,107 @@ func TestLoadOperationsFromEmptyDirectory(t *testing.T) { require.NoError(t, err) assert.Len(t, operations, 0, "Empty directory should return no operations") } + +func TestExtractMCPToolName(t *testing.T) { + tests := []struct { + name string + query string + expected string + }{ + { + name: "with @mcpTool directive and name argument", + query: `query Foo @mcpTool(name: "custom_foo") { bar }`, + expected: "custom_foo", + }, + { + name: "without directive", + query: `query Foo { bar }`, + expected: "", + }, + { + name: "with @mcpTool but no name argument", + query: `query Foo @mcpTool { bar }`, + expected: "", + }, + { + name: "with different directive", + query: `query Foo @deprecated { bar }`, + expected: "", + }, + { + name: "with @mcpTool and other arguments but no name", + query: `query Foo @mcpTool(other: "value") { bar }`, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + doc, report := astparser.ParseGraphqlDocumentString(tt.query) + require.False(t, report.HasErrors()) + + result := extractMCPToolName(&doc) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestLoadOperationsWithMCPDirective(t *testing.T) { + tempDir := t.TempDir() + + testFiles := map[string]string{ + "WithCustomToolName.graphql": `"""Custom tool operation""" +query MyQuery @mcpTool(name: "custom_tool_name") { + employee(id: "1") { + id + name + } +}`, + "WithoutDirective.graphql": `query AnotherQuery { + employee(id: "1") { + id + name + } +}`, + } + + for filename, content := range testFiles { + err := os.WriteFile(filepath.Join(tempDir, filename), []byte(content), 0644) + require.NoError(t, err) + } + + schemaStr := ` +directive @mcpTool(name: String) on QUERY | MUTATION + +type Query { + employee(id: ID!): Employee +} + +type Employee { + id: ID! + name: String! +} +` + schemaDoc, report := astparser.ParseGraphqlDocumentString(schemaStr) + require.False(t, report.HasErrors()) + + err := asttransform.MergeDefinitionWithBaseSchema(&schemaDoc) + require.NoError(t, err) + + logger := zap.NewNop() + loader := NewOperationLoader(logger, &schemaDoc) + operations, err := loader.LoadOperationsFromDirectory(tempDir) + require.NoError(t, err) + require.Len(t, operations, 2) + + opMap := make(map[string]Operation) + for _, op := range operations { + opMap[op.Name] = op + } + + op1 := opMap["MyQuery"] + assert.Equal(t, "custom_tool_name", op1.ToolName) + + op2 := opMap["AnotherQuery"] + assert.Empty(t, op2.ToolName) +}