Skip to content

Commit af7a261

Browse files
Nilushiya.KNilushiya.K
authored andcommitted
Feature: Add multi-selector support to flagd
Signed-off-by: Nilushiya.K <Nilushiya.K@cloudsolutions.com.sa>
1 parent a176bc6 commit af7a261

10 files changed

Lines changed: 356 additions & 20 deletions

File tree

flagd/pkg/service/constants.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
package service
22

33
const FLAGD_SELECTOR_HEADER = "Flagd-Selector"
4+
5+
const FLAG_SELECTOR_HEADER = "flag-selector"

flagd/pkg/service/flag-evaluation/flag_evaluator.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,15 @@ func (s *OldFlagEvaluationService) ResolveAll(
7575
Flags: make(map[string]*schemaV1.AnyFlag),
7676
}
7777

78-
selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
79-
selector := store.NewSelector(selectorExpression)
80-
ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector)
78+
selectorExpression := flagdService.SelectorExpressionFromHTTPHeaders(req.Header())
8179

82-
values, _, err := s.eval.ResolveAllValues(ctx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string)))
80+
values, _, err := ResolveAllWithSelectorMerge(
81+
ctx,
82+
reqID,
83+
s.eval,
84+
mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string)),
85+
selectorExpression,
86+
)
8387
if err != nil {
8488
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
8589
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)

flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,11 @@ func (s *FlagEvaluationService) ResolveAll(
7575
Flags: make(map[string]*evalV1.AnyFlag),
7676
}
7777

78-
selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
79-
selector := store.NewSelector(selectorExpression)
78+
selectorExpression := flagdService.SelectorExpressionFromHTTPHeaders(req.Header())
8079
evaluationContext := mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), s.headerToContextKeyMappings)
81-
ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector)
8280
ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1")
8381

84-
resolutions, flagSetMetadata, err := s.eval.ResolveAllValues(ctx, reqID, evaluationContext)
82+
resolutions, flagSetMetadata, err := ResolveAllWithSelectorMerge(ctx, reqID, s.eval, evaluationContext, selectorExpression)
8583
if err != nil {
8684
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
8785
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)

flagd/pkg/service/flag-evaluation/ofrep/handler.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) {
9494
return
9595
}
9696
evaluationContext := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings)
97-
selectorExpression := r.Header.Get(service.FLAGD_SELECTOR_HEADER)
97+
selectorExpression := service.SelectorExpressionFromHTTPHeaders(r.Header)
9898
selector := store.NewSelector(selectorExpression)
9999
ctx := context.WithValue(r.Context(), store.SelectorContextKey{}, selector)
100100

@@ -118,11 +118,10 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) {
118118
}
119119

120120
evaluationContext := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings)
121-
selectorExpression := r.Header.Get(service.FLAGD_SELECTOR_HEADER)
122-
selector := store.NewSelector(selectorExpression)
123-
ctx := context.WithValue(r.Context(), store.SelectorContextKey{}, selector)
121+
selectorExpression := service.SelectorExpressionFromHTTPHeaders(r.Header)
122+
ctx := r.Context()
124123

125-
evaluations, metadata, err := h.evaluator.ResolveAllValues(ctx, requestID, evaluationContext)
124+
evaluations, metadata, err := evalservice.ResolveAllWithSelectorMerge(ctx, requestID, h.evaluator, evaluationContext, selectorExpression)
126125
if err != nil {
127126
h.Logger.WarnWithID(requestID, fmt.Sprintf("error from resolver: %v", err))
128127

flagd/pkg/service/flag-evaluation/ofrep/handler_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package ofrep
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"errors"
78
"io"
@@ -16,6 +17,9 @@ import (
1617
"github.com/open-feature/flagd/core/pkg/logger"
1718
"github.com/open-feature/flagd/core/pkg/model"
1819
"github.com/open-feature/flagd/core/pkg/service/ofrep"
20+
"github.com/open-feature/flagd/core/pkg/store"
21+
service "github.com/open-feature/flagd/flagd/pkg/service"
22+
"github.com/stretchr/testify/require"
1923
"go.uber.org/mock/gomock"
2024
)
2125

@@ -270,6 +274,44 @@ func Test_handler_HandleBulkEvaluation(t *testing.T) {
270274
}
271275
}
272276

277+
func Test_handler_HandleBulkEvaluation_UsesFlagSelectorHeader(t *testing.T) {
278+
log := logger.NewLogger(nil, false)
279+
eval := mock.NewMockIEvaluator(gomock.NewController(t))
280+
expectedOrder := []string{"'source=A'", "'source=C'", "'source=B'"}
281+
callCount := 0
282+
283+
eval.EXPECT().ResolveAllValues(gomock.Any(), gomock.Any(), gomock.Any()).Times(3).DoAndReturn(
284+
func(ctx context.Context, _ string, _ map[string]any) ([]evaluator.AnyValue, model.Metadata, error) {
285+
selector, ok := ctx.Value(store.SelectorContextKey{}).(store.Selector)
286+
if !ok {
287+
t.Fatalf("selector not found in context")
288+
}
289+
290+
if callCount >= len(expectedOrder) {
291+
t.Fatalf("unexpected extra selector call")
292+
}
293+
if selector.ToLogString() != expectedOrder[callCount] {
294+
t.Fatalf("unexpected selector at call %d: %s", callCount, selector.ToLogString())
295+
}
296+
callCount++
297+
298+
return []evaluator.AnyValue{successValue}, model.Metadata{}, nil
299+
},
300+
)
301+
302+
h := handler{Logger: log, evaluator: eval}
303+
request, err := http.NewRequest(http.MethodPost, "/ofrep/v1/evaluate/flags", bytes.NewReader([]byte("{}")))
304+
require.NoError(t, err)
305+
request.Header.Set(service.FLAG_SELECTOR_HEADER, "A,C,B")
306+
307+
recorder := httptest.NewRecorder()
308+
router := mux.NewRouter()
309+
router.HandleFunc(bulkEvaluation, h.HandleBulkEvaluation)
310+
router.ServeHTTP(recorder, request)
311+
312+
require.Equal(t, http.StatusOK, recorder.Code)
313+
}
314+
273315
func TestWriteJSONResponse(t *testing.T) {
274316
log := logger.NewLogger(nil, false)
275317
h := handler{Logger: log}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package service
2+
3+
import (
4+
"context"
5+
"sort"
6+
"strings"
7+
8+
"github.com/open-feature/flagd/core/pkg/evaluator"
9+
"github.com/open-feature/flagd/core/pkg/model"
10+
"github.com/open-feature/flagd/core/pkg/store"
11+
)
12+
13+
func ResolveAllWithSelectorMerge(
14+
ctx context.Context,
15+
reqID string,
16+
eval evaluator.IEvaluator,
17+
evaluationContext map[string]any,
18+
selectorExpression string,
19+
) ([]evaluator.AnyValue, model.Metadata, error) {
20+
selectors := splitSelectorExpression(selectorExpression)
21+
22+
switch len(selectors) {
23+
case 0:
24+
return eval.ResolveAllValues(ctx, reqID, evaluationContext)
25+
case 1:
26+
selector := store.NewSelector(selectors[0])
27+
selectorCtx := context.WithValue(ctx, store.SelectorContextKey{}, selector)
28+
return eval.ResolveAllValues(selectorCtx, reqID, evaluationContext)
29+
default:
30+
mergedValues := map[string]evaluator.AnyValue{}
31+
mergedMetadata := model.Metadata{}
32+
33+
for _, selectorExpression := range selectors {
34+
selector := store.NewSelector(selectorExpression)
35+
selectorCtx := context.WithValue(ctx, store.SelectorContextKey{}, selector)
36+
values, metadata, err := eval.ResolveAllValues(selectorCtx, reqID, evaluationContext)
37+
if err != nil {
38+
return nil, nil, err
39+
}
40+
41+
for key, value := range metadata {
42+
mergedMetadata[key] = value
43+
}
44+
for _, value := range values {
45+
mergedValues[value.FlagKey] = value
46+
}
47+
}
48+
49+
keys := make([]string, 0, len(mergedValues))
50+
for key := range mergedValues {
51+
keys = append(keys, key)
52+
}
53+
sort.Strings(keys)
54+
55+
resolutions := make([]evaluator.AnyValue, 0, len(keys))
56+
for _, key := range keys {
57+
resolutions = append(resolutions, mergedValues[key])
58+
}
59+
60+
return resolutions, mergedMetadata, nil
61+
}
62+
}
63+
64+
func splitSelectorExpression(selectorExpression string) []string {
65+
if strings.TrimSpace(selectorExpression) == "" {
66+
return nil
67+
}
68+
69+
parts := strings.Split(selectorExpression, ",")
70+
selectors := make([]string, 0, len(parts))
71+
for _, part := range parts {
72+
trimmed := strings.TrimSpace(part)
73+
if trimmed == "" {
74+
continue
75+
}
76+
selectors = append(selectors, trimmed)
77+
}
78+
return selectors
79+
}

flagd/pkg/service/flag-sync/handler.go

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"errors"
77
"fmt"
88
"maps"
9+
"sort"
10+
"strings"
911
"time"
1012

1113
"github.com/open-feature/flagd/core/pkg/model"
@@ -35,7 +37,7 @@ type syncHandler struct {
3537
func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.FlagSyncService_SyncFlagsServer) error {
3638
watcher := make(chan store.FlagQueryResult, 1)
3739
selectorExpression := s.getSelectorExpression(server.Context(), req)
38-
selector := store.NewSelector(selectorExpression)
40+
selectors := parseSelectorList(selectorExpression)
3941
ctx := server.Context()
4042

4143
syncContextMap := make(map[string]any)
@@ -53,7 +55,15 @@ func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.F
5355
defer cancel()
5456
}
5557

56-
s.store.Watch(ctx, &selector, watcher)
58+
switch len(selectors) {
59+
case 0:
60+
s.store.Watch(ctx, nil, watcher)
61+
case 1:
62+
s.store.Watch(ctx, &selectors[0], watcher)
63+
default:
64+
// For multi-selector requests, watch all updates and recompute merged view in order.
65+
s.store.Watch(ctx, nil, watcher)
66+
}
5767

5868
for {
5969
select {
@@ -63,7 +73,16 @@ func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.F
6373
return fmt.Errorf("error constructing metadata response")
6474
}
6575

66-
flags, err := s.generateResponse(payload.Flags)
76+
flagsToSend := payload.Flags
77+
if len(selectors) != 1 {
78+
flagsToSend, err = s.fetchMergedFlags(ctx, selectors)
79+
if err != nil {
80+
s.log.Error(fmt.Sprintf("error retrieving merged flags from store: %v", err))
81+
return status.Error(codes.Internal, "error retrieving flags from store")
82+
}
83+
}
84+
85+
flags, err := s.generateResponse(flagsToSend)
6786
if err != nil {
6887
s.log.Error(fmt.Sprintf("error retrieving flags from store: %v", err))
6988
return status.Error(codes.DataLoss, "error marshalling flags")
@@ -103,8 +122,7 @@ func (s syncHandler) generateResponse(payload []model.Flag) ([]byte, error) {
103122
func (s syncHandler) getSelectorExpression(ctx context.Context, req interface{}) string {
104123
// Try to get selector from metadata (header)
105124
if md, ok := metadata.FromIncomingContext(ctx); ok {
106-
if values := md.Get(flagdService.FLAGD_SELECTOR_HEADER); len(values) > 0 {
107-
headerSelector := values[0]
125+
if headerSelector := flagdService.SelectorExpressionFromGRPCMetadata(md); headerSelector != "" {
108126
s.log.Debug(fmt.Sprintf("using selector from request header: %s", headerSelector))
109127
return headerSelector
110128
}
@@ -139,8 +157,7 @@ func (s syncHandler) FetchAllFlags(ctx context.Context, req *syncv1.FetchAllFlag
139157
*syncv1.FetchAllFlagsResponse, error,
140158
) {
141159
selectorExpression := s.getSelectorExpression(ctx, req)
142-
selector := store.NewSelector(selectorExpression)
143-
flags, _, err := s.store.GetAll(ctx, &selector)
160+
flags, err := s.fetchMergedFlags(ctx, parseSelectorList(selectorExpression))
144161
if err != nil {
145162
s.log.Error(fmt.Sprintf("error retrieving flags from store: %v", err))
146163
return nil, status.Error(codes.Internal, "error retrieving flags from store")
@@ -157,8 +174,66 @@ func (s syncHandler) FetchAllFlags(ctx context.Context, req *syncv1.FetchAllFlag
157174
}, nil
158175
}
159176

177+
func parseSelectorList(selectorExpression string) []store.Selector {
178+
if strings.TrimSpace(selectorExpression) == "" {
179+
return nil
180+
}
181+
182+
parts := strings.Split(selectorExpression, ",")
183+
selectors := make([]store.Selector, 0, len(parts))
184+
for _, part := range parts {
185+
trimmed := strings.TrimSpace(part)
186+
if trimmed == "" {
187+
continue
188+
}
189+
selector := store.NewSelector(trimmed)
190+
selectors = append(selectors, selector)
191+
}
192+
return selectors
193+
}
194+
195+
func (s syncHandler) fetchMergedFlags(ctx context.Context, selectors []store.Selector) ([]model.Flag, error) {
196+
switch len(selectors) {
197+
case 0:
198+
flags, _, err := s.store.GetAll(ctx, nil)
199+
return flags, err
200+
case 1:
201+
flags, _, err := s.store.GetAll(ctx, &selectors[0])
202+
return flags, err
203+
default:
204+
type flagIdentifier struct {
205+
flagSetID string
206+
key string
207+
}
208+
209+
merged := map[flagIdentifier]model.Flag{}
210+
for _, selector := range selectors {
211+
flags, _, err := s.store.GetAll(ctx, &selector)
212+
if err != nil {
213+
return nil, err
214+
}
215+
for _, flag := range flags {
216+
merged[flagIdentifier{flagSetID: flag.FlagSetId, key: flag.Key}] = flag
217+
}
218+
}
219+
220+
out := make([]model.Flag, 0, len(merged))
221+
for _, flag := range merged {
222+
out = append(out, flag)
223+
}
224+
sort.Slice(out, func(i, j int) bool {
225+
if out[i].FlagSetId != out[j].FlagSetId {
226+
return out[i].FlagSetId < out[j].FlagSetId
227+
}
228+
return out[i].Key < out[j].Key
229+
})
230+
return out, nil
231+
}
232+
}
233+
160234
// Deprecated - GetMetadata is deprecated and will be removed in a future release.
161235
// Use the sync_context field in syncv1.SyncFlagsResponse, providing same info.
236+
//
162237
//nolint:staticcheck // SA1019 temporarily suppress deprecation warning
163238
func (s syncHandler) GetMetadata(_ context.Context, _ *syncv1.GetMetadataRequest) (
164239
*syncv1.GetMetadataResponse, error,

0 commit comments

Comments
 (0)