From 07d8536434ce33977e7cf339a200c38d57262967 Mon Sep 17 00:00:00 2001 From: Nikolay Petrov Date: Tue, 17 Feb 2026 18:46:28 -0500 Subject: [PATCH] wait for client conns to close --- cmd/connet/control.go | 4 +++- cmd/connet/relay.go | 4 +++- cmd/connet/server.go | 7 ++++++- server/control/clients.go | 15 +++++++++++++-- server/control/relays.go | 10 +++++++++- server/control/server.go | 13 +++++++++---- server/relay/clients.go | 10 ++++++++-- server/relay/server.go | 13 +++++++++---- server/server.go | 5 +++++ 9 files changed, 65 insertions(+), 16 deletions(-) diff --git a/cmd/connet/control.go b/cmd/connet/control.go index 21f92261..f321fdf9 100644 --- a/cmd/connet/control.go +++ b/cmd/connet/control.go @@ -194,7 +194,9 @@ func controlRun(ctx context.Context, cfg ControlConfig, logger *slog.Logger) err if err != nil { return fmt.Errorf("create control server: %w", err) } - return runWithStatus(ctx, srv, statusAddr, logger) + err = runWithStatus(ctx, srv, statusAddr, logger) + srv.WaitDrainConns(ctx) + return err } func (cfg ControlIngress) parse() (control.Ingress, error) { diff --git a/cmd/connet/relay.go b/cmd/connet/relay.go index 9902ca1c..deae4d5c 100644 --- a/cmd/connet/relay.go +++ b/cmd/connet/relay.go @@ -185,7 +185,9 @@ func relayRun(ctx context.Context, cfg RelayConfig, logger *slog.Logger) error { if err != nil { return fmt.Errorf("create relay server: %w", err) } - return runWithStatus(ctx, srv, statusAddr, logger) + err = runWithStatus(ctx, srv, statusAddr, logger) + srv.WaitDrainConns(ctx) + return err } func (cfg RelayIngress) parse() (relay.Ingress, error) { diff --git a/cmd/connet/server.go b/cmd/connet/server.go index a92c5f63..4ee066c3 100644 --- a/cmd/connet/server.go +++ b/cmd/connet/server.go @@ -148,7 +148,12 @@ func serverRun(ctx context.Context, cfg ServerConfig, logger *slog.Logger) error if err != nil { return fmt.Errorf("create server: %w", err) } - return runWithStatus(ctx, srv, statusAddr, logger) + + err = runWithStatus(ctx, srv, statusAddr, logger) + logger.Info("server stopped, draining conns") + srv.WaitDrainConns(ctx) + logger.Info("server stopped, completed conns") + return err } func (c *ServerConfig) merge(o ServerConfig) { diff --git a/server/control/clients.go b/server/control/clients.go index 32dfdb91..15f444e3 100644 --- a/server/control/clients.go +++ b/server/control/clients.go @@ -163,6 +163,8 @@ type clientServer struct { reactivate map[ClientID]reactivateValue reactivateMu sync.RWMutex + + connsWg sync.WaitGroup } type peerKey struct { @@ -268,6 +270,7 @@ func (s *clientServer) listen(ctx context.Context, endpoint model.Endpoint, role func (s *clientServer) run(ctx context.Context) error { g := reliable.NewGroup(ctx) + s.connsWg.Add(len(s.ingresses)) for _, ingress := range s.ingresses { g.Go(reliable.Bind(ingress, s.runListener)) } @@ -281,6 +284,8 @@ func (s *clientServer) run(ctx context.Context) error { } func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { + defer s.connsWg.Done() + s.logger.Debug("start udp listener", "addr", ingress.Addr) udpConn, err := net.ListenUDP("udp", ingress.Addr) if err != nil { @@ -339,7 +344,10 @@ func (s *clientServer) runListener(ctx context.Context, ingress Ingress) error { conn: conn, logger: s.logger, } - go cc.run(ctx) + s.connsWg.Go(func() { + cc.run(ctx) + }) + // go cc.run(ctx) } } @@ -518,7 +526,10 @@ func (c *clientConn) runErr(ctx context.Context) error { conn: c, stream: stream, } - go cs.run(ctx) + c.server.connsWg.Go(func() { + cs.run(ctx) + }) + // go cs.run(ctx) } } diff --git a/server/control/relays.go b/server/control/relays.go index 78725f83..0ab0ccd0 100644 --- a/server/control/relays.go +++ b/server/control/relays.go @@ -132,6 +132,8 @@ type relayServer struct { connsCache map[RelayID]cachedRelay connsOffset int64 connsMu sync.RWMutex + + connsWg sync.WaitGroup } type cachedRelay struct { @@ -223,6 +225,7 @@ func (s *relayServer) Relays(ctx context.Context, endpoint model.Endpoint, role func (s *relayServer) run(ctx context.Context) error { g := reliable.NewGroup(ctx) + s.connsWg.Add(len(s.ingresses)) for _, ingress := range s.ingresses { g.Go(reliable.Bind(ingress, s.runListener)) } @@ -234,6 +237,8 @@ func (s *relayServer) run(ctx context.Context) error { } func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error { + defer s.connsWg.Done() + s.logger.Debug("start udp listener", "addr", ingress.Addr) udpConn, err := net.ListenUDP("udp", ingress.Addr) if err != nil { @@ -292,7 +297,10 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error { conn: conn, logger: s.logger, } - go rc.run(ctx) + s.connsWg.Go(func() { + rc.run(ctx) + }) + // go rc.run(ctx) } } diff --git a/server/control/server.go b/server/control/server.go index b7dbffd8..794e04be 100644 --- a/server/control/server.go +++ b/server/control/server.go @@ -25,15 +25,15 @@ type Config struct { } func NewServer(cfg Config) (*Server, error) { + if err := cfg.Stores.RemoveDeprecated(); err != nil { + cfg.Logger.Warn("could not remove deprecated stores", "err", err) + } + configStore, err := cfg.Stores.Config() if err != nil { return nil, fmt.Errorf("config store open: %w", err) } - if err := cfg.Stores.RemoveDeprecated(); err != nil { - cfg.Logger.Warn("could not remove deprecated stores", "err", err) - } - relays, err := newRelayServer(cfg.RelaysIngress, cfg.RelaysAuth, configStore, cfg.Stores, cfg.Logger) if err != nil { return nil, fmt.Errorf("create relay server: %w", err) @@ -67,6 +67,11 @@ func (s *Server) Run(ctx context.Context) error { ) } +func (s *Server) WaitDrainConns(ctx context.Context) { + s.relays.connsWg.Wait() + s.clients.connsWg.Wait() +} + func (s *Server) Status(ctx context.Context) (Status, error) { clients, err := s.getClients() if err != nil { diff --git a/server/relay/clients.go b/server/relay/clients.go index 9e4916d5..c4102af9 100644 --- a/server/relay/clients.go +++ b/server/relay/clients.go @@ -46,7 +46,8 @@ type clientsServer struct { endpoints map[model.Endpoint]*endpointClients endpointsMu sync.RWMutex - logger *slog.Logger + connsWg sync.WaitGroup + logger *slog.Logger } func newClientsServer(cfg Config, cert *certc.Cert, auth ClientAuthenticator) (*clientsServer, error) { @@ -190,6 +191,8 @@ type clientsServerCfg struct { } func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error { + defer s.connsWg.Done() + s.logger.Debug("start udp listener", "addr", cfg.ingress.Addr) udpConn, err := net.ListenUDP("udp", cfg.ingress.Addr) if err != nil { @@ -246,7 +249,10 @@ func (s *clientsServer) run(ctx context.Context, cfg clientsServerCfg) error { conn: conn, logger: s.logger, } - go rc.run(ctx) + s.connsWg.Go(func() { + rc.run(ctx) + }) + // go rc.run(ctx) } } diff --git a/server/relay/server.go b/server/relay/server.go index d5137eef..12607e57 100644 --- a/server/relay/server.go +++ b/server/relay/server.go @@ -42,15 +42,15 @@ func NewServer(cfg Config) (*Server, error) { return nil, fmt.Errorf("relay server is missing ingresses") } + if err := cfg.Stores.RemoveDeprecated(); err != nil { + cfg.Logger.Warn("could not remove deprecated stores", "err", err) + } + configStore, err := cfg.Stores.Config() if err != nil { return nil, fmt.Errorf("relay stores: %w", err) } - if err := cfg.Stores.RemoveDeprecated(); err != nil { - cfg.Logger.Warn("could not remove deprecated stores", "err", err) - } - statelessResetVal, err := configStore.GetOrInit(configStatelessReset, func(ck ConfigKey) (ConfigValue, error) { var key quic.StatelessResetKey if _, err := io.ReadFull(rand.Reader, key[:]); err != nil { @@ -112,6 +112,7 @@ func (s *Server) Run(ctx context.Context) error { g := reliable.NewGroup(ctx) + s.clients.connsWg.Add(len(s.ingress)) for _, ingress := range s.ingress { cfg := clientsServerCfg{ ingress: ingress, @@ -131,6 +132,10 @@ func (s *Server) Run(ctx context.Context) error { return g.Wait() } +func (s *Server) WaitDrainConns(ctx context.Context) { + s.clients.connsWg.Wait() +} + type Status struct { Status statusc.Status `json:"status"` BuildVersion string `json:"build-version"` diff --git a/server/server.go b/server/server.go index ff008014..3eb91b59 100644 --- a/server/server.go +++ b/server/server.go @@ -107,6 +107,11 @@ func (s *Server) Run(ctx context.Context) error { return g.Wait() } +func (s *Server) WaitDrainConns(ctx context.Context) { + s.control.WaitDrainConns(ctx) + s.relay.WaitDrainConns(ctx) +} + func (s *Server) Status(ctx context.Context) (ServerStatus, error) { control, err := s.control.Status(ctx) if err != nil {