Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 64 additions & 49 deletions core/pkg/evaluator/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,7 @@ func (je *Resolver) evaluateVariant(ctx context.Context, reqID string, flagKey s
) {

var selector store.Selector
s := ctx.Value(store.SelectorContextKey{})
if s != nil {
if s := ctx.Value(store.SelectorContextKey{}); s != nil {
selector = s.(store.Selector)
}
flag, metadata, err := je.store.Get(ctx, flagKey, &selector)
Expand All @@ -358,73 +357,89 @@ func (je *Resolver) evaluateVariant(ctx context.Context, reqID string, flagKey s
return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.FlagDisabledErrorCode)
}

// get the targeting logic, if any
targeting := flag.Targeting
if flag.Targeting != nil && string(flag.Targeting) != "{}" {
variant, reason, err := je.evaluateTargeting(ctx, reqID, flagKey, flag, evalCtx)
return variant, flag.Variants, reason, metadata, err
}

if targeting != nil && string(targeting) != "{}" {
targetingBytes, err := targeting.MarshalJSON()
if err != nil {
je.Logger.ErrorWithID(reqID, fmt.Sprintf("Error parsing rules for flag: %s, %s", flagKey, err))
return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.ParseErrorCode)
}
variant, reason, err = je.resolveDefaultVariant(ctx, flag)
return variant, flag.Variants, reason, metadata, err
}

evalCtx = setFlagdProperties(je.Logger, evalCtx, flagdProperties{
FlagKey: flagKey,
Timestamp: time.Now().Unix(),
})
func (je *Resolver) evaluateTargeting(
ctx context.Context,
reqID string,
flagKey string,
flag model.Flag,
evalCtx map[string]any,
) (variant string, reason string, err error) {
targetingBytes, err := flag.Targeting.MarshalJSON()
if err != nil {
je.Logger.ErrorWithID(reqID, fmt.Sprintf("Error parsing rules for flag: %s, %s", flagKey, err))
return "", model.ErrorReason, errors.New(model.ParseErrorCode)
}

b, err := json.Marshal(evalCtx)
if err != nil {
je.Logger.ErrorWithID(reqID, fmt.Sprintf("error parsing context for flag: %s, %s, %v", flagKey, err, evalCtx))
evalCtx = setFlagdProperties(je.Logger, evalCtx, flagdProperties{
FlagKey: flagKey,
Timestamp: time.Now().Unix(),
})

return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.ErrorReason)
}
b, err := json.Marshal(evalCtx)
if err != nil {
je.Logger.ErrorWithID(reqID, fmt.Sprintf("error parsing context for flag: %s, %s, %v", flagKey, err, evalCtx))
return "", model.ErrorReason, errors.New(model.ErrorReason)
}

var result bytes.Buffer
// evaluate JsonLogic rules to determine the variant
err = jsonlogic.Apply(bytes.NewReader(targetingBytes), bytes.NewReader(b), &result)
if err != nil {
je.Logger.ErrorWithID(reqID, fmt.Sprintf("error applying targeting rules: %s", err))
return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.ParseErrorCode)
}
var result bytes.Buffer
// evaluate JsonLogic rules to determine the variant
if err = jsonlogic.Apply(bytes.NewReader(targetingBytes), bytes.NewReader(b), &result); err != nil {
je.Logger.ErrorWithID(reqID, fmt.Sprintf("error applying targeting rules: %s", err))
return "", model.ErrorReason, errors.New(model.ParseErrorCode)
}

// check if string is "null" before we strip quotes, so we can differentiate between JSON null and "null"
trimmed := strings.TrimSpace(result.String())
// check if string is "null" before we strip quotes, so we can differentiate between JSON null and "null"
trimmed := strings.TrimSpace(result.String())

if trimmed == "null" {
if flag.DefaultVariant == "" {
if ctx.Value(ProtoVersionKey) != nil {
// old proto version behavior
return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.FlagNotFoundErrorCode)
}
if trimmed == "null" {
return je.handleNullTargetingResult(ctx, flag)
}

return "", flag.Variants, model.FallbackReason, metadata, nil
}
// strip whitespace and quotes from the variant
variant = strings.ReplaceAll(trimmed, "\"", "")

return flag.DefaultVariant, flag.Variants, model.DefaultReason, metadata, nil
}
// if this is a valid variant, return it
if _, ok := flag.Variants[variant]; ok {
return variant, model.TargetingMatchReason, nil
}

// strip whitespace and quotes from the variant
variant = strings.ReplaceAll(trimmed, "\"", "")
je.Logger.ErrorWithID(reqID,
fmt.Sprintf("invalid or missing variant: %s for flagKey: %s, variant is not valid", variant, flagKey))
return "", model.ErrorReason, errors.New(model.GeneralErrorCode)
}

// if this is a valid variant, return it
if _, ok := flag.Variants[variant]; ok {
return variant, flag.Variants, model.TargetingMatchReason, metadata, nil
func (je *Resolver) handleNullTargetingResult(ctx context.Context, flag model.Flag) (string, string, error) {
if flag.DefaultVariant == "" {
if ctx.Value(ProtoVersionKey) != nil {
// old proto version behavior
return "", model.ErrorReason, errors.New(model.FlagNotFoundErrorCode)
}
je.Logger.ErrorWithID(reqID,
fmt.Sprintf("invalid or missing variant: %s for flagKey: %s, variant is not valid", variant, flagKey))
return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.GeneralErrorCode)

return "", model.FallbackReason, nil
}

return flag.DefaultVariant, model.DefaultReason, nil
}

func (je *Resolver) resolveDefaultVariant(ctx context.Context, flag model.Flag) (string, string, error) {
if flag.DefaultVariant == "" {
if ctx.Value(ProtoVersionKey) != nil {
// old proto version behavior
return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.FlagNotFoundErrorCode)
return "", model.ErrorReason, errors.New(model.FlagNotFoundErrorCode)
}
return "", flag.Variants, model.FallbackReason, metadata, nil
return "", model.FallbackReason, nil
}

return flag.DefaultVariant, flag.Variants, model.StaticReason, metadata, nil
return flag.DefaultVariant, model.StaticReason, nil
}

func setFlagdProperties(
Expand Down
8 changes: 6 additions & 2 deletions core/pkg/sync/grpc/nameresolvers/envoy_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ func (r *envoyResolver) start() {
}
}

func (*envoyResolver) ResolveNow(resolver.ResolveNowOptions) {}
func (*envoyResolver) ResolveNow(resolver.ResolveNowOptions) {
// no-op: the resolver relies on static configuration provided during construction.
}

func (*envoyResolver) Close() {}
func (*envoyResolver) Close() {
// no-op: there are no resources to release for the static resolver.
}

// Validate user specified target
//
Expand Down
38 changes: 23 additions & 15 deletions core/pkg/sync/http/http_sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,25 +401,33 @@ func TestHTTPSync_Resync(t *testing.T) {
if !tt.wantErr && err != nil {
t.Errorf("got error for %s %s", name, err.Error())
}
for _, dataSync := range tt.wantNotifications {
select {
case x := <-d:
if x.FlagData != dataSync.FlagData || x.Source != dataSync.Source {
t.Errorf("unexpected datasync received %v vs %v", x, dataSync)
}
case <-time.After(2 * time.Second):
t.Error("expected datasync not received", dataSync)
}
}
select {
case x := <-d:
t.Error("unexpected datasync received", x)
case <-time.After(2 * time.Second):
}
assertDataSyncsDelivered(t, d, tt.wantNotifications)
assertNoUnexpectedDataSync(t, d)
})
}
}

func assertDataSyncsDelivered(t *testing.T, queue chan sync.DataSync, expected []sync.DataSync) {
for _, dataSync := range expected {
select {
case x := <-queue:
if x.FlagData != dataSync.FlagData || x.Source != dataSync.Source {
t.Errorf("unexpected datasync received %v vs %v", x, dataSync)
}
case <-time.After(2 * time.Second):
t.Error("expected datasync not received", dataSync)
}
}
}

func assertNoUnexpectedDataSync(t *testing.T, queue chan sync.DataSync) {
select {
case x := <-queue:
t.Error("unexpected datasync received", x)
case <-time.After(2 * time.Second):
}
}

func TestHTTPSync_getClient(t *testing.T) {
oauth := &sync.OAuthCredentialHandler{
ClientID: "myClientID",
Expand Down
2 changes: 2 additions & 0 deletions flagd/pkg/service/constants.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package service

const FLAGD_SELECTOR_HEADER = "Flagd-Selector"

const FLAG_SELECTOR_HEADER = "flag-selector"
12 changes: 8 additions & 4 deletions flagd/pkg/service/flag-evaluation/flag_evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,15 @@ func (s *OldFlagEvaluationService) ResolveAll(
Flags: make(map[string]*schemaV1.AnyFlag),
}

selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
selector := store.NewSelector(selectorExpression)
ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector)
selectorExpression := flagdService.SelectorExpressionFromHTTPHeaders(req.Header())

values, _, err := s.eval.ResolveAllValues(ctx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string)))
values, _, err := ResolveAllWithSelectorMerge(
ctx,
reqID,
s.eval,
mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string)),
selectorExpression,
)
if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
Expand Down
6 changes: 2 additions & 4 deletions flagd/pkg/service/flag-evaluation/flag_evaluator_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,11 @@ func (s *FlagEvaluationService) ResolveAll(
Flags: make(map[string]*evalV1.AnyFlag),
}

selectorExpression := req.Header().Get(flagdService.FLAGD_SELECTOR_HEADER)
selector := store.NewSelector(selectorExpression)
selectorExpression := flagdService.SelectorExpressionFromHTTPHeaders(req.Header())
evaluationContext := mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), s.headerToContextKeyMappings)
ctx = context.WithValue(ctx, store.SelectorContextKey{}, selector)
ctx = context.WithValue(ctx, evaluator.ProtoVersionKey, "v1")

resolutions, flagSetMetadata, err := s.eval.ResolveAllValues(ctx, reqID, evaluationContext)
resolutions, flagSetMetadata, err := ResolveAllWithSelectorMerge(ctx, reqID, s.eval, evaluationContext, selectorExpression)
if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
Expand Down
9 changes: 4 additions & 5 deletions flagd/pkg/service/flag-evaluation/ofrep/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) {
}
}
evaluationContext := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings)
selectorExpression := r.Header.Get(service.FLAGD_SELECTOR_HEADER)
selectorExpression := service.SelectorExpressionFromHTTPHeaders(r.Header)
selector := store.NewSelector(selectorExpression)
Comment on lines +99 to 100
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

In HandleFlagEvaluation, the selectorExpression (which can now be a comma-separated list of multiple selectors) is passed directly to store.NewSelector. This is a critical logic error as store.NewSelector is designed for a single selector, leading to incorrect parsing or an empty selector. This can cause the evaluator to fall back to the highest priority flag across all sources, potentially bypassing intended isolation between flag sets (e.g., tenant isolation), which is a broken access control vulnerability. This behavior is inconsistent with HandleBulkEvaluation. Update this endpoint to correctly handle multi-selectors, for example by splitting the expression and using the first valid selector if multi-merge is not yet supported for single evaluations.

ctx := context.WithValue(r.Context(), store.SelectorContextKey{}, selector)

Expand All @@ -121,11 +121,10 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) {
}

evaluationContext := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings)
selectorExpression := r.Header.Get(service.FLAGD_SELECTOR_HEADER)
selector := store.NewSelector(selectorExpression)
ctx := context.WithValue(r.Context(), store.SelectorContextKey{}, selector)
selectorExpression := service.SelectorExpressionFromHTTPHeaders(r.Header)
ctx := r.Context()

evaluations, metadata, err := h.evaluator.ResolveAllValues(ctx, requestID, evaluationContext)
evaluations, metadata, err := evalservice.ResolveAllWithSelectorMerge(ctx, requestID, h.evaluator, evaluationContext, selectorExpression)
if err != nil {
h.Logger.WarnWithID(requestID, fmt.Sprintf("error from resolver: %v", err))

Expand Down
42 changes: 42 additions & 0 deletions flagd/pkg/service/flag-evaluation/ofrep/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ofrep

import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
Expand All @@ -16,6 +17,9 @@ import (
"github.com/open-feature/flagd/core/pkg/logger"
"github.com/open-feature/flagd/core/pkg/model"
"github.com/open-feature/flagd/core/pkg/service/ofrep"
"github.com/open-feature/flagd/core/pkg/store"
service "github.com/open-feature/flagd/flagd/pkg/service"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

Expand Down Expand Up @@ -270,6 +274,44 @@ func Test_handler_HandleBulkEvaluation(t *testing.T) {
}
}

func TestHandlerHandleBulkEvaluationUsesFlagSelectorHeader(t *testing.T) {
log := logger.NewLogger(nil, false)
eval := mock.NewMockIEvaluator(gomock.NewController(t))
expectedOrder := []string{"'source=A'", "'source=C'", "'source=B'"}
callCount := 0

eval.EXPECT().ResolveAllValues(gomock.Any(), gomock.Any(), gomock.Any()).Times(3).DoAndReturn(
func(ctx context.Context, _ string, _ map[string]any) ([]evaluator.AnyValue, model.Metadata, error) {
selector, ok := ctx.Value(store.SelectorContextKey{}).(store.Selector)
if !ok {
t.Fatalf("selector not found in context")
}

if callCount >= len(expectedOrder) {
t.Fatalf("unexpected extra selector call")
}
if selector.ToLogString() != expectedOrder[callCount] {
t.Fatalf("unexpected selector at call %d: %s", callCount, selector.ToLogString())
}
callCount++

return []evaluator.AnyValue{successValue}, model.Metadata{}, nil
},
)

h := handler{Logger: log, evaluator: eval}
request, err := http.NewRequest(http.MethodPost, "/ofrep/v1/evaluate/flags", bytes.NewReader([]byte("{}")))
require.NoError(t, err)
request.Header.Set(service.FLAG_SELECTOR_HEADER, "A,C,B")

recorder := httptest.NewRecorder()
router := mux.NewRouter()
router.HandleFunc(bulkEvaluation, h.HandleBulkEvaluation)
router.ServeHTTP(recorder, request)

require.Equal(t, http.StatusOK, recorder.Code)
}

func TestWriteJSONResponse(t *testing.T) {
log := logger.NewLogger(nil, false)
h := handler{Logger: log}
Expand Down
Loading
Loading