diff --git a/config.go b/config.go index 61de3b9..72a5a94 100644 --- a/config.go +++ b/config.go @@ -40,6 +40,14 @@ type Config struct { // // If zero, a sensible default is chosen based on MaxIdleTime. MaintainerInterval time.Duration + + // DialTimeout is the maximum duration allowed for creating a new connection. + // + // If the dial operation takes longer than this duration, it will be + // canceled and an error will be returned. + // + // A zero value means no timeout (dial may block indefinitely). + DialTimeout time.Duration } // Opt represents a functional option used to configure a Netpool. @@ -92,3 +100,13 @@ func WithMaintainerInterval(d time.Duration) Opt { c.MaintainerInterval = d } } + +// WithDialTimeout sets the maximum duration for creating new connections. +// +// If a dial operation exceeds this timeout, it fails with a timeout error. +// A zero value disables the timeout. +func WithDialTimeout(d time.Duration) Opt { + return func(c *Config) { + c.DialTimeout = d + } +} diff --git a/errors.go b/errors.go index 54d126a..62af424 100644 --- a/errors.go +++ b/errors.go @@ -14,4 +14,7 @@ var ( // ErrConnReturned is returned when connection already returned ErrConnReturned = errors.New("connection already returned") + + // ErrDialTimeout is returned when dial operation exceeds the configured timeout + ErrDialTimeout = errors.New("netpool: dial timeout exceeded") ) diff --git a/netpool.go b/netpool.go index 87c2ea6..1935bfa 100644 --- a/netpool.go +++ b/netpool.go @@ -192,7 +192,39 @@ func New(fn netpoolFunc, opts ...Opt) (*Netpool, error) { } func (netpool *Netpool) createConnection() (net.Conn, error) { - conn, err := netpool.fn() + var conn net.Conn + var err error + + // If DialTimeout is set, wrap the dial in a timeout + if netpool.config.DialTimeout > 0 { + type dialResult struct { + conn net.Conn + err error + } + result := make(chan dialResult, 1) + + go func() { + c, e := netpool.fn() + result <- dialResult{conn: c, err: e} + }() + + select { + case r := <-result: + conn, err = r.conn, r.err + case <-time.After(netpool.config.DialTimeout): + // Dial timed out - if dial eventually succeeds, connection will leak + // To prevent this, we spawn a cleanup goroutine + go func() { + if r := <-result; r.conn != nil { + r.conn.Close() + } + }() + return nil, ErrDialTimeout + } + } else { + conn, err = netpool.fn() + } + if err != nil { return nil, err } @@ -259,20 +291,15 @@ func (netpool *Netpool) getWithContext(ctx context.Context) (net.Conn, error) { return c, nil } - done := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - netpool.mu.Lock() - netpool.cond.Broadcast() - netpool.mu.Unlock() - case <-done: - } - }() + // Use context.AfterFunc for cleaner context cancellation handling (Go 1.21+) + stop := context.AfterFunc(ctx, func() { + netpool.mu.Lock() + netpool.cond.Broadcast() + netpool.mu.Unlock() + }) netpool.cond.Wait() - close(done) + stop() } } diff --git a/netpool_test.go b/netpool_test.go index c3f2c5c..e4de610 100644 --- a/netpool_test.go +++ b/netpool_test.go @@ -1165,6 +1165,67 @@ func TestExtremeContention(t *testing.T) { } } +func TestDialTimeout(t *testing.T) { + slowDial := func() (net.Conn, error) { + time.Sleep(500 * time.Millisecond) + return net.Dial("tcp", "127.0.0.1:9999") + } + + start := time.Now() + _, err := New(slowDial, + WithMinPool(1), + WithMaxPool(5), + WithDialTimeout(100*time.Millisecond), + ) + elapsed := time.Since(start) + + if err != ErrDialTimeout { + t.Errorf("expected ErrDialTimeout, got %v", err) + } + + if elapsed >= 400*time.Millisecond { + t.Errorf("timeout took too long: %v, expected ~100ms", elapsed) + } + if elapsed < 50*time.Millisecond { + t.Errorf("timeout happened too quickly: %v", elapsed) + } + + t.Logf("Dial timeout detected after %v", elapsed) +} + +func TestDialTimeoutWithWorkingServer(t *testing.T) { + listener, addr := createTestServer(t) + defer listener.Close() + + pool, err := New(func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, + WithMinPool(2), + WithMaxPool(5), + WithDialTimeout(5*time.Second), + ) + + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + defer pool.Close() + + conn, err := pool.Get() + if err != nil { + t.Fatalf("failed to get connection: %v", err) + } + defer conn.Close() + + _, err = conn.Write([]byte("test")) + if err != nil { + t.Fatalf("write failed: %v", err) + } + + stats := pool.Stats() + t.Logf("Pool stats with DialTimeout: Active=%d, Idle=%d, InUse=%d", + stats.Active, stats.Idle, stats.InUse) +} + func BenchmarkPoolConcurrent(b *testing.B) { listener, addr := createBenchmarkTestServer(b) defer listener.Close() diff --git a/pool_connection.go b/pool_connection.go index 7e863fb..7ee9c48 100644 --- a/pool_connection.go +++ b/pool_connection.go @@ -4,6 +4,7 @@ import ( "net" "sync" "sync/atomic" + "time" ) var pooledConnPool = sync.Pool{ @@ -48,6 +49,12 @@ func (pc *pooledConn) Close() error { if v := pc.lastErr.Load(); v != nil { err = *v } + + // Reset deadline before returning to pool so next user gets clean connection + if pc.Conn != nil { + _ = pc.Conn.SetDeadline(time.Time{}) + } + pc.pool.Put(pc.Conn, err) pc.Conn = nil @@ -62,8 +69,11 @@ func (pc *pooledConn) MarkUnusable() error { if !pc.returned.CompareAndSwap(false, true) { return nil } - pc.pool.Put(pc.Conn, ErrInvalidConn) + + pc.Conn = nil + pc.pool = nil + pooledConnPool.Put(pc) return nil } @@ -73,7 +83,12 @@ func (pc *pooledConn) Read(b []byte) (n int, err error) { return 0, ErrConnReturned } - n, err = pc.Conn.Read(b) + conn := pc.Conn + if conn == nil { + return 0, ErrConnReturned + } + + n, err = conn.Read(b) if err != nil { pc.setErr(err) } @@ -86,9 +101,84 @@ func (pc *pooledConn) Write(b []byte) (n int, err error) { return 0, ErrConnReturned } - n, err = pc.Conn.Write(b) + conn := pc.Conn + if conn == nil { + return 0, ErrConnReturned + } + + n, err = conn.Write(b) if err != nil { pc.setErr(err) } return n, err } + +// SetDeadline implements net.Conn.SetDeadline +func (pc *pooledConn) SetDeadline(t time.Time) error { + if pc.returned.Load() { + return ErrConnReturned + } + + conn := pc.Conn + if conn == nil { + return ErrConnReturned + } + + return conn.SetDeadline(t) +} + +// SetReadDeadline implements net.Conn.SetReadDeadline +func (pc *pooledConn) SetReadDeadline(t time.Time) error { + if pc.returned.Load() { + return ErrConnReturned + } + + conn := pc.Conn + if conn == nil { + return ErrConnReturned + } + + return conn.SetReadDeadline(t) +} + +// SetWriteDeadline implements net.Conn.SetWriteDeadline +func (pc *pooledConn) SetWriteDeadline(t time.Time) error { + if pc.returned.Load() { + return ErrConnReturned + } + + conn := pc.Conn + if conn == nil { + return ErrConnReturned + } + + return conn.SetWriteDeadline(t) +} + +// LocalAddr implements net.Conn.LocalAddr +func (pc *pooledConn) LocalAddr() net.Addr { + if pc.returned.Load() { + return nil + } + + conn := pc.Conn + if conn == nil { + return nil + } + + return conn.LocalAddr() +} + +// RemoteAddr implements net.Conn.RemoteAddr +func (pc *pooledConn) RemoteAddr() net.Addr { + if pc.returned.Load() { + return nil + } + + conn := pc.Conn + if conn == nil { + return nil + } + + return conn.RemoteAddr() +}