diff --git a/service/matching/forwarder.go b/service/matching/forwarder.go index cfdf18757f..018aeffca1 100644 --- a/service/matching/forwarder.go +++ b/service/matching/forwarder.go @@ -248,6 +248,7 @@ func (fwdr *Forwarder) ForwardPoll(ctx context.Context, pollMetadata *pollMetada Identity: identity, WorkerVersionCapabilities: pollMetadata.workerVersionCapabilities, DeploymentOptions: pollMetadata.deploymentOptions, + WorkerInstanceKey: pollMetadata.workerInstanceKey, }, ForwardedSource: fwdr.partition.RpcName(), Conditions: pollMetadata.conditions, @@ -271,6 +272,7 @@ func (fwdr *Forwarder) ForwardPoll(ctx context.Context, pollMetadata *pollMetada TaskQueueMetadata: pollMetadata.taskQueueMetadata, WorkerVersionCapabilities: pollMetadata.workerVersionCapabilities, DeploymentOptions: pollMetadata.deploymentOptions, + WorkerInstanceKey: pollMetadata.workerInstanceKey, }, ForwardedSource: fwdr.partition.RpcName(), Conditions: pollMetadata.conditions, @@ -293,6 +295,7 @@ func (fwdr *Forwarder) ForwardPoll(ctx context.Context, pollMetadata *pollMetada Identity: identity, WorkerVersionCapabilities: pollMetadata.workerVersionCapabilities, DeploymentOptions: pollMetadata.deploymentOptions, + WorkerInstanceKey: pollMetadata.workerInstanceKey, // Namespace is ignored here. }, ForwardedSource: fwdr.partition.RpcName(), diff --git a/service/matching/forwarder_test.go b/service/matching/forwarder_test.go index 8298d23160..3d545431a1 100644 --- a/service/matching/forwarder_test.go +++ b/service/matching/forwarder_test.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/suite" enumspb "go.temporal.io/api/enums/v1" + "go.temporal.io/api/workflowservice/v1" enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/matchingservice/v1" "go.temporal.io/server/api/matchingservicemock/v1" @@ -282,6 +283,34 @@ func (t *ForwarderTestSuite) TestForwardPollWorkflowTaskQueue() { t.Nil(task.pollActivityTaskQueueResponse()) } +func (t *ForwarderTestSuite) TestForwardPollWorkflowTaskQueuePreservesWorkerInstanceKey() { + t.usingTaskqueuePartition(enumspb.TASK_QUEUE_TYPE_WORKFLOW) + + pollerID := uuid.NewString() + workerInstanceKey := "test-worker-instance-" + uuid.NewString() + ctx := context.WithValue(context.Background(), pollerIDKey, pollerID) + ctx = context.WithValue(ctx, identityKey, "id1") + resp := &matchingservice.PollWorkflowTaskQueueResponse{ + TaskToken: []byte("token1"), + } + + var request *matchingservice.PollWorkflowTaskQueueRequest + t.client.EXPECT().PollWorkflowTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()).Do( + func(arg0 context.Context, arg1 *matchingservice.PollWorkflowTaskQueueRequest, arg2 ...interface{}) { + request = arg1 + }, + ).Return(resp, nil) + + task, err := t.fwdr.ForwardPoll(ctx, &pollMetadata{ + workerInstanceKey: workerInstanceKey, + }) + t.Require().NoError(err) + t.NotNil(task) + t.NotNil(request) + t.Equal(workerInstanceKey, request.GetPollRequest().GetWorkerInstanceKey(), + "WorkerInstanceKey should be preserved when forwarding workflow poll") +} + func (t *ForwarderTestSuite) TestForwardPollForActivity() { t.usingTaskqueuePartition(enumspb.TASK_QUEUE_TYPE_ACTIVITY) @@ -300,7 +329,7 @@ func (t *ForwarderTestSuite) TestForwardPollForActivity() { ).Return(resp, nil) task, err := t.fwdr.ForwardPoll(ctx, &pollMetadata{}) - t.NoError(err) + t.Require().NoError(err) t.NotNil(task) t.NotNil(request) t.Equal(pollerID, request.GetPollerId()) @@ -312,6 +341,64 @@ func (t *ForwarderTestSuite) TestForwardPollForActivity() { t.Nil(task.pollWorkflowTaskQueueResponse()) } +func (t *ForwarderTestSuite) TestForwardPollForActivityPreservesWorkerInstanceKey() { + t.usingTaskqueuePartition(enumspb.TASK_QUEUE_TYPE_ACTIVITY) + + pollerID := uuid.NewString() + workerInstanceKey := "test-worker-instance-" + uuid.NewString() + ctx := context.WithValue(context.Background(), pollerIDKey, pollerID) + ctx = context.WithValue(ctx, identityKey, "id1") + resp := &matchingservice.PollActivityTaskQueueResponse{ + TaskToken: []byte("token1"), + } + + var request *matchingservice.PollActivityTaskQueueRequest + t.client.EXPECT().PollActivityTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()).Do( + func(arg0 context.Context, arg1 *matchingservice.PollActivityTaskQueueRequest, arg2 ...interface{}) { + request = arg1 + }, + ).Return(resp, nil) + + task, err := t.fwdr.ForwardPoll(ctx, &pollMetadata{ + workerInstanceKey: workerInstanceKey, + }) + t.Require().NoError(err) + t.NotNil(task) + t.NotNil(request) + t.Equal(workerInstanceKey, request.GetPollRequest().GetWorkerInstanceKey(), + "WorkerInstanceKey should be preserved when forwarding activity poll") +} + +func (t *ForwarderTestSuite) TestForwardPollForNexusPreservesWorkerInstanceKey() { + t.usingTaskqueuePartition(enumspb.TASK_QUEUE_TYPE_NEXUS) + + pollerID := uuid.NewString() + workerInstanceKey := "test-worker-instance-" + uuid.NewString() + ctx := context.WithValue(context.Background(), pollerIDKey, pollerID) + ctx = context.WithValue(ctx, identityKey, "id1") + resp := &matchingservice.PollNexusTaskQueueResponse{ + Response: &workflowservice.PollNexusTaskQueueResponse{ + TaskToken: []byte("token1"), + }, + } + + var request *matchingservice.PollNexusTaskQueueRequest + t.client.EXPECT().PollNexusTaskQueue(gomock.Any(), gomock.Any(), gomock.Any()).Do( + func(arg0 context.Context, arg1 *matchingservice.PollNexusTaskQueueRequest, arg2 ...interface{}) { + request = arg1 + }, + ).Return(resp, nil) + + task, err := t.fwdr.ForwardPoll(ctx, &pollMetadata{ + workerInstanceKey: workerInstanceKey, + }) + t.Require().NoError(err) + t.NotNil(task) + t.NotNil(request) + t.Equal(workerInstanceKey, request.GetRequest().GetWorkerInstanceKey(), + "WorkerInstanceKey should be preserved when forwarding Nexus poll") +} + // TODO(pri): old matcher cleanup func (t *ForwarderTestSuite) TestMaxOutstandingConcurrency() { if t.newFwdr { diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index e0f1271bda..1122bbc829 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -94,6 +94,7 @@ type ( conditions *matchingservice.PollConditions forwardedFrom string localPollStartTime time.Time + workerInstanceKey string } userDataUpdate struct { @@ -638,6 +639,7 @@ pollLoop: deploymentOptions: request.DeploymentOptions, forwardedFrom: req.ForwardedSource, conditions: req.Conditions, + workerInstanceKey: request.WorkerInstanceKey, } task, versionSetUsed, err := e.pollTask(pollerCtx, partition, pollMetadata) if err != nil { @@ -942,6 +944,7 @@ pollLoop: deploymentOptions: request.DeploymentOptions, forwardedFrom: req.ForwardedSource, conditions: req.Conditions, + workerInstanceKey: request.WorkerInstanceKey, } task, versionSetUsed, err := e.pollTask(pollerCtx, partition, pollMetadata) if err != nil { @@ -2495,6 +2498,7 @@ pollLoop: deploymentOptions: request.DeploymentOptions, forwardedFrom: req.ForwardedSource, conditions: req.Conditions, + workerInstanceKey: request.WorkerInstanceKey, } task, _, err := e.pollTask(pollerCtx, partition, pollMetadata) if err != nil { diff --git a/service/matching/pri_forwarder.go b/service/matching/pri_forwarder.go index 0f68bcf245..ca1e1f6f5c 100644 --- a/service/matching/pri_forwarder.go +++ b/service/matching/pri_forwarder.go @@ -218,6 +218,7 @@ func ForwardPollWithTarget( Identity: identity, WorkerVersionCapabilities: pollMetadata.workerVersionCapabilities, DeploymentOptions: pollMetadata.deploymentOptions, + WorkerInstanceKey: pollMetadata.workerInstanceKey, }, ForwardedSource: source.RpcName(), Conditions: pollMetadata.conditions, @@ -241,6 +242,7 @@ func ForwardPollWithTarget( TaskQueueMetadata: pollMetadata.taskQueueMetadata, WorkerVersionCapabilities: pollMetadata.workerVersionCapabilities, DeploymentOptions: pollMetadata.deploymentOptions, + WorkerInstanceKey: pollMetadata.workerInstanceKey, }, ForwardedSource: source.RpcName(), Conditions: pollMetadata.conditions, @@ -263,6 +265,7 @@ func ForwardPollWithTarget( Identity: identity, WorkerVersionCapabilities: pollMetadata.workerVersionCapabilities, DeploymentOptions: pollMetadata.deploymentOptions, + WorkerInstanceKey: pollMetadata.workerInstanceKey, // Namespace is ignored here. }, ForwardedSource: source.RpcName(),