Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ type Config struct {
//
// If zero, a sensible default is chosen based on MaxIdleTime.
MaintainerInterval time.Duration

// DialTimeout is the maximum duration allowed for creating a new connection.
//
// If the dial operation takes longer than this duration, it will be
// canceled and an error will be returned.
//
// A zero value means no timeout (dial may block indefinitely).
DialTimeout time.Duration
}

// Opt represents a functional option used to configure a Netpool.
Expand Down Expand Up @@ -92,3 +100,13 @@ func WithMaintainerInterval(d time.Duration) Opt {
c.MaintainerInterval = d
}
}

// WithDialTimeout sets the maximum duration for creating new connections.
//
// If a dial operation exceeds this timeout, it fails with a timeout error.
// A zero value disables the timeout.
func WithDialTimeout(d time.Duration) Opt {
return func(c *Config) {
c.DialTimeout = d
}
}
3 changes: 3 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ var (

// ErrConnReturned is returned when connection already returned
ErrConnReturned = errors.New("connection already returned")

// ErrDialTimeout is returned when dial operation exceeds the configured timeout
ErrDialTimeout = errors.New("netpool: dial timeout exceeded")
)
53 changes: 40 additions & 13 deletions netpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,39 @@ func New(fn netpoolFunc, opts ...Opt) (*Netpool, error) {
}

func (netpool *Netpool) createConnection() (net.Conn, error) {
conn, err := netpool.fn()
var conn net.Conn
var err error

// If DialTimeout is set, wrap the dial in a timeout
if netpool.config.DialTimeout > 0 {
type dialResult struct {
conn net.Conn
err error
}
result := make(chan dialResult, 1)

go func() {
c, e := netpool.fn()
result <- dialResult{conn: c, err: e}
}()

select {
case r := <-result:
conn, err = r.conn, r.err
case <-time.After(netpool.config.DialTimeout):
// Dial timed out - if dial eventually succeeds, connection will leak
// To prevent this, we spawn a cleanup goroutine
go func() {
if r := <-result; r.conn != nil {
r.conn.Close()
}
}()
return nil, ErrDialTimeout
}
} else {
conn, err = netpool.fn()
}

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -259,20 +291,15 @@ func (netpool *Netpool) getWithContext(ctx context.Context) (net.Conn, error) {
return c, nil
}

done := make(chan struct{})

go func() {
select {
case <-ctx.Done():
netpool.mu.Lock()
netpool.cond.Broadcast()
netpool.mu.Unlock()
case <-done:
}
}()
// Use context.AfterFunc for cleaner context cancellation handling (Go 1.21+)
stop := context.AfterFunc(ctx, func() {
netpool.mu.Lock()
netpool.cond.Broadcast()
netpool.mu.Unlock()
})

netpool.cond.Wait()
close(done)
stop()
}
}

Expand Down
61 changes: 61 additions & 0 deletions netpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,67 @@ func TestExtremeContention(t *testing.T) {
}
}

func TestDialTimeout(t *testing.T) {
slowDial := func() (net.Conn, error) {
time.Sleep(500 * time.Millisecond)
return net.Dial("tcp", "127.0.0.1:9999")
}

start := time.Now()
_, err := New(slowDial,
WithMinPool(1),
WithMaxPool(5),
WithDialTimeout(100*time.Millisecond),
)
elapsed := time.Since(start)

if err != ErrDialTimeout {
t.Errorf("expected ErrDialTimeout, got %v", err)
}

if elapsed >= 400*time.Millisecond {
t.Errorf("timeout took too long: %v, expected ~100ms", elapsed)
}
if elapsed < 50*time.Millisecond {
t.Errorf("timeout happened too quickly: %v", elapsed)
}

t.Logf("Dial timeout detected after %v", elapsed)
}

func TestDialTimeoutWithWorkingServer(t *testing.T) {
listener, addr := createTestServer(t)
defer listener.Close()

pool, err := New(func() (net.Conn, error) {
return net.Dial("tcp", addr)
},
WithMinPool(2),
WithMaxPool(5),
WithDialTimeout(5*time.Second),
)

if err != nil {
t.Fatalf("failed to create pool: %v", err)
}
defer pool.Close()

conn, err := pool.Get()
if err != nil {
t.Fatalf("failed to get connection: %v", err)
}
defer conn.Close()

_, err = conn.Write([]byte("test"))
if err != nil {
t.Fatalf("write failed: %v", err)
}

stats := pool.Stats()
t.Logf("Pool stats with DialTimeout: Active=%d, Idle=%d, InUse=%d",
stats.Active, stats.Idle, stats.InUse)
}

func BenchmarkPoolConcurrent(b *testing.B) {
listener, addr := createBenchmarkTestServer(b)
defer listener.Close()
Expand Down
96 changes: 93 additions & 3 deletions pool_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net"
"sync"
"sync/atomic"
"time"
)

var pooledConnPool = sync.Pool{
Expand Down Expand Up @@ -48,6 +49,12 @@ func (pc *pooledConn) Close() error {
if v := pc.lastErr.Load(); v != nil {
err = *v
}

// Reset deadline before returning to pool so next user gets clean connection
if pc.Conn != nil {
_ = pc.Conn.SetDeadline(time.Time{})
}

pc.pool.Put(pc.Conn, err)

pc.Conn = nil
Expand All @@ -62,8 +69,11 @@ func (pc *pooledConn) MarkUnusable() error {
if !pc.returned.CompareAndSwap(false, true) {
return nil
}

pc.pool.Put(pc.Conn, ErrInvalidConn)

pc.Conn = nil
pc.pool = nil
pooledConnPool.Put(pc)
return nil
}

Expand All @@ -73,7 +83,12 @@ func (pc *pooledConn) Read(b []byte) (n int, err error) {
return 0, ErrConnReturned
}

n, err = pc.Conn.Read(b)
conn := pc.Conn
if conn == nil {
return 0, ErrConnReturned
}

n, err = conn.Read(b)
if err != nil {
pc.setErr(err)
}
Expand All @@ -86,9 +101,84 @@ func (pc *pooledConn) Write(b []byte) (n int, err error) {
return 0, ErrConnReturned
}

n, err = pc.Conn.Write(b)
conn := pc.Conn
if conn == nil {
return 0, ErrConnReturned
}

n, err = conn.Write(b)
if err != nil {
pc.setErr(err)
}
return n, err
}

// SetDeadline implements net.Conn.SetDeadline
func (pc *pooledConn) SetDeadline(t time.Time) error {
if pc.returned.Load() {
return ErrConnReturned
}

conn := pc.Conn
if conn == nil {
return ErrConnReturned
}

return conn.SetDeadline(t)
}

// SetReadDeadline implements net.Conn.SetReadDeadline
func (pc *pooledConn) SetReadDeadline(t time.Time) error {
if pc.returned.Load() {
return ErrConnReturned
}

conn := pc.Conn
if conn == nil {
return ErrConnReturned
}

return conn.SetReadDeadline(t)
}

// SetWriteDeadline implements net.Conn.SetWriteDeadline
func (pc *pooledConn) SetWriteDeadline(t time.Time) error {
if pc.returned.Load() {
return ErrConnReturned
}

conn := pc.Conn
if conn == nil {
return ErrConnReturned
}

return conn.SetWriteDeadline(t)
}

// LocalAddr implements net.Conn.LocalAddr
func (pc *pooledConn) LocalAddr() net.Addr {
if pc.returned.Load() {
return nil
}

conn := pc.Conn
if conn == nil {
return nil
}

return conn.LocalAddr()
}

// RemoteAddr implements net.Conn.RemoteAddr
func (pc *pooledConn) RemoteAddr() net.Addr {
if pc.returned.Load() {
return nil
}

conn := pc.Conn
if conn == nil {
return nil
}

return conn.RemoteAddr()
}
Loading