From 3a993b44a90a3a06a6f2463c69b16c17bf45ab48 Mon Sep 17 00:00:00 2001 From: "rajinder.saini" Date: Thu, 11 Dec 2025 11:55:02 -0800 Subject: [PATCH 1/4] rate limiter per client --- router/core/graph_server.go | 28 +++++-- router/core/ratelimiter.go | 62 +++++++++++--- router/core/ratelimiter_test.go | 139 ++++++++++++++++++++++++++++++-- router/pkg/config/config.go | 51 +++++++----- 4 files changed, 232 insertions(+), 48 deletions(-) diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 87aa96331c..6e345d88ad 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -114,6 +114,21 @@ type BuildGraphMuxOptions struct { RoutingUrlGroupings map[string]map[string]bool } +func rateLimitOverridesFromConfig(source map[string]config.RateLimitSimpleOverride) map[string]RateLimitOverride { + if len(source) == 0 { + return nil + } + converted := make(map[string]RateLimitOverride, len(source)) + for suffix, override := range source { + converted[suffix] = RateLimitOverride{ + Rate: override.Rate, + Burst: override.Burst, + Period: override.Period, + } + } + return converted +} + func (b BuildGraphMuxOptions) IsBaseGraph() bool { return b.FeatureFlagName == "" } @@ -519,12 +534,12 @@ type graphMux struct { validationCache *ristretto.Cache[uint64, bool] operationHashCache *ristretto.Cache[uint64, string] - accessLogsFileLogger *logging.BufferedLogger - metricStore rmetric.Store - prometheusCacheMetrics *rmetric.CacheMetrics - otelCacheMetrics *rmetric.CacheMetrics - streamMetricStore rmetric.StreamMetricStore - prometheusMetricsExporter *graphqlmetrics.PrometheusMetricsExporter + accessLogsFileLogger *logging.BufferedLogger + metricStore rmetric.Store + prometheusCacheMetrics *rmetric.CacheMetrics + otelCacheMetrics *rmetric.CacheMetrics + streamMetricStore rmetric.StreamMetricStore + prometheusMetricsExporter *graphqlmetrics.PrometheusMetricsExporter } // buildOperationCaches creates the caches for the graph mux. @@ -1404,6 +1419,7 @@ func (s *graphServer) buildGraphMux( RejectStatusCode: s.rateLimit.SimpleStrategy.RejectStatusCode, KeySuffixExpression: s.rateLimit.KeySuffixExpression, ExprManager: exprManager, + Overrides: rateLimitOverridesFromConfig(s.rateLimit.SimpleStrategy.Overrides), }) if err != nil { return nil, fmt.Errorf("failed to create rate limiter: %w", err) diff --git a/router/core/ratelimiter.go b/router/core/ratelimiter.go index c206c3a421..058da9affa 100644 --- a/router/core/ratelimiter.go +++ b/router/core/ratelimiter.go @@ -8,7 +8,9 @@ import ( "fmt" "io" "reflect" + "strings" "sync" + "time" rd "github.com/wundergraph/cosmo/router/internal/rediscloser" @@ -30,15 +32,32 @@ type CosmoRateLimiterOptions struct { KeySuffixExpression string ExprManager *expr.Manager + Overrides map[string]RateLimitOverride } func NewCosmoRateLimiter(opts *CosmoRateLimiterOptions) (rl *CosmoRateLimiter, err error) { limiter := redis_rate.NewLimiter(opts.RedisClient) + var overrides map[string]redis_rate.Limit + if len(opts.Overrides) > 0 { + overrides = make(map[string]redis_rate.Limit, len(opts.Overrides)) + for rawKey, override := range opts.Overrides { + key := strings.TrimSpace(rawKey) + if key == "" { + continue + } + overrides[key] = redis_rate.Limit{ + Rate: override.Rate, + Burst: override.Burst, + Period: override.Period, + } + } + } rl = &CosmoRateLimiter{ client: opts.RedisClient, limiter: limiter, debug: opts.Debug, rejectStatusCode: opts.RejectStatusCode, + keyOverrides: overrides, } if rl.rejectStatusCode == 0 { rl.rejectStatusCode = 200 @@ -54,12 +73,23 @@ func NewCosmoRateLimiter(opts *CosmoRateLimiterOptions) (rl *CosmoRateLimiter, e type CosmoRateLimiter struct { client rd.RDCloser - limiter *redis_rate.Limiter + limiter redisLimiter debug bool rejectStatusCode int keySuffixProgram *vm.Program + keyOverrides map[string]redis_rate.Limit +} + +type redisLimiter interface { + AllowN(ctx context.Context, key string, limit redis_rate.Limit, n int) (*redis_rate.Result, error) +} + +type RateLimitOverride struct { + Rate int + Burst int + Period time.Duration } func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve.FetchInfo, input json.RawMessage) (result *resolve.RateLimitDeny, err error) { @@ -67,15 +97,11 @@ func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve return nil, nil } requestRate := c.calculateRate() - limit := redis_rate.Limit{ - Rate: ctx.RateLimitOptions.Rate, - Burst: ctx.RateLimitOptions.Burst, - Period: ctx.RateLimitOptions.Period, - } - key, err := c.generateKey(ctx) + key, _, err := c.generateKey(ctx) if err != nil { return nil, err } + limit := c.limitFor(ctx, key) allow, err := c.limiter.AllowN(ctx.Context(), key, limit, requestRate) if err != nil { return nil, err @@ -90,23 +116,35 @@ func (c *CosmoRateLimiter) RateLimitPreFetch(ctx *resolve.Context, info *resolve return &resolve.RateLimitDeny{}, nil } -func (c *CosmoRateLimiter) generateKey(ctx *resolve.Context) (string, error) { +func (c *CosmoRateLimiter) generateKey(ctx *resolve.Context) (string, string, error) { if c.keySuffixProgram == nil { - return ctx.RateLimitOptions.RateLimitKey, nil + return ctx.RateLimitOptions.RateLimitKey, "", nil } rc := getRequestContext(ctx.Context()) if rc == nil { - return "", errors.New("no request context") + return "", "", errors.New("no request context") } str, err := expr.ResolveStringExpression(c.keySuffixProgram, rc.expressionContext) if err != nil { - return "", fmt.Errorf("failed to resolve key suffix expression: %w", err) + return "", "", fmt.Errorf("failed to resolve key suffix expression: %w", err) } buf := bytes.NewBuffer(make([]byte, 0, len(ctx.RateLimitOptions.RateLimitKey)+len(str)+1)) _, _ = buf.WriteString(ctx.RateLimitOptions.RateLimitKey) _ = buf.WriteByte(':') _, _ = buf.WriteString(str) - return buf.String(), nil + return buf.String(), str, nil +} + +func (c *CosmoRateLimiter) limitFor(ctx *resolve.Context, key string) redis_rate.Limit { + limit := redis_rate.Limit{ + Rate: ctx.RateLimitOptions.Rate, + Burst: ctx.RateLimitOptions.Burst, + Period: ctx.RateLimitOptions.Period, + } + if override, ok := c.keyOverrides[key]; ok { + return override + } + return limit } func (c *CosmoRateLimiter) RejectStatusCode() int { diff --git a/router/core/ratelimiter_test.go b/router/core/ratelimiter_test.go index f260ac1ae0..affb479d59 100644 --- a/router/core/ratelimiter_test.go +++ b/router/core/ratelimiter_test.go @@ -4,7 +4,9 @@ import ( "context" "net/http" "testing" + "time" + "github.com/go-redis/redis_rate/v10" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/internal/expr" @@ -36,14 +38,17 @@ func expressionResolveContext(t *testing.T, header http.Header, claims map[strin func TestRateLimiterGenerateKey(t *testing.T) { t.Parallel() + t.Run("default", func(t *testing.T) { t.Parallel() rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{}) assert.NoError(t, err) - key, err := rl.generateKey(expressionResolveContext(t, nil, nil)) + key, suffix, err := rl.generateKey(expressionResolveContext(t, nil, nil)) assert.NoError(t, err) assert.Equal(t, "test", key) + assert.Equal(t, "", suffix) }) + t.Run("from header", func(t *testing.T) { t.Parallel() rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ @@ -51,12 +56,14 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) require.NoError(t, err) - key, err := rl.generateKey( + key, suffix, err := rl.generateKey( expressionResolveContext(t, http.Header{"Authorization": []string{"token"}}, nil), ) assert.NoError(t, err) assert.Equal(t, "test:token", key) + assert.Equal(t, "token", suffix) }) + t.Run("from header number", func(t *testing.T) { t.Parallel() rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ @@ -64,12 +71,14 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( + key, suffix, err := rl.generateKey( expressionResolveContext(t, http.Header{"Authorization": []string{"123"}}, nil), ) assert.NoError(t, err) assert.Equal(t, "test:123", key) + assert.Equal(t, "123", suffix) }) + t.Run("from header whitespace", func(t *testing.T) { t.Parallel() rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ @@ -77,12 +86,14 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( + key, suffix, err := rl.generateKey( expressionResolveContext(t, http.Header{"Authorization": []string{" token "}}, nil), ) assert.NoError(t, err) assert.Equal(t, "test:token", key) + assert.Equal(t, "token", suffix) }) + t.Run("from claims", func(t *testing.T) { t.Parallel() rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ @@ -90,12 +101,14 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( + key, suffix, err := rl.generateKey( expressionResolveContext(t, nil, map[string]any{"sub": "token"}), ) assert.NoError(t, err) assert.Equal(t, "test:token", key) + assert.Equal(t, "token", suffix) }) + t.Run("from claims invalid claim", func(t *testing.T) { t.Parallel() rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ @@ -103,12 +116,14 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( + key, suffix, err := rl.generateKey( expressionResolveContext(t, nil, map[string]any{"sub": 123}), ) assert.Error(t, err) assert.Empty(t, key) + assert.Empty(t, suffix) }) + t.Run("from claims or X-Forwarded-For header claims present", func(t *testing.T) { t.Parallel() rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ @@ -116,12 +131,14 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( + key, suffix, err := rl.generateKey( expressionResolveContext(t, http.Header{"X-Forwarded-For": []string{"192.168.0.1"}}, map[string]any{"sub": "token"}), ) assert.NoError(t, err) assert.Equal(t, "test:token", key) + assert.Equal(t, "token", suffix) }) + t.Run("from claims or X-Forwarded-For header claims not present", func(t *testing.T) { t.Parallel() rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ @@ -129,12 +146,118 @@ func TestRateLimiterGenerateKey(t *testing.T) { ExprManager: expr.CreateNewExprManager(), }) assert.NoError(t, err) - key, err := rl.generateKey( + key, suffix, err := rl.generateKey( expressionResolveContext(t, http.Header{"X-Forwarded-For": []string{"192.168.0.1"}}, nil), ) assert.NoError(t, err) assert.Equal(t, "test:192.168.0.1", key) + assert.Equal(t, "192.168.0.1", suffix) + }) +} + +func TestRateLimiterOverrides(t *testing.T) { + t.Parallel() + + baseCtx := func(header http.Header) *resolve.Context { + ctx := expressionResolveContext(t, header, nil) + ctx.RateLimitOptions.RateLimitKey = "cosmo_rate_limit" + ctx.RateLimitOptions.Rate = 100 + ctx.RateLimitOptions.Burst = 50 + ctx.RateLimitOptions.Period = 2 * time.Second + return ctx + } + + info := &resolve.FetchInfo{RootFields: []resolve.GraphCoordinate{{TypeName: "Query", FieldName: "product"}}} + + t.Run("uses override when key matches", func(t *testing.T) { + t.Parallel() + overrideLimit := RateLimitOverride{Rate: 5, Burst: 10, Period: time.Second} + rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ + KeySuffixExpression: "request.header.Get('Authorization')", + ExprManager: expr.CreateNewExprManager(), + Overrides: map[string]RateLimitOverride{ + "cosmo_rate_limit:planA": overrideLimit, + }, + }) + require.NoError(t, err) + + fake := &fakeLimiter{result: &redis_rate.Result{Allowed: 1, Remaining: 9}} + rl.limiter = fake + + _, err = rl.RateLimitPreFetch(baseCtx(http.Header{"Authorization": []string{"planA"}}), info, nil) + require.NoError(t, err) + + expected := redis_rate.Limit{Rate: overrideLimit.Rate, Burst: overrideLimit.Burst, Period: overrideLimit.Period} + assert.Equal(t, "cosmo_rate_limit:planA", fake.lastKey) + assert.Equal(t, expected, fake.lastLimit) + }) + + t.Run("falls back to default when key missing", func(t *testing.T) { + t.Parallel() + rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ + KeySuffixExpression: "request.header.Get('Authorization')", + ExprManager: expr.CreateNewExprManager(), + Overrides: map[string]RateLimitOverride{ + "planA": {Rate: 5, Burst: 10, Period: time.Second}, + }, + }) + require.NoError(t, err) + + fake := &fakeLimiter{result: &redis_rate.Result{Allowed: 1, Remaining: 9}} + rl.limiter = fake + + ctx := baseCtx(http.Header{"Authorization": []string{"unknown"}}) + _, err = rl.RateLimitPreFetch(ctx, info, nil) + require.NoError(t, err) + + expected := redis_rate.Limit{Rate: ctx.RateLimitOptions.Rate, Burst: ctx.RateLimitOptions.Burst, Period: ctx.RateLimitOptions.Period} + assert.Equal(t, "cosmo_rate_limit:unknown", fake.lastKey) + assert.Equal(t, expected, fake.lastLimit) }) + + t.Run("uses client_id claim for suffix", func(t *testing.T) { + t.Parallel() + overrideLimit := RateLimitOverride{Rate: 1000, Burst: 1000, Period: time.Second} + rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ + KeySuffixExpression: "request.auth.claims.client_id", + ExprManager: expr.CreateNewExprManager(), + Overrides: map[string]RateLimitOverride{ + "cosmo_rate_limit:id_my_client": overrideLimit, + }, + }) + require.NoError(t, err) + + ctx := expressionResolveContext(t, nil, map[string]any{"client_id": "id_my_client"}) + ctx.RateLimitOptions.RateLimitKey = "cosmo_rate_limit" + ctx.RateLimitOptions.Rate = 100 + ctx.RateLimitOptions.Burst = 100 + ctx.RateLimitOptions.Period = time.Second + + fake := &fakeLimiter{result: &redis_rate.Result{Allowed: 1, Remaining: 999}} + rl.limiter = fake + + _, err = rl.RateLimitPreFetch(ctx, info, nil) + require.NoError(t, err) + + expected := redis_rate.Limit{Rate: overrideLimit.Rate, Burst: overrideLimit.Burst, Period: overrideLimit.Period} + assert.Equal(t, "cosmo_rate_limit:id_my_client", fake.lastKey) + assert.Equal(t, expected, fake.lastLimit) + }) +} + +type fakeLimiter struct { + lastKey string + lastLimit redis_rate.Limit + lastRequestRate int + result *redis_rate.Result + err error +} + +func (f *fakeLimiter) AllowN(ctx context.Context, key string, limit redis_rate.Limit, n int) (*redis_rate.Result, error) { + f.lastKey = key + f.lastLimit = limit + f.lastRequestRate = n + return f.result, f.err } func ContextWithClaims(ctx *resolve.Context, claims map[string]any) *resolve.Context { diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 4e430627c0..5135cfcd19 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -360,25 +360,25 @@ func (r *ResponseHeaderRule) GetMatching() string { } type EngineDebugConfiguration struct { - PrintOperationTransformations bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_OPERATION_TRANSFORMATIONS" yaml:"print_operation_transformations"` - PrintOperationEnableASTRefs bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_OPERATION_ENABLE_AST_REFS" yaml:"print_operation_enable_ast_refs"` - PrintPlanningPaths bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_PLANNING_PATHS" yaml:"print_planning_paths"` - PrintQueryPlans bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_QUERY_PLANS" yaml:"print_query_plans"` - PrintIntermediateQueryPlans bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_INTERMEDIATE_QUERY_PLANS" yaml:"print_intermediate_query_plans"` - PrintNodeSuggestions bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_NODE_SUGGESTIONS" yaml:"print_node_suggestions"` - ConfigurationVisitor bool `envDefault:"false" env:"ENGINE_DEBUG_CONFIGURATION_VISITOR" yaml:"configuration_visitor"` - PlanningVisitor bool `envDefault:"false" env:"ENGINE_DEBUG_PLANNING_VISITOR" yaml:"planning_visitor"` - DatasourceVisitor bool `envDefault:"false" env:"ENGINE_DEBUG_DATASOURCE_VISITOR" yaml:"datasource_visitor"` - ReportWebSocketConnections bool `envDefault:"false" env:"ENGINE_DEBUG_REPORT_WEBSOCKET_CONNECTIONS" yaml:"report_websocket_connections"` - ReportMemoryUsage bool `envDefault:"false" env:"ENGINE_DEBUG_REPORT_MEMORY_USAGE" yaml:"report_memory_usage"` - EnableResolverDebugging bool `envDefault:"false" env:"ENGINE_DEBUG_ENABLE_RESOLVER_DEBUGGING" yaml:"enable_resolver_debugging"` + PrintOperationTransformations bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_OPERATION_TRANSFORMATIONS" yaml:"print_operation_transformations"` + PrintOperationEnableASTRefs bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_OPERATION_ENABLE_AST_REFS" yaml:"print_operation_enable_ast_refs"` + PrintPlanningPaths bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_PLANNING_PATHS" yaml:"print_planning_paths"` + PrintQueryPlans bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_QUERY_PLANS" yaml:"print_query_plans"` + PrintIntermediateQueryPlans bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_INTERMEDIATE_QUERY_PLANS" yaml:"print_intermediate_query_plans"` + PrintNodeSuggestions bool `envDefault:"false" env:"ENGINE_DEBUG_PRINT_NODE_SUGGESTIONS" yaml:"print_node_suggestions"` + ConfigurationVisitor bool `envDefault:"false" env:"ENGINE_DEBUG_CONFIGURATION_VISITOR" yaml:"configuration_visitor"` + PlanningVisitor bool `envDefault:"false" env:"ENGINE_DEBUG_PLANNING_VISITOR" yaml:"planning_visitor"` + DatasourceVisitor bool `envDefault:"false" env:"ENGINE_DEBUG_DATASOURCE_VISITOR" yaml:"datasource_visitor"` + ReportWebSocketConnections bool `envDefault:"false" env:"ENGINE_DEBUG_REPORT_WEBSOCKET_CONNECTIONS" yaml:"report_websocket_connections"` + ReportMemoryUsage bool `envDefault:"false" env:"ENGINE_DEBUG_REPORT_MEMORY_USAGE" yaml:"report_memory_usage"` + EnableResolverDebugging bool `envDefault:"false" env:"ENGINE_DEBUG_ENABLE_RESOLVER_DEBUGGING" yaml:"enable_resolver_debugging"` // EnablePersistedOperationsCacheResponseHeader is deprecated, use EnableCacheResponseHeaders instead. EnablePersistedOperationsCacheResponseHeader bool `envDefault:"false" env:"ENGINE_DEBUG_ENABLE_PERSISTED_OPERATIONS_CACHE_RESPONSE_HEADER" yaml:"enable_persisted_operations_cache_response_header"` // EnableNormalizationCacheResponseHeader is deprecated, use EnableCacheResponseHeaders instead. - EnableNormalizationCacheResponseHeader bool `envDefault:"false" env:"ENGINE_DEBUG_ENABLE_NORMALIZATION_CACHE_RESPONSE_HEADER" yaml:"enable_normalization_cache_response_header"` - EnableCacheResponseHeaders bool `envDefault:"false" env:"ENGINE_DEBUG_ENABLE_CACHE_RESPONSE_HEADERS" yaml:"enable_cache_response_headers"` - AlwaysIncludeQueryPlan bool `envDefault:"false" env:"ENGINE_DEBUG_ALWAYS_INCLUDE_QUERY_PLAN" yaml:"always_include_query_plan"` - AlwaysSkipLoader bool `envDefault:"false" env:"ENGINE_DEBUG_ALWAYS_SKIP_LOADER" yaml:"always_skip_loader"` + EnableNormalizationCacheResponseHeader bool `envDefault:"false" env:"ENGINE_DEBUG_ENABLE_NORMALIZATION_CACHE_RESPONSE_HEADER" yaml:"enable_normalization_cache_response_header"` + EnableCacheResponseHeaders bool `envDefault:"false" env:"ENGINE_DEBUG_ENABLE_CACHE_RESPONSE_HEADERS" yaml:"enable_cache_response_headers"` + AlwaysIncludeQueryPlan bool `envDefault:"false" env:"ENGINE_DEBUG_ALWAYS_INCLUDE_QUERY_PLAN" yaml:"always_include_query_plan"` + AlwaysSkipLoader bool `envDefault:"false" env:"ENGINE_DEBUG_ALWAYS_SKIP_LOADER" yaml:"always_skip_loader"` } type EngineExecutionConfiguration struct { @@ -554,12 +554,19 @@ type RedisConfiguration struct { } type RateLimitSimpleStrategy struct { - Rate int `yaml:"rate" envDefault:"10" env:"RATE_LIMIT_SIMPLE_RATE"` - Burst int `yaml:"burst" envDefault:"10" env:"RATE_LIMIT_SIMPLE_BURST"` - Period time.Duration `yaml:"period" envDefault:"1s" env:"RATE_LIMIT_SIMPLE_PERIOD"` - RejectExceedingRequests bool `yaml:"reject_exceeding_requests" envDefault:"false" env:"RATE_LIMIT_SIMPLE_REJECT_EXCEEDING_REQUESTS"` - RejectStatusCode int `yaml:"reject_status_code" envDefault:"200" env:"RATE_LIMIT_SIMPLE_REJECT_STATUS_CODE"` - HideStatsFromResponseExtension bool `yaml:"hide_stats_from_response_extension" envDefault:"false" env:"RATE_LIMIT_SIMPLE_HIDE_STATS_FROM_RESPONSE_EXTENSION"` + Rate int `yaml:"rate" envDefault:"10" env:"RATE_LIMIT_SIMPLE_RATE"` + Burst int `yaml:"burst" envDefault:"10" env:"RATE_LIMIT_SIMPLE_BURST"` + Period time.Duration `yaml:"period" envDefault:"1s" env:"RATE_LIMIT_SIMPLE_PERIOD"` + RejectExceedingRequests bool `yaml:"reject_exceeding_requests" envDefault:"false" env:"RATE_LIMIT_SIMPLE_REJECT_EXCEEDING_REQUESTS"` + RejectStatusCode int `yaml:"reject_status_code" envDefault:"200" env:"RATE_LIMIT_SIMPLE_REJECT_STATUS_CODE"` + HideStatsFromResponseExtension bool `yaml:"hide_stats_from_response_extension" envDefault:"false" env:"RATE_LIMIT_SIMPLE_HIDE_STATS_FROM_RESPONSE_EXTENSION"` + Overrides map[string]RateLimitSimpleOverride `yaml:"overrides,omitempty" json:"overrides,omitempty"` +} + +type RateLimitSimpleOverride struct { + Rate int `yaml:"rate"` + Burst int `yaml:"burst"` + Period time.Duration `yaml:"period"` } type CDNConfiguration struct { From 1e1b8641a16f3536091465a72448e1623e909d3a Mon Sep 17 00:00:00 2001 From: "rajinder.saini" Date: Thu, 11 Dec 2025 12:02:24 -0800 Subject: [PATCH 2/4] test with mini redis --- router/core/ratelimiter_test.go | 36 +++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/router/core/ratelimiter_test.go b/router/core/ratelimiter_test.go index affb479d59..f79f2e4624 100644 --- a/router/core/ratelimiter_test.go +++ b/router/core/ratelimiter_test.go @@ -6,7 +6,9 @@ import ( "testing" "time" + "github.com/alicebob/miniredis/v2" "github.com/go-redis/redis_rate/v10" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/internal/expr" @@ -245,6 +247,40 @@ func TestRateLimiterOverrides(t *testing.T) { }) } +func TestRateLimiterOverrideEndToEnd(t *testing.T) { + mr := miniredis.RunT(t) + client := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { + _ = client.Close() + }) + + rl, err := NewCosmoRateLimiter(&CosmoRateLimiterOptions{ + RedisClient: client, + KeySuffixExpression: "request.auth.claims.client_id", + ExprManager: expr.CreateNewExprManager(), + Overrides: map[string]RateLimitOverride{ + "cosmo_rate_limit:id_my_client": {Rate: 1, Burst: 1, Period: time.Second}, + }, + }) + require.NoError(t, err) + + ctx := expressionResolveContext(t, nil, map[string]any{"client_id": "id_my_client"}) + ctx.RateLimitOptions.RateLimitKey = "cosmo_rate_limit" + ctx.RateLimitOptions.Rate = 5 + ctx.RateLimitOptions.Burst = 5 + ctx.RateLimitOptions.Period = time.Second + + info := &resolve.FetchInfo{RootFields: []resolve.GraphCoordinate{{TypeName: "Query", FieldName: "product"}}} + + result, err := rl.RateLimitPreFetch(ctx, info, nil) + require.NoError(t, err) + assert.Nil(t, result) + + result, err = rl.RateLimitPreFetch(ctx, info, nil) + require.NoError(t, err) + assert.NotNil(t, result) +} + type fakeLimiter struct { lastKey string lastLimit redis_rate.Limit From 1b08158276e6a7227e3e0141bce361cee27d77e8 Mon Sep 17 00:00:00 2001 From: "rajinder.saini" Date: Thu, 11 Dec 2025 15:48:16 -0800 Subject: [PATCH 3/4] update router config schema --- router/pkg/config/config.schema.json | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 0b5fc698c3..eac56563db 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1969,6 +1969,27 @@ "type": "boolean", "default": false, "description": "Hide the rate limit stats from the response extension. If the value is true, the rate limit stats are not included in the response extension." + }, + "overrides": { + "type": "object", + "description": "Override the default rate limit for specific keys. The key is the full Redis key (prefix:suffix), and the value is the rate limit configuration for that key.", + "additionalProperties": { + "type": "object", + "required": ["rate", "burst", "period"], + "properties": { + "rate": { + "type": "integer", + "minimum": 1 + }, + "burst": { + "type": "integer", + "minimum": 1 + }, + "period": { + "type": "string" + } + } + } } }, "required": ["rate", "burst", "period"] From cc81d9450d3cbccb742d15c11a7ee72eef23adf2 Mon Sep 17 00:00:00 2001 From: rajinder Date: Thu, 11 Dec 2025 16:00:12 -0800 Subject: [PATCH 4/4] Update router/pkg/config/config.schema.json Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- router/pkg/config/config.schema.json | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index eac56563db..91cd680a7a 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1986,7 +1986,10 @@ "minimum": 1 }, "period": { - "type": "string" + "type": "string", + "duration": { + "minimum": "1s" + } } } }