Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions router-tests/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,16 @@ func TestMCP(t *testing.T) {
Name: "get_operation_info",
Description: "Provides instructions on how to execute the GraphQL operation via HTTP and how to integrate it into your application.",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]interface{}{"operationName": map[string]interface{}{"description": "The exact name of the GraphQL operation to retrieve information for.", "enum": []interface{}{"UpdateMood", "MyEmployees"}, "type": "string"}},
Required: []string{"operationName"}},
Type: "object",
Properties: map[string]interface{}{
"operationName": map[string]interface{}{
"description": "The exact name of the GraphQL operation to retrieve information for.",
"enum": []interface{}{"CustomNamedQuery", "UpdateMood", "MyEmployees"},
"type": "string",
},
},
Required: []string{"operationName"},
},
RawInputSchema: json.RawMessage(nil),
Annotations: mcp.ToolAnnotation{
Title: "Get GraphQL Operation Info",
Expand Down Expand Up @@ -187,6 +194,36 @@ func TestMCP(t *testing.T) {
})
})

t.Run("List user Operations / Custom tool name via @mcpTool directive", func(t *testing.T) {
testenv.Run(t, &testenv.Config{
MCP: config.MCPConfiguration{
Enabled: true,
},
}, func(t *testing.T, xEnv *testenv.Environment) {

toolsRequest := mcp.ListToolsRequest{}
resp, err := xEnv.MCPClient.ListTools(xEnv.Context, toolsRequest)
require.NoError(t, err)
require.NotNil(t, resp)

var customTool *mcp.Tool
for i, tool := range resp.Tools {
if tool.Name == "get_employee_by_id" {
customTool = &resp.Tools[i]
break
}
}

require.NotNil(t, customTool, "Tool get_employee_by_id should be found")
assert.Equal(t, "A query with a custom MCP tool name.", customTool.Description)

for _, tool := range resp.Tools {
assert.NotEqual(t, "execute_operation_custom_named_query", tool.Name,
"Tool should not be registered with default generated name")
}
})
})

t.Run("Execute Operation Info", func(t *testing.T) {
testenv.Run(t, &testenv.Config{
MCP: config.MCPConfiguration{
Expand Down
11 changes: 11 additions & 0 deletions router-tests/testdata/mcp_operations/CustomToolName.graphql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
A query with a custom MCP tool name.
"""
query CustomNamedQuery($id: Int!) @mcpTool(name: "get_employee_by_id") {
employee(id: $id) {
id
details {
forename
}
}
}
12 changes: 8 additions & 4 deletions router/pkg/mcpserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,14 @@ func (s *GraphQLSchemaServer) registerTools() error {
compiledSchema: compiledSchema,
}

// Convert the operation name to snake_case for consistent tool naming
operationToolName := strcase.ToSnake(op.Name)
// Use custom tool name if provided via @mcpTool directive, otherwise generate default
var toolName string
if op.ToolName != "" {
toolName = op.ToolName
} else {
operationToolName := strcase.ToSnake(op.Name)
toolName = fmt.Sprintf("execute_operation_%s", operationToolName)
}

// Use the operation description directly if provided, otherwise generate a default description
var toolDescription string
Expand All @@ -546,8 +552,6 @@ func (s *GraphQLSchemaServer) registerTools() error {
} else {
toolDescription = fmt.Sprintf("Executes the GraphQL operation '%s' of type %s.", op.Name, op.OperationType)
}

toolName := fmt.Sprintf("execute_operation_%s", operationToolName)
tool := mcp.NewToolWithRawSchema(
toolName,
toolDescription,
Expand Down
58 changes: 53 additions & 5 deletions router/pkg/schemaloader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
// Operation represents a GraphQL operation with its AST document and schema information
type Operation struct {
Name string
ToolName string
FilePath string
Document ast.Document
OperationString string
Expand Down Expand Up @@ -46,7 +47,6 @@ func NewOperationLoader(logger *zap.Logger, schemaDoc *ast.Document) *OperationL
func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operation, error) {
var operations []Operation

// Create an operation validator
validator := astvalidation.DefaultOperationValidator()

// Walk through the directory and process GraphQL files
Expand Down Expand Up @@ -92,7 +92,10 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operati
return nil
}

// Validate operation against schema
opDescription := extractOperationDescription(&opDoc)
toolName := extractMCPToolName(&opDoc)
stripMCPDirective(&opDoc)

validationReport := operationreport.Report{}
validationState := validator.Validate(&opDoc, l.SchemaDocument, &validationReport)
if validationState == astvalidation.Invalid {
Expand All @@ -116,12 +119,10 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operati
}
}

// Extract description from operation definition
opDescription := extractOperationDescription(&opDoc)

// Add to our list of operations
operations = append(operations, Operation{
Name: opName,
ToolName: toolName,
FilePath: path,
Document: opDoc,
OperationString: operationString,
Expand Down Expand Up @@ -200,3 +201,50 @@ func extractOperationDescription(doc *ast.Document) string {
}
return ""
}

var mcpDirectiveName = []byte("mcpTool")
var mcpNameArgument = []byte("name")

// extractMCPToolName extracts the custom tool name from the @mcpTool directive on an operation
func extractMCPToolName(doc *ast.Document) string {
for _, ref := range doc.RootNodes {
if ref.Kind == ast.NodeKindOperationDefinition {
opDef := doc.OperationDefinitions[ref.Ref]
if !opDef.HasDirectives {
return ""
}

directiveRef, exists := doc.DirectiveWithNameBytes(opDef.Directives.Refs, mcpDirectiveName)
if !exists {
return ""
}

value, argExists := doc.DirectiveArgumentValueByName(directiveRef, mcpNameArgument)
if !argExists {
return ""
}

if value.Kind == ast.ValueKindString {
return doc.StringValueContentString(value.Ref)
}

return ""
}
}
return ""
}

// stripMCPDirective removes the @mcpTool directive from an operation before validation
func stripMCPDirective(doc *ast.Document) {
for _, ref := range doc.RootNodes {
if ref.Kind == ast.NodeKindOperationDefinition {
opDef := &doc.OperationDefinitions[ref.Ref]
if !opDef.HasDirectives {
return
}
opDef.Directives.RemoveDirectiveByName(doc, string(mcpDirectiveName))
opDef.HasDirectives = len(opDef.Directives.Refs) > 0
return
}
}
}
104 changes: 104 additions & 0 deletions router/pkg/schemaloader/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,107 @@ func TestLoadOperationsFromEmptyDirectory(t *testing.T) {
require.NoError(t, err)
assert.Len(t, operations, 0, "Empty directory should return no operations")
}

func TestExtractMCPToolName(t *testing.T) {
tests := []struct {
name string
query string
expected string
}{
{
name: "with @mcpTool directive and name argument",
query: `query Foo @mcpTool(name: "custom_foo") { bar }`,
expected: "custom_foo",
},
{
name: "without directive",
query: `query Foo { bar }`,
expected: "",
},
{
name: "with @mcpTool but no name argument",
query: `query Foo @mcpTool { bar }`,
expected: "",
},
{
name: "with different directive",
query: `query Foo @deprecated { bar }`,
expected: "",
},
{
name: "with @mcpTool and other arguments but no name",
query: `query Foo @mcpTool(other: "value") { bar }`,
expected: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
doc, report := astparser.ParseGraphqlDocumentString(tt.query)
require.False(t, report.HasErrors())

result := extractMCPToolName(&doc)
assert.Equal(t, tt.expected, result)
})
}
}

func TestLoadOperationsWithMCPDirective(t *testing.T) {
tempDir := t.TempDir()

testFiles := map[string]string{
"WithCustomToolName.graphql": `"""Custom tool operation"""
query MyQuery @mcpTool(name: "custom_tool_name") {
employee(id: "1") {
id
name
}
}`,
"WithoutDirective.graphql": `query AnotherQuery {
employee(id: "1") {
id
name
}
}`,
}

for filename, content := range testFiles {
err := os.WriteFile(filepath.Join(tempDir, filename), []byte(content), 0644)
require.NoError(t, err)
}

schemaStr := `
directive @mcpTool(name: String) on QUERY | MUTATION

type Query {
employee(id: ID!): Employee
}

type Employee {
id: ID!
name: String!
}
`
schemaDoc, report := astparser.ParseGraphqlDocumentString(schemaStr)
require.False(t, report.HasErrors())

err := asttransform.MergeDefinitionWithBaseSchema(&schemaDoc)
require.NoError(t, err)

logger := zap.NewNop()
loader := NewOperationLoader(logger, &schemaDoc)
operations, err := loader.LoadOperationsFromDirectory(tempDir)
require.NoError(t, err)
require.Len(t, operations, 2)

opMap := make(map[string]Operation)
for _, op := range operations {
opMap[op.Name] = op
}

op1 := opMap["MyQuery"]
assert.Equal(t, "custom_tool_name", op1.ToolName)

op2 := opMap["AnotherQuery"]
assert.Empty(t, op2.ToolName)
}