diff --git a/initiator.go b/initiator.go index a68cb64e..525ff5b6 100644 --- a/initiator.go +++ b/initiator.go @@ -20,6 +20,7 @@ import ( "context" "crypto/tls" "fmt" + "net" "strings" "sync" "time" @@ -161,6 +162,11 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di connectionAttempt := 0 + // pendingConn holds the raw TCP connection during TLS handshake so the + // stop-goroutine can close it to unblock a stuck Handshake() call. + var pendingConn net.Conn + var connMu sync.Mutex + for { if !i.waitForInSessionTime(session) { return @@ -168,12 +174,18 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di ctx, cancel := context.WithCancel(context.Background()) - // We start a goroutine in order to be able to cancel the dialer mid-connection - // on receiving a stop signal to stop the initiator. + // We start a goroutine in order to be able to cancel the dialer + // mid-connection and to close the raw TCP socket if a TLS handshake + // is in progress when a stop signal is received. go func() { select { case <-i.stopChan: cancel() + connMu.Lock() + if pendingConn != nil { + pendingConn.Close() + } + connMu.Unlock() case <-ctx.Done(): return } @@ -200,11 +212,26 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di } tlsConfig.ServerName = serverName } + + // Store the raw TCP connection so the stop-goroutine can close it + // to unblock Handshake() if stop is signaled during TLS negotiation. + connMu.Lock() + pendingConn = netConn + connMu.Unlock() + tlsConn := tls.Client(netConn, tlsConfig) if err = tlsConn.Handshake(); err != nil { + connMu.Lock() + pendingConn = nil + connMu.Unlock() session.log.OnEventf("Failed handshake: %v", err) goto reconnect } + + connMu.Lock() + pendingConn = nil + connMu.Unlock() + netConn = tlsConn }