diff --git a/core/pkg/evaluator/json.go b/core/pkg/evaluator/json.go index 45a13175d..b133142e6 100644 --- a/core/pkg/evaluator/json.go +++ b/core/pkg/evaluator/json.go @@ -335,8 +335,7 @@ func (je *Resolver) evaluateVariant(ctx context.Context, reqID string, flagKey s ) { var selector store.Selector - s := ctx.Value(store.SelectorContextKey{}) - if s != nil { + if s := ctx.Value(store.SelectorContextKey{}); s != nil { selector = s.(store.Selector) } flag, metadata, err := je.store.Get(ctx, flagKey, &selector) @@ -358,73 +357,89 @@ func (je *Resolver) evaluateVariant(ctx context.Context, reqID string, flagKey s return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.FlagDisabledErrorCode) } - // get the targeting logic, if any - targeting := flag.Targeting + if flag.Targeting != nil && string(flag.Targeting) != "{}" { + variant, reason, err := je.evaluateTargeting(ctx, reqID, flagKey, flag, evalCtx) + return variant, flag.Variants, reason, metadata, err + } - if targeting != nil && string(targeting) != "{}" { - targetingBytes, err := targeting.MarshalJSON() - if err != nil { - je.Logger.ErrorWithID(reqID, fmt.Sprintf("Error parsing rules for flag: %s, %s", flagKey, err)) - return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.ParseErrorCode) - } + variant, reason, err = je.resolveDefaultVariant(ctx, flag) + return variant, flag.Variants, reason, metadata, err +} - evalCtx = setFlagdProperties(je.Logger, evalCtx, flagdProperties{ - FlagKey: flagKey, - Timestamp: time.Now().Unix(), - }) +func (je *Resolver) evaluateTargeting( + ctx context.Context, + reqID string, + flagKey string, + flag model.Flag, + evalCtx map[string]any, +) (variant string, reason string, err error) { + targetingBytes, err := flag.Targeting.MarshalJSON() + if err != nil { + je.Logger.ErrorWithID(reqID, fmt.Sprintf("Error parsing rules for flag: %s, %s", flagKey, err)) + return "", model.ErrorReason, errors.New(model.ParseErrorCode) + } - b, err := json.Marshal(evalCtx) - if err != nil { - je.Logger.ErrorWithID(reqID, fmt.Sprintf("error parsing context for flag: %s, %s, %v", flagKey, err, evalCtx)) + evalCtx = setFlagdProperties(je.Logger, evalCtx, flagdProperties{ + FlagKey: flagKey, + Timestamp: time.Now().Unix(), + }) - return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.ErrorReason) - } + b, err := json.Marshal(evalCtx) + if err != nil { + je.Logger.ErrorWithID(reqID, fmt.Sprintf("error parsing context for flag: %s, %s, %v", flagKey, err, evalCtx)) + return "", model.ErrorReason, errors.New(model.ErrorReason) + } - var result bytes.Buffer - // evaluate JsonLogic rules to determine the variant - err = jsonlogic.Apply(bytes.NewReader(targetingBytes), bytes.NewReader(b), &result) - if err != nil { - je.Logger.ErrorWithID(reqID, fmt.Sprintf("error applying targeting rules: %s", err)) - return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.ParseErrorCode) - } + var result bytes.Buffer + // evaluate JsonLogic rules to determine the variant + if err = jsonlogic.Apply(bytes.NewReader(targetingBytes), bytes.NewReader(b), &result); err != nil { + je.Logger.ErrorWithID(reqID, fmt.Sprintf("error applying targeting rules: %s", err)) + return "", model.ErrorReason, errors.New(model.ParseErrorCode) + } - // check if string is "null" before we strip quotes, so we can differentiate between JSON null and "null" - trimmed := strings.TrimSpace(result.String()) + // check if string is "null" before we strip quotes, so we can differentiate between JSON null and "null" + trimmed := strings.TrimSpace(result.String()) - if trimmed == "null" { - if flag.DefaultVariant == "" { - if ctx.Value(ProtoVersionKey) != nil { - // old proto version behavior - return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.FlagNotFoundErrorCode) - } + if trimmed == "null" { + return je.handleNullTargetingResult(ctx, flag) + } - return "", flag.Variants, model.FallbackReason, metadata, nil - } + // strip whitespace and quotes from the variant + variant = strings.ReplaceAll(trimmed, "\"", "") - return flag.DefaultVariant, flag.Variants, model.DefaultReason, metadata, nil - } + // if this is a valid variant, return it + if _, ok := flag.Variants[variant]; ok { + return variant, model.TargetingMatchReason, nil + } - // strip whitespace and quotes from the variant - variant = strings.ReplaceAll(trimmed, "\"", "") + je.Logger.ErrorWithID(reqID, + fmt.Sprintf("invalid or missing variant: %s for flagKey: %s, variant is not valid", variant, flagKey)) + return "", model.ErrorReason, errors.New(model.GeneralErrorCode) +} - // if this is a valid variant, return it - if _, ok := flag.Variants[variant]; ok { - return variant, flag.Variants, model.TargetingMatchReason, metadata, nil +func (je *Resolver) handleNullTargetingResult(ctx context.Context, flag model.Flag) (string, string, error) { + if flag.DefaultVariant == "" { + if ctx.Value(ProtoVersionKey) != nil { + // old proto version behavior + return "", model.ErrorReason, errors.New(model.FlagNotFoundErrorCode) } - je.Logger.ErrorWithID(reqID, - fmt.Sprintf("invalid or missing variant: %s for flagKey: %s, variant is not valid", variant, flagKey)) - return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.GeneralErrorCode) + + return "", model.FallbackReason, nil } + return flag.DefaultVariant, model.DefaultReason, nil +} + +func (je *Resolver) resolveDefaultVariant(ctx context.Context, flag model.Flag) (string, string, error) { if flag.DefaultVariant == "" { if ctx.Value(ProtoVersionKey) != nil { // old proto version behavior - return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.FlagNotFoundErrorCode) + return "", model.ErrorReason, errors.New(model.FlagNotFoundErrorCode) } - return "", flag.Variants, model.FallbackReason, metadata, nil + return "", model.FallbackReason, nil } - return flag.DefaultVariant, flag.Variants, model.StaticReason, metadata, nil + return flag.DefaultVariant, model.StaticReason, nil } func setFlagdProperties( diff --git a/core/pkg/sync/grpc/nameresolvers/envoy_resolver.go b/core/pkg/sync/grpc/nameresolvers/envoy_resolver.go index 74dff8f31..e01f19941 100644 --- a/core/pkg/sync/grpc/nameresolvers/envoy_resolver.go +++ b/core/pkg/sync/grpc/nameresolvers/envoy_resolver.go @@ -52,9 +52,13 @@ func (r *envoyResolver) start() { } } -func (*envoyResolver) ResolveNow(resolver.ResolveNowOptions) {} +func (*envoyResolver) ResolveNow(resolver.ResolveNowOptions) { + // no-op: the resolver relies on static configuration provided during construction. +} -func (*envoyResolver) Close() {} +func (*envoyResolver) Close() { + // no-op: there are no resources to release for the static resolver. +} // Validate user specified target // diff --git a/core/pkg/sync/http/http_sync_test.go b/core/pkg/sync/http/http_sync_test.go index 45d1409d4..20213c510 100644 --- a/core/pkg/sync/http/http_sync_test.go +++ b/core/pkg/sync/http/http_sync_test.go @@ -401,25 +401,33 @@ func TestHTTPSync_Resync(t *testing.T) { if !tt.wantErr && err != nil { t.Errorf("got error for %s %s", name, err.Error()) } - for _, dataSync := range tt.wantNotifications { - select { - case x := <-d: - if x.FlagData != dataSync.FlagData || x.Source != dataSync.Source { - t.Errorf("unexpected datasync received %v vs %v", x, dataSync) - } - case <-time.After(2 * time.Second): - t.Error("expected datasync not received", dataSync) - } - } - select { - case x := <-d: - t.Error("unexpected datasync received", x) - case <-time.After(2 * time.Second): - } + assertDataSyncsDelivered(t, d, tt.wantNotifications) + assertNoUnexpectedDataSync(t, d) }) } } +func assertDataSyncsDelivered(t *testing.T, queue chan sync.DataSync, expected []sync.DataSync) { + for _, dataSync := range expected { + select { + case x := <-queue: + if x.FlagData != dataSync.FlagData || x.Source != dataSync.Source { + t.Errorf("unexpected datasync received %v vs %v", x, dataSync) + } + case <-time.After(2 * time.Second): + t.Error("expected datasync not received", dataSync) + } + } +} + +func assertNoUnexpectedDataSync(t *testing.T, queue chan sync.DataSync) { + select { + case x := <-queue: + t.Error("unexpected datasync received", x) + case <-time.After(2 * time.Second): + } +} + func TestHTTPSync_getClient(t *testing.T) { oauth := &sync.OAuthCredentialHandler{ ClientID: "myClientID", diff --git a/flagd/pkg/service/constants.go b/flagd/pkg/service/constants.go index da9e008c5..eb6d3e255 100644 --- a/flagd/pkg/service/constants.go +++ b/flagd/pkg/service/constants.go @@ -1,3 +1,5 @@ package service const FLAGD_SELECTOR_HEADER = "Flagd-Selector" + +const FLAG_SELECTOR_HEADER = "flag-selector" diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator.go b/flagd/pkg/service/flag-evaluation/flag_evaluator.go index ff376ebfb..365fec88f 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator.go @@ -75,11 +75,15 @@ func (s *OldFlagEvaluationService) ResolveAll( Flags: make(map[string]*schemaV1.AnyFlag), } - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) - ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) + selectorExpression := flagdService.SelectorExpressionFromHTTPHeaders(req.Header()) - values, _, err := s.eval.ResolveAllValues(ctx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string))) + values, _, err := ResolveAllWithSelectorMerge( + ctx, + reqID, + s.eval, + mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string)), + selectorExpression, + ) if err != nil { s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err)) return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID) diff --git a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go index 747a8742b..3966532fd 100644 --- a/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go +++ b/flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go @@ -75,13 +75,11 @@ func (s *FlagEvaluationService) ResolveAll( Flags: make(map[string]*evalV1.AnyFlag), } - selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) + selectorExpression := flagdService.SelectorExpressionFromHTTPHeaders(req.Header()) evaluationContext := mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), s.headerToContextKeyMappings) - ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector) ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1") - resolutions, flagSetMetadata, err := s.eval.ResolveAllValues(ctx, reqID, evaluationContext) + resolutions, flagSetMetadata, err := ResolveAllWithSelectorMerge(ctx, reqID, s.eval, evaluationContext, selectorExpression) if err != nil { s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err)) return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID) diff --git a/flagd/pkg/service/flag-evaluation/ofrep/handler.go b/flagd/pkg/service/flag-evaluation/ofrep/handler.go index ac3f2ddf8..491dc78a9 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/handler.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/handler.go @@ -96,7 +96,7 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) { } } evaluationContext := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings) - selectorExpression := r.Header.Get(service.FLAGD_SELECTOR_HEADER) + selectorExpression := service.SelectorExpressionFromHTTPHeaders(r.Header) selector := store.NewSelector(selectorExpression) ctx := context.WithValue(r.Context(), store.SelectorContextKey{}, selector) @@ -121,11 +121,10 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) { } evaluationContext := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings) - selectorExpression := r.Header.Get(service.FLAGD_SELECTOR_HEADER) - selector := store.NewSelector(selectorExpression) - ctx := context.WithValue(r.Context(), store.SelectorContextKey{}, selector) + selectorExpression := service.SelectorExpressionFromHTTPHeaders(r.Header) + ctx := r.Context() - evaluations, metadata, err := h.evaluator.ResolveAllValues(ctx, requestID, evaluationContext) + evaluations, metadata, err := evalservice.ResolveAllWithSelectorMerge(ctx, requestID, h.evaluator, evaluationContext, selectorExpression) if err != nil { h.Logger.WarnWithID(requestID, fmt.Sprintf("error from resolver: %v", err)) diff --git a/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go b/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go index 3ae4d4635..f791b9d1e 100644 --- a/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go +++ b/flagd/pkg/service/flag-evaluation/ofrep/handler_test.go @@ -2,6 +2,7 @@ package ofrep import ( "bytes" + "context" "encoding/json" "errors" "io" @@ -16,6 +17,9 @@ import ( "github.com/open-feature/flagd/core/pkg/logger" "github.com/open-feature/flagd/core/pkg/model" "github.com/open-feature/flagd/core/pkg/service/ofrep" + "github.com/open-feature/flagd/core/pkg/store" + service "github.com/open-feature/flagd/flagd/pkg/service" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) @@ -270,6 +274,44 @@ func Test_handler_HandleBulkEvaluation(t *testing.T) { } } +func TestHandlerHandleBulkEvaluationUsesFlagSelectorHeader(t *testing.T) { + log := logger.NewLogger(nil, false) + eval := mock.NewMockIEvaluator(gomock.NewController(t)) + expectedOrder := []string{"'source=A'", "'source=C'", "'source=B'"} + callCount := 0 + + eval.EXPECT().ResolveAllValues(gomock.Any(), gomock.Any(), gomock.Any()).Times(3).DoAndReturn( + func(ctx context.Context, _ string, _ map[string]any) ([]evaluator.AnyValue, model.Metadata, error) { + selector, ok := ctx.Value(store.SelectorContextKey{}).(store.Selector) + if !ok { + t.Fatalf("selector not found in context") + } + + if callCount >= len(expectedOrder) { + t.Fatalf("unexpected extra selector call") + } + if selector.ToLogString() != expectedOrder[callCount] { + t.Fatalf("unexpected selector at call %d: %s", callCount, selector.ToLogString()) + } + callCount++ + + return []evaluator.AnyValue{successValue}, model.Metadata{}, nil + }, + ) + + h := handler{Logger: log, evaluator: eval} + request, err := http.NewRequest(http.MethodPost, "/ofrep/v1/evaluate/flags", bytes.NewReader([]byte("{}"))) + require.NoError(t, err) + request.Header.Set(service.FLAG_SELECTOR_HEADER, "A,C,B") + + recorder := httptest.NewRecorder() + router := mux.NewRouter() + router.HandleFunc(bulkEvaluation, h.HandleBulkEvaluation) + router.ServeHTTP(recorder, request) + + require.Equal(t, http.StatusOK, recorder.Code) +} + func TestWriteJSONResponse(t *testing.T) { log := logger.NewLogger(nil, false) h := handler{Logger: log} diff --git a/flagd/pkg/service/flag-evaluation/selector_merge.go b/flagd/pkg/service/flag-evaluation/selector_merge.go new file mode 100644 index 000000000..3ca6bdb03 --- /dev/null +++ b/flagd/pkg/service/flag-evaluation/selector_merge.go @@ -0,0 +1,108 @@ +package service + +import ( + "context" + "sort" + "strings" + + "github.com/open-feature/flagd/core/pkg/evaluator" + "github.com/open-feature/flagd/core/pkg/model" + "github.com/open-feature/flagd/core/pkg/store" +) + +func ResolveAllWithSelectorMerge( + ctx context.Context, + reqID string, + eval evaluator.IEvaluator, + evaluationContext map[string]any, + selectorExpression string, +) ([]evaluator.AnyValue, model.Metadata, error) { + selectors := splitSelectorExpression(selectorExpression) + + switch len(selectors) { + case 0: + return eval.ResolveAllValues(ctx, reqID, evaluationContext) + case 1: + return resolveWithSingleSelector(ctx, reqID, eval, evaluationContext, selectors[0]) + default: + return resolveWithMultipleSelectors(ctx, reqID, eval, evaluationContext, selectors) + } +} + +func resolveWithSingleSelector( + ctx context.Context, + reqID string, + eval evaluator.IEvaluator, + evaluationContext map[string]any, + selectorExpression string, +) ([]evaluator.AnyValue, model.Metadata, error) { + selector := store.NewSelector(selectorExpression) + selectorCtx := context.WithValue(ctx, store.SelectorContextKey{}, selector) + return eval.ResolveAllValues(selectorCtx, reqID, evaluationContext) +} + +func resolveWithMultipleSelectors( + ctx context.Context, + reqID string, + eval evaluator.IEvaluator, + evaluationContext map[string]any, + selectors []string, +) ([]evaluator.AnyValue, model.Metadata, error) { + mergedValues := map[string]evaluator.AnyValue{} + mergedMetadata := model.Metadata{} + + for _, selectorExpression := range selectors { + selector := store.NewSelector(selectorExpression) + selectorCtx := context.WithValue(ctx, store.SelectorContextKey{}, selector) + values, metadata, err := eval.ResolveAllValues(selectorCtx, reqID, evaluationContext) + if err != nil { + return nil, nil, err + } + + mergeMetadata(mergedMetadata, metadata) + for _, value := range values { + mergedValues[value.FlagKey] = value + } + } + + resolutions := flattenMergedValues(mergedValues) + return resolutions, mergedMetadata, nil +} + +func mergeMetadata(dest, src model.Metadata) { + for key, value := range src { + dest[key] = value + } +} + +func flattenMergedValues(merged map[string]evaluator.AnyValue) []evaluator.AnyValue { + keys := make([]string, 0, len(merged)) + for key := range merged { + keys = append(keys, key) + } + sort.Strings(keys) + + resolutions := make([]evaluator.AnyValue, 0, len(keys)) + for _, key := range keys { + resolutions = append(resolutions, merged[key]) + } + + return resolutions +} + +func splitSelectorExpression(selectorExpression string) []string { + if strings.TrimSpace(selectorExpression) == "" { + return nil + } + + parts := strings.Split(selectorExpression, ",") + selectors := make([]string, 0, len(parts)) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed == "" { + continue + } + selectors = append(selectors, trimmed) + } + return selectors +} diff --git a/flagd/pkg/service/flag-sync/handler.go b/flagd/pkg/service/flag-sync/handler.go index 559a33f44..3517c4089 100644 --- a/flagd/pkg/service/flag-sync/handler.go +++ b/flagd/pkg/service/flag-sync/handler.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "maps" + "sort" + "strings" "time" "github.com/open-feature/flagd/core/pkg/model" @@ -64,25 +66,56 @@ func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.F watcher := make(chan store.FlagQueryResult, 1) selector := store.NewSelector(selectorExpression) ctx := server.Context() - - syncContextMap := make(map[string]any) - maps.Copy(syncContextMap, s.contextValues) - syncContext, err := structpb.NewStruct(syncContextMap) + syncContext, err := s.buildSyncContext() if err != nil { exitReason = "error" return status.Error(codes.DataLoss, "error constructing sync context") } - // attach server-side stream deadline to context - if s.deadline != 0 { - streamDeadline := time.Now().Add(s.deadline) - deadlineCtx, cancel := context.WithDeadline(ctx, streamDeadline) - ctx = deadlineCtx + var cancel context.CancelFunc + ctx, cancel = s.withDeadline(ctx) + if cancel != nil { defer cancel() } - s.store.Watch(ctx, &selector, watcher) + s.watchSelectors(ctx, selectors, watcher) + return s.streamFlagUpdates(ctx, selectors, watcher, syncContext, server) +} + +func (s syncHandler) buildSyncContext() (*structpb.Struct, error) { + syncContextMap := make(map[string]any) + maps.Copy(syncContextMap, s.contextValues) + return structpb.NewStruct(syncContextMap) +} + +func (s syncHandler) withDeadline(ctx context.Context) (context.Context, context.CancelFunc) { + if s.deadline == 0 { + return ctx, nil + } + streamDeadline := time.Now().Add(s.deadline) + return context.WithDeadline(ctx, streamDeadline) +} + +func (s syncHandler) watchSelectors(ctx context.Context, selectors []store.Selector, watcher chan store.FlagQueryResult) { + switch len(selectors) { + case 0: + s.store.Watch(ctx, nil, watcher) + case 1: + s.store.Watch(ctx, &selectors[0], watcher) + default: + // For multi-selector requests, watch all updates and recompute merged view in order. + s.store.Watch(ctx, nil, watcher) + } +} + +func (s syncHandler) streamFlagUpdates( + ctx context.Context, + selectors []store.Selector, + watcher chan store.FlagQueryResult, + syncContext *structpb.Struct, + server syncv1grpc.FlagSyncService_SyncFlagsServer, +) error { for { select { case payload := <-watcher: @@ -112,6 +145,13 @@ func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.F } } +func (s syncHandler) resolveFlagsForSelectors(ctx context.Context, selectors []store.Selector, payloadFlags []model.Flag) ([]model.Flag, error) { + if len(selectors) == 1 { + return payloadFlags, nil + } + return s.fetchMergedFlags(ctx, selectors) +} + func (s syncHandler) generateResponse(payload []model.Flag) ([]byte, error) { flagConfig := map[string]interface{}{ "flags": s.convertMap(payload), @@ -130,8 +170,7 @@ func (s syncHandler) generateResponse(payload []model.Flag) ([]byte, error) { func (s syncHandler) getSelectorExpression(ctx context.Context, req interface{}) string { // Try to get selector from metadata (header) if md, ok := metadata.FromIncomingContext(ctx); ok { - if values := md.Get(flagdService.FLAGD_SELECTOR_HEADER); len(values) > 0 { - headerSelector := values[0] + if headerSelector := flagdService.SelectorExpressionFromGRPCMetadata(md); headerSelector != "" { s.log.Debug(fmt.Sprintf("using selector from request header: %s", headerSelector)) return headerSelector } @@ -166,8 +205,7 @@ func (s syncHandler) FetchAllFlags(ctx context.Context, req *syncv1.FetchAllFlag *syncv1.FetchAllFlagsResponse, error, ) { selectorExpression := s.getSelectorExpression(ctx, req) - selector := store.NewSelector(selectorExpression) - flags, _, err := s.store.GetAll(ctx, &selector) + flags, err := s.fetchMergedFlags(ctx, parseSelectorList(selectorExpression)) if err != nil { s.log.Error(fmt.Sprintf("error retrieving flags from store: %v", err)) return nil, status.Error(codes.Internal, "error retrieving flags from store") @@ -184,8 +222,66 @@ func (s syncHandler) FetchAllFlags(ctx context.Context, req *syncv1.FetchAllFlag }, nil } +func parseSelectorList(selectorExpression string) []store.Selector { + if strings.TrimSpace(selectorExpression) == "" { + return nil + } + + parts := strings.Split(selectorExpression, ",") + selectors := make([]store.Selector, 0, len(parts)) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed == "" { + continue + } + selector := store.NewSelector(trimmed) + selectors = append(selectors, selector) + } + return selectors +} + +func (s syncHandler) fetchMergedFlags(ctx context.Context, selectors []store.Selector) ([]model.Flag, error) { + switch len(selectors) { + case 0: + flags, _, err := s.store.GetAll(ctx, nil) + return flags, err + case 1: + flags, _, err := s.store.GetAll(ctx, &selectors[0]) + return flags, err + default: + type flagIdentifier struct { + flagSetID string + key string + } + + merged := map[flagIdentifier]model.Flag{} + for _, selector := range selectors { + flags, _, err := s.store.GetAll(ctx, &selector) + if err != nil { + return nil, err + } + for _, flag := range flags { + merged[flagIdentifier{flagSetID: flag.FlagSetId, key: flag.Key}] = flag + } + } + + out := make([]model.Flag, 0, len(merged)) + for _, flag := range merged { + out = append(out, flag) + } + sort.Slice(out, func(i, j int) bool { + if out[i].FlagSetId != out[j].FlagSetId { + return out[i].FlagSetId < out[j].FlagSetId + } + return out[i].Key < out[j].Key + }) + return out, nil + } +} + // Deprecated - GetMetadata is deprecated and will be removed in a future release. // Use the sync_context field in syncv1.SyncFlagsResponse, providing same info. +// //nolint:staticcheck // SA1019 temporarily suppress deprecation warning func (s syncHandler) GetMetadata(_ context.Context, _ *syncv1.GetMetadataRequest) ( *syncv1.GetMetadataResponse, error, diff --git a/flagd/pkg/service/flag-sync/handler_test.go b/flagd/pkg/service/flag-sync/handler_test.go index 85af91eb8..3eca5275f 100644 --- a/flagd/pkg/service/flag-sync/handler_test.go +++ b/flagd/pkg/service/flag-sync/handler_test.go @@ -19,6 +19,34 @@ import ( "google.golang.org/grpc/metadata" ) +type mockMergeStore struct { + flagsBySource map[string][]model.Flag +} + +func (m *mockMergeStore) Get(_ context.Context, _ string, _ *store.Selector) (model.Flag, model.Metadata, error) { + return model.Flag{}, model.Metadata{}, nil +} + +func (m *mockMergeStore) GetAll(_ context.Context, selector *store.Selector) ([]model.Flag, model.Metadata, error) { + if selector == nil || selector.IsEmpty() { + out := []model.Flag{} + for _, flags := range m.flagsBySource { + out = append(out, flags...) + } + return out, model.Metadata{}, nil + } + + source, _ := selector.ToMetadata()["source"].(string) + return append([]model.Flag(nil), m.flagsBySource[source]...), model.Metadata{}, nil +} + +func (m *mockMergeStore) Watch(_ context.Context, _ *store.Selector, _ chan<- store.FlagQueryResult) { + // no-op: the mock does not track watches for the current tests. +} +func (m *mockMergeStore) Update(_ string, _ []model.Flag, _ model.Metadata) { + // intentionally empty because test coverage does not exercise store updates. +} + func TestSyncHandler_SyncFlags(t *testing.T) { tests := []struct { name string @@ -258,3 +286,57 @@ func TestSyncHandler_SelectorLocationPrecedence(t *testing.T) { }) } } + +func TestSyncHandlerFetchAllFlagsMultiSelectorOverrideOrder(t *testing.T) { + handler := syncHandler{ + store: &mockMergeStore{ + flagsBySource: map[string][]model.Flag{ + "A": { + {Key: "shared", FlagSetId: "set", Source: "A", DefaultVariant: "a", State: "ENABLED", Variants: testVariants}, + }, + "B": { + {Key: "shared", FlagSetId: "set", Source: "B", DefaultVariant: "b", State: "ENABLED", Variants: testVariants}, + }, + "C": { + {Key: "shared", FlagSetId: "set", Source: "C", DefaultVariant: "c", State: "ENABLED", Variants: testVariants}, + }, + }, + }, + log: logger.NewLogger(nil, false), + contextValues: map[string]any{}, + } + + md := metadata.New(map[string]string{ + flagdService.FLAG_SELECTOR_HEADER: "A,C,B", + }) + ctx := metadata.NewIncomingContext(context.Background(), md) + + resp, err := handler.FetchAllFlags(ctx, &syncv1.FetchAllFlagsRequest{}) + require.NoError(t, err) + assert.Contains(t, resp.FlagConfiguration, "\"source\":\"B\"") + assert.Contains(t, resp.FlagConfiguration, "\"defaultVariant\":\"b\"") +} + +func TestSyncHandlerFetchMergedFlagsOrderOverride(t *testing.T) { + handler := syncHandler{ + store: &mockMergeStore{ + flagsBySource: map[string][]model.Flag{ + "A": { + {Key: "shared", FlagSetId: "set", Source: "A", DefaultVariant: "a"}, + }, + "B": { + {Key: "shared", FlagSetId: "set", Source: "B", DefaultVariant: "b"}, + }, + "C": { + {Key: "shared", FlagSetId: "set", Source: "C", DefaultVariant: "c"}, + }, + }, + }, + } + + flags, err := handler.fetchMergedFlags(context.Background(), parseSelectorList("A,C,B")) + require.NoError(t, err) + require.Len(t, flags, 1) + assert.Equal(t, "B", flags[0].Source) + assert.Equal(t, "b", flags[0].DefaultVariant) +} diff --git a/flagd/pkg/service/selector.go b/flagd/pkg/service/selector.go new file mode 100644 index 000000000..719300a45 --- /dev/null +++ b/flagd/pkg/service/selector.go @@ -0,0 +1,28 @@ +package service + +import ( + "net/http" + "strings" + + "google.golang.org/grpc/metadata" +) + +func SelectorExpressionFromHTTPHeaders(headers http.Header) string { + if selectors := strings.TrimSpace(strings.Join(headers.Values(FLAG_SELECTOR_HEADER), ",")); selectors != "" { + return selectors + } + return strings.TrimSpace(strings.Join(headers.Values(FLAGD_SELECTOR_HEADER), ",")) +} + +func SelectorExpressionFromGRPCMetadata(md metadata.MD) string { + if selectors := strings.TrimSpace(strings.Join(md.Get(strings.ToLower(FLAG_SELECTOR_HEADER)), ",")); selectors != "" { + return selectors + } + if selectors := strings.TrimSpace(strings.Join(md.Get(strings.ToLower(FLAGD_SELECTOR_HEADER)), ",")); selectors != "" { + return selectors + } + if selectors := strings.TrimSpace(strings.Join(md.Get(FLAG_SELECTOR_HEADER), ",")); selectors != "" { + return selectors + } + return strings.TrimSpace(strings.Join(md.Get(FLAGD_SELECTOR_HEADER), ",")) +} diff --git a/flagd/pkg/service/selector_test.go b/flagd/pkg/service/selector_test.go new file mode 100644 index 000000000..92b43f70a --- /dev/null +++ b/flagd/pkg/service/selector_test.go @@ -0,0 +1,30 @@ +package service + +import ( + "net/http" + "testing" + + "google.golang.org/grpc/metadata" +) + +func TestSelectorExpressionFromHTTPHeaders(t *testing.T) { + headers := http.Header{} + headers.Add(FLAG_SELECTOR_HEADER, "A") + headers.Add(FLAG_SELECTOR_HEADER, "B") + + got := SelectorExpressionFromHTTPHeaders(headers) + if got != "A,B" { + t.Fatalf("expected A,B, got %s", got) + } +} + +func TestSelectorExpressionFromGRPCMetadata(t *testing.T) { + md := metadata.New(map[string]string{ + FLAG_SELECTOR_HEADER: "A,C,B", + }) + + got := SelectorExpressionFromGRPCMetadata(md) + if got != "A,C,B" { + t.Fatalf("expected A,C,B, got %s", got) + } +}