diff --git a/docs/api/README.md b/docs/api/README.md
index bf225c1..0a04809 100644
--- a/docs/api/README.md
+++ b/docs/api/README.md
@@ -109,6 +109,7 @@ Policy 5 configured
- [type RetryPolicy](<#RetryPolicy>)
- [type StopPolicy](<#StopPolicy>)
- [func NewStopPolicy\(policy RetryPolicy\) \*StopPolicy](<#NewStopPolicy>)
+ - [func \(p \*StopPolicy\) GetMaxDuration\(\) time.Duration](<#StopPolicy.GetMaxDuration>)
- [func \(p \*StopPolicy\) NextDelay\(attempt int\) \(time.Duration, bool\)](<#StopPolicy.NextDelay>)
- [func \(p \*StopPolicy\) WithMaxAttempts\(attempts int\) \*StopPolicy](<#StopPolicy.WithMaxAttempts>)
- [func \(p \*StopPolicy\) WithMaxDuration\(duration time.Duration\) \*StopPolicy](<#StopPolicy.WithMaxDuration>)
@@ -328,7 +329,7 @@ policy := NewExponentialBackoffPolicy(100*time.Millisecond, 10*time.Second)
func (p *ExponentialBackoffPolicy) NextDelay(attempt int) (time.Duration, bool)
```
-NextDelay calculates the delay for the next retry attempt using exponential backoff. The delay grows exponentially with each attempt, capped at MaxDelay. If jitter is enabled, the delay is randomized between delay/2 and delay.
+NextDelay calculates the delay for the next retry attempt using exponential backoff. The delay grows exponentially with each attempt, capped at MaxDelay. If jitter is enabled, the delay is randomized between delay/2 and delay using cryptographically secure randomness.
### func \(\*ExponentialBackoffPolicy\) WithJitter
@@ -839,7 +840,7 @@ type RetryPolicy interface {
## type StopPolicy
-StopPolicy wraps another policy and stops retrying after a specified condition
+StopPolicy wraps another policy and stops retrying after a specified condition. Note: For duration\-based stopping, timing is managed by the retrier, not the policy, to ensure thread safety and correct timing behavior.
```go
type StopPolicy struct {
@@ -856,6 +857,15 @@ func NewStopPolicy(policy RetryPolicy) *StopPolicy
+
+### func \(\*StopPolicy\) GetMaxDuration
+
+```go
+func (p *StopPolicy) GetMaxDuration() time.Duration
+```
+
+GetMaxDuration returns the maximum duration setting for use by the retrier
+
### func \(\*StopPolicy\) NextDelay
diff --git a/policies.go b/policies.go
index 5ad9b0a..03b32f5 100644
--- a/policies.go
+++ b/policies.go
@@ -3,8 +3,10 @@
package goretry
import (
+ "crypto/rand"
+ "fmt"
"math"
- "math/rand"
+ "math/big"
"time"
)
@@ -28,6 +30,9 @@ type FixedDelayPolicy struct {
//
// policy := NewFixedDelayPolicy(1 * time.Second) // 1 second between each retry
func NewFixedDelayPolicy(delay time.Duration) *FixedDelayPolicy {
+ if delay < 0 {
+ panic(fmt.Sprintf("delay must be non-negative, got %v", delay))
+ }
return &FixedDelayPolicy{Delay: delay}
}
@@ -72,6 +77,15 @@ type ExponentialBackoffPolicy struct {
//
// policy := NewExponentialBackoffPolicy(100*time.Millisecond, 10*time.Second)
func NewExponentialBackoffPolicy(baseDelay, maxDelay time.Duration) *ExponentialBackoffPolicy {
+ if baseDelay < 0 {
+ panic(fmt.Sprintf("baseDelay must be non-negative, got %v", baseDelay))
+ }
+ if maxDelay < 0 {
+ panic(fmt.Sprintf("maxDelay must be non-negative, got %v", maxDelay))
+ }
+ if maxDelay < baseDelay {
+ panic(fmt.Sprintf("maxDelay (%v) must be >= baseDelay (%v)", maxDelay, baseDelay))
+ }
return &ExponentialBackoffPolicy{
BaseDelay: baseDelay,
MaxDelay: maxDelay,
@@ -89,6 +103,12 @@ func NewExponentialBackoffPolicy(baseDelay, maxDelay time.Duration) *Exponential
// policy := NewExponentialBackoffPolicy(100*time.Millisecond, 10*time.Second).
// WithMultiplier(1.5) // Slower growth than default
func (p *ExponentialBackoffPolicy) WithMultiplier(multiplier float64) *ExponentialBackoffPolicy {
+ if multiplier <= 1.0 {
+ panic(fmt.Sprintf("multiplier must be > 1.0, got %v", multiplier))
+ }
+ if math.IsInf(multiplier, 0) || math.IsNaN(multiplier) {
+ panic(fmt.Sprintf("multiplier must be a finite number, got %v", multiplier))
+ }
p.Multiplier = multiplier
return p
}
@@ -109,20 +129,56 @@ func (p *ExponentialBackoffPolicy) WithJitter(jitter bool) *ExponentialBackoffPo
// NextDelay calculates the delay for the next retry attempt using exponential backoff.
// The delay grows exponentially with each attempt, capped at MaxDelay.
-// If jitter is enabled, the delay is randomized between delay/2 and delay.
+// If jitter is enabled, the delay is randomized between delay/2 and delay using cryptographically secure randomness.
func (p *ExponentialBackoffPolicy) NextDelay(attempt int) (time.Duration, bool) {
- delay := time.Duration(float64(p.BaseDelay) * math.Pow(p.Multiplier, float64(attempt-1)))
+ if attempt <= 0 {
+ return 0, false
+ }
- if delay > p.MaxDelay {
+ // Calculate exponential backoff with overflow protection
+ exponent := float64(attempt - 1)
+ multiplierPower := math.Pow(p.Multiplier, exponent)
+
+ // Check for overflow or infinity
+ if math.IsInf(multiplierPower, 0) || multiplierPower > float64(p.MaxDelay)/float64(p.BaseDelay) {
+ delay := p.MaxDelay
+ return p.applyJitter(delay), true
+ }
+
+ delay := time.Duration(float64(p.BaseDelay) * multiplierPower)
+
+ // Ensure we don't exceed MaxDelay
+ if delay > p.MaxDelay || delay < 0 { // negative check for overflow
delay = p.MaxDelay
}
- if p.Jitter {
- jitter := time.Duration(rand.Int63n(int64(delay)))
- delay = delay/2 + jitter
+ return p.applyJitter(delay), true
+}
+
+// applyJitter applies cryptographically secure jitter to the delay if enabled.
+// Jitter randomizes the delay between delay/2 and delay to prevent thundering herd.
+func (p *ExponentialBackoffPolicy) applyJitter(delay time.Duration) time.Duration {
+ if !p.Jitter || delay <= 0 {
+ return delay
}
- return delay, true
+ // Use crypto/rand for thread-safe, cryptographically secure randomness
+ halfDelay := delay / 2
+ maxJitter := int64(delay - halfDelay)
+
+ if maxJitter <= 0 {
+ return delay
+ }
+
+ // Generate secure random number
+ jitterBig, err := rand.Int(rand.Reader, big.NewInt(maxJitter))
+ if err != nil {
+ // Fallback to no jitter if crypto/rand fails
+ return delay
+ }
+
+ jitter := time.Duration(jitterBig.Int64())
+ return halfDelay + jitter
}
// LinearBackoffPolicy implements linear backoff
@@ -133,6 +189,18 @@ type LinearBackoffPolicy struct {
}
func NewLinearBackoffPolicy(baseDelay, increment, maxDelay time.Duration) *LinearBackoffPolicy {
+ if baseDelay < 0 {
+ panic(fmt.Sprintf("baseDelay must be non-negative, got %v", baseDelay))
+ }
+ if increment < 0 {
+ panic(fmt.Sprintf("increment must be non-negative, got %v", increment))
+ }
+ if maxDelay < 0 {
+ panic(fmt.Sprintf("maxDelay must be non-negative, got %v", maxDelay))
+ }
+ if maxDelay < baseDelay {
+ panic(fmt.Sprintf("maxDelay (%v) must be >= baseDelay (%v)", maxDelay, baseDelay))
+ }
return &LinearBackoffPolicy{
BaseDelay: baseDelay,
MaxDelay: maxDelay,
@@ -141,9 +209,20 @@ func NewLinearBackoffPolicy(baseDelay, increment, maxDelay time.Duration) *Linea
}
func (p *LinearBackoffPolicy) NextDelay(attempt int) (time.Duration, bool) {
- delay := p.BaseDelay + time.Duration(attempt-1)*p.Increment
+ if attempt <= 0 {
+ return 0, false
+ }
+
+ // Protect against overflow in multiplication
+ increment := time.Duration(attempt-1) * p.Increment
+ if attempt > 1 && increment < 0 { // Overflow check
+ return p.MaxDelay, true
+ }
- if delay > p.MaxDelay {
+ delay := p.BaseDelay + increment
+
+ // Check for overflow or exceeding max
+ if delay < p.BaseDelay || delay > p.MaxDelay {
delay = p.MaxDelay
}
@@ -161,27 +240,36 @@ func (p *NoDelayPolicy) NextDelay(attempt int) (time.Duration, bool) {
return 0, true
}
-// StopPolicy wraps another policy and stops retrying after a specified condition
+// StopPolicy wraps another policy and stops retrying after a specified condition.
+// Note: For duration-based stopping, timing is managed by the retrier, not the policy,
+// to ensure thread safety and correct timing behavior.
type StopPolicy struct {
policy RetryPolicy
maxAttempts int
maxDuration time.Duration
- startTime time.Time
}
func NewStopPolicy(policy RetryPolicy) *StopPolicy {
+ if policy == nil {
+ panic("policy cannot be nil")
+ }
return &StopPolicy{
- policy: policy,
- startTime: time.Now(),
+ policy: policy,
}
}
func (p *StopPolicy) WithMaxAttempts(attempts int) *StopPolicy {
+ if attempts < 0 {
+ panic(fmt.Sprintf("maxAttempts must be non-negative, got %d", attempts))
+ }
p.maxAttempts = attempts
return p
}
func (p *StopPolicy) WithMaxDuration(duration time.Duration) *StopPolicy {
+ if duration < 0 {
+ panic(fmt.Sprintf("maxDuration must be non-negative, got %v", duration))
+ }
p.maxDuration = duration
return p
}
@@ -191,9 +279,12 @@ func (p *StopPolicy) NextDelay(attempt int) (time.Duration, bool) {
return 0, false
}
- if p.maxDuration > 0 && time.Since(p.startTime) >= p.maxDuration {
- return 0, false
- }
-
+ // Duration checking will be handled by the retrier that maintains the start time
+ // This ensures thread safety and proper timing behavior per retry operation
return p.policy.NextDelay(attempt)
}
+
+// GetMaxDuration returns the maximum duration setting for use by the retrier
+func (p *StopPolicy) GetMaxDuration() time.Duration {
+ return p.maxDuration
+}
diff --git a/policies_test.go b/policies_test.go
index 61db8ff..9d5abc6 100644
--- a/policies_test.go
+++ b/policies_test.go
@@ -126,7 +126,12 @@ func TestStopPolicy_MaxDuration(t *testing.T) {
basePolicy := NewFixedDelayPolicy(10 * time.Millisecond)
policy := NewStopPolicy(basePolicy).WithMaxDuration(50 * time.Millisecond)
- // First attempt should work
+ // Test that policy correctly reports max duration
+ if policy.GetMaxDuration() != 50*time.Millisecond {
+ t.Errorf("Expected max duration 50ms, got %v", policy.GetMaxDuration())
+ }
+
+ // Duration checking is now handled by the retrier, so test NextDelay behavior
delay, shouldContinue := policy.NextDelay(1)
if !shouldContinue {
t.Error("Expected to continue at first attempt")
@@ -135,12 +140,13 @@ func TestStopPolicy_MaxDuration(t *testing.T) {
t.Errorf("Expected 10ms delay, got %v", delay)
}
- // Wait longer than max duration
- time.Sleep(60 * time.Millisecond)
-
- // Second attempt should be stopped
- _, shouldContinue = policy.NextDelay(2)
- if shouldContinue {
- t.Error("Expected to stop after max duration exceeded")
+ // Policy no longer tracks time internally - that's handled by retrier
+ // This ensures thread safety and proper timing per retry operation
+ delay, shouldContinue = policy.NextDelay(2)
+ if !shouldContinue {
+ t.Error("Expected policy to continue (duration checking moved to retrier)")
+ }
+ if delay != 10*time.Millisecond {
+ t.Errorf("Expected 10ms delay, got %v", delay)
}
}
diff --git a/retry.go b/retry.go
index 8e995d4..c66853d 100644
--- a/retry.go
+++ b/retry.go
@@ -110,6 +110,10 @@ func (e *OutOfRetriesError) Unwrap() error {
// WithTransientErrorFunc(customTransientFunc),
// )
func NewRetrier(policy RetryPolicy, options ...Option) *Retrier {
+ if policy == nil {
+ panic("policy cannot be nil")
+ }
+
r := &Retrier{
policy: policy,
maxAttempts: 3,
@@ -120,6 +124,10 @@ func NewRetrier(policy RetryPolicy, options ...Option) *Retrier {
opt(r)
}
+ if r.maxAttempts <= 0 {
+ panic(fmt.Sprintf("maxAttempts must be positive, got %d", r.maxAttempts))
+ }
+
return r
}
@@ -135,6 +143,9 @@ type Option func(*Retrier)
// retrier := NewRetrier(policy, WithMaxAttempts(5)) // Will try up to 5 times total
func WithMaxAttempts(attempts int) Option {
return func(r *Retrier) {
+ if attempts <= 0 {
+ panic(fmt.Sprintf("maxAttempts must be positive, got %d", attempts))
+ }
r.maxAttempts = attempts
}
}
@@ -220,8 +231,20 @@ func (r *Retrier) Do(fn func() error) error {
func (r *Retrier) DoWithContext(ctx context.Context, fn func(context.Context) error) error {
var allErrors []error
var lastErr error
+ startTime := time.Now()
+
+ // Check if policy has duration limit (for StopPolicy)
+ var maxDuration time.Duration
+ if stopPolicy, ok := r.policy.(*StopPolicy); ok {
+ maxDuration = stopPolicy.GetMaxDuration()
+ }
for attempt := 1; attempt <= r.maxAttempts; attempt++ {
+ // Check duration limit before each attempt
+ if maxDuration > 0 && time.Since(startTime) >= maxDuration {
+ break
+ }
+
err := fn(ctx)
if err == nil {
return nil
@@ -251,11 +274,18 @@ func (r *Retrier) DoWithContext(ctx context.Context, fn func(context.Context) er
break
}
- // Wait for the delay or context cancellation
+ // Check duration limit before waiting
+ if maxDuration > 0 && time.Since(startTime)+delay >= maxDuration {
+ break
+ }
+
+ // Wait for the delay or context cancellation with proper timer cleanup
+ timer := time.NewTimer(delay)
select {
case <-ctx.Done():
+ timer.Stop()
return ctx.Err()
- case <-time.After(delay):
+ case <-timer.C:
// Continue to next attempt
}
}
diff --git a/validation_test.go b/validation_test.go
new file mode 100644
index 0000000..8ed834a
--- /dev/null
+++ b/validation_test.go
@@ -0,0 +1,291 @@
+package goretry
+
+import (
+ "math"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestValidation_FixedDelayPolicy(t *testing.T) {
+ // Test negative delay panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for negative delay")
+ }
+ }()
+ NewFixedDelayPolicy(-1 * time.Second)
+ }()
+}
+
+func TestValidation_ExponentialBackoffPolicy(t *testing.T) {
+ // Test negative base delay panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for negative base delay")
+ }
+ }()
+ NewExponentialBackoffPolicy(-1*time.Second, 1*time.Second)
+ }()
+
+ // Test negative max delay panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for negative max delay")
+ }
+ }()
+ NewExponentialBackoffPolicy(1*time.Second, -1*time.Second)
+ }()
+
+ // Test max delay < base delay panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for max delay < base delay")
+ }
+ }()
+ NewExponentialBackoffPolicy(2*time.Second, 1*time.Second)
+ }()
+
+ // Test invalid multiplier panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for multiplier <= 1.0")
+ }
+ }()
+ policy := NewExponentialBackoffPolicy(100*time.Millisecond, 1*time.Second)
+ policy.WithMultiplier(0.5)
+ }()
+
+ // Test NaN multiplier panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for NaN multiplier")
+ }
+ }()
+ policy := NewExponentialBackoffPolicy(100*time.Millisecond, 1*time.Second)
+ policy.WithMultiplier(math.NaN())
+ }()
+}
+
+func TestValidation_LinearBackoffPolicy(t *testing.T) {
+ // Test negative base delay panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for negative base delay")
+ }
+ }()
+ NewLinearBackoffPolicy(-1*time.Second, 100*time.Millisecond, 1*time.Second)
+ }()
+
+ // Test negative increment panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for negative increment")
+ }
+ }()
+ NewLinearBackoffPolicy(100*time.Millisecond, -100*time.Millisecond, 1*time.Second)
+ }()
+
+ // Test negative max delay panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for negative max delay")
+ }
+ }()
+ NewLinearBackoffPolicy(100*time.Millisecond, 100*time.Millisecond, -1*time.Second)
+ }()
+
+ // Test max delay < base delay panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for max delay < base delay")
+ }
+ }()
+ NewLinearBackoffPolicy(2*time.Second, 100*time.Millisecond, 1*time.Second)
+ }()
+}
+
+func TestValidation_StopPolicy(t *testing.T) {
+ // Test nil policy panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for nil policy")
+ }
+ }()
+ NewStopPolicy(nil)
+ }()
+
+ // Test negative max attempts panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for negative max attempts")
+ }
+ }()
+ policy := NewStopPolicy(NewFixedDelayPolicy(100 * time.Millisecond))
+ policy.WithMaxAttempts(-1)
+ }()
+
+ // Test negative max duration panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for negative max duration")
+ }
+ }()
+ policy := NewStopPolicy(NewFixedDelayPolicy(100 * time.Millisecond))
+ policy.WithMaxDuration(-1 * time.Second)
+ }()
+}
+
+func TestValidation_Retrier(t *testing.T) {
+ // Test nil policy panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for nil policy")
+ }
+ }()
+ NewRetrier(nil)
+ }()
+
+ // Test zero max attempts panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for zero max attempts")
+ }
+ }()
+ policy := NewFixedDelayPolicy(100 * time.Millisecond)
+ NewRetrier(policy, WithMaxAttempts(0))
+ }()
+
+ // Test negative max attempts panics
+ func() {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("Expected panic for negative max attempts")
+ }
+ }()
+ policy := NewFixedDelayPolicy(100 * time.Millisecond)
+ NewRetrier(policy, WithMaxAttempts(-1)) // The panic happens when this option is applied
+ }()
+}
+
+func TestOverflowProtection_ExponentialBackoff(t *testing.T) {
+ policy := NewExponentialBackoffPolicy(1*time.Millisecond, 1*time.Hour).
+ WithMultiplier(10.0).
+ WithJitter(false)
+
+ // Test very large attempt numbers don't cause overflow
+ delay, shouldContinue := policy.NextDelay(1000)
+ if !shouldContinue {
+ t.Error("Expected to continue")
+ }
+ if delay != 1*time.Hour {
+ t.Errorf("Expected max delay (1 hour), got %v", delay)
+ }
+
+ // Test zero attempt returns false
+ delay, shouldContinue = policy.NextDelay(0)
+ if shouldContinue {
+ t.Error("Expected to not continue for zero attempt")
+ }
+ if delay != 0 {
+ t.Errorf("Expected zero delay for zero attempt, got %v", delay)
+ }
+}
+
+func TestOverflowProtection_LinearBackoff(t *testing.T) {
+ policy := NewLinearBackoffPolicy(1*time.Millisecond, 1*time.Hour, 2*time.Hour)
+
+ // Test very large attempt numbers don't cause overflow
+ delay, shouldContinue := policy.NextDelay(1000000)
+ if !shouldContinue {
+ t.Error("Expected to continue")
+ }
+ if delay != 2*time.Hour {
+ t.Errorf("Expected max delay (2 hours), got %v", delay)
+ }
+
+ // Test zero attempt returns false
+ delay, shouldContinue = policy.NextDelay(0)
+ if shouldContinue {
+ t.Error("Expected to not continue for zero attempt")
+ }
+ if delay != 0 {
+ t.Errorf("Expected zero delay for zero attempt, got %v", delay)
+ }
+}
+
+func TestJitterSecurity(t *testing.T) {
+ policy := NewExponentialBackoffPolicy(100*time.Millisecond, 1*time.Second).
+ WithJitter(true)
+
+ // Test that jitter produces different values (with very high probability)
+ delays := make(map[time.Duration]bool)
+ for i := 0; i < 100; i++ {
+ delay, _ := policy.NextDelay(5) // Use a consistent attempt number
+ delays[delay] = true
+ }
+
+ // With crypto/rand, we should get significant variation
+ if len(delays) < 50 { // Allow some duplicates but expect good distribution
+ t.Errorf("Expected significant jitter variation, got %d unique values out of 100", len(delays))
+ }
+
+ // Test jitter range - should be between delay/2 and delay
+ expectedBase := time.Duration(float64(100*time.Millisecond) * 16) // 2^4 for attempt 5
+ if expectedBase > 1*time.Second {
+ expectedBase = 1 * time.Second
+ }
+ minExpected := expectedBase / 2
+ maxExpected := expectedBase
+
+ for delay := range delays {
+ if delay < minExpected || delay > maxExpected {
+ t.Errorf("Jitter delay %v outside expected range [%v, %v]", delay, minExpected, maxExpected)
+ }
+ }
+}
+
+func TestStopPolicyIntegration(t *testing.T) {
+ // Test that StopPolicy duration limiting works through retrier
+ basePolicy := NewFixedDelayPolicy(10 * time.Millisecond)
+ stopPolicy := NewStopPolicy(basePolicy).WithMaxDuration(50 * time.Millisecond)
+ retrier := NewRetrier(stopPolicy, WithMaxAttempts(10))
+
+ callCount := 0
+ start := time.Now()
+ err := retrier.Do(func() error {
+ callCount++
+ return &testError{message: "always fails", timeout: true}
+ })
+
+ elapsed := time.Since(start)
+
+ // Should stop due to duration limit, not attempt limit
+ if elapsed < 40*time.Millisecond || elapsed > 80*time.Millisecond {
+ t.Errorf("Expected to stop around 50ms, took %v", elapsed)
+ }
+
+ if callCount >= 10 {
+ t.Errorf("Expected fewer calls due to duration limit, got %d", callCount)
+ }
+
+ // Should return OutOfRetriesError
+ if !strings.Contains(err.Error(), "retry failed after") {
+ t.Errorf("Expected OutOfRetriesError, got %T: %v", err, err)
+ }
+}