Skip to content
This repository was archived by the owner on Oct 2, 2023. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,29 +244,46 @@ func (s *Server) serveConn(ctx context.Context, c *Conn) error {
}
}

type startupState struct {
*Server
done bool
expectingStartup bool
expectingSSL bool
}

func (s *Server) serveConnStartup(ctx context.Context, c *Conn) error {
msg, err := c.backend.ReceiveStartupMessage()
if err != nil {
return fmt.Errorf("receive startup message: %w", err)
}
state := &startupState{Server: s, done: false, expectingStartup: true, expectingSSL: true}

switch msg := msg.(type) {
case *pgproto3.StartupMessage:
if err := s.handleStartupMessage(ctx, c, msg); err != nil {
return fmt.Errorf("startup message: %w", err)
for !state.done {
state.done = true
msg, err := c.backend.ReceiveStartupMessage()
if err != nil {
return fmt.Errorf("receive startup message: %w", err)
}
return nil
case *pgproto3.SSLRequest:
if err := s.handleSSLRequestMessage(ctx, c, msg); err != nil {
return fmt.Errorf("ssl request message: %w", err)

switch msg := msg.(type) {
case *pgproto3.StartupMessage:
if !state.expectingStartup {
return fmt.Errorf("unexpected StartupMessage")
}
if err := state.handleStartupMessage(ctx, c, msg); err != nil {
return fmt.Errorf("startup message: %w", err)
}
case *pgproto3.SSLRequest:
if !state.expectingSSL {
return fmt.Errorf("unexpected SSLRequest")
}
if err := state.handleSSLRequestMessage(ctx, c, msg); err != nil {
return fmt.Errorf("ssl request message: %w", err)
}
default:
return fmt.Errorf("unexpected startup message: %#v", msg)
}
return nil
default:
return fmt.Errorf("unexpected startup message: %#v", msg)
}
return nil
}

func (s *Server) handleStartupMessage(ctx context.Context, c *Conn, msg *pgproto3.StartupMessage) (err error) {
func (s *startupState) handleStartupMessage(ctx context.Context, c *Conn, msg *pgproto3.StartupMessage) (err error) {
log.Printf("received startup message: %#v", msg)

// Validate
Expand Down Expand Up @@ -317,12 +334,15 @@ func (s *Server) handleStartupMessage(ctx context.Context, c *Conn, msg *pgproto
)
}

func (s *Server) handleSSLRequestMessage(ctx context.Context, c *Conn, msg *pgproto3.SSLRequest) error {
func (s *startupState) handleSSLRequestMessage(ctx context.Context, c *Conn, msg *pgproto3.SSLRequest) error {
log.Printf("received ssl request message: %#v", msg)
if _, err := c.Write([]byte("N")); err != nil {
return err
}
return s.serveConnStartup(ctx, c)

s.done = false
s.expectingSSL = false
return nil
}

func (s *Server) handleQueryMessage(ctx context.Context, c *Conn, msg *pgproto3.Query) error {
Expand Down