From 26ec8c3a3a683e74ba34245fbc8f51ebe76c4c8c Mon Sep 17 00:00:00 2001 From: Andrew Yuan Date: Thu, 5 Mar 2026 15:17:56 -0500 Subject: [PATCH] heartbeat ctx cancel fix --- internal/internal_worker_heartbeat.go | 25 ++++--- internal/internal_worker_heartbeat_test.go | 78 ++++++++++++++++++++++ 2 files changed, 95 insertions(+), 8 deletions(-) create mode 100644 internal/internal_worker_heartbeat_test.go diff --git a/internal/internal_worker_heartbeat.go b/internal/internal_worker_heartbeat.go index 74ddb9f73..86c0c35fd 100644 --- a/internal/internal_worker_heartbeat.go +++ b/internal/internal_worker_heartbeat.go @@ -60,14 +60,17 @@ func (m *heartbeatManager) registerWorker( hw, ok := m.workers[namespace] // If this is the first worker on the namespace, start a new shared namespace worker. if !ok { + heartbeatCtx, heartbeatCancel := context.WithCancel(context.Background()) hw = &sharedNamespaceWorker{ - client: m.client, - namespace: namespace, - interval: m.interval, - callbacks: make(map[string]func() *workerpb.WorkerHeartbeat), - stopC: make(chan struct{}), - stoppedC: make(chan struct{}), - logger: m.logger, + client: m.client, + namespace: namespace, + interval: m.interval, + heartbeatCtx: heartbeatCtx, + heartbeatCancel: heartbeatCancel, + callbacks: make(map[string]func() *workerpb.WorkerHeartbeat), + stopC: make(chan struct{}), + stoppedC: make(chan struct{}), + logger: m.logger, } m.workers[namespace] = hw if hw.started.Swap(true) { @@ -113,6 +116,9 @@ type sharedNamespaceWorker struct { interval time.Duration logger log.Logger + heartbeatCtx context.Context + heartbeatCancel context.CancelFunc + // callbacksMutex should only be unlocked under callbacksMutex sync.RWMutex callbacks map[string]func() *workerpb.WorkerHeartbeat // workerInstanceKey -> callback @@ -159,7 +165,7 @@ func (hw *sharedNamespaceWorker) sendHeartbeats() error { heartbeats = append(heartbeats, hb) } - _, err := hw.client.recordWorkerHeartbeat(context.Background(), &workflowservice.RecordWorkerHeartbeatRequest{ + _, err := hw.client.recordWorkerHeartbeat(hw.heartbeatCtx, &workflowservice.RecordWorkerHeartbeatRequest{ Namespace: hw.namespace, WorkerHeartbeat: heartbeats, }) @@ -179,6 +185,9 @@ func (hw *sharedNamespaceWorker) stop() { if !hw.started.CompareAndSwap(true, false) { return } + if hw.heartbeatCancel != nil { + hw.heartbeatCancel() + } close(hw.stopC) <-hw.stoppedC diff --git a/internal/internal_worker_heartbeat_test.go b/internal/internal_worker_heartbeat_test.go new file mode 100644 index 000000000..af3d21473 --- /dev/null +++ b/internal/internal_worker_heartbeat_test.go @@ -0,0 +1,78 @@ +package internal + +import ( + "context" + "testing" + "time" + + "github.com/golang/mock/gomock" + workerpb "go.temporal.io/api/worker/v1" + "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/api/workflowservicemock/v1" + ilog "go.temporal.io/sdk/internal/log" + "google.golang.org/grpc" +) + +// TestStopCancelsInFlightHeartbeatRPC verifies that calling stop() on a +// sharedNamespaceWorker cancels an in-flight heartbeat RPC. Without the fix +// (using context.Background() for the RPC), stop() would hang forever because +// the blocked RPC prevents run() from seeing stopC. With the fix +// (heartbeatCtx), stop() cancels the context first, unblocking the RPC. +func TestStopCancelsInFlightHeartbeatRPC(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockService := workflowservicemock.NewMockWorkflowServiceClient(ctrl) + + mockService.EXPECT().GetSystemInfo(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&workflowservice.GetSystemInfoResponse{}, nil).AnyTimes() + + // Simulate an RPC that blocks until its context is cancelled. + heartbeatStarted := make(chan struct{}) + mockService.EXPECT().RecordWorkerHeartbeat(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, _ *workflowservice.RecordWorkerHeartbeatRequest, _ ...grpc.CallOption) (*workflowservice.RecordWorkerHeartbeatResponse, error) { + close(heartbeatStarted) + <-ctx.Done() + return nil, ctx.Err() + }).AnyTimes() + + wfClient := NewServiceClient(mockService, nil, ClientOptions{}) + + heartbeatCtx, heartbeatCancel := context.WithCancel(context.Background()) + hw := &sharedNamespaceWorker{ + client: wfClient, + namespace: "test-ns", + interval: 50 * time.Millisecond, + heartbeatCtx: heartbeatCtx, + heartbeatCancel: heartbeatCancel, + callbacks: map[string]func() *workerpb.WorkerHeartbeat{ + "worker1": func() *workerpb.WorkerHeartbeat { return &workerpb.WorkerHeartbeat{} }, + }, + stopC: make(chan struct{}), + stoppedC: make(chan struct{}), + logger: ilog.NewDefaultLogger(), + } + hw.started.Store(true) + go hw.run() + + // Wait for the heartbeat RPC to be in-flight. + select { + case <-heartbeatStarted: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for heartbeat RPC to start") + } + + // stop() should return promptly because heartbeatCancel() unblocks the + // in-flight RPC. Without the fix, this hangs forever. + done := make(chan struct{}) + go func() { + hw.stop() + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for stop() — in-flight heartbeat RPC was not cancelled") + } +}