From 247d9bef3a4a45f61b557c0ff16e465e7a92bb98 Mon Sep 17 00:00:00 2001 From: Aman-Cool Date: Sun, 25 Jan 2026 15:05:50 +0530 Subject: [PATCH] fix(ipc): Add timeout to AwaitMessage to prevent indefinite blocking - Add IPCAcceptTimeout (60s) and IPCReadTimeout (10s) to prevent orphaned processes when counterpart never connects - Fix closure bug in executeHooksConcurrently using wrong loop variable - Fix isRunning() using annotType instead of annotHypervisor - Add tests for timeout and wrong message handling Signed-off-by: Aman-Cool --- pkg/unikontainers/ipc.go | 41 +++++++++++++++++++---- pkg/unikontainers/ipc_test.go | 53 ++++++++++++++++++++++++++++++ pkg/unikontainers/unikontainers.go | 6 ++-- 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/pkg/unikontainers/ipc.go b/pkg/unikontainers/ipc.go index d7c19c45..a4ad121c 100644 --- a/pkg/unikontainers/ipc.go +++ b/pkg/unikontainers/ipc.go @@ -41,6 +41,12 @@ const ( maxRetries = 50 waitTime = 5 * time.Millisecond FromReexec = true + // IPCAcceptTimeout is the maximum time to wait for a connection on the IPC socket. + // This prevents processes from hanging indefinitely if the counterpart never connects + // (e.g., due to containerd restart, node pressure, or orchestration failures). + IPCAcceptTimeout = 60 * time.Second + // IPCReadTimeout is the maximum time to wait for reading a message after connection. + IPCReadTimeout = 10 * time.Second ) func getSockAddr(dir string, name string) string { @@ -145,27 +151,48 @@ func createListener(socketAddress string, mustBeValid bool) (*net.UnixListener, return listener, nil } -// awaitMessage opens a new connection to socketAddress -// and waits for a given message +// AwaitMessage waits for a connection on the listener and reads an expected message. +// It uses timeouts to prevent indefinite blocking if the counterpart process +// never connects (e.g., due to orchestration failures, crashes, or restarts). func AwaitMessage(listener *net.UnixListener, expectedMessage IPCMessage) error { + // Set accept deadline to prevent indefinite blocking. + // This is critical for preventing orphaned processes when urunc start + // never runs after urunc create, or when reexec fails silently. + if err := listener.SetDeadline(time.Now().Add(IPCAcceptTimeout)); err != nil { + return fmt.Errorf("failed to set listener deadline: %w", err) + } + conn, err := listener.AcceptUnix() if err != nil { - return err + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return fmt.Errorf("timeout waiting for IPC connection (waited %v): counterpart process may have failed or not started", IPCAcceptTimeout) + } + return fmt.Errorf("failed to accept connection: %w", err) } defer func() { - err = conn.Close() - if err != nil { - logrus.WithError(err).Error("failed to close connection") + if closeErr := conn.Close(); closeErr != nil { + logrus.WithError(closeErr).Error("failed to close connection") } }() + + // Set read deadline to prevent hanging on slow or stuck writers + if err := conn.SetReadDeadline(time.Now().Add(IPCReadTimeout)); err != nil { + return fmt.Errorf("failed to set read deadline: %w", err) + } + buf := make([]byte, len(expectedMessage)) n, err := conn.Read(buf) if err != nil { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return fmt.Errorf("timeout reading IPC message (waited %v): counterpart process may be stuck", IPCReadTimeout) + } return fmt.Errorf("failed to read from socket: %w", err) } msg := string(buf[0:n]) if msg != string(expectedMessage) { - return fmt.Errorf("received unexpected message: %s", msg) + return fmt.Errorf("received unexpected message: %s (expected: %s)", msg, expectedMessage) } return nil } diff --git a/pkg/unikontainers/ipc_test.go b/pkg/unikontainers/ipc_test.go index dbc63a9a..fcd66127 100644 --- a/pkg/unikontainers/ipc_test.go +++ b/pkg/unikontainers/ipc_test.go @@ -176,3 +176,56 @@ func TestAwaitMessage(t *testing.T) { err = AwaitMessage(listener, expectedMessage) assert.NoError(t, err, "Expected no error in awaiting message") } + +func TestAwaitMessageTimeout(t *testing.T) { + socketAddress := "/tmp/test_await_message_timeout.sock" + expectedMessage := ReexecStarted + + listener, err := createListener(socketAddress, true) + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer listener.Close() + + // Don't send any message - this should trigger a timeout + // Note: For testing, we need shorter timeouts than production. + // The actual timeout check is that it returns an error containing "timeout" + // rather than blocking forever. + + // Set a shorter deadline for testing purposes + listener.SetDeadline(time.Now().Add(100 * time.Millisecond)) + + err = AwaitMessage(listener, expectedMessage) + assert.Error(t, err, "Expected timeout error when no connection arrives") + assert.Contains(t, err.Error(), "timeout", "Expected error message to mention timeout") +} + +func TestAwaitMessageWrongMessage(t *testing.T) { + socketAddress := "/tmp/test_await_wrong_message.sock" + expectedMessage := ReexecStarted + wrongMessage := StartExecve + + listener, err := createListener(socketAddress, true) + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer listener.Close() + + go func() { + conn, err := net.Dial("unix", socketAddress) + if err != nil { + t.Errorf("Failed to dial connection: %v", err) + } + defer conn.Close() + + // Send wrong message + _, err = conn.Write([]byte(wrongMessage)) + if err != nil { + t.Errorf("Failed to send message: %v", err) + } + }() + + err = AwaitMessage(listener, expectedMessage) + assert.Error(t, err, "Expected error for unexpected message") + assert.Contains(t, err.Error(), "unexpected message", "Expected error to mention unexpected message") +} diff --git a/pkg/unikontainers/unikontainers.go b/pkg/unikontainers/unikontainers.go index 85dc1e2a..1b7235a1 100644 --- a/pkg/unikontainers/unikontainers.go +++ b/pkg/unikontainers/unikontainers.go @@ -753,8 +753,8 @@ func (u *Unikontainer) executeHooksConcurrently(name string, hooks []specs.Hook, uniklog.WithFields(logrus.Fields{ "id": u.State.ID, "name": name, - "path": hooks[i].Path, - "args": hooks[i].Args, + "path": h.Path, + "args": h.Args, "error": err, }).Error("Executing hook failed") errChan <- err @@ -1121,7 +1121,7 @@ func (u *Unikontainer) SendMessage(message IPCMessage) error { // isRunning returns true if the PID is alive or hedge.ListVMs returns our containerID func (u *Unikontainer) isRunning() bool { - vmmType := hypervisors.VmmType(u.State.Annotations[annotType]) + vmmType := hypervisors.VmmType(u.State.Annotations[annotHypervisor]) if vmmType != hypervisors.HedgeVmm { return syscall.Kill(u.State.Pid, syscall.Signal(0)) == nil }