diff --git a/README.md b/README.md index d1bf7a8..f226d23 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ import ( ) func main() { - limiter := botrate.New( + limiter, err := botrate.New( // Rate limiting for blacklisted IPs only botrate.WithLimit(rate.Every(10*time.Minute)), @@ -61,22 +61,17 @@ func main() { botrate.WithAnalyzerPageThreshold(50), botrate.WithAnalyzerQueueCap(10000), ) + if err != nil { + log.Fatalf("Failed to create limiter: %v", err) + } defer limiter.Close() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ua := r.UserAgent() ip := extractIP(r) - result := limiter.Allow(ua, ip) - if !result.Allowed { - switch result.Reason { - case "fake bot": - http.Error(w, "Bot verification failed", http.StatusForbidden) - case "rate limited": - http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) - default: - http.Error(w, "Forbidden", http.StatusForbidden) - } + if !limiter.Allow(ua, ip) { + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) return } @@ -192,7 +187,10 @@ if err != nil { Gracefully shuts down the limiter and releases resources. **Always call this when the limiter is no longer needed.** ```go -limiter := botrate.New(...) +limiter, err := botrate.New(...) +if err != nil { + log.Fatalf("Failed to create limiter: %v", err) +} defer limiter.Close() ``` @@ -301,44 +299,59 @@ if !allowed { ### Strict Rate Limiting ```go -limiter := botrate.New( +limiter, err := botrate.New( botrate.WithLimit(rate.Every(10*time.Minute)), ) +if err != nil { + log.Fatalf("Failed to create limiter: %v", err) +} ``` ### Aggressive Bot Detection ```go -limiter := botrate.New( +limiter, err := botrate.New( botrate.WithAnalyzerWindow(30*time.Second), botrate.WithAnalyzerPageThreshold(20), ) +if err != nil { + log.Fatalf("Failed to create limiter: %v", err) +} ``` ### High-Throughput Configuration ```go -limiter := botrate.New( +limiter, err := botrate.New( botrate.WithAnalyzerWindow(10*time.Minute), botrate.WithAnalyzerPageThreshold(100), botrate.WithAnalyzerQueueCap(50000), ) +if err != nil { + log.Fatalf("Failed to create limiter: %v", err) +} ``` ### Custom KnownBots Validator ```go // Create custom validator with specific configuration -customKB := knownbots.New( +customKB, err := knownbots.New( knownbots.WithRoot("./custom-bots"), knownbots.WithSchedulerInterval(12*time.Hour), ) +if err != nil { + log.Fatalf("Failed to create validator: %v", err) +} // Use custom validator -limiter := botrate.New( +limiter, err := botrate.New( botrate.WithKnownbots(customKB), botrate.WithLimit(rate.Every(5*time.Minute)), ) +if err != nil { + log.Fatalf("Failed to create limiter: %v", err) +} ``` ## Architecture diff --git a/botrate_test.go b/botrate_test.go index 5db3f03..434ce84 100644 --- a/botrate_test.go +++ b/botrate_test.go @@ -10,7 +10,11 @@ import ( ) func TestLimiter_New(t *testing.T) { - l := New() + l, err := New() + + if err != nil { + t.Fatalf("New() returned error: %v", err) + } if l == nil { t.Fatal("New() returned nil") @@ -32,12 +36,16 @@ func TestLimiter_New(t *testing.T) { } func TestLimiter_New_WithOptions(t *testing.T) { - l := New( + l, err := New( WithLimit(rate.Every(time.Second)), WithAnalyzerWindow(time.Minute), WithAnalyzerPageThreshold(100), WithAnalyzerQueueCap(5000), ) + + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() if l.cfg.Limit != rate.Every(time.Second) { @@ -58,7 +66,10 @@ func TestLimiter_New_WithOptions(t *testing.T) { } func TestLimiter_Allow_VerifiedBot(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() result := l.Allow("Googlebot/2.1", "66.249.66.1") @@ -66,23 +77,29 @@ func TestLimiter_Allow_VerifiedBot(t *testing.T) { } func TestLimiter_Wait_VerifiedBot(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() - err := l.Wait(context.Background(), "Googlebot/2.1", "66.249.66.1") + err = l.Wait(context.Background(), "Googlebot/2.1", "66.249.66.1") _ = err } func TestLimiter_Wait_ContextCanceled(t *testing.T) { - l := New( + l, err := New( WithLimit(rate.Every(time.Hour)), ) + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() ctx, cancel := context.WithCancel(context.Background()) cancel() - err := l.Wait(ctx, "Mozilla/5.0", "192.168.1.1") + err = l.Wait(ctx, "Mozilla/5.0", "192.168.1.1") if err != nil && err != context.Canceled && err != ErrLimit { t.Errorf("expected nil, context.Canceled, or ErrLimit, got %v", err) @@ -90,10 +107,13 @@ func TestLimiter_Wait_ContextCanceled(t *testing.T) { } func TestLimiter_Allow_NormalUser(t *testing.T) { - l := New( + l, err := New( WithAnalyzerWindow(time.Hour), WithAnalyzerPageThreshold(1000), ) + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() allowed := l.Allow("Mozilla/5.0", "192.168.1.1") @@ -104,7 +124,10 @@ func TestLimiter_Allow_NormalUser(t *testing.T) { } func TestLimiter_Allow_BotLike(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() allowed := l.Allow("Python-urllib/3.11", "192.168.1.1") @@ -112,10 +135,13 @@ func TestLimiter_Allow_BotLike(t *testing.T) { } func TestLimiter_Allow_BlacklistedIP(t *testing.T) { - l := New( + l, err := New( WithAnalyzerWindow(time.Hour), WithAnalyzerPageThreshold(1), ) + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() allowed := l.Allow("Mozilla/5.0", "192.168.1.1") @@ -130,13 +156,16 @@ func TestLimiter_Allow_BlacklistedIP(t *testing.T) { } func TestLimiter_Wait_NormalUser(t *testing.T) { - l := New( + l, err := New( WithAnalyzerWindow(time.Hour), WithAnalyzerPageThreshold(1000), ) + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() - err := l.Wait(context.Background(), "Mozilla/5.0", "192.168.1.1") + err = l.Wait(context.Background(), "Mozilla/5.0", "192.168.1.1") if err != nil { t.Errorf("normal user should not return error, got %v", err) @@ -144,9 +173,12 @@ func TestLimiter_Wait_NormalUser(t *testing.T) { } func TestLimiter_Wait_BotLike(t *testing.T) { - l := New( + l, err := New( WithLimit(rate.Every(time.Hour)), ) + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) @@ -156,17 +188,23 @@ func TestLimiter_Wait_BotLike(t *testing.T) { } func TestLimiter_Close(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } l.Close() l.Close() } func TestLimiter_Allow_ManyRequests(t *testing.T) { - l := New( + l, err := New( WithAnalyzerWindow(time.Hour), WithAnalyzerPageThreshold(10000), ) + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() for i := 0; i < 1000; i++ { @@ -180,7 +218,10 @@ func TestLimiter_Allow_ManyRequests(t *testing.T) { } func TestLimiter_Allow_IPv6(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() if !l.Allow("Mozilla/5.0", "2001:0db8:85a3:0000:0000:8a2e:0370:7334") { @@ -189,7 +230,10 @@ func TestLimiter_Allow_IPv6(t *testing.T) { } func TestLimiter_Allow_EmptyUserAgent(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() if !l.Allow("", "192.168.1.1") { @@ -198,7 +242,10 @@ func TestLimiter_Allow_EmptyUserAgent(t *testing.T) { } func TestLimiter_Allow_EmptyIP(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() if !l.Allow("Mozilla/5.0", "") { @@ -207,10 +254,16 @@ func TestLimiter_Allow_EmptyIP(t *testing.T) { } func TestLimiter_WithKnownbots(t *testing.T) { - l1 := New() + l1, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l1.Close() - l2 := New(WithKnownbots(nil)) + l2, err := New(WithKnownbots(nil)) + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l2.Close() _ = l1.Allow("Googlebot/2.1", "66.249.66.1") @@ -218,11 +271,14 @@ func TestLimiter_WithKnownbots(t *testing.T) { } func TestLimiter_RateLimitPersistence(t *testing.T) { - l := New( + l, err := New( WithLimit(rate.Every(time.Hour)), WithAnalyzerWindow(time.Hour), WithAnalyzerPageThreshold(10000), ) + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() _ = l.Allow("Python-urllib/3.11", "192.168.1.1") @@ -230,7 +286,10 @@ func TestLimiter_RateLimitPersistence(t *testing.T) { } func TestLimiter_DifferentBots(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() bots := []string{ @@ -264,10 +323,13 @@ func TestLimiter_BotScenarios(t *testing.T) { }, } - l := New( + l, err := New( WithAnalyzerWindow(time.Hour), WithAnalyzerPageThreshold(10000), ) + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() for _, tc := range testCases { @@ -279,7 +341,10 @@ func TestLimiter_BotScenarios(t *testing.T) { } func TestLimiter_InvalidIPFormat(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() invalidIPs := []string{ @@ -295,7 +360,10 @@ func TestLimiter_InvalidIPFormat(t *testing.T) { } func TestLimiter_LongUserAgent(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() longUA := strings.Repeat("Mozilla/5.0 ", 1000) @@ -306,7 +374,10 @@ func TestLimiter_LongUserAgent(t *testing.T) { } func TestLimiter_LongPath(t *testing.T) { - l := New() + l, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() longPath := "/" + strings.Repeat("a", 10000) @@ -318,10 +389,13 @@ func TestLimiter_LongPath(t *testing.T) { } func TestLimiter_ConcurrentAccess(t *testing.T) { - l := New( + l, err := New( WithAnalyzerWindow(time.Hour), WithAnalyzerPageThreshold(10000), ) + if err != nil { + t.Fatalf("New() returned error: %v", err) + } defer l.Close() done := make(chan bool) @@ -342,7 +416,10 @@ func TestLimiter_ConcurrentAccess(t *testing.T) { } func BenchmarkLimiter_Allow_VerifiedBot(b *testing.B) { - l := New() + l, err := New() + if err != nil { + b.Fatalf("New() returned error: %v", err) + } defer l.Close() b.ResetTimer() @@ -354,10 +431,13 @@ func BenchmarkLimiter_Allow_VerifiedBot(b *testing.B) { } func BenchmarkLimiter_Allow_NormalUser(b *testing.B) { - l := New( + l, err := New( WithAnalyzerWindow(time.Hour), WithAnalyzerPageThreshold(10000), ) + if err != nil { + b.Fatalf("New() returned error: %v", err) + } defer l.Close() b.ResetTimer() @@ -369,10 +449,13 @@ func BenchmarkLimiter_Allow_NormalUser(b *testing.B) { } func BenchmarkLimiter_Allow_BlacklistedIP(b *testing.B) { - l := New( + l, err := New( WithAnalyzerWindow(time.Hour), WithAnalyzerPageThreshold(10000), ) + if err != nil { + b.Fatalf("New() returned error: %v", err) + } defer l.Close() l.Allow("Mozilla/5.0", "192.168.1.1") @@ -387,9 +470,12 @@ func BenchmarkLimiter_Allow_BlacklistedIP(b *testing.B) { } func BenchmarkLimiter_Wait(b *testing.B) { - l := New( + l, err := New( WithLimit(rate.Every(time.Hour)), ) + if err != nil { + b.Fatalf("New() returned error: %v", err) + } defer l.Close() ctx := context.Background() @@ -406,7 +492,10 @@ func BenchmarkLimiter_Close(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - l := New() + l, err := New() + if err != nil { + b.Fatalf("New() returned error: %v", err) + } l.Close() } } diff --git a/example/main.go b/example/main.go index bb6346c..449c33b 100644 --- a/example/main.go +++ b/example/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "log" "net/http" "strings" "time" @@ -11,12 +12,16 @@ import ( ) func main() { - limiter := botrate.New( + limiter, err := botrate.New( botrate.WithLimit(rate.Every(10*time.Minute)), botrate.WithAnalyzerWindow(time.Minute), botrate.WithAnalyzerPageThreshold(50), botrate.WithAnalyzerQueueCap(10000), ) + if err != nil { + log.Fatalf("Failed to create limiter: %v", err) + } + defer limiter.Close() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ua := r.UserAgent() diff --git a/limiter.go b/limiter.go index ee737f7..30d47c6 100644 --- a/limiter.go +++ b/limiter.go @@ -33,7 +33,7 @@ type Limiter struct { } // New creates a new rate limiter with default config and applies options. -func New(opts ...Option) *Limiter { +func New(opts ...Option) (*Limiter, error) { l := &Limiter{ cfg: Config{ Limit: DefaultLimit, @@ -50,7 +50,7 @@ func New(opts ...Option) *Limiter { if l.kb == nil { kb, err := knownbots.New() if err != nil { - panic(err) + return nil, err } l.kb = kb } @@ -61,7 +61,7 @@ func New(opts ...Option) *Limiter { QueueCap: l.cfg.QueueCap, }) - return l + return l, nil } // Allow reports whether the request should proceed.