diff --git a/adaptive.go b/adaptive.go index f332548..cc111d5 100644 --- a/adaptive.go +++ b/adaptive.go @@ -42,8 +42,10 @@ type AdaptiveThrottle struct { k float64 minPerWindow float64 - requests []windowedCounter - accepts []windowedCounter + priorities int + requests []windowedCounter + accepts []windowedCounter + validate func(p Priority, priorities int) (Priority, error) } // NewAdaptiveThrottle returns an AdaptiveThrottle. @@ -51,6 +53,10 @@ type AdaptiveThrottle struct { // priorities is the number of priorities that the throttle will accept. Giving a priority outside // of `[0, priorities)` will panic. func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *AdaptiveThrottle { + if priorities <= 0 { + panic("bulwark: priorities must be greater than 0") + } + opts := adaptiveThrottleOptions{ d: time.Minute, k: K, @@ -70,9 +76,11 @@ func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *Ada return &AdaptiveThrottle{ k: opts.k, + priorities: priorities, requests: requests, accepts: accepts, minPerWindow: opts.minRate * opts.d.Seconds(), + validate: OnInvalidPriorityAdjust, } } @@ -93,7 +101,11 @@ func NewAdaptiveThrottle(priorities int, options ...AdaptiveThrottleOption) *Ada func (t *AdaptiveThrottle) Throttle( ctx context.Context, defaultPriority Priority, fn throttledFn, fallbackFn ...fallbackFn, ) error { - priority := PriorityFromContext(ctx, defaultPriority) + priority, err := t.validate(PriorityFromContext(ctx, defaultPriority), t.priorities) + if err != nil { + return err + } + now := Now() rejectionProbability := t.rejectionProbability(priority, now) if rand.Float64() < rejectionProbability { @@ -112,7 +124,7 @@ func (t *AdaptiveThrottle) Throttle( return ClientSideRejectionError } - err := fn(ctx) + err = fn(ctx) now = Now() switch { @@ -191,6 +203,7 @@ type adaptiveThrottleOptions struct { minRate float64 d time.Duration isErrorAccepted func(err error) bool + validate func(p Priority, priorities int) (Priority, error) } // WithAdaptiveThrottleRatio sets the ratio of the measured success rate and the rate that the throttle @@ -234,14 +247,38 @@ func WithAcceptedErrors(fn func(err error) bool) AdaptiveThrottleOption { }} } +// WithPriorityValidator sets the function that validates input priority values. +// +// The function should return the validated priority value. If the priority is +// invalid, the function should return an error. +func WithPriorityValidator(fn func(p Priority, priorities int) (Priority, error)) AdaptiveThrottleOption { + return AdaptiveThrottleOption{func(opts *adaptiveThrottleOptions) { + opts.validate = func(p Priority, priorities int) (Priority, error) { + p, err := fn(p, priorities) + if err != nil { + return p, err + } + + // This is a safeguard in case the validator function does not return a + // valid priority. It is better to panic with this functions, because + // the message is more informative. + return OnInvalidPriorityPanic(p, priorities) + } + }} +} + func Throttle[T any]( ctx context.Context, at *AdaptiveThrottle, defaultPriority Priority, throttledFn throttledArgsFn[T], fallbackFn ...fallbackArgsFn[T], -) (T, error) { - priority := PriorityFromContext(ctx, defaultPriority) +) (res T, err error) { + priority, err := at.validate(PriorityFromContext(ctx, defaultPriority), at.priorities) + if err != nil { + return res, err + } + now := Now() rejectionProbability := at.rejectionProbability(priority, now) if rand.Float64() < rejectionProbability { @@ -261,7 +298,7 @@ func Throttle[T any]( return zero, ClientSideRejectionError } - t, err := throttledFn(ctx) + res, err = throttledFn(ctx) now = Now() switch { @@ -282,7 +319,7 @@ func Throttle[T any]( return fallbackFn[0](ctx, err, false) } - return t, err + return res, err } // WithAdaptiveThrottle is used to send a request to a backend using the given AdaptiveThrottle for @@ -298,7 +335,12 @@ func WithAdaptiveThrottle[T any]( at *AdaptiveThrottle, priority Priority, throttledFn func() (T, error), -) (T, error) { +) (res T, err error) { + priority, err = at.validate(priority, at.priorities) + if err != nil { + return res, err + } + now := Now() rejectionProbability := at.rejectionProbability(priority, now) if rand.Float64() < rejectionProbability { @@ -314,7 +356,7 @@ func WithAdaptiveThrottle[T any]( return zero, ClientSideRejectionError } - t, err := throttledFn() + res, err = throttledFn() now = Now() switch { @@ -331,7 +373,7 @@ func WithAdaptiveThrottle[T any]( at.accept(priority, now) } - return t, err + return res, err } // RejectedError wraps an error to indicate that the error should be considered diff --git a/validator.go b/validator.go new file mode 100644 index 0000000..a2f6f77 --- /dev/null +++ b/validator.go @@ -0,0 +1,50 @@ +package bulwark + +import ( + "fmt" + "log/slog" + + "github.com/deixis/faults" +) + +var ( + // OnInvalidPriorityPanic panics when a priority is out of range. + // A priority is out of range when it is less than 0 or greater than or equal + // to priorities-1. + OnInvalidPriorityPanic = func(p Priority, priorities int) (Priority, error) { + if p < 0 || int(p) >= priorities-1 { + panic(fmt.Sprintf("bulwark: priority must be in the range [0, %d), but got %d", priorities, p)) + } + + return p, nil + } + + // OnInvalidPriorityAdjust adjusts the priority to the nearest valid value. + // A negative priority will be set to the lowest priority. + // A priority is out of range when it is less than 0 or greater than or equal + // to priorities-1. + OnInvalidPriorityAdjust = func(p Priority, priorities int) (Priority, error) { + if p >= 0 && int(p) < priorities { + return p, nil + } + slog.Warn("bulwark: priority is out of range", "max", priorities-1, "priority", p) + + // Receiving an invalid value is likely due to an input that was not properly + // validated, so this prevents abuse of the system. + return Priority(priorities - 1), nil + } + + // OnInvalidPriorityError returns an error when a priority is out of range. + // A priority is out of range when it is less than 0 or greater than or equal + // to priorities-1. + OnInvalidPriorityError = func(p Priority, priorities int) (Priority, error) { + if p < 0 || int(p) >= priorities-1 { + return p, faults.Bad(&faults.FieldViolation{ + Field: "priority", + Description: fmt.Sprintf("priority must be in the range [0, %d), but got %d", priorities, p), + }) + } + + return p, nil + } +)