Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions internal/internal_worker_heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,17 @@ func (m *heartbeatManager) registerWorker(
hw, ok := m.workers[namespace]
// If this is the first worker on the namespace, start a new shared namespace worker.
if !ok {
heartbeatCtx, heartbeatCancel := context.WithCancel(context.Background())
hw = &sharedNamespaceWorker{
client: m.client,
namespace: namespace,
interval: m.interval,
callbacks: make(map[string]func() *workerpb.WorkerHeartbeat),
stopC: make(chan struct{}),
stoppedC: make(chan struct{}),
logger: m.logger,
client: m.client,
namespace: namespace,
interval: m.interval,
heartbeatCtx: heartbeatCtx,
heartbeatCancel: heartbeatCancel,
callbacks: make(map[string]func() *workerpb.WorkerHeartbeat),
stopC: make(chan struct{}),
stoppedC: make(chan struct{}),
logger: m.logger,
}
m.workers[namespace] = hw
if hw.started.Swap(true) {
Expand Down Expand Up @@ -113,6 +116,9 @@ type sharedNamespaceWorker struct {
interval time.Duration
logger log.Logger

heartbeatCtx context.Context
heartbeatCancel context.CancelFunc

// callbacksMutex should only be unlocked under
callbacksMutex sync.RWMutex
callbacks map[string]func() *workerpb.WorkerHeartbeat // workerInstanceKey -> callback
Expand Down Expand Up @@ -159,7 +165,7 @@ func (hw *sharedNamespaceWorker) sendHeartbeats() error {
heartbeats = append(heartbeats, hb)
}

_, err := hw.client.recordWorkerHeartbeat(context.Background(), &workflowservice.RecordWorkerHeartbeatRequest{
_, err := hw.client.recordWorkerHeartbeat(hw.heartbeatCtx, &workflowservice.RecordWorkerHeartbeatRequest{
Namespace: hw.namespace,
WorkerHeartbeat: heartbeats,
})
Expand All @@ -179,6 +185,9 @@ func (hw *sharedNamespaceWorker) stop() {
if !hw.started.CompareAndSwap(true, false) {
return
}
if hw.heartbeatCancel != nil {
hw.heartbeatCancel()
}

close(hw.stopC)
<-hw.stoppedC
Expand Down
78 changes: 78 additions & 0 deletions internal/internal_worker_heartbeat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package internal

import (
"context"
"testing"
"time"

"github.com/golang/mock/gomock"
workerpb "go.temporal.io/api/worker/v1"
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/api/workflowservicemock/v1"
ilog "go.temporal.io/sdk/internal/log"
"google.golang.org/grpc"
)

// TestStopCancelsInFlightHeartbeatRPC verifies that calling stop() on a
// sharedNamespaceWorker cancels an in-flight heartbeat RPC. Without the fix
// (using context.Background() for the RPC), stop() would hang forever because
// the blocked RPC prevents run() from seeing stopC. With the fix
// (heartbeatCtx), stop() cancels the context first, unblocking the RPC.
func TestStopCancelsInFlightHeartbeatRPC(t *testing.T) {
t.Parallel()

ctrl := gomock.NewController(t)
mockService := workflowservicemock.NewMockWorkflowServiceClient(ctrl)

mockService.EXPECT().GetSystemInfo(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&workflowservice.GetSystemInfoResponse{}, nil).AnyTimes()

// Simulate an RPC that blocks until its context is cancelled.
heartbeatStarted := make(chan struct{})
mockService.EXPECT().RecordWorkerHeartbeat(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, _ *workflowservice.RecordWorkerHeartbeatRequest, _ ...grpc.CallOption) (*workflowservice.RecordWorkerHeartbeatResponse, error) {
close(heartbeatStarted)
<-ctx.Done()
return nil, ctx.Err()
}).AnyTimes()

wfClient := NewServiceClient(mockService, nil, ClientOptions{})

heartbeatCtx, heartbeatCancel := context.WithCancel(context.Background())
hw := &sharedNamespaceWorker{
client: wfClient,
namespace: "test-ns",
interval: 50 * time.Millisecond,
heartbeatCtx: heartbeatCtx,
heartbeatCancel: heartbeatCancel,
callbacks: map[string]func() *workerpb.WorkerHeartbeat{
"worker1": func() *workerpb.WorkerHeartbeat { return &workerpb.WorkerHeartbeat{} },
},
stopC: make(chan struct{}),
stoppedC: make(chan struct{}),
logger: ilog.NewDefaultLogger(),
}
hw.started.Store(true)
go hw.run()

// Wait for the heartbeat RPC to be in-flight.
select {
case <-heartbeatStarted:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for heartbeat RPC to start")
}

// stop() should return promptly because heartbeatCancel() unblocks the
// in-flight RPC. Without the fix, this hangs forever.
done := make(chan struct{})
go func() {
hw.stop()
close(done)
}()

select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for stop() — in-flight heartbeat RPC was not cancelled")
}
}
Loading