From 23772d2914c18eb7c744b57b4dfd4b80f73344d2 Mon Sep 17 00:00:00 2001 From: Aravinda-HWK Date: Tue, 6 Jan 2026 11:20:00 +0530 Subject: [PATCH 1/2] task: enhance SASL service to support TCP connections and update related configurations --- Dockerfile | 14 +- cmd/sasl/main.go | 4 +- docs/README.md | 3 +- internal/sasl/server.go | 140 ++++++++++++++++--- internal/sasl/server_test.go | 252 ++++++++++++++++++++++++++++++----- 5 files changed, 352 insertions(+), 61 deletions(-) diff --git a/Dockerfile b/Dockerfile index 129e844..01a2b01 100644 --- a/Dockerfile +++ b/Dockerfile @@ -48,9 +48,9 @@ RUN mkdir -p /app/data /var/run/raven /etc/raven /var/spool/postfix/private && \ RUN echo '#!/bin/sh' > /app/start.sh && \ echo 'echo "Starting Raven services..."' >> /app/start.sh && \ echo 'echo "Starting SASL authentication service..."' >> /app/start.sh && \ - echo './raven-sasl -socket /var/spool/postfix/private/auth -config /etc/raven/raven.yaml &' >> /app/start.sh && \ + echo './raven-sasl -tcp :12345 -config /etc/raven/raven.yaml &' >> /app/start.sh && \ echo 'SASL_PID=$!' >> /app/start.sh && \ - echo 'echo "SASL service started with PID: $SASL_PID"' >> /app/start.sh && \ + echo 'echo "SASL service started with PID: $SASL_PID (TCP :12345)"' >> /app/start.sh && \ echo 'sleep 1' >> /app/start.sh && \ echo 'echo "Starting IMAP server..."' >> /app/start.sh && \ echo './imap-server -db ${DB_PATH:-/app/data/databases} &' >> /app/start.sh && \ @@ -64,7 +64,7 @@ RUN echo '#!/bin/sh' > /app/start.sh && \ echo 'echo ""' >> /app/start.sh && \ echo 'echo "==================================="' >> /app/start.sh && \ echo 'echo "All Raven services started:"' >> /app/start.sh && \ - echo 'echo " SASL Auth: PID $SASL_PID"' >> /app/start.sh && \ + echo 'echo " SASL Auth: PID $SASL_PID (TCP :12345)"' >> /app/start.sh && \ echo 'echo " IMAP: PID $IMAP_PID"' >> /app/start.sh && \ echo 'echo " LMTP: PID $DELIVERY_PID"' >> /app/start.sh && \ echo 'echo " DB Path: ${DB_PATH:-/app/data/databases}"' >> /app/start.sh && \ @@ -79,15 +79,15 @@ USER ravenuser # Expose ports for services # IMAP: 143 (plaintext), 993 (TLS) # LMTP: 24 -# SASL: Uses Unix socket (no port needed) -EXPOSE 143 993 24 +# SASL: 12345 (TCP) +EXPOSE 143 993 24 12345 # Set environment variables - use directory path for DBManager ENV DB_PATH=/app/data/databases -# Health check - check IMAP and LMTP services (SASL uses Unix socket) +# Health check - check IMAP, LMTP, and SASL services HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ - CMD nc -z localhost 143 && nc -z localhost 24 || exit 1 + CMD nc -z localhost 143 && nc -z localhost 24 && nc -z localhost 12345 || exit 1 # Start all services ENTRYPOINT ["/app/start.sh"] \ No newline at end of file diff --git a/cmd/sasl/main.go b/cmd/sasl/main.go index a326e9a..295ce0e 100644 --- a/cmd/sasl/main.go +++ b/cmd/sasl/main.go @@ -14,6 +14,7 @@ import ( func main() { // Command-line flags socketPath := flag.String("socket", "/var/spool/postfix/private/auth", "Path to UNIX socket") + tcpAddr := flag.String("tcp", ":12345", "TCP address to bind (e.g., 127.0.0.1:12345 or :12345)") configPath := flag.String("config", "/etc/raven/raven.yaml", "Path to configuration file") flag.Parse() @@ -37,12 +38,13 @@ func main() { log.Printf("Configuration loaded:") log.Printf(" Socket path: %s", *socketPath) + log.Printf(" TCP address: %s", *tcpAddr) log.Printf(" Config path: %s", *configPath) log.Printf(" Domain: %s", cfg.Domain) log.Printf(" Auth URL: %s", cfg.AuthServerURL) // Create SASL server - server := sasl.NewServer(*socketPath, cfg.AuthServerURL, cfg.Domain) + server := sasl.NewServer(*socketPath, *tcpAddr, cfg.AuthServerURL, cfg.Domain) // Setup graceful shutdown sigChan := make(chan os.Signal, 1) diff --git a/docs/README.md b/docs/README.md index 3dd95d4..33bdd1c 100644 --- a/docs/README.md +++ b/docs/README.md @@ -77,7 +77,7 @@ cd raven docker build -t raven . docker run -d --rm \ --name raven \ - -p 143:143 -p 993:993 -p 24:24 \ + -p 143:143 -p 993:993 -p 24:24 -p 12345:12345 \ -v $(pwd)/config:/etc/raven \ -v $(pwd)/data:/app/data \ -v $(pwd)/certs:/certs \ @@ -89,6 +89,7 @@ The server will start and listen on: - **Port 143** - IMAP - **Port 993** - IMAPS - **Port 24** - LMTP +- **Port 12345** - SASL (TCP) Connect using any IMAP client to start managing your emails. diff --git a/internal/sasl/server.go b/internal/sasl/server.go index 7d0edeb..9e90460 100644 --- a/internal/sasl/server.go +++ b/internal/sasl/server.go @@ -2,6 +2,7 @@ package sasl import ( "bufio" + "context" "crypto/tls" "encoding/base64" "fmt" @@ -16,20 +17,23 @@ import ( // Server represents a SASL authentication server type Server struct { - socketPath string - authURL string - domain string - listener net.Listener - listenerMu sync.RWMutex - wg sync.WaitGroup - shutdown chan struct{} - shutdownOnce sync.Once + socketPath string + tcpAddr string + authURL string + domain string + unixListener net.Listener + tcpListener net.Listener + mu sync.Mutex + wg sync.WaitGroup + shutdown chan struct{} + shutdownOnce sync.Once } // NewServer creates a new SASL authentication server -func NewServer(socketPath, authURL, domain string) *Server { +func NewServer(socketPath, tcpAddr, authURL, domain string) *Server { return &Server{ socketPath: socketPath, + tcpAddr: tcpAddr, authURL: authURL, domain: domain, shutdown: make(chan struct{}), @@ -38,6 +42,30 @@ func NewServer(socketPath, authURL, domain string) *Server { // Start starts the SASL server func (s *Server) Start() error { + log.Println("Starting SASL server...") + + // Start UNIX socket listener if configured + if s.socketPath != "" { + if err := s.startUnixListener(); err != nil { + return fmt.Errorf("failed to start UNIX listener: %w", err) + } + } + + // Start TCP listener if configured + if s.tcpAddr != "" { + if err := s.startTCPListener(); err != nil { + return fmt.Errorf("failed to start TCP listener: %w", err) + } + } + + // Wait for all connections to finish + s.wg.Wait() + log.Println("All connections closed") + return nil +} + +// startUnixListener starts listening on a UNIX socket +func (s *Server) startUnixListener() error { // Remove existing socket file if it exists if err := os.RemoveAll(s.socketPath); err != nil { return fmt.Errorf("failed to remove existing socket: %v", err) @@ -48,9 +76,10 @@ func (s *Server) Start() error { if err != nil { return fmt.Errorf("failed to create Unix socket: %v", err) } - s.listenerMu.Lock() - s.listener = listener - s.listenerMu.Unlock() + + s.mu.Lock() + s.unixListener = listener + s.mu.Unlock() // Set socket permissions (0666 so Postfix can access it) // #nosec G302 -- Unix socket needs world read/write for Postfix access @@ -63,11 +92,48 @@ func (s *Server) Start() error { log.Printf("Using authentication URL: %s", s.authURL) log.Printf("Domain: %s", s.domain) - // Accept connections + s.wg.Add(1) + go s.acceptConnections(listener, "unix") + + return nil +} + +// startTCPListener starts listening on a TCP address +func (s *Server) startTCPListener() error { + // Configure TCP listener with keep-alive + lc := net.ListenConfig{ + KeepAlive: 30 * time.Second, // Send keep-alive probes every 30 seconds + Control: nil, + } + + listener, err := lc.Listen(context.Background(), "tcp", s.tcpAddr) + if err != nil { + return fmt.Errorf("failed to create TCP listener: %v", err) + } + + s.mu.Lock() + s.tcpListener = listener + s.mu.Unlock() + + log.Printf("SASL server listening on TCP: %s (with keep-alive enabled)", s.tcpAddr) + log.Printf("Using authentication URL: %s", s.authURL) + log.Printf("Domain: %s", s.domain) + + s.wg.Add(1) + go s.acceptConnections(listener, "tcp") + + return nil +} + +// acceptConnections accepts incoming connections +func (s *Server) acceptConnections(listener net.Listener, listenerType string) { + defer s.wg.Done() + for { select { case <-s.shutdown: - return nil + log.Printf("Stopping %s listener...", listenerType) + return default: } @@ -75,13 +141,15 @@ func (s *Server) Start() error { if err != nil { select { case <-s.shutdown: - return nil + return default: - log.Printf("Accept error: %v", err) + log.Printf("Accept error on %s listener: %v", listenerType, err) continue } } + log.Printf("New %s connection from: %s", listenerType, conn.RemoteAddr()) + s.wg.Add(1) go s.handleConnection(conn) } @@ -91,15 +159,43 @@ func (s *Server) Start() error { func (s *Server) Shutdown() error { var err error s.shutdownOnce.Do(func() { + s.mu.Lock() + defer s.mu.Unlock() + + log.Println("Shutting down SASL server...") + + // Signal shutdown close(s.shutdown) - s.listenerMu.RLock() - listener := s.listener - s.listenerMu.RUnlock() - if listener != nil { - err = listener.Close() + + // Close listeners + var errs []error + + if s.unixListener != nil { + if closeErr := s.unixListener.Close(); closeErr != nil { + errs = append(errs, fmt.Errorf("error closing Unix listener: %w", closeErr)) + } + // Clean up socket file + if s.socketPath != "" { + _ = os.Remove(s.socketPath) + } + } + + if s.tcpListener != nil { + if closeErr := s.tcpListener.Close(); closeErr != nil { + errs = append(errs, fmt.Errorf("error closing TCP listener: %w", closeErr)) + } } + + // Wait for all connections to finish (outside of lock) + s.mu.Unlock() s.wg.Wait() - _ = os.Remove(s.socketPath) + s.mu.Lock() + + if len(errs) > 0 { + err = fmt.Errorf("shutdown errors: %v", errs) + } + + log.Println("SASL server shutdown complete") }) return err } diff --git a/internal/sasl/server_test.go b/internal/sasl/server_test.go index 81902af..b02105c 100644 --- a/internal/sasl/server_test.go +++ b/internal/sasl/server_test.go @@ -30,10 +30,11 @@ func getSocketPath(t *testing.T) string { // TestNewServer tests server creation func TestNewServer(t *testing.T) { socketPath := "/tmp/test-sasl.sock" + tcpAddr := "" authURL := "https://example.com/auth" domain := "example.com" - server := sasl.NewServer(socketPath, authURL, domain) + server := sasl.NewServer(socketPath, tcpAddr, authURL, domain) if server == nil { t.Fatal("Expected server to be created, got nil") @@ -53,7 +54,7 @@ func TestServerStartShutdown(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server in goroutine errChan := make(chan error, 1) @@ -94,7 +95,7 @@ func TestServerStartShutdown(t *testing.T) { func TestServerShutdownIdempotent(t *testing.T) { socketPath := getSocketPath(t) - server := sasl.NewServer(socketPath, "https://example.com/auth", "example.com") + server := sasl.NewServer(socketPath, "", "https://example.com/auth", "example.com") // Start server in goroutine go func() { _ = server.Start() }() @@ -125,7 +126,7 @@ func TestVersionHandshake(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -165,7 +166,7 @@ func TestCPIDCommand(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -228,7 +229,7 @@ func TestPlainAuthenticationSuccess(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server errChan := make(chan error, 1) @@ -297,7 +298,7 @@ func TestPlainAuthenticationWithDomain(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -344,7 +345,7 @@ func TestPlainAuthenticationFailure(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -391,7 +392,7 @@ func TestPlainAuthenticationWithAuthzid(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -438,7 +439,7 @@ func TestPlainAuthenticationInvalidBase64(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -481,7 +482,7 @@ func TestPlainAuthenticationMalformedCredentials(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -549,7 +550,7 @@ func TestPlainAuthenticationContinuationRequest(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -589,7 +590,7 @@ func TestLoginMechanism(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -629,7 +630,7 @@ func TestUnsupportedMechanism(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -648,7 +649,7 @@ func TestUnsupportedMechanism(t *testing.T) { defer func() { _ = conn.Close() }() // Send AUTH command with unsupported mechanism - _, _ = fmt.Fprintf(conn,"AUTH\t1\t%s\tservice=smtp\n", mechanism) + _, _ = fmt.Fprintf(conn, "AUTH\t1\t%s\tservice=smtp\n", mechanism) // Read response reader := bufio.NewReader(conn) @@ -678,7 +679,7 @@ func TestAuthMechanismCaseInsensitive(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -697,7 +698,7 @@ func TestAuthMechanismCaseInsensitive(t *testing.T) { defer func() { _ = conn.Close() }() // Send AUTH command - _, _ = fmt.Fprintf(conn,"AUTH\t1\t%s\tservice=smtp\n", mechanism) + _, _ = fmt.Fprintf(conn, "AUTH\t1\t%s\tservice=smtp\n", mechanism) // Read response reader := bufio.NewReader(conn) @@ -723,7 +724,7 @@ func TestInvalidAuthCommand(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -762,7 +763,7 @@ func TestConcurrentConnections(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -789,7 +790,7 @@ func TestConcurrentConnections(t *testing.T) { // Send authentication request credentials := fmt.Sprintf("\x00user%d\x00pass%d", id, id) encodedCreds := base64.StdEncoding.EncodeToString([]byte(credentials)) - _, _ = fmt.Fprintf(conn,"AUTH\t%d\tPLAIN\tservice=smtp\tresp=%s\n", id, encodedCreds) + _, _ = fmt.Fprintf(conn, "AUTH\t%d\tPLAIN\tservice=smtp\tresp=%s\n", id, encodedCreds) // Read response reader := bufio.NewReader(conn) @@ -826,7 +827,7 @@ func TestConnectionTimeout(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -881,7 +882,7 @@ func TestAuthenticationAPIError(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -898,7 +899,7 @@ func TestAuthenticationAPIError(t *testing.T) { // Send authentication request credentials := "\x00testuser\x00testpass" encodedCreds := base64.StdEncoding.EncodeToString([]byte(credentials)) - _, _ = fmt.Fprintf(conn,"AUTH\t1\tPLAIN\tservice=smtp\tresp=%s\n", encodedCreds) + _, _ = fmt.Fprintf(conn, "AUTH\t1\tPLAIN\tservice=smtp\tresp=%s\n", encodedCreds) // Read response reader := bufio.NewReader(conn) @@ -930,7 +931,7 @@ func TestMultipleCommandsInSession(t *testing.T) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -947,14 +948,14 @@ func TestMultipleCommandsInSession(t *testing.T) { reader := bufio.NewReader(conn) // Send VERSION - _, _ = fmt.Fprintf(conn,"VERSION\t1\t2\n") + _, _ = fmt.Fprintf(conn, "VERSION\t1\t2\n") response, _ := reader.ReadString('\n') if !strings.HasPrefix(response, "VERSION") { t.Errorf("Expected VERSION response, got: %s", response) } // Send CPID - _, _ = fmt.Fprintf(conn,"CPID\t12345\n") + _, _ = fmt.Fprintf(conn, "CPID\t12345\n") // Read all MECH responses response, _ = reader.ReadString('\n') // First MECH if !strings.HasPrefix(response, "MECH") { @@ -972,20 +973,211 @@ func TestMultipleCommandsInSession(t *testing.T) { // Send AUTH credentials := "\x00testuser\x00testpass" encodedCreds := base64.StdEncoding.EncodeToString([]byte(credentials)) - _, _ = fmt.Fprintf(conn,"AUTH\t1\tPLAIN\tservice=smtp\tresp=%s\n", encodedCreds) + _, _ = fmt.Fprintf(conn, "AUTH\t1\tPLAIN\tservice=smtp\tresp=%s\n", encodedCreds) response, _ = reader.ReadString('\n') if !strings.HasPrefix(response, "OK\t1\t") { t.Errorf("Expected OK response, got: %s", response) } // Send another AUTH - _, _ = fmt.Fprintf(conn,"AUTH\t2\tPLAIN\tservice=smtp\tresp=%s\n", encodedCreds) + _, _ = fmt.Fprintf(conn, "AUTH\t2\tPLAIN\tservice=smtp\tresp=%s\n", encodedCreds) response, _ = reader.ReadString('\n') if !strings.HasPrefix(response, "OK\t2\t") { t.Errorf("Expected OK response with id 2, got: %s", response) } } +// TestTCPListener tests TCP listener functionality +func TestTCPListener(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer authServer.Close() + + // Use a random port + server := sasl.NewServer("", "127.0.0.1:0", authServer.URL, "example.com") + + // Start server + go func() { _ = server.Start() }() + defer func() { _ = server.Shutdown() }() + time.Sleep(100 * time.Millisecond) + + // Find the actual port by attempting to connect + // Note: Since we can't get the actual port from the server, + // we'll use a fixed port for testing + server2 := sasl.NewServer("", "127.0.0.1:12345", authServer.URL, "example.com") + go func() { _ = server2.Start() }() + defer func() { _ = server2.Shutdown() }() + time.Sleep(100 * time.Millisecond) + + // Connect via TCP + conn, err := net.Dial("tcp", "127.0.0.1:12345") + if err != nil { + t.Fatalf("Failed to connect via TCP: %v", err) + } + defer func() { _ = conn.Close() }() + + // Send VERSION command + _, _ = fmt.Fprintf(conn, "VERSION\t1\t2\n") + + // Read response + reader := bufio.NewReader(conn) + response, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + + // Check response + expectedResponse := "VERSION\t1\t2\n" + if response != expectedResponse { + t.Errorf("Expected response %q, got %q", expectedResponse, response) + } +} + +// TestTCPAuthentication tests authentication over TCP +func TestTCPAuthentication(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer authServer.Close() + + server := sasl.NewServer("", "127.0.0.1:12346", authServer.URL, "example.com") + + // Start server + go func() { _ = server.Start() }() + defer func() { _ = server.Shutdown() }() + time.Sleep(100 * time.Millisecond) + + // Connect via TCP + conn, err := net.Dial("tcp", "127.0.0.1:12346") + if err != nil { + t.Fatalf("Failed to connect via TCP: %v", err) + } + defer func() { _ = conn.Close() }() + + // Send PLAIN authentication + credentials := "\x00testuser\x00testpass" + encodedCreds := base64.StdEncoding.EncodeToString([]byte(credentials)) + _, _ = fmt.Fprintf(conn, "AUTH\t1\tPLAIN\tservice=smtp\tresp=%s\n", encodedCreds) + + // Read response + reader := bufio.NewReader(conn) + response, err := reader.ReadString('\n') + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + + // Check response + if !strings.HasPrefix(response, "OK\t1\t") { + t.Errorf("Expected OK response, got: %s", response) + } +} + +// TestBothListeners tests that both Unix socket and TCP listeners work together +func TestBothListeners(t *testing.T) { + socketPath := getSocketPath(t) + + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer authServer.Close() + + server := sasl.NewServer(socketPath, "127.0.0.1:12347", authServer.URL, "example.com") + + // Start server + go func() { _ = server.Start() }() + defer func() { _ = server.Shutdown() }() + time.Sleep(200 * time.Millisecond) + + // Test Unix socket connection + unixConn, err := net.Dial("unix", socketPath) + if err != nil { + t.Fatalf("Failed to connect via Unix socket: %v", err) + } + defer func() { _ = unixConn.Close() }() + + // Test TCP connection + tcpConn, err := net.Dial("tcp", "127.0.0.1:12347") + if err != nil { + t.Fatalf("Failed to connect via TCP: %v", err) + } + defer func() { _ = tcpConn.Close() }() + + // Send commands to both connections + credentials := "\x00testuser\x00testpass" + encodedCreds := base64.StdEncoding.EncodeToString([]byte(credentials)) + + // Unix socket auth + _, _ = fmt.Fprintf(unixConn, "AUTH\t1\tPLAIN\tservice=smtp\tresp=%s\n", encodedCreds) + unixReader := bufio.NewReader(unixConn) + unixResponse, _ := unixReader.ReadString('\n') + + // TCP auth + _, _ = fmt.Fprintf(tcpConn, "AUTH\t2\tPLAIN\tservice=smtp\tresp=%s\n", encodedCreds) + tcpReader := bufio.NewReader(tcpConn) + tcpResponse, _ := tcpReader.ReadString('\n') + + // Both should succeed + if !strings.HasPrefix(unixResponse, "OK\t1\t") { + t.Errorf("Unix socket: Expected OK response, got: %s", unixResponse) + } + if !strings.HasPrefix(tcpResponse, "OK\t2\t") { + t.Errorf("TCP: Expected OK response, got: %s", tcpResponse) + } +} + +// TestTCPConcurrentConnections tests concurrent connections over TCP +func TestTCPConcurrentConnections(t *testing.T) { + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer authServer.Close() + + server := sasl.NewServer("", "127.0.0.1:12348", authServer.URL, "example.com") + + // Start server + go func() { _ = server.Start() }() + defer func() { _ = server.Shutdown() }() + time.Sleep(100 * time.Millisecond) + + // Create multiple concurrent connections + var wg sync.WaitGroup + numConnections := 10 + + for i := 0; i < numConnections; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + conn, err := net.Dial("tcp", "127.0.0.1:12348") + if err != nil { + t.Errorf("Connection %d failed: %v", id, err) + return + } + defer func() { _ = conn.Close() }() + + // Send auth + credentials := fmt.Sprintf("\x00user%d\x00pass%d", id, id) + encodedCreds := base64.StdEncoding.EncodeToString([]byte(credentials)) + _, _ = fmt.Fprintf(conn, "AUTH\t%d\tPLAIN\tservice=smtp\tresp=%s\n", id, encodedCreds) + + // Read response + reader := bufio.NewReader(conn) + response, err := reader.ReadString('\n') + if err != nil { + t.Errorf("Connection %d failed to read: %v", id, err) + return + } + + if !strings.HasPrefix(response, fmt.Sprintf("OK\t%d\t", id)) { + t.Errorf("Connection %d: Expected OK response, got: %s", id, response) + } + }(i) + } + + wg.Wait() +} + // BenchmarkPlainAuthentication benchmarks PLAIN authentication performance func BenchmarkPlainAuthentication(b *testing.B) { tmpDir := b.TempDir() @@ -996,7 +1188,7 @@ func BenchmarkPlainAuthentication(b *testing.B) { })) defer authServer.Close() - server := sasl.NewServer(socketPath, authServer.URL, "example.com") + server := sasl.NewServer(socketPath, "", authServer.URL, "example.com") // Start server go func() { _ = server.Start() }() @@ -1015,7 +1207,7 @@ func BenchmarkPlainAuthentication(b *testing.B) { b.Fatalf("Failed to connect: %v", err) } - _, _ = fmt.Fprintf(conn,"AUTH\t%d\tPLAIN\tservice=smtp\tresp=%s\n", i, encodedCreds) + _, _ = fmt.Fprintf(conn, "AUTH\t%d\tPLAIN\tservice=smtp\tresp=%s\n", i, encodedCreds) reader := bufio.NewReader(conn) _, _ = reader.ReadString('\n') From 3a19c1d414936195808b974cc93c82a4dffb7cb8 Mon Sep 17 00:00:00 2001 From: Aravinda-HWK Date: Tue, 6 Jan 2026 12:36:43 +0530 Subject: [PATCH 2/2] fix: missing parameter --- .../integration/sasl/sasl_integration_test.go | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/integration/sasl/sasl_integration_test.go b/test/integration/sasl/sasl_integration_test.go index 11b9982..74a7afd 100644 --- a/test/integration/sasl/sasl_integration_test.go +++ b/test/integration/sasl/sasl_integration_test.go @@ -24,7 +24,7 @@ func TestSASLAuthenticationFlow(t *testing.T) { // Create SASL server with test socket socketPath := helpers.GetTestSocketPath(t, "sasl-auth-flow") - server := sasl.NewServer(socketPath, mockAuth.URL+"/auth", "example.com") + server := sasl.NewServer(socketPath, "", mockAuth.URL+"/auth", "example.com") // Start SASL server serverErr := make(chan error, 1) @@ -113,7 +113,7 @@ func TestSASLAuthenticationFailure(t *testing.T) { defer mockAuth.Close() socketPath := helpers.GetTestSocketPath(t, "sasl-auth-failure") - server := sasl.NewServer(socketPath, mockAuth.URL+"/auth", "example.com") + server := sasl.NewServer(socketPath, "", mockAuth.URL+"/auth", "example.com") // Start server go func() { _ = server.Start() }() @@ -162,7 +162,7 @@ func TestSASLPlainWithoutInitialResponse(t *testing.T) { defer mockAuth.Close() socketPath := helpers.GetTestSocketPath(t, "sasl-plain-continuation") - server := sasl.NewServer(socketPath, mockAuth.URL+"/auth", "example.com") + server := sasl.NewServer(socketPath, "", mockAuth.URL+"/auth", "example.com") go func() { _ = server.Start() }() defer func() { _ = server.Shutdown() }() @@ -198,7 +198,7 @@ func TestSASLInvalidMechanism(t *testing.T) { defer mockAuth.Close() socketPath := helpers.GetTestSocketPath(t, "sasl-invalid-mech") - server := sasl.NewServer(socketPath, mockAuth.URL+"/auth", "example.com") + server := sasl.NewServer(socketPath, "", mockAuth.URL+"/auth", "example.com") go func() { _ = server.Start() }() defer func() { _ = server.Shutdown() }() @@ -237,7 +237,7 @@ func TestSASLMalformedCredentials(t *testing.T) { defer mockAuth.Close() socketPath := helpers.GetTestSocketPath(t, "sasl-malformed") - server := sasl.NewServer(socketPath, mockAuth.URL+"/auth", "example.com") + server := sasl.NewServer(socketPath, "", mockAuth.URL+"/auth", "example.com") go func() { _ = server.Start() }() defer func() { _ = server.Shutdown() }() @@ -302,7 +302,7 @@ func TestSASLConcurrentConnections(t *testing.T) { defer mockAuth.Close() socketPath := helpers.GetTestSocketPath(t, "sasl-concurrent") - server := sasl.NewServer(socketPath, mockAuth.URL+"/auth", "example.com") + server := sasl.NewServer(socketPath, "", mockAuth.URL+"/auth", "example.com") go func() { _ = server.Start() }() defer func() { _ = server.Shutdown() }() @@ -374,7 +374,7 @@ func TestSASLServerShutdownGraceful(t *testing.T) { defer mockAuth.Close() socketPath := helpers.GetTestSocketPath(t, "sasl-shutdown") - server := sasl.NewServer(socketPath, mockAuth.URL+"/auth", "example.com") + server := sasl.NewServer(socketPath, "", mockAuth.URL+"/auth", "example.com") t.Log("Testing graceful SASL server shutdown...") @@ -447,7 +447,7 @@ func TestSASLAuthenticationServerTimeout(t *testing.T) { defer mockAuth.Close() socketPath := helpers.GetTestSocketPath(t, "sasl-timeout") - server := sasl.NewServer(socketPath, mockAuth.URL+"/auth", "example.com") + server := sasl.NewServer(socketPath, "", mockAuth.URL+"/auth", "example.com") go func() { _ = server.Start() }() defer func() { _ = server.Shutdown() }() @@ -500,7 +500,7 @@ func TestSASLDomainHandling(t *testing.T) { defer mockAuth.Close() socketPath := helpers.GetTestSocketPath(t, "sasl-domain") - server := sasl.NewServer(socketPath, mockAuth.URL+"/auth", "example.com") + server := sasl.NewServer(socketPath, "", mockAuth.URL+"/auth", "example.com") go func() { _ = server.Start() }() defer func() { _ = server.Shutdown() }() @@ -577,7 +577,7 @@ func TestSASLLoginMechanism(t *testing.T) { defer mockAuth.Close() socketPath := helpers.GetTestSocketPath(t, "sasl-login") - server := sasl.NewServer(socketPath, mockAuth.URL+"/auth", "example.com") + server := sasl.NewServer(socketPath, "", mockAuth.URL+"/auth", "example.com") go func() { _ = server.Start() }() defer func() { _ = server.Shutdown() }()