From a02f1e4d79c65194551a05fc3c8144df152be5a7 Mon Sep 17 00:00:00 2001 From: losfair Date: Sun, 14 Aug 2022 14:58:18 +0000 Subject: [PATCH] Improve startup state handling --- server.go | 56 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/server.go b/server.go index 190e0a8..e4fdd6a 100644 --- a/server.go +++ b/server.go @@ -244,29 +244,46 @@ func (s *Server) serveConn(ctx context.Context, c *Conn) error { } } +type startupState struct { + *Server + done bool + expectingStartup bool + expectingSSL bool +} + func (s *Server) serveConnStartup(ctx context.Context, c *Conn) error { - msg, err := c.backend.ReceiveStartupMessage() - if err != nil { - return fmt.Errorf("receive startup message: %w", err) - } + state := &startupState{Server: s, done: false, expectingStartup: true, expectingSSL: true} - switch msg := msg.(type) { - case *pgproto3.StartupMessage: - if err := s.handleStartupMessage(ctx, c, msg); err != nil { - return fmt.Errorf("startup message: %w", err) + for !state.done { + state.done = true + msg, err := c.backend.ReceiveStartupMessage() + if err != nil { + return fmt.Errorf("receive startup message: %w", err) } - return nil - case *pgproto3.SSLRequest: - if err := s.handleSSLRequestMessage(ctx, c, msg); err != nil { - return fmt.Errorf("ssl request message: %w", err) + + switch msg := msg.(type) { + case *pgproto3.StartupMessage: + if !state.expectingStartup { + return fmt.Errorf("unexpected StartupMessage") + } + if err := state.handleStartupMessage(ctx, c, msg); err != nil { + return fmt.Errorf("startup message: %w", err) + } + case *pgproto3.SSLRequest: + if !state.expectingSSL { + return fmt.Errorf("unexpected SSLRequest") + } + if err := state.handleSSLRequestMessage(ctx, c, msg); err != nil { + return fmt.Errorf("ssl request message: %w", err) + } + default: + return fmt.Errorf("unexpected startup message: %#v", msg) } - return nil - default: - return fmt.Errorf("unexpected startup message: %#v", msg) } + return nil } -func (s *Server) handleStartupMessage(ctx context.Context, c *Conn, msg *pgproto3.StartupMessage) (err error) { +func (s *startupState) handleStartupMessage(ctx context.Context, c *Conn, msg *pgproto3.StartupMessage) (err error) { log.Printf("received startup message: %#v", msg) // Validate @@ -317,12 +334,15 @@ func (s *Server) handleStartupMessage(ctx context.Context, c *Conn, msg *pgproto ) } -func (s *Server) handleSSLRequestMessage(ctx context.Context, c *Conn, msg *pgproto3.SSLRequest) error { +func (s *startupState) handleSSLRequestMessage(ctx context.Context, c *Conn, msg *pgproto3.SSLRequest) error { log.Printf("received ssl request message: %#v", msg) if _, err := c.Write([]byte("N")); err != nil { return err } - return s.serveConnStartup(ctx, c) + + s.done = false + s.expectingSSL = false + return nil } func (s *Server) handleQueryMessage(ctx context.Context, c *Conn, msg *pgproto3.Query) error {