diff --git a/router-tests/go.mod b/router-tests/go.mod index 3de9ec5fcf..a3a264de3f 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -26,7 +26,7 @@ require ( github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects v0.0.0-20250715110703-10f2e5f9c79e github.com/wundergraph/cosmo/router v0.0.0-20251125205644-175f80c4e6d9 github.com/wundergraph/cosmo/router-plugin v0.0.0-20250808194725-de123ba1c65e - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.245 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.246.0.20260202150435-e2c713b42e65 go.opentelemetry.io/otel v1.36.0 go.opentelemetry.io/otel/sdk v1.36.0 go.opentelemetry.io/otel/sdk/metric v1.36.0 diff --git a/router-tests/go.sum b/router-tests/go.sum index dc73f635fd..12b17fd61c 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -352,8 +352,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.245 h1:MYewlXgIhI9jusocPUeyo346J3M5cqzc6ddru1qp+S8= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.245/go.mod h1:mX25ASEQiKamxaFSK6NZihh0oDCigIuzro30up4mFH4= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.246.0.20260202150435-e2c713b42e65 h1:o5wqeMnGK2GdyTXlcLKZzPRrVtpzT55Iua5zyMT4ErU= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.246.0.20260202150435-e2c713b42e65/go.mod h1:mX25ASEQiKamxaFSK6NZihh0oDCigIuzro30up4mFH4= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index b517c287f1..a63083e414 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -711,4 +711,76 @@ func TestStartSubscriptionHook(t *testing.T) { assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) + + t.Run("Test StartSubscription hook can access field arguments", func(t *testing.T) { + t.Parallel() + + // This test verifies that the subscription start hook can access GraphQL field arguments + // via ctx.Operation().Arguments(). + + var capturedEmployeeID int + + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + args := ctx.Operation().Arguments() + if args != nil { + employeeIDArg := args.Get("subscription.employeeUpdatedMyKafka.employeeID") + if employeeIDArg != nil { + capturedEmployeeID = employeeIDArg.GetInt() + } + } + return nil + }, + } + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": customModule, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 7, + } + subscriptionOneID, err := client.Subscribe(&subscriptionOne, vars, func(dataValue []byte, errValue error) error { + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) + assert.Equal(t, 7, capturedEmployeeID, "expected to capture employeeID argument value") + }) + }) } diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go index bc5187d400..af93b4c6f6 100644 --- a/router-tests/modules/stream_publish_test.go +++ b/router-tests/modules/stream_publish_test.go @@ -358,4 +358,52 @@ func TestPublishHook(t *testing.T) { require.Equal(t, []byte("3"), header.Value) }) }) + + t.Run("Test Publish hook can access field arguments", func(t *testing.T) { + t.Parallel() + + // This test verifies that the publish hook can access GraphQL field arguments + // via ctx.Operation().Arguments(). + + var capturedEmployeeID int + + customModule := stream_publish.PublishModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + args := ctx.Operation().Arguments() + if args != nil { + employeeIDArg := args.Get("mutation.updateEmployeeMyKafka.employeeID") + if employeeIDArg != nil { + capturedEmployeeID = employeeIDArg.GetInt() + } + } + return events, nil + }, + } + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": customModule, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 5, update: {name: "test"}) { success } }`, + }) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) + + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) + assert.Equal(t, 5, capturedEmployeeID, "expected to capture employeeID argument value") + }) + }) } diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index c71d3019bb..3936d92b6c 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -963,4 +963,91 @@ func TestReceiveHook(t *testing.T) { assert.Equal(t, int32(3), customModule.HookCallCount.Load()) }) }) + + t.Run("Test Receive hook can access field arguments", func(t *testing.T) { + t.Parallel() + + // This test verifies that the receive hook can access GraphQL field arguments + // via ctx.Operation().Arguments(). + + var capturedEmployeeID int + + customModule := stream_receive.StreamReceiveModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + args := ctx.Operation().Arguments() + if args != nil { + employeeIDArg := args.Get("subscription.employeeUpdatedMyKafka.employeeID") + if employeeIDArg != nil { + capturedEmployeeID = employeeIDArg.GetInt() + } + } + return events, nil + }, + } + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": customModule, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) + assert.Equal(t, 3, capturedEmployeeID, "expected to capture employeeID argument value") + }) + }) } diff --git a/router/core/arguments.go b/router/core/arguments.go new file mode 100644 index 0000000000..eab6788119 --- /dev/null +++ b/router/core/arguments.go @@ -0,0 +1,106 @@ +package core + +import ( + "github.com/wundergraph/astjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astnormalization" +) + +// Arguments allow access to GraphQL field arguments used by clients. +type Arguments struct { + // mapping maps "fieldPath.argumentName" to "variableName". + // For example: {"user.posts.limit": "a", "user.id": "userId"} + mapping astnormalization.FieldArgumentMapping + + // variables contains the JSON-parsed variables from the request. + variables *astjson.Value +} + +// NewArguments creates an Arguments instance. +func NewArguments( + mapping astnormalization.FieldArgumentMapping, + variables *astjson.Value, +) Arguments { + return Arguments{ + mapping: mapping, + variables: variables, + } +} + +// Get will return the value of the field argument at path. +// +// To access a specific field argument you need to provide +// the path in it's GraphQL operation via dot notation, +// prefixed by the root levels type. +// +// Get("rootfield_operation_type.rootfield_name.other.fields.argument_name") +// +// To access the storeId field argument of the operation +// +// subscription { +// orderUpdated(storeId: 1) { +// id +// status +// } +// } +// +// you need to call Get("subscription.orderUpdated.storeId") . +// You can also access deeper nested fields. +// For example you can access the categoryId field of the operation +// +// subscription { +// orderUpdated(storeId: 1) { +// lineItems(categoryId: 2) { +// id +// name +// } +// } +// } +// +// by calling Get("subscription.orderUpdated.lineItems.categoryId") . +// +// If you use aliases in operation you need to provide the alias name +// instead of the field name. +// +// query { +// a: user(id: "1") { name } +// b: user(id: "2") { name } +// } +// +// You need to call Get("query.a.id") or Get("query.b.id") respectively. +// +// If you want to access field arguments of fragments, you need to +// access it on one of the fields where the fragment is resolved. +// +// fragment GoldTrophies on RaceDrivers { +// trophies(color:"gold") { +// title +// } +// } +// +// subscription { +// driversFinish { +// name +// ... GoldTrophies +// } +// } +// +// If you want to access the "color" field argument, you need to +// call Get("subscription.driversFinish.trophies.color") . +// The same concept applies to inline fragments. +// +// If fa is nil, or f or a cannot be found, nil is returned. +func (fa *Arguments) Get(path string) *astjson.Value { + if fa == nil || len(fa.mapping) == 0 || fa.variables == nil { + return nil + } + + // Look up variable name from field argument map + varName, ok := fa.mapping[path] + if !ok { + return nil + } + + // Use the name to get the actual value from + // the operation contexts variables. + return fa.variables.Get(varName) +} diff --git a/router/core/arguments_test.go b/router/core/arguments_test.go new file mode 100644 index 0000000000..a7f1428792 --- /dev/null +++ b/router/core/arguments_test.go @@ -0,0 +1,340 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/astjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astnormalization" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" + "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" +) + +func TestArgumentMapping(t *testing.T) { + testCases := []struct { + name string + schema string + operation string + variables string + assertions func(t *testing.T, result Arguments) + }{ + { + name: "root field arguments with variables are accessible", + schema: ` + type Query { + user(id: ID!): User + } + type User { + id: ID! + name: String! + } + `, + operation: ` + query GetUser($userId: ID!) { + user(id: $userId) { + id + name + } + } + `, + variables: `{"userId": "123"}`, + assertions: func(t *testing.T, result Arguments) { + idArg := result.Get("query.user.id") + require.NotNil(t, idArg, "expected 'id' argument on 'user' field") + assert.Equal(t, "123", string(idArg.GetStringBytes())) + }, + }, + { + name: "root field arguments without variables are accessible", + schema: ` + type Query { + user(id: ID!): User + } + type User { + id: ID! + name: String! + } + `, + operation: ` + query GetUser { + user(id: "123") { + id + name + } + } + `, + variables: `{}`, + assertions: func(t *testing.T, result Arguments) { + idArg := result.Get("query.user.id") + require.NotNil(t, idArg, "expected 'id' argument on 'user' field") + assert.Equal(t, "123", string(idArg.GetStringBytes())) + }, + }, + { + name: "nested field arguments are accessible", + schema: ` + type Query { + user(id: ID!): User + } + type User { + id: ID! + posts(limit: Int!, offset: Int): [Post!]! + } + type Post { + id: ID! + title: String! + } + `, + operation: ` + query GetUserPosts($userId: ID!, $limit: Int!, $offset: Int) { + user(id: $userId) { + id + posts(limit: $limit, offset: $offset) { + id + title + } + } + } + `, + variables: `{"userId": "user-1", "limit": 10, "offset": 5}`, + assertions: func(t *testing.T, result Arguments) { + // Assert root field argument + userIdArg := result.Get("query.user.id") + require.NotNil(t, userIdArg) + assert.Equal(t, "user-1", string(userIdArg.GetStringBytes())) + + // Assert nested field arguments (dot notation path) + limitArg := result.Get("query.user.posts.limit") + require.NotNil(t, limitArg, "expected 'limit' argument on 'user.posts' field") + assert.Equal(t, 10, limitArg.GetInt()) + + offsetArg := result.Get("query.user.posts.offset") + require.NotNil(t, offsetArg, "expected 'offset' argument on 'user.posts' field") + assert.Equal(t, 5, offsetArg.GetInt()) + }, + }, + { + name: "non-existent field returns nil", + schema: ` + type Query { + hello: String + } + `, + operation: ` + query { + hello + } + `, + variables: `{}`, + assertions: func(t *testing.T, result Arguments) { + arg := result.Get("query.hello.someArg") + require.Nil(t, arg, "expected nil for non-existent argument") + + arg = result.Get("query.nonExistent.arg") + require.Nil(t, arg, "expected nil for non-existent field") + }, + }, + { + name: "multiple root fields with arguments", + schema: ` + type Query { + user(id: ID!): User + post(slug: String!): Post + } + type User { + id: ID! + } + type Post { + slug: String! + } + `, + operation: ` + query GetUserAndPost($userId: ID!, $postSlug: String!) { + user(id: $userId) { + id + } + post(slug: $postSlug) { + slug + } + } + `, + variables: `{"userId": "user-123", "postSlug": "my-post"}`, + assertions: func(t *testing.T, result Arguments) { + userIdArg := result.Get("query.user.id") + require.NotNil(t, userIdArg) + assert.Equal(t, "user-123", string(userIdArg.GetStringBytes())) + + postSlugArg := result.Get("query.post.slug") + require.NotNil(t, postSlugArg) + assert.Equal(t, "my-post", string(postSlugArg.GetStringBytes())) + }, + }, + { + name: "array argument is accessible", + schema: ` + type Query { + users(ids: [ID!]!): [User!]! + } + type User { + id: ID! + name: String! + } + `, + operation: ` + query GetUsers($userIds: [ID!]!) { + users(ids: $userIds) { + id + name + } + } + `, + variables: `{"userIds": ["user-1", "user-2", "user-3"]}`, + assertions: func(t *testing.T, result Arguments) { + idsArg := result.Get("query.users.ids") + require.NotNil(t, idsArg, "expected 'ids' argument on 'users' field") + + // Verify it's an array + arr := idsArg.GetArray() + require.Len(t, arr, 3) + assert.Equal(t, "user-1", string(arr[0].GetStringBytes())) + assert.Equal(t, "user-2", string(arr[1].GetStringBytes())) + assert.Equal(t, "user-3", string(arr[2].GetStringBytes())) + }, + }, + { + name: "object argument is accessible", + schema: ` + type Query { + users(filter: UserFilter!): [User!]! + } + input UserFilter { + name: String + age: Int + active: Boolean! + } + type User { + id: ID! + name: String! + } + `, + operation: ` + query GetUsers($filter: UserFilter!) { + users(filter: $filter) { + id + name + } + } + `, + variables: `{"filter": {"name": "John", "age": 30, "active": true}}`, + assertions: func(t *testing.T, result Arguments) { + filterArg := result.Get("query.users.filter") + require.NotNil(t, filterArg, "expected 'filter' argument on 'users' field") + + // Verify it's an object and access its fields + obj := filterArg.GetObject() + require.NotNil(t, obj) + + nameVal := filterArg.Get("name") + require.NotNil(t, nameVal) + assert.Equal(t, "John", string(nameVal.GetStringBytes())) + + ageVal := filterArg.Get("age") + require.NotNil(t, ageVal) + assert.Equal(t, 30, ageVal.GetInt()) + + activeVal := filterArg.Get("active") + require.NotNil(t, activeVal) + assert.True(t, activeVal.GetBool()) + }, + }, + { + name: "aliased fields have unique paths", + schema: ` + type Query { + user(id: ID!): User + } + type User { + id: ID! + name: String! + } + `, + operation: ` + query GetUsers($id1: ID!, $id2: ID!) { + a: user(id: $id1) { + id + name + } + b: user(id: $id2) { + id + name + } + } + `, + variables: `{"id1": "user-1", "id2": "user-2"}`, + assertions: func(t *testing.T, result Arguments) { + // Access arguments using the alias, not the field name + aIdArg := result.Get("query.a.id") + require.NotNil(t, aIdArg, "expected 'id' argument on aliased field 'a'") + assert.Equal(t, "user-1", string(aIdArg.GetStringBytes())) + + bIdArg := result.Get("query.b.id") + require.NotNil(t, bIdArg, "expected 'id' argument on aliased field 'b'") + assert.Equal(t, "user-2", string(bIdArg.GetStringBytes())) + + // Using the field name should not find the arguments + userIdArg := result.Get("query.user.id") + assert.Nil(t, userIdArg, "expected nil when using field name instead of alias") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Parse schema + schema, report := astparser.ParseGraphqlDocumentString(tc.schema) + require.False(t, report.HasErrors(), "failed to parse schema") + err := asttransform.MergeDefinitionWithBaseSchema(&schema) + require.NoError(t, err) + + // Parse operation + operation, report := astparser.ParseGraphqlDocumentString(tc.operation) + require.False(t, report.HasErrors(), "failed to parse operation") + + // Set variables before normalization (like the router does) + operation.Input.Variables = []byte(tc.variables) + + // Normalize operation (merges provided variables with extracted inline literals) + rep := &operationreport.Report{} + norm := astnormalization.NewNormalizer(true, true) + norm.NormalizeOperation(&operation, &schema, rep) + require.False(t, rep.HasErrors(), "failed to normalize operation") + + // Then normalize variables using VariablesNormalizer which returns the field argument mapping + varNorm := astnormalization.NewVariablesNormalizer(true) + result := varNorm.NormalizeOperation(&operation, &schema, rep) + require.False(t, rep.HasErrors(), "failed to normalize variables") + + // Use normalized variables (includes both provided and extracted variables) + vars, err := astjson.ParseBytes(operation.Input.Variables) + require.NoError(t, err) + + arguments := NewArguments(result.FieldArgumentMapping, vars) + + tc.assertions(t, arguments) + }) + } +} + +func TestNewArguments_NilMapping(t *testing.T) { + // Test that nil mapping returns empty Arguments + result := NewArguments(nil, nil) + assert.Nil(t, result.Get("query.user.id")) +} + +func TestNewArguments_EmptyMapping(t *testing.T) { + // Test that empty mapping returns empty Arguments + result := NewArguments(astnormalization.FieldArgumentMapping{}, nil) + assert.Nil(t, result.Get("query.user.id")) +} diff --git a/router/core/cache_warmup.go b/router/core/cache_warmup.go index b914567291..2319a60f94 100644 --- a/router/core/cache_warmup.go +++ b/router/core/cache_warmup.go @@ -295,7 +295,7 @@ func (c *CacheWarmupPlanningProcessor) ProcessOperation(ctx context.Context, ope return nil, err } - _, _, err = k.NormalizeVariables() + _, _, _, err = k.NormalizeVariables() if err != nil { return nil, err } diff --git a/router/core/context.go b/router/core/context.go index 264e166c32..da81e3d66b 100644 --- a/router/core/context.go +++ b/router/core/context.go @@ -490,16 +490,16 @@ type OperationContext interface { Hash() uint64 // Content is the content of the operation Content() string - // Variables is the variables of the operation + // Arguments allow access to GraphQL operation field arguments. + Arguments() *Arguments + // Variables allow access to GraphQL operation variables. Variables() *astjson.Value // ClientInfo returns information about the client that initiated this operation ClientInfo() ClientInfo - // Sha256Hash returns the SHA256 hash of the original operation // It is important to note that this hash is not calculated just because this method has been called // and is only calculated based on other existing logic (such as if sha256Hash is used in expressions) Sha256Hash() string - // QueryPlanStats returns some statistics about the query plan for the operation // if called too early in request chain, it may be inaccurate for modules, using // in Middleware is recommended @@ -532,7 +532,11 @@ type operationContext struct { // RawContent is the raw content of the operation rawContent string // Content is the normalized content of the operation - content string + content string + // fieldArguments are the arguments of the operation. + // These are not mapped by default, only when certain custom modules require them. + fieldArguments Arguments + // variables are the variables of the operation variables *astjson.Value files []*httpclient.FileUpload clientInfo *ClientInfo @@ -568,6 +572,10 @@ func (o *operationContext) Variables() *astjson.Value { return o.variables } +func (o *operationContext) Arguments() *Arguments { + return &o.fieldArguments +} + func (o *operationContext) Files() []*httpclient.FileUpload { return o.files } diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 87aa96331c..7ed37d22a9 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1295,6 +1295,7 @@ func (s *graphServer) buildGraphMux( ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags, DisableExposingVariablesContentOnValidationError: s.engineExecutionConfiguration.DisableExposingVariablesContentOnValidationError, ComplexityLimits: s.securityConfiguration.ComplexityLimits, + EnableFieldArgumentMapping: s.subscriptionHooks.needFieldArgumentMapping(), }) operationPlanner := NewOperationPlanner(executor, gm.planCache) @@ -1470,6 +1471,7 @@ func (s *graphServer) buildGraphMux( ComputeOperationSha256: computeSha256, ApolloCompatibilityFlags: &s.apolloCompatibilityFlags, DisableVariablesRemapping: s.engineExecutionConfiguration.DisableVariablesRemapping, + MapFieldArguments: s.subscriptionHooks.needFieldArgumentMapping(), ExprManager: exprManager, OmitBatchExtensions: s.batchingConfig.OmitExtensions, @@ -1495,6 +1497,7 @@ func (s *graphServer) buildGraphMux( WebSocketConfiguration: s.webSocketConfiguration, ClientHeader: s.clientHeader, DisableVariablesRemapping: s.engineExecutionConfiguration.DisableVariablesRemapping, + MapFieldArguments: s.subscriptionHooks.needFieldArgumentMapping(), ApolloCompatibilityFlags: s.apolloCompatibilityFlags, }) diff --git a/router/core/graphql_prehandler.go b/router/core/graphql_prehandler.go index 02a96f69e4..82867e07c0 100644 --- a/router/core/graphql_prehandler.go +++ b/router/core/graphql_prehandler.go @@ -67,6 +67,7 @@ type PreHandlerOptions struct { ComputeOperationSha256 bool ApolloCompatibilityFlags *config.ApolloCompatibilityFlags DisableVariablesRemapping bool + MapFieldArguments bool ExprManager *expr.Manager OmitBatchExtensions bool @@ -102,6 +103,7 @@ type PreHandler struct { apolloCompatibilityFlags *config.ApolloCompatibilityFlags variableParsePool astjson.ParserPool disableVariablesRemapping bool + mapFieldArguments bool exprManager *expr.Manager omitBatchExtensions bool @@ -160,6 +162,7 @@ func NewPreHandler(opts *PreHandlerOptions) *PreHandler { computeOperationSha256: opts.ComputeOperationSha256, apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, disableVariablesRemapping: opts.DisableVariablesRemapping, + mapFieldArguments: opts.MapFieldArguments, exprManager: opts.ExprManager, omitBatchExtensions: opts.OmitBatchExtensions, @@ -783,7 +786,7 @@ func (h *PreHandler) handleOperation(w http.ResponseWriter, req *http.Request, v * Normalize the variables */ - cached, uploadsMapping, err := operationKit.NormalizeVariables() + cached, uploadsMapping, fieldArgMapping, err := operationKit.NormalizeVariables() if err != nil { rtrace.AttachErrToSpan(engineNormalizeSpan, err) @@ -798,6 +801,7 @@ func (h *PreHandler) handleOperation(w http.ResponseWriter, req *http.Request, v engineNormalizeSpan.End() return err } + // Store the field argument mapping for later use when creating Arguments engineNormalizeSpan.SetAttributes(otel.WgVariablesNormalizationCacheHit.Bool(cached)) requestContext.operation.variablesNormalizationCacheHit = cached @@ -907,6 +911,7 @@ func (h *PreHandler) handleOperation(w http.ResponseWriter, req *http.Request, v requestContext.operation.rawContent = operationKit.parsedOperation.Request.Query requestContext.operation.content = operationKit.parsedOperation.NormalizedRepresentation requestContext.operation.variables, err = variablesParser.ParseBytes(operationKit.parsedOperation.Request.Variables) + if err != nil { rtrace.AttachErrToSpan(engineNormalizeSpan, err) if !requestContext.operation.traceOptions.ExcludeNormalizeStats { @@ -915,6 +920,14 @@ func (h *PreHandler) handleOperation(w http.ResponseWriter, req *http.Request, v engineNormalizeSpan.End() return err } + + if h.mapFieldArguments { + requestContext.operation.fieldArguments = NewArguments( + fieldArgMapping, + requestContext.operation.variables, + ) + } + requestContext.operation.normalizationTime = time.Since(startNormalization) requestContext.expressionContext.Request.Operation.NormalizationTime = requestContext.operation.normalizationTime setTelemetryAttributes(normalizeCtx, requestContext, expr.BucketNormalizationTime) diff --git a/router/core/operation_processor.go b/router/core/operation_processor.go index 8ec9ba5bd1..b19f08d87e 100644 --- a/router/core/operation_processor.go +++ b/router/core/operation_processor.go @@ -125,6 +125,7 @@ type OperationProcessorOptions struct { ComplexityLimits *config.ComplexityLimits ParserTokenizerLimits astparser.TokenizerLimits OperationNameLengthLimit int + EnableFieldArgumentMapping bool } // OperationProcessor provides shared resources to the parseKit and OperationKit. @@ -780,6 +781,10 @@ type VariablesNormalizationCacheEntry struct { // request spec for file uploads. uploadsMapping []uploads.UploadPathMapping + // fieldArgumentMapping maps field arguments to their variable names for fast lookup. + // This is populated during variable normalization and cached to avoid repeated AST walks. + fieldArgumentMapping astnormalization.FieldArgumentMapping + // reparse indicates whether the operation document needs to be reparsed from // its string representation when retrieved from the cache. reparse bool @@ -907,10 +912,10 @@ func (o *OperationKit) normalizeVariablesCacheKey() uint64 { } // NormalizeVariables normalizes variables and returns a slice of upload mappings -// if any of them were present in a query. +// if any of them were present in a query, as well as the field argument mapping. // If normalized values were found in the cache, it skips normalization and returns the caching set to true. // If an error is returned, then caching is set to false. -func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.UploadPathMapping, err error) { +func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.UploadPathMapping, fieldArgMapping astnormalization.FieldArgumentMapping, err error) { cacheKey := o.normalizeVariablesCacheKey() if o.cache != nil && o.cache.variablesNormalizationCache != nil { entry, ok := o.cache.variablesNormalizationCache.Get(cacheKey) @@ -921,10 +926,10 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo if entry.reparse { if err = o.setAndParseOperationDoc(); err != nil { - return false, nil, err + return false, nil, nil, err } } - return true, entry.uploadsMapping, nil + return true, entry.uploadsMapping, entry.fieldArgumentMapping, nil } } @@ -935,11 +940,14 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo copy(operationRawBytesBefore, o.kit.doc.Input.RawBytes) report := &operationreport.Report{} - uploadsMapping := o.kit.variablesNormalizer.NormalizeOperation(o.kit.doc, o.operationProcessor.executor.ClientSchema, report) + normalizerResult := o.kit.variablesNormalizer.NormalizeOperation(o.kit.doc, o.operationProcessor.executor.ClientSchema, report) if report.HasErrors() { - return false, nil, &reportError{report: report} + return false, nil, nil, &reportError{report: report} } + uploadsMapping := normalizerResult.UploadsMapping + fieldArgumentMapping := normalizerResult.FieldArgumentMapping + // Assuming the user sends a multi-operation document // During normalization, we removed the unused operations from the document // This will always lead to operation definitions of a length of 1 even when multiple operations are sent @@ -959,14 +967,14 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo err = o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation) if err != nil { - return false, nil, err + return false, nil, nil, err } // Reset the doc with the original name o.kit.doc.OperationDefinitions[o.operationDefinitionRef].Name = nameRef _, err = o.kit.keyGen.Write(o.kit.normalizedOperation.Bytes()) if err != nil { - return false, nil, err + return false, nil, nil, err } o.parsedOperation.ID = o.kit.keyGen.Sum64() o.kit.keyGen.Reset() @@ -976,6 +984,7 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo if o.cache != nil && o.cache.variablesNormalizationCache != nil { entry := VariablesNormalizationCacheEntry{ uploadsMapping: uploadsMapping, + fieldArgumentMapping: fieldArgumentMapping, id: o.parsedOperation.ID, normalizedRepresentation: o.parsedOperation.NormalizedRepresentation, variables: o.parsedOperation.Request.Variables, @@ -984,14 +993,14 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo o.cache.variablesNormalizationCache.Set(cacheKey, entry, 1) } - return false, uploadsMapping, nil + return false, uploadsMapping, fieldArgumentMapping, nil } o.kit.normalizedOperation.Reset() err = o.kit.printer.Print(o.kit.doc, o.kit.normalizedOperation) if err != nil { - return false, nil, err + return false, nil, nil, err } o.parsedOperation.NormalizedRepresentation = o.kit.normalizedOperation.String() @@ -1000,6 +1009,7 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo if o.cache != nil && o.cache.variablesNormalizationCache != nil { entry := VariablesNormalizationCacheEntry{ uploadsMapping: uploadsMapping, + fieldArgumentMapping: fieldArgumentMapping, id: o.parsedOperation.ID, normalizedRepresentation: o.parsedOperation.NormalizedRepresentation, variables: o.parsedOperation.Request.Variables, @@ -1008,7 +1018,7 @@ func (o *OperationKit) NormalizeVariables() (cached bool, mapping []uploads.Uplo o.cache.variablesNormalizationCache.Set(cacheKey, entry, 1) } - return false, uploadsMapping, nil + return false, uploadsMapping, fieldArgumentMapping, nil } func (o *OperationKit) remapVariablesCacheKey(disabled bool) uint64 { @@ -1399,6 +1409,7 @@ type parseKitOptions struct { apolloCompatibilityFlags config.ApolloCompatibilityFlags apolloRouterCompatibilityFlags config.ApolloRouterCompatibilityFlags disableExposingVariablesContentOnValidationError bool + enableFieldArgumentMapping bool } func createParseKit(i int, options *parseKitOptions) *parseKit { @@ -1414,7 +1425,7 @@ func createParseKit(i int, options *parseKitOptions) *parseKit { astnormalization.WithRemoveFragmentDefinitions(), astnormalization.WithRemoveUnusedVariables(), ), - variablesNormalizer: astnormalization.NewVariablesNormalizer(), + variablesNormalizer: astnormalization.NewVariablesNormalizer(options.enableFieldArgumentMapping), variablesRemapper: astnormalization.NewVariablesMapper(), printer: &astprinter.Printer{}, normalizedOperation: &bytes.Buffer{}, @@ -1453,6 +1464,7 @@ func NewOperationProcessor(opts OperationProcessorOptions) *OperationProcessor { apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, apolloRouterCompatibilityFlags: opts.ApolloRouterCompatibilityFlags, disableExposingVariablesContentOnValidationError: opts.DisableExposingVariablesContentOnValidationError, + enableFieldArgumentMapping: opts.EnableFieldArgumentMapping, }, } for i := 0; i < opts.ParseKitPoolSize; i++ { diff --git a/router/core/operation_processor_test.go b/router/core/operation_processor_test.go index 589fcdf50c..d92bdc6bd1 100644 --- a/router/core/operation_processor_test.go +++ b/router/core/operation_processor_test.go @@ -261,7 +261,7 @@ func TestNormalizeVariablesOperationProcessor(t *testing.T) { _, err = kit.NormalizeOperation("test", false) require.NoError(t, err) - _, _, err = kit.NormalizeVariables() + _, _, _, err = kit.NormalizeVariables() require.NoError(t, err) assert.Equal(t, tc.ExpectedNormalizedRepresentation, kit.parsedOperation.NormalizedRepresentation) diff --git a/router/core/router_config.go b/router/core/router_config.go index 319216a18a..afbe03db78 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -46,6 +46,10 @@ type onReceiveEventsHooks struct { timeout time.Duration } +func (h *subscriptionHooks) needFieldArgumentMapping() bool { + return len(h.onStart.handlers) > 0 || len(h.onPublishEvents.handlers) > 0 || len(h.onReceiveEvents.handlers) > 0 +} + type Config struct { clusterName string instanceID string diff --git a/router/core/websocket.go b/router/core/websocket.go index 35d6ffc304..573dedf6e6 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -63,6 +63,7 @@ type WebsocketMiddlewareOptions struct { ClientHeader config.ClientHeader DisableVariablesRemapping bool + MapFieldArguments bool ApolloCompatibilityFlags config.ApolloCompatibilityFlags } @@ -85,6 +86,7 @@ func NewWebsocketMiddleware(ctx context.Context, opts WebsocketMiddlewareOptions config: opts.WebSocketConfiguration, clientHeader: opts.ClientHeader, disableVariablesRemapping: opts.DisableVariablesRemapping, + mapFieldArguments: opts.MapFieldArguments, apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, } if opts.WebSocketConfiguration != nil && opts.WebSocketConfiguration.AbsintheProtocol.Enabled { @@ -265,6 +267,7 @@ type WebsocketHandler struct { clientHeader config.ClientHeader disableVariablesRemapping bool + mapFieldArguments bool apolloCompatibilityFlags config.ApolloCompatibilityFlags } @@ -372,6 +375,7 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R ForwardUpgradeHeaders: h.forwardUpgradeHeadersConfig, ForwardQueryParams: h.forwardQueryParamsConfig, DisableVariablesRemapping: h.disableVariablesRemapping, + MapFieldArguments: h.mapFieldArguments, ApolloCompatibilityFlags: h.apolloCompatibilityFlags, }) err = handler.Initialize() @@ -713,6 +717,7 @@ type WebSocketConnectionHandlerOptions struct { ForwardUpgradeHeaders forwardConfig ForwardQueryParams forwardConfig DisableVariablesRemapping bool + MapFieldArguments bool ApolloCompatibilityFlags config.ApolloCompatibilityFlags } @@ -750,6 +755,7 @@ type WebSocketConnectionHandler struct { forwardQueryParams *forwardConfig disableVariablesRemapping bool + mapFieldArguments bool apolloCompatibilityFlags config.ApolloCompatibilityFlags @@ -791,6 +797,7 @@ func NewWebsocketConnectionHandler(ctx context.Context, opts WebSocketConnection forwardInitialPayload: opts.ForwardInitialPayload, plannerOptions: opts.PlanOptions, disableVariablesRemapping: opts.DisableVariablesRemapping, + mapFieldArguments: opts.MapFieldArguments, apolloCompatibilityFlags: opts.ApolloCompatibilityFlags, clientInfoFromInitialPayload: opts.ClientInfoFromInitialPayload, } @@ -907,7 +914,7 @@ func (h *WebSocketConnectionHandler) parseAndPlan(registration *SubscriptionRegi } opContext.normalizationCacheHit = operationKit.parsedOperation.NormalizationCacheHit - cached, _, err := operationKit.NormalizeVariables() + cached, _, fieldArgMapping, err := operationKit.NormalizeVariables() if err != nil { opContext.normalizationTime = time.Since(startNormalization) return nil, nil, err @@ -932,6 +939,13 @@ func (h *WebSocketConnectionHandler) parseAndPlan(registration *SubscriptionRegi return nil, nil, err } + if h.mapFieldArguments { + opContext.fieldArguments = NewArguments( + fieldArgMapping, + opContext.variables, + ) + } + startValidation := time.Now() _, _, err = operationKit.ValidateQueryComplexity() diff --git a/router/go.mod b/router/go.mod index 415918d050..43b24fdc15 100644 --- a/router/go.mod +++ b/router/go.mod @@ -31,7 +31,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/twmb/franz-go v1.16.1 - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.245 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.246.0.20260202150435-e2c713b42e65 // Do not upgrade, it renames attributes we rely on go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 go.opentelemetry.io/contrib/propagators/b3 v1.23.0 diff --git a/router/go.sum b/router/go.sum index 715b552f47..abc4185299 100644 --- a/router/go.sum +++ b/router/go.sum @@ -322,8 +322,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.245 h1:MYewlXgIhI9jusocPUeyo346J3M5cqzc6ddru1qp+S8= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.245/go.mod h1:mX25ASEQiKamxaFSK6NZihh0oDCigIuzro30up4mFH4= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.246.0.20260202150435-e2c713b42e65 h1:o5wqeMnGK2GdyTXlcLKZzPRrVtpzT55Iua5zyMT4ErU= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.246.0.20260202150435-e2c713b42e65/go.mod h1:mX25ASEQiKamxaFSK6NZihh0oDCigIuzro30up4mFH4= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=