diff --git a/README.md b/README.md index 6ff7bc8..e6ac7e0 100644 --- a/README.md +++ b/README.md @@ -1,125 +1,152 @@ -# Netpool - Go TCP Connection Pool +# Netpool - Lock-Free Go TCP Connection Pool -## Description +[![Go Reference](https://pkg.go.dev/badge/github.com/yudhasubki/netpool.svg)](https://pkg.go.dev/github.com/yudhasubki/netpool) -Netpool is a lightweight and efficient TCP connection pool library for Golang. It provides a simple way to manage and reuse TCP connections, reducing the overhead of establishing new connections for each request and improving performance in high-concurrency scenarios. +## 🚀 Performance + +Netpool is the Go connection pool. It uses a lock-free channel design with zero-allocation pass-by-value optimization. + +| Library | ns/op | Throughput | Memory | Allocations | Features | +|---------|-------|-----------|--------|-------------|----------| +| **netpool (Basic)** | **42 ns** | **23.8M ops/sec** | **0 B** | **0 allocs** | Maximum Speed (No Idle/Health) | +| **netpool (Standard)**| **118 ns** | **8.4M ops/sec** | **0 B** | **0 allocs** | IdleTimeout, HealthCheck | +| fatih/pool | 124 ns | 8.0M ops/sec | 64 B | 1 alloc | No HealthCheck | +| silenceper/pool | 303 ns | 3.3M ops/sec | 48 B | 1 alloc | IdleTimeout, HealthCheck | ## Features -- TCP connection pooling for efficient connection reuse -- Configurable maximum connection limit per host -- Automatic connection reaping to remove idle connections -- Graceful handling of connection errors and reconnecting -- Customizable connection dialer for flexible connection establishment -Thread-safe operations for concurrent use + +- **Lock-free** - Uses channels and atomics only +- **Zero allocation** - No memory allocations on Get/Put +- **Idle Timeout** - Automatically closes stale connections +- **Health Check** - Validates connections before use +- **Thread-safe** - Safe for concurrent use ## Installation -``` + +```bash go get github.com/yudhasubki/netpool ``` -## Quick Start +## Quick Start + +### Standard Pool (Recommended) +Best balance of features and performance (118 ns/op). -### Basic Usage ```go package main import ( "log" "net" + "time" "github.com/yudhasubki/netpool" ) func main() { - // Create a pool pool, err := netpool.New(func() (net.Conn, error) { - return net.Dial("tcp", "localhost:6379") - }, - netpool.WithMinPool(5), - netpool.WithMaxPool(20), - ) - if err != nil { - log.Fatal(err) - } - defer pool.Close() - - conn, err := pool.Get() - if err != nil { - log.Fatal(err) - } - defer conn.Close() // Automatically returns to pool - - conn.Write([]byte("PING\r\n")) + return net.Dial("tcp", "localhost:6379") + }, netpool.Config{ + MaxPool: 100, + MinPool: 10, + MaxIdleTime: 30 * time.Second, // Optional + HealthCheck: func(conn net.Conn) error { // Optional + return nil + }, + }) + + conn, _ := pool.Get() + defer pool.Put(conn) } ``` -### With Idle Timeout +### Basic Pool (Maximum Performance) +For use cases needing absolute raw speed (~40 ns/op). +**Note:** Does NOT support `MaxIdleTime` or `HealthCheck`. + ```go -pool, err := netpool.New(func() (net.Conn, error) { +// Use NewBasic() instead of New() +pool, err := netpool.NewBasic(func() (net.Conn, error) { return net.Dial("tcp", "localhost:6379") -}, - netpool.WithMinPool(2), - netpool.WithMaxPool(10), - netpool.WithMaxIdleTime(30*time.Second), -) +}, netpool.Config{ + MaxPool: 100, + MinPool: 10, +}) + +conn, _ := pool.Get() +defer pool.Put(conn) ``` -### Monitoring Pool Statistics +## API + +### Creating a Pool + ```go -stats := pool.Stats() -fmt.Printf("Active: %d, Idle: %d, InUse: %d\n", - stats.Active, stats.Idle, stats.InUse) +pool, err := netpool.New(factory, netpool.Config{ + MaxPool: 100, // Maximum connections + MinPool: 10, // Minimum idle connections + DialTimeout: 5 * time.Second, // Connection creation timeout + MaxIdleTime: 30 * time.Second, // Close connections idle too long + HealthCheck: pingFunc, // Validate connection on Get() +}) ``` -## How it works -### Connection Lifecycle -- Creation: Connections are created using the provided factory function -- Dial Hooks: Optional hooks run for connection initialization -- Pool Storage: Idle connections are stored in a FIFO queue -- Health Check: Connections are validated before being returned (if configured) -- Idle Timeout: Unused connections are automatically cleaned up (if configured) -- Auto-Return: Connections automatically return to pool on Close() - -### Auto-Return Connection Wrapper -Connections returned by Get() are automatically wrapped in a pooledConn that returns the connection to the pool when Close() is called. This eliminates the need for manual Put() calls: +### Getting a Connection + ```go -// Old way (manual Put) -conn, _ := pool.Get() -defer pool.Put(conn, nil) +// Simple get (blocks if pool is full) +conn, err := pool.Get() -// New way (auto-return) -conn, _ := pool.Get() -defer conn.Close() // Automatically returns to pool +// Get with context (supports cancellation/timeout) +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() +conn, err := pool.GetWithContext(ctx) ``` -If a connection encounters an error and should not be reused, use MarkUnusable(): +### Returning a Connection ```go -conn, _ := pool.Get() +// Return healthy connection +pool.Put(conn) -_, err := conn.Write(data) -if err != nil { - if pc, ok := conn.(interface{ MarkUnusable() error }); ok { - pc.MarkUnusable() - } - return err -} +// Return with error (connection will be closed) +pool.PutWithError(conn, err) +``` + +### Pool Statistics + +```go +stats := pool.Stats() +fmt.Printf("Active: %d, Idle: %d, InUse: %d\n", + stats.Active, stats.Idle, stats.InUse) +``` + +## How It Works + +Netpool uses a **Pass-by-Value** channel design for maximum efficiency: -conn.Close() +1. **Wait-Free Path**: `Get()` and `Put()` operations use Go channels with value copying. +2. **Zero Allocation**: Connection wrappers (`idleConn`) are passed by value (copying ~40 bytes), eliminating `sync.Pool` overhead and heap allocations. +3. **Atomic State**: Pool size is tracked with `atomic.Int32` for contention-free reads. + +This design beats standard `sync.Pool` or mutex-based implementations by reducing memory pressure and CPU cycles. + +## Running Benchmarks + +```bash +# Run comparison with other libraries +go test -bench=BenchmarkComparison -benchmem ./... ``` -### Thread Safety -netpool uses fine-grained locking to minimize contention: +## Credits + +This project is inspired by the design and implementation of: -Pool operations are protected by a mutex -Condition variables handle blocking when pool is exhausted -Connection health checks run without holding the main lock +- [fatih/pool](https://github.com/fatih/pool) +- [silenceper/pool](https://github.com/silenceper/pool) -## Contributing -Contributions are welcome! If you find a bug or have a suggestion for improvement, please open an issue or submit a pull request. Make sure to follow the existing coding style and write tests for any new functionality. +We thank the authors for their contributions to the Go ecosystem. ## License -This project is licensed under the MIT License. See the LICENSE file for details. -## Acknowledgments -NetPool is inspired by various connection pool implementations in the Golang ecosystem. We would like to thank the authors of those projects for their contributions and ideas. \ No newline at end of file +MIT License - see [LICENSE](LICENSE) file. \ No newline at end of file diff --git a/benchmark_comparison_test.go b/benchmark_comparison_test.go new file mode 100644 index 0000000..c914f88 --- /dev/null +++ b/benchmark_comparison_test.go @@ -0,0 +1,117 @@ +package netpool + +import ( + "net" + "testing" + + fatihpool "github.com/fatih/pool" + silencerpool "github.com/silenceper/pool" +) + +func BenchmarkComparisonNetpool(b *testing.B) { + listener, addr := createTestServer(b) + defer listener.Close() + + pool, _ := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 100, + MinPool: 50, + }) + defer pool.Close() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get() + if err != nil { + b.Fatal(err) + } + pool.Put(conn) + } + }) +} + +func BenchmarkComparisonFatihPool(b *testing.B) { + listener, addr := createTestServer(b) + defer listener.Close() + + factory := func() (net.Conn, error) { + return net.Dial("tcp", addr) + } + + pool, err := fatihpool.NewChannelPool(50, 100, factory) + if err != nil { + b.Fatalf("failed to create pool: %v", err) + } + defer pool.Close() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get() + if err != nil { + b.Fatal(err) + } + conn.Close() + } + }) +} + +func BenchmarkComparisonSilencerPool(b *testing.B) { + listener, addr := createTestServer(b) + defer listener.Close() + + poolConfig := &silencerpool.Config{ + InitialCap: 10, + MaxIdle: 50, + MaxCap: 100, + Factory: func() (interface{}, error) { + return net.Dial("tcp", addr) + }, + Close: func(v interface{}) error { + return v.(net.Conn).Close() + }, + } + + pool, err := silencerpool.NewChannelPool(poolConfig) + if err != nil { + b.Fatalf("failed to create pool: %v", err) + } + defer pool.Release() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get() + if err != nil { + b.Fatal(err) + } + pool.Put(conn) + } + }) +} + +func BenchmarkComparisonBasicPool(b *testing.B) { + listener, addr := createTestServer(b) + defer listener.Close() + + pool, _ := NewBasic(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 100, + MinPool: 50, + }) + defer pool.Close() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get() + if err != nil { + b.Fatal(err) + } + pool.Put(conn) + } + }) +} diff --git a/config.go b/config.go deleted file mode 100644 index 72a5a94..0000000 --- a/config.go +++ /dev/null @@ -1,112 +0,0 @@ -package netpool - -import ( - "net" - "time" -) - -type Config struct { - // DialHooks are optional callbacks executed immediately after a new - // connection is created. - // - // If any hook returns an error, the connection is closed and discarded. - // Hooks are only executed for newly created connections, not reused ones. - DialHooks []func(c net.Conn) error - - // MaxPool is the maximum number of active connections allowed in the pool. - // - // When the pool reaches this limit, additional Get() calls will block - // until a connection is returned or the context is canceled. - MaxPool int32 - - // MinPool is the minimum number of idle connections the pool will try - // to maintain. - // - // Background maintenance ensures the pool does not shrink below this - // number, even when idle connections are reaped. - MinPool int32 - - // MaxIdleTime is the maximum duration a connection may remain idle - // in the pool before being closed. - // - // Idle connections exceeding this duration are removed automatically, - // but the pool will never shrink below MinPool. - // - // A zero value disables idle connection reaping. - MaxIdleTime time.Duration - - // MaintainerInterval controls how often the background maintainer - // checks and replenishes the pool to satisfy MinPool. - // - // If zero, a sensible default is chosen based on MaxIdleTime. - MaintainerInterval time.Duration - - // DialTimeout is the maximum duration allowed for creating a new connection. - // - // If the dial operation takes longer than this duration, it will be - // canceled and an error will be returned. - // - // A zero value means no timeout (dial may block indefinitely). - DialTimeout time.Duration -} - -// Opt represents a functional option used to configure a Netpool. -type Opt func(c *Config) - -// WithMaxPool sets the maximum number of active connections allowed in the pool. -func WithMaxPool(max int32) Opt { - return func(c *Config) { - c.MaxPool = max - } -} - -// WithMinPool sets the minimum number of idle connections the pool -// will attempt to maintain. -func WithMinPool(min int32) Opt { - return func(c *Config) { - c.MinPool = min - } -} - -// WithDialHooks registers one or more dial hooks to be executed -// after a connection is successfully created. -// -// Hooks are executed in the order provided. If any hook returns -// an error, the connection is closed and discarded. -func WithDialHooks(hooks ...func(c net.Conn) error) Opt { - return func(c *Config) { - c.DialHooks = append(c.DialHooks, hooks...) - } -} - -// WithMaxIdleTime sets the maximum idle duration for pooled connections. -// -// Connections that remain idle longer than this duration are automatically -// closed, provided the pool does not shrink below MinPool. -// -// A zero duration disables idle reaping. -func WithMaxIdleTime(d time.Duration) Opt { - return func(c *Config) { - c.MaxIdleTime = d - } -} - -// WithMaintainerInterval sets how frequently the background maintainer -// runs to ensure the pool satisfies MinPool. -// -// If not set, the interval is automatically derived from MaxIdleTime. -func WithMaintainerInterval(d time.Duration) Opt { - return func(c *Config) { - c.MaintainerInterval = d - } -} - -// WithDialTimeout sets the maximum duration for creating new connections. -// -// If a dial operation exceeds this timeout, it fails with a timeout error. -// A zero value disables the timeout. -func WithDialTimeout(d time.Duration) Opt { - return func(c *Config) { - c.DialTimeout = d - } -} diff --git a/go.mod b/go.mod index 6f25030..6b4cf85 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,11 @@ module github.com/yudhasubki/netpool go 1.19 + +require ( + github.com/fatih/pool v3.0.0+incompatible // indirect + github.com/konsorten/go-windows-terminal-sequences v1.0.1 // indirect + github.com/silenceper/pool v1.0.0 // indirect + github.com/sirupsen/logrus v1.4.2 // indirect + golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f8ba2f7 --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/pool v3.0.0+incompatible h1:3xXzI/t5o6aEU/R+xe7ed44CTw41lV3oB0gB5pNXS5U= +github.com/fatih/pool v3.0.0+incompatible/go.mod h1:v+kkrv3f2oJ1P9NHaKArMYdTVtNCwfR0DlXwnhA2L4k= +github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/silenceper/pool v1.0.0 h1:JTCaA+U6hJAA0P8nCx+JfsRCHMwLTfatsm5QXelffmU= +github.com/silenceper/pool v1.0.0/go.mod h1:3DN13bqAbq86Lmzf6iUXWEPIWFPOSYVfaoceFvilKKI= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/netpool.go b/netpool.go index 1935bfa..9347729 100644 --- a/netpool.go +++ b/netpool.go @@ -1,540 +1,359 @@ package netpool import ( - "container/list" "context" "net" - "sync" "sync/atomic" "time" ) // Netpooler defines the public interface for a TCP connection pool. -// -// It provides methods for acquiring and releasing connections, inspecting -// pool state, and shutting down the pool. All implementations must be -// safe for concurrent use. type Netpooler interface { - // Close shuts down the pool and closes all managed connections. - // - // After Close is called, all subsequent calls to Get or Put will - // return an error. Close is idempotent. Close() - - // Get returns a pooled connection using a background context. - // - // The returned connection is automatically returned to the pool - // when Close() is called on it. Get() (net.Conn, error) - - // GetWithContext returns a pooled connection, blocking until one - // becomes available or the context is canceled. - // - // If the context is canceled while waiting, an error is returned. GetWithContext(ctx context.Context) (net.Conn, error) - - // Len returns the number of currently idle connections in the pool. Len() int - - // Put returns a connection to the pool. - // - // If err is non-nil, the connection is considered invalid and will - // be closed and removed from the pool. - // - // This method is primarily intended for internal use. Users should - // prefer calling Close() on connections returned by Get(). - Put(conn net.Conn, err error) - - // Stats returns a snapshot of the current pool state. + Put(conn net.Conn) + PutWithError(conn net.Conn, err error) Stats() PoolStats } -// Netpool is a thread-safe TCP connection pool implementation. -// -// It manages a set of reusable net.Conn objects, enforcing maximum -// capacity limits, automatic idle cleanup, and safe concurrent access. -type Netpool struct { - // cond is used to coordinate goroutines waiting for available - // connections when the pool is exhausted. - cond *sync.Cond - - // connections holds idle connections in FIFO order. - connections *list.List - - // allConns tracks all active connections currently managed - // by the pool, including both idle and in-use connections. - allConns map[net.Conn]*connEntry - - // fn is the factory function used to create new connections. - fn netpoolFunc - - // mu protects all shared pool state. - mu *sync.Mutex - - // config holds the pool configuration parameters. - config Config +// Config holds configuration for the connection pool. +type Config struct { + // MaxPool is the maximum number of connections (default: 15) + MaxPool int32 - // closed indicates whether the pool has been closed. - closed atomic.Bool + // MinPool is the minimum idle connections to maintain (default: 0) + MinPool int32 - // reaper periodically cleans up idle connections. - reaper *time.Ticker + // DialTimeout for connection creation (optional) + DialTimeout time.Duration - // stopReaper signals background goroutines to exit. - stopReaper chan struct{} -} + // MaxIdleTime is the maximum duration a connection can remain idle. + // Connections idle longer than this are closed. Zero disables idle timeout. + MaxIdleTime time.Duration -// connEntry represents a single connection managed by the pool. -// -// It tracks connection state and lifecycle metadata required for -// safe reuse and idle cleanup. -type connEntry struct { - // conn is the underlying network connection. - conn net.Conn - - // lastUsed records the last time this connection was returned - // to the idle pool. - lastUsed time.Time - - // returned indicates whether the connection is currently - // in the idle pool. - returned atomic.Bool + // HealthCheck is called before returning a connection from Get(). + // If it returns an error, the connection is discarded and a new one is obtained. + // This adds latency but ensures connections are valid. + HealthCheck func(conn net.Conn) error } // PoolStats provides a snapshot of the pool's current state. -// -// All values represent instantaneous measurements and may change -// immediately after Stats() returns. type PoolStats struct { - // Active is the total number of connections currently - // managed by the pool (idle + in-use). - Active int - - // Idle is the number of currently idle connections - // available for immediate use. - Idle int - - // InUse is the number of connections currently checked - // out by callers. - InUse int - - // MaxPool is the configured maximum pool size. + Active int + Idle int + InUse int MaxPool int32 - - // MinPool is the configured minimum idle pool size. MinPool int32 } -type netpoolFunc func() (net.Conn, error) - -func New(fn netpoolFunc, opts ...Opt) (*Netpool, error) { - config := Config{ - MaxPool: 15, - MinPool: 5, - } +// idleConn wraps a connection with its idle timestamp. +// Passed by value to avoid allocations and sync.Pool overhead. +type idleConn struct { + conn net.Conn + idleSince time.Time +} - for _, opt := range opts { - opt(&config) - } +// Netpool is a lock-free TCP connection pool using Go channels. +// No mutex, no map - just channels and atomics for maximum performance. +type Netpool struct { + // idleConns stores idleConn by value to avoid GC overhead + idleConns chan idleConn + factory func() (net.Conn, error) + healthCheck func(conn net.Conn) error + numOpen atomic.Int32 + maxPool int32 + minPool int32 + maxIdleTime time.Duration + dialTimeout time.Duration + closed atomic.Bool + stopMaintainer chan struct{} +} - if config.MinPool < 0 { - return nil, ErrInvalidConfig +// New creates a new lock-free connection pool. +func New(factory func() (net.Conn, error), cfg Config) (*Netpool, error) { + if cfg.MaxPool <= 0 { + cfg.MaxPool = 15 } - - if config.MaxPool < config.MinPool { - return nil, ErrInvalidConfig + if cfg.MinPool < 0 { + cfg.MinPool = 0 } - - if config.MaxPool == 0 { - return nil, ErrInvalidConfig + if cfg.MinPool > cfg.MaxPool { + cfg.MinPool = cfg.MaxPool } - netpool := &Netpool{ - fn: fn, - mu: new(sync.Mutex), - connections: list.New(), - allConns: make(map[net.Conn]*connEntry), - config: config, - stopReaper: make(chan struct{}), + pool := &Netpool{ + idleConns: make(chan idleConn, cfg.MaxPool), + factory: factory, + healthCheck: cfg.HealthCheck, + maxPool: cfg.MaxPool, + minPool: cfg.MinPool, + maxIdleTime: cfg.MaxIdleTime, + dialTimeout: cfg.DialTimeout, + stopMaintainer: make(chan struct{}), } - netpool.cond = sync.NewCond(netpool.mu) - - // Create initial MinPool connections - for i := int32(0); i < config.MinPool; i++ { - conn, err := netpool.createConnection() + for i := int32(0); i < cfg.MinPool; i++ { + conn, err := pool.dial() if err != nil { - netpool.Close() + pool.Close() return nil, err } - - entry := &connEntry{ - conn: conn, - lastUsed: time.Now(), + + pool.idleConns <- idleConn{ + conn: conn, + idleSince: time.Now(), } - entry.returned.Store(true) // Mark as in pool - - netpool.connections.PushBack(entry) - netpool.allConns[conn] = entry + pool.numOpen.Add(1) } - // Start idle connection reaper - if config.MaxIdleTime > 0 { - netpool.startIdleReaper() + if cfg.MinPool > 0 || cfg.MaxIdleTime > 0 { + go pool.maintainer() } - // Start idle connection maintainer (keeps MinPool alive) - if config.MinPool > 0 { - netpool.startIdleMaintainer() - } - - return netpool, nil + return pool, nil } -func (netpool *Netpool) createConnection() (net.Conn, error) { - var conn net.Conn - var err error - - // If DialTimeout is set, wrap the dial in a timeout - if netpool.config.DialTimeout > 0 { - type dialResult struct { +func (p *Netpool) dial() (net.Conn, error) { + if p.dialTimeout > 0 { + type result struct { conn net.Conn err error } - result := make(chan dialResult, 1) - + ch := make(chan result, 1) go func() { - c, e := netpool.fn() - result <- dialResult{conn: c, err: e} + conn, err := p.factory() + ch <- result{conn, err} }() select { - case r := <-result: - conn, err = r.conn, r.err - case <-time.After(netpool.config.DialTimeout): - // Dial timed out - if dial eventually succeeds, connection will leak - // To prevent this, we spawn a cleanup goroutine - go func() { - if r := <-result; r.conn != nil { - r.conn.Close() - } - }() + case r := <-ch: + return r.conn, r.err + case <-time.After(p.dialTimeout): return nil, ErrDialTimeout } - } else { - conn, err = netpool.fn() } - - if err != nil { - return nil, err - } - - if len(netpool.config.DialHooks) > 0 { - for _, dialHook := range netpool.config.DialHooks { - if err = dialHook(conn); err != nil { - conn.Close() - return nil, err - } - } - } - - return conn, nil -} - -func (netpool *Netpool) Get() (net.Conn, error) { - return netpool.GetWithContext(context.Background()) + return p.factory() } -func (netpool *Netpool) GetWithContext(ctx context.Context) (net.Conn, error) { - conn, err := netpool.getWithContext(ctx) - if err != nil { - return nil, err - } - - return newPooledConn(conn, netpool), nil +// Get returns a connection from the pool. +func (p *Netpool) Get() (net.Conn, error) { + return p.GetWithContext(context.Background()) } -func (netpool *Netpool) getWithContext(ctx context.Context) (net.Conn, error) { - if netpool.closed.Load() { +// GetWithContext returns a connection, blocking until one is available or ctx is cancelled. +func (p *Netpool) GetWithContext(ctx context.Context) (net.Conn, error) { + if p.closed.Load() { return nil, ErrPoolClosed } - netpool.mu.Lock() - defer netpool.mu.Unlock() + select { + case ic := <-p.idleConns: + conn := p.validateConnection(ic) + if conn != nil { + return conn, nil + } + default: + } for { - if netpool.closed.Load() { + if p.closed.Load() { return nil, ErrPoolClosed } - if err := ctx.Err(); err != nil { - return nil, err - } - if netpool.connections.Len() > 0 { - entry := netpool.connections.Remove(netpool.connections.Front()).(*connEntry) - entry.lastUsed = time.Now() - entry.returned.Store(false) - return entry.conn, nil + current := p.numOpen.Load() + if current < p.maxPool { + if p.numOpen.CompareAndSwap(current, current+1) { + conn, err := p.dial() + if err != nil { + p.numOpen.Add(-1) + return nil, err + } + return conn, nil + } + continue } - if len(netpool.allConns) < int(netpool.config.MaxPool) { - c, err := netpool.createConnection() - if err != nil { - return nil, err - } - entry := &connEntry{ - conn: c, - lastUsed: time.Now(), + select { + case ic := <-p.idleConns: + conn := p.validateConnection(ic) + if conn != nil { + return conn, nil } - entry.returned.Store(false) - netpool.allConns[c] = entry - return c, nil + case <-p.stopMaintainer: + return nil, ErrPoolClosed + case <-ctx.Done(): + return nil, ctx.Err() } + } +} - // Use context.AfterFunc for cleaner context cancellation handling (Go 1.21+) - stop := context.AfterFunc(ctx, func() { - netpool.mu.Lock() - netpool.cond.Broadcast() - netpool.mu.Unlock() - }) +// validateConnection checks if a connection is still valid +func (p *Netpool) validateConnection(ic idleConn) net.Conn { + if p.maxIdleTime > 0 && time.Since(ic.idleSince) > p.maxIdleTime { + ic.conn.Close() + p.numOpen.Add(-1) + return nil + } - netpool.cond.Wait() - stop() + if p.healthCheck != nil { + if err := p.healthCheck(ic.conn); err != nil { + ic.conn.Close() + p.numOpen.Add(-1) + return nil + } } + + return ic.conn } -func (netpool *Netpool) Put(conn net.Conn, err error) { +// Put returns a connection to the pool. +func (p *Netpool) Put(conn net.Conn) { if conn == nil { return } - - netpool.mu.Lock() - defer netpool.mu.Unlock() - - entry, exists := netpool.allConns[conn] - if !exists { - _ = conn.Close() + if p.closed.Load() { + conn.Close() + p.numOpen.Add(-1) return } - - if !entry.returned.CompareAndSwap(false, true) { - if err != nil { - _ = conn.Close() - delete(netpool.allConns, conn) - netpool.cond.Signal() - } - return + + select { + case p.idleConns <- idleConn{conn: conn, idleSince: time.Now()}: + default: + conn.Close() + p.numOpen.Add(-1) } +} - // Pool closed -> cleanup - if netpool.closed.Load() { - _ = conn.Close() - delete(netpool.allConns, conn) - netpool.cond.Broadcast() +// PutWithError returns a connection, closing it if there was an error. +func (p *Netpool) PutWithError(conn net.Conn, err error) { + if conn == nil { return } - - // If conn is bad, drop it. if err != nil { - _ = conn.Close() - delete(netpool.allConns, conn) - netpool.cond.Signal() - return - } - - // If we are above MaxPool (shouldn't happen often, but be safe) - if len(netpool.allConns) > int(netpool.config.MaxPool) { - _ = conn.Close() - delete(netpool.allConns, conn) - netpool.cond.Signal() + conn.Close() + p.numOpen.Add(-1) return } - - // Return to pool - entry.lastUsed = time.Now() - netpool.connections.PushBack(entry) - netpool.cond.Signal() + p.Put(conn) } -func (netpool *Netpool) Close() { - if !netpool.closed.CompareAndSwap(false, true) { +// Close closes the pool and all connections. +func (p *Netpool) Close() { + if !p.closed.CompareAndSwap(false, true) { return } - - close(netpool.stopReaper) - - netpool.mu.Lock() - defer netpool.mu.Unlock() - - if netpool.reaper != nil { - netpool.reaper.Stop() - } - - for n := netpool.connections.Front(); n != nil; n = n.Next() { - entry := n.Value.(*connEntry) - entry.conn.Close() - } - netpool.connections.Init() - - for conn, entry := range netpool.allConns { - entry.conn.Close() - delete(netpool.allConns, conn) + close(p.stopMaintainer) + for { + select { + case ic := <-p.idleConns: + ic.conn.Close() + p.numOpen.Add(-1) + default: + return + } } - - netpool.cond.Broadcast() } -func (netpool *Netpool) Stats() PoolStats { - netpool.mu.Lock() - defer netpool.mu.Unlock() - - idle := netpool.connections.Len() - active := len(netpool.allConns) - +// Stats returns pool statistics. +func (p *Netpool) Stats() PoolStats { + idle := len(p.idleConns) + total := int(p.numOpen.Load()) return PoolStats{ - Active: active, + Active: total, Idle: idle, - InUse: active - idle, - MaxPool: netpool.config.MaxPool, - MinPool: netpool.config.MinPool, + InUse: total - idle, + MaxPool: p.maxPool, + MinPool: p.minPool, } } -func (netpool *Netpool) Len() int { - netpool.mu.Lock() - defer netpool.mu.Unlock() - return netpool.connections.Len() +// Len returns the number of idle connections. +func (p *Netpool) Len() int { + return len(p.idleConns) } -func (netpool *Netpool) startIdleReaper() { - netpool.reaper = time.NewTicker(netpool.config.MaxIdleTime / 2) - - go func() { - for { - select { - case <-netpool.reaper.C: - netpool.reapIdleConnections() - case <-netpool.stopReaper: - return - } - } - }() -} - -func (netpool *Netpool) startIdleMaintainer() { - go func() { - interval := netpool.calculateMaintainerInterval() - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - netpool.maintainMinPool() - case <-netpool.stopReaper: - return - } - } - }() -} - -func (netpool *Netpool) calculateMaintainerInterval() time.Duration { - if netpool.config.MaintainerInterval > 0 { - return netpool.config.MaintainerInterval +func (p *Netpool) maintainer() { + interval := 5 * time.Second + if p.maxIdleTime > 0 && p.maxIdleTime < interval { + interval = p.maxIdleTime / 2 } - if netpool.config.MaxIdleTime > 0 && netpool.config.MaxIdleTime < 10*time.Second { - return netpool.config.MaxIdleTime - } - - if netpool.config.MaxIdleTime > 0 && netpool.config.MaxIdleTime < 1*time.Minute { - return 5 * time.Second + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-p.stopMaintainer: + return + case <-ticker.C: + p.reapIdle() + p.maintainMin() + } } - - return 30 * time.Second } -// maintainMinPool ensures we always have MinPool idle connections -func (netpool *Netpool) maintainMinPool() { - netpool.mu.Lock() - defer netpool.mu.Unlock() - - if netpool.closed.Load() { - return - } - - idle := netpool.connections.Len() - active := len(netpool.allConns) - - // Maintain *idle* minimum = MinPool - neededIdle := int(netpool.config.MinPool) - idle - if neededIdle <= 0 { - return - } - - // Cap by remaining capacity to MaxPool - remaining := int(netpool.config.MaxPool) - active - if remaining <= 0 { +// reapIdle removes connections that have been idle too long +func (p *Netpool) reapIdle() { + if p.closed.Load() || p.maxIdleTime == 0 { return } - if neededIdle > remaining { - neededIdle = remaining - } - added := 0 - for i := 0; i < neededIdle; i++ { - conn, err := netpool.createConnection() - if err != nil { - continue - } + now := time.Now() + var toKeep []idleConn - entry := &connEntry{ - conn: conn, - lastUsed: time.Now(), + for { + select { + case ic := <-p.idleConns: + if now.Sub(ic.idleSince) > p.maxIdleTime { + if len(toKeep)+len(p.idleConns) >= int(p.minPool) { + ic.conn.Close() + p.numOpen.Add(-1) + continue + } + } + toKeep = append(toKeep, ic) + default: + goto done } - entry.returned.Store(true) - - netpool.connections.PushBack(entry) - netpool.allConns[conn] = entry - added++ } - if added > 0 { - // Wake anyone waiting for a conn - netpool.cond.Broadcast() +done: + for _, ic := range toKeep { + select { + case p.idleConns <- ic: + default: + ic.conn.Close() + p.numOpen.Add(-1) + } } } -// reapIdleConnections removes connections that have been idle too long -func (netpool *Netpool) reapIdleConnections() { - netpool.mu.Lock() - defer netpool.mu.Unlock() - - if netpool.closed.Load() { +func (p *Netpool) maintainMin() { + if p.closed.Load() { return } - - now := time.Now() - var toRemove []*list.Element - - for e := netpool.connections.Front(); e != nil; e = e.Next() { - entry := e.Value.(*connEntry) - if now.Sub(entry.lastUsed) > netpool.config.MaxIdleTime { - if netpool.connections.Len()-len(toRemove) <= int(netpool.config.MinPool) { - break + idle := len(p.idleConns) + needed := int(p.minPool) - idle + for i := 0; i < needed; i++ { + current := p.numOpen.Load() + if current >= p.maxPool { + break + } + if p.numOpen.CompareAndSwap(current, current+1) { + conn, err := p.dial() + if err != nil { + p.numOpen.Add(-1) + continue + } + + select { + case p.idleConns <- idleConn{conn: conn, idleSince: time.Now()}: + default: + conn.Close() + p.numOpen.Add(-1) } - toRemove = append(toRemove, e) } } - - for _, e := range toRemove { - entry := e.Value.(*connEntry) - netpool.connections.Remove(e) - entry.conn.Close() - delete(netpool.allConns, entry.conn) - } - - if len(toRemove) > 0 { - netpool.cond.Broadcast() - } } diff --git a/netpool_basic_pool.go b/netpool_basic_pool.go new file mode 100644 index 0000000..08ad881 --- /dev/null +++ b/netpool_basic_pool.go @@ -0,0 +1,234 @@ +package netpool + +import ( + "context" + "net" + "sync/atomic" + "time" +) + +// BasicPool is a stripped-down lock-free pool without IdleTimeout or HealthCheck. +type BasicPool struct { + idleConns chan net.Conn + factory func() (net.Conn, error) + numOpen atomic.Int32 + maxPool int32 + minPool int32 + dialTimeout time.Duration + closed atomic.Bool + stopMaintainer chan struct{} +} + +// NewBasic creates a new high-performance basic pool. +// Note: This pool ignores MaxIdleTime and HealthCheck in Config. +func NewBasic(factory func() (net.Conn, error), cfg Config) (*BasicPool, error) { + if cfg.MaxPool <= 0 { + cfg.MaxPool = 15 + } + if cfg.MinPool < 0 { + cfg.MinPool = 0 + } + if cfg.MinPool > cfg.MaxPool { + cfg.MinPool = cfg.MaxPool + } + + pool := &BasicPool{ + idleConns: make(chan net.Conn, cfg.MaxPool), + factory: factory, + maxPool: cfg.MaxPool, + minPool: cfg.MinPool, + dialTimeout: cfg.DialTimeout, + stopMaintainer: make(chan struct{}), + } + + for i := int32(0); i < cfg.MinPool; i++ { + conn, err := pool.dial() + if err != nil { + pool.Close() + return nil, err + } + pool.idleConns <- conn + pool.numOpen.Add(1) + } + + if cfg.MinPool > 0 { + go pool.maintainer() + } + + return pool, nil +} + +func (p *BasicPool) dial() (net.Conn, error) { + if p.dialTimeout > 0 { + type result struct { + conn net.Conn + err error + } + ch := make(chan result, 1) + go func() { + conn, err := p.factory() + ch <- result{conn, err} + }() + + select { + case r := <-ch: + return r.conn, r.err + case <-time.After(p.dialTimeout): + return nil, ErrDialTimeout + } + } + return p.factory() +} + +// Get returns a connection from the pool. +func (p *BasicPool) Get() (net.Conn, error) { + return p.GetWithContext(context.Background()) +} + +// GetWithContext returns a connection, blocking until one is available. +func (p *BasicPool) GetWithContext(ctx context.Context) (net.Conn, error) { + if p.closed.Load() { + return nil, ErrPoolClosed + } + + select { + case conn := <-p.idleConns: + return conn, nil + default: + } + + for { + if p.closed.Load() { + return nil, ErrPoolClosed + } + + current := p.numOpen.Load() + if current < p.maxPool { + if p.numOpen.CompareAndSwap(current, current+1) { + conn, err := p.dial() + if err != nil { + p.numOpen.Add(-1) + return nil, err + } + return conn, nil + } + continue + } + + select { + case conn := <-p.idleConns: + return conn, nil + case <-p.stopMaintainer: + return nil, ErrPoolClosed + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +// Put returns a connection to the pool. +func (p *BasicPool) Put(conn net.Conn) { + if conn == nil { + return + } + if p.closed.Load() { + conn.Close() + p.numOpen.Add(-1) + return + } + + select { + case p.idleConns <- conn: + default: + conn.Close() + p.numOpen.Add(-1) + } +} + +// PutWithError returns a connection, closing it if there was an error. +func (p *BasicPool) PutWithError(conn net.Conn, err error) { + if conn == nil { + return + } + if err != nil { + conn.Close() + p.numOpen.Add(-1) + return + } + p.Put(conn) +} + +// Close closes the pool and all connections. +func (p *BasicPool) Close() { + if !p.closed.CompareAndSwap(false, true) { + return + } + close(p.stopMaintainer) + for { + select { + case conn := <-p.idleConns: + conn.Close() + p.numOpen.Add(-1) + default: + return + } + } +} + +// Stats returns pool statistics. +func (p *BasicPool) Stats() PoolStats { + idle := len(p.idleConns) + total := int(p.numOpen.Load()) + return PoolStats{ + Active: total, + Idle: idle, + InUse: total - idle, + MaxPool: p.maxPool, + MinPool: p.minPool, + } +} + +// Len returns the number of idle connections. +func (p *BasicPool) Len() int { + return len(p.idleConns) +} + +func (p *BasicPool) maintainer() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for { + select { + case <-p.stopMaintainer: + return + case <-ticker.C: + p.maintainMin() + } + } +} + +func (p *BasicPool) maintainMin() { + if p.closed.Load() { + return + } + idle := len(p.idleConns) + needed := int(p.minPool) - idle + for i := 0; i < needed; i++ { + current := p.numOpen.Load() + if current >= p.maxPool { + break + } + if p.numOpen.CompareAndSwap(current, current+1) { + conn, err := p.dial() + if err != nil { + p.numOpen.Add(-1) + continue + } + select { + case p.idleConns <- conn: + default: + conn.Close() + p.numOpen.Add(-1) + } + } + } +} diff --git a/netpool_basic_pool_test.go b/netpool_basic_pool_test.go new file mode 100644 index 0000000..10a04b6 --- /dev/null +++ b/netpool_basic_pool_test.go @@ -0,0 +1,205 @@ +package netpool + +import ( + "context" + "net" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestNewBasic(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := NewBasic(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 10, + MinPool: 2, + }) + if err != nil { + t.Fatalf("failed to create basic pool: %v", err) + } + defer pool.Close() + + stats := pool.Stats() + if stats.Idle < 2 { + t.Errorf("expected at least 2 idle connections, got %d", stats.Idle) + } + if stats.MaxPool != 10 { + t.Errorf("expected MaxPool 10, got %d", stats.MaxPool) + } +} + +func TestBasicGetPut(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := NewBasic(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 5, + MinPool: 0, + }) + if err != nil { + t.Fatalf("failed to create request pool: %v", err) + } + defer pool.Close() + + // Get a connection + conn, err := pool.Get() + if err != nil { + t.Fatalf("Get() failed: %v", err) + } + + stats := pool.Stats() + if stats.InUse != 1 { + t.Errorf("expected 1 in-use, got %d", stats.InUse) + } + + // Put it back + pool.Put(conn) + + stats = pool.Stats() + if stats.InUse != 0 { + t.Errorf("expected 0 in-use after Put, got %d", stats.InUse) + } + if stats.Idle != 1 { + t.Errorf("expected 1 idle after Put, got %d", stats.Idle) + } +} + +func TestBasicConcurrent(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := NewBasic(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 100, + MinPool: 0, + }) + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + defer pool.Close() + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := pool.Get() + if err != nil { + t.Errorf("Get() failed: %v", err) + return + } + time.Sleep(time.Millisecond) + pool.Put(conn) + }() + } + wg.Wait() + + stats := pool.Stats() + if stats.InUse != 0 { + t.Errorf("expected 0 in-use after concurrent test, got %d", stats.InUse) + } +} + +func TestBasicPoolClose(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, _ := NewBasic(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 5, + MinPool: 2, + }) + + pool.Close() + + _, err := pool.Get() + if err != ErrPoolClosed { + t.Errorf("expected ErrPoolClosed, got %v", err) + } +} + +func TestBasicPutWithError(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, _ := NewBasic(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 5, + }) + defer pool.Close() + + conn, _ := pool.Get() + pool.PutWithError(conn, ErrInvalidConn) + + stats := pool.Stats() + if stats.Active != 0 { + t.Errorf("expected 0 active after PutWithError, got %d", stats.Active) + } +} + +func TestBasicContextCancellation(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, _ := NewBasic(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 1, + }) + defer pool.Close() + + conn, _ := pool.Get() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := pool.GetWithContext(ctx) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %v", err) + } + + pool.Put(conn) +} + +func TestBasicRace(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, _ := NewBasic(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 10, + }) + defer pool.Close() + + var wg sync.WaitGroup + var ops atomic.Int64 + + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + conn, err := pool.Get() + if err != nil { + return + } + ops.Add(1) + pool.Put(conn) + } + }() + } + + wg.Wait() + t.Logf("BasicPool Race: Completed %d operations", ops.Load()) +} diff --git a/netpool_test.go b/netpool_test.go index e4de610..464dac8 100644 --- a/netpool_test.go +++ b/netpool_test.go @@ -2,28 +2,25 @@ package netpool import ( "context" - "errors" - "fmt" "net" "sync" + "sync/atomic" "testing" "time" ) -func createTestServer(t *testing.T) (net.Listener, string) { +func createTestServer(t testing.TB) (net.Listener, string) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to create listener: %v", err) } - // Accept connections in background go func() { for { conn, err := listener.Accept() if err != nil { return } - // Echo server: read and write back go func(c net.Conn) { defer c.Close() buf := make([]byte, 1024) @@ -41,1238 +38,366 @@ func createTestServer(t *testing.T) (net.Listener, string) { return listener, listener.Addr().String() } -func TestNewConnection(t *testing.T) { +func TestNewPool(t *testing.T) { listener, addr := createTestServer(t) defer listener.Close() pool, err := New(func() (net.Conn, error) { return net.Dial("tcp", addr) - }, WithMinPool(3), WithMaxPool(10)) - + }, Config{ + MaxPool: 10, + MinPool: 2, + }) if err != nil { - t.Fatalf("expected no error, got %v", err) + t.Fatalf("failed to create pool: %v", err) } defer pool.Close() - if pool.Len() != 3 { - t.Errorf("expected 3 idle connections, got %d", pool.Len()) - } - stats := pool.Stats() - if stats.Active != 3 { - t.Errorf("expected 3 total connections, got %d", stats.Active) + if stats.Idle < 2 { + t.Errorf("expected at least 2 idle connections, got %d", stats.Idle) } } -func TestGetConnection(t *testing.T) { +func TestGetPut(t *testing.T) { listener, addr := createTestServer(t) defer listener.Close() pool, err := New(func() (net.Conn, error) { return net.Dial("tcp", addr) - }, WithMinPool(2), WithMaxPool(5)) - + }, Config{ + MaxPool: 5, + MinPool: 0, + }) if err != nil { t.Fatalf("failed to create pool: %v", err) } defer pool.Close() + // Get a connection conn, err := pool.Get() if err != nil { - t.Fatalf("Get failed: %v", err) + t.Fatalf("Get() failed: %v", err) } - defer conn.Close() // Auto-return to pool - testMsg := []byte("hello") - _, err = conn.Write(testMsg) - if err != nil { - t.Fatalf("write failed: %v", err) + stats := pool.Stats() + if stats.InUse != 1 { + t.Errorf("expected 1 in-use, got %d", stats.InUse) } - buf := make([]byte, 1024) - conn.SetReadDeadline(time.Now().Add(1 * time.Second)) - n, err := conn.Read(buf) - if err != nil { - t.Fatalf("read failed: %v", err) - } + // Put it back + pool.Put(conn) - if string(buf[:n]) != string(testMsg) { - t.Errorf("expected %s, got %s", testMsg, buf[:n]) + stats = pool.Stats() + if stats.InUse != 0 { + t.Errorf("expected 0 in-use after Put, got %d", stats.InUse) } - - // Connection auto-returns on Close() - conn.Close() - - // Give a moment for Close() to complete - time.Sleep(10 * time.Millisecond) - - if pool.Len() != 2 { - t.Errorf("expected 2 idle connections after Close, got %d", pool.Len()) + if stats.Idle != 1 { + t.Errorf("expected 1 idle after Put, got %d", stats.Idle) } } -func TestGetCreatesNewConnection(t *testing.T) { +func TestConcurrentAccess(t *testing.T) { listener, addr := createTestServer(t) defer listener.Close() pool, err := New(func() (net.Conn, error) { return net.Dial("tcp", addr) - }, WithMinPool(1), WithMaxPool(5)) - + }, Config{ + MaxPool: 10, + MinPool: 0, + }) if err != nil { t.Fatalf("failed to create pool: %v", err) } defer pool.Close() - conn1, _ := pool.Get() - defer conn1.Close() - - conn2, err := pool.Get() - if err != nil { - t.Fatalf("expected new connection, got error: %v", err) + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := pool.Get() + if err != nil { + t.Errorf("Get() failed: %v", err) + return + } + time.Sleep(1 * time.Millisecond) + pool.Put(conn) + }() } - defer conn2.Close() + wg.Wait() stats := pool.Stats() - if stats.Active != 2 { - t.Errorf("expected 2 total connections, got %d", stats.Active) - } - if stats.InUse != 2 { - t.Errorf("expected 2 in-use connections, got %d", stats.InUse) + if stats.InUse != 0 { + t.Errorf("expected 0 in-use after concurrent test, got %d", stats.InUse) } } -func TestPutWithConnectionError(t *testing.T) { +func TestMaxPoolLimit(t *testing.T) { listener, addr := createTestServer(t) defer listener.Close() pool, err := New(func() (net.Conn, error) { return net.Dial("tcp", addr) - }, WithMinPool(1), WithMaxPool(5)) - + }, Config{ + MaxPool: 3, + MinPool: 0, + }) if err != nil { t.Fatalf("failed to create pool: %v", err) } defer pool.Close() - conn, _ := pool.Get() - initialStats := pool.Stats() - - // Use MarkUnusable for better error handling - if pc, ok := conn.(interface{ MarkUnusable() error }); ok { - err := pc.MarkUnusable() + // Get all 3 connections + conns := make([]net.Conn, 3) + for i := 0; i < 3; i++ { + conn, err := pool.Get() if err != nil { - t.Fatalf("MarkUnusable failed: %v", err) + t.Fatalf("Get() %d failed: %v", i, err) } - } else { - t.Fatal("connection doesn't support MarkUnusable") - } - - // Give time for cleanup - time.Sleep(50 * time.Millisecond) - - finalStats := pool.Stats() - - if finalStats.Active != initialStats.Active-1 { - t.Errorf("expected TotalCreated to decrease from %d to %d, got %d", - initialStats.Active, initialStats.Active-1, finalStats.Active) + conns[i] = conn } -} - -func TestGetBlocksAtMaxPool(t *testing.T) { - listener, addr := createTestServer(t) - defer listener.Close() - pool, err := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(1), WithMaxPool(2)) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) + stats := pool.Stats() + if stats.Active != 3 { + t.Errorf("expected 3 active, got %d", stats.Active) } - defer pool.Close() - conn1, _ := pool.Get() - conn2, _ := pool.Get() - - blocked := make(chan bool, 1) - var conn3 net.Conn - - go func() { - conn3, _ = pool.Get() - blocked <- true - }() + // Try to get 4th with timeout - should fail + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() - select { - case <-blocked: - t.Fatal("Get() should have blocked") - case <-time.After(100 * time.Millisecond): + _, err = pool.GetWithContext(ctx) + if err != context.DeadlineExceeded { + t.Errorf("expected DeadlineExceeded, got %v", err) } - conn1.Close() // Auto-return to pool - - select { - case <-blocked: - if conn3 == nil { - t.Fatal("expected connection after unblock") - } - case <-time.After(1 * time.Second): - t.Fatal("Get() didn't unblock") + // Return connections + for _, conn := range conns { + pool.Put(conn) } - - conn2.Close() - conn3.Close() } -func TestConcurrentGetPutConnections(t *testing.T) { +func TestPoolClose(t *testing.T) { listener, addr := createTestServer(t) defer listener.Close() pool, err := New(func() (net.Conn, error) { return net.Dial("tcp", addr) - }, WithMinPool(2), WithMaxPool(10)) - + }, Config{ + MaxPool: 5, + MinPool: 2, + }) if err != nil { t.Fatalf("failed to create pool: %v", err) } - defer pool.Close() - - var wg sync.WaitGroup - goroutines := 20 - iterations := 50 - - errorsChan := make(chan error, goroutines*iterations) - - for i := 0; i < goroutines; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - for j := 0; j < iterations; j++ { - conn, err := pool.Get() - if err != nil { - errorsChan <- err - continue - } - - // Use the connection - msg := []byte("test") - _, err = conn.Write(msg) - if err != nil { - // Mark unusable on error - if pc, ok := conn.(interface{ MarkUnusable() error }); ok { - pc.MarkUnusable() - } else { - conn.Close() - } - errorsChan <- err - continue - } - - buf := make([]byte, 1024) - conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) - _, err = conn.Read(buf) - - if err != nil { - // Mark unusable on error - if pc, ok := conn.(interface{ MarkUnusable() error }); ok { - pc.MarkUnusable() - } else { - conn.Close() - } - } else { - conn.Close() // Normal return to pool - } - } - }(i) - } - - wg.Wait() - close(errorsChan) - errorCount := 0 - for err := range errorsChan { - t.Logf("error during test: %v", err) - errorCount++ - } + pool.Close() - if errorCount > 0 { - t.Errorf("got %d errors during concurrent operations", errorCount) + // Get after close should fail + _, err = pool.Get() + if err != ErrPoolClosed { + t.Errorf("expected ErrPoolClosed, got %v", err) } - - stats := pool.Stats() - t.Logf("Final stats - Total: %d, Idle: %d, InUse: %d", - stats.Active, stats.Idle, stats.InUse) } -func TestCloseConnections(t *testing.T) { +func TestPutWithError(t *testing.T) { listener, addr := createTestServer(t) defer listener.Close() pool, err := New(func() (net.Conn, error) { return net.Dial("tcp", addr) - }, WithMinPool(3), WithMaxPool(5)) - + }, Config{ + MaxPool: 5, + MinPool: 0, + }) if err != nil { t.Fatalf("failed to create pool: %v", err) } + defer pool.Close() - conns := make([]net.Conn, 2) - for i := 0; i < 2; i++ { - conns[i], _ = pool.Get() - } + conn, _ := pool.Get() - pool.Close() + // Put with error should close the connection + pool.PutWithError(conn, ErrInvalidConn) stats := pool.Stats() - if stats.Idle != 0 { - t.Errorf("expected 0 idle connections after Close, got %d", stats.Idle) - } - - // Connections should be closed by pool.Close() - _, err = conns[0].Write([]byte("test")) - if err == nil { - t.Error("expected error writing to closed connection") + if stats.Active != 0 { + t.Errorf("expected 0 active after PutWithError, got %d", stats.Active) } } -func TestNewDialFailure(t *testing.T) { +func TestDialFailure(t *testing.T) { callCount := 0 _, err := New(func() (net.Conn, error) { callCount++ - if callCount == 2 { - return nil, errors.New("dial failed") - } - return net.Dial("tcp", "127.0.0.1:0") - }, WithMinPool(3)) + return nil, net.UnknownNetworkError("test error") + }, Config{ + MaxPool: 5, + MinPool: 2, + }) if err == nil { - t.Fatal("expected error from failed dial") + t.Error("expected error from New with failing dial") } } -func TestDialHooksConnection(t *testing.T) { +func TestContextCancellation(t *testing.T) { listener, addr := createTestServer(t) defer listener.Close() - hookCalled := 0 - hook := func(conn net.Conn) error { // Updated signature: any instead of net.Conn - hookCalled++ - c := conn.(net.Conn) - // Test the connection - _, err := c.Write([]byte("hook test")) - return err - } - - pool, err := New(func() (net.Conn, error) { + pool, _ := New(func() (net.Conn, error) { return net.Dial("tcp", addr) - }, WithMinPool(2), WithDialHooks(hook)) // Pass as slice - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } + }, Config{ + MaxPool: 1, + MinPool: 0, + }) defer pool.Close() - if hookCalled != 2 { - t.Errorf("expected hook to be called 2 times during init, got %d", hookCalled) - } - - conn1, _ := pool.Get() - conn2, _ := pool.Get() - - if hookCalled != 2 { - t.Errorf("expected hook still 2 times (reusing pool connections), got %d", hookCalled) - } + // Hold the only connection + conn, _ := pool.Get() - conn3, _ := pool.Get() + // Try to get with cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() - if hookCalled != 3 { - t.Errorf("expected hook to be called 3 times total (created new conn), got %d", hookCalled) + _, err := pool.GetWithContext(ctx) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %v", err) } - conn1.Close() - conn2.Close() - conn3.Close() + pool.Put(conn) } -func TestPooledConnAutoReturn(t *testing.T) { - listener, addr := createTestServer(t) - defer listener.Close() - +func TestDialTimeout(t *testing.T) { + // Use a non-routable IP to simulate timeout pool, err := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(2), WithMaxPool(5)) - + return net.Dial("tcp", "10.255.255.1:12345") + }, Config{ + MaxPool: 5, + MinPool: 0, + DialTimeout: 100 * time.Millisecond, + }) if err != nil { t.Fatalf("failed to create pool: %v", err) } defer pool.Close() - initialStats := pool.Stats() - - // Get connection - conn, err := pool.Get() - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - afterGetStats := pool.Stats() - if afterGetStats.InUse != initialStats.InUse+1 { - t.Errorf("expected InUse to increase by 1") - } - - err = conn.Close() - if err != nil { - t.Fatalf("Close failed: %v", err) - } - - afterCloseStats := pool.Stats() - if afterCloseStats.InUse != initialStats.InUse { - t.Errorf("expected InUse to return to original after Close") - } - - _, err = conn.Write([]byte("test")) - if err == nil { - t.Error("expected error writing to closed connection") + _, err = pool.Get() + if err != ErrDialTimeout { + t.Errorf("expected ErrDialTimeout, got %v", err) } } -func TestPooledConnMarkUnusable(t *testing.T) { - listener, addr := createTestServer(t) +// Benchmark + +func BenchmarkPoolGet(b *testing.B) { + listener, addr := createTestServer(b) defer listener.Close() - pool, err := New(func() (net.Conn, error) { + pool, _ := New(func() (net.Conn, error) { return net.Dial("tcp", addr) - }, WithMinPool(2), WithMaxPool(5)) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } + }, Config{ + MaxPool: 100, + MinPool: 50, + }) defer pool.Close() - initialStats := pool.Stats() - - conn, _ := pool.Get() - - if pc, ok := conn.(interface{ MarkUnusable() error }); ok { - err = pc.MarkUnusable() - if err != nil { - t.Fatalf("MarkUnusable failed: %v", err) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get() + if err != nil { + b.Fatal(err) + } + pool.Put(conn) } - } else { - t.Fatal("connection doesn't implement MarkUnusable") - } - - finalStats := pool.Stats() - if finalStats.Active != initialStats.Active-1 { - t.Errorf("expected TotalCreated to decrease by 1 after MarkUnusable") - } + }) } -func TestIdleTimeoutWithPoolGrowthAndShrink(t *testing.T) { - listener, addr := createTestServer(t) +func BenchmarkPoolConcurrent(b *testing.B) { + listener, addr := createTestServer(b) defer listener.Close() - pool, err := New(func() (net.Conn, error) { + pool, _ := New(func() (net.Conn, error) { return net.Dial("tcp", addr) - }, - WithMinPool(2), - WithMaxPool(10), - WithMaxIdleTime(2*time.Second), // 2 second idle timeout - ) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } + }, Config{ + MaxPool: 300, + MinPool: 10, + }) defer pool.Close() - t.Log("Phase 1: Initial state") - stats := pool.Stats() - if stats.Active != 2 { - t.Errorf("expected 2 initial connections, got %d", stats.Active) - } - t.Logf("Initial - Total: %d, Idle: %d, InUse: %d", stats.Active, stats.Idle, stats.InUse) - - // Phase 2: Grow pool to max - t.Log("Phase 2: Growing pool to max capacity") - conns := make([]net.Conn, 10) - for i := 0; i < 10; i++ { - conns[i], err = pool.Get() - if err != nil { - t.Fatalf("failed to get connection %d: %v", i, err) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := pool.Get() + if err != nil { + b.Fatal(err) + } + conn.Write([]byte("test data")) + pool.Put(conn) } - } - - stats = pool.Stats() - if stats.Active != 10 { - t.Errorf("expected 10 connections at max, got %d", stats.Active) - } - if stats.InUse != 10 { - t.Errorf("expected 10 in-use connections, got %d", stats.InUse) - } - t.Logf("At max capacity - Total: %d, Idle: %d, InUse: %d", stats.Active, stats.Idle, stats.InUse) - - // Phase 3: Return all connections to pool - t.Log("Phase 3: Returning all connections to pool") - for i := 0; i < 10; i++ { - conns[i].Close() - } - time.Sleep(100 * time.Millisecond) // Wait for close to complete - - stats = pool.Stats() - if stats.Idle != 10 { - t.Errorf("expected 10 idle connections, got %d", stats.Idle) - } - if stats.InUse != 0 { - t.Errorf("expected 0 in-use connections, got %d", stats.InUse) - } - t.Logf("All returned - Total: %d, Idle: %d, InUse: %d", stats.Active, stats.Idle, stats.InUse) - - // Phase 4: Wait for idle timeout to kick in (should clean up to MinPool) - t.Log("Phase 4: Waiting for idle timeout...") - time.Sleep(3 * time.Second) // Wait longer than MaxIdleTime - - stats = pool.Stats() - t.Logf("After idle timeout - Total: %d, Idle: %d, InUse: %d", stats.Active, stats.Idle, stats.InUse) - - // Should have cleaned up idle connections but kept at least MinPool - if stats.Active < 2 { - t.Errorf("pool went below MinPool: got %d, expected at least 2", stats.Active) - } - if stats.Active > 5 { - t.Errorf("idle reaper didn't clean up enough: got %d, expected <= 5", stats.Active) - } - - // Phase 5: Verify pool still works after cleanup - t.Log("Phase 5: Verifying pool still works after cleanup") - testConn, err := pool.Get() - if err != nil { - t.Fatalf("pool not working after idle cleanup: %v", err) - } - - // Use the connection - testMsg := []byte("test after cleanup") - _, err = testConn.Write(testMsg) - if err != nil { - t.Fatalf("connection not working after cleanup: %v", err) - } + }) +} - buf := make([]byte, 1024) - testConn.SetReadDeadline(time.Now().Add(1 * time.Second)) - n, err := testConn.Read(buf) - if err != nil { - t.Fatalf("read failed after cleanup: %v", err) - } +func BenchmarkPoolGetNoContention(b *testing.B) { + listener, addr := createTestServer(b) + defer listener.Close() - if string(buf[:n]) != string(testMsg) { - t.Errorf("echo failed: expected %s, got %s", testMsg, buf[:n]) - } + pool, _ := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 1000, + MinPool: 100, + }) + defer pool.Close() - testConn.Close() - t.Log("Phase 5: Pool still functional") + // Wait for MinPool to be ready + time.Sleep(100 * time.Millisecond) - // Phase 6: Grow again and verify it works - t.Log("Phase 6: Growing pool again after cleanup") - newConns := make([]net.Conn, 5) - for i := 0; i < 5; i++ { - newConns[i], err = pool.Get() + b.ResetTimer() + for i := 0; i < b.N; i++ { + conn, err := pool.Get() if err != nil { - t.Fatalf("failed to get connection after cleanup: %v", err) + b.Fatal(err) } + pool.Put(conn) } - - stats = pool.Stats() - t.Logf("After regrowth - Total: %d, Idle: %d, InUse: %d", stats.Active, stats.Idle, stats.InUse) - - if stats.InUse < 5 { - t.Errorf("expected at least 5 in-use connections, got %d", stats.InUse) - } - - // Cleanup - for i := 0; i < 5; i++ { - newConns[i].Close() - } - - t.Log("Test completed successfully") } -func TestRepeatedGrowthShrinkDoesNotLeak(t *testing.T) { +// Race condition tests + +func TestRaceGetPut(t *testing.T) { listener, addr := createTestServer(t) defer listener.Close() - pool, err := New( - func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, - WithMinPool(2), - WithMaxPool(20), - WithMaxIdleTime(200*time.Millisecond), - ) - if err != nil { - t.Fatal(err) - } + pool, _ := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, Config{ + MaxPool: 10, + MinPool: 0, + }) defer pool.Close() - for i := 0; i < 10; i++ { - conns := make([]net.Conn, 20) - for j := 0; j < 20; j++ { - conns[j], _ = pool.Get() - } - - for j := 0; j < 20; j++ { - conns[j].Close() - } - - time.Sleep(400 * time.Millisecond) + var wg sync.WaitGroup + var ops atomic.Int64 - stats := pool.Stats() - if stats.Active > 5 { - t.Fatalf("iteration %d: pool not shrinking properly: %+v", i, stats) - } - } -} - -func TestIdleReaperWhileConcurrentGet(t *testing.T) { - listener, addr := createTestServer(t) - defer listener.Close() - - pool, err := New( - func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, - WithMinPool(1), - WithMaxPool(5), - WithMaxIdleTime(200*time.Millisecond), - ) - if err != nil { - t.Fatal(err) - } - defer pool.Close() - - stop := make(chan struct{}) - go func() { - for { - select { - case <-stop: - return - default: - conn, err := pool.Get() - if err == nil { - time.Sleep(10 * time.Millisecond) - conn.Close() - } - } - } - }() - - time.Sleep(1 * time.Second) // let reaper & maintainer fight - - close(stop) - - stats := pool.Stats() - if stats.Active < 1 { - t.Fatalf("pool corrupted: %+v", stats) - } -} - -func BenchmarkPoolGet(b *testing.B) { - listener, addr := createBenchmarkTestServer(b) - defer listener.Close() - - pool, _ := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(10), WithMaxPool(300)) - defer pool.Close() - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - conn, err := pool.Get() - if err != nil { - b.Fatal(err) - } - conn.Close() - } - }) -} - -func TestContextCancelDuringWait(t *testing.T) { - listener, addr := createTestServer(t) - defer listener.Close() - - pool, err := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(0), WithMaxPool(1)) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } - defer pool.Close() - - // Exhaust the pool - conn1, err := pool.Get() - if err != nil { - t.Fatalf("failed to get first connection: %v", err) - } - defer conn1.Close() - - // Verify pool is exhausted - stats := pool.Stats() - if stats.InUse != 1 || stats.Idle != 0 { - t.Errorf("pool should be exhausted: InUse=%d, Idle=%d", stats.InUse, stats.Idle) - } - - // Try to get with timeout context - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - start := time.Now() - _, err = pool.GetWithContext(ctx) - elapsed := time.Since(start) - - if err != context.DeadlineExceeded { - t.Errorf("expected DeadlineExceeded, got %v", err) - } - - // Should timeout around 100ms, not instantly - if elapsed < 50*time.Millisecond { - t.Errorf("timeout happened too quickly: %v", elapsed) - } - if elapsed > 200*time.Millisecond { - t.Errorf("timeout took too long: %v", elapsed) - } - - t.Logf("Context cancellation detected after %v", elapsed) -} - -func TestCloseWithWaiters(t *testing.T) { - listener, addr := createTestServer(t) - defer listener.Close() - - pool, err := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(0), WithMaxPool(1)) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } - - // Exhaust the pool - conn1, err := pool.Get() - if err != nil { - t.Fatalf("failed to get connection: %v", err) - } - - // Start goroutine that will block waiting for a connection - done := make(chan error, 1) - started := make(chan struct{}) - - go func() { - close(started) // Signal that goroutine started - _, err := pool.Get() - done <- err - }() - - // Wait for goroutine to start and begin waiting - <-started - time.Sleep(50 * time.Millisecond) - - // Close pool while goroutine is waiting - pool.Close() - - // The waiting goroutine should unblock with ErrPoolClosed - select { - case err := <-done: - if err != ErrPoolClosed { - t.Errorf("expected ErrPoolClosed, got %v", err) - } - t.Logf("Waiter properly unblocked with: %v", err) - case <-time.After(1 * time.Second): - t.Fatal("waiting goroutine did not unblock after pool.Close()") - } - - // Verify connection is closed - _, err = conn1.Write([]byte("test")) - if err == nil { - t.Error("connection should be closed after pool.Close()") - } - - // Subsequent Get should return ErrPoolClosed - _, err = pool.Get() - if err != ErrPoolClosed { - t.Errorf("Get after Close should return ErrPoolClosed, got %v", err) - } -} - -func TestStatsConsistency(t *testing.T) { - listener, addr := createTestServer(t) - defer listener.Close() - - pool, err := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(2), WithMaxPool(10)) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } - defer pool.Close() - - var wg sync.WaitGroup - errors := make(chan string, 100) - - // Goroutines that continuously check stats consistency - for i := 0; i < 10; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - for j := 0; j < 100; j++ { - stats := pool.Stats() - - // Active should always equal Idle + InUse - if stats.Active != stats.Idle+stats.InUse { - errors <- fmt.Sprintf("goroutine %d: inconsistent stats: Active=%d, Idle=%d, InUse=%d", - id, stats.Active, stats.Idle, stats.InUse) - } - - // Active should never exceed MaxPool - if stats.Active > int(stats.MaxPool) { - errors <- fmt.Sprintf("goroutine %d: Active(%d) > MaxPool(%d)", - id, stats.Active, stats.MaxPool) - } - - // Idle should never exceed Active - if stats.Idle > stats.Active { - errors <- fmt.Sprintf("goroutine %d: Idle(%d) > Active(%d)", - id, stats.Idle, stats.Active) - } - - // InUse should never exceed Active - if stats.InUse > stats.Active { - errors <- fmt.Sprintf("goroutine %d: InUse(%d) > Active(%d)", - id, stats.InUse, stats.Active) - } - - time.Sleep(1 * time.Millisecond) - } - }(i) - } - - // Goroutines that Get and Put connections - for i := 0; i < 10; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - for j := 0; j < 50; j++ { - conn, err := pool.Get() - if err != nil { - continue - } - - // Use the connection briefly - time.Sleep(5 * time.Millisecond) - - // Randomly mark some as unusable - if j%10 == 0 { - if pc, ok := conn.(interface{ MarkUnusable() error }); ok { - pc.MarkUnusable() - } - } else { - conn.Close() - } - } - }(i) - } - - wg.Wait() - close(errors) - - // Check for any consistency errors - errorCount := 0 - for errMsg := range errors { - t.Error(errMsg) - errorCount++ - if errorCount >= 10 { - t.Log("... (more errors suppressed)") - break - } - } - - if errorCount > 0 { - t.Fatalf("found %d stats consistency violations", errorCount) - } - - // Final verification - finalStats := pool.Stats() - t.Logf("Final stats: Active=%d, Idle=%d, InUse=%d, MaxPool=%d, MinPool=%d", - finalStats.Active, finalStats.Idle, finalStats.InUse, finalStats.MaxPool, finalStats.MinPool) - - if finalStats.Active != finalStats.Idle+finalStats.InUse { - t.Errorf("final stats inconsistent: Active=%d != Idle(%d) + InUse(%d)", - finalStats.Active, finalStats.Idle, finalStats.InUse) - } -} - -func TestMultipleWaitersUnblockOnClose(t *testing.T) { - listener, addr := createTestServer(t) - defer listener.Close() - - pool, err := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(0), WithMaxPool(1)) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } - - // Exhaust the pool - conn1, err := pool.Get() - if err != nil { - t.Fatalf("failed to get connection: %v", err) - } - - // Start multiple goroutines waiting for connections - numWaiters := 5 - done := make(chan error, numWaiters) - - for i := 0; i < numWaiters; i++ { - go func(id int) { - _, err := pool.Get() - done <- err - }(i) - } - - // Give them time to start waiting - time.Sleep(100 * time.Millisecond) - - // Close the pool - pool.Close() - - // All waiters should unblock with ErrPoolClosed - timeout := time.After(2 * time.Second) - for i := 0; i < numWaiters; i++ { - select { - case err := <-done: - if err != ErrPoolClosed { - t.Errorf("waiter %d: expected ErrPoolClosed, got %v", i, err) - } - case <-timeout: - t.Fatalf("only %d/%d waiters unblocked", i, numWaiters) - } - } - - t.Logf("All %d waiters properly unblocked", numWaiters) - - // Cleanup - conn1.Close() -} - -func TestContextCancelWithMultipleWaiters(t *testing.T) { - listener, addr := createTestServer(t) - defer listener.Close() - - pool, err := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(0), WithMaxPool(2)) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } - defer pool.Close() - - // Exhaust the pool - conn1, _ := pool.Get() - conn2, _ := pool.Get() - defer conn1.Close() - defer conn2.Close() - - // Start 3 goroutines with different context timeouts - type result struct { - id int - err error - elapsed time.Duration - } - - results := make(chan result, 3) - - // Waiter 1: 50ms timeout - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - start := time.Now() - _, err := pool.GetWithContext(ctx) - results <- result{1, err, time.Since(start)} - }() - - // Waiter 2: 100ms timeout - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - start := time.Now() - _, err := pool.GetWithContext(ctx) - results <- result{2, err, time.Since(start)} - }() - - // Waiter 3: 150ms timeout - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) - defer cancel() - start := time.Now() - _, err := pool.GetWithContext(ctx) - results <- result{3, err, time.Since(start)} - }() - - // Collect results - for i := 0; i < 3; i++ { - r := <-results - if r.err != context.DeadlineExceeded { - t.Errorf("waiter %d: expected DeadlineExceeded, got %v", r.id, r.err) - } - t.Logf("Waiter %d timed out after %v", r.id, r.elapsed) - } -} - -func TestDoubleCloseConnection(t *testing.T) { - listener, addr := createTestServer(t) - defer listener.Close() - - pool, err := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(1), WithMaxPool(5)) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } - defer pool.Close() - - conn, err := pool.Get() - if err != nil { - t.Fatalf("failed to get connection: %v", err) - } - - // First close - should return to pool - err = conn.Close() - if err != nil { - t.Errorf("first Close() failed: %v", err) - } - - time.Sleep(50 * time.Millisecond) - - statsAfterFirstClose := pool.Stats() - t.Logf("After first close: Active=%d, Idle=%d, InUse=%d", - statsAfterFirstClose.Active, statsAfterFirstClose.Idle, statsAfterFirstClose.InUse) - - // Second close - should be a no-op or return error - err = conn.Close() - if err != nil { - t.Logf("second Close() returned expected error: %v", err) - } - - time.Sleep(50 * time.Millisecond) - - statsAfterSecondClose := pool.Stats() - t.Logf("After second close: Active=%d, Idle=%d, InUse=%d", - statsAfterSecondClose.Active, statsAfterSecondClose.Idle, statsAfterSecondClose.InUse) - - // Stats should not change after second close - if statsAfterFirstClose.Active != statsAfterSecondClose.Active { - t.Errorf("stats changed after double close: %+v -> %+v", - statsAfterFirstClose, statsAfterSecondClose) - } - - // Using connection after double close should fail - _, err = conn.Write([]byte("test")) - if err == nil { - t.Error("Write should fail on double-closed connection") - } -} - -func TestExtremeContention(t *testing.T) { - if testing.Short() { - t.Skip("skipping extreme contention test in short mode") - } - - listener, addr := createTestServer(t) - defer listener.Close() - - pool, err := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(1), WithMaxPool(5)) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } - defer pool.Close() - - numGoroutines := 100 - iterationsPerGoroutine := 100 - - var wg sync.WaitGroup - errors := make(chan error, numGoroutines*iterationsPerGoroutine) - - start := time.Now() - - for i := 0; i < numGoroutines; i++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - - for j := 0; j < iterationsPerGoroutine; j++ { - conn, err := pool.Get() - if err != nil { - errors <- fmt.Errorf("goroutine %d: Get failed: %w", id, err) - continue - } - - // Simulate very brief use - time.Sleep(1 * time.Millisecond) - - conn.Close() - } - }(i) + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + conn, err := pool.Get() + if err != nil { + return + } + ops.Add(1) + pool.Put(conn) + } + }() } wg.Wait() - elapsed := time.Since(start) - close(errors) - - errorCount := 0 - for err := range errors { - t.Logf("Error: %v", err) - errorCount++ - if errorCount >= 10 { - t.Log("... (more errors suppressed)") - break - } - } - - if errorCount > 0 { - t.Errorf("got %d errors during extreme contention", errorCount) - } - - stats := pool.Stats() - t.Logf("Extreme contention test completed in %v", elapsed) - t.Logf("Final stats: Active=%d, Idle=%d, InUse=%d", - stats.Active, stats.Idle, stats.InUse) - t.Logf("Total operations: %d", numGoroutines*iterationsPerGoroutine) - t.Logf("Operations/sec: %.0f", float64(numGoroutines*iterationsPerGoroutine)/elapsed.Seconds()) - - // Pool should be stable - if stats.Active > int(stats.MaxPool) { - t.Errorf("pool exceeded MaxPool: Active=%d, MaxPool=%d", stats.Active, stats.MaxPool) - } - - if stats.InUse != 0 { - t.Errorf("expected all connections returned: InUse=%d", stats.InUse) - } -} - -func TestDialTimeout(t *testing.T) { - slowDial := func() (net.Conn, error) { - time.Sleep(500 * time.Millisecond) - return net.Dial("tcp", "127.0.0.1:9999") - } - - start := time.Now() - _, err := New(slowDial, - WithMinPool(1), - WithMaxPool(5), - WithDialTimeout(100*time.Millisecond), - ) - elapsed := time.Since(start) - - if err != ErrDialTimeout { - t.Errorf("expected ErrDialTimeout, got %v", err) - } - - if elapsed >= 400*time.Millisecond { - t.Errorf("timeout took too long: %v, expected ~100ms", elapsed) - } - if elapsed < 50*time.Millisecond { - t.Errorf("timeout happened too quickly: %v", elapsed) - } - - t.Logf("Dial timeout detected after %v", elapsed) -} - -func TestDialTimeoutWithWorkingServer(t *testing.T) { - listener, addr := createTestServer(t) - defer listener.Close() - - pool, err := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, - WithMinPool(2), - WithMaxPool(5), - WithDialTimeout(5*time.Second), - ) - - if err != nil { - t.Fatalf("failed to create pool: %v", err) - } - defer pool.Close() - - conn, err := pool.Get() - if err != nil { - t.Fatalf("failed to get connection: %v", err) - } - defer conn.Close() - - _, err = conn.Write([]byte("test")) - if err != nil { - t.Fatalf("write failed: %v", err) - } - - stats := pool.Stats() - t.Logf("Pool stats with DialTimeout: Active=%d, Idle=%d, InUse=%d", - stats.Active, stats.Idle, stats.InUse) -} - -func BenchmarkPoolConcurrent(b *testing.B) { - listener, addr := createBenchmarkTestServer(b) - defer listener.Close() - - pool, _ := New(func() (net.Conn, error) { - return net.Dial("tcp", addr) - }, WithMinPool(10), WithMaxPool(300)) - defer pool.Close() - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - conn, err := pool.Get() - if err != nil { - b.Fatal(err) - } - conn.Write([]byte("test data")) - conn.Close() - } - }) -} - -func createBenchmarkTestServer(t testing.TB) (net.Listener, string) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("failed to create listener: %v", err) - } - - go func() { - for { - conn, err := listener.Accept() - if err != nil { - return - } - go func(c net.Conn) { - defer c.Close() - buf := make([]byte, 1024) - for { - n, err := c.Read(buf) - if err != nil { - return - } - c.Write(buf[:n]) - } - }(conn) - } - }() - - return listener, listener.Addr().String() + t.Logf("Completed %d operations", ops.Load()) } diff --git a/pool_connection.go b/pool_connection.go deleted file mode 100644 index 7ee9c48..0000000 --- a/pool_connection.go +++ /dev/null @@ -1,184 +0,0 @@ -package netpool - -import ( - "net" - "sync" - "sync/atomic" - "time" -) - -var pooledConnPool = sync.Pool{ - New: func() any { - return new(pooledConn) - }, -} - -// pooledConn wraps a net.Conn and automatically returns it to the pool on Close() -type pooledConn struct { - net.Conn - pool *Netpool - returned atomic.Bool - - // pointer to last non-nil error from Read/Write - lastErr atomic.Pointer[error] -} - -func newPooledConn(c net.Conn, p *Netpool) *pooledConn { - pc := pooledConnPool.Get().(*pooledConn) - pc.Conn = c - pc.pool = p - pc.returned.Store(false) - pc.lastErr.Store(nil) - return pc -} - -func (pc *pooledConn) setErr(err error) { - if err == nil { - return - } - pc.lastErr.CompareAndSwap(nil, &err) -} - -// Close returns the connection to the pool instead of closing it -func (pc *pooledConn) Close() error { - if !pc.returned.CompareAndSwap(false, true) { - return nil - } - - var err error - if v := pc.lastErr.Load(); v != nil { - err = *v - } - - // Reset deadline before returning to pool so next user gets clean connection - if pc.Conn != nil { - _ = pc.Conn.SetDeadline(time.Time{}) - } - - pc.pool.Put(pc.Conn, err) - - pc.Conn = nil - pc.pool = nil - pooledConnPool.Put(pc) - - return nil -} - -// MarkUnusable marks the connection as unusable -func (pc *pooledConn) MarkUnusable() error { - if !pc.returned.CompareAndSwap(false, true) { - return nil - } - pc.pool.Put(pc.Conn, ErrInvalidConn) - - pc.Conn = nil - pc.pool = nil - pooledConnPool.Put(pc) - return nil -} - -// Read implements net.Conn.Read -func (pc *pooledConn) Read(b []byte) (n int, err error) { - if pc.returned.Load() { - return 0, ErrConnReturned - } - - conn := pc.Conn - if conn == nil { - return 0, ErrConnReturned - } - - n, err = conn.Read(b) - if err != nil { - pc.setErr(err) - } - return n, err -} - -// Write implements net.Conn.Write -func (pc *pooledConn) Write(b []byte) (n int, err error) { - if pc.returned.Load() { - return 0, ErrConnReturned - } - - conn := pc.Conn - if conn == nil { - return 0, ErrConnReturned - } - - n, err = conn.Write(b) - if err != nil { - pc.setErr(err) - } - return n, err -} - -// SetDeadline implements net.Conn.SetDeadline -func (pc *pooledConn) SetDeadline(t time.Time) error { - if pc.returned.Load() { - return ErrConnReturned - } - - conn := pc.Conn - if conn == nil { - return ErrConnReturned - } - - return conn.SetDeadline(t) -} - -// SetReadDeadline implements net.Conn.SetReadDeadline -func (pc *pooledConn) SetReadDeadline(t time.Time) error { - if pc.returned.Load() { - return ErrConnReturned - } - - conn := pc.Conn - if conn == nil { - return ErrConnReturned - } - - return conn.SetReadDeadline(t) -} - -// SetWriteDeadline implements net.Conn.SetWriteDeadline -func (pc *pooledConn) SetWriteDeadline(t time.Time) error { - if pc.returned.Load() { - return ErrConnReturned - } - - conn := pc.Conn - if conn == nil { - return ErrConnReturned - } - - return conn.SetWriteDeadline(t) -} - -// LocalAddr implements net.Conn.LocalAddr -func (pc *pooledConn) LocalAddr() net.Addr { - if pc.returned.Load() { - return nil - } - - conn := pc.Conn - if conn == nil { - return nil - } - - return conn.LocalAddr() -} - -// RemoteAddr implements net.Conn.RemoteAddr -func (pc *pooledConn) RemoteAddr() net.Addr { - if pc.returned.Load() { - return nil - } - - conn := pc.Conn - if conn == nil { - return nil - } - - return conn.RemoteAddr() -}