diff --git a/ratelimit.go b/ratelimit.go index 504d03d..325f2a2 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -76,7 +76,7 @@ func (r *RateLimiter) RateLimit(fn RateLimiterFn) error { cancel context.CancelFunc ) - ctx = context.WithValue(context.Background(), rateLimitCtxKey, rateLimitCtxVal) + ctx = r.makeCtx() ctx, cancel = context.WithTimeout(context.Background(), r.bucketTimeout) defer cancel() @@ -84,6 +84,12 @@ func (r *RateLimiter) RateLimit(fn RateLimiterFn) error { return r.RateLimitContext(ctx, fn) } +func (r *RateLimiter) RateLimitSync(fn RateLimiterFn) error { + ctx := r.makeCtx() + + return r.RateLimitContextSync(ctx, fn) +} + // RateLimitContext calls the passed function fn using the passed context.Context. // If the passed context's deadline is exceeded while waiting to acquire bucket space, a wrapped error is returned. // Otherwise, the result of calling fn() is returned. @@ -103,6 +109,10 @@ func (r *RateLimiter) RateLimitContext(ctx context.Context, fn RateLimiterFn) er return err } +func (r *RateLimiter) RateLimitContextSync(ctx context.Context, fn RateLimiterFn) error { + return fn() +} + // Wrap wraps the passed function within another function, which no parameters // and calls RateLimit using fn(). Helpful for building functions which are "preloaded" with // a rate limiter, rather than calling RateLimit around every function call. @@ -224,6 +234,10 @@ func (r *RateLimiter) withWriteLock(f func()) { f() } +func (r RateLimiter) makeCtx() context.Context { + return context.WithValue(context.Background(), rateLimitCtxKey, rateLimitCtxVal) +} + func (r RateLimiter) checkRateLimiterCtx(ctx context.Context) bool { // check if context was created by r.RateLimit instead of // a context which was passed from elsewhere