Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions docs/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Policy 5 configured
- [type RetryPolicy](<#RetryPolicy>)
- [type StopPolicy](<#StopPolicy>)
- [func NewStopPolicy\(policy RetryPolicy\) \*StopPolicy](<#NewStopPolicy>)
- [func \(p \*StopPolicy\) GetMaxDuration\(\) time.Duration](<#StopPolicy.GetMaxDuration>)
- [func \(p \*StopPolicy\) NextDelay\(attempt int\) \(time.Duration, bool\)](<#StopPolicy.NextDelay>)
- [func \(p \*StopPolicy\) WithMaxAttempts\(attempts int\) \*StopPolicy](<#StopPolicy.WithMaxAttempts>)
- [func \(p \*StopPolicy\) WithMaxDuration\(duration time.Duration\) \*StopPolicy](<#StopPolicy.WithMaxDuration>)
Expand Down Expand Up @@ -328,7 +329,7 @@ policy := NewExponentialBackoffPolicy(100*time.Millisecond, 10*time.Second)
func (p *ExponentialBackoffPolicy) NextDelay(attempt int) (time.Duration, bool)
```

NextDelay calculates the delay for the next retry attempt using exponential backoff. The delay grows exponentially with each attempt, capped at MaxDelay. If jitter is enabled, the delay is randomized between delay/2 and delay.
NextDelay calculates the delay for the next retry attempt using exponential backoff. The delay grows exponentially with each attempt, capped at MaxDelay. If jitter is enabled, the delay is randomized between delay/2 and delay using cryptographically secure randomness.

<a name="ExponentialBackoffPolicy.WithJitter"></a>
### func \(\*ExponentialBackoffPolicy\) WithJitter
Expand Down Expand Up @@ -839,7 +840,7 @@ type RetryPolicy interface {
<a name="StopPolicy"></a>
## type StopPolicy

StopPolicy wraps another policy and stops retrying after a specified condition
StopPolicy wraps another policy and stops retrying after a specified condition. Note: For duration\-based stopping, timing is managed by the retrier, not the policy, to ensure thread safety and correct timing behavior.

```go
type StopPolicy struct {
Expand All @@ -856,6 +857,15 @@ func NewStopPolicy(policy RetryPolicy) *StopPolicy



<a name="StopPolicy.GetMaxDuration"></a>
### func \(\*StopPolicy\) GetMaxDuration

```go
func (p *StopPolicy) GetMaxDuration() time.Duration
```

GetMaxDuration returns the maximum duration setting for use by the retrier

<a name="StopPolicy.NextDelay"></a>
### func \(\*StopPolicy\) NextDelay

Expand Down
127 changes: 109 additions & 18 deletions policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
package goretry

import (
"crypto/rand"
"fmt"
"math"
"math/rand"
"math/big"
"time"
)

Expand All @@ -28,6 +30,9 @@ type FixedDelayPolicy struct {
//
// policy := NewFixedDelayPolicy(1 * time.Second) // 1 second between each retry
func NewFixedDelayPolicy(delay time.Duration) *FixedDelayPolicy {
if delay < 0 {
panic(fmt.Sprintf("delay must be non-negative, got %v", delay))
}
return &FixedDelayPolicy{Delay: delay}
}

Expand Down Expand Up @@ -72,6 +77,15 @@ type ExponentialBackoffPolicy struct {
//
// policy := NewExponentialBackoffPolicy(100*time.Millisecond, 10*time.Second)
func NewExponentialBackoffPolicy(baseDelay, maxDelay time.Duration) *ExponentialBackoffPolicy {
if baseDelay < 0 {
panic(fmt.Sprintf("baseDelay must be non-negative, got %v", baseDelay))
}
if maxDelay < 0 {
panic(fmt.Sprintf("maxDelay must be non-negative, got %v", maxDelay))
}
if maxDelay < baseDelay {
panic(fmt.Sprintf("maxDelay (%v) must be >= baseDelay (%v)", maxDelay, baseDelay))
}
return &ExponentialBackoffPolicy{
BaseDelay: baseDelay,
MaxDelay: maxDelay,
Expand All @@ -89,6 +103,12 @@ func NewExponentialBackoffPolicy(baseDelay, maxDelay time.Duration) *Exponential
// policy := NewExponentialBackoffPolicy(100*time.Millisecond, 10*time.Second).
// WithMultiplier(1.5) // Slower growth than default
func (p *ExponentialBackoffPolicy) WithMultiplier(multiplier float64) *ExponentialBackoffPolicy {
if multiplier <= 1.0 {
panic(fmt.Sprintf("multiplier must be > 1.0, got %v", multiplier))
}
if math.IsInf(multiplier, 0) || math.IsNaN(multiplier) {
panic(fmt.Sprintf("multiplier must be a finite number, got %v", multiplier))
}
p.Multiplier = multiplier
return p
}
Expand All @@ -109,20 +129,56 @@ func (p *ExponentialBackoffPolicy) WithJitter(jitter bool) *ExponentialBackoffPo

// NextDelay calculates the delay for the next retry attempt using exponential backoff.
// The delay grows exponentially with each attempt, capped at MaxDelay.
// If jitter is enabled, the delay is randomized between delay/2 and delay.
// If jitter is enabled, the delay is randomized between delay/2 and delay using cryptographically secure randomness.
func (p *ExponentialBackoffPolicy) NextDelay(attempt int) (time.Duration, bool) {
delay := time.Duration(float64(p.BaseDelay) * math.Pow(p.Multiplier, float64(attempt-1)))
if attempt <= 0 {
return 0, false
}

if delay > p.MaxDelay {
// Calculate exponential backoff with overflow protection
exponent := float64(attempt - 1)
multiplierPower := math.Pow(p.Multiplier, exponent)

// Check for overflow or infinity
if math.IsInf(multiplierPower, 0) || multiplierPower > float64(p.MaxDelay)/float64(p.BaseDelay) {
delay := p.MaxDelay
return p.applyJitter(delay), true
}

delay := time.Duration(float64(p.BaseDelay) * multiplierPower)

// Ensure we don't exceed MaxDelay
if delay > p.MaxDelay || delay < 0 { // negative check for overflow
delay = p.MaxDelay
}

if p.Jitter {
jitter := time.Duration(rand.Int63n(int64(delay)))
delay = delay/2 + jitter
return p.applyJitter(delay), true
}

// applyJitter applies cryptographically secure jitter to the delay if enabled.
// Jitter randomizes the delay between delay/2 and delay to prevent thundering herd.
func (p *ExponentialBackoffPolicy) applyJitter(delay time.Duration) time.Duration {
if !p.Jitter || delay <= 0 {
return delay
}

return delay, true
// Use crypto/rand for thread-safe, cryptographically secure randomness
halfDelay := delay / 2
maxJitter := int64(delay - halfDelay)

if maxJitter <= 0 {
return delay
}

// Generate secure random number
jitterBig, err := rand.Int(rand.Reader, big.NewInt(maxJitter))
if err != nil {
// Fallback to no jitter if crypto/rand fails
return delay
}

jitter := time.Duration(jitterBig.Int64())
return halfDelay + jitter
}

// LinearBackoffPolicy implements linear backoff
Expand All @@ -133,6 +189,18 @@ type LinearBackoffPolicy struct {
}

func NewLinearBackoffPolicy(baseDelay, increment, maxDelay time.Duration) *LinearBackoffPolicy {
if baseDelay < 0 {
panic(fmt.Sprintf("baseDelay must be non-negative, got %v", baseDelay))
}
if increment < 0 {
panic(fmt.Sprintf("increment must be non-negative, got %v", increment))
}
if maxDelay < 0 {
panic(fmt.Sprintf("maxDelay must be non-negative, got %v", maxDelay))
}
if maxDelay < baseDelay {
panic(fmt.Sprintf("maxDelay (%v) must be >= baseDelay (%v)", maxDelay, baseDelay))
}
return &LinearBackoffPolicy{
BaseDelay: baseDelay,
MaxDelay: maxDelay,
Expand All @@ -141,9 +209,20 @@ func NewLinearBackoffPolicy(baseDelay, increment, maxDelay time.Duration) *Linea
}

func (p *LinearBackoffPolicy) NextDelay(attempt int) (time.Duration, bool) {
delay := p.BaseDelay + time.Duration(attempt-1)*p.Increment
if attempt <= 0 {
return 0, false
}

// Protect against overflow in multiplication
increment := time.Duration(attempt-1) * p.Increment
if attempt > 1 && increment < 0 { // Overflow check
return p.MaxDelay, true
}

if delay > p.MaxDelay {
delay := p.BaseDelay + increment

// Check for overflow or exceeding max
if delay < p.BaseDelay || delay > p.MaxDelay {
delay = p.MaxDelay
}

Expand All @@ -161,27 +240,36 @@ func (p *NoDelayPolicy) NextDelay(attempt int) (time.Duration, bool) {
return 0, true
}

// StopPolicy wraps another policy and stops retrying after a specified condition
// StopPolicy wraps another policy and stops retrying after a specified condition.
// Note: For duration-based stopping, timing is managed by the retrier, not the policy,
// to ensure thread safety and correct timing behavior.
type StopPolicy struct {
policy RetryPolicy
maxAttempts int
maxDuration time.Duration
startTime time.Time
}

func NewStopPolicy(policy RetryPolicy) *StopPolicy {
if policy == nil {
panic("policy cannot be nil")
}
return &StopPolicy{
policy: policy,
startTime: time.Now(),
policy: policy,
}
}

func (p *StopPolicy) WithMaxAttempts(attempts int) *StopPolicy {
if attempts < 0 {
panic(fmt.Sprintf("maxAttempts must be non-negative, got %d", attempts))
}
p.maxAttempts = attempts
return p
}

func (p *StopPolicy) WithMaxDuration(duration time.Duration) *StopPolicy {
if duration < 0 {
panic(fmt.Sprintf("maxDuration must be non-negative, got %v", duration))
}
p.maxDuration = duration
return p
}
Expand All @@ -191,9 +279,12 @@ func (p *StopPolicy) NextDelay(attempt int) (time.Duration, bool) {
return 0, false
}

if p.maxDuration > 0 && time.Since(p.startTime) >= p.maxDuration {
return 0, false
}

// Duration checking will be handled by the retrier that maintains the start time
// This ensures thread safety and proper timing behavior per retry operation
return p.policy.NextDelay(attempt)
}

// GetMaxDuration returns the maximum duration setting for use by the retrier
func (p *StopPolicy) GetMaxDuration() time.Duration {
return p.maxDuration
}
22 changes: 14 additions & 8 deletions policies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,12 @@ func TestStopPolicy_MaxDuration(t *testing.T) {
basePolicy := NewFixedDelayPolicy(10 * time.Millisecond)
policy := NewStopPolicy(basePolicy).WithMaxDuration(50 * time.Millisecond)

// First attempt should work
// Test that policy correctly reports max duration
if policy.GetMaxDuration() != 50*time.Millisecond {
t.Errorf("Expected max duration 50ms, got %v", policy.GetMaxDuration())
}

// Duration checking is now handled by the retrier, so test NextDelay behavior
delay, shouldContinue := policy.NextDelay(1)
if !shouldContinue {
t.Error("Expected to continue at first attempt")
Expand All @@ -135,12 +140,13 @@ func TestStopPolicy_MaxDuration(t *testing.T) {
t.Errorf("Expected 10ms delay, got %v", delay)
}

// Wait longer than max duration
time.Sleep(60 * time.Millisecond)

// Second attempt should be stopped
_, shouldContinue = policy.NextDelay(2)
if shouldContinue {
t.Error("Expected to stop after max duration exceeded")
// Policy no longer tracks time internally - that's handled by retrier
// This ensures thread safety and proper timing per retry operation
delay, shouldContinue = policy.NextDelay(2)
if !shouldContinue {
t.Error("Expected policy to continue (duration checking moved to retrier)")
}
if delay != 10*time.Millisecond {
t.Errorf("Expected 10ms delay, got %v", delay)
}
}
34 changes: 32 additions & 2 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ func (e *OutOfRetriesError) Unwrap() error {
// WithTransientErrorFunc(customTransientFunc),
// )
func NewRetrier(policy RetryPolicy, options ...Option) *Retrier {
if policy == nil {
panic("policy cannot be nil")
}

r := &Retrier{
policy: policy,
maxAttempts: 3,
Expand All @@ -120,6 +124,10 @@ func NewRetrier(policy RetryPolicy, options ...Option) *Retrier {
opt(r)
}

if r.maxAttempts <= 0 {
panic(fmt.Sprintf("maxAttempts must be positive, got %d", r.maxAttempts))
}

return r
}

Expand All @@ -135,6 +143,9 @@ type Option func(*Retrier)
// retrier := NewRetrier(policy, WithMaxAttempts(5)) // Will try up to 5 times total
func WithMaxAttempts(attempts int) Option {
return func(r *Retrier) {
if attempts <= 0 {
panic(fmt.Sprintf("maxAttempts must be positive, got %d", attempts))
}
r.maxAttempts = attempts
}
}
Expand Down Expand Up @@ -220,8 +231,20 @@ func (r *Retrier) Do(fn func() error) error {
func (r *Retrier) DoWithContext(ctx context.Context, fn func(context.Context) error) error {
var allErrors []error
var lastErr error
startTime := time.Now()

// Check if policy has duration limit (for StopPolicy)
var maxDuration time.Duration
if stopPolicy, ok := r.policy.(*StopPolicy); ok {
maxDuration = stopPolicy.GetMaxDuration()
}

for attempt := 1; attempt <= r.maxAttempts; attempt++ {
// Check duration limit before each attempt
if maxDuration > 0 && time.Since(startTime) >= maxDuration {
break
}

err := fn(ctx)
if err == nil {
return nil
Expand Down Expand Up @@ -251,11 +274,18 @@ func (r *Retrier) DoWithContext(ctx context.Context, fn func(context.Context) er
break
}

// Wait for the delay or context cancellation
// Check duration limit before waiting
if maxDuration > 0 && time.Since(startTime)+delay >= maxDuration {
break
}

// Wait for the delay or context cancellation with proper timer cleanup
timer := time.NewTimer(delay)
select {
case <-ctx.Done():
timer.Stop()
return ctx.Err()
case <-time.After(delay):
case <-timer.C:
// Continue to next attempt
}
}
Expand Down
Loading
Loading