From 2743913af7745397ec0a92c5c23c2b23907379a6 Mon Sep 17 00:00:00 2001 From: Maxim Vladimirsky Date: Fri, 20 Nov 2015 16:57:10 -0800 Subject: [PATCH 1/2] Rebase to master --- helpers_test.go | 2 +- listener.go | 190 +++++++++++++++++++++++++++++++++++++ server.go | 145 +++++++++++++++++++++------- server_test.go | 124 +++++++++++++++++++++++- test_helpers/conn.go | 7 +- test_helpers/wait_group.go | 10 +- transition_test.go | 4 +- 7 files changed, 439 insertions(+), 43 deletions(-) create mode 100644 listener.go diff --git a/helpers_test.go b/helpers_test.go index dd9a8ba..5a2e575 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -86,7 +86,7 @@ func startGenericServer(t *testing.T, server *GracefulServer, statechanged chan // Wrap the ConnState handler with something that will notify // the statechanged channel when a state change happens server.ConnState = func(conn net.Conn, newState http.ConnState) { - statechanged <- newState + statechanged <- conn.LocalAddr().(*gracefulAddr).gconn.lastHTTPState } } diff --git a/listener.go b/listener.go new file mode 100644 index 0000000..ccb1c9f --- /dev/null +++ b/listener.go @@ -0,0 +1,190 @@ +package manners + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "os" + "sync" + "time" +) + +// NewListener wraps an existing listener for use with +// GracefulServer. +// +// Note that you generally don't need to use this directly as +// GracefulServer will automatically wrap any non-graceful listeners +// supplied to it. +func NewListener(l net.Listener) *GracefulListener { + return &GracefulListener{ + listener: l, + mutex: &sync.RWMutex{}, + open: true, + } +} + +// A gracefulCon wraps a normal net.Conn and tracks the last known http state. +type gracefulConn struct { + net.Conn + lastHTTPState http.ConnState + // protected tells whether the connection is going to defer server shutdown + // until the current HTTP request is completed. + protected bool +} + +type gracefulAddr struct { + net.Addr + gconn *gracefulConn +} + +func (g *gracefulConn) LocalAddr() net.Addr { + return &gracefulAddr{g.Conn.LocalAddr(), g} +} + +// retrieveGracefulConn retrieves a concrete gracefulConn instance from an +// interface value that can either refer to it directly or refer to a tls.Conn +// instance wrapping around a gracefulConn one. +func retrieveGracefulConn(conn net.Conn) *gracefulConn { + return conn.LocalAddr().(*gracefulAddr).gconn +} + +// A GracefulListener differs from a standard net.Listener in one way: if +// Accept() is called after it is gracefully closed, it returns a +// listenerAlreadyClosed error. The GracefulServer will ignore this error. +type GracefulListener struct { + listener net.Listener + open bool + mutex *sync.RWMutex +} + +func (l *GracefulListener) isClosed() bool { + l.mutex.RLock() + defer l.mutex.RUnlock() + return !l.open +} + +func (l *GracefulListener) Addr() net.Addr { + return l.listener.Addr() +} + +// Accept implements the Accept method in the Listener interface. +func (l *GracefulListener) Accept() (net.Conn, error) { + conn, err := l.listener.Accept() + if err != nil { + if l.isClosed() { + err = listenerAlreadyClosed{err} + } + return nil, err + } + + // don't wrap connection if it's tls so we won't break + // http server internal logic that relies on the type + if _, ok := conn.(*tls.Conn); ok { + return conn, nil + } + return &gracefulConn{Conn: conn}, nil +} + +// Close tells the wrapped listener to stop listening. It is idempotent. +func (l *GracefulListener) Close() error { + l.mutex.Lock() + defer l.mutex.Unlock() + if !l.open { + return nil + } + l.open = false + return l.listener.Close() +} + +func (l *GracefulListener) GetFile() (*os.File, error) { + return getListenerFile(l.listener) +} + +func (l *GracefulListener) Clone() (net.Listener, error) { + l.mutex.Lock() + defer l.mutex.Unlock() + + if !l.open { + return nil, fmt.Errorf("listener is already closed") + } + + file, err := l.GetFile() + if err != nil { + return nil, err + } + defer file.Close() + + fl, err := net.FileListener(file) + if nil != err { + return nil, err + } + return fl, nil +} + +// A listener implements a network listener (net.Listener) for TLS connections. +// direct lift from crypto/tls.go +type TLSListener struct { + net.Listener + config *tls.Config +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection c is a *tls.Conn. +func (l *TLSListener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + if err != nil { + return + } + c = tls.Server(&gracefulConn{Conn: c}, l.config) + return +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must have +// at least one certificate. +func NewTLSListener(inner net.Listener, config *tls.Config) net.Listener { + l := new(TLSListener) + l.Listener = inner + l.config = config + return l +} + +type listenerAlreadyClosed struct { + error +} + +// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +// +// direct lift from net/http/server.go +type TCPKeepAliveListener struct { + *net.TCPListener +} + +func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +func getListenerFile(listener net.Listener) (*os.File, error) { + switch t := listener.(type) { + case *net.TCPListener: + return t.File() + case *net.UnixListener: + return t.File() + case TCPKeepAliveListener: + return t.TCPListener.File() + case *TLSListener: + return getListenerFile(t.Listener) + } + return nil, fmt.Errorf("Unsupported listener: %T", listener) +} diff --git a/server.go b/server.go index e45f5c6..91664ad 100644 --- a/server.go +++ b/server.go @@ -44,10 +44,21 @@ import ( "crypto/tls" "net" "net/http" + "os" "sync" "sync/atomic" ) +// StateHandler can be called by the server if the state of the connection changes. +// Notice that it passed previous state and the new state as parameters. +type StateHandler func(net.Conn, http.ConnState, http.ConnState) + +type Options struct { + Server *http.Server + StateHandler StateHandler + Listener net.Listener +} + // A GracefulServer maintains a WaitGroup that counts how many in-flight // requests the server is handling. When it receives a shutdown signal, // it stops accepting new requests but does not actually shut down until @@ -56,20 +67,25 @@ import ( // GracefulServer embeds the underlying net/http.Server making its non-override // methods and properties avaiable. // -// It must be initialized by calling NewWithServer. +// It must be initialized by calling NewServer or NewWithServer type GracefulServer struct { *http.Server shutdown chan bool shutdownFinished chan bool wg waitGroup - - lcsmu sync.RWMutex - connections map[net.Conn]bool + listener *GracefulListener + stateHandler StateHandler up chan net.Listener // Only used by test code. } +// NewServer creates a new GracefulServer. The server will begin shutting down +// when a value is passed to the Shutdown channel. +func NewServer() *GracefulServer { + return NewWithServer(new(http.Server)) +} + // NewWithServer wraps an existing http.Server object and returns a // GracefulServer that supports all of the original Server operations. func NewWithServer(s *http.Server) *GracefulServer { @@ -78,7 +94,28 @@ func NewWithServer(s *http.Server) *GracefulServer { shutdown: make(chan bool), shutdownFinished: make(chan bool, 1), wg: new(sync.WaitGroup), - connections: make(map[net.Conn]bool), + } +} + +func NewWithOptions(o Options) *GracefulServer { + // Set up listener + var listener *GracefulListener + if o.Listener != nil { + g, ok := o.Listener.(*GracefulListener) + if !ok { + listener = NewListener(o.Listener) + } else { + listener = g + } + } + + return &GracefulServer{ + listener: listener, + Server: o.Server, + stateHandler: o.StateHandler, + shutdown: make(chan bool), + shutdownFinished: make(chan bool, 1), + wg: new(sync.WaitGroup), } } @@ -98,16 +135,14 @@ func (s *GracefulServer) BlockingClose() bool { // ListenAndServe provides a graceful equivalent of net/http.Serve.ListenAndServe. func (s *GracefulServer) ListenAndServe() error { - addr := s.Addr - if addr == "" { - addr = ":http" - } - listener, err := net.Listen("tcp", addr) - if err != nil { - return err + if s.listener == nil { + oldListener, err := net.Listen("tcp", s.Addr) + if err != nil { + return err + } + s.listener = NewListener(oldListener.(*net.TCPListener)) } - - return s.Serve(listener) + return s.Serve(s.listener) } // ListenAndServeTLS provides a graceful equivalent of net/http.Serve.ListenAndServeTLS. @@ -132,16 +167,62 @@ func (s *GracefulServer) ListenAndServeTLS(certFile, keyFile string) error { return err } - ln, err := net.Listen("tcp", addr) + return s.ListenAndServeTLSWithConfig(config) +} + +// ListenAndServeTLS provides a graceful equivalent of net/http.Serve.ListenAndServeTLS. +func (s *GracefulServer) ListenAndServeTLSWithConfig(config *tls.Config) error { + addr := s.Addr + if addr == "" { + addr = ":https" + } + + if s.listener == nil { + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + tlsListener := NewTLSListener(TCPKeepAliveListener{ln.(*net.TCPListener)}, config) + s.listener = NewListener(tlsListener) + } + return s.Serve(s.listener) +} + +func (gs *GracefulServer) GetFile() (*os.File, error) { + return gs.listener.GetFile() +} + +func (gs *GracefulServer) HijackListener(s *http.Server, config *tls.Config) (*GracefulServer, error) { + listener, err := gs.listener.Clone() if err != nil { - return err + return nil, err } - return s.Serve(tls.NewListener(ln, config)) + if config != nil { + listener = NewTLSListener(TCPKeepAliveListener{listener.(*net.TCPListener)}, config) + } + + other := NewWithServer(s) + other.listener = NewListener(listener) + return other, nil } // Serve provides a graceful equivalent net/http.Server.Serve. +// +// If listener is not an instance of *GracefulListener it will be wrapped +// to become one. func (s *GracefulServer) Serve(listener net.Listener) error { + // Accept a net.Listener to preserve the interface compatibility with the + // standard http.Server. If it is not a GracefulListener then wrap it into + // one. + gracefulListener, ok := listener.(*GracefulListener) + if !ok { + gracefulListener = NewListener(listener) + listener = gracefulListener + } + s.listener = gracefulListener + // Wrap the server HTTP handler into graceful one, that will close kept // alive connections if a new request is received after shutdown. gracefulHandler := newGracefulHandler(s.Server.Handler) @@ -155,7 +236,7 @@ func (s *GracefulServer) Serve(listener net.Listener) error { close(s.shutdown) gracefulHandler.Close() s.Server.SetKeepAlivesEnabled(false) - listener.Close() + gracefulListener.Close() }() originalConnState := s.Server.ConnState @@ -164,44 +245,40 @@ func (s *GracefulServer) Serve(listener net.Listener) error { // changes state. It keeps track of each connection's state over time, // enabling manners to handle persisted connections correctly. s.ConnState = func(conn net.Conn, newState http.ConnState) { - s.lcsmu.RLock() - protected := s.connections[conn] - s.lcsmu.RUnlock() + gracefulConn := retrieveGracefulConn(conn) + oldState := gracefulConn.lastHTTPState + gracefulConn.lastHTTPState = newState switch newState { case http.StateNew: // New connection -> StateNew - protected = true + gracefulConn.protected = true s.StartRoutine() case http.StateActive: // (StateNew, StateIdle) -> StateActive if gracefulHandler.IsClosed() { - conn.Close() + gracefulConn.Close() break } - if !protected { - protected = true + if !gracefulConn.protected { + gracefulConn.protected = true s.StartRoutine() } default: // (StateNew, StateActive) -> (StateIdle, StateClosed, StateHiJacked) - if protected { + if gracefulConn.protected { s.FinishRoutine() - protected = false + gracefulConn.protected = false } } - s.lcsmu.Lock() - if newState == http.StateClosed || newState == http.StateHijacked { - delete(s.connections, conn) - } else { - s.connections[conn] = protected + if s.stateHandler != nil { + s.stateHandler(conn, oldState, newState) } - s.lcsmu.Unlock() if originalConnState != nil { originalConnState(conn, newState) @@ -216,7 +293,7 @@ func (s *GracefulServer) Serve(listener net.Listener) error { err := s.Server.Serve(listener) // An error returned on shutdown is not worth reporting. - if err != nil && gracefulHandler.IsClosed() { + if _, ok = err.(listenerAlreadyClosed); ok { err = nil } diff --git a/server_test.go b/server_test.go index 1ab7f19..073bdd2 100644 --- a/server_test.go +++ b/server_test.go @@ -1,7 +1,7 @@ package manners import ( - helpers "github.com/braintree/manners/test_helpers" + helpers "github.com/mailgun/manners/test_helpers" "net" "net/http" "testing" @@ -252,3 +252,125 @@ func TestStateTransitionActiveIdleClosed(t *testing.T) { } } } + +// Test that supplying a non GracefulListener to Serve works +// correctly (ie. that the listener is wrapped to become graceful) +func TestWrapConnection(t *testing.T) { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Failed to create listener", err) + } + + s := NewServer() + s.up = make(chan net.Listener) + + var called bool + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + s.Close() // clean shutdown as soon as handler exits + }) + s.Handler = handler + + serverr := make(chan error) + + go func() { + serverr <- s.Serve(l) + }() + + gl := <-s.up + if _, ok := gl.(*GracefulListener); !ok { + t.Fatal("connection was not wrapped into a GracefulListener") + } + + addr := l.Addr() + if _, err := http.Get("http://" + addr.String()); err != nil { + t.Fatal("Get failed", err) + } + + if err := <-serverr; err != nil { + t.Fatal("Error from Serve()", err) + } + + if !called { + t.Error("Handler was not called") + } + +} + +// Hijack listener +func TestHijackListener(t *testing.T) { + server := NewServer() + wg := helpers.NewWaitGroup() + server.wg = wg + listener, exitchan := startServer(t, server, nil) + + client := newClient(listener.Addr(), false) + client.Run() + + // wait for client to connect, but don't let it send the request yet + if err := <-client.connected; err != nil { + t.Fatal("Client failed to connect to server", err) + } + + // Make sure server1 got the request and added it to the waiting group + <-wg.CountChanged + + wg2 := helpers.NewWaitGroup() + server2, err := server.HijackListener(new(http.Server), nil) + server2.wg = wg2 + if err != nil { + t.Fatal("Failed to hijack listener", err) + } + + listener2, exitchan2 := startServer(t, server2, nil) + + // Close the first server + server.Close() + + // First server waits for the first request to finish + waiting := <-wg.WaitCalled + if waiting < 1 { + t.Errorf("Expected the waitgroup to equal 1 at shutdown; actually %d", waiting) + } + + // allow the client to finish sending the request and make sure the server exits after + // (client will be in connected but idle state at that point) + client.sendrequest <- true + close(client.sendrequest) + if err := <-exitchan; err != nil { + t.Error("Unexpected error during shutdown", err) + } + + client2 := newClient(listener2.Addr(), false) + client2.Run() + + // wait for client to connect, but don't let it send the request yet + select { + case err := <-client2.connected: + if err != nil { + t.Fatal("Client failed to connect to server", err) + } + case <-time.After(time.Second): + t.Fatal("Timeout connecting to the server", err) + } + + // Close the second server + server2.Close() + + waiting = <-wg2.WaitCalled + if waiting < 1 { + t.Errorf("Expected the waitgroup to equal 1 at shutdown; actually %d", waiting) + } + + // allow the client to finish sending the request and make sure the server exits after + // (client will be in connected but idle state at that point) + client2.sendrequest <- true + // Make sure that request resulted in success + if rr := <-client2.response; rr.err != nil { + t.Errorf("Client failed to write the request, error: %s", err) + } + close(client2.sendrequest) + if err := <-exitchan2; err != nil { + t.Error("Unexpected error during shutdown", err) + } +} diff --git a/test_helpers/conn.go b/test_helpers/conn.go index 8c610f5..915c534 100644 --- a/test_helpers/conn.go +++ b/test_helpers/conn.go @@ -4,10 +4,13 @@ import "net" type Conn struct { net.Conn - CloseCalled bool + localAddr net.Addr +} + +func (f *Conn) LocalAddr() net.Addr { + return &net.IPAddr{} } func (c *Conn) Close() error { - c.CloseCalled = true return nil } diff --git a/test_helpers/wait_group.go b/test_helpers/wait_group.go index 1df590d..192a121 100644 --- a/test_helpers/wait_group.go +++ b/test_helpers/wait_group.go @@ -4,25 +4,29 @@ import "sync" type WaitGroup struct { sync.Mutex - Count int - WaitCalled chan int + Count int + WaitCalled chan int + CountChanged chan int } func NewWaitGroup() *WaitGroup { return &WaitGroup{ - WaitCalled: make(chan int, 1), + WaitCalled: make(chan int, 1), + CountChanged: make(chan int, 1024), } } func (wg *WaitGroup) Add(delta int) { wg.Lock() wg.Count++ + wg.CountChanged <- wg.Count wg.Unlock() } func (wg *WaitGroup) Done() { wg.Lock() wg.Count-- + wg.CountChanged <- wg.Count wg.Unlock() } diff --git a/transition_test.go b/transition_test.go index 34fe5c6..3ca63e6 100644 --- a/transition_test.go +++ b/transition_test.go @@ -1,7 +1,7 @@ package manners import ( - helpers "github.com/braintree/manners/test_helpers" + helpers "github.com/mailgun/manners/test_helpers" "net/http" "strings" "testing" @@ -36,7 +36,7 @@ func testStateTransition(t *testing.T, test transitionTest) { server.wg = wg startServer(t, server, nil) - conn := &helpers.Conn{} + conn := &gracefulConn{Conn: &helpers.Conn{}} for _, newState := range test.states { server.ConnState(conn, newState) } From 54d2603390df2bf8b8ada2006b5a17b5c2379552 Mon Sep 17 00:00:00 2001 From: Russell Haering Date: Mon, 4 Jan 2016 13:56:32 -0800 Subject: [PATCH 2/2] Extract and utilize an implicit "filer" interface for listeners --- listener.go | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/listener.go b/listener.go index ccb1c9f..d99d9aa 100644 --- a/listener.go +++ b/listener.go @@ -140,6 +140,10 @@ func (l *TLSListener) Accept() (c net.Conn, err error) { return } +func (l *TLSListener) File() (*os.File, error) { + return getListenerFile(l.Listener) +} + // NewListener creates a Listener which accepts connections from an inner // Listener and wraps each connection with Server. // The configuration config must be non-nil and must have @@ -155,6 +159,10 @@ type listenerAlreadyClosed struct { error } +type filer interface { + File() (*os.File, error) +} + // TCPKeepAliveListener sets TCP keep-alive timeouts on accepted // connections. It's used by ListenAndServe and ListenAndServeTLS so // dead TCP connections (e.g. closing laptop mid-download) eventually @@ -175,16 +183,15 @@ func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) { return tc, nil } +func (ln TCPKeepAliveListener) File() (*os.File, error) { + return ln.TCPListener.File() +} + func getListenerFile(listener net.Listener) (*os.File, error) { - switch t := listener.(type) { - case *net.TCPListener: - return t.File() - case *net.UnixListener: - return t.File() - case TCPKeepAliveListener: - return t.TCPListener.File() - case *TLSListener: - return getListenerFile(t.Listener) + fl, ok := listener.(filer) + if !ok { + return nil, fmt.Errorf("Unsupported listener: %T", listener) } - return nil, fmt.Errorf("Unsupported listener: %T", listener) + + return fl.File() }