diff --git a/netpool.go b/netpool.go index 47c06b2..87c2ea6 100644 --- a/netpool.go +++ b/netpool.go @@ -238,7 +238,6 @@ func (netpool *Netpool) getWithContext(ctx context.Context) (net.Conn, error) { return nil, err } - // If we have an idle conn, use it. if netpool.connections.Len() > 0 { entry := netpool.connections.Remove(netpool.connections.Front()).(*connEntry) entry.lastUsed = time.Now() @@ -246,13 +245,11 @@ func (netpool *Netpool) getWithContext(ctx context.Context) (net.Conn, error) { return entry.conn, nil } - // If we can create a new conn, do it. if len(netpool.allConns) < int(netpool.config.MaxPool) { c, err := netpool.createConnection() if err != nil { return nil, err } - entry := &connEntry{ conn: c, lastUsed: time.Now(), @@ -262,10 +259,20 @@ func (netpool *Netpool) getWithContext(ctx context.Context) (net.Conn, error) { return c, nil } - // Otherwise wait for a conn to be returned or pool closed. - // Note: sync.Cond can't be context-canceled without extra plumbing, - // but we re-check ctx.Err() on each wakeup. + done := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + netpool.mu.Lock() + netpool.cond.Broadcast() + netpool.mu.Unlock() + case <-done: + } + }() + netpool.cond.Wait() + close(done) } } diff --git a/netpool_test.go b/netpool_test.go index 57efd7d..c3f2c5c 100644 --- a/netpool_test.go +++ b/netpool_test.go @@ -1,7 +1,9 @@ package netpool import ( + "context" "errors" + "fmt" "net" "sync" "testing" @@ -693,6 +695,476 @@ func BenchmarkPoolGet(b *testing.B) { }) } +func TestContextCancelDuringWait(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, WithMinPool(0), WithMaxPool(1)) + + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + defer pool.Close() + + // Exhaust the pool + conn1, err := pool.Get() + if err != nil { + t.Fatalf("failed to get first connection: %v", err) + } + defer conn1.Close() + + // Verify pool is exhausted + stats := pool.Stats() + if stats.InUse != 1 || stats.Idle != 0 { + t.Errorf("pool should be exhausted: InUse=%d, Idle=%d", stats.InUse, stats.Idle) + } + + // Try to get with timeout context + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err = pool.GetWithContext(ctx) + elapsed := time.Since(start) + + if err != context.DeadlineExceeded { + t.Errorf("expected DeadlineExceeded, got %v", err) + } + + // Should timeout around 100ms, not instantly + if elapsed < 50*time.Millisecond { + t.Errorf("timeout happened too quickly: %v", elapsed) + } + if elapsed > 200*time.Millisecond { + t.Errorf("timeout took too long: %v", elapsed) + } + + t.Logf("Context cancellation detected after %v", elapsed) +} + +func TestCloseWithWaiters(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, WithMinPool(0), WithMaxPool(1)) + + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + + // Exhaust the pool + conn1, err := pool.Get() + if err != nil { + t.Fatalf("failed to get connection: %v", err) + } + + // Start goroutine that will block waiting for a connection + done := make(chan error, 1) + started := make(chan struct{}) + + go func() { + close(started) // Signal that goroutine started + _, err := pool.Get() + done <- err + }() + + // Wait for goroutine to start and begin waiting + <-started + time.Sleep(50 * time.Millisecond) + + // Close pool while goroutine is waiting + pool.Close() + + // The waiting goroutine should unblock with ErrPoolClosed + select { + case err := <-done: + if err != ErrPoolClosed { + t.Errorf("expected ErrPoolClosed, got %v", err) + } + t.Logf("Waiter properly unblocked with: %v", err) + case <-time.After(1 * time.Second): + t.Fatal("waiting goroutine did not unblock after pool.Close()") + } + + // Verify connection is closed + _, err = conn1.Write([]byte("test")) + if err == nil { + t.Error("connection should be closed after pool.Close()") + } + + // Subsequent Get should return ErrPoolClosed + _, err = pool.Get() + if err != ErrPoolClosed { + t.Errorf("Get after Close should return ErrPoolClosed, got %v", err) + } +} + +func TestStatsConsistency(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, WithMinPool(2), WithMaxPool(10)) + + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + defer pool.Close() + + var wg sync.WaitGroup + errors := make(chan string, 100) + + // Goroutines that continuously check stats consistency + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 100; j++ { + stats := pool.Stats() + + // Active should always equal Idle + InUse + if stats.Active != stats.Idle+stats.InUse { + errors <- fmt.Sprintf("goroutine %d: inconsistent stats: Active=%d, Idle=%d, InUse=%d", + id, stats.Active, stats.Idle, stats.InUse) + } + + // Active should never exceed MaxPool + if stats.Active > int(stats.MaxPool) { + errors <- fmt.Sprintf("goroutine %d: Active(%d) > MaxPool(%d)", + id, stats.Active, stats.MaxPool) + } + + // Idle should never exceed Active + if stats.Idle > stats.Active { + errors <- fmt.Sprintf("goroutine %d: Idle(%d) > Active(%d)", + id, stats.Idle, stats.Active) + } + + // InUse should never exceed Active + if stats.InUse > stats.Active { + errors <- fmt.Sprintf("goroutine %d: InUse(%d) > Active(%d)", + id, stats.InUse, stats.Active) + } + + time.Sleep(1 * time.Millisecond) + } + }(i) + } + + // Goroutines that Get and Put connections + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 50; j++ { + conn, err := pool.Get() + if err != nil { + continue + } + + // Use the connection briefly + time.Sleep(5 * time.Millisecond) + + // Randomly mark some as unusable + if j%10 == 0 { + if pc, ok := conn.(interface{ MarkUnusable() error }); ok { + pc.MarkUnusable() + } + } else { + conn.Close() + } + } + }(i) + } + + wg.Wait() + close(errors) + + // Check for any consistency errors + errorCount := 0 + for errMsg := range errors { + t.Error(errMsg) + errorCount++ + if errorCount >= 10 { + t.Log("... (more errors suppressed)") + break + } + } + + if errorCount > 0 { + t.Fatalf("found %d stats consistency violations", errorCount) + } + + // Final verification + finalStats := pool.Stats() + t.Logf("Final stats: Active=%d, Idle=%d, InUse=%d, MaxPool=%d, MinPool=%d", + finalStats.Active, finalStats.Idle, finalStats.InUse, finalStats.MaxPool, finalStats.MinPool) + + if finalStats.Active != finalStats.Idle+finalStats.InUse { + t.Errorf("final stats inconsistent: Active=%d != Idle(%d) + InUse(%d)", + finalStats.Active, finalStats.Idle, finalStats.InUse) + } +} + +func TestMultipleWaitersUnblockOnClose(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, WithMinPool(0), WithMaxPool(1)) + + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + + // Exhaust the pool + conn1, err := pool.Get() + if err != nil { + t.Fatalf("failed to get connection: %v", err) + } + + // Start multiple goroutines waiting for connections + numWaiters := 5 + done := make(chan error, numWaiters) + + for i := 0; i < numWaiters; i++ { + go func(id int) { + _, err := pool.Get() + done <- err + }(i) + } + + // Give them time to start waiting + time.Sleep(100 * time.Millisecond) + + // Close the pool + pool.Close() + + // All waiters should unblock with ErrPoolClosed + timeout := time.After(2 * time.Second) + for i := 0; i < numWaiters; i++ { + select { + case err := <-done: + if err != ErrPoolClosed { + t.Errorf("waiter %d: expected ErrPoolClosed, got %v", i, err) + } + case <-timeout: + t.Fatalf("only %d/%d waiters unblocked", i, numWaiters) + } + } + + t.Logf("All %d waiters properly unblocked", numWaiters) + + // Cleanup + conn1.Close() +} + +func TestContextCancelWithMultipleWaiters(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, WithMinPool(0), WithMaxPool(2)) + + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + defer pool.Close() + + // Exhaust the pool + conn1, _ := pool.Get() + conn2, _ := pool.Get() + defer conn1.Close() + defer conn2.Close() + + // Start 3 goroutines with different context timeouts + type result struct { + id int + err error + elapsed time.Duration + } + + results := make(chan result, 3) + + // Waiter 1: 50ms timeout + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + start := time.Now() + _, err := pool.GetWithContext(ctx) + results <- result{1, err, time.Since(start)} + }() + + // Waiter 2: 100ms timeout + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + start := time.Now() + _, err := pool.GetWithContext(ctx) + results <- result{2, err, time.Since(start)} + }() + + // Waiter 3: 150ms timeout + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + defer cancel() + start := time.Now() + _, err := pool.GetWithContext(ctx) + results <- result{3, err, time.Since(start)} + }() + + // Collect results + for i := 0; i < 3; i++ { + r := <-results + if r.err != context.DeadlineExceeded { + t.Errorf("waiter %d: expected DeadlineExceeded, got %v", r.id, r.err) + } + t.Logf("Waiter %d timed out after %v", r.id, r.elapsed) + } +} + +func TestDoubleCloseConnection(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, WithMinPool(1), WithMaxPool(5)) + + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + defer pool.Close() + + conn, err := pool.Get() + if err != nil { + t.Fatalf("failed to get connection: %v", err) + } + + // First close - should return to pool + err = conn.Close() + if err != nil { + t.Errorf("first Close() failed: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + statsAfterFirstClose := pool.Stats() + t.Logf("After first close: Active=%d, Idle=%d, InUse=%d", + statsAfterFirstClose.Active, statsAfterFirstClose.Idle, statsAfterFirstClose.InUse) + + // Second close - should be a no-op or return error + err = conn.Close() + if err != nil { + t.Logf("second Close() returned expected error: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + statsAfterSecondClose := pool.Stats() + t.Logf("After second close: Active=%d, Idle=%d, InUse=%d", + statsAfterSecondClose.Active, statsAfterSecondClose.Idle, statsAfterSecondClose.InUse) + + // Stats should not change after second close + if statsAfterFirstClose.Active != statsAfterSecondClose.Active { + t.Errorf("stats changed after double close: %+v -> %+v", + statsAfterFirstClose, statsAfterSecondClose) + } + + // Using connection after double close should fail + _, err = conn.Write([]byte("test")) + if err == nil { + t.Error("Write should fail on double-closed connection") + } +} + +func TestExtremeContention(t *testing.T) { + if testing.Short() { + t.Skip("skipping extreme contention test in short mode") + } + + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, WithMinPool(1), WithMaxPool(5)) + + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + defer pool.Close() + + numGoroutines := 100 + iterationsPerGoroutine := 100 + + var wg sync.WaitGroup + errors := make(chan error, numGoroutines*iterationsPerGoroutine) + + start := time.Now() + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + for j := 0; j < iterationsPerGoroutine; j++ { + conn, err := pool.Get() + if err != nil { + errors <- fmt.Errorf("goroutine %d: Get failed: %w", id, err) + continue + } + + // Simulate very brief use + time.Sleep(1 * time.Millisecond) + + conn.Close() + } + }(i) + } + + wg.Wait() + elapsed := time.Since(start) + close(errors) + + errorCount := 0 + for err := range errors { + t.Logf("Error: %v", err) + errorCount++ + if errorCount >= 10 { + t.Log("... (more errors suppressed)") + break + } + } + + if errorCount > 0 { + t.Errorf("got %d errors during extreme contention", errorCount) + } + + stats := pool.Stats() + t.Logf("Extreme contention test completed in %v", elapsed) + t.Logf("Final stats: Active=%d, Idle=%d, InUse=%d", + stats.Active, stats.Idle, stats.InUse) + t.Logf("Total operations: %d", numGoroutines*iterationsPerGoroutine) + t.Logf("Operations/sec: %.0f", float64(numGoroutines*iterationsPerGoroutine)/elapsed.Seconds()) + + // Pool should be stable + if stats.Active > int(stats.MaxPool) { + t.Errorf("pool exceeded MaxPool: Active=%d, MaxPool=%d", stats.Active, stats.MaxPool) + } + + if stats.InUse != 0 { + t.Errorf("expected all connections returned: InUse=%d", stats.InUse) + } +} + func BenchmarkPoolConcurrent(b *testing.B) { listener, addr := createBenchmarkTestServer(b) defer listener.Close()