diff --git a/listener.go b/listener.go index 63d2db5..a0df657 100644 --- a/listener.go +++ b/listener.go @@ -2,14 +2,16 @@ package manners import ( "crypto/tls" - "errors" "fmt" "net" "os" "sync" + "sync/atomic" "time" ) +var listenerClosed = fmt.Errorf("listener is closed") + // NewListener wraps an existing listener for use with // GracefulServer. // @@ -17,26 +19,20 @@ import ( // GracefulServer will automatically wrap any non-graceful listeners // supplied to it. func NewListener(l net.Listener) *GracefulListener { - return &GracefulListener{ - listener: l, - mutex: &sync.RWMutex{}, - open: true, - } + return &GracefulListener{listener: l} } // A GracefulListener differs from a standard net.Listener in one way: if // Accept() is called after it is gracefully closed, it returns a // listenerAlreadyClosed error. The GracefulServer will ignore this error. type GracefulListener struct { - listener net.Listener - open bool - mutex *sync.RWMutex + listener net.Listener + closeOnce sync.Once + closed int32 // accessed atomically } func (l *GracefulListener) isClosed() bool { - l.mutex.RLock() - defer l.mutex.RUnlock() - return !l.open + return atomic.LoadInt32(&l.closed) == 1 } func (l *GracefulListener) Addr() net.Addr { @@ -45,25 +41,19 @@ func (l *GracefulListener) Addr() net.Addr { // Accept implements the Accept method in the Listener interface. func (l *GracefulListener) Accept() (net.Conn, error) { - conn, err := l.listener.Accept() - if err != nil { - if l.isClosed() { - err = fmt.Errorf("listener already closed: err=(%s)", err) - } - return nil, err + if l.isClosed() { + return nil, listenerClosed } - return conn, nil + return l.listener.Accept() } // Close tells the wrapped listener to stop listening. It is idempotent. -func (l *GracefulListener) Close() error { - l.mutex.Lock() - defer l.mutex.Unlock() - if !l.open { - return nil - } - l.open = false - return l.listener.Close() +func (l *GracefulListener) Close() (err error) { + l.closeOnce.Do(func() { + atomic.StoreInt32(&l.closed, 1) + err = l.listener.Close() + }) + return } func (l *GracefulListener) GetFile() (*os.File, error) { @@ -71,11 +61,8 @@ func (l *GracefulListener) GetFile() (*os.File, error) { } func (l *GracefulListener) Clone() (net.Listener, error) { - l.mutex.Lock() - defer l.mutex.Unlock() - - if !l.open { - return nil, errors.New("listener is already closed") + if l.isClosed() { + return nil, listenerClosed } file, err := l.GetFile()