From 48c9392b172543a1b8eb26dad6ee7e352fd10629 Mon Sep 17 00:00:00 2001 From: PJ Doerner Date: Thu, 11 Dec 2025 09:03:50 -0800 Subject: [PATCH 01/26] Update Nexus error model to use Temporal failures --- api/matchingservice/v1/request_response.pb.go | 163 ++++++++++-------- chasm/lib/callback/chasm_invocation.go | 2 +- common/nexus/failure.go | 135 ++++++++++++--- common/nexus/nexusrpc/client.go | 125 ++++++-------- common/nexus/nexusrpc/completion.go | 16 +- common/nexus/nexusrpc/completion_test.go | 6 +- common/nexus/nexusrpc/server.go | 14 +- common/nexus/nexusrpc/setup_test.go | 10 +- components/callbacks/chasm_invocation.go | 2 +- components/nexusoperations/executors.go | 47 +---- go.mod | 2 +- go.sum | 4 +- .../matchingservice/v1/request_response.proto | 7 +- service/frontend/nexus_handler.go | 60 +++++-- service/frontend/workflow_handler.go | 21 ++- 15 files changed, 366 insertions(+), 248 deletions(-) diff --git a/api/matchingservice/v1/request_response.pb.go b/api/matchingservice/v1/request_response.pb.go index 88ac1fa827..852dae0299 100644 --- a/api/matchingservice/v1/request_response.pb.go +++ b/api/matchingservice/v1/request_response.pb.go @@ -14,16 +14,17 @@ import ( v11 "go.temporal.io/api/common/v1" v112 "go.temporal.io/api/deployment/v1" v19 "go.temporal.io/api/enums/v1" + v114 "go.temporal.io/api/failure/v1" v16 "go.temporal.io/api/history/v1" v113 "go.temporal.io/api/nexus/v1" v15 "go.temporal.io/api/protocol/v1" v12 "go.temporal.io/api/query/v1" v14 "go.temporal.io/api/taskqueue/v1" - v114 "go.temporal.io/api/worker/v1" + v115 "go.temporal.io/api/worker/v1" v1 "go.temporal.io/api/workflowservice/v1" v17 "go.temporal.io/server/api/clock/v1" v110 "go.temporal.io/server/api/deployment/v1" - v115 "go.temporal.io/server/api/enums/v1" + v116 "go.temporal.io/server/api/enums/v1" v13 "go.temporal.io/server/api/history/v1" v111 "go.temporal.io/server/api/persistence/v1" v18 "go.temporal.io/server/api/taskqueue/v1" @@ -3523,6 +3524,7 @@ type DispatchNexusTaskResponse struct { // *DispatchNexusTaskResponse_HandlerError // *DispatchNexusTaskResponse_Response // *DispatchNexusTaskResponse_RequestTimeout + // *DispatchNexusTaskResponse_Failure Outcome isDispatchNexusTaskResponse_Outcome `protobuf_oneof:"outcome"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -3565,6 +3567,7 @@ func (x *DispatchNexusTaskResponse) GetOutcome() isDispatchNexusTaskResponse_Out return nil } +// Deprecated: Marked as deprecated in temporal/server/api/matchingservice/v1/request_response.proto. func (x *DispatchNexusTaskResponse) GetHandlerError() *v113.HandlerError { if x != nil { if x, ok := x.Outcome.(*DispatchNexusTaskResponse_HandlerError); ok { @@ -3592,12 +3595,23 @@ func (x *DispatchNexusTaskResponse) GetRequestTimeout() *DispatchNexusTaskRespon return nil } +func (x *DispatchNexusTaskResponse) GetFailure() *v114.Failure { + if x != nil { + if x, ok := x.Outcome.(*DispatchNexusTaskResponse_Failure); ok { + return x.Failure + } + } + return nil +} + type isDispatchNexusTaskResponse_Outcome interface { isDispatchNexusTaskResponse_Outcome() } type DispatchNexusTaskResponse_HandlerError struct { - // Set if the worker's handler failed the nexus task. + // Deprecated. Use failure field instead. + // + // Deprecated: Marked as deprecated in temporal/server/api/matchingservice/v1/request_response.proto. HandlerError *v113.HandlerError `protobuf:"bytes,1,opt,name=handler_error,json=handlerError,proto3,oneof"` } @@ -3610,12 +3624,19 @@ type DispatchNexusTaskResponse_RequestTimeout struct { RequestTimeout *DispatchNexusTaskResponse_Timeout `protobuf:"bytes,3,opt,name=request_timeout,json=requestTimeout,proto3,oneof"` } +type DispatchNexusTaskResponse_Failure struct { + // Set if the worker's handler failed the nexus task. Must contain a NexusHandlerFailureInfo object. + Failure *v114.Failure `protobuf:"bytes,4,opt,name=failure,proto3,oneof"` +} + func (*DispatchNexusTaskResponse_HandlerError) isDispatchNexusTaskResponse_Outcome() {} func (*DispatchNexusTaskResponse_Response) isDispatchNexusTaskResponse_Outcome() {} func (*DispatchNexusTaskResponse_RequestTimeout) isDispatchNexusTaskResponse_Outcome() {} +func (*DispatchNexusTaskResponse_Failure) isDispatchNexusTaskResponse_Outcome() {} + type PollNexusTaskQueueRequest struct { state protoimpl.MessageState `protogen:"open.v1"` NamespaceId string `protobuf:"bytes,1,opt,name=namespace_id,json=namespaceId,proto3" json:"namespace_id,omitempty"` @@ -4536,7 +4557,7 @@ func (x *ListWorkersRequest) GetListRequest() *v1.ListWorkersRequest { type ListWorkersResponse struct { state protoimpl.MessageState `protogen:"open.v1"` - WorkersInfo []*v114.WorkerInfo `protobuf:"bytes,1,rep,name=workers_info,json=workersInfo,proto3" json:"workers_info,omitempty"` + WorkersInfo []*v115.WorkerInfo `protobuf:"bytes,1,rep,name=workers_info,json=workersInfo,proto3" json:"workers_info,omitempty"` NextPageToken []byte `protobuf:"bytes,2,opt,name=next_page_token,json=nextPageToken,proto3" json:"next_page_token,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -4572,7 +4593,7 @@ func (*ListWorkersResponse) Descriptor() ([]byte, []int) { return file_temporal_server_api_matchingservice_v1_request_response_proto_rawDescGZIP(), []int{69} } -func (x *ListWorkersResponse) GetWorkersInfo() []*v114.WorkerInfo { +func (x *ListWorkersResponse) GetWorkersInfo() []*v115.WorkerInfo { if x != nil { return x.WorkersInfo } @@ -4747,7 +4768,7 @@ func (x *DescribeWorkerRequest) GetRequest() *v1.DescribeWorkerRequest { type DescribeWorkerResponse struct { state protoimpl.MessageState `protogen:"open.v1"` - WorkerInfo *v114.WorkerInfo `protobuf:"bytes,1,opt,name=worker_info,json=workerInfo,proto3" json:"worker_info,omitempty"` + WorkerInfo *v115.WorkerInfo `protobuf:"bytes,1,opt,name=worker_info,json=workerInfo,proto3" json:"worker_info,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -4782,7 +4803,7 @@ func (*DescribeWorkerResponse) Descriptor() ([]byte, []int) { return file_temporal_server_api_matchingservice_v1_request_response_proto_rawDescGZIP(), []int{73} } -func (x *DescribeWorkerResponse) GetWorkerInfo() *v114.WorkerInfo { +func (x *DescribeWorkerResponse) GetWorkerInfo() *v115.WorkerInfo { if x != nil { return x.WorkerInfo } @@ -4805,7 +4826,7 @@ type UpdateFairnessStateRequest struct { NamespaceId string `protobuf:"bytes,1,opt,name=namespace_id,json=namespaceId,proto3" json:"namespace_id,omitempty"` TaskQueue string `protobuf:"bytes,2,opt,name=task_queue,json=taskQueue,proto3" json:"task_queue,omitempty"` TaskQueueType v19.TaskQueueType `protobuf:"varint,3,opt,name=task_queue_type,json=taskQueueType,proto3,enum=temporal.api.enums.v1.TaskQueueType" json:"task_queue_type,omitempty"` - FairnessState v115.FairnessState `protobuf:"varint,4,opt,name=fairness_state,json=fairnessState,proto3,enum=temporal.server.api.enums.v1.FairnessState" json:"fairness_state,omitempty"` + FairnessState v116.FairnessState `protobuf:"varint,4,opt,name=fairness_state,json=fairnessState,proto3,enum=temporal.server.api.enums.v1.FairnessState" json:"fairness_state,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -4861,11 +4882,11 @@ func (x *UpdateFairnessStateRequest) GetTaskQueueType() v19.TaskQueueType { return v19.TaskQueueType(0) } -func (x *UpdateFairnessStateRequest) GetFairnessState() v115.FairnessState { +func (x *UpdateFairnessStateRequest) GetFairnessState() v116.FairnessState { if x != nil { return x.FairnessState } - return v115.FairnessState(0) + return v116.FairnessState(0) } type UpdateFairnessStateResponse struct { @@ -5340,7 +5361,7 @@ var File_temporal_server_api_matchingservice_v1_request_response_proto protorefl const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc = "" + "\n" + - "=temporal/server/api/matchingservice/v1/request_response.proto\x12&temporal.server.api.matchingservice.v1\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a$temporal/api/common/v1/message.proto\x1a(temporal/api/deployment/v1/message.proto\x1a&temporal/api/enums/v1/task_queue.proto\x1a%temporal/api/history/v1/message.proto\x1a'temporal/api/taskqueue/v1/message.proto\x1a#temporal/api/query/v1/message.proto\x1a&temporal/api/protocol/v1/message.proto\x1a*temporal/server/api/clock/v1/message.proto\x1a/temporal/server/api/deployment/v1/message.proto\x1a,temporal/server/api/history/v1/message.proto\x1a.temporal/server/api/persistence/v1/nexus.proto\x1a4temporal/server/api/persistence/v1/task_queues.proto\x1a.temporal/server/api/taskqueue/v1/message.proto\x1a1temporal/server/api/enums/v1/fairness_state.proto\x1a6temporal/api/workflowservice/v1/request_response.proto\x1a#temporal/api/nexus/v1/message.proto\x1a$temporal/api/worker/v1/message.proto\"\xc3\x02\n" + + "=temporal/server/api/matchingservice/v1/request_response.proto\x12&temporal.server.api.matchingservice.v1\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a$temporal/api/common/v1/message.proto\x1a(temporal/api/deployment/v1/message.proto\x1a&temporal/api/enums/v1/task_queue.proto\x1a%temporal/api/failure/v1/message.proto\x1a%temporal/api/history/v1/message.proto\x1a'temporal/api/taskqueue/v1/message.proto\x1a#temporal/api/query/v1/message.proto\x1a&temporal/api/protocol/v1/message.proto\x1a*temporal/server/api/clock/v1/message.proto\x1a/temporal/server/api/deployment/v1/message.proto\x1a,temporal/server/api/history/v1/message.proto\x1a.temporal/server/api/persistence/v1/nexus.proto\x1a4temporal/server/api/persistence/v1/task_queues.proto\x1a.temporal/server/api/taskqueue/v1/message.proto\x1a1temporal/server/api/enums/v1/fairness_state.proto\x1a6temporal/api/workflowservice/v1/request_response.proto\x1a#temporal/api/nexus/v1/message.proto\x1a$temporal/api/worker/v1/message.proto\"\xc3\x02\n" + "\x1cPollWorkflowTaskQueueRequest\x12!\n" + "\fnamespace_id\x18\x01 \x01(\tR\vnamespaceId\x12\x1b\n" + "\tpoller_id\x18\x02 \x01(\tR\bpollerId\x12`\n" + @@ -5632,11 +5653,12 @@ const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc "\n" + "task_queue\x18\x02 \x01(\v2$.temporal.api.taskqueue.v1.TaskQueueR\ttaskQueue\x128\n" + "\arequest\x18\x03 \x01(\v2\x1e.temporal.api.nexus.v1.RequestR\arequest\x12T\n" + - "\fforward_info\x18\x04 \x01(\v21.temporal.server.api.taskqueue.v1.TaskForwardInfoR\vforwardInfo\"\xb2\x02\n" + - "\x19DispatchNexusTaskResponse\x12J\n" + - "\rhandler_error\x18\x01 \x01(\v2#.temporal.api.nexus.v1.HandlerErrorH\x00R\fhandlerError\x12=\n" + + "\fforward_info\x18\x04 \x01(\v21.temporal.server.api.taskqueue.v1.TaskForwardInfoR\vforwardInfo\"\xf4\x02\n" + + "\x19DispatchNexusTaskResponse\x12N\n" + + "\rhandler_error\x18\x01 \x01(\v2#.temporal.api.nexus.v1.HandlerErrorB\x02\x18\x01H\x00R\fhandlerError\x12=\n" + "\bresponse\x18\x02 \x01(\v2\x1f.temporal.api.nexus.v1.ResponseH\x00R\bresponse\x12t\n" + - "\x0frequest_timeout\x18\x03 \x01(\v2I.temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.TimeoutH\x00R\x0erequestTimeout\x1a\t\n" + + "\x0frequest_timeout\x18\x03 \x01(\v2I.temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.TimeoutH\x00R\x0erequestTimeout\x12<\n" + + "\afailure\x18\x04 \x01(\v2 .temporal.api.failure.v1.FailureH\x00R\afailure\x1a\t\n" + "\aTimeoutB\t\n" + "\aoutcome\"\xb4\x02\n" + "\x19PollNexusTaskQueueRequest\x12!\n" + @@ -5871,23 +5893,24 @@ var file_temporal_server_api_matchingservice_v1_request_response_proto_goTypes = (*v113.Request)(nil), // 129: temporal.api.nexus.v1.Request (*v113.HandlerError)(nil), // 130: temporal.api.nexus.v1.HandlerError (*v113.Response)(nil), // 131: temporal.api.nexus.v1.Response - (*v1.PollNexusTaskQueueRequest)(nil), // 132: temporal.api.workflowservice.v1.PollNexusTaskQueueRequest - (*v1.PollNexusTaskQueueResponse)(nil), // 133: temporal.api.workflowservice.v1.PollNexusTaskQueueResponse - (*v1.RespondNexusTaskCompletedRequest)(nil), // 134: temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest - (*v1.RespondNexusTaskFailedRequest)(nil), // 135: temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest - (*v111.NexusEndpointSpec)(nil), // 136: temporal.server.api.persistence.v1.NexusEndpointSpec - (*v111.NexusEndpointEntry)(nil), // 137: temporal.server.api.persistence.v1.NexusEndpointEntry - (*v1.RecordWorkerHeartbeatRequest)(nil), // 138: temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest - (*v1.ListWorkersRequest)(nil), // 139: temporal.api.workflowservice.v1.ListWorkersRequest - (*v114.WorkerInfo)(nil), // 140: temporal.api.worker.v1.WorkerInfo - (*v1.UpdateTaskQueueConfigRequest)(nil), // 141: temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest - (*v14.TaskQueueConfig)(nil), // 142: temporal.api.taskqueue.v1.TaskQueueConfig - (*v1.DescribeWorkerRequest)(nil), // 143: temporal.api.workflowservice.v1.DescribeWorkerRequest - (v115.FairnessState)(0), // 144: temporal.server.api.enums.v1.FairnessState - (*v14.TaskQueueStats)(nil), // 145: temporal.api.taskqueue.v1.TaskQueueStats - (*v18.TaskQueueVersionInfoInternal)(nil), // 146: temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal - (*v1.UpdateWorkerBuildIdCompatibilityRequest)(nil), // 147: temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest - (*v110.WorkerDeploymentVersionData)(nil), // 148: temporal.server.api.deployment.v1.WorkerDeploymentVersionData + (*v114.Failure)(nil), // 132: temporal.api.failure.v1.Failure + (*v1.PollNexusTaskQueueRequest)(nil), // 133: temporal.api.workflowservice.v1.PollNexusTaskQueueRequest + (*v1.PollNexusTaskQueueResponse)(nil), // 134: temporal.api.workflowservice.v1.PollNexusTaskQueueResponse + (*v1.RespondNexusTaskCompletedRequest)(nil), // 135: temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest + (*v1.RespondNexusTaskFailedRequest)(nil), // 136: temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest + (*v111.NexusEndpointSpec)(nil), // 137: temporal.server.api.persistence.v1.NexusEndpointSpec + (*v111.NexusEndpointEntry)(nil), // 138: temporal.server.api.persistence.v1.NexusEndpointEntry + (*v1.RecordWorkerHeartbeatRequest)(nil), // 139: temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest + (*v1.ListWorkersRequest)(nil), // 140: temporal.api.workflowservice.v1.ListWorkersRequest + (*v115.WorkerInfo)(nil), // 141: temporal.api.worker.v1.WorkerInfo + (*v1.UpdateTaskQueueConfigRequest)(nil), // 142: temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest + (*v14.TaskQueueConfig)(nil), // 143: temporal.api.taskqueue.v1.TaskQueueConfig + (*v1.DescribeWorkerRequest)(nil), // 144: temporal.api.workflowservice.v1.DescribeWorkerRequest + (v116.FairnessState)(0), // 145: temporal.server.api.enums.v1.FairnessState + (*v14.TaskQueueStats)(nil), // 146: temporal.api.taskqueue.v1.TaskQueueStats + (*v18.TaskQueueVersionInfoInternal)(nil), // 147: temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal + (*v1.UpdateWorkerBuildIdCompatibilityRequest)(nil), // 148: temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest + (*v110.WorkerDeploymentVersionData)(nil), // 149: temporal.server.api.deployment.v1.WorkerDeploymentVersionData } var file_temporal_server_api_matchingservice_v1_request_response_proto_depIdxs = []int32{ 88, // 0: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueRequest.poll_request:type_name -> temporal.api.workflowservice.v1.PollWorkflowTaskQueueRequest @@ -5987,43 +6010,44 @@ var file_temporal_server_api_matchingservice_v1_request_response_proto_depIdxs = 130, // 94: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.handler_error:type_name -> temporal.api.nexus.v1.HandlerError 131, // 95: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.response:type_name -> temporal.api.nexus.v1.Response 87, // 96: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.request_timeout:type_name -> temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.Timeout - 132, // 97: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.request:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueRequest - 78, // 98: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.conditions:type_name -> temporal.server.api.matchingservice.v1.PollConditions - 133, // 99: temporal.server.api.matchingservice.v1.PollNexusTaskQueueResponse.response:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueResponse - 93, // 100: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue - 134, // 101: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest - 93, // 102: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue - 135, // 103: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest - 136, // 104: temporal.server.api.matchingservice.v1.CreateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec - 137, // 105: temporal.server.api.matchingservice.v1.CreateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry - 136, // 106: temporal.server.api.matchingservice.v1.UpdateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec - 137, // 107: temporal.server.api.matchingservice.v1.UpdateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry - 137, // 108: temporal.server.api.matchingservice.v1.ListNexusEndpointsResponse.entries:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry - 138, // 109: temporal.server.api.matchingservice.v1.RecordWorkerHeartbeatRequest.heartbeart_request:type_name -> temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest - 139, // 110: temporal.server.api.matchingservice.v1.ListWorkersRequest.list_request:type_name -> temporal.api.workflowservice.v1.ListWorkersRequest - 140, // 111: temporal.server.api.matchingservice.v1.ListWorkersResponse.workers_info:type_name -> temporal.api.worker.v1.WorkerInfo - 141, // 112: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigRequest.update_taskqueue_config:type_name -> temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest - 142, // 113: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigResponse.updated_taskqueue_config:type_name -> temporal.api.taskqueue.v1.TaskQueueConfig - 143, // 114: temporal.server.api.matchingservice.v1.DescribeWorkerRequest.request:type_name -> temporal.api.workflowservice.v1.DescribeWorkerRequest - 140, // 115: temporal.server.api.matchingservice.v1.DescribeWorkerResponse.worker_info:type_name -> temporal.api.worker.v1.WorkerInfo - 111, // 116: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType - 144, // 117: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.fairness_state:type_name -> temporal.server.api.enums.v1.FairnessState - 111, // 118: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType - 113, // 119: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.version:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersion - 91, // 120: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueResponse.QueriesEntry.value:type_name -> temporal.api.query.v1.WorkflowQuery - 111, // 121: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesRequest.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType - 111, // 122: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType - 145, // 123: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats:type_name -> temporal.api.taskqueue.v1.TaskQueueStats - 82, // 124: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats_by_priority_key:type_name -> temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry - 145, // 125: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry.value:type_name -> temporal.api.taskqueue.v1.TaskQueueStats - 146, // 126: temporal.server.api.matchingservice.v1.DescribeTaskQueuePartitionResponse.VersionsInfoInternalEntry.value:type_name -> temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal - 147, // 127: temporal.server.api.matchingservice.v1.UpdateWorkerBuildIdCompatibilityRequest.ApplyPublicRequest.request:type_name -> temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest - 148, // 128: temporal.server.api.matchingservice.v1.SyncDeploymentUserDataRequest.UpsertVersionsDataEntry.value:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersionData - 129, // [129:129] is the sub-list for method output_type - 129, // [129:129] is the sub-list for method input_type - 129, // [129:129] is the sub-list for extension type_name - 129, // [129:129] is the sub-list for extension extendee - 0, // [0:129] is the sub-list for field type_name + 132, // 97: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.failure:type_name -> temporal.api.failure.v1.Failure + 133, // 98: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.request:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueRequest + 78, // 99: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.conditions:type_name -> temporal.server.api.matchingservice.v1.PollConditions + 134, // 100: temporal.server.api.matchingservice.v1.PollNexusTaskQueueResponse.response:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueResponse + 93, // 101: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue + 135, // 102: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest + 93, // 103: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue + 136, // 104: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest + 137, // 105: temporal.server.api.matchingservice.v1.CreateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec + 138, // 106: temporal.server.api.matchingservice.v1.CreateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry + 137, // 107: temporal.server.api.matchingservice.v1.UpdateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec + 138, // 108: temporal.server.api.matchingservice.v1.UpdateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry + 138, // 109: temporal.server.api.matchingservice.v1.ListNexusEndpointsResponse.entries:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry + 139, // 110: temporal.server.api.matchingservice.v1.RecordWorkerHeartbeatRequest.heartbeart_request:type_name -> temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest + 140, // 111: temporal.server.api.matchingservice.v1.ListWorkersRequest.list_request:type_name -> temporal.api.workflowservice.v1.ListWorkersRequest + 141, // 112: temporal.server.api.matchingservice.v1.ListWorkersResponse.workers_info:type_name -> temporal.api.worker.v1.WorkerInfo + 142, // 113: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigRequest.update_taskqueue_config:type_name -> temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest + 143, // 114: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigResponse.updated_taskqueue_config:type_name -> temporal.api.taskqueue.v1.TaskQueueConfig + 144, // 115: temporal.server.api.matchingservice.v1.DescribeWorkerRequest.request:type_name -> temporal.api.workflowservice.v1.DescribeWorkerRequest + 141, // 116: temporal.server.api.matchingservice.v1.DescribeWorkerResponse.worker_info:type_name -> temporal.api.worker.v1.WorkerInfo + 111, // 117: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType + 145, // 118: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.fairness_state:type_name -> temporal.server.api.enums.v1.FairnessState + 111, // 119: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType + 113, // 120: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.version:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersion + 91, // 121: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueResponse.QueriesEntry.value:type_name -> temporal.api.query.v1.WorkflowQuery + 111, // 122: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesRequest.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType + 111, // 123: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType + 146, // 124: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats:type_name -> temporal.api.taskqueue.v1.TaskQueueStats + 82, // 125: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats_by_priority_key:type_name -> temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry + 146, // 126: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry.value:type_name -> temporal.api.taskqueue.v1.TaskQueueStats + 147, // 127: temporal.server.api.matchingservice.v1.DescribeTaskQueuePartitionResponse.VersionsInfoInternalEntry.value:type_name -> temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal + 148, // 128: temporal.server.api.matchingservice.v1.UpdateWorkerBuildIdCompatibilityRequest.ApplyPublicRequest.request:type_name -> temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest + 149, // 129: temporal.server.api.matchingservice.v1.SyncDeploymentUserDataRequest.UpsertVersionsDataEntry.value:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersionData + 130, // [130:130] is the sub-list for method output_type + 130, // [130:130] is the sub-list for method input_type + 130, // [130:130] is the sub-list for extension type_name + 130, // [130:130] is the sub-list for extension extendee + 0, // [0:130] is the sub-list for field type_name } func init() { file_temporal_server_api_matchingservice_v1_request_response_proto_init() } @@ -6050,6 +6074,7 @@ func file_temporal_server_api_matchingservice_v1_request_response_proto_init() { (*DispatchNexusTaskResponse_HandlerError)(nil), (*DispatchNexusTaskResponse_Response)(nil), (*DispatchNexusTaskResponse_RequestTimeout)(nil), + (*DispatchNexusTaskResponse_Failure)(nil), } type x struct{} out := protoimpl.TypeBuilder{ diff --git a/chasm/lib/callback/chasm_invocation.go b/chasm/lib/callback/chasm_invocation.go index bd6f161d61..7cd10772a6 100644 --- a/chasm/lib/callback/chasm_invocation.go +++ b/chasm/lib/callback/chasm_invocation.go @@ -162,7 +162,7 @@ func (c chasmInvocation) getHistoryRequest( Completion: completion, } case *nexusrpc.OperationCompletionUnsuccessful: - apiFailure, err := commonnexus.NexusFailureToAPIFailure(op.Failure, true) + apiFailure, err := commonnexus.NexusFailureToAPIFailure(op.Failure) if err != nil { return nil, fmt.Errorf("failed to convert failure type: %v", err) } diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 35522dcd52..008af7a892 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -4,11 +4,13 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "sync/atomic" "github.com/nexus-rpc/sdk-go/nexus" commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" failurepb "go.temporal.io/api/failure/v1" nexuspb "go.temporal.io/api/nexus/v1" "go.temporal.io/api/serviceerror" @@ -74,17 +76,22 @@ func NexusFailureToProtoFailure(failure nexus.Failure) *nexuspb.Failure { // Mutates the failure temporarily, unsetting the Message field to avoid duplicating the information in the serialized // failure. Mutating was chosen over cloning for performance reasons since this function may be called frequently. func APIFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) { - // Unset message so it's not serialized in the details. + // Unset message and stack trace so it's not serialized in the details. var message string message, failure.Message = failure.Message, "" + var stackTrace string + stackTrace, failure.StackTrace = failure.StackTrace, "" + data, err := protojson.Marshal(failure) failure.Message = message + failure.StackTrace = stackTrace if err != nil { return nexus.Failure{}, err } return nexus.Failure{ - Message: failure.GetMessage(), + Message: failure.GetMessage(), + StackTrace: failure.GetStackTrace(), Metadata: map[string]string{ "type": failureTypeString, }, @@ -92,32 +99,103 @@ func APIFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) }, nil } +type serializedOperationError struct { + State string `json:"state,omitempty"` + // Bytes as base64 encoded string. + EncodedAttributes string `json:"encodedAttributes,omitempty"` +} + +type serializedHandlerError struct { + Type string `json:"type,omitempty"` + RetryableOverride *bool `json:"retryableOverride,omitempty"` + // Bytes as base64 encoded string. + EncodedAttributes string `json:"encodedAttributes,omitempty"` +} + // NexusFailureToAPIFailure converts a Nexus Failure to an API proto Failure. // If the failure metadata "type" field is set to the fullname of the temporal API Failure message, the failure is // reconstructed using protojson.Unmarshal on the failure details field. -func NexusFailureToAPIFailure(failure nexus.Failure, retryable bool) (*failurepb.Failure, error) { - apiFailure := &failurepb.Failure{} +func NexusFailureToAPIFailure(f nexus.Failure) (*failurepb.Failure, error) { + apiFailure := &failurepb.Failure{ + Message: f.Message, + StackTrace: f.StackTrace, + } - if failure.Metadata != nil && failure.Metadata["type"] == failureTypeString { - if err := protojson.Unmarshal(failure.Details, apiFailure); err != nil { - return nil, err - } - } else { - payloads, err := nexusFailureMetadataToPayloads(failure) - if err != nil { - return nil, err - } - apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ - ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ - // Make up a type here, it's not part of the Nexus Failure spec. - Type: "NexusFailure", - Details: payloads, - NonRetryable: !retryable, - }, + if f.Metadata != nil { + switch f.Metadata["type"] { + case failureTypeString: + if err := protojson.Unmarshal(f.Details, apiFailure); err != nil { + return nil, err + } + // Restore these fields as they are not included in the marshalled failure. + apiFailure.Message = f.Message + apiFailure.StackTrace = f.StackTrace + return apiFailure, nil + case "nexus.OperationError": + var se serializedOperationError + err := json.Unmarshal(f.Details, &se) + if err != nil { + return nil, fmt.Errorf("failed to deserialize OperationError: %w", err) + } + apiFailure.FailureInfo = &failurepb.Failure_NexusSdkOperationFailureInfo{ + NexusSdkOperationFailureInfo: &failurepb.NexusSDKOperationFailureInfo{ + State: se.State, + }, + } + if err := protojson.Unmarshal([]byte(se.EncodedAttributes), apiFailure.EncodedAttributes); err != nil { + return nil, fmt.Errorf("failed to deserialize OperationError attributes: %w", err) + } + if f.Cause != nil { + apiFailure.Cause, err = NexusFailureToAPIFailure(*f.Cause) + if err != nil { + return nil, err + } + } + return apiFailure, nil + case "nexus.HandlerError": + var se serializedHandlerError + err := json.Unmarshal(f.Details, &se) + if err != nil { + return nil, fmt.Errorf("failed to deserialize HandlerError: %w", err) + } + var retryBehavior enumspb.NexusHandlerErrorRetryBehavior + if se.RetryableOverride == nil { + retryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_UNSPECIFIED + } else if *se.RetryableOverride { + retryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE + } else { + retryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE + } + apiFailure.FailureInfo = &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: se.Type, + RetryBehavior: retryBehavior, + }, + } + if err := protojson.Unmarshal([]byte(se.EncodedAttributes), apiFailure.EncodedAttributes); err != nil { + return nil, fmt.Errorf("failed to deserialize HandlerError attributes: %w", err) + } + if f.Cause != nil { + apiFailure.Cause, err = NexusFailureToAPIFailure(*f.Cause) + if err != nil { + return nil, err + } + } + return apiFailure, nil } } - // Ensure this always gets written. - apiFailure.Message = failure.Message + + payloads, err := nexusFailureMetadataToPayloads(f) + if err != nil { + return nil, err + } + apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + // Make up a type here, it's not part of the Nexus Failure spec. + Type: "NexusFailure", + Details: payloads, + }, + } return apiFailure, nil } @@ -133,7 +211,7 @@ func OperationErrorToTemporalFailure(opErr *nexus.OperationError) (*failurepb.Fa // Canceled must be translated into a CanceledFailure to match the SDK expectation. if opErr.State == nexus.OperationStateCanceled { if nexusFailure.Metadata != nil && nexusFailure.Metadata["type"] == failureTypeString { - temporalFailure, err := NexusFailureToAPIFailure(nexusFailure, false) + temporalFailure, err := NexusFailureToAPIFailure(nexusFailure) if err != nil { return nil, err } @@ -158,16 +236,19 @@ func OperationErrorToTemporalFailure(opErr *nexus.OperationError) (*failurepb.Fa }, nil } - return NexusFailureToAPIFailure(nexusFailure, false) + return NexusFailureToAPIFailure(nexusFailure) } func nexusFailureMetadataToPayloads(failure nexus.Failure) (*commonpb.Payloads, error) { if len(failure.Metadata) == 0 && len(failure.Details) == 0 { return nil, nil } - // Delete before serializing. - failure.Message = "" - data, err := json.Marshal(failure) + // Only serialize metadata and details. + cpy := nexus.Failure{ + Metadata: failure.Metadata, + Details: failure.Details, + } + data, err := json.Marshal(cpy) if err != nil { return nil, err } diff --git a/common/nexus/nexusrpc/client.go b/common/nexus/nexusrpc/client.go index 126c1f4815..8c3045558f 100644 --- a/common/nexus/nexusrpc/client.go +++ b/common/nexus/nexusrpc/client.go @@ -281,21 +281,29 @@ func (c *HTTPClient) StartOperation( Links: links, }, nil case statusOperationFailed: - state, err := getUnsuccessfulStateFromHeader(response, body) + failure, err := c.failureFromResponse(response, body) if err != nil { return nil, err } - failure, err := c.failureFromResponse(response, body) + opErr, err := c.options.FailureConverter.FailureToError(failure) if err != nil { return nil, err } - failureErr := c.options.FailureConverter.FailureToError(failure) - return nil, &nexus.OperationError{ - State: state, - Cause: failureErr, + // For compatibility with older servers. + if _, ok := opErr.(*nexus.OperationError); !ok { + state, err := getUnsuccessfulStateFromHeader(response, body) + if err != nil { + return nil, err + } + opErr = &nexus.OperationError{ + State: state, + Cause: opErr, + } } + + return nil, opErr default: return nil, c.bestEffortHandlerErrorFromResponse(response, body) } @@ -355,90 +363,63 @@ func (c *HTTPClient) failureFromResponse(response *http.Response, body []byte) ( return failure, err } -func (c *HTTPClient) failureFromResponseOrDefault(response *http.Response, body []byte, defaultMessage string) nexus.Failure { - failure, err := c.failureFromResponse(response, body) +func (c *HTTPClient) defaultErrorFromResponse(response *http.Response, body []byte, cause error) error { + errorType, err := httpStatusCodeToHandlerErrorType(response) if err != nil { - failure.Message = defaultMessage + // TODO: use the provided cause, it's already a deserialized failure. + return newUnexpectedResponseError(err.Error(), response, body) + } + return &nexus.HandlerError{ + Type: errorType, + Message: response.Status, + // For compatibility with older servers. + RetryBehavior: retryBehaviorFromHeader(response.Header), + Cause: cause, } - return failure } -func (c *HTTPClient) failureErrorFromResponseOrDefault(response *http.Response, body []byte, defaultMessage string) error { - failure := c.failureFromResponseOrDefault(response, body, defaultMessage) - failureErr := c.options.FailureConverter.FailureToError(failure) - return failureErr +func (c *HTTPClient) bestEffortHandlerErrorFromResponse(response *http.Response, body []byte) error { + // TODO: support old servers + failure, err := c.failureFromResponse(response, body) + if err != nil { + return c.defaultErrorFromResponse(response, body, nil) + } + convErr, err := c.options.FailureConverter.FailureToError(failure) + if err != nil { + return fmt.Errorf("failed to convert Failure to error: %w", err) + } + if _, ok := convErr.(*nexus.HandlerError); !ok { + convErr = c.defaultErrorFromResponse(response, body, convErr) + } + return convErr } -func (c *HTTPClient) bestEffortHandlerErrorFromResponse(response *http.Response, body []byte) error { +func httpStatusCodeToHandlerErrorType(response *http.Response) (nexus.HandlerErrorType, error) { switch response.StatusCode { case http.StatusBadRequest: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeBadRequest, - Cause: c.failureErrorFromResponseOrDefault(response, body, "bad request"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } - case http.StatusUnauthorized: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnauthenticated, - Cause: c.failureErrorFromResponseOrDefault(response, body, "unauthenticated"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeBadRequest, nil case http.StatusRequestTimeout: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeRequestTimeout, - Cause: c.failureErrorFromResponseOrDefault(response, body, "request timeout"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeRequestTimeout, nil case http.StatusConflict: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeConflict, - Cause: c.failureErrorFromResponseOrDefault(response, body, "conflict"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeConflict, nil + case http.StatusUnauthorized: + return nexus.HandlerErrorTypeUnauthenticated, nil case http.StatusForbidden: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnauthorized, - Cause: c.failureErrorFromResponseOrDefault(response, body, "unauthorized"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeUnauthorized, nil case http.StatusNotFound: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeNotFound, - Cause: c.failureErrorFromResponseOrDefault(response, body, "not found"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeNotFound, nil case http.StatusTooManyRequests: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeResourceExhausted, - Cause: c.failureErrorFromResponseOrDefault(response, body, "resource exhausted"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeResourceExhausted, nil case http.StatusInternalServerError: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeInternal, - Cause: c.failureErrorFromResponseOrDefault(response, body, "internal error"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeInternal, nil case http.StatusNotImplemented: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeNotImplemented, - Cause: c.failureErrorFromResponseOrDefault(response, body, "not implemented"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeNotImplemented, nil case http.StatusServiceUnavailable: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnavailable, - Cause: c.failureErrorFromResponseOrDefault(response, body, "unavailable"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeUnavailable, nil case nexus.StatusUpstreamTimeout: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUpstreamTimeout, - Cause: c.failureErrorFromResponseOrDefault(response, body, "upstream timeout"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeUpstreamTimeout, nil default: - return newUnexpectedResponseError(fmt.Sprintf("unexpected response status: %q", response.Status), response, body) + return "", fmt.Errorf("unexpected response status: %q", response.Status) } } diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index 9c5a49df64..691b09e722 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -165,6 +165,9 @@ type OperationCompletionUnsuccessful struct { type OperationCompletionUnsuccessfulOptions struct { // A [FailureConverter] to convert a [Failure] instance to and from an [error]. Defaults to // [DefaultFailureConverter]. + // + // NOTE: To call server versions <= 0.4.0, use a FailureConverter that unwraps the error cause if message is not + // present. FailureConverter nexus.FailureConverter // OperationID is the unique ID for this operation. Used when a completion callback is received before a started response. // @@ -186,7 +189,10 @@ func NewOperationCompletionUnsuccessful(opErr *nexus.OperationError, options Ope if options.FailureConverter == nil { options.FailureConverter = nexus.DefaultFailureConverter() } - failure := options.FailureConverter.ErrorToFailure(opErr.Cause) + failure, err := options.FailureConverter.ErrorToFailure(opErr) + if err != nil { + return nil, err + } return &OperationCompletionUnsuccessful{ Header: make(nexus.Header), @@ -206,6 +212,8 @@ func (c *OperationCompletionUnsuccessful) applyToHTTPRequest(request *http.Reque if c.Header != nil { addNexusHeaderToHTTPHeader(c.Header, request.Header) } + + // Set the operation state header for backwards compatibility. request.Header.Set(headerOperationState, string(c.State)) request.Header.Set("Content-Type", contentTypeJSON) @@ -322,7 +330,11 @@ func (h *completionHTTPHandler) ServeHTTP(writer http.ResponseWriter, request *h h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to read Failure from request body")) return } - completion.Error = h.failureConverter.FailureToError(failure) + completion.Error, err = h.failureConverter.FailureToError(failure) + if err != nil { + h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode failure from request body")) + return + } case nexus.OperationStateSucceeded: completion.Result = nexus.NewLazyValue( h.options.Serializer, diff --git a/common/nexus/nexusrpc/completion_test.go b/common/nexus/nexusrpc/completion_test.go index 4436634430..a566705749 100644 --- a/common/nexus/nexusrpc/completion_test.go +++ b/common/nexus/nexusrpc/completion_test.go @@ -178,10 +178,10 @@ func TestFailureCompletion(t *testing.T) { ctx, callbackURL, teardown := setupForCompletion(t, &failureExpectingCompletionHandler{ errorChecker: func(err error) error { - if err.Error() != "expected message" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid failure: %v", err) + if opErr, ok := err.(*nexus.OperationError); ok && opErr.Message == "expected message" { + return nil } - return nil + return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid failure: %v", err) }, expectedStartTime: startTime, expectedCloseTime: closeTime, diff --git a/common/nexus/nexusrpc/server.go b/common/nexus/nexusrpc/server.go index a4b16ae893..8347a00d00 100644 --- a/common/nexus/nexusrpc/server.go +++ b/common/nexus/nexusrpc/server.go @@ -98,7 +98,12 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { if errors.As(err, &opError) { operationState = opError.State - failure = h.failureConverter.ErrorToFailure(opError.Cause) + failure, err = h.failureConverter.ErrorToFailure(opError.Cause) + if err != nil { + h.logger.Error("failed to convert operation error cause to failure", "error", err) + writer.WriteHeader(http.StatusInternalServerError) + return + } statusCode = statusOperationFailed if operationState != nexus.OperationStateFailed && operationState != nexus.OperationStateCanceled { @@ -108,7 +113,12 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { } writer.Header().Set(headerOperationState, string(operationState)) } else if errors.As(err, &handlerError) { - failure = h.failureConverter.ErrorToFailure(handlerError.Cause) + failure, err = h.failureConverter.ErrorToFailure(handlerError.Cause) + if err != nil { + h.logger.Error("failed to convert handler error cause to failure", "error", err) + writer.WriteHeader(http.StatusInternalServerError) + return + } switch handlerError.Type { case nexus.HandlerErrorTypeBadRequest: statusCode = http.StatusBadRequest diff --git a/common/nexus/nexusrpc/setup_test.go b/common/nexus/nexusrpc/setup_test.go index 63c65e42b6..58f7f5bdf7 100644 --- a/common/nexus/nexusrpc/setup_test.go +++ b/common/nexus/nexusrpc/setup_test.go @@ -112,19 +112,19 @@ type customFailureConverter struct{} var errCustom = errors.New("custom") // ErrorToFailure implements FailureConverter. -func (c customFailureConverter) ErrorToFailure(err error) nexus.Failure { +func (c customFailureConverter) ErrorToFailure(err error) (nexus.Failure, error) { return nexus.Failure{ Message: err.Error(), Metadata: map[string]string{ "type": "custom", }, - } + }, nil } // FailureToError implements FailureConverter. -func (c customFailureConverter) FailureToError(f nexus.Failure) error { +func (c customFailureConverter) FailureToError(f nexus.Failure) (error, error) { if f.Metadata["type"] != "custom" { - return errors.New(f.Message) + return errors.New(f.Message), nil } - return fmt.Errorf("%w: %s", errCustom, f.Message) + return fmt.Errorf("%w: %s", errCustom, f.Message), nil } diff --git a/components/callbacks/chasm_invocation.go b/components/callbacks/chasm_invocation.go index 87484e2c4c..c5b1355921 100644 --- a/components/callbacks/chasm_invocation.go +++ b/components/callbacks/chasm_invocation.go @@ -114,7 +114,7 @@ func (c chasmInvocation) getHistoryRequest( Completion: completion, } case *nexusrpc.OperationCompletionUnsuccessful: - apiFailure, err := commonnexus.NexusFailureToAPIFailure(op.Failure, true) + apiFailure, err := commonnexus.NexusFailureToAPIFailure(op.Failure) if err != nil { return nil, fmt.Errorf("failed to convert failure type: %v", err) } diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 222a5a07e9..4f3e3c4b1b 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -436,11 +436,11 @@ func (e taskExecutor) saveResult(ctx context.Context, env hsm.Environment, ref h func (e taskExecutor) handleStartOperationError(env hsm.Environment, node *hsm.Node, operation Operation, callErr error) error { var handlerErr *nexus.HandlerError - var opFailedErr *nexus.OperationError + var opErr *nexus.OperationError switch { - case errors.As(callErr, &opFailedErr): - return handleOperationError(node, operation, opFailedErr) + case errors.As(callErr, &opErr): + return handleOperationError(node, operation, opErr) case errors.As(callErr, &handlerErr) && !handlerErr.Retryable(): // The StartOperation request got an unexpected response that is not retryable, fail the operation. // Although Failure is nullable, Nexus SDK is expected to always populate this field @@ -863,44 +863,11 @@ func isDestinationDown(err error) bool { func callErrToFailure(callErr error, retryable bool) (*failurepb.Failure, error) { var handlerErr *nexus.HandlerError if errors.As(callErr, &handlerErr) { - var retryBehavior enumspb.NexusHandlerErrorRetryBehavior - // nolint:exhaustive // unspecified is the default - switch handlerErr.RetryBehavior { - case nexus.HandlerErrorRetryBehaviorRetryable: - retryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE - case nexus.HandlerErrorRetryBehaviorNonRetryable: - retryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE - } - failure := &failurepb.Failure{ - Message: handlerErr.Error(), - FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ - NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ - Type: string(handlerErr.Type), - RetryBehavior: retryBehavior, - }, - }, - } - var failureError *nexus.FailureError - if errors.As(handlerErr.Cause, &failureError) { - var err error - failure.Cause, err = commonnexus.NexusFailureToAPIFailure(failureError.Failure, retryable) - if err != nil { - return nil, err - } - } else { - cause := handlerErr.Cause - if cause == nil { - cause = errors.New("unknown cause") - } - failure.Cause = &failurepb.Failure{ - Message: cause.Error(), - FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ - ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{}, - }, - } + nf, err := nexus.DefaultFailureConverter().ErrorToFailure(handlerErr) + if err != nil { + return nil, err } - - return failure, nil + return commonnexus.NexusFailureToAPIFailure(nf) } return &failurepb.Failure{ diff --git a/go.mod b/go.mod index a660043654..ca726d34fe 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/lib/pq v1.10.9 github.com/maruel/panicparse/v2 v2.4.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/nexus-rpc/sdk-go v0.5.1 + github.com/nexus-rpc/sdk-go v0.5.2-0.20251205193432-20a8501a0c1b github.com/olekukonko/tablewriter v0.0.5 github.com/olivere/elastic/v7 v7.0.32 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 9a46dfd672..2c5d18c1d0 100644 --- a/go.sum +++ b/go.sum @@ -234,8 +234,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/nexus-rpc/sdk-go v0.5.1 h1:UFYYfoHlQc+Pn9gQpmn9QE7xluewAn2AO1OSkAh7YFU= -github.com/nexus-rpc/sdk-go v0.5.1/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= +github.com/nexus-rpc/sdk-go v0.5.2-0.20251205193432-20a8501a0c1b h1:JOkfj0lBDtxFiHaU0eCSfsM6mI/gc8GCOd5AFNxjAjQ= +github.com/nexus-rpc/sdk-go v0.5.2-0.20251205193432-20a8501a0c1b/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= diff --git a/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto b/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto index da9cbc2bb5..4d18819755 100644 --- a/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto +++ b/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto @@ -9,6 +9,7 @@ import "google/protobuf/timestamp.proto"; import "temporal/api/common/v1/message.proto"; import "temporal/api/deployment/v1/message.proto"; import "temporal/api/enums/v1/task_queue.proto"; +import "temporal/api/failure/v1/message.proto"; import "temporal/api/history/v1/message.proto"; import "temporal/api/taskqueue/v1/message.proto"; import "temporal/api/query/v1/message.proto"; @@ -488,11 +489,13 @@ message DispatchNexusTaskResponse { message Timeout {} oneof outcome { - // Set if the worker's handler failed the nexus task. - temporal.api.nexus.v1.HandlerError handler_error = 1; + // Deprecated. Use failure field instead. + temporal.api.nexus.v1.HandlerError handler_error = 1 [deprecated = true]; // Set if the worker's handler responded successfully to the nexus task. temporal.api.nexus.v1.Response response = 2; Timeout request_timeout = 3; + // Set if the worker's handler failed the nexus task. Must contain a NexusHandlerFailureInfo object. + temporal.api.failure.v1.Failure failure = 4; } } diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index 55bf5f28cb..366c5feb3b 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -444,19 +444,32 @@ func (h *nexusHandler) StartOperation( } // Convert to standard Nexus SDK response. switch t := response.GetOutcome().(type) { + case *matchingservice.DispatchNexusTaskResponse_Failure: + oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.Failure.GetNexusHandlerFailureInfo().GetType())) + nf, err := commonnexus.APIFailureToNexusFailure(t.Failure) + if err != nil { + oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + he, err := nexus.DefaultFailureConverter().FailureToError(nf) + if err != nil { + oc.logger.Error("error converting Nexus failure to Nexus HandlerError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + // Failure conversions are our fault so only set this after converting the Temporal failure to a HandlerError. + oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) + return nil, he + case *matchingservice.DispatchNexusTaskResponse_HandlerError: + // Deprecated case. Replaced with DispatchNexusTaskResponse_Failure oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.HandlerError.GetErrorType())) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - err := h.convertOutcomeToNexusHandlerError(t) return nil, err case *matchingservice.DispatchNexusTaskResponse_RequestTimeout: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_timeout")) - oc.setFailureSource(commonnexus.FailureSourceWorker) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeUpstreamTimeout, "upstream timeout") case *matchingservice.DispatchNexusTaskResponse_Response: @@ -471,7 +484,6 @@ func (h *nexusHandler) StartOperation( case *nexuspb.StartOperationResponse_AsyncSuccess: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("async_success")) - token := t.AsyncSuccess.GetOperationToken() if token == "" { token = t.AsyncSuccess.GetOperationId() @@ -484,9 +496,7 @@ func (h *nexusHandler) StartOperation( case *nexuspb.StartOperationResponse_OperationError: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("operation_error")) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - err := &nexus.OperationError{ State: nexus.OperationState(t.OperationError.GetOperationState()), Cause: &nexus.FailureError{ @@ -498,7 +508,6 @@ func (h *nexusHandler) StartOperation( } // This is the worker's fault. oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:EMPTY_OUTCOME")) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") @@ -609,19 +618,32 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, } // Convert to standard Nexus SDK response. switch t := response.GetOutcome().(type) { + case *matchingservice.DispatchNexusTaskResponse_Failure: + oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.Failure.GetNexusHandlerFailureInfo().GetType())) + nf, err := commonnexus.APIFailureToNexusFailure(t.Failure) + if err != nil { + oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + he, err := nexus.DefaultFailureConverter().FailureToError(nf) + if err != nil { + oc.logger.Error("error converting Nexus failure to Nexus HandlerError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + // Failure conversions are our fault so only set this after converting the Temporal failure to a HandlerError. + oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) + return he + case *matchingservice.DispatchNexusTaskResponse_HandlerError: + // Deprecated case. Replaced with DispatchNexusTaskResponse_Failure oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.HandlerError.GetErrorType())) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - err := h.convertOutcomeToNexusHandlerError(t) return err case *matchingservice.DispatchNexusTaskResponse_RequestTimeout: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_timeout")) - oc.setFailureSource(commonnexus.FailureSourceWorker) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUpstreamTimeout, "upstream timeout") case *matchingservice.DispatchNexusTaskResponse_Response: @@ -630,7 +652,6 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, } // This is the worker's fault. oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:EMPTY_OUTCOME")) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") @@ -738,13 +759,18 @@ func (h *nexusHandler) convertOutcomeToNexusHandlerError(resp *matchingservice.D case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: retryBehavior = nexus.HandlerErrorRetryBehaviorNonRetryable } + + nf := commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()) handlerError := &nexus.HandlerError{ - Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), - Cause: &nexus.FailureError{ - Failure: commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()), - }, + Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), + Message: nf.Message, RetryBehavior: retryBehavior, } + if nf.Cause != nil { + handlerError.Cause = &nexus.FailureError{ + Failure: *nf.Cause, + } + } switch handlerError.Type { case nexus.HandlerErrorTypeUpstreamTimeout, diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 9f51ad6bb1..13293663d4 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -5425,8 +5425,13 @@ func (wh *WorkflowHandler) RespondNexusTaskCompleted(ctx context.Context, reques // doesn't go into workflow history, and the Nexus request caller is unknown, there doesn't seem like there's a // good reason to fail at this point. - if details := request.GetResponse().GetStartOperation().GetOperationError().GetFailure().GetDetails(); details != nil && !json.Valid(details) { - return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") + if opErr := request.GetResponse().GetStartOperation().GetOperationError(); opErr != nil { + if details := opErr.GetFailure().GetDetails(); details != nil && !json.Valid(details) { + return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") + } + } + if f := request.GetResponse().GetStartOperation().GetFailure(); f != nil && f.GetNexusSdkOperationFailureInfo() == nil { + return nil, serviceerror.NewInvalidArgument("request StartOperation Failure must contain failure with NexusSdkOperationFailureInfo") } matchingRequest := &matchingservice.RespondNexusTaskCompletedRequest{ @@ -5466,8 +5471,16 @@ func (wh *WorkflowHandler) RespondNexusTaskFailed(ctx context.Context, request * } namespaceId := namespace.ID(tt.GetNamespaceId()) - if details := request.GetError().GetFailure().GetDetails(); details != nil && !json.Valid(details) { - return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") + if request.Error == nil && request.Failure == nil { + return nil, serviceerror.NewInvalidArgument("request must contain error or failure") + } + if request.GetError() != nil { + if details := request.GetError().GetFailure().GetDetails(); details != nil && !json.Valid(details) { + return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") + } + } + if request.GetFailure() != nil && request.GetFailure().GetNexusHandlerFailureInfo() == nil { + return nil, serviceerror.NewInvalidArgument("request Failure must contain error or failure with NexusHandlerFailureInfo") } // NOTE: Not checking blob size limit here as we already enforce the 4 MB gRPC request limit and since this From bc39bbc67fc00419b1d9dbe6fcc5d4ffd51e46ad Mon Sep 17 00:00:00 2001 From: PJ Doerner Date: Tue, 16 Dec 2025 10:42:21 -0800 Subject: [PATCH 02/26] Upgrade SDK dep to 1.38 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index ca726d34fe..3ddedf3faa 100644 --- a/go.mod +++ b/go.mod @@ -58,7 +58,7 @@ require ( go.opentelemetry.io/otel/sdk v1.34.0 go.opentelemetry.io/otel/sdk/metric v1.34.0 go.opentelemetry.io/otel/trace v1.34.0 - go.temporal.io/api v1.61.1-0.20260123144430-3418f5100388 + go.temporal.io/api v1.59.1-0.20251210194227-b7eb0896bf95 go.temporal.io/sdk v1.38.0 go.uber.org/fx v1.24.0 go.uber.org/mock v0.6.0 diff --git a/go.sum b/go.sum index 2c5d18c1d0..4f1ad52761 100644 --- a/go.sum +++ b/go.sum @@ -371,8 +371,8 @@ go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4= go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4= -go.temporal.io/api v1.61.1-0.20260123144430-3418f5100388 h1:Rahqpgjqalbv28RLoOtnNNZvwtnes/sQP0+cisO70Hw= -go.temporal.io/api v1.61.1-0.20260123144430-3418f5100388/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= +go.temporal.io/api v1.59.1-0.20251210194227-b7eb0896bf95 h1:G7ndB6smLBO+JWNs7bBmNdfMHRrhCab7kM19DH1qZBs= +go.temporal.io/api v1.59.1-0.20251210194227-b7eb0896bf95/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= go.temporal.io/sdk v1.38.0 h1:4Bok5LEdED7YKpsSjIa3dDqram5VOq+ydBf4pyx0Wo4= go.temporal.io/sdk v1.38.0/go.mod h1:a+R2Ej28ObvHoILbHaxMyind7M6D+W0L7edt5UJF4SE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= From e0ebac6707c1f031c35551f571d806596cccd021 Mon Sep 17 00:00:00 2001 From: PJ Doerner Date: Tue, 16 Dec 2025 11:13:54 -0800 Subject: [PATCH 03/26] Set capabilities --- service/frontend/nexus_handler.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index 366c5feb3b..cb8c53491e 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -414,6 +414,9 @@ func (h *nexusHandler) StartOperation( Variant: &nexuspb.Request_StartOperation{ StartOperation: &startOperationRequest, }, + Capabilities: &nexuspb.Request_Capabilities{ + TemporalFailureResponses: true, + }, }) if err := oc.interceptRequest(ctx, request, options.Header); err != nil { @@ -457,7 +460,7 @@ func (h *nexusHandler) StartOperation( return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } // Failure conversions are our fault so only set this after converting the Temporal failure to a HandlerError. - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) + oc.setFailureSource(commonnexus.FailureSourceWorker) return nil, he case *matchingservice.DispatchNexusTaskResponse_HandlerError: @@ -599,6 +602,9 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, OperationId: token, }, }, + Capabilities: &nexuspb.Request_Capabilities{ + TemporalFailureResponses: true, + }, }) if err := oc.interceptRequest(ctx, request, options.Header); err != nil { var notActiveErr *serviceerror.NamespaceNotActive @@ -631,7 +637,7 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } // Failure conversions are our fault so only set this after converting the Temporal failure to a HandlerError. - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) + oc.setFailureSource(commonnexus.FailureSourceWorker) return he case *matchingservice.DispatchNexusTaskResponse_HandlerError: From 6a6ea6e1070b82eaee6c8cf6b145b2d618132b62 Mon Sep 17 00:00:00 2001 From: PJ Doerner Date: Tue, 16 Dec 2025 11:44:00 -0800 Subject: [PATCH 04/26] fix writeFailure --- common/nexus/nexusrpc/server.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/common/nexus/nexusrpc/server.go b/common/nexus/nexusrpc/server.go index 8347a00d00..7f62f29fcc 100644 --- a/common/nexus/nexusrpc/server.go +++ b/common/nexus/nexusrpc/server.go @@ -98,11 +98,11 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { if errors.As(err, &opError) { operationState = opError.State - failure, err = h.failureConverter.ErrorToFailure(opError.Cause) - if err != nil { - h.logger.Error("failed to convert operation error cause to failure", "error", err) + var convErr error + failure, convErr = h.failureConverter.ErrorToFailure(opError) + if convErr != nil { + h.logger.Error("failed to convert operation error to failure", "error", convErr) writer.WriteHeader(http.StatusInternalServerError) - return } statusCode = statusOperationFailed @@ -113,11 +113,11 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { } writer.Header().Set(headerOperationState, string(operationState)) } else if errors.As(err, &handlerError) { - failure, err = h.failureConverter.ErrorToFailure(handlerError.Cause) - if err != nil { - h.logger.Error("failed to convert handler error cause to failure", "error", err) + var convErr error + failure, convErr = h.failureConverter.ErrorToFailure(handlerError) + if convErr != nil { + h.logger.Error("failed to convert handler error to failure", "error", convErr) writer.WriteHeader(http.StatusInternalServerError) - return } switch handlerError.Type { case nexus.HandlerErrorTypeBadRequest: From 273c5eb38d43d2cb124bd97845c5acaab10b068a Mon Sep 17 00:00:00 2001 From: PJ Doerner Date: Wed, 17 Dec 2025 09:24:59 -0800 Subject: [PATCH 05/26] Bug fixes and test updates --- common/nexus/failure.go | 28 ++++++++++++++----- components/nexusoperations/completion.go | 2 +- components/nexusoperations/executors.go | 9 +++++- components/nexusoperations/executors_test.go | 16 +++++------ .../nexusoperations/frontend/handler.go | 21 +++++++++----- go.mod | 2 +- go.sum | 4 +-- service/frontend/nexus_handler.go | 13 +++------ service/matching/matching_engine.go | 10 +++++-- tests/nexus_api_test.go | 26 ++++++++++------- tests/xdc/nexus_request_forwarding_test.go | 8 ++++-- 11 files changed, 89 insertions(+), 50 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 008af7a892..b111fd98b5 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -142,8 +142,11 @@ func NexusFailureToAPIFailure(f nexus.Failure) (*failurepb.Failure, error) { State: se.State, }, } - if err := protojson.Unmarshal([]byte(se.EncodedAttributes), apiFailure.EncodedAttributes); err != nil { - return nil, fmt.Errorf("failed to deserialize OperationError attributes: %w", err) + if len(se.EncodedAttributes) > 0 { + apiFailure.EncodedAttributes = &commonpb.Payload{} + if err := protojson.Unmarshal([]byte(se.EncodedAttributes), apiFailure.EncodedAttributes); err != nil { + return nil, fmt.Errorf("failed to deserialize OperationError attributes: %w", err) + } } if f.Cause != nil { apiFailure.Cause, err = NexusFailureToAPIFailure(*f.Cause) @@ -172,8 +175,11 @@ func NexusFailureToAPIFailure(f nexus.Failure) (*failurepb.Failure, error) { RetryBehavior: retryBehavior, }, } - if err := protojson.Unmarshal([]byte(se.EncodedAttributes), apiFailure.EncodedAttributes); err != nil { - return nil, fmt.Errorf("failed to deserialize HandlerError attributes: %w", err) + if len(se.EncodedAttributes) > 0 { + apiFailure.EncodedAttributes = &commonpb.Payload{} + if err := protojson.Unmarshal([]byte(se.EncodedAttributes), apiFailure.EncodedAttributes); err != nil { + return nil, fmt.Errorf("failed to deserialize HandlerError attributes: %w", err) + } } if f.Cause != nil { apiFailure.Cause, err = NexusFailureToAPIFailure(*f.Cause) @@ -199,9 +205,10 @@ func NexusFailureToAPIFailure(f nexus.Failure) (*failurepb.Failure, error) { return apiFailure, nil } -func OperationErrorToTemporalFailure(opErr *nexus.OperationError) (*failurepb.Failure, error) { +func OperationErrorToTemporalFailure(opErr *nexus.OperationError, retryable bool) (*failurepb.Failure, error) { var nexusFailure nexus.Failure - failureErr, ok := opErr.Cause.(*nexus.FailureError) + var failureErr *nexus.FailureError + ok := errors.As(opErr.Cause, &failureErr) if ok { nexusFailure = failureErr.Failure } else if opErr.Cause != nil { @@ -236,7 +243,14 @@ func OperationErrorToTemporalFailure(opErr *nexus.OperationError) (*failurepb.Fa }, nil } - return NexusFailureToAPIFailure(nexusFailure) + f, err := NexusFailureToAPIFailure(nexusFailure) + if err != nil { + return nil, err + } + if f.GetApplicationFailureInfo() != nil { + f.GetApplicationFailureInfo().NonRetryable = !retryable + } + return f, nil } func nexusFailureMetadataToPayloads(failure nexus.Failure) (*commonpb.Payloads, error) { diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index 90d59bbecf..e7ffc5b356 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -49,7 +49,7 @@ func handleOperationError( if err != nil { return err } - failure, err := commonnexus.OperationErrorToTemporalFailure(opFailedError) + failure, err := commonnexus.OperationErrorToTemporalFailure(opFailedError, false) if err != nil { return err } diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 4f3e3c4b1b..4380f7e657 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -867,7 +867,14 @@ func callErrToFailure(callErr error, retryable bool) (*failurepb.Failure, error) if err != nil { return nil, err } - return commonnexus.NexusFailureToAPIFailure(nf) + f, err := commonnexus.NexusFailureToAPIFailure(nf) + if err != nil { + return nil, err + } + if f.GetApplicationFailureInfo() != nil { + f.GetApplicationFailureInfo().NonRetryable = !retryable + } + return f, nil } return &failurepb.Failure{ diff --git a/components/nexusoperations/executors_test.go b/components/nexusoperations/executors_test.go index 95dcb3f585..989743a922 100644 --- a/components/nexusoperations/executors_test.go +++ b/components/nexusoperations/executors_test.go @@ -288,7 +288,7 @@ func TestProcessInvocationTask(t *testing.T) { checkOutcome: func(t *testing.T, op nexusoperations.Operation, events []*historypb.HistoryEvent) { require.Equal(t, enumsspb.NEXUS_OPERATION_STATE_BACKING_OFF, op.State()) require.NotNil(t, op.LastAttemptFailure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (INTERNAL): internal server error", op.LastAttemptFailure.Message) + require.Equal(t, "internal server error", op.LastAttemptFailure.Message) require.Equal(t, 0, len(events)) }, }, @@ -352,7 +352,7 @@ func TestProcessInvocationTask(t *testing.T) { require.Equal(t, 1, len(events)) failure := events[0].GetNexusOperationFailedEventAttributes().Failure.Cause require.NotNil(t, failure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (NOT_FOUND): endpoint not registered", failure.Message) + require.Equal(t, "endpoint not registered", failure.Message) }, }, { @@ -366,9 +366,7 @@ func TestProcessInvocationTask(t *testing.T) { require.Equal(t, 1, len(events)) failure := events[0].GetNexusOperationFailedEventAttributes().Failure.Cause require.NotNil(t, failure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (NOT_FOUND): endpoint not registered", failure.Message) - require.NotNil(t, failure.Cause.GetApplicationFailureInfo()) - require.Equal(t, "endpoint not registered", failure.Cause.Message) + require.Equal(t, "endpoint not registered", failure.Message) }, }, { @@ -654,7 +652,9 @@ func TestProcessCancelationTask(t *testing.T) { checkOutcome: func(t *testing.T, c nexusoperations.Cancelation) { require.Equal(t, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_FAILED, c.State()) require.NotNil(t, c.LastAttemptFailure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (INTERNAL): operation not found", c.LastAttemptFailure.Message) + require.Equal(t, "500 Internal Server Error", c.LastAttemptFailure.Message) + require.NotNil(t, c.LastAttemptFailure.Cause) + require.Equal(t, "operation not found", c.LastAttemptFailure.Cause.Message) }, }, { @@ -698,7 +698,7 @@ func TestProcessCancelationTask(t *testing.T) { checkOutcome: func(t *testing.T, c nexusoperations.Cancelation) { require.Equal(t, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_BACKING_OFF, c.State()) require.NotNil(t, c.LastAttemptFailure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (INTERNAL): internal server error", c.LastAttemptFailure.Message) + require.Equal(t, "internal server error", c.LastAttemptFailure.Message) }, }, { @@ -738,7 +738,7 @@ func TestProcessCancelationTask(t *testing.T) { checkOutcome: func(t *testing.T, c nexusoperations.Cancelation) { require.Equal(t, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_FAILED, c.State()) require.NotNil(t, c.LastAttemptFailure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (NOT_FOUND): endpoint not registered", c.LastAttemptFailure.Message) + require.Equal(t, "endpoint not registered", c.LastAttemptFailure.Message) }, }, } diff --git a/components/nexusoperations/frontend/handler.go b/components/nexusoperations/frontend/handler.go index 843274ccfc..988c4870da 100644 --- a/components/nexusoperations/frontend/handler.go +++ b/components/nexusoperations/frontend/handler.go @@ -201,16 +201,23 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C } switch r.State { // nolint:exhaustive case nexus.OperationStateFailed, nexus.OperationStateCanceled: - failureErr, ok := r.Error.(*nexus.FailureError) - if !ok { + var failureErr *nexus.FailureError + var operationErr *nexus.OperationError + switch { + case errors.As(r.Error, &failureErr): + hr.Outcome = &historyservice.CompleteNexusOperationRequest_Failure{ + Failure: commonnexus.NexusFailureToProtoFailure(failureErr.Failure), + } + case errors.As(r.Error, &operationErr): + hr.Outcome = &historyservice.CompleteNexusOperationRequest_Failure{ + Failure: commonnexus.NexusFailureToProtoFailure(*operationErr.OriginalFailure), + } + default: // This shouldn't happen as the Nexus SDK is always expected to convert Failures from the wire to - // FailureErrors. - logger.Error("result error is not a FailureError", tag.Error(err)) + // FailureError or OperationErrors. + logger.Error("result error is not an OperationError or FailureError", tag.Error(err)) return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") } - hr.Outcome = &historyservice.CompleteNexusOperationRequest_Failure{ - Failure: commonnexus.NexusFailureToProtoFailure(failureErr.Failure), - } case nexus.OperationStateSucceeded: var result *commonpb.Payload if err := r.Result.Consume(&result); err != nil { diff --git a/go.mod b/go.mod index 3ddedf3faa..cdf97f9f0f 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/lib/pq v1.10.9 github.com/maruel/panicparse/v2 v2.4.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/nexus-rpc/sdk-go v0.5.2-0.20251205193432-20a8501a0c1b + github.com/nexus-rpc/sdk-go v0.5.2-0.20251217172131-63a8027ef960 github.com/olekukonko/tablewriter v0.0.5 github.com/olivere/elastic/v7 v7.0.32 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 4f1ad52761..bdf2fcd825 100644 --- a/go.sum +++ b/go.sum @@ -234,8 +234,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/nexus-rpc/sdk-go v0.5.2-0.20251205193432-20a8501a0c1b h1:JOkfj0lBDtxFiHaU0eCSfsM6mI/gc8GCOd5AFNxjAjQ= -github.com/nexus-rpc/sdk-go v0.5.2-0.20251205193432-20a8501a0c1b/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= +github.com/nexus-rpc/sdk-go v0.5.2-0.20251217172131-63a8027ef960 h1:ljAYqlX3IFBf7zqF8JGAgn21k7PBq4qyS8d45LcLDmQ= +github.com/nexus-rpc/sdk-go v0.5.2-0.20251217172131-63a8027ef960/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index cb8c53491e..536a5b4189 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -765,18 +765,13 @@ func (h *nexusHandler) convertOutcomeToNexusHandlerError(resp *matchingservice.D case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: retryBehavior = nexus.HandlerErrorRetryBehaviorNonRetryable } - - nf := commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()) handlerError := &nexus.HandlerError{ - Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), - Message: nf.Message, + Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), + Cause: &nexus.FailureError{ + Failure: commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()), + }, RetryBehavior: retryBehavior, } - if nf.Cause != nil { - handlerError.Cause = &nexus.FailureError{ - Failure: *nf.Cause, - } - } switch handlerError.Type { case nexus.HandlerErrorTypeUpstreamTimeout, diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index 445fb5e8f9..c20cfdd0cb 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -2405,8 +2405,14 @@ func (e *matchingEngineImpl) DispatchNexusTask(ctx context.Context, request *mat return nil, result.internalError } if result.failedWorkerResponse != nil { - return &matchingservice.DispatchNexusTaskResponse{Outcome: &matchingservice.DispatchNexusTaskResponse_HandlerError{ - HandlerError: result.failedWorkerResponse.GetRequest().GetError(), + if result.failedWorkerResponse.GetRequest().GetError() != nil { + // Deprecated case. Kept for backwards-compatibility with older SDKs that are sending errors instead of failures. + return &matchingservice.DispatchNexusTaskResponse{Outcome: &matchingservice.DispatchNexusTaskResponse_HandlerError{ + HandlerError: result.failedWorkerResponse.GetRequest().GetError(), + }}, nil + } + return &matchingservice.DispatchNexusTaskResponse{Outcome: &matchingservice.DispatchNexusTaskResponse_Failure{ + Failure: result.failedWorkerResponse.GetRequest().GetFailure(), }}, nil } diff --git a/tests/nexus_api_test.go b/tests/nexus_api_test.go index 5bca780c1e..4eefd0ddf3 100644 --- a/tests/nexus_api_test.go +++ b/tests/nexus_api_test.go @@ -203,6 +203,8 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, nexus.HandlerErrorRetryBehaviorUnspecified, handlerErr.RetryBehavior) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) + require.Equal(t, "500 Internal Server Error", handlerErr.Message) + require.NotNil(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, }, @@ -223,6 +225,8 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, nexus.HandlerErrorRetryBehaviorNonRetryable, handlerErr.RetryBehavior) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) + require.Equal(t, "500 Internal Server Error", handlerErr.Message) + require.NotNil(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, }, @@ -248,7 +252,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUpstreamTimeout, handlerErr.Type) - require.Equal(t, "upstream timeout", handlerErr.Cause.Error()) + require.Equal(t, "upstream timeout", handlerErr.Message) }, }, } @@ -343,7 +347,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_WithNamespaceAndTaskQueue_Na var handlerError *nexus.HandlerError s.ErrorAs(err, &handlerError) s.Equal(nexus.HandlerErrorTypeNotFound, handlerError.Type) - s.Equal(fmt.Sprintf("namespace not found: %q", namespace), handlerError.Cause.Error()) + s.Equal(fmt.Sprintf("namespace not found: %q", namespace), handlerError.Message) snap := capture.Snapshot() @@ -406,7 +410,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Forbidden() { }, checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied: unauthorized in test", handlerErr.Cause.Error()) + require.Equal(t, "permission denied: unauthorized in test", handlerErr.Message) }, expectedOutcomeMetric: "unauthorized", exposeAuthorizerErrors: false, @@ -427,7 +431,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Forbidden() { }, checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied", handlerErr.Cause.Error()) + require.Equal(t, "permission denied", handlerErr.Message) }, expectedOutcomeMetric: "unauthorized", exposeAuthorizerErrors: false, @@ -448,7 +452,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Forbidden() { }, checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied", handlerErr.Cause.Error()) + require.Equal(t, "permission denied", handlerErr.Message) }, expectedOutcomeMetric: "unauthorized", exposeAuthorizerErrors: false, @@ -469,7 +473,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Forbidden() { }, checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { require.Equal(t, nexus.HandlerErrorTypeUnavailable, handlerErr.Type) - require.Equal(t, "exposed error", handlerErr.Cause.Error()) + require.Equal(t, "exposed error", handlerErr.Message) }, expectedOutcomeMetric: "internal_auth_error", exposeAuthorizerErrors: true, @@ -536,7 +540,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied", handlerErr.Cause.Error()) + require.Equal(t, "permission denied", handlerErr.Message) require.Equal(t, 0, len(snap["nexus_request_preprocess_errors"])) }, }, @@ -549,7 +553,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUnauthenticated, handlerErr.Type) - require.Equal(t, "unauthorized", handlerErr.Cause.Error()) + require.Equal(t, "401 Unauthorized", handlerErr.Message) require.Equal(t, 1, len(snap["nexus_request_preprocess_errors"])) }, }, @@ -662,7 +666,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_PayloadSizeLimit() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeBadRequest, handlerErr.Type) - require.Equal(t, "input exceeds size limit", handlerErr.Cause.Error()) + require.Equal(t, "input exceeds size limit", handlerErr.Message) } s.T().Run("ByNamespaceAndTaskQueue", func(t *testing.T) { @@ -722,6 +726,8 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) + require.Equal(t, "500 Internal Server Error", handlerErr.Message) + require.NotNil(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, }, @@ -741,7 +747,7 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUpstreamTimeout, handlerErr.Type) - require.Equal(t, "upstream timeout", handlerErr.Cause.Error()) + require.Equal(t, "upstream timeout", handlerErr.Message) }, }, } diff --git a/tests/xdc/nexus_request_forwarding_test.go b/tests/xdc/nexus_request_forwarding_test.go index de8670a14c..d0eef023c4 100644 --- a/tests/xdc/nexus_request_forwarding_test.go +++ b/tests/xdc/nexus_request_forwarding_test.go @@ -191,6 +191,8 @@ func (s *NexusRequestForwardingSuite) TestStartOperationForwardedFromStandbyToAc var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) + require.Equal(t, "500 Internal Server Error", handlerErr.Message) + require.NotNil(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) requireExpectedMetricsCaptured(t, activeSnap, ns, "StartNexusOperation", "handler_error:INTERNAL") requireExpectedMetricsCaptured(t, passiveSnap, ns, "StartNexusOperation", "forwarded_request_error") @@ -211,7 +213,7 @@ func (s *NexusRequestForwardingSuite) TestStartOperationForwardedFromStandbyToAc var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUnavailable, handlerErr.Type) - require.Equal(t, "cluster inactive", handlerErr.Cause.Error()) + require.Equal(t, "cluster inactive", handlerErr.Message) requireExpectedMetricsCaptured(t, passiveSnap, ns, "StartNexusOperation", "namespace_inactive_forwarding_disabled") }, }, @@ -315,6 +317,8 @@ func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToA var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) + require.Equal(t, "500 Internal Server Error", handlerErr.Message) + require.NotNil(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) requireExpectedMetricsCaptured(t, activeSnap, ns, "CancelNexusOperation", "handler_error:INTERNAL") requireExpectedMetricsCaptured(t, passiveSnap, ns, "CancelNexusOperation", "forwarded_request_error") @@ -335,7 +339,7 @@ func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToA var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUnavailable, handlerErr.Type) - require.Equal(t, "cluster inactive", handlerErr.Cause.Error()) + require.Equal(t, "cluster inactive", handlerErr.Message) requireExpectedMetricsCaptured(t, passiveSnap, ns, "CancelNexusOperation", "namespace_inactive_forwarding_disabled") }, }, From 6dbbfed9f424d92e2abd551e40f131f322b95718 Mon Sep 17 00:00:00 2001 From: PJ Doerner Date: Wed, 7 Jan 2026 10:09:00 -0800 Subject: [PATCH 06/26] lint --- common/nexus/failure.go | 1 + common/nexus/nexusrpc/setup_test.go | 1 + service/frontend/workflow_handler.go | 6 +++--- service/matching/matching_engine.go | 4 ++-- tests/nexus_api_test.go | 6 +++--- tests/xdc/nexus_request_forwarding_test.go | 4 ++-- 6 files changed, 12 insertions(+), 10 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index b111fd98b5..8409fe8a01 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -115,6 +115,7 @@ type serializedHandlerError struct { // NexusFailureToAPIFailure converts a Nexus Failure to an API proto Failure. // If the failure metadata "type" field is set to the fullname of the temporal API Failure message, the failure is // reconstructed using protojson.Unmarshal on the failure details field. +// nolint:revive // cognitive-complexity is high but justified to keep each case together func NexusFailureToAPIFailure(f nexus.Failure) (*failurepb.Failure, error) { apiFailure := &failurepb.Failure{ Message: f.Message, diff --git a/common/nexus/nexusrpc/setup_test.go b/common/nexus/nexusrpc/setup_test.go index 58f7f5bdf7..b910cff02e 100644 --- a/common/nexus/nexusrpc/setup_test.go +++ b/common/nexus/nexusrpc/setup_test.go @@ -122,6 +122,7 @@ func (c customFailureConverter) ErrorToFailure(err error) (nexus.Failure, error) } // FailureToError implements FailureConverter. +// nolint:revive // unnamed results of the same type is fine for test func (c customFailureConverter) FailureToError(f nexus.Failure) (error, error) { if f.Metadata["type"] != "custom" { return errors.New(f.Message), nil diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 13293663d4..9dc5410337 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -5471,11 +5471,11 @@ func (wh *WorkflowHandler) RespondNexusTaskFailed(ctx context.Context, request * } namespaceId := namespace.ID(tt.GetNamespaceId()) - if request.Error == nil && request.Failure == nil { + if request.Error == nil && request.Failure == nil { // nolint:staticcheck // checking deprecated field for backwards compatibility return nil, serviceerror.NewInvalidArgument("request must contain error or failure") } - if request.GetError() != nil { - if details := request.GetError().GetFailure().GetDetails(); details != nil && !json.Valid(details) { + if request.GetError() != nil { // nolint:staticcheck // checking deprecated field for backwards compatibility + if details := request.GetError().GetFailure().GetDetails(); details != nil && !json.Valid(details) { // nolint:staticcheck // checking deprecated field for backwards compatibility return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") } } diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index c20cfdd0cb..17aea97f7c 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -2405,10 +2405,10 @@ func (e *matchingEngineImpl) DispatchNexusTask(ctx context.Context, request *mat return nil, result.internalError } if result.failedWorkerResponse != nil { - if result.failedWorkerResponse.GetRequest().GetError() != nil { + if result.failedWorkerResponse.GetRequest().GetError() != nil { // nolint:staticcheck // checking deprecated field for backwards compatibility // Deprecated case. Kept for backwards-compatibility with older SDKs that are sending errors instead of failures. return &matchingservice.DispatchNexusTaskResponse{Outcome: &matchingservice.DispatchNexusTaskResponse_HandlerError{ - HandlerError: result.failedWorkerResponse.GetRequest().GetError(), + HandlerError: result.failedWorkerResponse.GetRequest().GetError(), // nolint:staticcheck // checking deprecated field for backwards compatibility }}, nil } return &matchingservice.DispatchNexusTaskResponse{Outcome: &matchingservice.DispatchNexusTaskResponse_Failure{ diff --git a/tests/nexus_api_test.go b/tests/nexus_api_test.go index 4eefd0ddf3..a05fee0e17 100644 --- a/tests/nexus_api_test.go +++ b/tests/nexus_api_test.go @@ -204,7 +204,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorRetryBehaviorUnspecified, handlerErr.RetryBehavior) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) require.Equal(t, "500 Internal Server Error", handlerErr.Message) - require.NotNil(t, handlerErr.Cause) + require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, }, @@ -226,7 +226,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorRetryBehaviorNonRetryable, handlerErr.RetryBehavior) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) require.Equal(t, "500 Internal Server Error", handlerErr.Message) - require.NotNil(t, handlerErr.Cause) + require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, }, @@ -727,7 +727,7 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) require.Equal(t, "500 Internal Server Error", handlerErr.Message) - require.NotNil(t, handlerErr.Cause) + require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, }, diff --git a/tests/xdc/nexus_request_forwarding_test.go b/tests/xdc/nexus_request_forwarding_test.go index d0eef023c4..ab827f97ac 100644 --- a/tests/xdc/nexus_request_forwarding_test.go +++ b/tests/xdc/nexus_request_forwarding_test.go @@ -192,7 +192,7 @@ func (s *NexusRequestForwardingSuite) TestStartOperationForwardedFromStandbyToAc require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, "500 Internal Server Error", handlerErr.Message) - require.NotNil(t, handlerErr.Cause) + require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) requireExpectedMetricsCaptured(t, activeSnap, ns, "StartNexusOperation", "handler_error:INTERNAL") requireExpectedMetricsCaptured(t, passiveSnap, ns, "StartNexusOperation", "forwarded_request_error") @@ -318,7 +318,7 @@ func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToA require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, "500 Internal Server Error", handlerErr.Message) - require.NotNil(t, handlerErr.Cause) + require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) requireExpectedMetricsCaptured(t, activeSnap, ns, "CancelNexusOperation", "handler_error:INTERNAL") requireExpectedMetricsCaptured(t, passiveSnap, ns, "CancelNexusOperation", "forwarded_request_error") From 781bc00801f0b7cd0591a5707c2b93f7fa037a75 Mon Sep 17 00:00:00 2001 From: PJ Doerner Date: Fri, 16 Jan 2026 11:28:06 -0800 Subject: [PATCH 07/26] Add handling for StartOperationResponse_Failure --- common/nexus/failure.go | 15 ++++++++++--- common/nexus/nexusrpc/client.go | 3 +-- common/nexus/nexusrpc/completion.go | 4 ++-- common/nexus/nexusrpc/server.go | 5 +++++ components/nexusoperations/completion.go | 2 +- service/frontend/nexus_handler.go | 15 +++++++++++++ .../history/workflow/mutable_state_impl.go | 22 ++++++++++++++----- 7 files changed, 52 insertions(+), 14 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 8409fe8a01..06d9580951 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -85,10 +85,18 @@ func APIFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) data, err := protojson.Marshal(failure) failure.Message = message failure.StackTrace = stackTrace - if err != nil { return nexus.Failure{}, err } + + var cause nexus.Failure + if failure.GetCause() != nil { + cause, err = APIFailureToNexusFailure(failure.GetCause()) + if err != nil { + return nexus.Failure{}, err + } + } + return nexus.Failure{ Message: failure.GetMessage(), StackTrace: failure.GetStackTrace(), @@ -96,6 +104,7 @@ func APIFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) "type": failureTypeString, }, Details: data, + Cause: &cause, }, nil } @@ -206,7 +215,7 @@ func NexusFailureToAPIFailure(f nexus.Failure) (*failurepb.Failure, error) { return apiFailure, nil } -func OperationErrorToTemporalFailure(opErr *nexus.OperationError, retryable bool) (*failurepb.Failure, error) { +func OperationErrorToTemporalFailure(opErr *nexus.OperationError) (*failurepb.Failure, error) { var nexusFailure nexus.Failure var failureErr *nexus.FailureError ok := errors.As(opErr.Cause, &failureErr) @@ -249,7 +258,7 @@ func OperationErrorToTemporalFailure(opErr *nexus.OperationError, retryable bool return nil, err } if f.GetApplicationFailureInfo() != nil { - f.GetApplicationFailureInfo().NonRetryable = !retryable + f.GetApplicationFailureInfo().NonRetryable = true } return f, nil } diff --git a/common/nexus/nexusrpc/client.go b/common/nexus/nexusrpc/client.go index 8c3045558f..cd03e3aa5e 100644 --- a/common/nexus/nexusrpc/client.go +++ b/common/nexus/nexusrpc/client.go @@ -379,14 +379,13 @@ func (c *HTTPClient) defaultErrorFromResponse(response *http.Response, body []by } func (c *HTTPClient) bestEffortHandlerErrorFromResponse(response *http.Response, body []byte) error { - // TODO: support old servers failure, err := c.failureFromResponse(response, body) if err != nil { return c.defaultErrorFromResponse(response, body, nil) } convErr, err := c.options.FailureConverter.FailureToError(failure) if err != nil { - return fmt.Errorf("failed to convert Failure to error: %w", err) + return newUnexpectedResponseError(fmt.Sprintf("failed to convert Failure to error: %s", err.Error()), response, body) } if _, ok := convErr.(*nexus.HandlerError); !ok { convErr = c.defaultErrorFromResponse(response, body, convErr) diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index 691b09e722..12efb631cc 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -166,12 +166,12 @@ type OperationCompletionUnsuccessfulOptions struct { // A [FailureConverter] to convert a [Failure] instance to and from an [error]. Defaults to // [DefaultFailureConverter]. // - // NOTE: To call server versions <= 0.4.0, use a FailureConverter that unwraps the error cause if message is not + // NOTE: To call server versions <= 1.31.0, use a FailureConverter that unwraps the error cause if message is not // present. FailureConverter nexus.FailureConverter // OperationID is the unique ID for this operation. Used when a completion callback is received before a started response. // - // Deprecated: Use OperatonToken instead. + // Deprecated: Use OperationToken instead. OperationID string // OperationToken is the unique token for this operation. Used when a completion callback is received before a // started response. diff --git a/common/nexus/nexusrpc/server.go b/common/nexus/nexusrpc/server.go index 7f62f29fcc..9119c1a21e 100644 --- a/common/nexus/nexusrpc/server.go +++ b/common/nexus/nexusrpc/server.go @@ -91,6 +91,7 @@ func (h *httpHandler) writeResult(writer http.ResponseWriter, result any) { func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { var failure nexus.Failure + var failureError *nexus.FailureError var opError *nexus.OperationError var handlerError *nexus.HandlerError var operationState nexus.OperationState @@ -103,6 +104,7 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { if convErr != nil { h.logger.Error("failed to convert operation error to failure", "error", convErr) writer.WriteHeader(http.StatusInternalServerError) + return } statusCode = statusOperationFailed @@ -118,6 +120,7 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { if convErr != nil { h.logger.Error("failed to convert handler error to failure", "error", convErr) writer.WriteHeader(http.StatusInternalServerError) + return } switch handlerError.Type { case nexus.HandlerErrorTypeBadRequest: @@ -145,6 +148,8 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { default: h.logger.Error("unexpected handler error type", "type", handlerError.Type) } + } else if errors.As(err, &failureError) { + failure = failureError.Failure } else { failure = nexus.Failure{ Message: "internal server error", diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index e7ffc5b356..90d59bbecf 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -49,7 +49,7 @@ func handleOperationError( if err != nil { return err } - failure, err := commonnexus.OperationErrorToTemporalFailure(opFailedError, false) + failure, err := commonnexus.OperationErrorToTemporalFailure(opFailedError) if err != nil { return err } diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index 536a5b4189..77d2758d82 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -507,6 +507,21 @@ func (h *nexusHandler) StartOperation( }, } return nil, err + + case *nexuspb.StartOperationResponse_Failure: + oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("failure")) + nf, err := commonnexus.APIFailureToNexusFailure(t.Failure) + if err != nil { + oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + oe, err := nexus.DefaultFailureConverter().FailureToError(nf) + if err != nil { + oc.logger.Error("error converting Nexus failure to Nexus OperationError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) + return nil, oe } } // This is the worker's fault. diff --git a/service/history/workflow/mutable_state_impl.go b/service/history/workflow/mutable_state_impl.go index 7896a61a06..9ff18dd95b 100644 --- a/service/history/workflow/mutable_state_impl.go +++ b/service/history/workflow/mutable_state_impl.go @@ -782,7 +782,11 @@ func (ms *MutableStateImpl) GetNexusCompletion( return nil, err } return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: f}}, + &nexus.OperationError{ + State: nexus.OperationStateFailed, + Cause: &nexus.FailureError{Failure: f}, + OriginalFailure: &f, + }, nexusrpc.OperationCompletionUnsuccessfulOptions{ StartTime: ms.executionState.GetStartTime().AsTime(), CloseTime: ce.GetEventTime().AsTime(), @@ -802,8 +806,9 @@ func (ms *MutableStateImpl) GetNexusCompletion( } return nexusrpc.NewOperationCompletionUnsuccessful( &nexus.OperationError{ - State: nexus.OperationStateCanceled, - Cause: &nexus.FailureError{Failure: f}, + State: nexus.OperationStateCanceled, + Cause: &nexus.FailureError{Failure: f}, + OriginalFailure: &f, }, nexusrpc.OperationCompletionUnsuccessfulOptions{ StartTime: ms.executionState.GetStartTime().AsTime(), @@ -821,7 +826,11 @@ func (ms *MutableStateImpl) GetNexusCompletion( return nil, err } return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: f}}, + &nexus.OperationError{ + State: nexus.OperationStateFailed, + Cause: &nexus.FailureError{Failure: f}, + OriginalFailure: &f, + }, nexusrpc.OperationCompletionUnsuccessfulOptions{ StartTime: ms.executionState.GetStartTime().AsTime(), CloseTime: ce.GetEventTime().AsTime(), @@ -842,8 +851,9 @@ func (ms *MutableStateImpl) GetNexusCompletion( } return nexusrpc.NewOperationCompletionUnsuccessful( &nexus.OperationError{ - State: nexus.OperationStateFailed, - Cause: &nexus.FailureError{Failure: f}, + State: nexus.OperationStateFailed, + Cause: &nexus.FailureError{Failure: f}, + OriginalFailure: &f, }, nexusrpc.OperationCompletionUnsuccessfulOptions{ StartTime: ms.executionState.GetStartTime().AsTime(), From 699b0ff725fa05b71316ae22a515eed31f046aaa Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Fri, 23 Jan 2026 11:42:54 -0800 Subject: [PATCH 08/26] Upgrade api --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index cdf97f9f0f..4dbe2f9a4a 100644 --- a/go.mod +++ b/go.mod @@ -58,7 +58,7 @@ require ( go.opentelemetry.io/otel/sdk v1.34.0 go.opentelemetry.io/otel/sdk/metric v1.34.0 go.opentelemetry.io/otel/trace v1.34.0 - go.temporal.io/api v1.59.1-0.20251210194227-b7eb0896bf95 + go.temporal.io/api v1.61.1-0.20260123194132-ee4a47298624 go.temporal.io/sdk v1.38.0 go.uber.org/fx v1.24.0 go.uber.org/mock v0.6.0 diff --git a/go.sum b/go.sum index bdf2fcd825..c998bfdc3c 100644 --- a/go.sum +++ b/go.sum @@ -371,8 +371,8 @@ go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4= go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4= -go.temporal.io/api v1.59.1-0.20251210194227-b7eb0896bf95 h1:G7ndB6smLBO+JWNs7bBmNdfMHRrhCab7kM19DH1qZBs= -go.temporal.io/api v1.59.1-0.20251210194227-b7eb0896bf95/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= +go.temporal.io/api v1.61.1-0.20260123194132-ee4a47298624 h1:3dqSdkDfWugeg8QEf1WsAFhtTyNV6Fj4I6wpD51gtVc= +go.temporal.io/api v1.61.1-0.20260123194132-ee4a47298624/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= go.temporal.io/sdk v1.38.0 h1:4Bok5LEdED7YKpsSjIa3dDqram5VOq+ydBf4pyx0Wo4= go.temporal.io/sdk v1.38.0/go.mod h1:a+R2Ej28ObvHoILbHaxMyind7M6D+W0L7edt5UJF4SE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= From 2c9cb58cddf340cde6431c5d3c8bee5ecae7a69c Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Fri, 23 Jan 2026 13:27:57 -0800 Subject: [PATCH 09/26] Fix conversion from nexus failure to temporal failure --- chasm/lib/callback/chasm_invocation.go | 2 +- common/nexus/failure.go | 196 ++++++++---- common/nexus/failure_test.go | 302 ++++++++++++++++++ components/callbacks/chasm_invocation.go | 2 +- components/nexusoperations/executors.go | 2 +- go.mod | 2 +- go.sum | 2 + service/frontend/nexus_handler.go | 6 +- .../history/workflow/mutable_state_impl.go | 8 +- 9 files changed, 455 insertions(+), 67 deletions(-) create mode 100644 common/nexus/failure_test.go diff --git a/chasm/lib/callback/chasm_invocation.go b/chasm/lib/callback/chasm_invocation.go index 7cd10772a6..35b7389907 100644 --- a/chasm/lib/callback/chasm_invocation.go +++ b/chasm/lib/callback/chasm_invocation.go @@ -162,7 +162,7 @@ func (c chasmInvocation) getHistoryRequest( Completion: completion, } case *nexusrpc.OperationCompletionUnsuccessful: - apiFailure, err := commonnexus.NexusFailureToAPIFailure(op.Failure) + apiFailure, err := commonnexus.NexusFailureToTemporalFailure(op.Failure) if err != nil { return nil, fmt.Errorf("failed to convert failure type: %v", err) } diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 06d9580951..2248141bb7 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -2,6 +2,7 @@ package nexus import ( "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -71,11 +72,111 @@ func NexusFailureToProtoFailure(failure nexus.Failure) *nexuspb.Failure { } } -// APIFailureToNexusFailure converts an API proto Failure to a Nexus SDK Failure setting the metadata "type" field to -// the proto fullname of the temporal API Failure message. +type serializedOperationError struct { + State string `json:"state,omitempty"` + // Bytes as base64 encoded string. + EncodedAttributes string `json:"encodedAttributes,omitempty"` +} + +type serializedHandlerError struct { + Type string `json:"type,omitempty"` + RetryableOverride *bool `json:"retryableOverride,omitempty"` + // Bytes as base64 encoded string. + EncodedAttributes string `json:"encodedAttributes,omitempty"` +} + +// TemporalFailureToNexusFailure converts an API proto Failure to a Nexus SDK Failure setting the metadata "type" field to +// the proto fullname of the temporal API Failure message or the standard Nexus SDK failure types. +// Returns an error if the failure cannot be converted. // Mutates the failure temporarily, unsetting the Message field to avoid duplicating the information in the serialized // failure. Mutating was chosen over cloning for performance reasons since this function may be called frequently. -func APIFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) { +func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) { + var causep *nexus.Failure + if failure.GetCause() != nil { + var cause nexus.Failure + var err error + cause, err = TemporalFailureToNexusFailure(failure.GetCause()) + if err != nil { + return nexus.Failure{}, err + } + causep = &cause + } + + switch info := failure.GetFailureInfo().(type) { + case *failurepb.Failure_NexusSdkOperationFailureInfo: + var encodedAttributes string + if failure.EncodedAttributes != nil { + b, err := protojson.Marshal(failure.EncodedAttributes) + if err != nil { + return nexus.Failure{}, fmt.Errorf("failed to deserialize OperationError attributes: %w", err) + } + encodedAttributes = base64.StdEncoding.EncodeToString(b) + } + operationError := serializedOperationError{ + State: info.NexusSdkOperationFailureInfo.GetState(), + EncodedAttributes: encodedAttributes, + } + + details, err := json.Marshal(operationError) + if err != nil { + return nexus.Failure{}, err + } + return nexus.Failure{ + Message: failure.GetMessage(), + StackTrace: failure.GetStackTrace(), + Metadata: map[string]string{ + "type": "nexus.OperationError", + }, + Details: details, + Cause: causep, + }, nil + case *failurepb.Failure_NexusHandlerFailureInfo: + var encodedAttributes string + if failure.EncodedAttributes != nil { + b, err := protojson.Marshal(failure.EncodedAttributes) + if err != nil { + return nexus.Failure{}, fmt.Errorf("failed to deserialize HandlerError attributes: %w", err) + } + encodedAttributes = base64.StdEncoding.EncodeToString(b) + } + var retryableOverride *bool + switch info.NexusHandlerFailureInfo.GetRetryBehavior() { + case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE: + val := true + retryableOverride = &val + case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: + val := false + retryableOverride = &val + } + + handlerError := serializedHandlerError{ + Type: info.NexusHandlerFailureInfo.GetType(), + RetryableOverride: retryableOverride, + EncodedAttributes: encodedAttributes, + } + + details, err := json.Marshal(handlerError) + if err != nil { + return nexus.Failure{}, err + } + return nexus.Failure{ + Message: failure.GetMessage(), + StackTrace: failure.GetStackTrace(), + Metadata: map[string]string{ + "type": "nexus.HandlerError", + }, + Details: details, + Cause: causep, + }, nil + case *failurepb.Failure_NexusSdkFailureErrorInfo: + return nexus.Failure{ + Message: failure.GetMessage(), + StackTrace: failure.GetStackTrace(), + Metadata: info.NexusSdkFailureErrorInfo.GetMetadata(), + Details: info.NexusSdkFailureErrorInfo.GetDetails(), + Cause: causep, + }, nil + } // Unset message and stack trace so it's not serialized in the details. var message string message, failure.Message = failure.Message, "" @@ -89,14 +190,6 @@ func APIFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) return nexus.Failure{}, err } - var cause nexus.Failure - if failure.GetCause() != nil { - cause, err = APIFailureToNexusFailure(failure.GetCause()) - if err != nil { - return nexus.Failure{}, err - } - } - return nexus.Failure{ Message: failure.GetMessage(), StackTrace: failure.GetStackTrace(), @@ -104,28 +197,17 @@ func APIFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) "type": failureTypeString, }, Details: data, - Cause: &cause, + Cause: causep, }, nil } -type serializedOperationError struct { - State string `json:"state,omitempty"` - // Bytes as base64 encoded string. - EncodedAttributes string `json:"encodedAttributes,omitempty"` -} - -type serializedHandlerError struct { - Type string `json:"type,omitempty"` - RetryableOverride *bool `json:"retryableOverride,omitempty"` - // Bytes as base64 encoded string. - EncodedAttributes string `json:"encodedAttributes,omitempty"` -} - -// NexusFailureToAPIFailure converts a Nexus Failure to an API proto Failure. +// NexusFailureToTemporalFailure converts a Nexus Failure to an API proto Failure. // If the failure metadata "type" field is set to the fullname of the temporal API Failure message, the failure is -// reconstructed using protojson.Unmarshal on the failure details field. +// reconstructed using protojson.Unmarshal on the failure details field. Otherwise, the failure is reconstructed +// based on the known Nexus SDK failure types. +// Returns an error if the failure cannot be converted. // nolint:revive // cognitive-complexity is high but justified to keep each case together -func NexusFailureToAPIFailure(f nexus.Failure) (*failurepb.Failure, error) { +func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) { apiFailure := &failurepb.Failure{ Message: f.Message, StackTrace: f.StackTrace, @@ -140,7 +222,6 @@ func NexusFailureToAPIFailure(f nexus.Failure) (*failurepb.Failure, error) { // Restore these fields as they are not included in the marshalled failure. apiFailure.Message = f.Message apiFailure.StackTrace = f.StackTrace - return apiFailure, nil case "nexus.OperationError": var se serializedOperationError err := json.Unmarshal(f.Details, &se) @@ -153,18 +234,15 @@ func NexusFailureToAPIFailure(f nexus.Failure) (*failurepb.Failure, error) { }, } if len(se.EncodedAttributes) > 0 { + decoded, err := base64.StdEncoding.DecodeString(se.EncodedAttributes) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 OperationError attributes: %w", err) + } apiFailure.EncodedAttributes = &commonpb.Payload{} - if err := protojson.Unmarshal([]byte(se.EncodedAttributes), apiFailure.EncodedAttributes); err != nil { + if err := protojson.Unmarshal(decoded, apiFailure.EncodedAttributes); err != nil { return nil, fmt.Errorf("failed to deserialize OperationError attributes: %w", err) } } - if f.Cause != nil { - apiFailure.Cause, err = NexusFailureToAPIFailure(*f.Cause) - if err != nil { - return nil, err - } - } - return apiFailure, nil case "nexus.HandlerError": var se serializedHandlerError err := json.Unmarshal(f.Details, &se) @@ -186,31 +264,37 @@ func NexusFailureToAPIFailure(f nexus.Failure) (*failurepb.Failure, error) { }, } if len(se.EncodedAttributes) > 0 { + decoded, err := base64.StdEncoding.DecodeString(se.EncodedAttributes) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 HandlerError attributes: %w", err) + } apiFailure.EncodedAttributes = &commonpb.Payload{} - if err := protojson.Unmarshal([]byte(se.EncodedAttributes), apiFailure.EncodedAttributes); err != nil { + if err := protojson.Unmarshal(decoded, apiFailure.EncodedAttributes); err != nil { return nil, fmt.Errorf("failed to deserialize HandlerError attributes: %w", err) } } - if f.Cause != nil { - apiFailure.Cause, err = NexusFailureToAPIFailure(*f.Cause) - if err != nil { - return nil, err - } + default: + apiFailure.FailureInfo = &failurepb.Failure_NexusSdkFailureErrorInfo{ + NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ + Metadata: f.Metadata, + Details: f.Details, + }, } - return apiFailure, nil + } + } else if len(f.Details) > 0 { + apiFailure.FailureInfo = &failurepb.Failure_NexusSdkFailureErrorInfo{ + NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ + Details: f.Details, + }, } } - payloads, err := nexusFailureMetadataToPayloads(f) - if err != nil { - return nil, err - } - apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ - ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ - // Make up a type here, it's not part of the Nexus Failure spec. - Type: "NexusFailure", - Details: payloads, - }, + if f.Cause != nil { + var err error + apiFailure.Cause, err = NexusFailureToTemporalFailure(*f.Cause) + if err != nil { + return nil, err + } } return apiFailure, nil } @@ -228,7 +312,7 @@ func OperationErrorToTemporalFailure(opErr *nexus.OperationError) (*failurepb.Fa // Canceled must be translated into a CanceledFailure to match the SDK expectation. if opErr.State == nexus.OperationStateCanceled { if nexusFailure.Metadata != nil && nexusFailure.Metadata["type"] == failureTypeString { - temporalFailure, err := NexusFailureToAPIFailure(nexusFailure) + temporalFailure, err := NexusFailureToTemporalFailure(nexusFailure) if err != nil { return nil, err } @@ -253,7 +337,7 @@ func OperationErrorToTemporalFailure(opErr *nexus.OperationError) (*failurepb.Fa }, nil } - f, err := NexusFailureToAPIFailure(nexusFailure) + f, err := NexusFailureToTemporalFailure(nexusFailure) if err != nil { return nil, err } diff --git a/common/nexus/failure_test.go b/common/nexus/failure_test.go new file mode 100644 index 0000000000..6f04b2bdc8 --- /dev/null +++ b/common/nexus/failure_test.go @@ -0,0 +1,302 @@ +package nexus + +import ( + "testing" + + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" + failurepb "go.temporal.io/api/failure/v1" + "go.temporal.io/server/common/testing/protorequire" +) + +func TestRoundTrip_ApplicationFailure(t *testing.T) { + original := &failurepb.Failure{ + Message: "application error", + StackTrace: "stack trace here", + EncodedAttributes: mustToPayload(t, "encoded"), + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Type: "CustomError", + NonRetryable: false, + Details: &commonpb.Payloads{ + Payloads: []*commonpb.Payload{mustToPayload(t, "encoded")}, + }, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusSDKOperationFailure_WithoutAttributes(t *testing.T) { + original := &failurepb.Failure{ + Message: "operation failed", + StackTrace: "operation stack trace", + FailureInfo: &failurepb.Failure_NexusSdkOperationFailureInfo{ + NexusSdkOperationFailureInfo: &failurepb.NexusSDKOperationFailureInfo{ + State: "failed", + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusSDKOperationFailure_WithAttributes(t *testing.T) { + original := &failurepb.Failure{ + Message: "operation failed with details", + StackTrace: "operation stack trace", + FailureInfo: &failurepb.Failure_NexusSdkOperationFailureInfo{ + NexusSdkOperationFailureInfo: &failurepb.NexusSDKOperationFailureInfo{ + State: "failed", + }, + }, + EncodedAttributes: &commonpb.Payload{ + Metadata: map[string][]byte{"encoding": []byte("json/plain")}, + Data: []byte(`{"custom":"attribute"}`), + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusHandlerFailure_Retryable(t *testing.T) { + original := &failurepb.Failure{ + Message: "handler error - retryable", + StackTrace: "handler stack trace", + FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: "CustomHandlerError", + RetryBehavior: enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusHandlerFailure_NonRetryable(t *testing.T) { + original := &failurepb.Failure{ + Message: "handler error - non-retryable", + StackTrace: "handler stack trace", + FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: "FatalHandlerError", + RetryBehavior: enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusHandlerFailure_Unspecified(t *testing.T) { + original := &failurepb.Failure{ + Message: "handler error - unspecified retry", + StackTrace: "handler stack trace", + FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: "UnspecifiedHandlerError", + RetryBehavior: enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_UNSPECIFIED, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusHandlerFailure_WithAttributes(t *testing.T) { + original := &failurepb.Failure{ + Message: "handler error with attributes", + StackTrace: "handler stack trace", + FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: "ComplexHandlerError", + RetryBehavior: enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE, + }, + }, + EncodedAttributes: mustToPayload(t, "encoded attributes"), + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusSDKFailureErrorInfo(t *testing.T) { + original := &failurepb.Failure{ + Message: "sdk failure error", + StackTrace: "sdk stack trace", + FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ + NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ + Metadata: map[string]string{ + "custom-key": "custom-value", + "error-type": "SomeError", + }, + Details: []byte(`{"field":"value"}`), + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_WithNestedCauses(t *testing.T) { + original := &failurepb.Failure{ + Message: "top level failure", + StackTrace: "top stack trace", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Type: "TopLevelError", + }, + }, + Cause: &failurepb.Failure{ + Message: "middle failure", + StackTrace: "middle stack trace", + FailureInfo: &failurepb.Failure_TimeoutFailureInfo{ + TimeoutFailureInfo: &failurepb.TimeoutFailureInfo{ + TimeoutType: enumspb.TIMEOUT_TYPE_START_TO_CLOSE, + }, + }, + Cause: &failurepb.Failure{ + Message: "root cause", + StackTrace: "root stack trace", + FailureInfo: &failurepb.Failure_ServerFailureInfo{ + ServerFailureInfo: &failurepb.ServerFailureInfo{ + NonRetryable: true, + }, + }, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusOperationFailureWithNexusHandlerCause(t *testing.T) { + original := &failurepb.Failure{ + Message: "operation failed", + StackTrace: "operation stack trace", + FailureInfo: &failurepb.Failure_NexusSdkOperationFailureInfo{ + NexusSdkOperationFailureInfo: &failurepb.NexusSDKOperationFailureInfo{ + State: "failed", + }, + }, + Cause: &failurepb.Failure{ + Message: "handler caused the failure", + StackTrace: "handler stack trace", + FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: "BadRequest", + RetryBehavior: enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE, + }, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_EmptyFailure(t *testing.T) { + original := &failurepb.Failure{ + Message: "simple message", + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_OnlyStackTrace(t *testing.T) { + original := &failurepb.Failure{ + StackTrace: "just a stack trace", + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + protorequire.ProtoEqual(t, original, converted) +} + + +func TestRoundTrip_OnlyDetails(t *testing.T) { + original := &failurepb.Failure{ + FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ + NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ + Details: []byte(`{"only":"details"}`), + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + protorequire.ProtoEqual(t, original, converted) +} diff --git a/components/callbacks/chasm_invocation.go b/components/callbacks/chasm_invocation.go index c5b1355921..6b30ed404b 100644 --- a/components/callbacks/chasm_invocation.go +++ b/components/callbacks/chasm_invocation.go @@ -114,7 +114,7 @@ func (c chasmInvocation) getHistoryRequest( Completion: completion, } case *nexusrpc.OperationCompletionUnsuccessful: - apiFailure, err := commonnexus.NexusFailureToAPIFailure(op.Failure) + apiFailure, err := commonnexus.NexusFailureToTemporalFailure(op.Failure) if err != nil { return nil, fmt.Errorf("failed to convert failure type: %v", err) } diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 4380f7e657..8c072c64d2 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -867,7 +867,7 @@ func callErrToFailure(callErr error, retryable bool) (*failurepb.Failure, error) if err != nil { return nil, err } - f, err := commonnexus.NexusFailureToAPIFailure(nf) + f, err := commonnexus.NexusFailureToTemporalFailure(nf) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index 4dbe2f9a4a..d9e1e6e5cf 100644 --- a/go.mod +++ b/go.mod @@ -58,7 +58,7 @@ require ( go.opentelemetry.io/otel/sdk v1.34.0 go.opentelemetry.io/otel/sdk/metric v1.34.0 go.opentelemetry.io/otel/trace v1.34.0 - go.temporal.io/api v1.61.1-0.20260123194132-ee4a47298624 + go.temporal.io/api v1.61.1-0.20260123211420-03a0445cb1c7 go.temporal.io/sdk v1.38.0 go.uber.org/fx v1.24.0 go.uber.org/mock v0.6.0 diff --git a/go.sum b/go.sum index c998bfdc3c..b160a64b7c 100644 --- a/go.sum +++ b/go.sum @@ -373,6 +373,8 @@ go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4= go.temporal.io/api v1.61.1-0.20260123194132-ee4a47298624 h1:3dqSdkDfWugeg8QEf1WsAFhtTyNV6Fj4I6wpD51gtVc= go.temporal.io/api v1.61.1-0.20260123194132-ee4a47298624/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= +go.temporal.io/api v1.61.1-0.20260123211420-03a0445cb1c7 h1:v5k7tdCrSAI1zNj7jClndKc/2Sq3C/B5g5NEQmkZrWs= +go.temporal.io/api v1.61.1-0.20260123211420-03a0445cb1c7/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= go.temporal.io/sdk v1.38.0 h1:4Bok5LEdED7YKpsSjIa3dDqram5VOq+ydBf4pyx0Wo4= go.temporal.io/sdk v1.38.0/go.mod h1:a+R2Ej28ObvHoILbHaxMyind7M6D+W0L7edt5UJF4SE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index 77d2758d82..d624bc4a2f 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -449,7 +449,7 @@ func (h *nexusHandler) StartOperation( switch t := response.GetOutcome().(type) { case *matchingservice.DispatchNexusTaskResponse_Failure: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.Failure.GetNexusHandlerFailureInfo().GetType())) - nf, err := commonnexus.APIFailureToNexusFailure(t.Failure) + nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") @@ -510,7 +510,7 @@ func (h *nexusHandler) StartOperation( case *nexuspb.StartOperationResponse_Failure: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("failure")) - nf, err := commonnexus.APIFailureToNexusFailure(t.Failure) + nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") @@ -641,7 +641,7 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, switch t := response.GetOutcome().(type) { case *matchingservice.DispatchNexusTaskResponse_Failure: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.Failure.GetNexusHandlerFailureInfo().GetType())) - nf, err := commonnexus.APIFailureToNexusFailure(t.Failure) + nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") diff --git a/service/history/workflow/mutable_state_impl.go b/service/history/workflow/mutable_state_impl.go index 9ff18dd95b..9034c228e9 100644 --- a/service/history/workflow/mutable_state_impl.go +++ b/service/history/workflow/mutable_state_impl.go @@ -777,7 +777,7 @@ func (ms *MutableStateImpl) GetNexusCompletion( } return completion, nil case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_FAILED: - f, err := commonnexus.APIFailureToNexusFailure(ce.GetWorkflowExecutionFailedEventAttributes().GetFailure()) + f, err := commonnexus.TemporalFailureToNexusFailure(ce.GetWorkflowExecutionFailedEventAttributes().GetFailure()) if err != nil { return nil, err } @@ -793,7 +793,7 @@ func (ms *MutableStateImpl) GetNexusCompletion( Links: []nexus.Link{startLink}, }) case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED: - f, err := commonnexus.APIFailureToNexusFailure(&failurepb.Failure{ + f, err := commonnexus.TemporalFailureToNexusFailure(&failurepb.Failure{ Message: "operation canceled", FailureInfo: &failurepb.Failure_CanceledFailureInfo{ CanceledFailureInfo: &failurepb.CanceledFailureInfo{ @@ -816,7 +816,7 @@ func (ms *MutableStateImpl) GetNexusCompletion( Links: []nexus.Link{startLink}, }) case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED: - f, err := commonnexus.APIFailureToNexusFailure(&failurepb.Failure{ + f, err := commonnexus.TemporalFailureToNexusFailure(&failurepb.Failure{ Message: "operation terminated", FailureInfo: &failurepb.Failure_TerminatedFailureInfo{ TerminatedFailureInfo: &failurepb.TerminatedFailureInfo{}, @@ -837,7 +837,7 @@ func (ms *MutableStateImpl) GetNexusCompletion( Links: []nexus.Link{startLink}, }) case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT: - f, err := commonnexus.APIFailureToNexusFailure(&failurepb.Failure{ + f, err := commonnexus.TemporalFailureToNexusFailure(&failurepb.Failure{ Message: "operation exceeded internal timeout", FailureInfo: &failurepb.Failure_TimeoutFailureInfo{ TimeoutFailureInfo: &failurepb.TimeoutFailureInfo{ From 7671e018cb4db75cde51ae443ca1f72eb7f530c8 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Fri, 23 Jan 2026 20:35:57 -0800 Subject: [PATCH 10/26] Allow preserving more information for operation errors --- common/nexus/failure.go | 104 ++++++------------ common/nexus/nexusrpc/completion.go | 20 +++- components/nexusoperations/completion.go | 6 +- .../nexusoperations/frontend/handler.go | 29 +---- go.mod | 2 +- go.sum | 2 + service/history/handler.go | 23 +++- .../history/workflow/mutable_state_impl.go | 21 ++-- 8 files changed, 93 insertions(+), 114 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 2248141bb7..58a8cf8a57 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -55,21 +55,32 @@ var failureTypeString = string((&failurepb.Failure{}).ProtoReflect().Descriptor( // ProtoFailureToNexusFailure converts a proto Nexus Failure to a Nexus SDK Failure. func ProtoFailureToNexusFailure(failure *nexuspb.Failure) nexus.Failure { - return nexus.Failure{ - Message: failure.GetMessage(), - Metadata: failure.GetMetadata(), - Details: failure.GetDetails(), + nf := nexus.Failure{ + Message: failure.GetMessage(), + StackTrace: failure.GetStackTrace(), + Metadata: failure.GetMetadata(), + Details: failure.GetDetails(), + } + if failure.GetCause() != nil { + cause := ProtoFailureToNexusFailure(failure.GetCause()) + nf.Cause = &cause } + return nf } // NexusFailureToProtoFailure converts a Nexus SDK Failure to a proto Nexus Failure. // Always returns a non-nil value. func NexusFailureToProtoFailure(failure nexus.Failure) *nexuspb.Failure { - return &nexuspb.Failure{ - Message: failure.Message, - Metadata: failure.Metadata, - Details: failure.Details, + pf := &nexuspb.Failure{ + Message: failure.Message, + Metadata: failure.Metadata, + Details: failure.Details, + StackTrace: failure.StackTrace, } + if failure.Cause != nil { + pf.Cause = NexusFailureToProtoFailure(*failure.Cause) + } + return pf } type serializedOperationError struct { @@ -299,77 +310,28 @@ func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) return apiFailure, nil } -func OperationErrorToTemporalFailure(opErr *nexus.OperationError) (*failurepb.Failure, error) { - var nexusFailure nexus.Failure - var failureErr *nexus.FailureError - ok := errors.As(opErr.Cause, &failureErr) - if ok { - nexusFailure = failureErr.Failure - } else if opErr.Cause != nil { - nexusFailure = nexus.Failure{Message: opErr.Cause.Error()} +func OperationErrorToTemporalFailureCause(opErr *nexus.OperationError) (*failurepb.Failure, error) { + if opErr == nil || opErr.OriginalFailure == nil { + return nil, nil + } + nexusFailure := opErr.OriginalFailure + temporalFailure, err := NexusFailureToTemporalFailure(*nexusFailure) + if err != nil { + return nil, serviceerror.NewInvalidArgument("Malformed failure") } - - // Canceled must be translated into a CanceledFailure to match the SDK expectation. if opErr.State == nexus.OperationStateCanceled { - if nexusFailure.Metadata != nil && nexusFailure.Metadata["type"] == failureTypeString { - temporalFailure, err := NexusFailureToTemporalFailure(nexusFailure) - if err != nil { - return nil, err - } - if temporalFailure.GetCanceledFailureInfo() != nil { - // We already have a CanceledFailure, use it. - return temporalFailure, nil - } - // Fallback to encoding the Nexus failure into a Temporal canceled failure, we expect operations that end up - // as canceled to have a CanceledFailureInfo object. - } - payloads, err := nexusFailureMetadataToPayloads(nexusFailure) - if err != nil { - return nil, err - } return &failurepb.Failure{ - Message: nexusFailure.Message, FailureInfo: &failurepb.Failure_CanceledFailureInfo{ - CanceledFailureInfo: &failurepb.CanceledFailureInfo{ - Details: payloads, - }, + CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, }, + // Preserve the original cause. + Cause: temporalFailure, }, nil } - - f, err := NexusFailureToTemporalFailure(nexusFailure) - if err != nil { - return nil, err - } - if f.GetApplicationFailureInfo() != nil { - f.GetApplicationFailureInfo().NonRetryable = true - } - return f, nil -} - -func nexusFailureMetadataToPayloads(failure nexus.Failure) (*commonpb.Payloads, error) { - if len(failure.Metadata) == 0 && len(failure.Details) == 0 { - return nil, nil - } - // Only serialize metadata and details. - cpy := nexus.Failure{ - Metadata: failure.Metadata, - Details: failure.Details, - } - data, err := json.Marshal(cpy) - if err != nil { - return nil, err + if temporalFailure.GetApplicationFailureInfo() != nil { + temporalFailure.GetApplicationFailureInfo().NonRetryable = true } - return &commonpb.Payloads{ - Payloads: []*commonpb.Payload{ - { - Metadata: map[string][]byte{ - "encoding": []byte("json/plain"), - }, - Data: data, - }, - }, - }, err + return temporalFailure, nil } // ConvertGRPCError converts either a serviceerror or a gRPC status error into a Nexus HandlerError if possible. diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index 12efb631cc..591c4c234e 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -257,7 +257,7 @@ type CompletionRequest struct { // Links are used to link back to the operation when a completion callback is received before a started response. Links []nexus.Link // Parsed from request and set if State is failed or canceled. - Error error + Error *nexus.OperationError // Extracted from request and set if State is succeeded. Result *nexus.LazyValue } @@ -330,11 +330,27 @@ func (h *completionHTTPHandler) ServeHTTP(writer http.ResponseWriter, request *h h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to read Failure from request body")) return } - completion.Error, err = h.failureConverter.FailureToError(failure) + completionErr, err := h.failureConverter.FailureToError(failure) if err != nil { h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode failure from request body")) return } + opErr, ok := completionErr.(*nexus.OperationError) + if !ok { + // Backwards compatibility: wrap non-OperationError errors in an OperationError with the appropriate state. + completion.Error = &nexus.OperationError{ + State: completion.State, + Cause: completionErr, + } + originalFailure, err := h.failureConverter.ErrorToFailure(completion.Error) + if err != nil { + h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode failure from request body")) + return + } + completion.Error.OriginalFailure = &originalFailure + } else { + completion.Error = opErr + } case nexus.OperationStateSucceeded: completion.Result = nexus.NewLazyValue( h.options.Serializer, diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index 90d59bbecf..11dcb28619 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -49,7 +49,7 @@ func handleOperationError( if err != nil { return err } - failure, err := commonnexus.OperationErrorToTemporalFailure(opFailedError) + cause, err := commonnexus.OperationErrorToTemporalFailureCause(opFailedError) if err != nil { return err } @@ -61,7 +61,7 @@ func handleOperationError( // nolint:revive e.Attributes = &historypb.HistoryEvent_NexusOperationFailedEventAttributes{ NexusOperationFailedEventAttributes: &historypb.NexusOperationFailedEventAttributes{ - Failure: nexusOperationFailure(operation, eventID, failure), + Failure: nexusOperationFailure(operation, eventID, cause), ScheduledEventId: eventID, RequestId: operation.RequestId, }, @@ -75,7 +75,7 @@ func handleOperationError( // nolint:revive e.Attributes = &historypb.HistoryEvent_NexusOperationCanceledEventAttributes{ NexusOperationCanceledEventAttributes: &historypb.NexusOperationCanceledEventAttributes{ - Failure: nexusOperationFailure(operation, eventID, failure), + Failure: nexusOperationFailure(operation, eventID, cause), ScheduledEventId: eventID, RequestId: operation.RequestId, }, diff --git a/components/nexusoperations/frontend/handler.go b/components/nexusoperations/frontend/handler.go index 988c4870da..6c6bfe7a92 100644 --- a/components/nexusoperations/frontend/handler.go +++ b/components/nexusoperations/frontend/handler.go @@ -201,22 +201,8 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C } switch r.State { // nolint:exhaustive case nexus.OperationStateFailed, nexus.OperationStateCanceled: - var failureErr *nexus.FailureError - var operationErr *nexus.OperationError - switch { - case errors.As(r.Error, &failureErr): - hr.Outcome = &historyservice.CompleteNexusOperationRequest_Failure{ - Failure: commonnexus.NexusFailureToProtoFailure(failureErr.Failure), - } - case errors.As(r.Error, &operationErr): - hr.Outcome = &historyservice.CompleteNexusOperationRequest_Failure{ - Failure: commonnexus.NexusFailureToProtoFailure(*operationErr.OriginalFailure), - } - default: - // This shouldn't happen as the Nexus SDK is always expected to convert Failures from the wire to - // FailureError or OperationErrors. - logger.Error("result error is not an OperationError or FailureError", tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") + hr.Outcome = &historyservice.CompleteNexusOperationRequest_Failure{ + Failure: commonnexus.NexusFailureToProtoFailure(*r.Error.OriginalFailure), } case nexus.OperationStateSucceeded: var result *commonpb.Payload @@ -291,19 +277,14 @@ func (h *completionHandler) forwardCompleteOperation(ctx context.Context, r *nex case nexus.OperationStateFailed, nexus.OperationStateCanceled: // For unsuccessful operations, the Nexus framework reads and closes the original request body to deserialize // the failure, so we must construct a new completion to forward. - var failureErr *nexus.FailureError - if !errors.As(r.Error, &failureErr) { - // This shouldn't happen as the Nexus SDK is always expected to convert Failures from the wire to - // FailureErrors. - h.Logger.Error("received unexpected error type when trying to forward Nexus operation completion", tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") - } c := &nexusrpc.OperationCompletionUnsuccessful{ State: r.State, OperationToken: r.OperationToken, StartTime: r.StartTime, Links: r.Links, - Failure: failureErr.Failure, + } + if r.Error != nil && r.Error.OriginalFailure != nil { + c.Failure = *r.Error.OriginalFailure } forwardReq, err = nexusrpc.NewCompletionHTTPRequest(ctx, forwardURL, c) if err != nil { diff --git a/go.mod b/go.mod index d9e1e6e5cf..6265c0cec1 100644 --- a/go.mod +++ b/go.mod @@ -58,7 +58,7 @@ require ( go.opentelemetry.io/otel/sdk v1.34.0 go.opentelemetry.io/otel/sdk/metric v1.34.0 go.opentelemetry.io/otel/trace v1.34.0 - go.temporal.io/api v1.61.1-0.20260123211420-03a0445cb1c7 + go.temporal.io/api v1.61.1-0.20260123235933-4c15176b0e79 go.temporal.io/sdk v1.38.0 go.uber.org/fx v1.24.0 go.uber.org/mock v0.6.0 diff --git a/go.sum b/go.sum index b160a64b7c..ca2455f31a 100644 --- a/go.sum +++ b/go.sum @@ -375,6 +375,8 @@ go.temporal.io/api v1.61.1-0.20260123194132-ee4a47298624 h1:3dqSdkDfWugeg8QEf1Ws go.temporal.io/api v1.61.1-0.20260123194132-ee4a47298624/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= go.temporal.io/api v1.61.1-0.20260123211420-03a0445cb1c7 h1:v5k7tdCrSAI1zNj7jClndKc/2Sq3C/B5g5NEQmkZrWs= go.temporal.io/api v1.61.1-0.20260123211420-03a0445cb1c7/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= +go.temporal.io/api v1.61.1-0.20260123235933-4c15176b0e79 h1:iv07QpG6uAWY67jK1WK/NSwkm2P5pyT2gnh2mqsTyCQ= +go.temporal.io/api v1.61.1-0.20260123235933-4c15176b0e79/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= go.temporal.io/sdk v1.38.0 h1:4Bok5LEdED7YKpsSjIa3dDqram5VOq+ydBf4pyx0Wo4= go.temporal.io/sdk v1.38.0/go.mod h1:a+R2Ej28ObvHoILbHaxMyind7M6D+W0L7edt5UJF4SE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= diff --git a/service/history/handler.go b/service/history/handler.go index 6982f6591d..0259e64736 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -2443,11 +2443,24 @@ func (h *Handler) CompleteNexusOperation(ctx context.Context, request *historyse } var opErr *nexus.OperationError if request.State != string(nexus.OperationStateSucceeded) { - opErr = &nexus.OperationError{ - State: nexus.OperationState(request.GetState()), - Cause: &nexus.FailureError{ - Failure: commonnexus.ProtoFailureToNexusFailure(request.GetFailure()), - }, + failure := commonnexus.ProtoFailureToNexusFailure(request.GetFailure()) + recvdErr, err := nexus.DefaultFailureConverter().FailureToError(failure) + if err != nil { + return nil, serviceerror.NewInvalidArgument("unable to convert failure to error") + } + // Backward compatibility: if the received error is not of type OperationError, wrap the error in OperationError. + var ok bool + if opErr, ok = recvdErr.(*nexus.OperationError); !ok { + opErr = &nexus.OperationError{ + State: nexus.OperationState(request.GetState()), + Message: "operation completed as " + request.GetState(), + Cause: recvdErr, + } + origFailure, err := nexus.DefaultFailureConverter().ErrorToFailure(opErr) + if err != nil { + return nil, serviceerror.NewInvalidArgument("unable to convert operation error to failure") + } + opErr.OriginalFailure = &origFailure } } err = nexusoperations.CompletionHandler( diff --git a/service/history/workflow/mutable_state_impl.go b/service/history/workflow/mutable_state_impl.go index 9034c228e9..83c356657e 100644 --- a/service/history/workflow/mutable_state_impl.go +++ b/service/history/workflow/mutable_state_impl.go @@ -783,8 +783,9 @@ func (ms *MutableStateImpl) GetNexusCompletion( } return nexusrpc.NewOperationCompletionUnsuccessful( &nexus.OperationError{ - State: nexus.OperationStateFailed, - Cause: &nexus.FailureError{Failure: f}, + State: nexus.OperationStateFailed, + Cause: &nexus.FailureError{Failure: f}, + // Store the original failure to bypass the Nexus failure converter. OriginalFailure: &f, }, nexusrpc.OperationCompletionUnsuccessfulOptions{ @@ -806,8 +807,9 @@ func (ms *MutableStateImpl) GetNexusCompletion( } return nexusrpc.NewOperationCompletionUnsuccessful( &nexus.OperationError{ - State: nexus.OperationStateCanceled, - Cause: &nexus.FailureError{Failure: f}, + State: nexus.OperationStateCanceled, + Cause: &nexus.FailureError{Failure: f}, + // Store the original failure to bypass the Nexus failure converter. OriginalFailure: &f, }, nexusrpc.OperationCompletionUnsuccessfulOptions{ @@ -826,9 +828,11 @@ func (ms *MutableStateImpl) GetNexusCompletion( return nil, err } return nexusrpc.NewOperationCompletionUnsuccessful( + // NOTE: Not setting a message for compatibility with older servers than don't support both cause and message. &nexus.OperationError{ - State: nexus.OperationStateFailed, - Cause: &nexus.FailureError{Failure: f}, + State: nexus.OperationStateFailed, + Cause: &nexus.FailureError{Failure: f}, + // Store the original failure to bypass the Nexus failure converter. OriginalFailure: &f, }, nexusrpc.OperationCompletionUnsuccessfulOptions{ @@ -851,8 +855,9 @@ func (ms *MutableStateImpl) GetNexusCompletion( } return nexusrpc.NewOperationCompletionUnsuccessful( &nexus.OperationError{ - State: nexus.OperationStateFailed, - Cause: &nexus.FailureError{Failure: f}, + State: nexus.OperationStateFailed, + Cause: &nexus.FailureError{Failure: f}, + // Store the original failure to bypass the Nexus failure converter. OriginalFailure: &f, }, nexusrpc.OperationCompletionUnsuccessfulOptions{ From c0eb657acf5ef8970c2908eca05dac56610dd2f8 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Fri, 23 Jan 2026 21:30:23 -0800 Subject: [PATCH 11/26] Fix retryable flag --- components/nexusoperations/executors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 8c072c64d2..7982f34723 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -482,7 +482,7 @@ func handleNonRetryableStartOperationError(node *hsm.Node, operation Operation, if err != nil { return err } - failure, err := callErrToFailure(callErr, true) + failure, err := callErrToFailure(callErr, false) if err != nil { return err } From d64256073b1a377e8572d41d730c1ae7f5e0b8df Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Mon, 26 Jan 2026 08:51:52 -0800 Subject: [PATCH 12/26] Various edge cases and better handling of operation errors --- common/nexus/failure.go | 24 ------------- common/nexus/nexusrpc/client.go | 27 ++++++++++---- components/nexusoperations/completion.go | 36 ++++++++++++++----- components/nexusoperations/executors.go | 46 ++++++++++++++++-------- service/frontend/nexus_handler.go | 37 ++++++------------- service/history/handler.go | 3 +- tests/nexus_workflow_test.go | 26 +++++++------- 7 files changed, 105 insertions(+), 94 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 58a8cf8a57..3b4ae207cd 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -310,30 +310,6 @@ func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) return apiFailure, nil } -func OperationErrorToTemporalFailureCause(opErr *nexus.OperationError) (*failurepb.Failure, error) { - if opErr == nil || opErr.OriginalFailure == nil { - return nil, nil - } - nexusFailure := opErr.OriginalFailure - temporalFailure, err := NexusFailureToTemporalFailure(*nexusFailure) - if err != nil { - return nil, serviceerror.NewInvalidArgument("Malformed failure") - } - if opErr.State == nexus.OperationStateCanceled { - return &failurepb.Failure{ - FailureInfo: &failurepb.Failure_CanceledFailureInfo{ - CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, - }, - // Preserve the original cause. - Cause: temporalFailure, - }, nil - } - if temporalFailure.GetApplicationFailureInfo() != nil { - temporalFailure.GetApplicationFailureInfo().NonRetryable = true - } - return temporalFailure, nil -} - // ConvertGRPCError converts either a serviceerror or a gRPC status error into a Nexus HandlerError if possible. // If exposeDetails is true, the error message from the given error is exposed in the converted HandlerError, otherwise, // a default message with minimal information is attached to the returned error. diff --git a/common/nexus/nexusrpc/client.go b/common/nexus/nexusrpc/client.go index cd03e3aa5e..7648b2e313 100644 --- a/common/nexus/nexusrpc/client.go +++ b/common/nexus/nexusrpc/client.go @@ -286,24 +286,31 @@ func (c *HTTPClient) StartOperation( return nil, err } - opErr, err := c.options.FailureConverter.FailureToError(failure) + wireErr, err := c.options.FailureConverter.FailureToError(failure) if err != nil { return nil, err } // For compatibility with older servers. - if _, ok := opErr.(*nexus.OperationError); !ok { + if _, ok := wireErr.(*nexus.OperationError); !ok { state, err := getUnsuccessfulStateFromHeader(response, body) if err != nil { return nil, err } - opErr = &nexus.OperationError{ - State: state, - Cause: opErr, + opErr := &nexus.OperationError{ + State: state, + Message: "operation failed", + Cause: wireErr, } + originalFailure, err := c.options.FailureConverter.ErrorToFailure(wireErr) + if err != nil { + return nil, err + } + opErr.OriginalFailure = &originalFailure + wireErr = opErr } - return nil, opErr + return nil, wireErr default: return nil, c.bestEffortHandlerErrorFromResponse(response, body) } @@ -369,13 +376,19 @@ func (c *HTTPClient) defaultErrorFromResponse(response *http.Response, body []by // TODO: use the provided cause, it's already a deserialized failure. return newUnexpectedResponseError(err.Error(), response, body) } - return &nexus.HandlerError{ + handlerErr := &nexus.HandlerError{ Type: errorType, Message: response.Status, // For compatibility with older servers. RetryBehavior: retryBehaviorFromHeader(response.Header), Cause: cause, } + originalFailure, err := c.options.FailureConverter.ErrorToFailure(handlerErr) + if err != nil { + return newUnexpectedResponseError("failed to construct handler error from response: "+err.Error(), response, body) + } + handlerErr.OriginalFailure = &originalFailure + return handlerErr } func (c *HTTPClient) bestEffortHandlerErrorFromResponse(response *http.Response, body []byte) error { diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index 11dcb28619..38fb1159f6 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -8,6 +8,7 @@ import ( "github.com/nexus-rpc/sdk-go/nexus" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" + failurepb "go.temporal.io/api/failure/v1" historypb "go.temporal.io/api/history/v1" "go.temporal.io/api/serviceerror" commonnexus "go.temporal.io/server/common/nexus" @@ -43,25 +44,30 @@ func handleSuccessfulOperationResult( func handleOperationError( node *hsm.Node, operation Operation, - opFailedError *nexus.OperationError, + opErr *nexus.OperationError, ) error { eventID, err := hsm.EventIDFromToken(operation.ScheduledEventToken) if err != nil { return err } - cause, err := commonnexus.OperationErrorToTemporalFailureCause(opFailedError) - if err != nil { - return err + var originalFailure *failurepb.Failure + if opErr.OriginalFailure != nil { + nexusFailure := opErr.OriginalFailure + originalFailure, err = commonnexus.NexusFailureToTemporalFailure(*nexusFailure) + if err != nil { + return serviceerror.NewInvalidArgumentf("Malformed failure: %v", err) + } } - switch opFailedError.State { // nolint:exhaustive + switch opErr.State { // nolint:exhaustive case nexus.OperationStateFailed: event := node.AddHistoryEvent(enumspb.EVENT_TYPE_NEXUS_OPERATION_FAILED, func(e *historypb.HistoryEvent) { + failure := convertToNexusOperationFailure(operation, eventID, originalFailure) // We must assign to this property, linter doesn't like this. // nolint:revive e.Attributes = &historypb.HistoryEvent_NexusOperationFailedEventAttributes{ NexusOperationFailedEventAttributes: &historypb.NexusOperationFailedEventAttributes{ - Failure: nexusOperationFailure(operation, eventID, cause), + Failure: failure, ScheduledEventId: eventID, RequestId: operation.RequestId, }, @@ -70,12 +76,26 @@ func handleOperationError( return FailedEventDefinition{}.Apply(node.Parent, event) case nexus.OperationStateCanceled: + var originalCause *failurepb.Failure + if originalFailure.GetCause().GetCanceledFailureInfo() != nil { + originalCause = originalFailure.GetCause() + } else { + // Wrap the original failure in a CanceledFailureInfo to indicate cancellation. All workflow commands expected a + // nested CanceledFailure. + originalFailure.Cause = &failurepb.Failure{ + FailureInfo: &failurepb.Failure_CanceledFailureInfo{ + CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, + }, + // Preserve the original cause. + Cause: originalCause, + } + } event := node.AddHistoryEvent(enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCELED, func(e *historypb.HistoryEvent) { // We must assign to this property, linter doesn't like this. // nolint:revive e.Attributes = &historypb.HistoryEvent_NexusOperationCanceledEventAttributes{ NexusOperationCanceledEventAttributes: &historypb.NexusOperationCanceledEventAttributes{ - Failure: nexusOperationFailure(operation, eventID, cause), + Failure: convertToNexusOperationFailure(operation, eventID, originalFailure), ScheduledEventId: eventID, RequestId: operation.RequestId, }, @@ -86,7 +106,7 @@ func handleOperationError( default: // Both the Nexus Client and CompletionHandler reject invalid states, but just in case, we return this as a // transition error. - return fmt.Errorf("unexpected operation state: %v", opFailedError.State) + return fmt.Errorf("unexpected operation state: %v", opErr.State) } } diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 7982f34723..09ada93b78 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -487,10 +487,13 @@ func handleNonRetryableStartOperationError(node *hsm.Node, operation Operation, return err } attrs := &historypb.NexusOperationFailedEventAttributes{ - Failure: nexusOperationFailure( + Failure: convertToNexusOperationFailure( operation, eventID, - failure, + &failurepb.Failure{ + Message: "failed to start Nexus operation", + Cause: failure, + }, ), ScheduledEventId: eventID, RequestId: operation.RequestId, @@ -527,14 +530,17 @@ func (e taskExecutor) recordOperationTimeout(node *hsm.Node) error { // nolint:revive // We must mutate here even if the linter doesn't like it. e.Attributes = &historypb.HistoryEvent_NexusOperationTimedOutEventAttributes{ NexusOperationTimedOutEventAttributes: &historypb.NexusOperationTimedOutEventAttributes{ - Failure: nexusOperationFailure( + Failure: convertToNexusOperationFailure( op, eventID, &failurepb.Failure{ Message: "operation timed out", - FailureInfo: &failurepb.Failure_TimeoutFailureInfo{ - TimeoutFailureInfo: &failurepb.TimeoutFailureInfo{ - TimeoutType: enumspb.TIMEOUT_TYPE_SCHEDULE_TO_CLOSE, + Cause: &failurepb.Failure{ + Message: "operation timed out", + FailureInfo: &failurepb.Failure_TimeoutFailureInfo{ + TimeoutFailureInfo: &failurepb.TimeoutFailureInfo{ + TimeoutType: enumspb.TIMEOUT_TYPE_SCHEDULE_TO_CLOSE, + }, }, }, }, @@ -780,9 +786,13 @@ func (e taskExecutor) lookupEndpoint(ctx context.Context, namespaceID namespace. return entry, nil } -func nexusOperationFailure(operation Operation, scheduledEventID int64, cause *failurepb.Failure) *failurepb.Failure { +// Copy over message and stack trace if present since to preserve as much as possible from the original operation +// error. +func convertToNexusOperationFailure(operation Operation, scheduledEventID int64, originalFailure *failurepb.Failure) *failurepb.Failure { return &failurepb.Failure{ - Message: "nexus operation completed unsuccessfully", + Message: originalFailure.GetMessage(), + StackTrace: originalFailure.GetStackTrace(), + EncodedAttributes: originalFailure.GetEncodedAttributes(), FailureInfo: &failurepb.Failure_NexusOperationExecutionFailureInfo{ NexusOperationExecutionFailureInfo: &failurepb.NexusOperationFailureInfo{ Endpoint: operation.Endpoint, @@ -794,7 +804,7 @@ func nexusOperationFailure(operation Operation, scheduledEventID int64, cause *f ScheduledEventId: scheduledEventID, }, }, - Cause: cause, + Cause: originalFailure.GetCause(), } } @@ -863,17 +873,23 @@ func isDestinationDown(err error) bool { func callErrToFailure(callErr error, retryable bool) (*failurepb.Failure, error) { var handlerErr *nexus.HandlerError if errors.As(callErr, &handlerErr) { - nf, err := nexus.DefaultFailureConverter().ErrorToFailure(handlerErr) - if err != nil { - return nil, err + var nf nexus.Failure + if handlerErr.OriginalFailure != nil { + nf = *handlerErr.OriginalFailure + } else { + var err error + if handlerErr.Message == "" { + handlerErr.Message = "handler error" + } + nf, err = nexus.DefaultFailureConverter().ErrorToFailure(handlerErr) + if err != nil { + return nil, err + } } f, err := commonnexus.NexusFailureToTemporalFailure(nf) if err != nil { return nil, err } - if f.GetApplicationFailureInfo() != nil { - f.GetApplicationFailureInfo().NonRetryable = !retryable - } return f, nil } diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index d624bc4a2f..e4a2a552f8 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -467,7 +467,7 @@ func (h *nexusHandler) StartOperation( // Deprecated case. Replaced with DispatchNexusTaskResponse_Failure oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.HandlerError.GetErrorType())) oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - err := h.convertOutcomeToNexusHandlerError(t) + err := convertOutcomeToNexusHandlerError(t) return nil, err case *matchingservice.DispatchNexusTaskResponse_RequestTimeout: @@ -510,6 +510,7 @@ func (h *nexusHandler) StartOperation( case *nexuspb.StartOperationResponse_Failure: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("failure")) + oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) @@ -520,7 +521,6 @@ func (h *nexusHandler) StartOperation( oc.logger.Error("error converting Nexus failure to Nexus OperationError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) return nil, oe } } @@ -641,6 +641,7 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, switch t := response.GetOutcome().(type) { case *matchingservice.DispatchNexusTaskResponse_Failure: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.Failure.GetNexusHandlerFailureInfo().GetType())) + oc.setFailureSource(commonnexus.FailureSourceWorker) nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) @@ -652,14 +653,13 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } // Failure conversions are our fault so only set this after converting the Temporal failure to a HandlerError. - oc.setFailureSource(commonnexus.FailureSourceWorker) return he case *matchingservice.DispatchNexusTaskResponse_HandlerError: // Deprecated case. Replaced with DispatchNexusTaskResponse_Failure oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.HandlerError.GetErrorType())) oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - err := h.convertOutcomeToNexusHandlerError(t) + err := convertOutcomeToNexusHandlerError(t) return err case *matchingservice.DispatchNexusTaskResponse_RequestTimeout: @@ -771,7 +771,7 @@ func (h *nexusHandler) nexusClientForActiveCluster(oc *operationContext, service }) } -func (h *nexusHandler) convertOutcomeToNexusHandlerError(resp *matchingservice.DispatchNexusTaskResponse_HandlerError) *nexus.HandlerError { +func convertOutcomeToNexusHandlerError(resp *matchingservice.DispatchNexusTaskResponse_HandlerError) *nexus.HandlerError { var retryBehavior nexus.HandlerErrorRetryBehavior // nolint:exhaustive // unspecified is the default switch resp.HandlerError.RetryBehavior { @@ -780,28 +780,11 @@ func (h *nexusHandler) convertOutcomeToNexusHandlerError(resp *matchingservice.D case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: retryBehavior = nexus.HandlerErrorRetryBehaviorNonRetryable } - handlerError := &nexus.HandlerError{ - Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), - Cause: &nexus.FailureError{ - Failure: commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()), - }, - RetryBehavior: retryBehavior, - } - - switch handlerError.Type { - case nexus.HandlerErrorTypeUpstreamTimeout, - nexus.HandlerErrorTypeUnauthenticated, - nexus.HandlerErrorTypeUnauthorized, - nexus.HandlerErrorTypeBadRequest, - nexus.HandlerErrorTypeResourceExhausted, - nexus.HandlerErrorTypeNotFound, - nexus.HandlerErrorTypeNotImplemented, - nexus.HandlerErrorTypeUnavailable, - nexus.HandlerErrorTypeInternal: - return handlerError - default: - h.logger.Warn("received unknown or unset Nexus handler error type", tag.Value(handlerError.Type)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + originalFailure := commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()) + return &nexus.HandlerError{ + Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), + RetryBehavior: retryBehavior, + OriginalFailure: &originalFailure, } } diff --git a/service/history/handler.go b/service/history/handler.go index 0259e64736..d40b7d9582 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -2452,7 +2452,8 @@ func (h *Handler) CompleteNexusOperation(ctx context.Context, request *historyse var ok bool if opErr, ok = recvdErr.(*nexus.OperationError); !ok { opErr = &nexus.OperationError{ - State: nexus.OperationState(request.GetState()), + State: nexus.OperationState(request.GetState()), + // Setting a message here will bypass the Nexus SDK's failure converter backward compatibility logic. Message: "operation completed as " + request.GetState(), Cause: recvdErr, } diff --git a/tests/nexus_workflow_test.go b/tests/nexus_workflow_test.go index 7f1d8980ec..519c206ed4 100644 --- a/tests/nexus_workflow_test.go +++ b/tests/nexus_workflow_test.go @@ -3,7 +3,6 @@ package tests import ( "bytes" "context" - "encoding/json" "errors" "fmt" "io" @@ -1211,9 +1210,11 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { var result string err = run.Get(ctx, &result) var wee *temporal.WorkflowExecutionError - s.ErrorAs(err, &wee) - s.True(strings.HasPrefix(wee.Unwrap().Error(), "nexus operation completed unsuccessfully")) + + var noe *temporal.NexusOperationError + s.ErrorAs(wee, &noe) + s.Contains(noe.Error(), "test operation failed") } func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { @@ -2133,9 +2134,7 @@ func (s *NexusWorkflowTestSuite) TestNexusSyncOperationErrorRehydration() { checkWorkflowError: func(t *testing.T, wfErr error) { var opErr *temporal.NexusOperationError require.ErrorAs(t, wfErr, &opErr) - var appErr *temporal.ApplicationError - require.ErrorAs(t, opErr, &appErr) - require.Equal(t, "some error", appErr.Message()) + require.Equal(t, "some error", opErr.Message) }, }, { @@ -2503,15 +2502,18 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncNexusFailure() { var handlerErr *nexus.HandlerError s.ErrorAs(wfErr, &handlerErr) s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + // Old SDK path var appErr *temporal.ApplicationError s.ErrorAs(handlerErr.Cause, &appErr) s.Equal(appErr.Message(), "fail me") - var failure nexus.Failure - s.NoError(appErr.Details(&failure)) - s.Equal(map[string]string{"key": "val"}, failure.Metadata) - var details string - s.NoError(json.Unmarshal(failure.Details, &details)) - s.Equal("details", details) + // NOTE: We broke compatibility here but the likelyhood of anyone using FailureErrors directly is practically zero. + // TODO(bergundy): test when the new SDK supports deserializing Nexus SDK failures + // var failure nexus.Failure + // s.NoError(appErr.Details(&failure)) + // s.Equal(map[string]string{"key": "val"}, failure.Metadata) + // var details string + // s.NoError(json.Unmarshal(failure.Details, &details)) + // s.Equal("details", details) snap := capture.Snapshot() s.Len(snap["nexus_outbound_requests"], 1) From 7d11cc88ef4700fc18ffa357f0744b7ba6c3164a Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Mon, 26 Jan 2026 09:01:51 -0800 Subject: [PATCH 13/26] Revert back to original error message --- components/nexusoperations/executors.go | 9 ++++++--- service/history/handler.go | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 09ada93b78..8ac371b916 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -491,8 +491,7 @@ func handleNonRetryableStartOperationError(node *hsm.Node, operation Operation, operation, eventID, &failurepb.Failure{ - Message: "failed to start Nexus operation", - Cause: failure, + Cause: failure, }, ), ScheduledEventId: eventID, @@ -789,7 +788,7 @@ func (e taskExecutor) lookupEndpoint(ctx context.Context, namespaceID namespace. // Copy over message and stack trace if present since to preserve as much as possible from the original operation // error. func convertToNexusOperationFailure(operation Operation, scheduledEventID int64, originalFailure *failurepb.Failure) *failurepb.Failure { - return &failurepb.Failure{ + f := &failurepb.Failure{ Message: originalFailure.GetMessage(), StackTrace: originalFailure.GetStackTrace(), EncodedAttributes: originalFailure.GetEncodedAttributes(), @@ -806,6 +805,10 @@ func convertToNexusOperationFailure(operation Operation, scheduledEventID int64, }, Cause: originalFailure.GetCause(), } + if originalFailure.GetMessage() == "" { + f.Message = "nexus operation completed unsuccessfully" + } + return f } func startCallOutcomeTag(callCtx context.Context, result *nexusrpc.ClientStartOperationResponse[*nexus.LazyValue], callErr error) string { diff --git a/service/history/handler.go b/service/history/handler.go index d40b7d9582..e070b92d7d 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -2454,7 +2454,7 @@ func (h *Handler) CompleteNexusOperation(ctx context.Context, request *historyse opErr = &nexus.OperationError{ State: nexus.OperationState(request.GetState()), // Setting a message here will bypass the Nexus SDK's failure converter backward compatibility logic. - Message: "operation completed as " + request.GetState(), + Message: "nexus operation completed unsuccessfully", Cause: recvdErr, } origFailure, err := nexus.DefaultFailureConverter().ErrorToFailure(opErr) From 64da4dfff78953d89d6d829bf4d67c9933a746fd Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Mon, 26 Jan 2026 10:25:27 -0800 Subject: [PATCH 14/26] Lint and some self-review --- common/nexus/failure.go | 2 ++ common/nexus/failure_test.go | 1 - common/nexus/nexusrpc/client.go | 5 +++-- components/nexusoperations/completion.go | 23 ++++++++++++++++------ components/nexusoperations/executors.go | 2 +- service/frontend/nexus_handler.go | 10 ++++++---- tests/nexus_api_test.go | 6 +++--- tests/xdc/nexus_request_forwarding_test.go | 4 ++-- 8 files changed, 34 insertions(+), 19 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 3b4ae207cd..9dc510b984 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -158,6 +158,8 @@ func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, e case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: val := false retryableOverride = &val + case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_UNSPECIFIED: + // noop } handlerError := serializedHandlerError{ diff --git a/common/nexus/failure_test.go b/common/nexus/failure_test.go index 6f04b2bdc8..7949452127 100644 --- a/common/nexus/failure_test.go +++ b/common/nexus/failure_test.go @@ -283,7 +283,6 @@ func TestRoundTrip_OnlyStackTrace(t *testing.T) { protorequire.ProtoEqual(t, original, converted) } - func TestRoundTrip_OnlyDetails(t *testing.T) { original := &failurepb.Failure{ FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ diff --git a/common/nexus/nexusrpc/client.go b/common/nexus/nexusrpc/client.go index 7648b2e313..355ee89d49 100644 --- a/common/nexus/nexusrpc/client.go +++ b/common/nexus/nexusrpc/client.go @@ -373,12 +373,13 @@ func (c *HTTPClient) failureFromResponse(response *http.Response, body []byte) ( func (c *HTTPClient) defaultErrorFromResponse(response *http.Response, body []byte, cause error) error { errorType, err := httpStatusCodeToHandlerErrorType(response) if err != nil { - // TODO: use the provided cause, it's already a deserialized failure. + // TODO(bergundy): optimization - use the provided cause, it's already a deserialized failure. return newUnexpectedResponseError(err.Error(), response, body) } + statusText := strings.TrimPrefix(response.Status, fmt.Sprintf("%d ", response.StatusCode)) handlerErr := &nexus.HandlerError{ Type: errorType, - Message: response.Status, + Message: statusText, // For compatibility with older servers. RetryBehavior: retryBehaviorFromHeader(response.Header), Cause: cause, diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index 38fb1159f6..feba86c36d 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -51,12 +51,24 @@ func handleOperationError( return err } var originalFailure *failurepb.Failure - if opErr.OriginalFailure != nil { - nexusFailure := opErr.OriginalFailure - originalFailure, err = commonnexus.NexusFailureToTemporalFailure(*nexusFailure) + + // This should never be nil, but better be defensive here than to panic. + if opErr.OriginalFailure == nil { + if opErr.Message == "" { + // Add a generic message to ensure the failure converter does not unwrap the failure for compatibility. + opErr.Message = "nexus operation completed unsuccessfully" + } + originalNexusFailure, err := nexus.DefaultFailureConverter().ErrorToFailure(opErr) if err != nil { return serviceerror.NewInvalidArgumentf("Malformed failure: %v", err) } + opErr.OriginalFailure = &originalNexusFailure + } + + nexusFailure := opErr.OriginalFailure + originalFailure, err = commonnexus.NexusFailureToTemporalFailure(*nexusFailure) + if err != nil { + return serviceerror.NewInvalidArgumentf("Malformed failure: %v", err) } switch opErr.State { // nolint:exhaustive @@ -77,15 +89,14 @@ func handleOperationError( return FailedEventDefinition{}.Apply(node.Parent, event) case nexus.OperationStateCanceled: var originalCause *failurepb.Failure - if originalFailure.GetCause().GetCanceledFailureInfo() != nil { - originalCause = originalFailure.GetCause() - } else { + if originalFailure.GetCause().GetCanceledFailureInfo() == nil { // Wrap the original failure in a CanceledFailureInfo to indicate cancellation. All workflow commands expected a // nested CanceledFailure. originalFailure.Cause = &failurepb.Failure{ FailureInfo: &failurepb.Failure_CanceledFailureInfo{ CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, }, + // TODO(bergundy): This might be confusing. // Preserve the original cause. Cause: originalCause, } diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 8ac371b916..f4888aeeec 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -533,7 +533,6 @@ func (e taskExecutor) recordOperationTimeout(node *hsm.Node) error { op, eventID, &failurepb.Failure{ - Message: "operation timed out", Cause: &failurepb.Failure{ Message: "operation timed out", FailureInfo: &failurepb.Failure_TimeoutFailureInfo{ @@ -881,6 +880,7 @@ func callErrToFailure(callErr error, retryable bool) (*failurepb.Failure, error) nf = *handlerErr.OriginalFailure } else { var err error + // Ensure the error message is set to ensure the failure converter does not unwrap the cause. if handlerErr.Message == "" { handlerErr.Message = "handler error" } diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index e4a2a552f8..c8cfeee1c4 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -466,7 +466,7 @@ func (h *nexusHandler) StartOperation( case *matchingservice.DispatchNexusTaskResponse_HandlerError: // Deprecated case. Replaced with DispatchNexusTaskResponse_Failure oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.HandlerError.GetErrorType())) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) + oc.setFailureSource(commonnexus.FailureSourceWorker) err := convertOutcomeToNexusHandlerError(t) return nil, err @@ -499,7 +499,7 @@ func (h *nexusHandler) StartOperation( case *nexuspb.StartOperationResponse_OperationError: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("operation_error")) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) + oc.setFailureSource(commonnexus.FailureSourceWorker) err := &nexus.OperationError{ State: nexus.OperationState(t.OperationError.GetOperationState()), Cause: &nexus.FailureError{ @@ -510,7 +510,7 @@ func (h *nexusHandler) StartOperation( case *nexuspb.StartOperationResponse_Failure: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("failure")) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) + oc.setFailureSource(commonnexus.FailureSourceWorker) nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) @@ -526,7 +526,7 @@ func (h *nexusHandler) StartOperation( } // This is the worker's fault. oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:EMPTY_OUTCOME")) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) + oc.setFailureSource(commonnexus.FailureSourceWorker) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") } @@ -780,8 +780,10 @@ func convertOutcomeToNexusHandlerError(resp *matchingservice.DispatchNexusTaskRe case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: retryBehavior = nexus.HandlerErrorRetryBehaviorNonRetryable } + // nolint:staticcheck // Deprecated function still in use for backward compatibility. originalFailure := commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()) return &nexus.HandlerError{ + // nolint:staticcheck // Deprecated function still in use for backward compatibility. Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), RetryBehavior: retryBehavior, OriginalFailure: &originalFailure, diff --git a/tests/nexus_api_test.go b/tests/nexus_api_test.go index a05fee0e17..d5e53eda74 100644 --- a/tests/nexus_api_test.go +++ b/tests/nexus_api_test.go @@ -203,7 +203,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, nexus.HandlerErrorRetryBehaviorUnspecified, handlerErr.RetryBehavior) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) - require.Equal(t, "500 Internal Server Error", handlerErr.Message) + require.Equal(t, "Internal Server Error", handlerErr.Message) require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, @@ -225,7 +225,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, nexus.HandlerErrorRetryBehaviorNonRetryable, handlerErr.RetryBehavior) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) - require.Equal(t, "500 Internal Server Error", handlerErr.Message) + require.Equal(t, "Internal Server Error", handlerErr.Message) require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, @@ -553,7 +553,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUnauthenticated, handlerErr.Type) - require.Equal(t, "401 Unauthorized", handlerErr.Message) + require.Equal(t, "Unauthorized", handlerErr.Message) require.Equal(t, 1, len(snap["nexus_request_preprocess_errors"])) }, }, diff --git a/tests/xdc/nexus_request_forwarding_test.go b/tests/xdc/nexus_request_forwarding_test.go index ab827f97ac..31da7c9a59 100644 --- a/tests/xdc/nexus_request_forwarding_test.go +++ b/tests/xdc/nexus_request_forwarding_test.go @@ -191,7 +191,7 @@ func (s *NexusRequestForwardingSuite) TestStartOperationForwardedFromStandbyToAc var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) - require.Equal(t, "500 Internal Server Error", handlerErr.Message) + require.Equal(t, "Internal Server Error", handlerErr.Message) require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) requireExpectedMetricsCaptured(t, activeSnap, ns, "StartNexusOperation", "handler_error:INTERNAL") @@ -317,7 +317,7 @@ func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToA var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) - require.Equal(t, "500 Internal Server Error", handlerErr.Message) + require.Equal(t, "Internal Server Error", handlerErr.Message) require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) requireExpectedMetricsCaptured(t, activeSnap, ns, "CancelNexusOperation", "handler_error:INTERNAL") From d6e1b1fa27b1ea8a812e0019716ee8792fa24e13 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Mon, 26 Jan 2026 10:57:10 -0800 Subject: [PATCH 15/26] Test fixes --- common/nexus/failure.go | 2 + components/nexusoperations/completion.go | 3 +- components/nexusoperations/executors_test.go | 55 ++++++++++---------- tests/nexus_api_test.go | 2 +- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 9dc510b984..52a6d0cf78 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -160,6 +160,8 @@ func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, e retryableOverride = &val case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_UNSPECIFIED: // noop + default: + // noop } handlerError := serializedHandlerError{ diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index feba86c36d..bd55314fa9 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -88,7 +88,6 @@ func handleOperationError( return FailedEventDefinition{}.Apply(node.Parent, event) case nexus.OperationStateCanceled: - var originalCause *failurepb.Failure if originalFailure.GetCause().GetCanceledFailureInfo() == nil { // Wrap the original failure in a CanceledFailureInfo to indicate cancellation. All workflow commands expected a // nested CanceledFailure. @@ -98,7 +97,7 @@ func handleOperationError( }, // TODO(bergundy): This might be confusing. // Preserve the original cause. - Cause: originalCause, + Cause: originalFailure.GetCause(), } } event := node.AddHistoryEvent(enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCELED, func(e *historypb.HistoryEvent) { diff --git a/components/nexusoperations/executors_test.go b/components/nexusoperations/executors_test.go index 989743a922..5c70d036f9 100644 --- a/components/nexusoperations/executors_test.go +++ b/components/nexusoperations/executors_test.go @@ -187,9 +187,10 @@ func TestProcessInvocationTask(t *testing.T) { destinationDown: false, onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { return nil, &nexus.OperationError{ - State: nexus.OperationStateFailed, + State: nexus.OperationStateFailed, + Message: "operation failed from handler", Cause: &nexus.FailureError{ - Failure: nexus.Failure{Message: "operation failed from handler", Metadata: map[string]string{"encoding": "json/plain"}, Details: json.RawMessage("\"details\"")}, + Failure: nexus.Failure{Message: "cause", Metadata: map[string]string{"encoding": "json/plain"}, Details: json.RawMessage("\"details\"")}, }, } }, @@ -202,7 +203,7 @@ func TestProcessInvocationTask(t *testing.T) { ScheduledEventId: 1, RequestId: op.RequestId, Failure: &failurepb.Failure{ - Message: "nexus operation completed unsuccessfully", + Message: "operation failed from handler", FailureInfo: &failurepb.Failure_NexusOperationExecutionFailureInfo{ NexusOperationExecutionFailureInfo: &failurepb.NexusOperationFailureInfo{ ScheduledEventId: 1, @@ -212,16 +213,11 @@ func TestProcessInvocationTask(t *testing.T) { }, }, Cause: &failurepb.Failure{ - Message: "operation failed from handler", - FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ - ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ - Type: "NexusFailure", - Details: &commonpb.Payloads{ - Payloads: []*commonpb.Payload{ - mustToPayload(t, nexus.Failure{Metadata: map[string]string{"encoding": "json/plain"}, Details: []byte(`"details"`)}), - }, - }, - NonRetryable: true, + Message: "cause", + FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ + NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ + Metadata: map[string]string{"encoding": "json/plain"}, + Details: []byte(`"details"`), }, }, }, @@ -236,9 +232,10 @@ func TestProcessInvocationTask(t *testing.T) { destinationDown: false, onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { return nil, &nexus.OperationError{ - State: nexus.OperationStateCanceled, + State: nexus.OperationStateCanceled, + Message: "operation canceled from handler", Cause: &nexus.FailureError{ - Failure: nexus.Failure{Message: "operation canceled from handler", Metadata: map[string]string{"encoding": "json/plain"}, Details: json.RawMessage("\"details\"")}, + Failure: nexus.Failure{Message: "cause", Metadata: map[string]string{"encoding": "json/plain"}, Details: json.RawMessage("\"details\"")}, }, } }, @@ -251,7 +248,7 @@ func TestProcessInvocationTask(t *testing.T) { ScheduledEventId: 1, RequestId: op.RequestId, Failure: &failurepb.Failure{ - Message: "nexus operation completed unsuccessfully", + Message: "operation canceled from handler", FailureInfo: &failurepb.Failure_NexusOperationExecutionFailureInfo{ NexusOperationExecutionFailureInfo: &failurepb.NexusOperationFailureInfo{ ScheduledEventId: 1, @@ -261,13 +258,15 @@ func TestProcessInvocationTask(t *testing.T) { }, }, Cause: &failurepb.Failure{ - Message: "operation canceled from handler", FailureInfo: &failurepb.Failure_CanceledFailureInfo{ - CanceledFailureInfo: &failurepb.CanceledFailureInfo{ - Details: &commonpb.Payloads{ - Payloads: []*commonpb.Payload{ - mustToPayload(t, nexus.Failure{Metadata: map[string]string{"encoding": "json/plain"}, Details: []byte(`"details"`)}), - }, + CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, + }, + Cause: &failurepb.Failure{ + Message: "cause", + FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ + NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ + Metadata: map[string]string{"encoding": "json/plain"}, + Details: []byte(`"details"`), }, }, }, @@ -287,7 +286,7 @@ func TestProcessInvocationTask(t *testing.T) { expectedMetricOutcome: "handler-error:INTERNAL", checkOutcome: func(t *testing.T, op nexusoperations.Operation, events []*historypb.HistoryEvent) { require.Equal(t, enumsspb.NEXUS_OPERATION_STATE_BACKING_OFF, op.State()) - require.NotNil(t, op.LastAttemptFailure.GetNexusHandlerFailureInfo()) + require.Equal(t, string(nexus.HandlerErrorTypeInternal), op.LastAttemptFailure.GetNexusHandlerFailureInfo().GetType()) require.Equal(t, "internal server error", op.LastAttemptFailure.Message) require.Equal(t, 0, len(events)) }, @@ -351,7 +350,7 @@ func TestProcessInvocationTask(t *testing.T) { require.Equal(t, enumsspb.NEXUS_OPERATION_STATE_FAILED, op.State()) require.Equal(t, 1, len(events)) failure := events[0].GetNexusOperationFailedEventAttributes().Failure.Cause - require.NotNil(t, failure.GetNexusHandlerFailureInfo()) + require.Equal(t, string(nexus.HandlerErrorTypeNotFound), failure.GetNexusHandlerFailureInfo().GetType()) require.Equal(t, "endpoint not registered", failure.Message) }, }, @@ -365,7 +364,7 @@ func TestProcessInvocationTask(t *testing.T) { require.Equal(t, enumsspb.NEXUS_OPERATION_STATE_FAILED, op.State()) require.Equal(t, 1, len(events)) failure := events[0].GetNexusOperationFailedEventAttributes().Failure.Cause - require.NotNil(t, failure.GetNexusHandlerFailureInfo()) + require.Equal(t, string(nexus.HandlerErrorTypeNotFound), failure.GetNexusHandlerFailureInfo().GetType()) require.Equal(t, "endpoint not registered", failure.Message) }, }, @@ -651,8 +650,8 @@ func TestProcessCancelationTask(t *testing.T) { expectedMetricOutcome: "handler-error:INTERNAL", checkOutcome: func(t *testing.T, c nexusoperations.Cancelation) { require.Equal(t, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_FAILED, c.State()) - require.NotNil(t, c.LastAttemptFailure.GetNexusHandlerFailureInfo()) - require.Equal(t, "500 Internal Server Error", c.LastAttemptFailure.Message) + require.Equal(t, string(nexus.HandlerErrorTypeInternal), c.LastAttemptFailure.GetNexusHandlerFailureInfo().GetType()) + require.Equal(t, "Internal Server Error", c.LastAttemptFailure.Message) require.NotNil(t, c.LastAttemptFailure.Cause) require.Equal(t, "operation not found", c.LastAttemptFailure.Cause.Message) }, @@ -737,7 +736,7 @@ func TestProcessCancelationTask(t *testing.T) { onCancelOperation: nil, // This should not be called if the endpoint is not found. checkOutcome: func(t *testing.T, c nexusoperations.Cancelation) { require.Equal(t, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_FAILED, c.State()) - require.NotNil(t, c.LastAttemptFailure.GetNexusHandlerFailureInfo()) + require.Equal(t, string(nexus.HandlerErrorTypeNotFound), c.LastAttemptFailure.GetNexusHandlerFailureInfo().GetType()) require.Equal(t, "endpoint not registered", c.LastAttemptFailure.Message) }, }, diff --git a/tests/nexus_api_test.go b/tests/nexus_api_test.go index d5e53eda74..8055b925be 100644 --- a/tests/nexus_api_test.go +++ b/tests/nexus_api_test.go @@ -726,7 +726,7 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) - require.Equal(t, "500 Internal Server Error", handlerErr.Message) + require.Equal(t, "Internal Server Error", handlerErr.Message) require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, From 9083b1ddbe1a83eda6532354f97f7fae82334fef Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Tue, 27 Jan 2026 10:07:20 -0800 Subject: [PATCH 16/26] Minor stuff --- common/nexus/failure.go | 5 +---- common/nexus/nexusrpc/completion.go | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 52a6d0cf78..1d4a8c2d16 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -151,6 +151,7 @@ func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, e encodedAttributes = base64.StdEncoding.EncodeToString(b) } var retryableOverride *bool + // nolint:exhaustive // There are only two valid values other than unspecified. switch info.NexusHandlerFailureInfo.GetRetryBehavior() { case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE: val := true @@ -158,10 +159,6 @@ func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, e case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: val := false retryableOverride = &val - case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_UNSPECIFIED: - // noop - default: - // noop } handlerError := serializedHandlerError{ diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index 591c4c234e..d6f07082f7 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -167,7 +167,7 @@ type OperationCompletionUnsuccessfulOptions struct { // [DefaultFailureConverter]. // // NOTE: To call server versions <= 1.31.0, use a FailureConverter that unwraps the error cause if message is not - // present. + // present. That is is the default FailureConverter behavior, this comment applies to custom failure converters only. FailureConverter nexus.FailureConverter // OperationID is the unique ID for this operation. Used when a completion callback is received before a started response. // From 53d941d0f5d129757b3b542d180d862f1a9596b9 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Mon, 2 Feb 2026 16:29:31 -0800 Subject: [PATCH 17/26] Bump API --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index e8d732068d..f5551ab39d 100644 --- a/go.mod +++ b/go.mod @@ -58,7 +58,7 @@ require ( go.opentelemetry.io/otel/sdk v1.34.0 go.opentelemetry.io/otel/sdk/metric v1.34.0 go.opentelemetry.io/otel/trace v1.34.0 - go.temporal.io/api v1.61.1-0.20260130224800-de862587eb7e + go.temporal.io/api v1.62.1-0.20260202183046-ee8fb30e5160 go.temporal.io/sdk v1.38.0 go.uber.org/fx v1.24.0 go.uber.org/mock v0.6.0 diff --git a/go.sum b/go.sum index 97555779ee..f3fbd0028d 100644 --- a/go.sum +++ b/go.sum @@ -371,8 +371,8 @@ go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4= go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4= -go.temporal.io/api v1.61.1-0.20260130224800-de862587eb7e h1:bgA9XL9Q//JP2cjuuPQewDC7/htbMJZIqgqObMAG6h8= -go.temporal.io/api v1.61.1-0.20260130224800-de862587eb7e/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= +go.temporal.io/api v1.62.1-0.20260202183046-ee8fb30e5160 h1:UAhOiYian0WPrPYPtI8agjeuxM2hh4nHS7ACoJC15/A= +go.temporal.io/api v1.62.1-0.20260202183046-ee8fb30e5160/go.mod h1:iaxoP/9OXMJcQkETTECfwYq4cw/bj4nwov8b3ZLVnXM= go.temporal.io/sdk v1.38.0 h1:4Bok5LEdED7YKpsSjIa3dDqram5VOq+ydBf4pyx0Wo4= go.temporal.io/sdk v1.38.0/go.mod h1:a+R2Ej28ObvHoILbHaxMyind7M6D+W0L7edt5UJF4SE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= From 4b11275983adf18a70f487b5e36500fe8e52f8b9 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Tue, 3 Feb 2026 13:57:22 -0800 Subject: [PATCH 18/26] wip --- common/nexus/failure.go | 107 ++++++----- common/nexus/failure_test.go | 176 ++++++++----------- common/nexus/nexusrpc/client.go | 3 +- components/nexusoperations/completion.go | 35 ++-- components/nexusoperations/executors.go | 36 ++-- components/nexusoperations/executors_test.go | 30 ++-- service/frontend/nexus_handler.go | 17 +- service/frontend/nexus_http_handler.go | 11 +- service/frontend/workflow_handler.go | 7 +- 9 files changed, 186 insertions(+), 236 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 1d4a8c2d16..3448f19ba4 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -114,33 +114,6 @@ func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, e } switch info := failure.GetFailureInfo().(type) { - case *failurepb.Failure_NexusSdkOperationFailureInfo: - var encodedAttributes string - if failure.EncodedAttributes != nil { - b, err := protojson.Marshal(failure.EncodedAttributes) - if err != nil { - return nexus.Failure{}, fmt.Errorf("failed to deserialize OperationError attributes: %w", err) - } - encodedAttributes = base64.StdEncoding.EncodeToString(b) - } - operationError := serializedOperationError{ - State: info.NexusSdkOperationFailureInfo.GetState(), - EncodedAttributes: encodedAttributes, - } - - details, err := json.Marshal(operationError) - if err != nil { - return nexus.Failure{}, err - } - return nexus.Failure{ - Message: failure.GetMessage(), - StackTrace: failure.GetStackTrace(), - Metadata: map[string]string{ - "type": "nexus.OperationError", - }, - Details: details, - Cause: causep, - }, nil case *failurepb.Failure_NexusHandlerFailureInfo: var encodedAttributes string if failure.EncodedAttributes != nil { @@ -180,14 +153,6 @@ func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, e Details: details, Cause: causep, }, nil - case *failurepb.Failure_NexusSdkFailureErrorInfo: - return nexus.Failure{ - Message: failure.GetMessage(), - StackTrace: failure.GetStackTrace(), - Metadata: info.NexusSdkFailureErrorInfo.GetMetadata(), - Details: info.NexusSdkFailureErrorInfo.GetDetails(), - Cause: causep, - }, nil } // Unset message and stack trace so it's not serialized in the details. var message string @@ -228,31 +193,29 @@ func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) if f.Metadata != nil { switch f.Metadata["type"] { case failureTypeString: - if err := protojson.Unmarshal(f.Details, apiFailure); err != nil { + opts := protojson.UnmarshalOptions{DiscardUnknown: true} + if err := opts.Unmarshal(f.Details, apiFailure); err != nil { return nil, err } // Restore these fields as they are not included in the marshalled failure. apiFailure.Message = f.Message apiFailure.StackTrace = f.StackTrace case "nexus.OperationError": - var se serializedOperationError - err := json.Unmarshal(f.Details, &se) + var operationError *nexus.OperationError + err := json.Unmarshal(f.Details, &operationError) if err != nil { return nil, fmt.Errorf("failed to deserialize OperationError: %w", err) } - apiFailure.FailureInfo = &failurepb.Failure_NexusSdkOperationFailureInfo{ - NexusSdkOperationFailureInfo: &failurepb.NexusSDKOperationFailureInfo{ - State: se.State, - }, - } - if len(se.EncodedAttributes) > 0 { - decoded, err := base64.StdEncoding.DecodeString(se.EncodedAttributes) - if err != nil { - return nil, fmt.Errorf("failed to decode base64 OperationError attributes: %w", err) + if operationError.State == nexus.OperationStateCanceled { + apiFailure.FailureInfo = &failurepb.Failure_CanceledFailureInfo{ + CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, } - apiFailure.EncodedAttributes = &commonpb.Payload{} - if err := protojson.Unmarshal(decoded, apiFailure.EncodedAttributes); err != nil { - return nil, fmt.Errorf("failed to deserialize OperationError attributes: %w", err) + } else { + apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + NonRetryable: true, + Type: "OperationError", + }, } } case "nexus.HandlerError": @@ -286,17 +249,24 @@ func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) } } default: - apiFailure.FailureInfo = &failurepb.Failure_NexusSdkFailureErrorInfo{ - NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ - Metadata: f.Metadata, - Details: f.Details, + payloads, err := nexusFailureMetadataToPayloads(f) + if err != nil { + return nil, fmt.Errorf("failed to serialize failure: %w", err) + } + apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Details: payloads, }, } } } else if len(f.Details) > 0 { - apiFailure.FailureInfo = &failurepb.Failure_NexusSdkFailureErrorInfo{ - NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ - Details: f.Details, + payloads, err := nexusFailureMetadataToPayloads(f) + if err != nil { + return nil, fmt.Errorf("failed to serialize failure: %w", err) + } + apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Details: payloads, }, } } @@ -311,6 +281,29 @@ func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) return apiFailure, nil } +func nexusFailureMetadataToPayloads(failure nexus.Failure) (*commonpb.Payloads, error) { + if len(failure.Metadata) == 0 && len(failure.Details) == 0 { + return nil, nil + } + // Delete before serializing. + failure.Message = "" + failure.StackTrace = "" + data, err := json.Marshal(failure) + if err != nil { + return nil, err + } + return &commonpb.Payloads{ + Payloads: []*commonpb.Payload{ + { + Metadata: map[string][]byte{ + "encoding": []byte("json/plain"), + }, + Data: data, + }, + }, + }, err +} + // ConvertGRPCError converts either a serviceerror or a gRPC status error into a Nexus HandlerError if possible. // If exposeDetails is true, the error message from the given error is exposed in the converted HandlerError, otherwise, // a default message with minimal information is attached to the returned error. diff --git a/common/nexus/failure_test.go b/common/nexus/failure_test.go index 7949452127..e079ecc14e 100644 --- a/common/nexus/failure_test.go +++ b/common/nexus/failure_test.go @@ -3,10 +3,12 @@ package nexus import ( "testing" + "github.com/nexus-rpc/sdk-go/nexus" "github.com/stretchr/testify/require" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" failurepb "go.temporal.io/api/failure/v1" + "go.temporal.io/sdk/temporal" "go.temporal.io/server/common/testing/protorequire" ) @@ -35,50 +37,6 @@ func TestRoundTrip_ApplicationFailure(t *testing.T) { protorequire.ProtoEqual(t, original, converted) } -func TestRoundTrip_NexusSDKOperationFailure_WithoutAttributes(t *testing.T) { - original := &failurepb.Failure{ - Message: "operation failed", - StackTrace: "operation stack trace", - FailureInfo: &failurepb.Failure_NexusSdkOperationFailureInfo{ - NexusSdkOperationFailureInfo: &failurepb.NexusSDKOperationFailureInfo{ - State: "failed", - }, - }, - } - - nexusFailure, err := TemporalFailureToNexusFailure(original) - require.NoError(t, err) - - converted, err := NexusFailureToTemporalFailure(nexusFailure) - require.NoError(t, err) - - protorequire.ProtoEqual(t, original, converted) -} - -func TestRoundTrip_NexusSDKOperationFailure_WithAttributes(t *testing.T) { - original := &failurepb.Failure{ - Message: "operation failed with details", - StackTrace: "operation stack trace", - FailureInfo: &failurepb.Failure_NexusSdkOperationFailureInfo{ - NexusSdkOperationFailureInfo: &failurepb.NexusSDKOperationFailureInfo{ - State: "failed", - }, - }, - EncodedAttributes: &commonpb.Payload{ - Metadata: map[string][]byte{"encoding": []byte("json/plain")}, - Data: []byte(`{"custom":"attribute"}`), - }, - } - - nexusFailure, err := TemporalFailureToNexusFailure(original) - require.NoError(t, err) - - converted, err := NexusFailureToTemporalFailure(nexusFailure) - require.NoError(t, err) - - protorequire.ProtoEqual(t, original, converted) -} - func TestRoundTrip_NexusHandlerFailure_Retryable(t *testing.T) { original := &failurepb.Failure{ Message: "handler error - retryable", @@ -164,30 +122,6 @@ func TestRoundTrip_NexusHandlerFailure_WithAttributes(t *testing.T) { protorequire.ProtoEqual(t, original, converted) } -func TestRoundTrip_NexusSDKFailureErrorInfo(t *testing.T) { - original := &failurepb.Failure{ - Message: "sdk failure error", - StackTrace: "sdk stack trace", - FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ - NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ - Metadata: map[string]string{ - "custom-key": "custom-value", - "error-type": "SomeError", - }, - Details: []byte(`{"field":"value"}`), - }, - }, - } - - nexusFailure, err := TemporalFailureToNexusFailure(original) - require.NoError(t, err) - - converted, err := NexusFailureToTemporalFailure(nexusFailure) - require.NoError(t, err) - - protorequire.ProtoEqual(t, original, converted) -} - func TestRoundTrip_WithNestedCauses(t *testing.T) { original := &failurepb.Failure{ Message: "top level failure", @@ -226,25 +160,9 @@ func TestRoundTrip_WithNestedCauses(t *testing.T) { protorequire.ProtoEqual(t, original, converted) } -func TestRoundTrip_NexusOperationFailureWithNexusHandlerCause(t *testing.T) { +func TestRoundTrip_EmptyFailure(t *testing.T) { original := &failurepb.Failure{ - Message: "operation failed", - StackTrace: "operation stack trace", - FailureInfo: &failurepb.Failure_NexusSdkOperationFailureInfo{ - NexusSdkOperationFailureInfo: &failurepb.NexusSDKOperationFailureInfo{ - State: "failed", - }, - }, - Cause: &failurepb.Failure{ - Message: "handler caused the failure", - StackTrace: "handler stack trace", - FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ - NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ - Type: "BadRequest", - RetryBehavior: enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE, - }, - }, - }, + Message: "simple message", } nexusFailure, err := TemporalFailureToNexusFailure(original) @@ -256,9 +174,9 @@ func TestRoundTrip_NexusOperationFailureWithNexusHandlerCause(t *testing.T) { protorequire.ProtoEqual(t, original, converted) } -func TestRoundTrip_EmptyFailure(t *testing.T) { +func TestRoundTrip_OnlyStackTrace(t *testing.T) { original := &failurepb.Failure{ - Message: "simple message", + StackTrace: "just a stack trace", } nexusFailure, err := TemporalFailureToNexusFailure(original) @@ -266,36 +184,86 @@ func TestRoundTrip_EmptyFailure(t *testing.T) { converted, err := NexusFailureToTemporalFailure(nexusFailure) require.NoError(t, err) - protorequire.ProtoEqual(t, original, converted) } -func TestRoundTrip_OnlyStackTrace(t *testing.T) { - original := &failurepb.Failure{ - StackTrace: "just a stack trace", - } - - nexusFailure, err := TemporalFailureToNexusFailure(original) +func TestFromOperationFailedError(t *testing.T) { + nexusFailure, err := nexus.DefaultFailureConverter().ErrorToFailure(&nexus.OperationError{ + State: nexus.OperationStateFailed, + Message: "operation failed", + StackTrace: "stack trace", + }) + cause, err := TemporalFailureToNexusFailure( + temporal.GetDefaultFailureConverter().ErrorToFailure( + temporal.NewApplicationError("app err", "CustomError", "details"), + ), + ) require.NoError(t, err) + nexusFailure.Cause = &cause converted, err := NexusFailureToTemporalFailure(nexusFailure) require.NoError(t, err) - protorequire.ProtoEqual(t, original, converted) -} -func TestRoundTrip_OnlyDetails(t *testing.T) { - original := &failurepb.Failure{ - FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ - NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ - Details: []byte(`{"only":"details"}`), + expected := &failurepb.Failure{ + Message: "operation failed", + StackTrace: "stack trace", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + NonRetryable: true, + Type: "OperationError", + }, + }, + Cause: &failurepb.Failure{ + Message: "app err", + Source: "GoSDK", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Type: "CustomError", + Details: &commonpb.Payloads{ + Payloads: []*commonpb.Payload{mustToPayload(t, "details")}, + }, + }, }, }, } + protorequire.ProtoEqual(t, expected, converted) +} - nexusFailure, err := TemporalFailureToNexusFailure(original) +func TestFromOperationCanceledError(t *testing.T) { + nexusFailure, err := nexus.DefaultFailureConverter().ErrorToFailure(&nexus.OperationError{ + State: nexus.OperationStateCanceled, + Message: "operation canceled", + StackTrace: "stack trace", + }) + cause, err := TemporalFailureToNexusFailure( + temporal.GetDefaultFailureConverter().ErrorToFailure( + temporal.NewApplicationError("app err", "CustomError", "details"), + ), + ) require.NoError(t, err) + nexusFailure.Cause = &cause converted, err := NexusFailureToTemporalFailure(nexusFailure) require.NoError(t, err) - protorequire.ProtoEqual(t, original, converted) + + expected := &failurepb.Failure{ + Message: "operation canceled", + StackTrace: "stack trace", + FailureInfo: &failurepb.Failure_CanceledFailureInfo{ + CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, + }, + Cause: &failurepb.Failure{ + Message: "app err", + Source: "GoSDK", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Type: "CustomError", + Details: &commonpb.Payloads{ + Payloads: []*commonpb.Payload{mustToPayload(t, "details")}, + }, + }, + }, + }, + } + protorequire.ProtoEqual(t, expected, converted) } diff --git a/common/nexus/nexusrpc/client.go b/common/nexus/nexusrpc/client.go index 355ee89d49..d84960b855 100644 --- a/common/nexus/nexusrpc/client.go +++ b/common/nexus/nexusrpc/client.go @@ -219,6 +219,7 @@ func (c *HTTPClient) StartOperation( } addContextTimeoutToHTTPHeader(ctx, request.Header) addNexusHeaderToHTTPHeader(options.Header, request.Header) + request.Header.Set("temporal-nexus-failure-support", "true") response, err := c.options.HTTPCaller(request) if err != nil { @@ -302,7 +303,7 @@ func (c *HTTPClient) StartOperation( Message: "operation failed", Cause: wireErr, } - originalFailure, err := c.options.FailureConverter.ErrorToFailure(wireErr) + originalFailure, err := c.options.FailureConverter.ErrorToFailure(opErr) if err != nil { return nil, err } diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index bd55314fa9..770599381c 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -11,6 +11,7 @@ import ( failurepb "go.temporal.io/api/failure/v1" historypb "go.temporal.io/api/history/v1" "go.temporal.io/api/serviceerror" + "go.temporal.io/server/common" commonnexus "go.temporal.io/server/common/nexus" "go.temporal.io/server/service/history/hsm" "google.golang.org/protobuf/types/known/timestamppb" @@ -50,36 +51,23 @@ func handleOperationError( if err != nil { return err } - var originalFailure *failurepb.Failure - - // This should never be nil, but better be defensive here than to panic. - if opErr.OriginalFailure == nil { - if opErr.Message == "" { - // Add a generic message to ensure the failure converter does not unwrap the failure for compatibility. - opErr.Message = "nexus operation completed unsuccessfully" - } - originalNexusFailure, err := nexus.DefaultFailureConverter().ErrorToFailure(opErr) + var originalCause *failurepb.Failure + if opErr.OriginalFailure != nil && opErr.OriginalFailure.Cause != nil { + var err error + originalCause, err = commonnexus.NexusFailureToTemporalFailure(*opErr.OriginalFailure.Cause) if err != nil { return serviceerror.NewInvalidArgumentf("Malformed failure: %v", err) } - opErr.OriginalFailure = &originalNexusFailure - } - - nexusFailure := opErr.OriginalFailure - originalFailure, err = commonnexus.NexusFailureToTemporalFailure(*nexusFailure) - if err != nil { - return serviceerror.NewInvalidArgumentf("Malformed failure: %v", err) } switch opErr.State { // nolint:exhaustive case nexus.OperationStateFailed: event := node.AddHistoryEvent(enumspb.EVENT_TYPE_NEXUS_OPERATION_FAILED, func(e *historypb.HistoryEvent) { - failure := convertToNexusOperationFailure(operation, eventID, originalFailure) // We must assign to this property, linter doesn't like this. // nolint:revive e.Attributes = &historypb.HistoryEvent_NexusOperationFailedEventAttributes{ NexusOperationFailedEventAttributes: &historypb.NexusOperationFailedEventAttributes{ - Failure: failure, + Failure: createNexusOperationFailure(operation, eventID, originalCause), ScheduledEventId: eventID, RequestId: operation.RequestId, }, @@ -88,16 +76,15 @@ func handleOperationError( return FailedEventDefinition{}.Apply(node.Parent, event) case nexus.OperationStateCanceled: - if originalFailure.GetCause().GetCanceledFailureInfo() == nil { + if originalCause.GetCause().GetCanceledFailureInfo() == nil { // Wrap the original failure in a CanceledFailureInfo to indicate cancellation. All workflow commands expected a // nested CanceledFailure. - originalFailure.Cause = &failurepb.Failure{ + originalCause.Cause = &failurepb.Failure{ FailureInfo: &failurepb.Failure_CanceledFailureInfo{ CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, }, - // TODO(bergundy): This might be confusing. - // Preserve the original cause. - Cause: originalFailure.GetCause(), + // Preserve the original cause (+clone to avoid infinite self reference). + Cause: common.CloneProto(originalCause), } } event := node.AddHistoryEvent(enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCELED, func(e *historypb.HistoryEvent) { @@ -105,7 +92,7 @@ func handleOperationError( // nolint:revive e.Attributes = &historypb.HistoryEvent_NexusOperationCanceledEventAttributes{ NexusOperationCanceledEventAttributes: &historypb.NexusOperationCanceledEventAttributes{ - Failure: convertToNexusOperationFailure(operation, eventID, originalFailure), + Failure: createNexusOperationFailure(operation, eventID, originalCause), ScheduledEventId: eventID, RequestId: operation.RequestId, }, diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index e4812fadc7..0eec441269 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -526,17 +526,15 @@ func handleNonRetryableStartOperationError(node *hsm.Node, operation Operation, if err != nil { return err } - failure, err := callErrToFailure(callErr, false) + cause, err := callErrToFailure(callErr, false) if err != nil { return err } attrs := &historypb.NexusOperationFailedEventAttributes{ - Failure: convertToNexusOperationFailure( + Failure: createNexusOperationFailure( operation, eventID, - &failurepb.Failure{ - Cause: failure, - }, + cause, ), ScheduledEventId: eventID, RequestId: operation.RequestId, @@ -581,16 +579,14 @@ func (e taskExecutor) recordOperationTimeout(node *hsm.Node, timeoutType enumspb // nolint:revive // We must mutate here even if the linter doesn't like it. e.Attributes = &historypb.HistoryEvent_NexusOperationTimedOutEventAttributes{ NexusOperationTimedOutEventAttributes: &historypb.NexusOperationTimedOutEventAttributes{ - Failure: convertToNexusOperationFailure( + Failure: createNexusOperationFailure( op, eventID, &failurepb.Failure{ - Cause: &failurepb.Failure{ - Message: "operation timed out", - FailureInfo: &failurepb.Failure_TimeoutFailureInfo{ - TimeoutFailureInfo: &failurepb.TimeoutFailureInfo{ - TimeoutType: timeoutType, - }, + Message: "operation timed out", + FailureInfo: &failurepb.Failure_TimeoutFailureInfo{ + TimeoutFailureInfo: &failurepb.TimeoutFailureInfo{ + TimeoutType: timeoutType, }, }, }, @@ -845,13 +841,9 @@ func (e taskExecutor) lookupEndpoint(ctx context.Context, namespaceID namespace. return entry, nil } -// Copy over message and stack trace if present since to preserve as much as possible from the original operation -// error. -func convertToNexusOperationFailure(operation Operation, scheduledEventID int64, originalFailure *failurepb.Failure) *failurepb.Failure { - f := &failurepb.Failure{ - Message: originalFailure.GetMessage(), - StackTrace: originalFailure.GetStackTrace(), - EncodedAttributes: originalFailure.GetEncodedAttributes(), +func createNexusOperationFailure(operation Operation, scheduledEventID int64, cause *failurepb.Failure) *failurepb.Failure { + return &failurepb.Failure{ + Message: "nexus operation completed unsuccessfully", FailureInfo: &failurepb.Failure_NexusOperationExecutionFailureInfo{ NexusOperationExecutionFailureInfo: &failurepb.NexusOperationFailureInfo{ Endpoint: operation.Endpoint, @@ -863,12 +855,8 @@ func convertToNexusOperationFailure(operation Operation, scheduledEventID int64, ScheduledEventId: scheduledEventID, }, }, - Cause: originalFailure.GetCause(), - } - if originalFailure.GetMessage() == "" { - f.Message = "nexus operation completed unsuccessfully" + Cause: cause, } - return f } func startCallOutcomeTag(callCtx context.Context, result *nexusrpc.ClientStartOperationResponse[*nexus.LazyValue], callErr error) string { diff --git a/components/nexusoperations/executors_test.go b/components/nexusoperations/executors_test.go index 02fcf9f24c..b1ca8d7de2 100644 --- a/components/nexusoperations/executors_test.go +++ b/components/nexusoperations/executors_test.go @@ -221,12 +221,12 @@ func TestProcessInvocationTask(t *testing.T) { }, Cause: &failurepb.Failure{ Message: "cause", - FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ - NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ - Metadata: map[string]string{"encoding": "json/plain"}, - Details: []byte(`"details"`), - }, - }, + // FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ + // NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ + // Metadata: map[string]string{"encoding": "json/plain"}, + // Details: []byte(`"details"`), + // }, + // }, }, }, } @@ -255,7 +255,7 @@ func TestProcessInvocationTask(t *testing.T) { ScheduledEventId: 1, RequestId: op.RequestId, Failure: &failurepb.Failure{ - Message: "operation canceled from handler", + Message: "nexus operation completed unsuccessfully", FailureInfo: &failurepb.Failure_NexusOperationExecutionFailureInfo{ NexusOperationExecutionFailureInfo: &failurepb.NexusOperationFailureInfo{ ScheduledEventId: 1, @@ -265,17 +265,19 @@ func TestProcessInvocationTask(t *testing.T) { }, }, Cause: &failurepb.Failure{ + Message: "operation canceled from handler", FailureInfo: &failurepb.Failure_CanceledFailureInfo{ CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, }, Cause: &failurepb.Failure{ - Message: "cause", - FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ - NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ - Metadata: map[string]string{"encoding": "json/plain"}, - Details: []byte(`"details"`), - }, - }, + Message: "cause", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{}, + // FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ + // NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ + // Metadata: map[string]string{"encoding": "json/plain"}, + // Details: []byte(`"details"`), + // }, + // }, }, }, }, diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index acfa0c0fc9..d38c38b335 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -46,6 +46,7 @@ const ( // Generic Nexus context that is not bound to a specific operation. // Includes fields extracted from an incoming Nexus request before being handled by the Nexus HTTP handler. type nexusContext struct { + callerFailureSupport bool // Whether to use the new Temporal failure responses path. requestStartTime time.Time apiName string namespaceName string @@ -414,7 +415,7 @@ func (h *nexusHandler) StartOperation( StartOperation: &startOperationRequest, }, Capabilities: &nexuspb.Request_Capabilities{ - TemporalFailureResponses: true, + TemporalFailureResponses: oc.callerFailureSupport, }, }) @@ -515,12 +516,20 @@ func (h *nexusHandler) StartOperation( oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } - oe, err := nexus.DefaultFailureConverter().FailureToError(nf) + cause, err := nexus.DefaultFailureConverter().FailureToError(nf) if err != nil { oc.logger.Error("error converting Nexus failure to Nexus OperationError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } - return nil, oe + state := nexus.OperationStateFailed + if t.Failure.GetCanceledFailureInfo() != nil { + state = nexus.OperationStateCanceled + } + return nil, &nexus.OperationError{ + Message: "operation error", + State: state, + Cause: cause, + } } } // This is the worker's fault. @@ -617,7 +626,7 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, }, }, Capabilities: &nexuspb.Request_Capabilities{ - TemporalFailureResponses: true, + TemporalFailureResponses: oc.callerFailureSupport, }, }) if err := oc.interceptRequest(ctx, request, options.Header); err != nil { diff --git a/service/frontend/nexus_http_handler.go b/service/frontend/nexus_http_handler.go index 7f40edbdac..940e63fcfa 100644 --- a/service/frontend/nexus_http_handler.go +++ b/service/frontend/nexus_http_handler.go @@ -140,7 +140,7 @@ func (h *NexusHTTPHandler) dispatchNexusTaskByNamespaceAndTaskQueue(w http.Respo } var err error - nc := h.baseNexusContext(configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName) + nc := h.baseNexusContext(configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName, r.Header) params := prepareRequest(commonnexus.RouteDispatchNexusTaskByNamespaceAndTaskQueue, w, r) if nc.taskQueue, err = url.PathUnescape(params.TaskQueue); err != nil { @@ -216,7 +216,7 @@ func (h *NexusHTTPHandler) dispatchNexusTaskByEndpoint(w http.ResponseWriter, r return } - nc, ok := h.nexusContextFromEndpoint(endpointEntry, w) + nc, ok := h.nexusContextFromEndpoint(endpointEntry, w, r.Header) if !ok { // nexusContextFromEndpoint already writes the failure response. return @@ -239,7 +239,7 @@ func (h *NexusHTTPHandler) dispatchNexusTaskByEndpoint(w http.ResponseWriter, r h.serveResolvedURL(w, r, u, nc) } -func (h *NexusHTTPHandler) baseNexusContext(apiName string) *nexusContext { +func (h *NexusHTTPHandler) baseNexusContext(apiName string, header http.Header) *nexusContext { return &nexusContext{ namespaceValidationInterceptor: h.namespaceValidationInterceptor, namespaceRateLimitInterceptor: h.namespaceRateLimitInterceptor, @@ -248,6 +248,7 @@ func (h *NexusHTTPHandler) baseNexusContext(apiName string) *nexusContext { apiName: apiName, requestStartTime: time.Now(), responseHeaders: make(map[string]string), + callerFailureSupport: header.Get("temporal-nexus-failure-support") == "true", } } @@ -255,7 +256,7 @@ func (h *NexusHTTPHandler) baseNexusContext(apiName string) *nexusContext { // endpoint is valid for dispatching. // For security reasons, at the moment only worker target endpoints are considered valid, in the future external // endpoints may also be supported. -func (h *NexusHTTPHandler) nexusContextFromEndpoint(entry *persistencespb.NexusEndpointEntry, w http.ResponseWriter) (*nexusContext, bool) { +func (h *NexusHTTPHandler) nexusContextFromEndpoint(entry *persistencespb.NexusEndpointEntry, w http.ResponseWriter, header http.Header) (*nexusContext, bool) { switch v := entry.Endpoint.Spec.GetTarget().GetVariant().(type) { case *persistencespb.NexusEndpointTarget_Worker_: nsName, err := h.namespaceRegistry.GetNamespaceName(namespace.ID(v.Worker.GetNamespaceId())) @@ -270,7 +271,7 @@ func (h *NexusHTTPHandler) nexusContextFromEndpoint(entry *persistencespb.NexusE } return nil, false } - nc := h.baseNexusContext(configs.DispatchNexusTaskByEndpointAPIName) + nc := h.baseNexusContext(configs.DispatchNexusTaskByEndpointAPIName, header) nc.namespaceName = nsName.String() nc.taskQueue = v.Worker.GetTaskQueue() nc.endpointName = entry.Endpoint.Spec.Name diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 63d62746e9..e3b7749b44 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -5565,9 +5565,10 @@ func (wh *WorkflowHandler) RespondNexusTaskCompleted(ctx context.Context, reques return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") } } - if f := request.GetResponse().GetStartOperation().GetFailure(); f != nil && f.GetNexusSdkOperationFailureInfo() == nil { - return nil, serviceerror.NewInvalidArgument("request StartOperation Failure must contain failure with NexusSdkOperationFailureInfo") - } + // TODO: assert application error or canceled error? + // if f := request.GetResponse().GetStartOperation().GetFailure(); f != nil && f.GetNexusSdkOperationFailureInfo() == nil { + // return nil, serviceerror.NewInvalidArgument("request StartOperation Failure must contain failure with NexusSdkOperationFailureInfo") + // } matchingRequest := &matchingservice.RespondNexusTaskCompletedRequest{ NamespaceId: namespaceId.String(), From d76ef5aa143fa01f75cdabd35cfc5450af653dd9 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Tue, 3 Feb 2026 20:16:24 -0800 Subject: [PATCH 19/26] Fix tests and compat with generic Nexus --- common/nexus/failure.go | 8 +-- common/nexus/failure_test.go | 2 + common/nexus/nexusrpc/client.go | 1 + common/nexus/nexusrpc/completion.go | 2 + components/nexusoperations/completion.go | 23 ++++++--- components/nexusoperations/executors_test.go | 52 ++++++++++++++------ service/frontend/nexus_handler.go | 13 ++++- service/frontend/workflow_handler.go | 4 -- service/history/handler.go | 1 + tests/nexus_workflow_test.go | 5 +- 10 files changed, 73 insertions(+), 38 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 3448f19ba4..811b4b0038 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -83,12 +83,6 @@ func NexusFailureToProtoFailure(failure nexus.Failure) *nexuspb.Failure { return pf } -type serializedOperationError struct { - State string `json:"state,omitempty"` - // Bytes as base64 encoded string. - EncodedAttributes string `json:"encodedAttributes,omitempty"` -} - type serializedHandlerError struct { Type string `json:"type,omitempty"` RetryableOverride *bool `json:"retryableOverride,omitempty"` @@ -124,7 +118,7 @@ func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, e encodedAttributes = base64.StdEncoding.EncodeToString(b) } var retryableOverride *bool - // nolint:exhaustive // There are only two valid values other than unspecified. + // nolint:exhaustive,revive // There are only two valid values other than unspecified. switch info.NexusHandlerFailureInfo.GetRetryBehavior() { case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE: val := true diff --git a/common/nexus/failure_test.go b/common/nexus/failure_test.go index e079ecc14e..dae81a91b8 100644 --- a/common/nexus/failure_test.go +++ b/common/nexus/failure_test.go @@ -193,6 +193,7 @@ func TestFromOperationFailedError(t *testing.T) { Message: "operation failed", StackTrace: "stack trace", }) + require.NoError(t, err) cause, err := TemporalFailureToNexusFailure( temporal.GetDefaultFailureConverter().ErrorToFailure( temporal.NewApplicationError("app err", "CustomError", "details"), @@ -235,6 +236,7 @@ func TestFromOperationCanceledError(t *testing.T) { Message: "operation canceled", StackTrace: "stack trace", }) + require.NoError(t, err) cause, err := TemporalFailureToNexusFailure( temporal.GetDefaultFailureConverter().ErrorToFailure( temporal.NewApplicationError("app err", "CustomError", "details"), diff --git a/common/nexus/nexusrpc/client.go b/common/nexus/nexusrpc/client.go index d84960b855..f0764dc752 100644 --- a/common/nexus/nexusrpc/client.go +++ b/common/nexus/nexusrpc/client.go @@ -307,6 +307,7 @@ func (c *HTTPClient) StartOperation( if err != nil { return nil, err } + originalFailure.Metadata["unwrap-error"] = "true" opErr.OriginalFailure = &originalFailure wireErr = opErr } diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index d6f07082f7..6ef1285fc7 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -339,6 +339,8 @@ func (h *completionHTTPHandler) ServeHTTP(writer http.ResponseWriter, request *h if !ok { // Backwards compatibility: wrap non-OperationError errors in an OperationError with the appropriate state. completion.Error = &nexus.OperationError{ + // Not adding Message here to ensure the failure is unwrapped. + // After server version 1.31.0 is out, we can add the message back. State: completion.State, Cause: completionErr, } diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index 770599381c..efbfe8d483 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -11,7 +11,6 @@ import ( failurepb "go.temporal.io/api/failure/v1" historypb "go.temporal.io/api/history/v1" "go.temporal.io/api/serviceerror" - "go.temporal.io/server/common" commonnexus "go.temporal.io/server/common/nexus" "go.temporal.io/server/service/history/hsm" "google.golang.org/protobuf/types/known/timestamppb" @@ -52,12 +51,20 @@ func handleOperationError( return err } var originalCause *failurepb.Failure - if opErr.OriginalFailure != nil && opErr.OriginalFailure.Cause != nil { + unwrapError := opErr.OriginalFailure.Metadata["unwrap-error"] == "true" + + if unwrapError && opErr.OriginalFailure.Cause != nil { var err error originalCause, err = commonnexus.NexusFailureToTemporalFailure(*opErr.OriginalFailure.Cause) if err != nil { return serviceerror.NewInvalidArgumentf("Malformed failure: %v", err) } + } else { + // Transform the OperationError to either ApplicationFailure or CanceledFailure based on the operation error state. + originalCause, err = commonnexus.NexusFailureToTemporalFailure(*opErr.OriginalFailure) + if err != nil { + return serviceerror.NewInvalidArgumentf("Malformed failure: %v", err) + } } switch opErr.State { // nolint:exhaustive @@ -76,15 +83,15 @@ func handleOperationError( return FailedEventDefinition{}.Apply(node.Parent, event) case nexus.OperationStateCanceled: - if originalCause.GetCause().GetCanceledFailureInfo() == nil { - // Wrap the original failure in a CanceledFailureInfo to indicate cancellation. All workflow commands expected a - // nested CanceledFailure. - originalCause.Cause = &failurepb.Failure{ + if originalCause.GetCanceledFailureInfo() == nil { + // Old SDKs may send an ApplicationFailure for canceled operation causes. + originalCause = &failurepb.Failure{ + Message: originalCause.GetMessage(), + StackTrace: originalCause.GetStackTrace(), FailureInfo: &failurepb.Failure_CanceledFailureInfo{ CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, }, - // Preserve the original cause (+clone to avoid infinite self reference). - Cause: common.CloneProto(originalCause), + Cause: originalCause.GetCause(), } } event := node.AddHistoryEvent(enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCELED, func(e *historypb.HistoryEvent) { diff --git a/components/nexusoperations/executors_test.go b/components/nexusoperations/executors_test.go index b1ca8d7de2..0e5099d108 100644 --- a/components/nexusoperations/executors_test.go +++ b/components/nexusoperations/executors_test.go @@ -210,7 +210,7 @@ func TestProcessInvocationTask(t *testing.T) { ScheduledEventId: 1, RequestId: op.RequestId, Failure: &failurepb.Failure{ - Message: "operation failed from handler", + Message: "nexus operation completed unsuccessfully", FailureInfo: &failurepb.Failure_NexusOperationExecutionFailureInfo{ NexusOperationExecutionFailureInfo: &failurepb.NexusOperationFailureInfo{ ScheduledEventId: 1, @@ -220,13 +220,28 @@ func TestProcessInvocationTask(t *testing.T) { }, }, Cause: &failurepb.Failure{ - Message: "cause", - // FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ - // NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ - // Metadata: map[string]string{"encoding": "json/plain"}, - // Details: []byte(`"details"`), - // }, - // }, + Message: "operation failed from handler", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Type: "OperationError", + NonRetryable: true, + }, + }, + Cause: &failurepb.Failure{ + Message: "cause", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Details: &commonpb.Payloads{ + Payloads: []*commonpb.Payload{ + mustToPayload(t, nexus.Failure{ + Metadata: map[string]string{"encoding": "json/plain"}, + Details: json.RawMessage("\"details\""), + }), + }, + }, + }, + }, + }, }, }, } @@ -270,14 +285,19 @@ func TestProcessInvocationTask(t *testing.T) { CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, }, Cause: &failurepb.Failure{ - Message: "cause", - FailureInfo: &failurepb.Failure_ApplicationFailureInfo{}, - // FailureInfo: &failurepb.Failure_NexusSdkFailureErrorInfo{ - // NexusSdkFailureErrorInfo: &failurepb.NexusSDKFailureErrorFailureInfo{ - // Metadata: map[string]string{"encoding": "json/plain"}, - // Details: []byte(`"details"`), - // }, - // }, + Message: "cause", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Details: &commonpb.Payloads{ + Payloads: []*commonpb.Payload{ + mustToPayload(t, nexus.Failure{ + Metadata: map[string]string{"encoding": "json/plain"}, + Details: json.RawMessage("\"details\""), + }), + }, + }, + }, + }, }, }, }, diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index d38c38b335..c420206969 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -525,11 +525,20 @@ func (h *nexusHandler) StartOperation( if t.Failure.GetCanceledFailureInfo() != nil { state = nexus.OperationStateCanceled } - return nil, &nexus.OperationError{ - Message: "operation error", + opErr := &nexus.OperationError{ State: state, + Message: "operation error", Cause: cause, } + nf, err = nexus.DefaultFailureConverter().ErrorToFailure(opErr) + if err != nil { + oc.logger.Error("error converting OperationError to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + nf.Metadata["unwrap-error"] = "true" + opErr.OriginalFailure = &nf + + return nil, opErr } } // This is the worker's fault. diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index e3b7749b44..e613963006 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -5565,10 +5565,6 @@ func (wh *WorkflowHandler) RespondNexusTaskCompleted(ctx context.Context, reques return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") } } - // TODO: assert application error or canceled error? - // if f := request.GetResponse().GetStartOperation().GetFailure(); f != nil && f.GetNexusSdkOperationFailureInfo() == nil { - // return nil, serviceerror.NewInvalidArgument("request StartOperation Failure must contain failure with NexusSdkOperationFailureInfo") - // } matchingRequest := &matchingservice.RespondNexusTaskCompletedRequest{ NamespaceId: namespaceId.String(), diff --git a/service/history/handler.go b/service/history/handler.go index fc26b75cbf..91f4db3845 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -2175,6 +2175,7 @@ func (h *Handler) CompleteNexusOperation(ctx context.Context, request *historyse if err != nil { return nil, serviceerror.NewInvalidArgument("unable to convert operation error to failure") } + origFailure.Metadata["unwrap-error"] = "true" opErr.OriginalFailure = &origFailure } } diff --git a/tests/nexus_workflow_test.go b/tests/nexus_workflow_test.go index 548748c87e..b48bdcd649 100644 --- a/tests/nexus_workflow_test.go +++ b/tests/nexus_workflow_test.go @@ -2132,7 +2132,10 @@ func (s *NexusWorkflowTestSuite) TestNexusSyncOperationErrorRehydration() { checkWorkflowError: func(t *testing.T, wfErr error) { var opErr *temporal.NexusOperationError require.ErrorAs(t, wfErr, &opErr) - require.Equal(t, "some error", opErr.Message) + require.Equal(t, "nexus operation completed unsuccessfully", opErr.Message) + var appErr *temporal.ApplicationError + require.ErrorAs(t, opErr.Cause, &appErr) + require.Equal(t, "some error", appErr.Message()) }, }, { From 033df6d69ff8a8523f93fda149eadd5e0d93f783 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Wed, 4 Feb 2026 13:53:31 -0800 Subject: [PATCH 20/26] Self review --- common/nexus/failure.go | 79 +++++++++++++----------- common/nexus/nexusrpc/client.go | 8 ++- common/nexus/nexusrpc/completion.go | 3 +- components/nexusoperations/completion.go | 3 + components/nexusoperations/executors.go | 2 +- service/frontend/nexus_handler.go | 20 ++++-- service/history/handler.go | 2 + tests/nexus_workflow_test.go | 15 +++-- 8 files changed, 81 insertions(+), 51 deletions(-) diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 811b4b0038..5426d5ef9d 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -93,8 +93,9 @@ type serializedHandlerError struct { // TemporalFailureToNexusFailure converts an API proto Failure to a Nexus SDK Failure setting the metadata "type" field to // the proto fullname of the temporal API Failure message or the standard Nexus SDK failure types. // Returns an error if the failure cannot be converted. -// Mutates the failure temporarily, unsetting the Message field to avoid duplicating the information in the serialized -// failure. Mutating was chosen over cloning for performance reasons since this function may be called frequently. +// Mutates the failure temporarily, unsetting the Message and StackTrace fields to avoid duplicating the information in +// the serialized failure. Mutating was chosen over cloning for performance reasons since this function may be called +// frequently. func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) { var causep *nexus.Failure if failure.GetCause() != nil { @@ -115,7 +116,7 @@ func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, e if err != nil { return nexus.Failure{}, fmt.Errorf("failed to deserialize HandlerError attributes: %w", err) } - encodedAttributes = base64.StdEncoding.EncodeToString(b) + encodedAttributes = base64.RawURLEncoding.EncodeToString(b) } var retryableOverride *bool // nolint:exhaustive,revive // There are only two valid values other than unspecified. @@ -195,16 +196,21 @@ func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) apiFailure.Message = f.Message apiFailure.StackTrace = f.StackTrace case "nexus.OperationError": + // Special case for OperationError that adapts from Nexus semantics to Temporal semantics. + // Note that Temporal -> Temporal doesn't go through this code path, operation errors are always used as empty + // wrappers for an underlying causes. var operationError *nexus.OperationError err := json.Unmarshal(f.Details, &operationError) if err != nil { return nil, fmt.Errorf("failed to deserialize OperationError: %w", err) } if operationError.State == nexus.OperationStateCanceled { + // Canceled operation errors are represented as CanceledFailure in Temporal. apiFailure.FailureInfo = &failurepb.Failure_CanceledFailureInfo{ CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, } } else { + // Failed operation errors are represented as non-retryable ApplicationFailure in Temporal. apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ NonRetryable: true, @@ -233,7 +239,7 @@ func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) }, } if len(se.EncodedAttributes) > 0 { - decoded, err := base64.StdEncoding.DecodeString(se.EncodedAttributes) + decoded, err := base64.RawURLEncoding.DecodeString(se.EncodedAttributes) if err != nil { return nil, fmt.Errorf("failed to decode base64 HandlerError attributes: %w", err) } @@ -243,26 +249,22 @@ func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) } } default: - payloads, err := nexusFailureMetadataToPayloads(f) + // We don't recognize this type, convert to a generic ApplicationFailure and preserve the original Nexus failure + // as serialized details. + applicationFailureInfo, err := nexusFailureMetadataToApplicationFailureInfo(f) if err != nil { - return nil, fmt.Errorf("failed to serialize failure: %w", err) - } - apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ - ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ - Details: payloads, - }, + return nil, fmt.Errorf("failed to serialize Nexus failure: %w", err) } + apiFailure.FailureInfo = applicationFailureInfo } } else if len(f.Details) > 0 { - payloads, err := nexusFailureMetadataToPayloads(f) + // We don't recognize this type, convert to a generic ApplicationFailure and preserve the original Nexus failure as + // serialized details. + applicationFailureInfo, err := nexusFailureMetadataToApplicationFailureInfo(f) if err != nil { - return nil, fmt.Errorf("failed to serialize failure: %w", err) - } - apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ - ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ - Details: payloads, - }, + return nil, fmt.Errorf("failed to serialize Nexus failure: %w", err) } + apiFailure.FailureInfo = applicationFailureInfo } if f.Cause != nil { @@ -275,27 +277,32 @@ func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) return apiFailure, nil } -func nexusFailureMetadataToPayloads(failure nexus.Failure) (*commonpb.Payloads, error) { - if len(failure.Metadata) == 0 && len(failure.Details) == 0 { - return nil, nil - } - // Delete before serializing. - failure.Message = "" - failure.StackTrace = "" - data, err := json.Marshal(failure) - if err != nil { - return nil, err - } - return &commonpb.Payloads{ - Payloads: []*commonpb.Payload{ - { - Metadata: map[string][]byte{ - "encoding": []byte("json/plain"), +func nexusFailureMetadataToApplicationFailureInfo(failure nexus.Failure) (*failurepb.Failure_ApplicationFailureInfo, error) { + var payloads *commonpb.Payloads + if len(failure.Metadata) > 0 || len(failure.Details) > 0 { + // Delete before serializing (note the failure here is passed by value). + failure.Message = "" + failure.StackTrace = "" + data, err := json.Marshal(failure) + if err != nil { + return nil, err + } + payloads = &commonpb.Payloads{ + Payloads: []*commonpb.Payload{ + { + Metadata: map[string][]byte{ + "encoding": []byte("json/plain"), + }, + Data: data, }, - Data: data, }, + } + } + return &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Details: payloads, }, - }, err + }, nil } // ConvertGRPCError converts either a serviceerror or a gRPC status error into a Nexus HandlerError if possible. diff --git a/common/nexus/nexusrpc/client.go b/common/nexus/nexusrpc/client.go index f0764dc752..ff908b20e0 100644 --- a/common/nexus/nexusrpc/client.go +++ b/common/nexus/nexusrpc/client.go @@ -219,6 +219,7 @@ func (c *HTTPClient) StartOperation( } addContextTimeoutToHTTPHeader(ctx, request.Header) addNexusHeaderToHTTPHeader(options.Header, request.Header) + // If this request is handled by a newer server that supports Nexus failure serialization, trigger that behavior. request.Header.Set("temporal-nexus-failure-support", "true") response, err := c.options.HTTPCaller(request) @@ -300,13 +301,16 @@ func (c *HTTPClient) StartOperation( } opErr := &nexus.OperationError{ State: state, - Message: "operation failed", + Message: "nexus operation completed unsuccessfully", Cause: wireErr, } + // Ensure we the original failure is available, the calling code expects it. originalFailure, err := c.options.FailureConverter.ErrorToFailure(opErr) if err != nil { return nil, err } + // Special header to signal that this error should be unwrapped by the completion handler as old servers will send + // back empty wrappers for underlying causes. originalFailure.Metadata["unwrap-error"] = "true" opErr.OriginalFailure = &originalFailure wireErr = opErr @@ -386,6 +390,8 @@ func (c *HTTPClient) defaultErrorFromResponse(response *http.Response, body []by RetryBehavior: retryBehaviorFromHeader(response.Header), Cause: cause, } + + // Ensure we the original failure is available, the calling code expects it. originalFailure, err := c.options.FailureConverter.ErrorToFailure(handlerErr) if err != nil { return newUnexpectedResponseError("failed to construct handler error from response: "+err.Error(), response, body) diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index 6ef1285fc7..add778e6f7 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -339,7 +339,8 @@ func (h *completionHTTPHandler) ServeHTTP(writer http.ResponseWriter, request *h if !ok { // Backwards compatibility: wrap non-OperationError errors in an OperationError with the appropriate state. completion.Error = &nexus.OperationError{ - // Not adding Message here to ensure the failure is unwrapped. + // Not adding Message here to ensure the failure is unwrapped (behavior of the Nexus failure converter to + // maintain backwards compatibility). // After server version 1.31.0 is out, we can add the message back. State: completion.State, Cause: completionErr, diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index efbfe8d483..cf37247030 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -51,6 +51,9 @@ func handleOperationError( return err } var originalCause *failurepb.Failure + // Special marker for Temporal->Temporal calls to indicate that the original failure should be unwrapped. + // Temporal uses a wrapper operation error with no additional information to transmit the OperationError over the network. + // The meaningful information is in the operation error's cause. unwrapError := opErr.OriginalFailure.Metadata["unwrap-error"] == "true" if unwrapError && opErr.OriginalFailure.Cause != nil { diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 0eec441269..595e51c8be 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -929,7 +929,7 @@ func callErrToFailure(callErr error, retryable bool) (*failurepb.Failure, error) nf = *handlerErr.OriginalFailure } else { var err error - // Ensure the error message is set to ensure the failure converter does not unwrap the cause. + // Ensure the error message is set to prevent the Nexus failure converter from unwrapping the cause. if handlerErr.Message == "" { handlerErr.Message = "handler error" } diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index c420206969..06c86180b6 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -46,7 +46,9 @@ const ( // Generic Nexus context that is not bound to a specific operation. // Includes fields extracted from an incoming Nexus request before being handled by the Nexus HTTP handler. type nexusContext struct { - callerFailureSupport bool // Whether to use the new Temporal failure responses path. + // Whether to use the new Temporal failure responses path. + // Set from the incoming nexus request's "temporal-nexus-failure-support" header. + callerFailureSupport bool requestStartTime time.Time apiName string namespaceName string @@ -448,6 +450,10 @@ func (h *nexusHandler) StartOperation( // Convert to standard Nexus SDK response. switch t := response.GetOutcome().(type) { case *matchingservice.DispatchNexusTaskResponse_Failure: + // Set the failure source to "worker" if we've reached this case. + // Failure conversions errors below are the user's fault, as it implies that malformed completions were sent from + // the worker. + oc.setFailureSource(commonnexus.FailureSourceWorker) oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.Failure.GetNexusHandlerFailureInfo().GetType())) nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { @@ -459,8 +465,6 @@ func (h *nexusHandler) StartOperation( oc.logger.Error("error converting Nexus failure to Nexus HandlerError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } - // Failure conversions are our fault so only set this after converting the Temporal failure to a HandlerError. - oc.setFailureSource(commonnexus.FailureSourceWorker) return nil, he case *matchingservice.DispatchNexusTaskResponse_HandlerError: @@ -509,6 +513,9 @@ func (h *nexusHandler) StartOperation( return nil, err case *nexuspb.StartOperationResponse_Failure: + // Set the failure source to "worker" if we've reached this case. + // Failure conversions errors below are the user's fault, as it implies that malformed completions were sent from + // the worker. oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("failure")) oc.setFailureSource(commonnexus.FailureSourceWorker) nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) @@ -535,6 +542,9 @@ func (h *nexusHandler) StartOperation( oc.logger.Error("error converting OperationError to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } + // Mark that the original failure is an OperationError wrapper to be unwrapped. + // Newer server callers will unwrap the cause automatically if this metadata key is set. + // This is required to support calling non Temporal based implementations where OpeationErrors carry additional information. nf.Metadata["unwrap-error"] = "true" opErr.OriginalFailure = &nf @@ -658,6 +668,9 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, switch t := response.GetOutcome().(type) { case *matchingservice.DispatchNexusTaskResponse_Failure: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.Failure.GetNexusHandlerFailureInfo().GetType())) + // Set the failure source to "worker" if we've reached this case. + // Failure conversions errors below are the user's fault, as it implies that malformed completions were sent from + // the worker. oc.setFailureSource(commonnexus.FailureSourceWorker) nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { @@ -669,7 +682,6 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, oc.logger.Error("error converting Nexus failure to Nexus HandlerError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } - // Failure conversions are our fault so only set this after converting the Temporal failure to a HandlerError. return he case *matchingservice.DispatchNexusTaskResponse_HandlerError: diff --git a/service/history/handler.go b/service/history/handler.go index 91f4db3845..3288c55edd 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -2175,6 +2175,8 @@ func (h *Handler) CompleteNexusOperation(ctx context.Context, request *historyse if err != nil { return nil, serviceerror.NewInvalidArgument("unable to convert operation error to failure") } + // Special header to signal that this error should be unwrapped by the completion handler as old servers will send + // back empty wrappers for underlying causes. origFailure.Metadata["unwrap-error"] = "true" opErr.OriginalFailure = &origFailure } diff --git a/tests/nexus_workflow_test.go b/tests/nexus_workflow_test.go index b48bdcd649..5efe011614 100644 --- a/tests/nexus_workflow_test.go +++ b/tests/nexus_workflow_test.go @@ -3,6 +3,7 @@ package tests import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -2507,14 +2508,12 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncNexusFailure() { var appErr *temporal.ApplicationError s.ErrorAs(handlerErr.Cause, &appErr) s.Equal(appErr.Message(), "fail me") - // NOTE: We broke compatibility here but the likelyhood of anyone using FailureErrors directly is practically zero. - // TODO(bergundy): test when the new SDK supports deserializing Nexus SDK failures - // var failure nexus.Failure - // s.NoError(appErr.Details(&failure)) - // s.Equal(map[string]string{"key": "val"}, failure.Metadata) - // var details string - // s.NoError(json.Unmarshal(failure.Details, &details)) - // s.Equal("details", details) + var failure nexus.Failure + s.NoError(appErr.Details(&failure)) + s.Equal(map[string]string{"key": "val"}, failure.Metadata) + var details string + s.NoError(json.Unmarshal(failure.Details, &details)) + s.Equal("details", details) snap := capture.Snapshot() s.Len(snap["nexus_outbound_requests"], 1) From f6b7aa81a435da5473cb6fca925627c1d86c045d Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Mon, 9 Feb 2026 16:23:41 -0800 Subject: [PATCH 21/26] Inline the failure converter from the Nexus SDK and upgrade the Nexus SDK --- common/nexus/failure_test.go | 5 +- common/nexus/nexusrpc/client.go | 4 +- common/nexus/nexusrpc/completion.go | 11 +- common/nexus/nexusrpc/failure_converter.go | 218 ++++++++++++++++++ .../nexus/nexusrpc/failure_conveter_test.go | 139 +++++++++++ common/nexus/nexusrpc/server.go | 6 +- common/nexus/nexusrpc/setup_test.go | 4 +- components/nexusoperations/executors.go | 2 +- go.mod | 2 +- go.sum | 2 + service/frontend/nexus_handler.go | 8 +- service/history/handler.go | 5 +- 12 files changed, 382 insertions(+), 24 deletions(-) create mode 100644 common/nexus/nexusrpc/failure_converter.go create mode 100644 common/nexus/nexusrpc/failure_conveter_test.go diff --git a/common/nexus/failure_test.go b/common/nexus/failure_test.go index dae81a91b8..5f7dfc7cd9 100644 --- a/common/nexus/failure_test.go +++ b/common/nexus/failure_test.go @@ -9,6 +9,7 @@ import ( enumspb "go.temporal.io/api/enums/v1" failurepb "go.temporal.io/api/failure/v1" "go.temporal.io/sdk/temporal" + "go.temporal.io/server/common/nexus/nexusrpc" "go.temporal.io/server/common/testing/protorequire" ) @@ -188,7 +189,7 @@ func TestRoundTrip_OnlyStackTrace(t *testing.T) { } func TestFromOperationFailedError(t *testing.T) { - nexusFailure, err := nexus.DefaultFailureConverter().ErrorToFailure(&nexus.OperationError{ + nexusFailure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(&nexus.OperationError{ State: nexus.OperationStateFailed, Message: "operation failed", StackTrace: "stack trace", @@ -231,7 +232,7 @@ func TestFromOperationFailedError(t *testing.T) { } func TestFromOperationCanceledError(t *testing.T) { - nexusFailure, err := nexus.DefaultFailureConverter().ErrorToFailure(&nexus.OperationError{ + nexusFailure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(&nexus.OperationError{ State: nexus.OperationStateCanceled, Message: "operation canceled", StackTrace: "stack trace", diff --git a/common/nexus/nexusrpc/client.go b/common/nexus/nexusrpc/client.go index ff908b20e0..5985ccaf43 100644 --- a/common/nexus/nexusrpc/client.go +++ b/common/nexus/nexusrpc/client.go @@ -30,7 +30,7 @@ type HTTPClientOptions struct { Serializer nexus.Serializer // A [FailureConverter] to convert a [Failure] instance to and from an [error]. Defaults to // [DefaultFailureConverter]. - FailureConverter nexus.FailureConverter + FailureConverter FailureConverter } // User-Agent header set on HTTP requests. @@ -115,7 +115,7 @@ func NewHTTPClient(options HTTPClientOptions) (*HTTPClient, error) { options.Serializer = nexus.DefaultSerializer() } if options.FailureConverter == nil { - options.FailureConverter = nexus.DefaultFailureConverter() + options.FailureConverter = DefaultFailureConverter() } return &HTTPClient{ diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index add778e6f7..926a00de6e 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -165,10 +165,7 @@ type OperationCompletionUnsuccessful struct { type OperationCompletionUnsuccessfulOptions struct { // A [FailureConverter] to convert a [Failure] instance to and from an [error]. Defaults to // [DefaultFailureConverter]. - // - // NOTE: To call server versions <= 1.31.0, use a FailureConverter that unwraps the error cause if message is not - // present. That is is the default FailureConverter behavior, this comment applies to custom failure converters only. - FailureConverter nexus.FailureConverter + FailureConverter FailureConverter // OperationID is the unique ID for this operation. Used when a completion callback is received before a started response. // // Deprecated: Use OperationToken instead. @@ -187,7 +184,7 @@ type OperationCompletionUnsuccessfulOptions struct { // NewOperationCompletionUnsuccessful constructs an [OperationCompletionUnsuccessful] from a given error. func NewOperationCompletionUnsuccessful(opErr *nexus.OperationError, options OperationCompletionUnsuccessfulOptions) (*OperationCompletionUnsuccessful, error) { if options.FailureConverter == nil { - options.FailureConverter = nexus.DefaultFailureConverter() + options.FailureConverter = DefaultFailureConverter() } failure, err := options.FailureConverter.ErrorToFailure(opErr) if err != nil { @@ -280,7 +277,7 @@ type CompletionHandlerOptions struct { Serializer nexus.Serializer // A [FailureConverter] to convert a [Failure] instance to and from an [error]. Defaults to // [DefaultFailureConverter]. - FailureConverter nexus.FailureConverter + FailureConverter FailureConverter } type completionHTTPHandler struct { @@ -380,7 +377,7 @@ func NewCompletionHTTPHandler(options CompletionHandlerOptions) http.Handler { options.Serializer = nexus.DefaultSerializer() } if options.FailureConverter == nil { - options.FailureConverter = nexus.DefaultFailureConverter() + options.FailureConverter = DefaultFailureConverter() } return &completionHTTPHandler{ options: options, diff --git a/common/nexus/nexusrpc/failure_converter.go b/common/nexus/nexusrpc/failure_converter.go new file mode 100644 index 0000000000..04bbe9baaf --- /dev/null +++ b/common/nexus/nexusrpc/failure_converter.go @@ -0,0 +1,218 @@ +package nexusrpc + +import ( + "encoding/json" + "fmt" + + "github.com/nexus-rpc/sdk-go/nexus" +) + +// FailureConverter is used by the framework to transform [error] instances to and from [Failure] instances. +// To customize conversion logic, implement this interface and provide your implementation to framework methods such as +// [NewHTTPClient] and [NewHTTPHandler]. +// By default the SDK translates [HandlerError], [OperationError] and [FailureError] to and from [Failure] +// objects maintaining their cause chain. +// Arbitrary errors are translated to a [Failure] object with its Message set to the Error() string, losing the cause +// chain. +type FailureConverter interface { + // ErrorToFailure converts an [error] to a [Failure]. + // Note that the provided error may be nil. + ErrorToFailure(error) (nexus.Failure, error) + // FailureToError converts a [Failure] to an [error]. + FailureToError(nexus.Failure) (error, error) +} + +type knownErrorFailureConverter struct{} + +type serializedHandlerError struct { + Type string `json:"type,omitempty"` + RetryableOverride *bool `json:"retryableOverride,omitempty"` +} + +func (e serializedHandlerError) RetryBehavior() nexus.HandlerErrorRetryBehavior { + if e.RetryableOverride == nil { + return nexus.HandlerErrorRetryBehaviorUnspecified + } + if *e.RetryableOverride { + return nexus.HandlerErrorRetryBehaviorRetryable + } else { + return nexus.HandlerErrorRetryBehaviorNonRetryable + } +} + +type serializedOperationError struct { + State string `json:"state,omitempty"` +} + +// ErrorToFailure implements FailureConverter. +func (e knownErrorFailureConverter) ErrorToFailure(err error) (nexus.Failure, error) { + if err == nil { + return nexus.Failure{}, nil + } + // NOTE: not using errors.Unwrap here we are intentionally only supporting unwrapping known errors. + switch typedErr := err.(type) { + case *nexus.FailureError: + f := typedErr.Failure + // Convert the erorr cause if set and failure cause is not already set. + if typedErr.Cause != nil && f.Cause == nil { + c, err := e.ErrorToFailure(typedErr.Cause) + if err != nil { + return nexus.Failure{}, err + } + f.Cause = &c + } + return f, nil + case *nexus.HandlerError: + if typedErr.OriginalFailure != nil { + return *typedErr.OriginalFailure, nil + } + // Temporary workaround for compatibility with old SDKs that don't support handler error messages. + if typedErr.Message == "" && typedErr.Cause != nil { + return e.ErrorToFailure(typedErr.Cause) + } + data := serializedHandlerError{ + Type: string(typedErr.Type), + RetryableOverride: retryBehaviorAsOptionalBool(typedErr), + } + var details []byte + details, err := json.Marshal(data) + if err != nil { + return nexus.Failure{}, err + } + f := nexus.Failure{ + Message: typedErr.Message, + StackTrace: typedErr.StackTrace, + Metadata: map[string]string{ + "type": "nexus.HandlerError", + }, + Details: details, + } + + if typedErr.Cause != nil { + c, err := e.ErrorToFailure(typedErr.Cause) + if err != nil { + return nexus.Failure{}, err + } + f.Cause = &c + } + return f, nil + case *nexus.OperationError: + if typedErr.OriginalFailure != nil { + return *typedErr.OriginalFailure, nil + } + // Temporary workaround for compatibility with old SDKs that don't support operation error messages. + if typedErr.Message == "" && typedErr.Cause != nil { + return e.ErrorToFailure(typedErr.Cause) + } + data := serializedOperationError{ + State: string(typedErr.State), + } + details, err := json.Marshal(data) + if err != nil { + return nexus.Failure{}, err + } + f := nexus.Failure{ + Message: typedErr.Message, + StackTrace: typedErr.StackTrace, + Metadata: map[string]string{ + "type": "nexus.OperationError", + }, + Details: details, + } + + if typedErr.Cause != nil { + c, err := e.ErrorToFailure(typedErr.Cause) + if err != nil { + return nexus.Failure{}, err + } + f.Cause = &c + } + return f, nil + default: + return nexus.Failure{ + Message: typedErr.Error(), + }, nil + } +} + +// FailureToError implements FailureConverter. +func (e knownErrorFailureConverter) FailureToError(f nexus.Failure) (error, error) { + if f.Metadata != nil { + switch f.Metadata["type"] { + case "nexus.HandlerError": + var se serializedHandlerError + err := json.Unmarshal(f.Details, &se) + if err != nil { + return nil, fmt.Errorf("failed to deserialize HandlerError: %w", err) + } + he := &nexus.HandlerError{ + Message: f.Message, + StackTrace: f.StackTrace, + Type: nexus.HandlerErrorType(se.Type), + RetryBehavior: se.RetryBehavior(), + OriginalFailure: &f, + } + if f.Cause != nil { + he.Cause, err = e.FailureToError(*f.Cause) + if err != nil { + return nil, err + } + } + return he, nil + case "nexus.OperationError": + var se serializedOperationError + err := json.Unmarshal(f.Details, &se) + if err != nil { + return nil, fmt.Errorf("failed to deserialize OperationError: %w", err) + } + oe := &nexus.OperationError{ + Message: f.Message, + StackTrace: f.StackTrace, + State: nexus.OperationState(se.State), + OriginalFailure: &f, + } + if f.Cause != nil { + oe.Cause, err = e.FailureToError(*f.Cause) + if err != nil { + return nil, err + } + } + return oe, nil + } + } + // Note that the original failure cause is retained on the FailureError's failure object. + fe := &nexus.FailureError{Failure: f} + if f.Cause != nil { + c, err := e.FailureToError(*f.Cause) + if err != nil { + return nil, err + } + fe.Cause = c + } + return fe, nil +} + +var defaultFailureConverter FailureConverter = knownErrorFailureConverter{} + +// DefaultFailureConverter returns the SDK's default [FailureConverter] implementation. Translates [HandlerError], +// [OperationError] and [FailureError] to and from [Failure] objects maintaining their cause chain. +// Arbitrary errors are translated to a [Failure] object with its Message set to the Error() string, losing the cause +// chain. +// [Failure] instances are converted to [FailureError] to allow access to the full failure metadata and details if +// available. +func DefaultFailureConverter() FailureConverter { + return defaultFailureConverter +} + +func retryBehaviorAsOptionalBool(e *nexus.HandlerError) *bool { + switch e.RetryBehavior { + case nexus.HandlerErrorRetryBehaviorRetryable: + ret := true + return &ret + case nexus.HandlerErrorRetryBehaviorNonRetryable: + ret := false + return &ret + } + return nil +} + diff --git a/common/nexus/nexusrpc/failure_conveter_test.go b/common/nexus/nexusrpc/failure_conveter_test.go new file mode 100644 index 0000000000..516eca86f7 --- /dev/null +++ b/common/nexus/nexusrpc/failure_conveter_test.go @@ -0,0 +1,139 @@ +package nexusrpc + +import ( + "errors" + "testing" + + "github.com/nexus-rpc/sdk-go/nexus" + "github.com/stretchr/testify/require" +) + +func TestFailureConverter_GenericError(t *testing.T) { + failure, err := defaultFailureConverter.ErrorToFailure(errors.New("test")) + require.NoError(t, err) + actual, err := defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + require.Equal(t, &nexus.FailureError{Failure: nexus.Failure{Message: "test"}}, actual) +} + +func TestFailureConverter_FailureError(t *testing.T) { + cause := &nexus.FailureError{ + Failure: nexus.Failure{Message: "cause"}, + } + fe := &nexus.FailureError{ + Failure: nexus.Failure{ + Message: "foo", + }, + Cause: cause, + } + failure, err := defaultFailureConverter.ErrorToFailure(fe) + require.NoError(t, err) + actual, err := defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + // The serialized failure cause is retained. + fe.Failure.Cause = failure.Cause + require.Equal(t, fe, actual) + + // Serialize again and verify the original failure is used. + fe.Cause = errors.New("should be ignored") + failure, err = defaultFailureConverter.ErrorToFailure(fe) + require.NoError(t, err) + actual, err = defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + // Reset back to the orignal cause before comparing. + fe.Cause = cause + require.Equal(t, fe, actual) +} + +func TestFailureConverter_HandlerError(t *testing.T) { + cause := &nexus.FailureError{Failure: nexus.Failure{Message: "cause"}} + he := nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "foo") + he.StackTrace = "stack" + he.Cause = cause + failure, err := defaultFailureConverter.ErrorToFailure(he) + require.NoError(t, err) + // Verify that the original failure is retained. + he.OriginalFailure = &failure + actual, err := defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + require.Equal(t, he, actual) + + // Serialize again and verify the original failure is used. + he.Cause = errors.New("should be ignored") + failure, err = defaultFailureConverter.ErrorToFailure(he) + require.NoError(t, err) + actual, err = defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + // Reset back to the orignal cause before comparing. + he.Cause = cause + require.Equal(t, he, actual) +} + +func TestFailureConverter_HandlerErrorRetryBehavior(t *testing.T) { + he := nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "foo") + he.StackTrace = "stack" + he.RetryBehavior = nexus.HandlerErrorRetryBehaviorRetryable + failure, err := defaultFailureConverter.ErrorToFailure(he) + require.NoError(t, err) + // Verify that the original failure is retained. + he.OriginalFailure = &failure + actual, err := defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + + // Failure is rehydrated as failure error if it has no known type. + he.Cause = &nexus.FailureError{Failure: nexus.Failure{Message: "foo"}} + require.Equal(t, he, actual) +} + +func TestFailureConverter_OperationError(t *testing.T) { + cause := &nexus.FailureError{Failure: nexus.Failure{Message: "cause"}} + oe := nexus.NewOperationCanceledError("foo") + oe.StackTrace = "stack" + oe.Cause = cause + failure, err := defaultFailureConverter.ErrorToFailure(oe) + require.NoError(t, err) + // Verify that the original failure is retained. + oe.OriginalFailure = &failure + actual, err := defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + require.Equal(t, oe, actual) + + // Serialize again and verify the original failure is used. + oe.Cause = errors.New("should be ignored") + failure, err = defaultFailureConverter.ErrorToFailure(oe) + require.NoError(t, err) + actual, err = defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + // Reset back to the orignal cause before comparing. + oe.Cause = cause + require.Equal(t, oe, actual) +} + + +func TestDefaultFailureConverterArbitraryError(t *testing.T) { + sourceErr := errors.New("test") + conv := defaultFailureConverter + + f, err := conv.ErrorToFailure(sourceErr) + require.NoError(t, err) + convErr, err := conv.FailureToError(f) + require.NoError(t, err) + require.Equal(t, sourceErr.Error(), convErr.Error()) +} + +func TestDefaultFailureConverterFailureError(t *testing.T) { + sourceErr := &nexus.FailureError{ + Failure: nexus.Failure{ + Message: "test", + Metadata: map[string]string{"key": "value"}, + Details: []byte(`"details"`), + }, + } + conv := defaultFailureConverter + + f, err := conv.ErrorToFailure(sourceErr) + require.NoError(t, err) + convErr, err := conv.FailureToError(f) + require.NoError(t, err) + require.Equal(t, sourceErr, convErr) +} diff --git a/common/nexus/nexusrpc/server.go b/common/nexus/nexusrpc/server.go index 9119c1a21e..db3247408a 100644 --- a/common/nexus/nexusrpc/server.go +++ b/common/nexus/nexusrpc/server.go @@ -44,7 +44,7 @@ func applyResultToHTTPResponse(r nexus.HandlerStartOperationResult[any], writer type baseHTTPHandler struct { logger *slog.Logger - failureConverter nexus.FailureConverter + failureConverter FailureConverter } type httpHandler struct { @@ -302,7 +302,7 @@ type HandlerOptions struct { Serializer nexus.Serializer // A [FailureConverter] to convert a [Failure] instance to and from an [error]. // Defaults to [DefaultFailureConverter]. - FailureConverter nexus.FailureConverter + FailureConverter FailureConverter } func (h *httpHandler) handleRequest(writer http.ResponseWriter, request *http.Request) { @@ -378,7 +378,7 @@ func NewHTTPHandler(options HandlerOptions) http.Handler { options.Serializer = nexus.DefaultSerializer() } if options.FailureConverter == nil { - options.FailureConverter = nexus.DefaultFailureConverter() + options.FailureConverter = DefaultFailureConverter() } handler := &httpHandler{ baseHTTPHandler: baseHTTPHandler{ diff --git a/common/nexus/nexusrpc/setup_test.go b/common/nexus/nexusrpc/setup_test.go index b910cff02e..5a4369b075 100644 --- a/common/nexus/nexusrpc/setup_test.go +++ b/common/nexus/nexusrpc/setup_test.go @@ -19,7 +19,7 @@ const testTimeout = time.Second * 5 const testService = "Ser/vic e" const getResultMaxTimeout = time.Millisecond * 300 -func setupCustom(t *testing.T, handler nexus.Handler, serializer nexus.Serializer, failureConverter nexus.FailureConverter) (ctx context.Context, client *nexusrpc.HTTPClient, teardown func()) { +func setupCustom(t *testing.T, handler nexus.Handler, serializer nexus.Serializer, failureConverter nexusrpc.FailureConverter) (ctx context.Context, client *nexusrpc.HTTPClient, teardown func()) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) httpHandler := nexusrpc.NewHTTPHandler(nexusrpc.HandlerOptions{ @@ -55,7 +55,7 @@ func setup(t *testing.T, handler nexus.Handler) (ctx context.Context, client *ne return setupCustom(t, handler, nil, nil) } -func setupForCompletion(t *testing.T, handler nexusrpc.CompletionHandler, serializer nexus.Serializer, failureConverter nexus.FailureConverter) (ctx context.Context, callbackURL string, teardown func()) { +func setupForCompletion(t *testing.T, handler nexusrpc.CompletionHandler, serializer nexus.Serializer, failureConverter nexusrpc.FailureConverter) (ctx context.Context, callbackURL string, teardown func()) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) httpHandler := nexusrpc.NewCompletionHTTPHandler(nexusrpc.CompletionHandlerOptions{ diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 595e51c8be..2e2e5b4e43 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -933,7 +933,7 @@ func callErrToFailure(callErr error, retryable bool) (*failurepb.Failure, error) if handlerErr.Message == "" { handlerErr.Message = "handler error" } - nf, err = nexus.DefaultFailureConverter().ErrorToFailure(handlerErr) + nf, err = nexusrpc.DefaultFailureConverter().ErrorToFailure(handlerErr) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index f5551ab39d..a7bcdb85d1 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/lib/pq v1.10.9 github.com/maruel/panicparse/v2 v2.4.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/nexus-rpc/sdk-go v0.5.2-0.20251217172131-63a8027ef960 + github.com/nexus-rpc/sdk-go v0.5.2-0.20260210000428-3d8ad6dc9742 github.com/olekukonko/tablewriter v0.0.5 github.com/olivere/elastic/v7 v7.0.32 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index f3fbd0028d..da1cdacc97 100644 --- a/go.sum +++ b/go.sum @@ -236,6 +236,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nexus-rpc/sdk-go v0.5.2-0.20251217172131-63a8027ef960 h1:ljAYqlX3IFBf7zqF8JGAgn21k7PBq4qyS8d45LcLDmQ= github.com/nexus-rpc/sdk-go v0.5.2-0.20251217172131-63a8027ef960/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= +github.com/nexus-rpc/sdk-go v0.5.2-0.20260210000428-3d8ad6dc9742 h1:nE/NqmspHqHkLh8rsSsQ/dLXOrbx4N7dU7GoGsO0DdI= +github.com/nexus-rpc/sdk-go v0.5.2-0.20260210000428-3d8ad6dc9742/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index 06c86180b6..e5b9f4129e 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -460,7 +460,7 @@ func (h *nexusHandler) StartOperation( oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } - he, err := nexus.DefaultFailureConverter().FailureToError(nf) + he, err := nexusrpc.DefaultFailureConverter().FailureToError(nf) if err != nil { oc.logger.Error("error converting Nexus failure to Nexus HandlerError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") @@ -523,7 +523,7 @@ func (h *nexusHandler) StartOperation( oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } - cause, err := nexus.DefaultFailureConverter().FailureToError(nf) + cause, err := nexusrpc.DefaultFailureConverter().FailureToError(nf) if err != nil { oc.logger.Error("error converting Nexus failure to Nexus OperationError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") @@ -537,7 +537,7 @@ func (h *nexusHandler) StartOperation( Message: "operation error", Cause: cause, } - nf, err = nexus.DefaultFailureConverter().ErrorToFailure(opErr) + nf, err = nexusrpc.DefaultFailureConverter().ErrorToFailure(opErr) if err != nil { oc.logger.Error("error converting OperationError to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") @@ -677,7 +677,7 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } - he, err := nexus.DefaultFailureConverter().FailureToError(nf) + he, err := nexusrpc.DefaultFailureConverter().FailureToError(nf) if err != nil { oc.logger.Error("error converting Nexus failure to Nexus HandlerError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") diff --git a/service/history/handler.go b/service/history/handler.go index 3288c55edd..e091557232 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -36,6 +36,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" commonnexus "go.temporal.io/server/common/nexus" + "go.temporal.io/server/common/nexus/nexusrpc" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/visibility/manager" @@ -2158,7 +2159,7 @@ func (h *Handler) CompleteNexusOperation(ctx context.Context, request *historyse var opErr *nexus.OperationError if request.State != string(nexus.OperationStateSucceeded) { failure := commonnexus.ProtoFailureToNexusFailure(request.GetFailure()) - recvdErr, err := nexus.DefaultFailureConverter().FailureToError(failure) + recvdErr, err := nexusrpc.DefaultFailureConverter().FailureToError(failure) if err != nil { return nil, serviceerror.NewInvalidArgument("unable to convert failure to error") } @@ -2171,7 +2172,7 @@ func (h *Handler) CompleteNexusOperation(ctx context.Context, request *historyse Message: "nexus operation completed unsuccessfully", Cause: recvdErr, } - origFailure, err := nexus.DefaultFailureConverter().ErrorToFailure(opErr) + origFailure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(opErr) if err != nil { return nil, serviceerror.NewInvalidArgument("unable to convert operation error to failure") } From eaac970af437576ace1740c1946f9ccc26cb9e09 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Tue, 10 Feb 2026 21:41:48 -0800 Subject: [PATCH 22/26] Refactor --- chasm/lib/callback/chasm_invocation.go | 31 +- chasm/lib/callback/executors_test.go | 78 ++--- chasm/lib/callback/nexus_invocation.go | 140 +++------ common/nexus/failure.go | 74 ++--- common/nexus/nexusrpc/api.go | 18 +- common/nexus/nexusrpc/cancel_test.go | 16 +- common/nexus/nexusrpc/client.go | 141 +++++---- common/nexus/nexusrpc/completion.go | 262 ++++++++-------- common/nexus/nexusrpc/completion_test.go | 136 ++++----- common/nexus/nexusrpc/failure_converter.go | 17 +- .../nexus/nexusrpc/failure_conveter_test.go | 10 +- common/nexus/nexusrpc/handle.go | 4 +- common/nexus/nexusrpc/server.go | 90 +++--- common/nexus/nexusrpc/start_test.go | 18 +- common/nexus/payload_serializer.go | 4 + components/callbacks/chasm_invocation.go | 31 +- components/callbacks/executors_test.go | 67 ++--- components/callbacks/nexus_invocation.go | 134 ++------- components/nexusoperations/executors.go | 10 +- components/nexusoperations/executors_test.go | 31 +- .../nexusoperations/frontend/handler.go | 154 +++------- service/frontend/nexus_handler.go | 64 ++-- service/frontend/nexus_http_handler.go | 73 ++--- service/history/handler.go | 10 +- .../history/workflow/mutable_state_impl.go | 121 ++++---- .../workflow_test/mutable_state_impl_test.go | 22 +- tests/callbacks_test.go | 2 +- tests/nexus_api_test.go | 37 +-- tests/nexus_workflow_test.go | 283 +++++++++--------- tests/xdc/nexus_request_forwarding_test.go | 64 ++-- tests/xdc/nexus_state_replication_test.go | 41 +-- 31 files changed, 929 insertions(+), 1254 deletions(-) diff --git a/chasm/lib/callback/chasm_invocation.go b/chasm/lib/callback/chasm_invocation.go index bddfb00184..b0158db486 100644 --- a/chasm/lib/callback/chasm_invocation.go +++ b/chasm/lib/callback/chasm_invocation.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "fmt" - "io" "github.com/google/uuid" "github.com/nexus-rpc/sdk-go/nexus" @@ -137,20 +136,12 @@ func (c chasmInvocation) getHistoryRequest( switch op := c.completion.(type) { case *nexusrpc.OperationCompletionSuccessful: - payloadBody, err := io.ReadAll(op.Reader) - if err != nil { - return nil, fmt.Errorf("failed to read payload: %v", err) - } - var payload *commonpb.Payload - if payloadBody != nil { - content := &nexus.Content{ - Header: op.Reader.Header, - Data: payloadBody, - } - err := commonnexus.PayloadSerializer.Deserialize(content, &payload) - if err != nil { - return nil, fmt.Errorf("failed to deserialize payload: %v", err) + if op.Result != nil { + var ok bool + payload, ok = op.Result.(*commonpb.Payload) + if !ok { + return nil, fmt.Errorf("invalid result, expected a payload, got: %T", op.Result) } } @@ -162,9 +153,17 @@ func (c chasmInvocation) getHistoryRequest( Completion: completion, } case *nexusrpc.OperationCompletionUnsuccessful: - apiFailure, err := commonnexus.NexusFailureToTemporalFailure(op.Failure) + failure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(op.Error) + if err != nil { + return nil, fmt.Errorf("failed to convert error to failure: %w", err) + } + // Unwrap the operation error since it's not meant to be sent for Temporal->Temporal completions. + if failure.Cause != nil { + failure = *failure.Cause + } + apiFailure, err := commonnexus.NexusFailureToTemporalFailure(failure) if err != nil { - return nil, fmt.Errorf("failed to convert failure type: %v", err) + return nil, fmt.Errorf("failed to convert failure type: %w", err) } req = &historyservice.CompleteNexusOperationChasmRequest{ diff --git a/chasm/lib/callback/executors_test.go b/chasm/lib/callback/executors_test.go index a43883ab20..8101b14048 100644 --- a/chasm/lib/callback/executors_test.go +++ b/chasm/lib/callback/executors_test.go @@ -81,7 +81,7 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { caller: func(r *http.Request) (*http.Response, error) { return &http.Response{StatusCode: 200, Body: http.NoBody}, nil }, - expectedMetricOutcome: "status:200", + expectedMetricOutcome: "success", assertOutcome: func(t *testing.T, cb *Callback, err error) { require.NoError(t, err) require.Equal(t, callbackspb.CALLBACK_STATUS_SUCCEEDED, cb.Status) @@ -104,7 +104,7 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { caller: func(r *http.Request) (*http.Response, error) { return &http.Response{StatusCode: 500, Body: http.NoBody}, nil }, - expectedMetricOutcome: "status:500", + expectedMetricOutcome: "handler-error:INTERNAL", assertOutcome: func(t *testing.T, cb *Callback, err error) { var destDownErr *queueserrors.DestinationDownError require.ErrorAs(t, err, &destDownErr) @@ -116,7 +116,7 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { caller: func(r *http.Request) (*http.Response, error) { return &http.Response{StatusCode: 400, Body: http.NoBody}, nil }, - expectedMetricOutcome: "status:400", + expectedMetricOutcome: "handler-error:BAD_REQUEST", assertOutcome: func(t *testing.T, cb *Callback, err error) { require.NoError(t, err) require.Equal(t, callbackspb.CALLBACK_STATUS_FAILED, cb.Status) @@ -210,8 +210,7 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { } // Create completion - completion, err := nexusrpc.NewOperationCompletionSuccessful(nil, nexusrpc.OperationCompletionSuccessfulOptions{}) - require.NoError(t, err) + completion := &nexusrpc.OperationCompletionSuccessful{} // Set up the CompletionSource field to return our mock completion root.SetRootComponent(&mockNexusCompletionGetterComponent{ @@ -398,15 +397,10 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { return client }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayloadBytes([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - CloseTime: dummyTime, - }, - ) - require.NoError(t, err) - return comp + return &nexusrpc.OperationCompletionSuccessful{ + Result: createPayloadBytes([]byte("result-data")), + CloseTime: dummyTime, + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb *Callback, err error) { @@ -431,17 +425,13 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { return client }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{ + return &nexusrpc.OperationCompletionUnsuccessful{ + Error: &nexus.OperationError{ State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: nexus.Failure{Message: "operation failed"}}, }, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - CloseTime: dummyTime, - }, - ) - require.NoError(t, err) - return comp + CloseTime: dummyTime, + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb *Callback, err error) { @@ -460,14 +450,9 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { return client }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayloadBytes([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + return &nexusrpc.OperationCompletionSuccessful{ + Result: createPayloadBytes([]byte("result-data")), + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb *Callback, err error) { @@ -486,14 +471,9 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { return client }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayloadBytes([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + return &nexusrpc.OperationCompletionSuccessful{ + Result: createPayloadBytes([]byte("result-data")), + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb *Callback, err error) { @@ -508,14 +488,9 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { return historyservicemock.NewMockHistoryServiceClient(ctrl) }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayloadBytes([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + return &nexusrpc.OperationCompletionSuccessful{ + Result: createPayloadBytes([]byte("result-data")), + } }(), headerValue: "invalid-base64!!!", assertOutcome: func(t *testing.T, cb *Callback, err error) { @@ -530,14 +505,9 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { return historyservicemock.NewMockHistoryServiceClient(ctrl) }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayloadBytes([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + return &nexusrpc.OperationCompletionSuccessful{ + Result: createPayloadBytes([]byte("result-data")), + } }(), headerValue: base64.RawURLEncoding.EncodeToString([]byte("not-valid-protobuf")), assertOutcome: func(t *testing.T, cb *Callback, err error) { diff --git a/chasm/lib/callback/nexus_invocation.go b/chasm/lib/callback/nexus_invocation.go index 73eb1285b9..ceee0f8a08 100644 --- a/chasm/lib/callback/nexus_invocation.go +++ b/chasm/lib/callback/nexus_invocation.go @@ -1,15 +1,10 @@ package callback import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" - "mime" + "errors" "net/http" "net/http/httptrace" - "slices" "time" "github.com/nexus-rpc/sdk-go/nexus" @@ -37,20 +32,6 @@ type nexusInvocation struct { attempt int32 } -func isRetryableHTTPResponse(response *http.Response) bool { - return response.StatusCode >= 500 || slices.Contains(retryable4xxErrorTypes, response.StatusCode) -} - -func outcomeTag(callCtx context.Context, response *http.Response, callErr error) string { - if callErr != nil { - if callCtx.Err() != nil { - return "request-timeout" - } - return "unknown-error" - } - return fmt.Sprintf("status:%d", response.StatusCode) -} - func (n nexusInvocation) WrapError(result invocationResult, err error) error { if retry, ok := result.(invocationResultRetry); ok { return queueserrors.NewDestinationDownError(retry.err.Error(), err) @@ -80,107 +61,56 @@ func (n nexusInvocation) Invoke( } } - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, n.nexus.Url, n.completion) - if err != nil { - return invocationResultFail{queueserrors.NewUnprocessableTaskError( - fmt.Sprintf("failed to construct Nexus request: %v", err), - )} - } - if request.Header == nil { - request.Header = make(http.Header) - } - for k, v := range n.nexus.Header { - request.Header.Set(k, v) - } - - caller := e.httpCallerProvider(queuescommon.NamespaceIDAndDestination{ - NamespaceID: ns.ID().String(), - Destination: taskAttr.Destination, + client := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + HTTPCaller: e.httpCallerProvider(queuescommon.NamespaceIDAndDestination{ + NamespaceID: ns.ID().String(), + Destination: taskAttr.Destination, + }), + Serializer: commonnexus.PayloadSerializer, }) // Make the call and record metrics. startTime := time.Now() - response, err := caller(request) + + for k, v := range n.nexus.Header { + n.completion.SetHeader(k, v) + } + err := client.CompleteOperation(ctx, n.nexus.Url, n.completion) namespaceTag := metrics.NamespaceTag(ns.Name().String()) destTag := metrics.DestinationTag(taskAttr.Destination) - statusCodeTag := metrics.OutcomeTag(outcomeTag(ctx, response, err)) - e.metricsHandler.Counter(RequestCounter.Name()).Record(1, namespaceTag, destTag, statusCodeTag) - e.metricsHandler.Timer(RequestLatencyHistogram.Name()).Record(time.Since(startTime), namespaceTag, destTag, statusCodeTag) + outcomeTag := metrics.OutcomeTag(outcomeTag(ctx, err)) + e.metricsHandler.Counter(RequestCounter.Name()).Record(1, namespaceTag, destTag, outcomeTag) + e.metricsHandler.Timer(RequestLatencyHistogram.Name()).Record(time.Since(startTime), namespaceTag, destTag, outcomeTag) - if err != nil { - e.logger.Error("Callback request failed with error", tag.Error(err)) - return invocationResultRetry{err: err} - } - - if response.StatusCode >= 200 && response.StatusCode < 300 { - // Body is not read but should be discarded to keep the underlying TCP connection alive. - // Just in case something unexpected happens while discarding or closing the body, - // propagate errors to the machine. - if _, err = io.Copy(io.Discard, response.Body); err == nil { - if err = response.Body.Close(); err != nil { - e.logger.Error("Callback request failed with error", tag.Error(err)) - return invocationResultRetry{err: err} - } - } + if err == nil { return invocationResultOK{} } - - retryable := isRetryableHTTPResponse(response) - err = readHandlerErrFromResponse(response, e.logger) - e.logger.Error("Callback request failed", tag.Error(err), tag.String("status", response.Status), tag.Bool("retryable", retryable)) + retryable := isRetryableCallError(err) + e.logger.Error("Callback request failed", tag.Error(err), tag.Bool("retryable", retryable)) if retryable { - return invocationResultRetry{err: err} + return invocationResultRetry{err} } return invocationResultFail{err} } -// Reads and replaces the http response body and attempts to deserialize it into a Nexus failure. If successful, -// returns a nexus.HandlerError with the deserialized failure as the Cause. If there is an error reading the body or -// during deserialization, returns a nexus.HandlerError with a generic Cause based on response status. -// TODO: This logic is duplicated in the frontend handler for forwarded requests. Eventually it should live in the Nexus SDK. -func readHandlerErrFromResponse(response *http.Response, logger log.Logger) error { - handlerErr := &nexus.HandlerError{ - Type: commonnexus.HandlerErrorTypeFromHTTPStatus(response.StatusCode), - Cause: fmt.Errorf("request failed with: %v", response.Status), - } - - body, err := readAndReplaceBody(response) - if err != nil { - logger.Error("Error reading response body for non-ok callback request", tag.Error(err), tag.String("status", response.Status)) - return err - } - - if !isMediaTypeJSON(response.Header.Get("Content-Type")) { - logger.Error("received invalid content-type header for non-OK HTTP response to CompleteOperation request", tag.Value(response.Header.Get("Content-Type"))) - return handlerErr - } - - var failure nexus.Failure - err = json.Unmarshal(body, &failure) - if err != nil { - logger.Error("failed to deserialize Nexus Failure from HTTP response to CompleteOperation request", tag.Error(err)) - return handlerErr +func isRetryableCallError(err error) bool { + var handlerError *nexus.HandlerError + if errors.As(err, &handlerError) { + return handlerError.Retryable() } - - handlerErr.Cause = &nexus.FailureError{Failure: failure} - return handlerErr -} - -// readAndReplaceBody reads the response body in its entirety and closes it, and then replaces the original response -// body with an in-memory buffer. -// The body is replaced even when there was an error reading the entire body. -func readAndReplaceBody(response *http.Response) ([]byte, error) { - responseBody := response.Body - body, err := io.ReadAll(responseBody) - _ = responseBody.Close() - response.Body = io.NopCloser(bytes.NewReader(body)) - return body, err + return true } -func isMediaTypeJSON(contentType string) bool { - if contentType == "" { - return false +func outcomeTag(callCtx context.Context, callErr error) string { + if callErr != nil { + if callCtx.Err() != nil { + return "request-timeout" + } + var handlerErr *nexus.HandlerError + if errors.As(callErr, &handlerErr) { + return "handler-error:" + string(handlerErr.Type) + } + return "unknown-error" } - mediaType, _, err := mime.ParseMediaType(contentType) - return err == nil && mediaType == "application/json" + return "success" } diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 5426d5ef9d..c5543f1b83 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "encoding/json" - "errors" "fmt" "net/http" "sync/atomic" @@ -332,16 +331,16 @@ func ConvertGRPCError(err error, exposeDetails bool) error { errMessage = "bad request" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeBadRequest, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeBadRequest, + Message: errMessage, } case codes.Aborted, codes.Unavailable: if !exposeDetails { errMessage = "service unavailable" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnavailable, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeUnavailable, + Message: errMessage, } case codes.Canceled: // TODO: This should have a different status code (e.g. 499 which is semi standard but not supported by nexus). @@ -350,72 +349,72 @@ func ConvertGRPCError(err error, exposeDetails bool) error { errMessage = "canceled" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeInternal, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeInternal, + Message: errMessage, } case codes.DataLoss, codes.Internal, codes.Unknown: if !exposeDetails { errMessage = "internal error" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeInternal, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeInternal, + Message: errMessage, } case codes.Unauthenticated: if !exposeDetails { errMessage = "authentication failed" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnauthenticated, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeUnauthenticated, + Message: errMessage, } case codes.PermissionDenied: if !exposeDetails { errMessage = "permission denied" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnauthorized, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeUnauthorized, + Message: errMessage, } case codes.NotFound: if !exposeDetails { errMessage = "not found" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeNotFound, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeNotFound, + Message: errMessage, } case codes.ResourceExhausted: if !exposeDetails { errMessage = "resource exhausted" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeResourceExhausted, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeResourceExhausted, + Message: errMessage, } case codes.Unimplemented: if !exposeDetails { errMessage = "not implemented" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeNotImplemented, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeNotImplemented, + Message: errMessage, } case codes.DeadlineExceeded: if !exposeDetails { errMessage = "request timeout" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUpstreamTimeout, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeUpstreamTimeout, + Message: errMessage, } case codes.OK: return nil } if !exposeDetails { return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeInternal, - Cause: errors.New("internal error"), + Type: nexus.HandlerErrorTypeInternal, + Message: "internal error", } } // Let the nexus SDK handle this for us (log and convert to an internal error). @@ -424,32 +423,7 @@ func ConvertGRPCError(err error, exposeDetails bool) error { func AdaptAuthorizeError(permissionDeniedError *serviceerror.PermissionDenied) error { if permissionDeniedError.Reason != "" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnauthorized, "permission denied: %s", permissionDeniedError.Reason) - } - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnauthorized, "permission denied") -} - -func HandlerErrorTypeFromHTTPStatus(statusCode int) nexus.HandlerErrorType { - switch statusCode { - case http.StatusBadRequest: - return nexus.HandlerErrorTypeBadRequest - case http.StatusUnauthorized: - return nexus.HandlerErrorTypeUnauthenticated - case http.StatusForbidden: - return nexus.HandlerErrorTypeUnauthorized - case http.StatusNotFound: - return nexus.HandlerErrorTypeNotFound - case http.StatusTooManyRequests: - return nexus.HandlerErrorTypeResourceExhausted - case http.StatusInternalServerError: - return nexus.HandlerErrorTypeInternal - case http.StatusNotImplemented: - return nexus.HandlerErrorTypeNotImplemented - case http.StatusServiceUnavailable: - return nexus.HandlerErrorTypeUnavailable - case nexus.StatusUpstreamTimeout: - return nexus.HandlerErrorTypeUpstreamTimeout - default: - return nexus.HandlerErrorTypeInternal + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnauthorized, "permission denied: %s", permissionDeniedError.Reason) } + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnauthorized, "permission denied") } diff --git a/common/nexus/nexusrpc/api.go b/common/nexus/nexusrpc/api.go index ae20c54f61..5743725633 100644 --- a/common/nexus/nexusrpc/api.go +++ b/common/nexus/nexusrpc/api.go @@ -33,8 +33,8 @@ const contentTypeJSON = "application/json" // Query param for passing a callback URL. const queryCallbackURL = "callback" -// HTTP status code for failed operation responses. -const statusOperationFailed = http.StatusFailedDependency +// HTTP status code for unsuccessful (failed or canceled) operation responses. +const statusOperationUnsuccessful = http.StatusFailedDependency func isMediaTypeJSON(contentType string) bool { if contentType == "" { @@ -273,3 +273,17 @@ func ParseDuration(value string) (time.Duration, error) { func FormatDuration(d time.Duration) string { return strconv.FormatInt(d.Milliseconds(), 10) + "ms" } + +// MarkAsWrapperError adds the "unwrap-error" metadata to the original failure of the given OperationError, which +// signals the Temporal codebase to unwrap the underlying failure as the failure cause. +// This is used as a shim for Temporal->Temporal calls. Temporal already has a Failure type that represents an +// OperationError and does not need to record this wrapper error. +func MarkAsWrapperError(failureConverter FailureConverter, opErr *nexus.OperationError) error { + originalFailure, err := failureConverter.ErrorToFailure(opErr) + if err != nil { + return err + } + originalFailure.Metadata["unwrap-error"] = "true" + opErr.OriginalFailure = &originalFailure + return nil +} diff --git a/common/nexus/nexusrpc/cancel_test.go b/common/nexus/nexusrpc/cancel_test.go index c33e13d109..8ae2cd32e3 100644 --- a/common/nexus/nexusrpc/cancel_test.go +++ b/common/nexus/nexusrpc/cancel_test.go @@ -23,19 +23,19 @@ func (h *asyncWithCancelHandler) StartOperation(ctx context.Context, service, op func (h *asyncWithCancelHandler) CancelOperation(ctx context.Context, service, operation, token string, options nexus.CancelOperationOptions) error { if service != testService { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected service: %s", service) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected service: %s", service) } if operation != "f/o/o" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to be 'foo', got: %s", operation) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to be 'foo', got: %s", operation) } if token != "a/sync" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation ID to be 'async', got: %s", token) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation ID to be 'async', got: %s", token) } if h.expectHeader && options.Header.Get("foo") != "bar" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' request header") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' request header") } if options.Header.Get("User-Agent") != "temporalio/server" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", options.Header.Get("User-Agent")) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", options.Header.Get("User-Agent")) } return nil } @@ -78,14 +78,14 @@ func (h *echoTimeoutAsyncWithCancelHandler) StartOperation(ctx context.Context, func (h *echoTimeoutAsyncWithCancelHandler) CancelOperation(ctx context.Context, service, operation, token string, options nexus.CancelOperationOptions) error { deadline, set := ctx.Deadline() if h.expectedTimeout > 0 && !set { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to have timeout set but context has no deadline") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to have timeout set but context has no deadline") } if h.expectedTimeout <= 0 && set { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to have no timeout but context has deadline set") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to have no timeout but context has deadline set") } timeout := time.Until(deadline) if timeout > h.expectedTimeout { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation has timeout (%s) greater than expected (%s)", timeout.String(), h.expectedTimeout.String()) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation has timeout (%s) greater than expected (%s)", timeout.String(), h.expectedTimeout.String()) } return nil } diff --git a/common/nexus/nexusrpc/client.go b/common/nexus/nexusrpc/client.go index 5985ccaf43..aad46be6db 100644 --- a/common/nexus/nexusrpc/client.go +++ b/common/nexus/nexusrpc/client.go @@ -74,6 +74,66 @@ func newUnexpectedResponseError(message string, response *http.Response, body [] } } +type baseHTTPClient struct { + // A function for making HTTP requests. + // Defaults to [http.DefaultClient.Do]. + httpCaller func(*http.Request) (*http.Response, error) + // A [serializer] to customize client serialization behavior. + // By default the client handles JSONables, byte slices, and nil. + serializer nexus.Serializer + // A [failureConverter] to convert a [Failure] instance to and from an [error]. Defaults to + // [DefaultFailureConverter]. + failureConverter FailureConverter +} + +func (c *baseHTTPClient) failureFromResponse(response *http.Response, body []byte) (nexus.Failure, error) { + if !isMediaTypeJSON(response.Header.Get("Content-Type")) { + return nexus.Failure{}, newUnexpectedResponseError(fmt.Sprintf("invalid response content type: %q", response.Header.Get("Content-Type")), response, body) + } + var failure nexus.Failure + err := json.Unmarshal(body, &failure) + return failure, err +} + +func (c *baseHTTPClient) defaultErrorFromResponse(response *http.Response, body []byte, cause error) error { + errorType, err := httpStatusCodeToHandlerErrorType(response) + if err != nil { + // TODO(bergundy): optimization - use the provided cause, it's already a deserialized failure. + return newUnexpectedResponseError(err.Error(), response, body) + } + handlerErr := &nexus.HandlerError{ + Type: errorType, + // For compatibility with older servers. + RetryBehavior: retryBehaviorFromHeader(response.Header), + Cause: cause, + } + + // Ensure we the original failure is available, the calling code expects it. + originalFailure, err := c.failureConverter.ErrorToFailure(handlerErr) + if err != nil { + return newUnexpectedResponseError("failed to construct handler error from response: "+err.Error(), response, body) + } + handlerErr.OriginalFailure = &originalFailure + return handlerErr +} + +// BestEffortHandlerErrorFromResponse attempts to read a handler error from the response, but falls back to a default error. +// This method is exposed as a workaround because the functionality is required for completion. +func (c *baseHTTPClient) bestEffortHandlerErrorFromResponse(response *http.Response, body []byte) error { + failure, err := c.failureFromResponse(response, body) + if err != nil { + return c.defaultErrorFromResponse(response, body, nil) + } + convErr, err := c.failureConverter.FailureToError(failure) + if err != nil { + return newUnexpectedResponseError(fmt.Sprintf("failed to convert Failure to error: %s", err.Error()), response, body) + } + if _, ok := convErr.(*nexus.HandlerError); !ok { + convErr = c.defaultErrorFromResponse(response, body, convErr) + } + return convErr +} + // An HTTPClient makes Nexus service requests as defined in the [Nexus HTTP API]. // // It can start a new operation and get an [OperationHandle] to an existing, asynchronous operation. @@ -85,8 +145,8 @@ func newUnexpectedResponseError(message string, response *http.Response, body [] // // [Nexus HTTP API]: https://github.com/nexus-rpc/api type HTTPClient struct { - // The options this client was created with after applying defaults. - options HTTPClientOptions + baseHTTPClient + service string serviceBaseURL *url.URL } @@ -119,8 +179,13 @@ func NewHTTPClient(options HTTPClientOptions) (*HTTPClient, error) { } return &HTTPClient{ - options: options, + baseHTTPClient: baseHTTPClient{ + serializer: options.Serializer, + failureConverter: options.FailureConverter, + httpCaller: options.HTTPCaller, + }, serviceBaseURL: baseURL, + service: options.Service, }, nil } @@ -174,7 +239,7 @@ func (c *HTTPClient) StartOperation( content, ok := input.(*nexus.Content) if !ok { var err error - content, err = c.options.Serializer.Serialize(input) + content, err = c.serializer.Serialize(input) if err != nil { return nil, err } @@ -192,7 +257,7 @@ func (c *HTTPClient) StartOperation( } } - url := c.serviceBaseURL.JoinPath(url.PathEscape(c.options.Service), url.PathEscape(operation)) + url := c.serviceBaseURL.JoinPath(url.PathEscape(c.service), url.PathEscape(operation)) if options.CallbackURL != "" { q := url.Query() @@ -222,7 +287,7 @@ func (c *HTTPClient) StartOperation( // If this request is handled by a newer server that supports Nexus failure serialization, trigger that behavior. request.Header.Set("temporal-nexus-failure-support", "true") - response, err := c.options.HTTPCaller(request) + response, err := c.httpCaller(request) if err != nil { return nil, err } @@ -249,7 +314,7 @@ func (c *HTTPClient) StartOperation( if response.StatusCode == http.StatusOK { return &ClientStartOperationResponse[*nexus.LazyValue]{ Successful: nexus.NewLazyValue( - c.options.Serializer, + c.serializer, &nexus.Reader{ ReadCloser: response.Body, Header: prefixStrippedHTTPHeaderToNexusHeader(response.Header, "content-"), @@ -282,13 +347,13 @@ func (c *HTTPClient) StartOperation( Pending: handle, Links: links, }, nil - case statusOperationFailed: + case statusOperationUnsuccessful: failure, err := c.failureFromResponse(response, body) if err != nil { return nil, err } - wireErr, err := c.options.FailureConverter.FailureToError(failure) + wireErr, err := c.failureConverter.FailureToError(failure) if err != nil { return nil, err } @@ -304,15 +369,9 @@ func (c *HTTPClient) StartOperation( Message: "nexus operation completed unsuccessfully", Cause: wireErr, } - // Ensure we the original failure is available, the calling code expects it. - originalFailure, err := c.options.FailureConverter.ErrorToFailure(opErr) - if err != nil { - return nil, err + if err := MarkAsWrapperError(c.failureConverter, opErr); err != nil { + return nil, fmt.Errorf("failed to mark operation error as wrapper error: %w", err) } - // Special header to signal that this error should be unwrapped by the completion handler as old servers will send - // back empty wrappers for underlying causes. - originalFailure.Metadata["unwrap-error"] = "true" - opErr.OriginalFailure = &originalFailure wireErr = opErr } @@ -367,54 +426,6 @@ func operationInfoFromResponse(response *http.Response, body []byte) (*nexus.Ope return &info, nil } -func (c *HTTPClient) failureFromResponse(response *http.Response, body []byte) (nexus.Failure, error) { - if !isMediaTypeJSON(response.Header.Get("Content-Type")) { - return nexus.Failure{}, newUnexpectedResponseError(fmt.Sprintf("invalid response content type: %q", response.Header.Get("Content-Type")), response, body) - } - var failure nexus.Failure - err := json.Unmarshal(body, &failure) - return failure, err -} - -func (c *HTTPClient) defaultErrorFromResponse(response *http.Response, body []byte, cause error) error { - errorType, err := httpStatusCodeToHandlerErrorType(response) - if err != nil { - // TODO(bergundy): optimization - use the provided cause, it's already a deserialized failure. - return newUnexpectedResponseError(err.Error(), response, body) - } - statusText := strings.TrimPrefix(response.Status, fmt.Sprintf("%d ", response.StatusCode)) - handlerErr := &nexus.HandlerError{ - Type: errorType, - Message: statusText, - // For compatibility with older servers. - RetryBehavior: retryBehaviorFromHeader(response.Header), - Cause: cause, - } - - // Ensure we the original failure is available, the calling code expects it. - originalFailure, err := c.options.FailureConverter.ErrorToFailure(handlerErr) - if err != nil { - return newUnexpectedResponseError("failed to construct handler error from response: "+err.Error(), response, body) - } - handlerErr.OriginalFailure = &originalFailure - return handlerErr -} - -func (c *HTTPClient) bestEffortHandlerErrorFromResponse(response *http.Response, body []byte) error { - failure, err := c.failureFromResponse(response, body) - if err != nil { - return c.defaultErrorFromResponse(response, body, nil) - } - convErr, err := c.options.FailureConverter.FailureToError(failure) - if err != nil { - return newUnexpectedResponseError(fmt.Sprintf("failed to convert Failure to error: %s", err.Error()), response, body) - } - if _, ok := convErr.(*nexus.HandlerError); !ok { - convErr = c.defaultErrorFromResponse(response, body, convErr) - } - return convErr -} - func httpStatusCodeToHandlerErrorType(response *http.Response) (nexus.HandlerErrorType, error) { switch response.StatusCode { case http.StatusBadRequest: diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index 926a00de6e..b28569c3cd 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -6,44 +6,100 @@ import ( "encoding/json" "io" "log/slog" - "maps" "net/http" - "strconv" "time" "github.com/nexus-rpc/sdk-go/nexus" ) -// NewCompletionHTTPRequest creates an HTTP request that delivers an operation completion to a given URL. -func NewCompletionHTTPRequest(ctx context.Context, url string, completion OperationCompletion) (*http.Request, error) { +// CompletionHTTPClient is a client for sending Nexus operation completion callbacks via HTTP. +type CompletionHTTPClient struct { + baseHTTPClient +} + +// CompletionHTTPClientOptions are options for [NewCompletionHTTPClient]. +type CompletionHTTPClientOptions struct { + // A function for making HTTP requests. + // Defaults to [http.DefaultClient.Do]. + HTTPCaller func(*http.Request) (*http.Response, error) + // A [serializer] to customize client serialization behavior. + // By default the client handles JSONables, byte slices, and nil. + Serializer nexus.Serializer + // A [failureConverter] to convert a [Failure] instance to and from an [error]. Defaults to + // [DefaultFailureConverter]. + FailureConverter FailureConverter +} + +// NewCompletionHTTPClient constructs a [CompletionHTTPClient] from given options for sending Nexus operation completion +// callbacks via HTTP. +func NewCompletionHTTPClient(options CompletionHTTPClientOptions) *CompletionHTTPClient { + if options.HTTPCaller == nil { + options.HTTPCaller = http.DefaultClient.Do + } + if options.Serializer == nil { + options.Serializer = nexus.DefaultSerializer() + } + if options.FailureConverter == nil { + options.FailureConverter = DefaultFailureConverter() + } + return &CompletionHTTPClient{ + baseHTTPClient: baseHTTPClient{ + httpCaller: options.HTTPCaller, + serializer: options.Serializer, + failureConverter: options.FailureConverter, + }, + } +} + +// CompleteOperation sends a completion callback for a Nexus operation to the given URL with the given completion details. +func (c *CompletionHTTPClient) CompleteOperation(ctx context.Context, url string, completion OperationCompletion) error { httpReq, err := http.NewRequestWithContext(ctx, "POST", url, nil) if err != nil { - return nil, err + return err } - if err := completion.applyToHTTPRequest(httpReq); err != nil { - return nil, err + if err := completion.applyToHTTPRequest(c, httpReq); err != nil { + return err + } + + response, err := c.httpCaller(httpReq) + if err != nil { + return err } - httpReq.Header.Set(headerUserAgent, userAgent) - return httpReq, nil + if response.StatusCode >= 200 && response.StatusCode < 300 { + // Body is not read but should be discarded to keep the underlying TCP connection alive. + // Just in case something unexpected happens while discarding or closing the body, + // propagate errors to the machine. + if _, err = io.Copy(io.Discard, response.Body); err == nil { + if err = response.Body.Close(); err != nil { + return err + } + } + return nil + } + + body, err := readAndReplaceBody(response) + if err != nil { + return err + } + + return c.bestEffortHandlerErrorFromResponse(response, body) } -// OperationCompletion is input for [NewCompletionHTTPRequest]. +// OperationCompletion is input for CompleteOperation. // It has two implementations: [OperationCompletionSuccessful] and [OperationCompletionUnsuccessful]. type OperationCompletion interface { - applyToHTTPRequest(*http.Request) error + SetHeader(key, value string) + applyToHTTPRequest(*CompletionHTTPClient, *http.Request) error } -// OperationCompletionSuccessful is input for [NewCompletionHTTPRequest], used to deliver successful operation results. +// OperationCompletionSuccessful is input for CompleteOperation, used to deliver successful operation results. type OperationCompletionSuccessful struct { // Header to send in the completion request. // Note that this is a Nexus header, not an HTTP header. Header nexus.Header - - // A [Reader] that may be directly set on the completion or constructed when instantiating via - // [NewOperationCompletionSuccessful]. - // Automatically closed when the completion is delivered. - Reader *nexus.Reader + // Result to deliver with the completion. Uses the client's serializer to serialize the result into the request body. + Result any // OperationToken is the unique token for this operation. Used when a completion callback is received before a // started response. OperationToken string @@ -55,70 +111,44 @@ type OperationCompletionSuccessful struct { Links []nexus.Link } -// OperationCompletionSuccessfulOptions are options for [NewOperationCompletionSuccessful]. -type OperationCompletionSuccessfulOptions struct { - // Optional serializer for the result. Defaults to the SDK's default Serializer, which handles JSONables, byte - // slices and nils. - Serializer nexus.Serializer - // OperationToken is the unique token for this operation. Used when a completion callback is received before a - // started response. - OperationToken string - // StartTime is the time the operation started. Used when a completion callback is received before a started response. - StartTime time.Time - // CloseTime is the time the operation completed. Used when a completion callback is received before a started response. - CloseTime time.Time - // Links are used to link back to the operation when a completion callback is received before a started response. - Links []nexus.Link +func (c *OperationCompletionSuccessful) SetHeader(key, value string) { + if c.Header == nil { + c.Header = make(nexus.Header, 1) + } + c.Header[key] = value } -// NewOperationCompletionSuccessful constructs an [OperationCompletionSuccessful] from a given result. -func NewOperationCompletionSuccessful(result any, options OperationCompletionSuccessfulOptions) (*OperationCompletionSuccessful, error) { - reader, ok := result.(*nexus.Reader) +func (c *OperationCompletionSuccessful) applyToHTTPRequest(cc *CompletionHTTPClient, request *http.Request) error { + reader, ok := c.Result.(*nexus.Reader) if !ok { - content, ok := result.(*nexus.Content) + content, ok := c.Result.(*nexus.Content) if !ok { - serializer := options.Serializer - if serializer == nil { - serializer = nexus.DefaultSerializer() - } var err error - content, err = serializer.Serialize(result) + content, err = cc.serializer.Serialize(c.Result) if err != nil { - return nil, err + return err } } - header := maps.Clone(content.Header) - if header == nil { - header = make(nexus.Header, 1) - } - header["length"] = strconv.Itoa(len(content.Data)) + request.ContentLength = int64(len(content.Data)) reader = &nexus.Reader{ - Header: header, + Header: content.Header, ReadCloser: io.NopCloser(bytes.NewReader(content.Data)), } } - - return &OperationCompletionSuccessful{ - Header: make(nexus.Header), - Reader: reader, - OperationToken: options.OperationToken, - StartTime: options.StartTime, - CloseTime: options.CloseTime, - Links: options.Links, - }, nil -} - -func (c *OperationCompletionSuccessful) applyToHTTPRequest(request *http.Request) error { if request.Header == nil { - request.Header = make(http.Header, len(c.Header)+len(c.Reader.Header)+1) // +1 for headerOperationState + request.Header = make(http.Header, len(c.Header)+len(reader.Header)+1) // +1 for headerOperationState } - if c.Reader.Header != nil { - addContentHeaderToHTTPHeader(c.Reader.Header, request.Header) + if reader.Header != nil { + addContentHeaderToHTTPHeader(reader.Header, request.Header) } if c.Header != nil { addNexusHeaderToHTTPHeader(c.Header, request.Header) } + if c.Header.Get(headerUserAgent) == "" { + request.Header.Set(headerUserAgent, userAgent) + } + request.Header.Set(headerOperationState, string(nexus.OperationStateSucceeded)) if c.Header.Get(nexus.HeaderOperationToken) == "" && c.OperationToken != "" { @@ -136,7 +166,7 @@ func (c *OperationCompletionSuccessful) applyToHTTPRequest(request *http.Request } } - request.Body = c.Reader.ReadCloser + request.Body = reader.ReadCloser return nil } @@ -146,8 +176,6 @@ type OperationCompletionUnsuccessful struct { // Header to send in the completion request. // Note that this is a Nexus header, not an HTTP header. Header nexus.Header - // State of the operation, should be failed or canceled. - State nexus.OperationState // OperationToken is the unique token for this operation. Used when a completion callback is received before a // started response. OperationToken string @@ -157,61 +185,39 @@ type OperationCompletionUnsuccessful struct { CloseTime time.Time // Links are used to link back to the operation when a completion callback is received before a started response. Links []nexus.Link - // Failure object to send with the completion. - Failure nexus.Failure + // Error to send with the completion. + Error *nexus.OperationError } -// OperationCompletionUnsuccessfulOptions are options for [NewOperationCompletionUnsuccessful]. -type OperationCompletionUnsuccessfulOptions struct { - // A [FailureConverter] to convert a [Failure] instance to and from an [error]. Defaults to - // [DefaultFailureConverter]. - FailureConverter FailureConverter - // OperationID is the unique ID for this operation. Used when a completion callback is received before a started response. - // - // Deprecated: Use OperationToken instead. - OperationID string - // OperationToken is the unique token for this operation. Used when a completion callback is received before a - // started response. - OperationToken string - // StartTime is the time the operation started. Used when a completion callback is received before a started response. - StartTime time.Time - // CloseTime is the time the operation completed. This may be different from the time the completion callback is delivered. - CloseTime time.Time - // Links are used to link back to the operation when a completion callback is received before a started response. - Links []nexus.Link +func (c *OperationCompletionUnsuccessful) SetHeader(key, value string) { + if c.Header == nil { + c.Header = make(nexus.Header, 1) + } + c.Header[key] = value } -// NewOperationCompletionUnsuccessful constructs an [OperationCompletionUnsuccessful] from a given error. -func NewOperationCompletionUnsuccessful(opErr *nexus.OperationError, options OperationCompletionUnsuccessfulOptions) (*OperationCompletionUnsuccessful, error) { - if options.FailureConverter == nil { - options.FailureConverter = DefaultFailureConverter() - } - failure, err := options.FailureConverter.ErrorToFailure(opErr) +func (c *OperationCompletionUnsuccessful) applyToHTTPRequest(cc *CompletionHTTPClient, request *http.Request) error { + failure, err := cc.failureConverter.ErrorToFailure(c.Error) if err != nil { - return nil, err - } - - return &OperationCompletionUnsuccessful{ - Header: make(nexus.Header), - State: opErr.State, - Failure: failure, - OperationToken: options.OperationToken, - StartTime: options.StartTime, - CloseTime: options.CloseTime, - Links: options.Links, - }, nil -} + return err + } + // Backwards compatibility: if the failure has a cause, unwrap it to maintain the behavior as older servers. + if failure.Cause != nil { + failure = *failure.Cause + } -func (c *OperationCompletionUnsuccessful) applyToHTTPRequest(request *http.Request) error { if request.Header == nil { request.Header = make(http.Header, len(c.Header)+2) // +2 for headerOperationState and content-type } if c.Header != nil { addNexusHeaderToHTTPHeader(c.Header, request.Header) } + if c.Header.Get(headerUserAgent) == "" { + request.Header.Set(headerUserAgent, userAgent) + } // Set the operation state header for backwards compatibility. - request.Header.Set(headerOperationState, string(c.State)) + request.Header.Set(headerOperationState, string(c.Error.State)) request.Header.Set("Content-Type", contentTypeJSON) if c.Header.Get(nexus.HeaderOperationToken) == "" && c.OperationToken != "" { @@ -229,7 +235,7 @@ func (c *OperationCompletionUnsuccessful) applyToHTTPRequest(request *http.Reque } } - b, err := json.Marshal(c.Failure) + b, err := json.Marshal(failure) if err != nil { return err } @@ -281,7 +287,7 @@ type CompletionHandlerOptions struct { } type completionHTTPHandler struct { - baseHTTPHandler + BaseHTTPHandler options CompletionHandlerOptions } @@ -295,59 +301,55 @@ func (h *completionHTTPHandler) ServeHTTP(writer http.ResponseWriter, request *h if startTimeHeader := request.Header.Get(headerOperationStartTime); startTimeHeader != "" { var parseTimeErr error if completion.StartTime, parseTimeErr = http.ParseTime(startTimeHeader); parseTimeErr != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse operation start time header")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse operation start time header")) return } } if closeTimeHeader := request.Header.Get(headerOperationCloseTime); closeTimeHeader != "" { var parseTimeErr error if completion.CloseTime, parseTimeErr = unmarshalTimestamp(closeTimeHeader); parseTimeErr != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse operation close time header")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse operation close time header")) return } } var decodeErr error if completion.Links, decodeErr = getLinksFromHeader(request.Header); decodeErr != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode links from request headers")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode links from request headers")) return } switch completion.State { case nexus.OperationStateFailed, nexus.OperationStateCanceled: if !isMediaTypeJSON(request.Header.Get("Content-Type")) { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request content type: %q", request.Header.Get("Content-Type"))) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request content type: %q", request.Header.Get("Content-Type"))) return } var failure nexus.Failure b, err := io.ReadAll(request.Body) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to read Failure from request body")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to read Failure from request body")) return } if err := json.Unmarshal(b, &failure); err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to read Failure from request body")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode Failure from request body")) return } - completionErr, err := h.failureConverter.FailureToError(failure) + completionErr, err := h.FailureConverter.FailureToError(failure) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode failure from request body")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode Failure from request body")) return } opErr, ok := completionErr.(*nexus.OperationError) if !ok { // Backwards compatibility: wrap non-OperationError errors in an OperationError with the appropriate state. completion.Error = &nexus.OperationError{ - // Not adding Message here to ensure the failure is unwrapped (behavior of the Nexus failure converter to - // maintain backwards compatibility). - // After server version 1.31.0 is out, we can add the message back. - State: completion.State, - Cause: completionErr, + Message: "nexus operation completed unsuccessfully", + State: completion.State, + Cause: completionErr, } - originalFailure, err := h.failureConverter.ErrorToFailure(completion.Error) - if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode failure from request body")) + if err := MarkAsWrapperError(h.FailureConverter, completion.Error); err != nil { + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode Failure from request body")) return } - completion.Error.OriginalFailure = &originalFailure } else { completion.Error = opErr } @@ -360,11 +362,11 @@ func (h *completionHTTPHandler) ServeHTTP(writer http.ResponseWriter, request *h }, ) default: - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request operation state: %q", completion.State)) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request operation state: %q", completion.State)) return } if err := h.options.Handler.CompleteOperation(ctx, &completion); err != nil { - h.writeFailure(writer, err) + h.WriteFailure(writer, request, err) } } @@ -381,9 +383,9 @@ func NewCompletionHTTPHandler(options CompletionHandlerOptions) http.Handler { } return &completionHTTPHandler{ options: options, - baseHTTPHandler: baseHTTPHandler{ - logger: options.Logger, - failureConverter: options.FailureConverter, + BaseHTTPHandler: BaseHTTPHandler{ + Logger: options.Logger, + FailureConverter: options.FailureConverter, }, } } diff --git a/common/nexus/nexusrpc/completion_test.go b/common/nexus/nexusrpc/completion_test.go index a566705749..dc075b4f47 100644 --- a/common/nexus/nexusrpc/completion_test.go +++ b/common/nexus/nexusrpc/completion_test.go @@ -3,8 +3,6 @@ package nexusrpc_test import ( "context" "errors" - "io" - "net/http" "net/url" "testing" "time" @@ -33,28 +31,28 @@ func validateExpectedTime(expected, actual time.Time, resolution time.Duration) func (h *successfulCompletionHandler) CompleteOperation(ctx context.Context, completion *nexusrpc.CompletionRequest) error { if completion.HTTPRequest.URL.Path != "/callback" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL path: %q", completion.HTTPRequest.URL.Path) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL path: %q", completion.HTTPRequest.URL.Path) } if completion.HTTPRequest.URL.Query().Get("a") != "b" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'a' query param: %q", completion.HTTPRequest.URL.Query().Get("a")) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'a' query param: %q", completion.HTTPRequest.URL.Query().Get("a")) } if completion.HTTPRequest.Header.Get("foo") != "bar" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' header: %q", completion.HTTPRequest.Header.Get("foo")) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' header: %q", completion.HTTPRequest.Header.Get("foo")) } if completion.HTTPRequest.Header.Get("User-Agent") != "temporalio/server" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", completion.HTTPRequest.Header.Get("User-Agent")) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", completion.HTTPRequest.Header.Get("User-Agent")) } if completion.OperationToken != "test-operation-token" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation token: %q", completion.OperationToken) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation token: %q", completion.OperationToken) } if len(completion.Links) == 0 { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected Links to be set on CompletionRequest") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected Links to be set on CompletionRequest") } if !validateExpectedTime(h.expectedStartTime, completion.StartTime, time.Second) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected StartTime to be equal") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected StartTime to be equal") } if !validateExpectedTime(h.expectedCloseTime, completion.CloseTime, time.Millisecond) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected CloseTime to be equal") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected CloseTime to be equal") } var result int err := completion.Result.Consume(&result) @@ -62,7 +60,7 @@ func (h *successfulCompletionHandler) CompleteOperation(ctx context.Context, com return err } if result != 666 { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid result: %q", result) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid result: %q", result) } return nil } @@ -77,7 +75,8 @@ func TestSuccessfulCompletion(t *testing.T) { }, nil, nil) defer teardown() - completion, err := nexusrpc.NewOperationCompletionSuccessful(666, nexusrpc.OperationCompletionSuccessfulOptions{ + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: 666, OperationToken: "test-operation-token", StartTime: startTime, CloseTime: closeTime, @@ -90,19 +89,11 @@ func TestSuccessfulCompletion(t *testing.T) { }, Type: "url", }}, - }) - completion.Header.Set("foo", "bar") - require.NoError(t, err) + Header: nexus.Header{"foo": "bar"}, + } - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackURL, completion) - require.NoError(t, err) - response, err := http.DefaultClient.Do(request) + err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}).CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) - // nolint:errcheck - defer response.Body.Close() - _, err = io.ReadAll(response.Body) - require.NoError(t, err) - require.Equal(t, http.StatusOK, response.StatusCode) } func TestSuccessfulCompletion_CustomSerializer(t *testing.T) { @@ -110,8 +101,8 @@ func TestSuccessfulCompletion_CustomSerializer(t *testing.T) { ctx, callbackURL, teardown := setupForCompletion(t, &successfulCompletionHandler{}, serializer, nil) defer teardown() - completion, err := nexusrpc.NewOperationCompletionSuccessful(666, nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: serializer, + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: 666, Links: []nexus.Link{{ URL: &url.URL{ Scheme: "https", @@ -121,20 +112,16 @@ func TestSuccessfulCompletion_CustomSerializer(t *testing.T) { }, Type: "url", }}, - }) + Header: nexus.Header{"foo": "bar"}, + } + completion.Header.Set("foo", "bar") completion.Header.Set(nexus.HeaderOperationToken, "test-operation-token") - require.NoError(t, err) - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackURL, completion) - require.NoError(t, err) - response, err := http.DefaultClient.Do(request) - require.NoError(t, err) - // nolint:errcheck - defer response.Body.Close() - _, err = io.ReadAll(response.Body) + err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: serializer, + }).CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) - require.Equal(t, http.StatusOK, response.StatusCode) require.Equal(t, 1, serializer.decoded) require.Equal(t, 1, serializer.encoded) @@ -148,25 +135,25 @@ type failureExpectingCompletionHandler struct { func (h *failureExpectingCompletionHandler) CompleteOperation(ctx context.Context, completion *nexusrpc.CompletionRequest) error { if completion.State != nexus.OperationStateCanceled { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected completion state: %q", completion.State) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected completion state: %q", completion.State) } if err := h.errorChecker(completion.Error); err != nil { return err } if completion.HTTPRequest.Header.Get("foo") != "bar" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' header: %q", completion.HTTPRequest.Header.Get("foo")) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' header: %q", completion.HTTPRequest.Header.Get("foo")) } if completion.OperationToken != "test-operation-token" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation token: %q", completion.OperationToken) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation token: %q", completion.OperationToken) } if len(completion.Links) == 0 { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected Links to be set on CompletionRequest") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected Links to be set on CompletionRequest") } if !validateExpectedTime(h.expectedStartTime, completion.StartTime, time.Second) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected StartTime to be equal") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected StartTime to be equal") } if !validateExpectedTime(h.expectedCloseTime, completion.CloseTime, time.Millisecond) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected CloseTime to be equal") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected CloseTime to be equal") } return nil @@ -181,14 +168,15 @@ func TestFailureCompletion(t *testing.T) { if opErr, ok := err.(*nexus.OperationError); ok && opErr.Message == "expected message" { return nil } - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid failure: %v", err) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid failure: %v", err) }, expectedStartTime: startTime, expectedCloseTime: closeTime, }, nil, nil) defer teardown() - completion, err := nexusrpc.NewOperationCompletionUnsuccessful(nexus.NewOperationCanceledError("expected message"), nexusrpc.OperationCompletionUnsuccessfulOptions{ + completion := &nexusrpc.OperationCompletionUnsuccessful{ + Error: nexus.NewOperationCanceledErrorf("expected message"), OperationToken: "test-operation-token", StartTime: startTime, CloseTime: closeTime, @@ -201,18 +189,10 @@ func TestFailureCompletion(t *testing.T) { }, Type: "url", }}, - }) - require.NoError(t, err) - completion.Header.Set("foo", "bar") - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackURL, completion) - require.NoError(t, err) - response, err := http.DefaultClient.Do(request) - require.NoError(t, err) - // nolint:errcheck - defer response.Body.Close() - _, err = io.ReadAll(response.Body) + Header: nexus.Header{"foo": "bar"}, + } + err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}).CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) - require.Equal(t, http.StatusOK, response.StatusCode) } func TestFailureCompletion_CustomFailureConverter(t *testing.T) { @@ -223,7 +203,7 @@ func TestFailureCompletion_CustomFailureConverter(t *testing.T) { ctx, callbackURL, teardown := setupForCompletion(t, &failureExpectingCompletionHandler{ errorChecker: func(err error) error { if !errors.Is(err, errCustom) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid failure, expected a custom error: %v", err) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid failure, expected a custom error: %v", err) } return nil }, @@ -232,11 +212,11 @@ func TestFailureCompletion_CustomFailureConverter(t *testing.T) { }, nil, fc) defer teardown() - completion, err := nexusrpc.NewOperationCompletionUnsuccessful(nexus.NewOperationCanceledError("expected message"), nexusrpc.OperationCompletionUnsuccessfulOptions{ - FailureConverter: fc, - OperationToken: "test-operation-token", - StartTime: startTime, - CloseTime: closeTime, + completion := &nexusrpc.OperationCompletionUnsuccessful{ + Error: nexus.NewOperationCanceledErrorf("expected message"), + OperationToken: "test-operation-token", + StartTime: startTime, + CloseTime: closeTime, Links: []nexus.Link{{ URL: &url.URL{ Scheme: "https", @@ -246,40 +226,30 @@ func TestFailureCompletion_CustomFailureConverter(t *testing.T) { }, Type: "url", }}, - }) - require.NoError(t, err) - completion.Header.Set("foo", "bar") - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackURL, completion) - require.NoError(t, err) - response, err := http.DefaultClient.Do(request) - require.NoError(t, err) - // nolint:errcheck - defer response.Body.Close() - _, err = io.ReadAll(response.Body) + Header: nexus.Header{"foo": "bar"}, + } + err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + FailureConverter: fc, + }).CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) - require.Equal(t, http.StatusOK, response.StatusCode) } type failingCompletionHandler struct { } func (h *failingCompletionHandler) CompleteOperation(ctx context.Context, completion *nexusrpc.CompletionRequest) error { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "I can't get no satisfaction") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "I can't get no satisfaction") } func TestBadRequestCompletion(t *testing.T) { ctx, callbackURL, teardown := setupForCompletion(t, &failingCompletionHandler{}, nil, nil) defer teardown() - completion, err := nexusrpc.NewOperationCompletionSuccessful([]byte("success"), nexusrpc.OperationCompletionSuccessfulOptions{}) - require.NoError(t, err) - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackURL, completion) - require.NoError(t, err) - response, err := http.DefaultClient.Do(request) - require.NoError(t, err) - // nolint:errcheck - defer response.Body.Close() - _, err = io.ReadAll(response.Body) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, response.StatusCode) + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: []byte("success"), + } + err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}).CompleteOperation(ctx, callbackURL, completion) + var handlerErr *nexus.HandlerError + require.ErrorAs(t, err, &handlerErr) + require.Equal(t, nexus.HandlerErrorTypeBadRequest, handlerErr.Type) } diff --git a/common/nexus/nexusrpc/failure_converter.go b/common/nexus/nexusrpc/failure_converter.go index 04bbe9baaf..befdaec8d1 100644 --- a/common/nexus/nexusrpc/failure_converter.go +++ b/common/nexus/nexusrpc/failure_converter.go @@ -35,9 +35,8 @@ func (e serializedHandlerError) RetryBehavior() nexus.HandlerErrorRetryBehavior } if *e.RetryableOverride { return nexus.HandlerErrorRetryBehaviorRetryable - } else { - return nexus.HandlerErrorRetryBehaviorNonRetryable } + return nexus.HandlerErrorRetryBehaviorNonRetryable } type serializedOperationError struct { @@ -45,6 +44,7 @@ type serializedOperationError struct { } // ErrorToFailure implements FailureConverter. +// nolint:revive // Keeping all of the logic together for readability, even if it means the function is long. func (e knownErrorFailureConverter) ErrorToFailure(err error) (nexus.Failure, error) { if err == nil { return nexus.Failure{}, nil @@ -66,10 +66,6 @@ func (e knownErrorFailureConverter) ErrorToFailure(err error) (nexus.Failure, er if typedErr.OriginalFailure != nil { return *typedErr.OriginalFailure, nil } - // Temporary workaround for compatibility with old SDKs that don't support handler error messages. - if typedErr.Message == "" && typedErr.Cause != nil { - return e.ErrorToFailure(typedErr.Cause) - } data := serializedHandlerError{ Type: string(typedErr.Type), RetryableOverride: retryBehaviorAsOptionalBool(typedErr), @@ -100,10 +96,6 @@ func (e knownErrorFailureConverter) ErrorToFailure(err error) (nexus.Failure, er if typedErr.OriginalFailure != nil { return *typedErr.OriginalFailure, nil } - // Temporary workaround for compatibility with old SDKs that don't support operation error messages. - if typedErr.Message == "" && typedErr.Cause != nil { - return e.ErrorToFailure(typedErr.Cause) - } data := serializedOperationError{ State: string(typedErr.State), } @@ -136,6 +128,7 @@ func (e knownErrorFailureConverter) ErrorToFailure(err error) (nexus.Failure, er } // FailureToError implements FailureConverter. +// nolint:revive // Keeping all of the logic together for readability, even if it means the function is long. func (e knownErrorFailureConverter) FailureToError(f nexus.Failure) (error, error) { if f.Metadata != nil { switch f.Metadata["type"] { @@ -204,7 +197,8 @@ func DefaultFailureConverter() FailureConverter { return defaultFailureConverter } -func retryBehaviorAsOptionalBool(e *nexus.HandlerError) *bool { +func retryBehaviorAsOptionalBool(e *nexus.HandlerError) *bool { + // nolint:exhaustive // this is a simple optional boolean. switch e.RetryBehavior { case nexus.HandlerErrorRetryBehaviorRetryable: ret := true @@ -215,4 +209,3 @@ func retryBehaviorAsOptionalBool(e *nexus.HandlerError) *bool { } return nil } - diff --git a/common/nexus/nexusrpc/failure_conveter_test.go b/common/nexus/nexusrpc/failure_conveter_test.go index 516eca86f7..2cb94b9757 100644 --- a/common/nexus/nexusrpc/failure_conveter_test.go +++ b/common/nexus/nexusrpc/failure_conveter_test.go @@ -47,7 +47,7 @@ func TestFailureConverter_FailureError(t *testing.T) { func TestFailureConverter_HandlerError(t *testing.T) { cause := &nexus.FailureError{Failure: nexus.Failure{Message: "cause"}} - he := nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "foo") + he := nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "foo") he.StackTrace = "stack" he.Cause = cause failure, err := defaultFailureConverter.ErrorToFailure(he) @@ -70,7 +70,7 @@ func TestFailureConverter_HandlerError(t *testing.T) { } func TestFailureConverter_HandlerErrorRetryBehavior(t *testing.T) { - he := nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "foo") + he := nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "foo") he.StackTrace = "stack" he.RetryBehavior = nexus.HandlerErrorRetryBehaviorRetryable failure, err := defaultFailureConverter.ErrorToFailure(he) @@ -79,15 +79,12 @@ func TestFailureConverter_HandlerErrorRetryBehavior(t *testing.T) { he.OriginalFailure = &failure actual, err := defaultFailureConverter.FailureToError(failure) require.NoError(t, err) - - // Failure is rehydrated as failure error if it has no known type. - he.Cause = &nexus.FailureError{Failure: nexus.Failure{Message: "foo"}} require.Equal(t, he, actual) } func TestFailureConverter_OperationError(t *testing.T) { cause := &nexus.FailureError{Failure: nexus.Failure{Message: "cause"}} - oe := nexus.NewOperationCanceledError("foo") + oe := nexus.NewOperationCanceledErrorf("foo") oe.StackTrace = "stack" oe.Cause = cause failure, err := defaultFailureConverter.ErrorToFailure(oe) @@ -109,7 +106,6 @@ func TestFailureConverter_OperationError(t *testing.T) { require.Equal(t, oe, actual) } - func TestDefaultFailureConverterArbitraryError(t *testing.T) { sourceErr := errors.New("test") conv := defaultFailureConverter diff --git a/common/nexus/nexusrpc/handle.go b/common/nexus/nexusrpc/handle.go index a5f89b8154..ca8f75a9a7 100644 --- a/common/nexus/nexusrpc/handle.go +++ b/common/nexus/nexusrpc/handle.go @@ -22,7 +22,7 @@ type OperationHandle[T any] struct { // // Cancelation is asynchronous and may be not be respected by the operation's implementation. func (h *OperationHandle[T]) Cancel(ctx context.Context, options nexus.CancelOperationOptions) error { - u := h.client.serviceBaseURL.JoinPath(url.PathEscape(h.client.options.Service), url.PathEscape(h.Operation), "cancel") + u := h.client.serviceBaseURL.JoinPath(url.PathEscape(h.client.service), url.PathEscape(h.Operation), "cancel") request, err := http.NewRequestWithContext(ctx, "POST", u.String(), nil) if err != nil { return err @@ -31,7 +31,7 @@ func (h *OperationHandle[T]) Cancel(ctx context.Context, options nexus.CancelOpe addContextTimeoutToHTTPHeader(ctx, request.Header) request.Header.Set(headerUserAgent, userAgent) addNexusHeaderToHTTPHeader(options.Header, request.Header) - response, err := h.client.options.HTTPCaller(request) + response, err := h.client.httpCaller(request) if err != nil { return err } diff --git a/common/nexus/nexusrpc/server.go b/common/nexus/nexusrpc/server.go index db3247408a..5b0a55a9e0 100644 --- a/common/nexus/nexusrpc/server.go +++ b/common/nexus/nexusrpc/server.go @@ -18,10 +18,10 @@ import ( "github.com/nexus-rpc/sdk-go/nexus" ) -func applyResultToHTTPResponse(r nexus.HandlerStartOperationResult[any], writer http.ResponseWriter, handler *httpHandler) { +func applyResultToHTTPResponse(r nexus.HandlerStartOperationResult[any], writer http.ResponseWriter, request *http.Request, handler *httpHandler) { switch r := r.(type) { case interface{ ValueAsAny() any }: - handler.writeResult(writer, r.ValueAsAny()) + handler.writeResult(writer, request, r.ValueAsAny()) case *nexus.HandlerStartOperationResultAsync: info := nexus.OperationInfo{ Token: r.OperationToken, @@ -29,7 +29,7 @@ func applyResultToHTTPResponse(r nexus.HandlerStartOperationResult[any], writer } b, err := json.Marshal(info) if err != nil { - handler.logger.Error("failed to serialize operation info", "error", err) + handler.Logger.Error("failed to serialize operation info", "error", err) writer.WriteHeader(http.StatusInternalServerError) return } @@ -37,22 +37,22 @@ func applyResultToHTTPResponse(r nexus.HandlerStartOperationResult[any], writer writer.Header().Set("Content-Type", contentTypeJSON) writer.WriteHeader(http.StatusCreated) if _, err := writer.Write(b); err != nil { - handler.logger.Error("failed to write response body", "error", err) + handler.Logger.Error("failed to write response body", "error", err) } } } -type baseHTTPHandler struct { - logger *slog.Logger - failureConverter FailureConverter +type BaseHTTPHandler struct { + Logger *slog.Logger + FailureConverter FailureConverter } type httpHandler struct { - baseHTTPHandler + BaseHTTPHandler options HandlerOptions } -func (h *httpHandler) writeResult(writer http.ResponseWriter, result any) { +func (h *httpHandler) writeResult(writer http.ResponseWriter, request *http.Request, result any) { var reader *nexus.Reader if r, ok := result.(*nexus.Reader); ok { // Close the request body in case we error before sending the HTTP request (which may double close but @@ -66,7 +66,7 @@ func (h *httpHandler) writeResult(writer http.ResponseWriter, result any) { var err error content, err = h.options.Serializer.Serialize(result) if err != nil { - h.writeFailure(writer, fmt.Errorf("failed to serialize handler result: %w", err)) + h.WriteFailure(writer, request, fmt.Errorf("failed to serialize handler result: %w", err)) return } } @@ -85,11 +85,14 @@ func (h *httpHandler) writeResult(writer http.ResponseWriter, result any) { return } if _, err := io.Copy(writer, reader); err != nil { - h.logger.Error("failed to write response body", "error", err) + h.Logger.Error("failed to write response body", "error", err) } } -func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { +// WriteFailure writes the given error to the response, converting it to a Failure if possible. It also sets the HTTP +// status code based on the type of error. +// nolint:revive // Keeping all of the logic together for readability, even if it means the function is long. +func (h *BaseHTTPHandler) WriteFailure(writer http.ResponseWriter, r *http.Request, err error) { var failure nexus.Failure var failureError *nexus.FailureError var opError *nexus.OperationError @@ -100,28 +103,37 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { if errors.As(err, &opError) { operationState = opError.State var convErr error - failure, convErr = h.failureConverter.ErrorToFailure(opError) + failure, convErr = h.FailureConverter.ErrorToFailure(opError) if convErr != nil { - h.logger.Error("failed to convert operation error to failure", "error", convErr) + h.Logger.Error("failed to convert operation error to failure", "error", convErr) writer.WriteHeader(http.StatusInternalServerError) return } - statusCode = statusOperationFailed + // Backward compatibility, unwrap the failure cause. + if r.Header.Get("temporal-nexus-failure-support") != "true" && failure.Cause != nil { + failure = *failure.Cause + } + + statusCode = statusOperationUnsuccessful if operationState != nexus.OperationStateFailed && operationState != nexus.OperationStateCanceled { - h.logger.Error("unexpected operation state", "state", operationState) + h.Logger.Error("unexpected operation state", "state", operationState) writer.WriteHeader(http.StatusInternalServerError) return } writer.Header().Set(headerOperationState, string(operationState)) } else if errors.As(err, &handlerError) { var convErr error - failure, convErr = h.failureConverter.ErrorToFailure(handlerError) + failure, convErr = h.FailureConverter.ErrorToFailure(handlerError) if convErr != nil { - h.logger.Error("failed to convert handler error to failure", "error", convErr) + h.Logger.Error("failed to convert handler error to failure", "error", convErr) writer.WriteHeader(http.StatusInternalServerError) return } + // Backward compatibility, unwrap the failure cause. + if r.Header.Get("temporal-nexus-failure-support") != "true" && failure.Cause != nil { + failure = *failure.Cause + } switch handlerError.Type { case nexus.HandlerErrorTypeBadRequest: statusCode = http.StatusBadRequest @@ -146,7 +158,7 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { case nexus.HandlerErrorTypeUpstreamTimeout: statusCode = nexus.StatusUpstreamTimeout default: - h.logger.Error("unexpected handler error type", "type", handlerError.Type) + h.Logger.Error("unexpected handler error type", "type", handlerError.Type) } } else if errors.As(err, &failureError) { failure = failureError.Failure @@ -154,12 +166,12 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { failure = nexus.Failure{ Message: "internal server error", } - h.logger.Error("handler failed", "error", err) + h.Logger.Error("handler failed", "error", err) } b, err := json.Marshal(failure) if err != nil { - h.logger.Error("failed to marshal failure", "error", err) + h.Logger.Error("failed to marshal failure", "error", err) writer.WriteHeader(http.StatusInternalServerError) return } @@ -181,14 +193,14 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { writer.WriteHeader(statusCode) if _, err := writer.Write(b); err != nil { - h.logger.Error("failed to write response body", "error", err) + h.Logger.Error("failed to write response body", "error", err) } } func (h *httpHandler) startOperation(service, operation string, writer http.ResponseWriter, request *http.Request) { links, err := getLinksFromHeader(request.Header) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid %q header", headerLink)) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid %q header", headerLink)) return } options := nexus.StartOperationOptions{ @@ -219,16 +231,16 @@ func (h *httpHandler) startOperation(service, operation string, writer http.Resp }) response, err := h.options.Handler.StartOperation(ctx, service, operation, value, options) if err != nil { - h.writeFailure(writer, err) + h.WriteFailure(writer, request, err) } else { if err := addLinksToHTTPHeader(nexus.HandlerLinks(ctx), writer.Header()); err != nil { - h.logger.Error("failed to serialize links into header", "error", err) + h.Logger.Error("failed to serialize links into header", "error", err) // clear any previous links already written to the header writer.Header().Del(headerLink) writer.WriteHeader(http.StatusInternalServerError) return } - applyResultToHTTPResponse(response, writer, h) + applyResultToHTTPResponse(response, writer, request, h) } } @@ -247,7 +259,7 @@ func (h *httpHandler) cancelOperation(service, operation, token string, writer h Header: options.Header, }) if err := h.options.Handler.CancelOperation(ctx, service, operation, token, options); err != nil { - h.writeFailure(writer, err) + h.WriteFailure(writer, request, err) return } @@ -262,8 +274,8 @@ func (h *httpHandler) parseRequestTimeoutHeader(writer http.ResponseWriter, requ if timeoutStr != "" { timeoutDuration, err := ParseDuration(timeoutStr) if err != nil { - h.logger.Warn("invalid request timeout header", "timeout", timeoutStr) - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request timeout header")) + h.Logger.Warn("invalid request timeout header", "timeout", timeoutStr) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request timeout header")) return 0, false } return timeoutDuration, true @@ -307,23 +319,23 @@ type HandlerOptions struct { func (h *httpHandler) handleRequest(writer http.ResponseWriter, request *http.Request) { if request.Method != "POST" { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request method: expected POST, got %q", request.Method)) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request method: expected POST, got %q", request.Method)) return } parts := strings.Split(request.URL.EscapedPath(), "/") // First part is empty (due to leading /) if len(parts) < 3 { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "not found")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "not found")) return } service, err := url.PathUnescape(parts[1]) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) return } operation, err := url.PathUnescape(parts[2]) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) return } @@ -338,7 +350,7 @@ func (h *httpHandler) handleRequest(writer http.ResponseWriter, request *http.Re if len(parts) == 5 && parts[4] == "cancel" { token, err := url.PathUnescape(parts[3]) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) return } @@ -347,7 +359,7 @@ func (h *httpHandler) handleRequest(writer http.ResponseWriter, request *http.Re } if len(parts) != 4 || parts[3] != "cancel" { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "not found")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "not found")) return } @@ -355,7 +367,7 @@ func (h *httpHandler) handleRequest(writer http.ResponseWriter, request *http.Re if token == "" { token = request.URL.Query().Get("token") if token == "" { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "missing operation token")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "missing operation token")) return } } else { @@ -381,9 +393,9 @@ func NewHTTPHandler(options HandlerOptions) http.Handler { options.FailureConverter = DefaultFailureConverter() } handler := &httpHandler{ - baseHTTPHandler: baseHTTPHandler{ - logger: options.Logger, - failureConverter: options.FailureConverter, + BaseHTTPHandler: BaseHTTPHandler{ + Logger: options.Logger, + FailureConverter: options.FailureConverter, }, options: options, } diff --git a/common/nexus/nexusrpc/start_test.go b/common/nexus/nexusrpc/start_test.go index 20c90a5798..a06619c0fd 100644 --- a/common/nexus/nexusrpc/start_test.go +++ b/common/nexus/nexusrpc/start_test.go @@ -25,29 +25,29 @@ func (h *successHandler) StartOperation(ctx context.Context, service, operation return nil, err } if service != testService { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected service: %s", service) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected service: %s", service) } if operation != "i need to/be escaped" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected operation: %s", operation) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected operation: %s", operation) } if options.CallbackURL != "http://test/callback" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected callback URL: %s", options.CallbackURL) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected callback URL: %s", options.CallbackURL) } if options.CallbackHeader.Get("callback-test") != "ok" { - return nil, nexus.HandlerErrorf( + return nil, nexus.NewHandlerErrorf( nexus.HandlerErrorTypeBadRequest, "invalid 'callback-test' callback header: %q", options.CallbackHeader.Get("callback-test"), ) } if options.Header.Get("test") != "ok" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'test' header: %q", options.Header.Get("test")) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'test' header: %q", options.Header.Get("test")) } if options.Header.Get("nexus-callback-callback-test") != "" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "callback header not omitted from options Header") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "callback header not omitted from options Header") } if options.Header.Get("User-Agent") != "temporalio/server" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", options.Header.Get("User-Agent")) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", options.Header.Get("User-Agent")) } return &nexus.HandlerStartOperationResultSync[any]{Value: body}, nil @@ -191,7 +191,7 @@ func (h *echoHandler) StartOperation(ctx context.Context, service, operation str Data: data, } default: - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unknown input-type header") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unknown input-type header") } nexus.AddHandlerLinks(ctx, options.Links...) return &nexus.HandlerStartOperationResultSync[any]{Value: output}, nil @@ -377,7 +377,7 @@ func TestStart_NilContentHeaderDoesNotPanic(t *testing.T) { var numberValidatorOperation = nexus.NewSyncOperation("number-validator", func(ctx context.Context, input int, _ nexus.StartOperationOptions) (int, error) { if input != 3 { - return input, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid number: %d", input) + return input, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid number: %d", input) } return input, nil }) diff --git a/common/nexus/payload_serializer.go b/common/nexus/payload_serializer.go index 06e9f63e13..ff4a86438b 100644 --- a/common/nexus/payload_serializer.go +++ b/common/nexus/payload_serializer.go @@ -98,6 +98,10 @@ func setUnknownNexusContent(nexusHeader nexus.Header, payloadMetadata map[string // Serialize implements nexus.Serializer. func (payloadSerializer) Serialize(v any) (*nexus.Content, error) { + if v == nil { + // Use same structure as the nil serializer from the Nexus Go SDK. + return &nexus.Content{Header: nexus.Header{}}, nil + } payload, ok := v.(*commonpb.Payload) if !ok { return nil, fmt.Errorf("%w: cannot serialize %v", errSerializer, v) diff --git a/components/callbacks/chasm_invocation.go b/components/callbacks/chasm_invocation.go index 17e471a899..a88c258ab6 100644 --- a/components/callbacks/chasm_invocation.go +++ b/components/callbacks/chasm_invocation.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "fmt" - "io" "github.com/google/uuid" "github.com/nexus-rpc/sdk-go/nexus" @@ -89,20 +88,12 @@ func (c chasmInvocation) getHistoryRequest( switch op := c.completion.(type) { case *nexusrpc.OperationCompletionSuccessful: - payloadBody, err := io.ReadAll(op.Reader) - if err != nil { - return nil, fmt.Errorf("failed to read payload: %v", err) - } - var payload *commonpb.Payload - if payloadBody != nil { - content := &nexus.Content{ - Header: op.Reader.Header, - Data: payloadBody, - } - err := commonnexus.PayloadSerializer.Deserialize(content, &payload) - if err != nil { - return nil, fmt.Errorf("failed to deserialize payload: %v", err) + if op.Result != nil { + var ok bool + payload, ok = op.Result.(*commonpb.Payload) + if !ok { + return nil, fmt.Errorf("invalid result, expected a payload, got: %T", op.Result) } } @@ -114,9 +105,17 @@ func (c chasmInvocation) getHistoryRequest( Completion: completion, } case *nexusrpc.OperationCompletionUnsuccessful: - apiFailure, err := commonnexus.NexusFailureToTemporalFailure(op.Failure) + failure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(op.Error) + if err != nil { + return nil, fmt.Errorf("failed to convert error to failure: %w", err) + } + // Unwrap the operation error since it's not meant to be sent for Temporal->Temporal completions. + if failure.Cause != nil { + failure = *failure.Cause + } + apiFailure, err := commonnexus.NexusFailureToTemporalFailure(failure) if err != nil { - return nil, fmt.Errorf("failed to convert failure type: %v", err) + return nil, fmt.Errorf("failed to convert failure type: %w", err) } req = &historyservice.CompleteNexusOperationChasmRequest{ diff --git a/components/callbacks/executors_test.go b/components/callbacks/executors_test.go index 75c8c3010d..18d562dd54 100644 --- a/components/callbacks/executors_test.go +++ b/components/callbacks/executors_test.go @@ -80,7 +80,7 @@ func TestProcessInvocationTaskNexus_Outcomes(t *testing.T) { return &http.Response{StatusCode: 200, Body: http.NoBody}, nil }, retryable: false, - expectedMetricOutcome: "status:200", + expectedMetricOutcome: "success", assertOutcome: func(t *testing.T, cb callbacks.Callback) { require.Equal(t, enumsspb.CALLBACK_STATE_SUCCEEDED, cb.State()) }, @@ -102,7 +102,7 @@ func TestProcessInvocationTaskNexus_Outcomes(t *testing.T) { return &http.Response{StatusCode: 500, Body: http.NoBody}, nil }, retryable: true, - expectedMetricOutcome: "status:500", + expectedMetricOutcome: "handler-error:INTERNAL", assertOutcome: func(t *testing.T, cb callbacks.Callback) { require.Equal(t, enumsspb.CALLBACK_STATE_BACKING_OFF, cb.State()) }, @@ -113,7 +113,7 @@ func TestProcessInvocationTaskNexus_Outcomes(t *testing.T) { return &http.Response{StatusCode: 400, Body: http.NoBody}, nil }, retryable: false, - expectedMetricOutcome: "status:400", + expectedMetricOutcome: "handler-error:BAD_REQUEST", assertOutcome: func(t *testing.T, cb callbacks.Callback) { require.Equal(t, enumsspb.CALLBACK_STATE_FAILED, cb.State()) }, @@ -267,8 +267,7 @@ func TestProcessBackoffTask(t *testing.T) { } func newMutableState(t *testing.T) mutableState { - completionNexus, err := nexusrpc.NewOperationCompletionSuccessful(nil, nexusrpc.OperationCompletionSuccessfulOptions{}) - require.NoError(t, err) + completionNexus := &nexusrpc.OperationCompletionSuccessful{} hsmCallbackArg := &persistencespb.HSMCompletionCallbackArg{ NamespaceId: "mynsid", WorkflowId: "mywid", @@ -349,15 +348,10 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { return client }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayload([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - CloseTime: dummyTime, - }, - ) - require.NoError(t, err) - return comp + return &nexusrpc.OperationCompletionSuccessful{ + Result: createPayload([]byte("result-data")), + CloseTime: dummyTime, + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb callbacks.Callback) { @@ -381,17 +375,13 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { return client }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{ + return &nexusrpc.OperationCompletionUnsuccessful{ + Error: &nexus.OperationError{ State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: nexus.Failure{Message: "operation failed"}}, }, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - CloseTime: dummyTime, - }, - ) - require.NoError(t, err) - return comp + CloseTime: dummyTime, + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb callbacks.Callback) { @@ -409,14 +399,9 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { return client }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayload([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + return &nexusrpc.OperationCompletionSuccessful{ + Result: createPayload([]byte("result-data")), + } }(), headerValue: encodedRef, expectsInternalError: true, @@ -435,14 +420,9 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { return client }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayload([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + return &nexusrpc.OperationCompletionSuccessful{ + Result: createPayload([]byte("result-data")), + } }(), headerValue: encodedRef, expectsInternalError: true, @@ -457,14 +437,9 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { return historyservicemock.NewMockHistoryServiceClient(ctrl) }, completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayload([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + return &nexusrpc.OperationCompletionSuccessful{ + Result: createPayload([]byte("result-data")), + } }(), headerValue: "invalid-base64!!!", expectsInternalError: true, diff --git a/components/callbacks/nexus_invocation.go b/components/callbacks/nexus_invocation.go index 0d5c99da35..68866dfc35 100644 --- a/components/callbacks/nexus_invocation.go +++ b/components/callbacks/nexus_invocation.go @@ -1,15 +1,10 @@ package callbacks import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" - "mime" + "errors" "net/http" "net/http/httptrace" - "slices" "time" "github.com/nexus-rpc/sdk-go/nexus" @@ -40,20 +35,6 @@ type nexusInvocation struct { attempt int32 } -func isRetryableHTTPResponse(response *http.Response) bool { - return response.StatusCode >= 500 || slices.Contains(retryable4xxErrorTypes, response.StatusCode) -} - -func outcomeTag(callCtx context.Context, response *http.Response, callErr error) string { - if callErr != nil { - if callCtx.Err() != nil { - return "request-timeout" - } - return "unknown-error" - } - return fmt.Sprintf("status:%d", response.StatusCode) -} - func (n nexusInvocation) WrapError(result invocationResult, err error) error { if failure, ok := result.(invocationResultRetry); ok { return queueserrors.NewDestinationDownError(failure.err.Error(), err) @@ -77,107 +58,56 @@ func (n nexusInvocation) Invoke(ctx context.Context, ns *namespace.Namespace, e } } - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, n.nexus.Url, n.completion) - if err != nil { - return invocationResultFail{queueserrors.NewUnprocessableTaskError( - fmt.Sprintf("failed to construct Nexus request: %v", err), - )} - } - if request.Header == nil { - request.Header = make(http.Header) - } - for k, v := range n.nexus.Header { - request.Header.Set(k, v) - } - - caller := e.HTTPCallerProvider(queuescommon.NamespaceIDAndDestination{ - NamespaceID: ns.ID().String(), - Destination: task.Destination(), + client := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + HTTPCaller: e.HTTPCallerProvider(queuescommon.NamespaceIDAndDestination{ + NamespaceID: ns.ID().String(), + Destination: task.Destination(), + }), + Serializer: commonnexus.PayloadSerializer, }) // Make the call and record metrics. startTime := time.Now() - response, err := caller(request) + + for k, v := range n.nexus.Header { + n.completion.SetHeader(k, v) + } + err := client.CompleteOperation(ctx, n.nexus.Url, n.completion) namespaceTag := metrics.NamespaceTag(ns.Name().String()) destTag := metrics.DestinationTag(task.Destination()) - statusCodeTag := metrics.OutcomeTag(outcomeTag(ctx, response, err)) + statusCodeTag := metrics.OutcomeTag(outcomeTag(ctx, err)) e.MetricsHandler.Counter(RequestCounter.Name()).Record(1, namespaceTag, destTag, statusCodeTag) e.MetricsHandler.Timer(RequestLatencyHistogram.Name()).Record(time.Since(startTime), namespaceTag, destTag, statusCodeTag) - if err != nil { - e.Logger.Error("Callback request failed with error", tag.Error(err)) - return invocationResultRetry{err} - } - - if response.StatusCode >= 200 && response.StatusCode < 300 { - // Body is not read but should be discarded to keep the underlying TCP connection alive. - // Just in case something unexpected happens while discarding or closing the body, - // propagate errors to the machine. - if _, err = io.Copy(io.Discard, response.Body); err == nil { - if err = response.Body.Close(); err != nil { - e.Logger.Error("Callback request failed with error", tag.Error(err)) - return invocationResultRetry{err} - } - } + if err == nil { return invocationResultOK{} } - - retryable := isRetryableHTTPResponse(response) - err = readHandlerErrFromResponse(response, e.Logger) - e.Logger.Error("Callback request failed", tag.Error(err), tag.String("status", response.Status), tag.Bool("retryable", retryable)) + retryable := isRetryableCallError(err) + e.Logger.Error("Callback request failed", tag.Error(err), tag.Bool("retryable", retryable)) if retryable { return invocationResultRetry{err} } return invocationResultFail{err} } -// Reads and replaces the http response body and attempts to deserialize it into a Nexus failure. If successful, -// returns a nexus.HandlerError with the deserialized failure as the Cause. If there is an error reading the body or -// during deserialization, returns a nexus.HandlerError with a generic Cause based on response status. -// TODO: This logic is duplicated in the frontend handler for forwarded requests. Eventually it should live in the Nexus SDK. -func readHandlerErrFromResponse(response *http.Response, logger log.Logger) error { - handlerErr := &nexus.HandlerError{ - Type: commonnexus.HandlerErrorTypeFromHTTPStatus(response.StatusCode), - Cause: fmt.Errorf("request failed with: %v", response.Status), - } - - body, err := readAndReplaceBody(response) - if err != nil { - logger.Error("Error reading response body for non-ok callback request", tag.Error(err), tag.String("status", response.Status)) - return err - } - - if !isMediaTypeJSON(response.Header.Get("Content-Type")) { - logger.Error("received invalid content-type header for non-OK HTTP response to CompleteOperation request", tag.Value(response.Header.Get("Content-Type"))) - return handlerErr - } - - var failure nexus.Failure - err = json.Unmarshal(body, &failure) - if err != nil { - logger.Error("failed to deserialize Nexus Failure from HTTP response to CompleteOperation request", tag.Error(err)) - return handlerErr +func outcomeTag(callCtx context.Context, callErr error) string { + if callErr != nil { + if callCtx.Err() != nil { + return "request-timeout" + } + var handlerErr *nexus.HandlerError + if errors.As(callErr, &handlerErr) { + return "handler-error:" + string(handlerErr.Type) + } + return "unknown-error" } - - handlerErr.Cause = &nexus.FailureError{Failure: failure} - return handlerErr -} - -// readAndReplaceBody reads the response body in its entirety and closes it, and then replaces the original response -// body with an in-memory buffer. -// The body is replaced even when there was an error reading the entire body. -func readAndReplaceBody(response *http.Response) ([]byte, error) { - responseBody := response.Body - body, err := io.ReadAll(responseBody) - _ = responseBody.Close() - response.Body = io.NopCloser(bytes.NewReader(body)) - return body, err + return "success" } -func isMediaTypeJSON(contentType string) bool { - if contentType == "" { - return false +func isRetryableCallError(err error) bool { + var handlerError *nexus.HandlerError + if errors.As(err, &handlerError) { + return handlerError.Retryable() } - mediaType, _, err := mime.ParseMediaType(contentType) - return err == nil && mediaType == "application/json" + return true } diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 2e2e5b4e43..2a97803817 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -177,7 +177,7 @@ func (e taskExecutor) executeInvocationTask(ctx context.Context, env hsm.Environ // This happens when we accept the ScheduleNexusOperation command when the endpoint is not found in the registry as // indicated by the EndpointNotFoundAlwaysNonRetryable dynamic config. if args.endpointID == "" { - handlerError := nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") + handlerError := nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") return e.saveResult(ctx, env, ref, nil, handlerError) } @@ -185,7 +185,7 @@ func (e taskExecutor) executeInvocationTask(ctx context.Context, env hsm.Environ if err != nil { if errors.As(err, new(*serviceerror.NotFound)) { // The endpoint is not registered, immediately fail the invocation. - handlerError := nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") + handlerError := nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") return e.saveResult(ctx, env, ref, nil, handlerError) } return err @@ -629,7 +629,7 @@ func (e taskExecutor) executeCancelationTask(ctx context.Context, env hsm.Enviro endpoint, err := e.lookupEndpoint(ctx, namespace.ID(ref.WorkflowKey.NamespaceID), args.endpointID, args.endpointName) if err != nil { if errors.As(err, new(*serviceerror.NotFound)) { - handlerError := nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") + handlerError := nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") // The endpoint is not registered, immediately fail the invocation. return e.saveCancelationResult(ctx, env, ref, handlerError, args.scheduledEventID) @@ -929,10 +929,6 @@ func callErrToFailure(callErr error, retryable bool) (*failurepb.Failure, error) nf = *handlerErr.OriginalFailure } else { var err error - // Ensure the error message is set to prevent the Nexus failure converter from unwrapping the cause. - if handlerErr.Message == "" { - handlerErr.Message = "handler error" - } nf, err = nexusrpc.DefaultFailureConverter().ErrorToFailure(handlerErr) if err != nil { return nil, err diff --git a/components/nexusoperations/executors_test.go b/components/nexusoperations/executors_test.go index 0e5099d108..b6fef19b21 100644 --- a/components/nexusoperations/executors_test.go +++ b/components/nexusoperations/executors_test.go @@ -3,7 +3,6 @@ package nexusoperations_test import ( "context" "encoding/json" - "errors" "testing" "time" @@ -155,23 +154,23 @@ func TestProcessInvocationTask(t *testing.T) { onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { // Also use this test case to check the input and options provided. if service != "service" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation name") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation name") } if operation != "operation" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation name") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation name") } if options.CallbackHeader.Get("temporal-callback-token") == "" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "empty callback token") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "empty callback token") } if options.CallbackURL != "http://localhost/callback" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback URL") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback URL") } if options.Header.Get(nexus.HeaderOperationTimeout) != "1ms" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) } var v string if err := input.Consume(&v); err != nil || v != "input" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid input") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid input") } return &nexus.HandlerStartOperationResultSync[any]{Value: "result"}, nil }, @@ -310,7 +309,7 @@ func TestProcessInvocationTask(t *testing.T) { requestTimeout: time.Hour, destinationDown: true, onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") }, expectedMetricOutcome: "handler-error:INTERNAL", checkOutcome: func(t *testing.T, op nexusoperations.Operation, events []*historypb.HistoryEvent) { @@ -346,7 +345,7 @@ func TestProcessInvocationTask(t *testing.T) { onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { opTimeout, err := time.ParseDuration(options.Header.Get(nexus.HeaderOperationTimeout)) if err != nil || opTimeout > 10*time.Millisecond { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) } time.Sleep(time.Millisecond * 100) //nolint:forbidigo // Allow time.Sleep for timeout tests return &nexus.HandlerStartOperationResultAsync{OperationToken: "op-token"}, nil @@ -366,7 +365,7 @@ func TestProcessInvocationTask(t *testing.T) { expectedMetricOutcome: "request-timeout", onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { if options.Header.Get(nexus.HeaderOperationTimeout) != "" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation timeout header should not be set, got: %s", options.Header.Get(nexus.HeaderOperationTimeout)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation timeout header should not be set, got: %s", options.Header.Get(nexus.HeaderOperationTimeout)) } time.Sleep(time.Millisecond * 100) //nolint:forbidigo // Allow time.Sleep for timeout tests return &nexus.HandlerStartOperationResultAsync{OperationToken: "op-token"}, nil @@ -386,7 +385,7 @@ func TestProcessInvocationTask(t *testing.T) { expectedMetricOutcome: "pending", onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { if options.Header.Get(nexus.HeaderOperationTimeout) != "60000ms" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) } return &nexus.HandlerStartOperationResultAsync{OperationToken: "op-token"}, nil }, @@ -818,7 +817,7 @@ func TestProcessCancelationTask(t *testing.T) { // Check non retryable internal error. return &nexus.HandlerError{ Type: nexus.HandlerErrorTypeInternal, - Cause: errors.New("operation not found"), + Message: "operation not found", RetryBehavior: nexus.HandlerErrorRetryBehaviorNonRetryable, } }, @@ -826,9 +825,7 @@ func TestProcessCancelationTask(t *testing.T) { checkOutcome: func(t *testing.T, c nexusoperations.Cancelation) { require.Equal(t, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_FAILED, c.State()) require.Equal(t, string(nexus.HandlerErrorTypeInternal), c.LastAttemptFailure.GetNexusHandlerFailureInfo().GetType()) - require.Equal(t, "Internal Server Error", c.LastAttemptFailure.Message) - require.NotNil(t, c.LastAttemptFailure.Cause) - require.Equal(t, "operation not found", c.LastAttemptFailure.Cause.Message) + require.Equal(t, "operation not found", c.LastAttemptFailure.Message) }, }, { @@ -851,7 +848,7 @@ func TestProcessCancelationTask(t *testing.T) { header: map[string]string{"key": "value"}, onCancelOperation: func(ctx context.Context, service, operation, token string, options nexus.CancelOperationOptions) error { if options.Header["key"] != "value" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, `"key" header is not equal to "value"`) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, `"key" header is not equal to "value"`) } return nil }, @@ -866,7 +863,7 @@ func TestProcessCancelationTask(t *testing.T) { requestTimeout: time.Hour, destinationDown: true, onCancelOperation: func(ctx context.Context, service, operation, token string, options nexus.CancelOperationOptions) error { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") }, expectedMetricOutcome: "handler-error:INTERNAL", checkOutcome: func(t *testing.T, c nexusoperations.Cancelation) { diff --git a/components/nexusoperations/frontend/handler.go b/components/nexusoperations/frontend/handler.go index fa9fdb0326..b4f6f495fa 100644 --- a/components/nexusoperations/frontend/handler.go +++ b/components/nexusoperations/frontend/handler.go @@ -1,13 +1,9 @@ package frontend import ( - "bytes" "context" - "encoding/json" "errors" "fmt" - "io" - "mime" "net/http" "net/http/httptrace" "net/url" @@ -21,6 +17,7 @@ import ( commonpb "go.temporal.io/api/common/v1" "go.temporal.io/api/serviceerror" "go.temporal.io/server/api/historyservice/v1" + "go.temporal.io/server/common" "go.temporal.io/server/common/authorization" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/dynamicconfig" @@ -91,18 +88,18 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C startTime := time.Now() if !h.Config.Enabled() { h.preProcessErrorsCounter.Record(1) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "Nexus APIs are disabled") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "Nexus APIs are disabled") } token, err := commonnexus.DecodeCallbackToken(r.HTTPRequest.Header.Get(commonnexus.CallbackTokenHeader)) if err != nil { h.Logger.Error("failed to decode callback token", tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") } completion, err := h.CallbackTokenGenerator.DecodeCompletion(token) if err != nil { h.Logger.Error("failed to decode completion from token", tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") } ns, err := h.NamespaceRegistry.GetNamespaceByID(namespace.ID(completion.NamespaceId)) if err != nil { @@ -110,7 +107,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C h.preProcessErrorsCounter.Record(1) var nfe *serviceerror.NamespaceNotFound if errors.As(err, &nfe) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace %q not found", completion.NamespaceId) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace %q not found", completion.NamespaceId) } return commonnexus.ConvertGRPCError(err, false) } @@ -143,7 +140,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C if err != nil { h.Logger.Error("failed to extract namespace from request", tag.Error(err)) h.preProcessErrorsCounter.Record(1) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL") } if nsName != ns.Name().String() { logger.Error( @@ -152,7 +149,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C tag.Error(err), tag.String("completion-namespace-id", completion.GetNamespaceId()), ) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") } } @@ -165,7 +162,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C } tokenLimit := h.Config.MaxOperationTokenLength(ns.Name().String()) if len(r.OperationToken) > tokenLimit { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation token length exceeds allowed limit (%d/%d)", len(r.OperationToken), tokenLimit) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation token length exceeds allowed limit (%d/%d)", len(r.OperationToken), tokenLimit) } var links []*commonpb.Link @@ -208,11 +205,11 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C var result *commonpb.Payload if err := r.Result.Consume(&result); err != nil { logger.Error("cannot deserialize payload from completion result", tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid result content") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid result content") } if result.Size() > h.Config.PayloadSizeLimit(ns.Name().String()) { logger.Error("payload size exceeds error limit for Nexus CompleteOperation request", tag.WorkflowNamespace(ns.Name().String())) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "result exceeds size limit") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "result exceeds size limit") } hr.Outcome = &historyservice.CompleteNexusOperationRequest_Success{ Success: result, @@ -220,14 +217,14 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C default: // The Nexus SDK ensures this never happens but just in case... logger.Error("invalid operation state in completion request", tag.String("state", string(r.State)), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid completion state") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid completion state") } _, err = h.HistoryClient.CompleteNexusOperation(ctx, hr) if err != nil { logger.Error("failed to process nexus completion request", tag.Error(err)) var namespaceInactiveErr *serviceerror.NamespaceNotActive if errors.As(err, &namespaceInactiveErr) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") } var notFoundErr *serviceerror.NotFound if errors.As(err, ¬FoundErr) { @@ -242,13 +239,13 @@ func (h *completionHandler) forwardCompleteOperation(ctx context.Context, r *nex client, err := h.ForwardingClients.Get(rCtx.namespace.ActiveClusterName(rCtx.workflowID)) if err != nil { h.Logger.Error("unable to get HTTP client for forward request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err), tag.SourceCluster(h.ClusterMetadata.GetCurrentClusterName()), tag.TargetCluster(rCtx.namespace.ActiveClusterName(rCtx.workflowID))) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } forwardURL, err := url.JoinPath(client.BaseURL(), commonnexus.RouteCompletionCallback.Path(rCtx.namespace.Name().String())) if err != nil { h.Logger.Error("failed to construct forwarding request URL", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err), tag.TargetCluster(rCtx.namespace.ActiveClusterName(rCtx.workflowID))) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } if h.HTTPTraceProvider != nil { @@ -264,116 +261,55 @@ func (h *completionHandler) forwardCompleteOperation(ctx context.Context, r *nex } } - var forwardReq *http.Request + var completion nexusrpc.OperationCompletion + switch r.State { case nexus.OperationStateSucceeded: - // For successful operations, the Nexus framework streams the result as a LazyValue, so we can reuse the - // incoming request body. - forwardReq, err = http.NewRequestWithContext(ctx, r.HTTPRequest.Method, forwardURL, r.HTTPRequest.Body) - if err != nil { - h.Logger.Error("failed to construct forwarding HTTP request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + completion = &nexusrpc.OperationCompletionSuccessful{ + Result: r.Result.Reader, + OperationToken: r.OperationToken, + StartTime: r.StartTime, + CloseTime: r.CloseTime, + Links: r.Links, } case nexus.OperationStateFailed, nexus.OperationStateCanceled: // For unsuccessful operations, the Nexus framework reads and closes the original request body to deserialize // the failure, so we must construct a new completion to forward. - c := &nexusrpc.OperationCompletionUnsuccessful{ - State: r.State, + completion = &nexusrpc.OperationCompletionUnsuccessful{ + Error: r.Error, OperationToken: r.OperationToken, StartTime: r.StartTime, + CloseTime: r.CloseTime, Links: r.Links, } - if r.Error != nil && r.Error.OriginalFailure != nil { - c.Failure = *r.Error.OriginalFailure - } - forwardReq, err = nexusrpc.NewCompletionHTTPRequest(ctx, forwardURL, c) - if err != nil { - h.Logger.Error("failed to construct forwarding HTTP request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") - } default: - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation state: %q", r.State) - } - - forwardReq.Header = rCtx.originalHeaders - if forwardReq.Header == nil { - forwardReq.Header = make(http.Header, 1) - } - forwardReq.Header.Set(interceptor.DCRedirectionApiHeaderName, "true") - - resp, err := client.Do(forwardReq) - if err != nil { - h.Logger.Error("received error from HTTP client when forwarding request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") - } - - // TODO: The following response handling logic is duplicated in the nexus_invocation executor. Eventually it should live in the Nexus SDK. - body, err := readAndReplaceBody(resp) - if err != nil { - h.Logger.Error("unable to read HTTP response for forwarded request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation state: %q", r.State) } - if resp.StatusCode == http.StatusOK { - return nil - } - - if !isMediaTypeJSON(resp.Header.Get("Content-Type")) { - h.Logger.Error("received invalid content-type header for non-OK HTTP response to forwarded request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Value(resp.Header.Get("Content-Type"))) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") - } - - var failure nexus.Failure - err = json.Unmarshal(body, &failure) - if err != nil { - h.Logger.Error("failed to deserialize Nexus Failure from HTTP response to forwarded request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") - } - - // TODO: Upgrade Nexus SDK in order to reduce HTTP exposure - handlerErr := &nexus.HandlerError{ - Type: commonnexus.HandlerErrorTypeFromHTTPStatus(resp.StatusCode), - Cause: &nexus.FailureError{Failure: failure}, - } - - if handlerErr.Type == nexus.HandlerErrorTypeInternal && resp.StatusCode != http.StatusInternalServerError { - h.Logger.Warn("received unknown status code on Nexus client unexpected response error", tag.Value(resp.StatusCode)) - handlerErr.Cause = errors.New("internal error") - } - - return handlerErr -} - -// readAndReplaceBody reads the response body in its entirety and closes it, and then replaces the original response -// body with an in-memory buffer. -// The body is replaced even when there was an error reading the entire body. -func readAndReplaceBody(response *http.Response) ([]byte, error) { - responseBody := response.Body - body, err := io.ReadAll(responseBody) - _ = responseBody.Close() - response.Body = io.NopCloser(bytes.NewReader(body)) - return body, err + rCtx.originalHeaders.Set(interceptor.DCRedirectionApiHeaderName, "true") + cc := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + HTTPCaller: (&forwardingHTTPHeaderWrapper{ + client: client, + originalRequestHeaders: rCtx.originalHeaders, + }).Do, + }) + return cc.CompleteOperation(ctx, forwardURL, completion) } -func isMediaTypeJSON(contentType string) bool { - if contentType == "" { - return false - } - mediaType, _, err := mime.ParseMediaType(contentType) - return err == nil && mediaType == "application/json" +type forwardingHTTPHeaderWrapper struct { + client *common.FrontendHTTPClient + originalRequestHeaders http.Header } -// Copies HTTP request headers to Nexus headers except those starting with content- since those will be added by the client. -func httpHeaderToNexusHeader(httpHeader http.Header) nexus.Header { - header := nexus.Header{} - for k, v := range httpHeader { - lowerK := strings.ToLower(k) - if !strings.HasPrefix(lowerK, "content-") { - // Nexus headers can only have single values, ignore multiple values. - header[lowerK] = v[0] +func (f *forwardingHTTPHeaderWrapper) Do(req *http.Request) (*http.Response, error) { + // For forwarded requests, copy the original HTTP headers without sanitization. + for k, v := range f.originalRequestHeaders { + if req.Header.Get(k) == "" { + req.Header.Set(k, v[0]) } } - return header + + return f.client.Do(req) } type requestContext struct { @@ -520,7 +456,7 @@ func (c *requestContext) interceptRequest(ctx context.Context, request *nexusrpc return serviceerror.NewNamespaceNotActive(c.namespace.Name().String(), c.ClusterMetadata.GetCurrentClusterName(), c.namespace.ActiveClusterName(c.workflowID)) } c.metricsHandler = c.metricsHandler.WithTags(metrics.OutcomeTag("namespace_inactive_forwarding_disabled")) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") } c.cleanupFunctions = append(c.cleanupFunctions, func(retErr error) { diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index e5b9f4129e..d245447c28 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -199,7 +199,7 @@ func (c *operationContext) interceptRequest( ) } c.metricsHandler = c.metricsHandler.WithTags(metrics.OutcomeTag("namespace_inactive_forwarding_disabled")) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") } c.cleanupFunctions = append(c.cleanupFunctions, func(respHeaders map[string]string, retErr error) { @@ -370,7 +370,7 @@ func (h *nexusHandler) getOperationContext(ctx context.Context, method string) ( var nfe *serviceerror.NamespaceNotFound if errors.As(err, &nfe) { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace not found: %q", nc.namespaceName) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace not found: %q", nc.namespaceName) } return nil, commonnexus.ConvertGRPCError(err, false) } @@ -432,11 +432,11 @@ func (h *nexusHandler) StartOperation( // Transform nexus Content to temporal Payload with common/nexus PayloadSerializer. if err = input.Consume(&startOperationRequest.Payload); err != nil { oc.logger.Warn("invalid input", tag.Error(err)) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid input") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid input") } if startOperationRequest.Payload.Size() > h.payloadSizeLimit(oc.namespaceName) { oc.logger.Error("payload size exceeds error limit for Nexus StartOperation request", tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "input exceeds size limit") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "input exceeds size limit") } // Dispatch the request to be sync matched with a worker polling on the nexusContext taskQueue. @@ -458,12 +458,12 @@ func (h *nexusHandler) StartOperation( nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } he, err := nexusrpc.DefaultFailureConverter().FailureToError(nf) if err != nil { oc.logger.Error("error converting Nexus failure to Nexus HandlerError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } return nil, he @@ -477,7 +477,7 @@ func (h *nexusHandler) StartOperation( case *matchingservice.DispatchNexusTaskResponse_RequestTimeout: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_timeout")) oc.setFailureSource(commonnexus.FailureSourceWorker) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeUpstreamTimeout, "upstream timeout") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUpstreamTimeout, "upstream timeout") case *matchingservice.DispatchNexusTaskResponse_Response: switch t := t.Response.GetStartOperation().GetVariant().(type) { @@ -504,13 +504,18 @@ func (h *nexusHandler) StartOperation( case *nexuspb.StartOperationResponse_OperationError: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("operation_error")) oc.setFailureSource(commonnexus.FailureSourceWorker) - err := &nexus.OperationError{ - State: nexus.OperationState(t.OperationError.GetOperationState()), + opErr := &nexus.OperationError{ + Message: "operation error", + State: nexus.OperationState(t.OperationError.GetOperationState()), Cause: &nexus.FailureError{ Failure: commonnexus.ProtoFailureToNexusFailure(t.OperationError.GetFailure()), }, } - return nil, err + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + oc.logger.Error("error converting OperationError to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + return nil, opErr case *nexuspb.StartOperationResponse_Failure: // Set the failure source to "worker" if we've reached this case. @@ -521,12 +526,12 @@ func (h *nexusHandler) StartOperation( nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } cause, err := nexusrpc.DefaultFailureConverter().FailureToError(nf) if err != nil { oc.logger.Error("error converting Nexus failure to Nexus OperationError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } state := nexus.OperationStateFailed if t.Failure.GetCanceledFailureInfo() != nil { @@ -537,17 +542,10 @@ func (h *nexusHandler) StartOperation( Message: "operation error", Cause: cause, } - nf, err = nexusrpc.DefaultFailureConverter().ErrorToFailure(opErr) - if err != nil { + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { oc.logger.Error("error converting OperationError to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } - // Mark that the original failure is an OperationError wrapper to be unwrapped. - // Newer server callers will unwrap the cause automatically if this metadata key is set. - // This is required to support calling non Temporal based implementations where OpeationErrors carry additional information. - nf.Metadata["unwrap-error"] = "true" - opErr.OriginalFailure = &nf - return nil, opErr } } @@ -555,7 +553,7 @@ func (h *nexusHandler) StartOperation( oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:EMPTY_OUTCOME")) oc.setFailureSource(commonnexus.FailureSourceWorker) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") } func parseLinks(links []*nexuspb.Link, logger log.Logger) []nexus.Link { @@ -675,12 +673,12 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) if err != nil { oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } he, err := nexusrpc.DefaultFailureConverter().FailureToError(nf) if err != nil { oc.logger.Error("error converting Nexus failure to Nexus HandlerError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } return he @@ -694,7 +692,7 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, case *matchingservice.DispatchNexusTaskResponse_RequestTimeout: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_timeout")) oc.setFailureSource(commonnexus.FailureSourceWorker) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUpstreamTimeout, "upstream timeout") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUpstreamTimeout, "upstream timeout") case *matchingservice.DispatchNexusTaskResponse_Response: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("success")) @@ -704,7 +702,7 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:EMPTY_OUTCOME")) oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") } func (h *nexusHandler) forwardCancelOperation( @@ -725,7 +723,7 @@ func (h *nexusHandler) forwardCancelOperation( handle, err := client.NewOperationHandle(operation, id) if err != nil { oc.logger.Warn("invalid Nexus cancel operation.", tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation") } if h.httpTraceProvider != nil { @@ -758,7 +756,7 @@ func (h *nexusHandler) nexusClientForActiveCluster(oc *operationContext, service if err != nil { oc.logger.Error("failed to forward Nexus request. error creating HTTP client", tag.Error(err), tag.SourceCluster(oc.namespace.ActiveClusterName(namespace.EmptyBusinessID)), tag.TargetCluster(oc.namespace.ActiveClusterName(""))) oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("request_forwarding_failed")) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "request forwarding failed") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "request forwarding failed") } httpCaller := &forwardingHttpHeaderWrapper{ @@ -790,7 +788,7 @@ func (h *nexusHandler) nexusClientForActiveCluster(oc *operationContext, service tag.WorkflowTaskQueueName(oc.taskQueue), tag.Error(err)) oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("request_forwarding_failed")) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "request forwarding failed") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "request forwarding failed") } return nexusrpc.NewHTTPClient(nexusrpc.HTTPClientOptions{ @@ -810,12 +808,12 @@ func convertOutcomeToNexusHandlerError(resp *matchingservice.DispatchNexusTaskRe retryBehavior = nexus.HandlerErrorRetryBehaviorNonRetryable } // nolint:staticcheck // Deprecated function still in use for backward compatibility. - originalFailure := commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()) + cause := commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()) return &nexus.HandlerError{ // nolint:staticcheck // Deprecated function still in use for backward compatibility. - Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), - RetryBehavior: retryBehavior, - OriginalFailure: &originalFailure, + Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), + RetryBehavior: retryBehavior, + Cause: &nexus.FailureError{Failure: cause}, } } diff --git a/service/frontend/nexus_http_handler.go b/service/frontend/nexus_http_handler.go index 940e63fcfa..b5943d5f01 100644 --- a/service/frontend/nexus_http_handler.go +++ b/service/frontend/nexus_http_handler.go @@ -2,7 +2,6 @@ package frontend import ( "context" - "encoding/json" "errors" "net/http" "net/url" @@ -34,6 +33,7 @@ import ( // Small wrapper that does some pre-processing before handing requests over to the Nexus SDK's HTTP handler. type NexusHTTPHandler struct { + base nexusrpc.BaseHTTPHandler logger log.Logger nexusHandler http.Handler enpointRegistry commonnexus.EndpointRegistry @@ -67,6 +67,10 @@ func NewNexusHTTPHandler( httpTraceProvider commonnexus.HTTPClientTraceProvider, ) *NexusHTTPHandler { return &NexusHTTPHandler{ + base: nexusrpc.BaseHTTPHandler{ + Logger: log.NewSlogLogger(logger), + FailureConverter: nexusrpc.DefaultFailureConverter(), + }, logger: logger, enpointRegistry: endpointRegistry, namespaceRegistry: namespaceRegistry, @@ -110,32 +114,15 @@ func (h *NexusHTTPHandler) RegisterRoutes(r *mux.Router) { HandlerFunc(h.dispatchNexusTaskByEndpoint) } -func (h *NexusHTTPHandler) writeNexusFailure(writer http.ResponseWriter, statusCode int, failure *nexus.Failure) { +func (h *NexusHTTPHandler) writeFailure(writer http.ResponseWriter, r *http.Request, err error) { h.preprocessErrorCounter.Record(1) - - if failure == nil { - writer.WriteHeader(statusCode) - return - } - - bytes, err := json.Marshal(failure) - if err != nil { - h.logger.Error("failed to marshal failure", tag.Error(err)) - writer.WriteHeader(http.StatusInternalServerError) - return - } - writer.Header().Set("Content-Type", "application/json") - writer.WriteHeader(statusCode) - - if _, err := writer.Write(bytes); err != nil { - h.logger.Error("failed to write response body", tag.Error(err)) - } + h.base.WriteFailure(writer, r, err) } // Handler for [nexushttp.RouteSet.DispatchNexusTaskByNamespaceAndTaskQueue]. func (h *NexusHTTPHandler) dispatchNexusTaskByNamespaceAndTaskQueue(w http.ResponseWriter, r *http.Request) { if !h.enabled() { - h.writeNexusFailure(w, http.StatusNotFound, &nexus.Failure{Message: "nexus endpoints disabled"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "nexus endpoints disabled")) return } @@ -145,31 +132,32 @@ func (h *NexusHTTPHandler) dispatchNexusTaskByNamespaceAndTaskQueue(w http.Respo if nc.taskQueue, err = url.PathUnescape(params.TaskQueue); err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusBadRequest, &nexus.Failure{Message: "invalid URL"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL")) return } if nc.namespaceName, err = url.PathUnescape(params.Namespace); err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusBadRequest, &nexus.Failure{Message: "invalid URL"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL")) return } if err = h.namespaceValidationInterceptor.ValidateName(nc.namespaceName); err != nil { h.logger.Error("invalid namespace name", tag.Error(err)) - h.writeNexusFailure(w, http.StatusBadRequest, &nexus.Failure{Message: err.Error()}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "%v", err.Error())) return } - r, err = h.parseTlsAndAuthInfo(r, nc) + rWithAuthCtx, err := h.parseTlsAndAuthInfo(r, nc) if err != nil { h.logger.Error("failed to get claims", tag.Error(err)) - h.writeNexusFailure(w, http.StatusUnauthorized, &nexus.Failure{Message: "unauthorized"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnauthenticated, "unauthorized")) return } + r = rWithAuthCtx u, err := mux.CurrentRoute(r).URL("namespace", params.Namespace, "task_queue", params.TaskQueue) if err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusInternalServerError, &nexus.Failure{Message: "internal error"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error")) return } @@ -179,7 +167,7 @@ func (h *NexusHTTPHandler) dispatchNexusTaskByNamespaceAndTaskQueue(w http.Respo // Handler for [nexushttp.RouteSet.DispatchNexusTaskByEndpoint]. func (h *NexusHTTPHandler) dispatchNexusTaskByEndpoint(w http.ResponseWriter, r *http.Request) { if !h.enabled() { - h.writeNexusFailure(w, http.StatusNotFound, &nexus.Failure{Message: "nexus endpoints disabled"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "nexus endpoints disabled")) return } @@ -188,7 +176,7 @@ func (h *NexusHTTPHandler) dispatchNexusTaskByEndpoint(w http.ResponseWriter, r endpointID, err := url.PathUnescape(endpointIDEscaped) if err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusBadRequest, &nexus.Failure{Message: "invalid URL"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL")) return } endpointEntry, err := h.enpointRegistry.GetByID(r.Context(), endpointID) @@ -207,32 +195,33 @@ func (h *NexusHTTPHandler) dispatchNexusTaskByEndpoint(w http.ResponseWriter, r w.Header().Set("nexus-request-retryable", "false") } } - h.writeNexusFailure(w, http.StatusNotFound, &nexus.Failure{Message: "nexus endpoint not found"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "nexus endpoint not found")) case codes.DeadlineExceeded: - h.writeNexusFailure(w, http.StatusRequestTimeout, &nexus.Failure{Message: "request timed out"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeRequestTimeout, "request timed out")) default: - h.writeNexusFailure(w, http.StatusInternalServerError, &nexus.Failure{Message: "internal error"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error")) } return } - nc, ok := h.nexusContextFromEndpoint(endpointEntry, w, r.Header) + nc, ok := h.nexusContextFromEndpoint(endpointEntry, w, r) if !ok { // nexusContextFromEndpoint already writes the failure response. return } - r, err = h.parseTlsAndAuthInfo(r, nc) + rWithAuthCtx, err := h.parseTlsAndAuthInfo(r, nc) if err != nil { h.logger.Error("failed to get claims", tag.Error(err)) - h.writeNexusFailure(w, http.StatusUnauthorized, &nexus.Failure{Message: "unauthorized"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnauthenticated, "unauthorized")) return } + r = rWithAuthCtx u, err := mux.CurrentRoute(r).URL("endpoint", endpointIDEscaped) if err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusInternalServerError, &nexus.Failure{Message: "internal error"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error")) return } @@ -256,7 +245,7 @@ func (h *NexusHTTPHandler) baseNexusContext(apiName string, header http.Header) // endpoint is valid for dispatching. // For security reasons, at the moment only worker target endpoints are considered valid, in the future external // endpoints may also be supported. -func (h *NexusHTTPHandler) nexusContextFromEndpoint(entry *persistencespb.NexusEndpointEntry, w http.ResponseWriter, header http.Header) (*nexusContext, bool) { +func (h *NexusHTTPHandler) nexusContextFromEndpoint(entry *persistencespb.NexusEndpointEntry, w http.ResponseWriter, r *http.Request) (*nexusContext, bool) { switch v := entry.Endpoint.Spec.GetTarget().GetVariant().(type) { case *persistencespb.NexusEndpointTarget_Worker_: nsName, err := h.namespaceRegistry.GetNamespaceName(namespace.ID(v.Worker.GetNamespaceId())) @@ -265,20 +254,20 @@ func (h *NexusHTTPHandler) nexusContextFromEndpoint(entry *persistencespb.NexusE var notFoundErr *serviceerror.NamespaceNotFound if errors.As(err, ¬FoundErr) { w.Header().Set("nexus-request-retryable", "true") - h.writeNexusFailure(w, http.StatusNotFound, &nexus.Failure{Message: "invalid endpoint target"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "invalid endpoint target")) } else { - h.writeNexusFailure(w, http.StatusInternalServerError, &nexus.Failure{Message: "internal error"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error")) } return nil, false } - nc := h.baseNexusContext(configs.DispatchNexusTaskByEndpointAPIName, header) + nc := h.baseNexusContext(configs.DispatchNexusTaskByEndpointAPIName, r.Header) nc.namespaceName = nsName.String() nc.taskQueue = v.Worker.GetTaskQueue() nc.endpointName = entry.Endpoint.Spec.Name nc.endpointID = entry.Id return nc, true default: - h.writeNexusFailure(w, http.StatusBadRequest, &nexus.Failure{Message: "invalid endpoint target"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid endpoint target")) return nil, false } } @@ -329,7 +318,7 @@ func (h *NexusHTTPHandler) serveResolvedURL(w http.ResponseWriter, r *http.Reque prefix, err := url.PathUnescape(u.Path) if err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusInternalServerError, &nexus.Failure{Message: "internal error"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error")) return } prefix = path.Dir(prefix) diff --git a/service/history/handler.go b/service/history/handler.go index e091557232..3409602ac4 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -2167,19 +2167,13 @@ func (h *Handler) CompleteNexusOperation(ctx context.Context, request *historyse var ok bool if opErr, ok = recvdErr.(*nexus.OperationError); !ok { opErr = &nexus.OperationError{ - State: nexus.OperationState(request.GetState()), - // Setting a message here will bypass the Nexus SDK's failure converter backward compatibility logic. + State: nexus.OperationState(request.GetState()), Message: "nexus operation completed unsuccessfully", Cause: recvdErr, } - origFailure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(opErr) - if err != nil { + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { return nil, serviceerror.NewInvalidArgument("unable to convert operation error to failure") } - // Special header to signal that this error should be unwrapped by the completion handler as old servers will send - // back empty wrappers for underlying causes. - origFailure.Metadata["unwrap-error"] = "true" - opErr.OriginalFailure = &origFailure } } err = nexusoperations.CompletionHandler( diff --git a/service/history/workflow/mutable_state_impl.go b/service/history/workflow/mutable_state_impl.go index 862b1af8cd..b43bdb16cb 100644 --- a/service/history/workflow/mutable_state_impl.go +++ b/service/history/workflow/mutable_state_impl.go @@ -767,33 +767,31 @@ func (ms *MutableStateImpl) GetNexusCompletion( // Nexus does not support it. p = payloads[0] } - completion, err := nexusrpc.NewOperationCompletionSuccessful(p, nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - StartTime: ms.executionState.GetStartTime().AsTime(), - CloseTime: ce.GetEventTime().AsTime(), - Links: []nexus.Link{startLink}, - }) - if err != nil { - return nil, serviceerror.NewInternalf("failed to construct Nexus completion: %v", err) - } - return completion, nil + return &nexusrpc.OperationCompletionSuccessful{ + Result: p, + StartTime: ms.executionState.GetStartTime().AsTime(), + CloseTime: ce.GetEventTime().AsTime(), + Links: []nexus.Link{startLink}, + }, nil case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_FAILED: f, err := commonnexus.TemporalFailureToNexusFailure(ce.GetWorkflowExecutionFailedEventAttributes().GetFailure()) if err != nil { return nil, err } - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{ - State: nexus.OperationStateFailed, - Cause: &nexus.FailureError{Failure: f}, - // Store the original failure to bypass the Nexus failure converter. - OriginalFailure: &f, - }, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - StartTime: ms.executionState.GetStartTime().AsTime(), - CloseTime: ce.GetEventTime().AsTime(), - Links: []nexus.Link{startLink}, - }) + opErr := &nexus.OperationError{ + Message: "operation failed", + State: nexus.OperationStateFailed, + Cause: &nexus.FailureError{Failure: f}, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + return nil, err + } + return &nexusrpc.OperationCompletionUnsuccessful{ + Error: opErr, + StartTime: ms.executionState.GetStartTime().AsTime(), + CloseTime: ce.GetEventTime().AsTime(), + Links: []nexus.Link{startLink}, + }, nil case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED: f, err := commonnexus.TemporalFailureToNexusFailure(&failurepb.Failure{ Message: "operation canceled", @@ -806,18 +804,20 @@ func (ms *MutableStateImpl) GetNexusCompletion( if err != nil { return nil, err } - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{ - State: nexus.OperationStateCanceled, - Cause: &nexus.FailureError{Failure: f}, - // Store the original failure to bypass the Nexus failure converter. - OriginalFailure: &f, - }, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - StartTime: ms.executionState.GetStartTime().AsTime(), - CloseTime: ce.GetEventTime().AsTime(), - Links: []nexus.Link{startLink}, - }) + opErr := &nexus.OperationError{ + State: nexus.OperationStateCanceled, + Message: "operation canceled", + Cause: &nexus.FailureError{Failure: f}, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + return nil, err + } + return &nexusrpc.OperationCompletionUnsuccessful{ + Error: opErr, + StartTime: ms.executionState.GetStartTime().AsTime(), + CloseTime: ce.GetEventTime().AsTime(), + Links: []nexus.Link{startLink}, + }, nil case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED: f, err := commonnexus.TemporalFailureToNexusFailure(&failurepb.Failure{ Message: "operation terminated", @@ -828,19 +828,20 @@ func (ms *MutableStateImpl) GetNexusCompletion( if err != nil { return nil, err } - return nexusrpc.NewOperationCompletionUnsuccessful( - // NOTE: Not setting a message for compatibility with older servers than don't support both cause and message. - &nexus.OperationError{ - State: nexus.OperationStateFailed, - Cause: &nexus.FailureError{Failure: f}, - // Store the original failure to bypass the Nexus failure converter. - OriginalFailure: &f, - }, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - StartTime: ms.executionState.GetStartTime().AsTime(), - CloseTime: ce.GetEventTime().AsTime(), - Links: []nexus.Link{startLink}, - }) + opErr := &nexus.OperationError{ + State: nexus.OperationStateFailed, + Message: "operation failed", + Cause: &nexus.FailureError{Failure: f}, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + return nil, err + } + return &nexusrpc.OperationCompletionUnsuccessful{ + Error: opErr, + StartTime: ms.executionState.GetStartTime().AsTime(), + CloseTime: ce.GetEventTime().AsTime(), + Links: []nexus.Link{startLink}, + }, nil case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT: f, err := commonnexus.TemporalFailureToNexusFailure(&failurepb.Failure{ Message: "operation exceeded internal timeout", @@ -854,18 +855,20 @@ func (ms *MutableStateImpl) GetNexusCompletion( if err != nil { return nil, err } - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{ - State: nexus.OperationStateFailed, - Cause: &nexus.FailureError{Failure: f}, - // Store the original failure to bypass the Nexus failure converter. - OriginalFailure: &f, - }, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - StartTime: ms.executionState.GetStartTime().AsTime(), - CloseTime: ce.GetEventTime().AsTime(), - Links: []nexus.Link{startLink}, - }) + opErr := &nexus.OperationError{ + State: nexus.OperationStateFailed, + Message: "operation failed", + Cause: &nexus.FailureError{Failure: f}, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + return nil, err + } + return &nexusrpc.OperationCompletionUnsuccessful{ + Error: opErr, + StartTime: ms.executionState.GetStartTime().AsTime(), + CloseTime: ce.GetEventTime().AsTime(), + Links: []nexus.Link{startLink}, + }, nil } return nil, serviceerror.NewInternalf("invalid workflow execution status: %v", ce.GetEventType()) } diff --git a/service/history/workflow/workflow_test/mutable_state_impl_test.go b/service/history/workflow/workflow_test/mutable_state_impl_test.go index 1e06c10238..46384b0721 100644 --- a/service/history/workflow/workflow_test/mutable_state_impl_test.go +++ b/service/history/workflow/workflow_test/mutable_state_impl_test.go @@ -6,7 +6,6 @@ package workflow_test import ( "context" "fmt" - "io" "math" "testing" "time" @@ -382,11 +381,10 @@ func TestGetNexusCompletion(t *testing.T) { verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { success, ok := completion.(*nexusrpc.OperationCompletionSuccessful) require.True(t, ok) - require.Equal(t, "application/json", success.Reader.Header.Get("type")) - require.Equal(t, "1", success.Reader.Header.Get("length")) - buf, err := io.ReadAll(success.Reader) - require.NoError(t, err) - require.Equal(t, []byte("3"), buf) + require.Equal(t, commonpb.Payload{ + Metadata: map[string][]byte{"encoding": []byte("json/plain")}, + Data: []byte("3"), + }, success.Result) require.Equal(t, event.GetEventTime().AsTime(), success.CloseTime) }, }, @@ -402,8 +400,8 @@ func TestGetNexusCompletion(t *testing.T) { verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { failure, ok := completion.(*nexusrpc.OperationCompletionUnsuccessful) require.True(t, ok) - require.Equal(t, nexus.OperationStateFailed, failure.State) - require.Equal(t, "workflow failed", failure.Failure.Message) + require.Equal(t, nexus.OperationStateFailed, failure.Error.State) + require.Equal(t, "workflow failed", failure.Error.Message) require.Equal(t, event.GetEventTime().AsTime(), failure.CloseTime) }, }, @@ -415,8 +413,8 @@ func TestGetNexusCompletion(t *testing.T) { verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { failure, ok := completion.(*nexusrpc.OperationCompletionUnsuccessful) require.True(t, ok) - require.Equal(t, nexus.OperationStateFailed, failure.State) - require.Equal(t, "operation terminated", failure.Failure.Message) + require.Equal(t, nexus.OperationStateFailed, failure.Error.State) + require.Equal(t, "operation terminated", failure.Error.Message) require.Equal(t, event.GetEventTime().AsTime(), failure.CloseTime) }, }, @@ -428,8 +426,8 @@ func TestGetNexusCompletion(t *testing.T) { verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { failure, ok := completion.(*nexusrpc.OperationCompletionUnsuccessful) require.True(t, ok) - require.Equal(t, nexus.OperationStateCanceled, failure.State) - require.Equal(t, "operation canceled", failure.Failure.Message) + require.Equal(t, nexus.OperationStateCanceled, failure.Error.State) + require.Equal(t, "operation canceled", failure.Error.Message) require.Equal(t, event.GetEventTime().AsTime(), failure.CloseTime) }, }, diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index ac9d4a2309..5247903a7f 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -364,7 +364,7 @@ func (s *CallbacksSuite) TestWorkflowNexusCallbacks_CarriedOver() { var err error if attempt < numAttempts { // force retry - err = nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "intentional error") + err = nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "intentional error") } ch.requestCompleteCh <- err } diff --git a/tests/nexus_api_test.go b/tests/nexus_api_test.go index 8055b925be..566ac2ea04 100644 --- a/tests/nexus_api_test.go +++ b/tests/nexus_api_test.go @@ -203,7 +203,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, nexus.HandlerErrorRetryBehaviorUnspecified, handlerErr.RetryBehavior) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) - require.Equal(t, "Internal Server Error", handlerErr.Message) + require.Empty(t, handlerErr.Message) require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, @@ -225,7 +225,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, nexus.HandlerErrorRetryBehaviorNonRetryable, handlerErr.RetryBehavior) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) - require.Equal(t, "Internal Server Error", handlerErr.Message) + require.Empty(t, handlerErr.Message) require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, @@ -375,7 +375,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_WithNamespaceAndTaskQueue_Na s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) // I wish we'd never put periods in error messages :( - s.Equal("Namespace length exceeds limit.", handlerErr.Cause.Error()) + s.Equal("Namespace length exceeds limit.", handlerErr.Message) snap := capture.Snapshot() @@ -461,13 +461,13 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Forbidden() { name: "deny with exposed error", onAuthorize: func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName { - return authorization.Result{}, nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") + return authorization.Result{}, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") } if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName { if ct.NexusEndpointName != testEndpoint.Spec.Name { panic("expected nexus endpoint name") } - return authorization.Result{}, nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") + return authorization.Result{}, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") } return authorization.Result{Decision: authorization.DecisionAllow}, nil }, @@ -553,7 +553,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUnauthenticated, handlerErr.Type) - require.Equal(t, "Unauthorized", handlerErr.Message) + require.Equal(t, "unauthorized", handlerErr.Message) require.Equal(t, 1, len(snap["nexus_request_preprocess_errors"])) }, }, @@ -602,21 +602,12 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { go s.nexusTaskPoller(ctx, taskQueue, tc.handler) } - var result *nexusrpc.ClientStartOperationResponse[string] - var snap map[string][]*metricstest.CapturedRecording - - // Wait until the endpoint is loaded into the registry. - s.Eventually(func() bool { - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - - result, err = nexusrpc.StartOperation(ctx, client, op, "input", nexus.StartOperationOptions{ - Header: tc.header, - }) - snap = capture.Snapshot() - var handlerErr *nexus.HandlerError - return err == nil || !(errors.As(err, &handlerErr) && handlerErr.Type == nexus.HandlerErrorTypeNotFound) - }, 10*time.Second, 1*time.Second) + capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + result, err := nexusrpc.StartOperation(ctx, client, op, "input", nexus.StartOperationOptions{ + Header: tc.header, + }) + snap := capture.Snapshot() + s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) tc.assertion(t, result, err, snap) } @@ -726,7 +717,7 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) - require.Equal(t, "Internal Server Error", handlerErr.Message) + require.Empty(t, handlerErr.Message) require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, @@ -984,7 +975,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_ByEndpoint_EndpointNotFound( var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) - s.Equal("nexus endpoint not found", handlerErr.Cause.Error()) + s.Equal("nexus endpoint not found", handlerErr.Message) snap := capture.Snapshot() s.Equal(1, len(snap["nexus_request_preprocess_errors"])) } diff --git a/tests/nexus_workflow_test.go b/tests/nexus_workflow_test.go index 5efe011614..5fc5f5b97e 100644 --- a/tests/nexus_workflow_test.go +++ b/tests/nexus_workflow_test.go @@ -1,13 +1,10 @@ package tests import ( - "bytes" "context" "encoding/json" "errors" "fmt" - "io" - "net/http" "slices" "strings" "testing" @@ -656,23 +653,20 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { protorequire.ProtoEqual(s.T(), handlerLink, l) // Completion request fails if the result payload is too large. - largeCompletion, err := nexusrpc.NewOperationCompletionSuccessful( + largeCompletion := &nexusrpc.OperationCompletionSuccessful{ // Use -10 to avoid hitting MaxNexusAPIRequestBodyBytes. Actual payload will still exceed limit because of // additional Content headers. See common/rpc/grpc.go:66 - s.mustToPayload(strings.Repeat("a", (2*1024*1024)-10)), - nexusrpc.OperationCompletionSuccessfulOptions{Serializer: commonnexus.PayloadSerializer}, - ) + Result: s.mustToPayload(strings.Repeat("a", (2*1024*1024)-10)), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } s.NoError(err) - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, largeCompletion, callbackToken) - s.Equal(http.StatusBadRequest, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, largeCompletion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_bad_request"}) - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload(nil), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) - s.NoError(err) - invalidNamespace := testcore.RandomizeStr("ns") _, err = s.FrontendClient().RegisterNamespace(ctx, &workflowservice.RegisterNamespaceRequest{ Namespace: invalidNamespace, @@ -683,11 +677,15 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { // Send an invalid completion request and verify that we get an error that the namespace in the URL doesn't match the namespace in the token. invalidCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(invalidNamespace) - res, _, body := s.sendNexusCompletionRequest(ctx, s.T(), invalidCallbackURL, completion, callbackToken) - + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } + _, err = s.sendNexusCompletionRequest(ctx, invalidCallbackURL, completion) // Verify we get the correct error response - s.Equal(http.StatusBadRequest, res.StatusCode) - s.Contains(body, "invalid callback token", "Response should indicate namespace mismatch") + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + s.Contains(handlerErr.Error(), "invalid callback token", "Response should indicate namespace mismatch") // Manipulate the token to verify we get the expected errors in the API. gen := &commonnexus.CallbackTokenGenerator{} @@ -701,9 +699,11 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { workflowNotFoundToken.WorkflowId = "not-found" callbackToken, err = gen.Tokenize(workflowNotFoundToken) s.NoError(err) + completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} - res, snap, _ = s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_not_found"}) @@ -712,23 +712,20 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { staleToken.Ref.MachineInitialVersionedTransition.NamespaceFailoverVersion++ callbackToken, err = gen.Tokenize(staleToken) s.NoError(err) + completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} - res, snap, _ = s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_not_found"}) - // Send a valid - successful completion request. - completion, err = nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) - s.NoError(err) - callbackToken, err = gen.Tokenize(completionToken) s.NoError(err) + completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} - res, snap, _ = s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusOK, res.StatusCode) + snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + s.NoError(err) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "success"}) // Ensure that CompleteOperation request is tracked as part of normal service telemetry metrics @@ -739,8 +736,9 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { s.Greater(idx, -1) // Resend the request and verify we get a not found error since the operation has already completed. - res, snap, _ = s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_not_found"}) @@ -987,7 +985,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() s.NoError(err) expectedLinks := []*commonpb.Link_WorkflowEvent{ - &commonpb.Link_WorkflowEvent{ + { Namespace: s.Namespace().String(), WorkflowId: completionWFID, RunId: completionWfRunIDs[0], @@ -998,7 +996,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() }, }, }, - &commonpb.Link_WorkflowEvent{ + { Namespace: s.Namespace().String(), WorkflowId: completionWFID, RunId: completionWfRunIDs[1], @@ -1169,10 +1167,12 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { s.Greater(startedEventIdx, 0) // Send a valid - failed completion request. - completion, err := nexusrpc.NewOperationCompletionUnsuccessful(nexus.NewOperationFailedError("test operation failed"), nexusrpc.OperationCompletionUnsuccessfulOptions{}) + completion := &nexusrpc.OperationCompletionUnsuccessful{ + Error: nexus.NewOperationFailedErrorf("test operation failed"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) s.NoError(err) - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusOK, res.StatusCode) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "success"}) @@ -1219,24 +1219,27 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { ctx := testcore.NewContext() - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) - s.NoError(err) + commonCompletion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + } s.Run("ConfigDisabled", func() { s.OverrideDynamicConfig(dynamicconfig.EnableNexus, false) publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, "") - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, commonCompletion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_request_preprocess_errors"])) }) s.Run("ConfigDisabledNoIdentifier", func() { s.OverrideDynamicConfig(dynamicconfig.EnableNexus, false) publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, "") - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, commonCompletion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_request_preprocess_errors"])) }) @@ -1246,8 +1249,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.NoError(err) publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path("namespace-doesnt-exist") - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, tokenWithBadNamespace) - s.Equal(http.StatusNotFound, res.StatusCode) + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: tokenWithBadNamespace}, + } + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_request_preprocess_errors"])) }) @@ -1257,26 +1266,34 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.NoError(err) publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, tokenWithBadNamespace) - s.Equal(http.StatusNotFound, res.StatusCode) + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: tokenWithBadNamespace}, + } + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_request_preprocess_errors"])) }) s.Run("OperationTokenTooLong", func() { publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - OperationToken: strings.Repeat("long", 2000), - }) - s.NoError(err) // Generate a valid callback token to get past initial validation namespaceID := s.GetNamespaceID(s.Namespace().String()) validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + OperationToken: strings.Repeat("long", 2000), + Header: nexus.Header{commonnexus.CallbackTokenHeader: validToken}, + } - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, validToken) - s.Equal(http.StatusBadRequest, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Equal(0, len(snap["nexus_completion_request_preprocess_errors"])) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_bad_request"}) @@ -1284,19 +1301,21 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.Run("OperationTokenTooLongNoIdentifier", func() { publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - OperationToken: strings.Repeat("long", 2000), - }) - s.NoError(err) - // Generate a valid callback token to get past initial validation namespaceID := s.GetNamespaceID(s.Namespace().String()) validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, validToken) - s.Equal(http.StatusBadRequest, res.StatusCode) + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + OperationToken: strings.Repeat("long", 2000), + Header: nexus.Header{commonnexus.CallbackTokenHeader: validToken}, + } + + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Equal(0, len(snap["nexus_completion_request_preprocess_errors"])) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_bad_request"}) @@ -1306,22 +1325,24 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) // metrics collection is not initialized before callback validation // Send request without callback token, helper does not add token if blank - res, _, body := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, "") - + _, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, commonCompletion) // Verify we get the correct error response - s.Equal(http.StatusBadRequest, res.StatusCode) - s.Contains(string(body), "invalid callback token", "Response should indicate invalid callback token") + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + s.Contains(handlerErr.Error(), "invalid callback token", "Response should indicate invalid callback token") }) s.Run("InvalidCallbackTokenNoIdentifier", func() { publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier // metrics collection is not initialized before callback validation // Send request without callback token, helper does not add token if blank - res, _, body := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, "") - + _, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, commonCompletion) // Verify we get the correct error response - s.Equal(http.StatusBadRequest, res.StatusCode) - s.Contains(string(body), "invalid callback token", "Response should indicate invalid callback token") + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + s.Contains(handlerErr.Error(), "invalid callback token", "Response should indicate invalid callback token") }) s.Run("InvalidClientVersion", func() { @@ -1334,22 +1355,21 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, publicCallbackURL, completion) - s.NoError(err) - req.Header.Set("User-Agent", "Nexus-go-sdk/v99.0.0") - req.Header.Add(commonnexus.CallbackTokenHeader, validToken) - - res, err := http.DefaultClient.Do(req) - s.NoError(err) - _, err = io.ReadAll(res.Body) - s.NoError(err) - defer func() { - err := res.Body.Close() - s.NoError(err) - }() - + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + Header: nexus.Header{ + commonnexus.CallbackTokenHeader: validToken, + "user-agent": "Nexus-go-sdk/v99.0.0", + }, + } + client := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: commonnexus.PayloadSerializer, + }) + err = client.CompleteOperation(ctx, publicCallbackURL, completion) snap := capture.Snapshot() - s.Equal(http.StatusBadRequest, res.StatusCode) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unsupported_client"}) }) @@ -1364,19 +1384,22 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, publicCallbackURL, completion) - s.NoError(err) - req.Header.Set("User-Agent", "Nexus-go-sdk/v99.0.0") - req.Header.Add(commonnexus.CallbackTokenHeader, validToken) - - res, err := http.DefaultClient.Do(req) - s.NoError(err) - _, err = io.ReadAll(res.Body) - s.NoError(err) - defer res.Body.Close() + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + Header: nexus.Header{ + commonnexus.CallbackTokenHeader: validToken, + "user-agent": "Nexus-go-sdk/v99.0.0", + }, + } + client := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: commonnexus.PayloadSerializer, + }) + err = client.CompleteOperation(ctx, publicCallbackURL, completion) snap := capture.Snapshot() - s.Equal(http.StatusBadRequest, res.StatusCode) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unsupported_client"}) }) @@ -1394,19 +1417,21 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrors() { s.GetTestCluster().Host().SetOnAuthorize(onAuthorize) defer s.GetTestCluster().Host().SetOnAuthorize(nil) - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) - s.NoError(err) - // Generate a valid callback token for testing namespaceID := s.GetNamespaceID(s.Namespace().String()) callbackToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } + publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusForbidden, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unauthorized"}) } @@ -1423,19 +1448,20 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrorsNoId s.GetTestCluster().Host().SetOnAuthorize(onAuthorize) defer s.GetTestCluster().Host().SetOnAuthorize(nil) - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) - s.NoError(err) - // Generate a valid callback token for testing namespaceID := s.GetNamespaceID(s.Namespace().String()) callbackToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusForbidden, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unauthorized"}) } @@ -1875,14 +1901,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { } } s.True(seenStartedEvent) - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } + _, err = s.sendNexusCompletionRequest(ctx, publicCallbackUrl, completion) s.NoError(err) - res, _, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackUrl, completion, callbackToken) - s.Equal(http.StatusOK, res.StatusCode) - // Poll again and verify the completion is recorded and triggers workflow progress. pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ Namespace: s.Namespace().String(), @@ -3062,25 +3087,15 @@ func (s *NexusWorkflowTestSuite) generateValidCallbackToken(namespaceID, workflo func (s *NexusWorkflowTestSuite) sendNexusCompletionRequest( ctx context.Context, - t *testing.T, url string, completion nexusrpc.OperationCompletion, - callbackToken string, -) (*http.Response, map[string][]*metricstest.CapturedRecording, string) { +) (map[string][]*metricstest.CapturedRecording, error) { capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, url, completion) - require.NoError(t, err) - if callbackToken != "" { - req.Header.Add(commonnexus.CallbackTokenHeader, callbackToken) - } - res, err := http.DefaultClient.Do(req) - require.NoError(t, err) - responseBody := res.Body - body, err := io.ReadAll(responseBody) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - res.Body = io.NopCloser(bytes.NewReader(body)) - return res, capture.Snapshot(), string(body) + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: commonnexus.PayloadSerializer, + }) + err := c.CompleteOperation(ctx, url, completion) + return capture.Snapshot(), err } diff --git a/tests/xdc/nexus_request_forwarding_test.go b/tests/xdc/nexus_request_forwarding_test.go index 31da7c9a59..0454daf3dd 100644 --- a/tests/xdc/nexus_request_forwarding_test.go +++ b/tests/xdc/nexus_request_forwarding_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "slices" "testing" @@ -191,7 +190,7 @@ func (s *NexusRequestForwardingSuite) TestStartOperationForwardedFromStandbyToAc var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) - require.Equal(t, "Internal Server Error", handlerErr.Message) + require.Empty(t, handlerErr.Message) require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) requireExpectedMetricsCaptured(t, activeSnap, ns, "StartNexusOperation", "handler_error:INTERNAL") @@ -317,7 +316,7 @@ func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToA var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) - require.Equal(t, "Internal Server Error", handlerErr.Message) + require.Empty(t, handlerErr.Message) require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) requireExpectedMetricsCaptured(t, activeSnap, ns, "CancelNexusOperation", "handler_error:INTERNAL") @@ -382,16 +381,14 @@ func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToA func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandbyToActive() { testCases := []struct { name string - getCompletionFn func() (nexusrpc.OperationCompletion, error) + getCompletionFn func() nexusrpc.OperationCompletion assertHistoryAndGetCompleteWF func(*testing.T, []*historypb.HistoryEvent) *workflowservice.RespondWorkflowTaskCompletedRequest assertResult func(*testing.T, string) }{ { name: "success", - getCompletionFn: func() (nexusrpc.OperationCompletion, error) { - return nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: cnexus.PayloadSerializer, - }) + getCompletionFn: func() nexusrpc.OperationCompletion { + return &nexusrpc.OperationCompletionSuccessful{Result: s.mustToPayload("result")} }, assertHistoryAndGetCompleteWF: func(t *testing.T, events []*historypb.HistoryEvent) *workflowservice.RespondWorkflowTaskCompletedRequest { completedEventIdx := slices.IndexFunc(events, func(e *historypb.HistoryEvent) bool { @@ -415,11 +412,14 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb }, { name: "operation error", - getCompletionFn: func() (nexusrpc.OperationCompletion, error) { - f := nexus.Failure{Message: "intentional operation failure"} - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: f}}, - nexusrpc.OperationCompletionUnsuccessfulOptions{}) + getCompletionFn: func() nexusrpc.OperationCompletion { + opErr := &nexus.OperationError{ + State: nexus.OperationStateFailed, + Cause: &nexus.FailureError{Failure: nexus.Failure{Message: "intentional operation failure"}}, + } + return &nexusrpc.OperationCompletionUnsuccessful{ + Error: opErr, + } }, assertHistoryAndGetCompleteWF: func(t *testing.T, events []*historypb.HistoryEvent) *workflowservice.RespondWorkflowTaskCompletedRequest { failedEventIdx := slices.IndexFunc(events, func(e *historypb.HistoryEvent) bool { @@ -443,11 +443,11 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb }, { name: "canceled", - getCompletionFn: func() (nexusrpc.OperationCompletion, error) { + getCompletionFn: func() nexusrpc.OperationCompletion { f := nexus.Failure{Message: "operation canceled"} - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{State: nexus.OperationStateCanceled, Cause: &nexus.FailureError{Failure: f}}, - nexusrpc.OperationCompletionUnsuccessfulOptions{}) + return &nexusrpc.OperationCompletionUnsuccessful{ + Error: &nexus.OperationError{State: nexus.OperationStateCanceled, Cause: &nexus.FailureError{Failure: f}}, + } }, assertHistoryAndGetCompleteWF: func(t *testing.T, events []*historypb.HistoryEvent) *workflowservice.RespondWorkflowTaskCompletedRequest { canceledEventIdx := slices.IndexFunc(events, func(e *historypb.HistoryEvent) bool { @@ -596,10 +596,10 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb return err == nil && len(resp.PendingNexusOperations) > 0 }, 5*time.Second, 500*time.Millisecond) - completion, err := tc.getCompletionFn() + completion := tc.getCompletionFn() + completion.SetHeader(cnexus.CallbackTokenHeader, callbackToken) + snap, err := s.sendNexusCompletionRequest(ctx, s.T(), s.clusters[1], publicCallbackUrl, completion) s.NoError(err) - res, snap := s.sendNexusCompletionRequest(ctx, s.T(), s.clusters[1], publicCallbackUrl, completion, callbackToken) - s.Equal(http.StatusOK, res.StatusCode) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": ns, "outcome": "request_forwarded"}) @@ -614,8 +614,10 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb }) // Resend the request and verify we get a not found error since the operation has already completed. - res, snap = s.sendNexusCompletionRequest(ctx, s.T(), s.clusters[0], publicCallbackUrl, completion, callbackToken) - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err = s.sendNexusCompletionRequest(ctx, s.T(), s.clusters[0], publicCallbackUrl, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": ns, "outcome": "error_not_found"}) @@ -683,25 +685,17 @@ func (s *NexusRequestForwardingSuite) sendNexusCompletionRequest( testCluster *testcore.TestCluster, url string, completion nexusrpc.OperationCompletion, - callbackToken string, -) (*http.Response, map[string][]*metricstest.CapturedRecording) { +) (map[string][]*metricstest.CapturedRecording, error) { metricsHandler, ok := testCluster.Host().GetMetricsHandler().(*metricstest.CaptureHandler) s.True(ok) capture := metricsHandler.StartCapture() defer metricsHandler.StopCapture(capture) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, url, completion) - require.NoError(t, err) - if callbackToken != "" { - req.Header.Add(cnexus.CallbackTokenHeader, callbackToken) - } + err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: cnexus.PayloadSerializer, + }).CompleteOperation(ctx, url, completion) - res, err := http.DefaultClient.Do(req) - require.NoError(t, err) - _, err = io.ReadAll(res.Body) - require.NoError(t, err) - defer res.Body.Close() - return res, capture.Snapshot() + return capture.Snapshot(), err } func requireExpectedMetricsCaptured(t *testing.T, snap map[string][]*metricstest.CapturedRecording, ns string, method string, expectedOutcome string) { diff --git a/tests/xdc/nexus_state_replication_test.go b/tests/xdc/nexus_state_replication_test.go index 4d84b5cc7c..13805af62b 100644 --- a/tests/xdc/nexus_state_replication_test.go +++ b/tests/xdc/nexus_state_replication_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net/http" "net/http/httptest" "slices" @@ -715,40 +714,26 @@ func (s *NexusStateReplicationSuite) waitCallback( } func (s *NexusStateReplicationSuite) completeNexusOperation(ctx context.Context, result any, callbackUrl, callbackToken string) { - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload(result), nexusrpc.OperationCompletionSuccessfulOptions{ + completion := &nexusrpc.OperationCompletionSuccessful{ + Result: s.mustToPayload(result), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } + client := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ Serializer: commonnexus.PayloadSerializer, }) + err := client.CompleteOperation(ctx, callbackUrl, completion) s.NoError(err) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackUrl, completion) - s.NoError(err) - if callbackToken != "" { - req.Header.Add(commonnexus.CallbackTokenHeader, callbackToken) - } - - res, err := http.DefaultClient.Do(req) - s.NoError(err) - defer res.Body.Close() - _, err = io.ReadAll(res.Body) - s.NoError(err) - s.Equal(http.StatusOK, res.StatusCode) } func (s *NexusStateReplicationSuite) cancelNexusOperation(ctx context.Context, callbackUrl, callbackToken string) { - completion, err := nexusrpc.NewOperationCompletionUnsuccessful( - nexus.NewOperationCanceledError("operation canceled"), - nexusrpc.OperationCompletionUnsuccessfulOptions{}, - ) - s.NoError(err) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackUrl, completion) - s.NoError(err) + completion := &nexusrpc.OperationCompletionUnsuccessful{ + Error: nexus.NewOperationCanceledErrorf("operation canceled"), + } if callbackToken != "" { - req.Header.Add(commonnexus.CallbackTokenHeader, callbackToken) + completion.SetHeader(commonnexus.CallbackTokenHeader, callbackToken) } - - res, err := http.DefaultClient.Do(req) - s.NoError(err) - defer res.Body.Close() - _, err = io.ReadAll(res.Body) + err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: commonnexus.PayloadSerializer, + }).CompleteOperation(ctx, callbackUrl, completion) s.NoError(err) - s.Equal(http.StatusOK, res.StatusCode) } From ec8e7aea827559106f7a42b2a381bd921a6cff3c Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Tue, 10 Feb 2026 22:12:19 -0800 Subject: [PATCH 23/26] Replace OperationCompletion with CompleteOperationOptions --- chasm/lib/callback/chasm_invocation.go | 21 +-- chasm/lib/callback/component.go | 2 +- chasm/lib/callback/executors_test.go | 32 ++-- chasm/lib/callback/nexus_invocation.go | 2 +- chasm/lib/workflow/workflow.go | 2 +- chasm/ms_pointer.go | 2 +- chasm/node_backend_mock.go | 6 +- chasm/tree.go | 2 +- common/nexus/nexusrpc/completion.go | 161 ++++++------------ common/nexus/nexusrpc/completion_test.go | 10 +- components/callbacks/chasm_invocation.go | 21 +-- components/callbacks/executors_test.go | 28 +-- components/callbacks/nexus_invocation.go | 4 +- .../nexusoperations/frontend/handler.go | 6 +- .../history/workflow/mutable_state_impl.go | 32 ++-- .../workflow_test/mutable_state_impl_test.go | 44 +++-- tests/nexus_workflow_test.go | 28 +-- tests/xdc/nexus_request_forwarding_test.go | 16 +- tests/xdc/nexus_state_replication_test.go | 6 +- 19 files changed, 178 insertions(+), 247 deletions(-) diff --git a/chasm/lib/callback/chasm_invocation.go b/chasm/lib/callback/chasm_invocation.go index b0158db486..fe5248c325 100644 --- a/chasm/lib/callback/chasm_invocation.go +++ b/chasm/lib/callback/chasm_invocation.go @@ -27,7 +27,7 @@ import ( type chasmInvocation struct { nexus *callbackspb.Callback_Nexus attempt int32 - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions requestID string } @@ -134,14 +134,13 @@ func (c chasmInvocation) getHistoryRequest( RequestId: c.requestID, } - switch op := c.completion.(type) { - case *nexusrpc.OperationCompletionSuccessful: + if c.completion.Error == nil { var payload *commonpb.Payload - if op.Result != nil { + if c.completion.Result != nil { var ok bool - payload, ok = op.Result.(*commonpb.Payload) + payload, ok = c.completion.Result.(*commonpb.Payload) if !ok { - return nil, fmt.Errorf("invalid result, expected a payload, got: %T", op.Result) + return nil, fmt.Errorf("invalid result, expected a payload, got: %T", c.completion.Result) } } @@ -149,11 +148,11 @@ func (c chasmInvocation) getHistoryRequest( Outcome: &historyservice.CompleteNexusOperationChasmRequest_Success{ Success: payload, }, - CloseTime: timestamppb.New(op.CloseTime), + CloseTime: timestamppb.New(c.completion.CloseTime), Completion: completion, } - case *nexusrpc.OperationCompletionUnsuccessful: - failure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(op.Error) + } else { + failure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(c.completion.Error) if err != nil { return nil, fmt.Errorf("failed to convert error to failure: %w", err) } @@ -171,10 +170,8 @@ func (c chasmInvocation) getHistoryRequest( Outcome: &historyservice.CompleteNexusOperationChasmRequest_Failure{ Failure: apiFailure, }, - CloseTime: timestamppb.New(op.CloseTime), + CloseTime: timestamppb.New(c.completion.CloseTime), } - default: - return nil, fmt.Errorf("unexpected nexus.OperationCompletion: %v", completion) } return req, nil diff --git a/chasm/lib/callback/component.go b/chasm/lib/callback/component.go index a10734e19a..c288ff76c2 100644 --- a/chasm/lib/callback/component.go +++ b/chasm/lib/callback/component.go @@ -15,7 +15,7 @@ import ( ) type CompletionSource interface { - GetNexusCompletion(ctx chasm.Context, requestID string) (nexusrpc.OperationCompletion, error) + GetNexusCompletion(ctx chasm.Context, requestID string) (nexusrpc.CompleteOperationOptions, error) } var _ chasm.Component = (*Callback)(nil) diff --git a/chasm/lib/callback/executors_test.go b/chasm/lib/callback/executors_test.go index 8101b14048..42dfcfd48d 100644 --- a/chasm/lib/callback/executors_test.go +++ b/chasm/lib/callback/executors_test.go @@ -40,13 +40,13 @@ type mockNexusCompletionGetterComponent struct { Empty *emptypb.Empty - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions err error Callback chasm.Field[*Callback] } -func (m *mockNexusCompletionGetterComponent) GetNexusCompletion(_ chasm.Context, requestID string) (nexusrpc.OperationCompletion, error) { +func (m *mockNexusCompletionGetterComponent) GetNexusCompletion(_ chasm.Context, requestID string) (nexusrpc.CompleteOperationOptions, error) { return m.completion, m.err } @@ -210,7 +210,7 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { } // Create completion - completion := &nexusrpc.OperationCompletionSuccessful{} + completion := nexusrpc.CompleteOperationOptions{} // Set up the CompletionSource field to return our mock completion root.SetRootComponent(&mockNexusCompletionGetterComponent{ @@ -370,7 +370,7 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { cases := []struct { name string setupHistoryClient func(*testing.T, *gomock.Controller) resource.HistoryClient - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions headerValue string assertOutcome func(*testing.T, *Callback, error) }{ @@ -396,8 +396,8 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { }) return client }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionSuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Result: createPayloadBytes([]byte("result-data")), CloseTime: dummyTime, } @@ -424,8 +424,8 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { }) return client }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionUnsuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Error: &nexus.OperationError{ State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: nexus.Failure{Message: "operation failed"}}, @@ -449,8 +449,8 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { ).Return(nil, status.Error(codes.Unavailable, "service unavailable")) return client }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionSuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Result: createPayloadBytes([]byte("result-data")), } }(), @@ -470,8 +470,8 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { ).Return(nil, status.Error(codes.InvalidArgument, "invalid request")) return client }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionSuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Result: createPayloadBytes([]byte("result-data")), } }(), @@ -487,8 +487,8 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { // No RPC call expected return historyservicemock.NewMockHistoryServiceClient(ctrl) }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionSuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Result: createPayloadBytes([]byte("result-data")), } }(), @@ -504,8 +504,8 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { // No RPC call expected return historyservicemock.NewMockHistoryServiceClient(ctrl) }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionSuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Result: createPayloadBytes([]byte("result-data")), } }(), diff --git a/chasm/lib/callback/nexus_invocation.go b/chasm/lib/callback/nexus_invocation.go index ceee0f8a08..63dc9b4735 100644 --- a/chasm/lib/callback/nexus_invocation.go +++ b/chasm/lib/callback/nexus_invocation.go @@ -27,7 +27,7 @@ var retryable4xxErrorTypes = []int{ type nexusInvocation struct { nexus *callbackspb.Callback_Nexus - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions workflowID, runID string attempt int32 } diff --git a/chasm/lib/workflow/workflow.go b/chasm/lib/workflow/workflow.go index 6e0acdd063..bf609a40b2 100644 --- a/chasm/lib/workflow/workflow.go +++ b/chasm/lib/workflow/workflow.go @@ -117,7 +117,7 @@ func (w *Workflow) AddCompletionCallbacks( func (w *Workflow) GetNexusCompletion( ctx chasm.Context, requestID string, -) (nexusrpc.OperationCompletion, error) { +) (nexusrpc.CompleteOperationOptions, error) { // Retrieve the completion data from the underlying mutable state via MSPointer return w.MSPointer.GetNexusCompletion(ctx, requestID) } diff --git a/chasm/ms_pointer.go b/chasm/ms_pointer.go index 75013b4a62..ff35e4afe5 100644 --- a/chasm/ms_pointer.go +++ b/chasm/ms_pointer.go @@ -20,6 +20,6 @@ func NewMSPointer(backend NodeBackend) MSPointer { } // GetNexusCompletion retrieves the Nexus operation completion data for the given request ID from the underlying mutable state. -func (m MSPointer) GetNexusCompletion(ctx Context, requestID string) (nexusrpc.OperationCompletion, error) { +func (m MSPointer) GetNexusCompletion(ctx Context, requestID string) (nexusrpc.CompleteOperationOptions, error) { return m.backend.GetNexusCompletion(ctx.getContext(), requestID) } diff --git a/chasm/node_backend_mock.go b/chasm/node_backend_mock.go index eee5aca836..7fa090b15d 100644 --- a/chasm/node_backend_mock.go +++ b/chasm/node_backend_mock.go @@ -26,7 +26,7 @@ type MockNodeBackend struct { HandleGetWorkflowKey func() definition.WorkflowKey HandleUpdateWorkflowStateStatus func(state enumsspb.WorkflowExecutionState, status enumspb.WorkflowExecutionStatus) (bool, error) HandleIsWorkflow func() bool - HandleGetNexusCompletion func(ctx context.Context, requestID string) (nexusrpc.OperationCompletion, error) + HandleGetNexusCompletion func(ctx context.Context, requestID string) (nexusrpc.CompleteOperationOptions, error) // Recorded calls (protected by mu). mu sync.Mutex @@ -164,11 +164,11 @@ func (m *MockNodeBackend) IsWorkflow() bool { func (m *MockNodeBackend) GetNexusCompletion( ctx context.Context, requestID string, -) (nexusrpc.OperationCompletion, error) { +) (nexusrpc.CompleteOperationOptions, error) { if m.HandleGetNexusCompletion != nil { return m.HandleGetNexusCompletion(ctx, requestID) } - return nil, nil + return nexusrpc.CompleteOperationOptions{}, nil } func (m *MockNodeBackend) NumTasksAdded() int { diff --git a/chasm/tree.go b/chasm/tree.go index 34cad3cfb6..94d8440dce 100644 --- a/chasm/tree.go +++ b/chasm/tree.go @@ -199,7 +199,7 @@ type ( GetNexusCompletion( ctx context.Context, requestID string, - ) (nexusrpc.OperationCompletion, error) + ) (nexusrpc.CompleteOperationOptions, error) } // NodePathEncoder is an interface for encoding and decoding node paths. diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index b28569c3cd..efb9094c47 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -52,7 +52,7 @@ func NewCompletionHTTPClient(options CompletionHTTPClientOptions) *CompletionHTT } // CompleteOperation sends a completion callback for a Nexus operation to the given URL with the given completion details. -func (c *CompletionHTTPClient) CompleteOperation(ctx context.Context, url string, completion OperationCompletion) error { +func (c *CompletionHTTPClient) CompleteOperation(ctx context.Context, url string, completion CompleteOperationOptions) error { httpReq, err := http.NewRequestWithContext(ctx, "POST", url, nil) if err != nil { return err @@ -86,140 +86,88 @@ func (c *CompletionHTTPClient) CompleteOperation(ctx context.Context, url string return c.bestEffortHandlerErrorFromResponse(response, body) } -// OperationCompletion is input for CompleteOperation. -// It has two implementations: [OperationCompletionSuccessful] and [OperationCompletionUnsuccessful]. -type OperationCompletion interface { - SetHeader(key, value string) - applyToHTTPRequest(*CompletionHTTPClient, *http.Request) error -} - -// OperationCompletionSuccessful is input for CompleteOperation, used to deliver successful operation results. -type OperationCompletionSuccessful struct { +// CompleteOperationOptions is input for CompleteOperation. +// If Error is set, the completion is unsuccessful. Otherwise, it is successful. +type CompleteOperationOptions struct { // Header to send in the completion request. // Note that this is a Nexus header, not an HTTP header. Header nexus.Header - // Result to deliver with the completion. Uses the client's serializer to serialize the result into the request body. - Result any // OperationToken is the unique token for this operation. Used when a completion callback is received before a // started response. OperationToken string // StartTime is the time the operation started. Used when a completion callback is received before a started response. StartTime time.Time - // CloseTime is the time the operation completed. Used when a completion callback is received before a started response. + // CloseTime is the time the operation completed. This may be different from the time the completion callback is delivered. CloseTime time.Time // Links are used to link back to the operation when a completion callback is received before a started response. Links []nexus.Link + // Error to send with the completion. If set, the completion is unsuccessful. + Error *nexus.OperationError + // Result to deliver with the completion. Uses the client's serializer to serialize the result into the request body. + // Only used for successful completions. + Result any } -func (c *OperationCompletionSuccessful) SetHeader(key, value string) { +func (c CompleteOperationOptions) SetHeader(key, value string) { if c.Header == nil { c.Header = make(nexus.Header, 1) } c.Header[key] = value } -func (c *OperationCompletionSuccessful) applyToHTTPRequest(cc *CompletionHTTPClient, request *http.Request) error { - reader, ok := c.Result.(*nexus.Reader) - if !ok { - content, ok := c.Result.(*nexus.Content) - if !ok { - var err error - content, err = cc.serializer.Serialize(c.Result) - if err != nil { - return err - } - } - request.ContentLength = int64(len(content.Data)) - - reader = &nexus.Reader{ - Header: content.Header, - ReadCloser: io.NopCloser(bytes.NewReader(content.Data)), - } - } +func (c CompleteOperationOptions) applyToHTTPRequest(cc *CompletionHTTPClient, request *http.Request) error { if request.Header == nil { - request.Header = make(http.Header, len(c.Header)+len(reader.Header)+1) // +1 for headerOperationState - } - if reader.Header != nil { - addContentHeaderToHTTPHeader(reader.Header, request.Header) + request.Header = make(http.Header) } - if c.Header != nil { - addNexusHeaderToHTTPHeader(c.Header, request.Header) - } - if c.Header.Get(headerUserAgent) == "" { - request.Header.Set(headerUserAgent, userAgent) - } - - request.Header.Set(headerOperationState, string(nexus.OperationStateSucceeded)) - if c.Header.Get(nexus.HeaderOperationToken) == "" && c.OperationToken != "" { - request.Header.Set(nexus.HeaderOperationToken, c.OperationToken) - } - if c.Header.Get(headerOperationStartTime) == "" && !c.StartTime.IsZero() { - request.Header.Set(headerOperationStartTime, c.StartTime.Format(http.TimeFormat)) - } - if c.Header.Get(headerOperationCloseTime) == "" && !c.CloseTime.IsZero() { - request.Header.Set(headerOperationCloseTime, marshalTimestamp(c.CloseTime)) - } - if c.Header.Get(headerLink) == "" { - if err := addLinksToHTTPHeader(c.Links, request.Header); err != nil { + // Set the body and operation state based on whether the completion is successful or not. + if c.Error != nil { + failure, err := cc.failureConverter.ErrorToFailure(c.Error) + if err != nil { return err } + // Backwards compatibility: if the failure has a cause, unwrap it to maintain the behavior as older servers. + if failure.Cause != nil { + failure = *failure.Cause + } + b, err := json.Marshal(failure) + if err != nil { + return err + } + request.Body = io.NopCloser(bytes.NewReader(b)) + // Set the operation state header for backwards compatibility. + request.Header.Set(headerOperationState, string(c.Error.State)) + request.Header.Set("Content-Type", contentTypeJSON) + } else { + reader, ok := c.Result.(*nexus.Reader) + if !ok { + content, ok := c.Result.(*nexus.Content) + if !ok { + var err error + content, err = cc.serializer.Serialize(c.Result) + if err != nil { + return err + } + } + request.ContentLength = int64(len(content.Data)) + reader = &nexus.Reader{ + Header: content.Header, + ReadCloser: io.NopCloser(bytes.NewReader(content.Data)), + } + } + if reader.Header != nil { + addContentHeaderToHTTPHeader(reader.Header, request.Header) + } + request.Body = reader.ReadCloser + request.Header.Set(headerOperationState, string(nexus.OperationStateSucceeded)) } - request.Body = reader.ReadCloser - return nil -} - -// OperationCompletionUnsuccessful is input for [NewCompletionHTTPRequest], used to deliver unsuccessful operation -// results. -type OperationCompletionUnsuccessful struct { - // Header to send in the completion request. - // Note that this is a Nexus header, not an HTTP header. - Header nexus.Header - // OperationToken is the unique token for this operation. Used when a completion callback is received before a - // started response. - OperationToken string - // StartTime is the time the operation started. Used when a completion callback is received before a started response. - StartTime time.Time - // CloseTime is the time the operation completed. This may be different from the time the completion callback is delivered. - CloseTime time.Time - // Links are used to link back to the operation when a completion callback is received before a started response. - Links []nexus.Link - // Error to send with the completion. - Error *nexus.OperationError -} - -func (c *OperationCompletionUnsuccessful) SetHeader(key, value string) { - if c.Header == nil { - c.Header = make(nexus.Header, 1) - } - c.Header[key] = value -} - -func (c *OperationCompletionUnsuccessful) applyToHTTPRequest(cc *CompletionHTTPClient, request *http.Request) error { - failure, err := cc.failureConverter.ErrorToFailure(c.Error) - if err != nil { - return err - } - // Backwards compatibility: if the failure has a cause, unwrap it to maintain the behavior as older servers. - if failure.Cause != nil { - failure = *failure.Cause - } - - if request.Header == nil { - request.Header = make(http.Header, len(c.Header)+2) // +2 for headerOperationState and content-type - } if c.Header != nil { addNexusHeaderToHTTPHeader(c.Header, request.Header) } if c.Header.Get(headerUserAgent) == "" { request.Header.Set(headerUserAgent, userAgent) } - - // Set the operation state header for backwards compatibility. - request.Header.Set(headerOperationState, string(c.Error.State)) - request.Header.Set("Content-Type", contentTypeJSON) - if c.Header.Get(nexus.HeaderOperationToken) == "" && c.OperationToken != "" { request.Header.Set(nexus.HeaderOperationToken, c.OperationToken) } @@ -234,13 +182,6 @@ func (c *OperationCompletionUnsuccessful) applyToHTTPRequest(cc *CompletionHTTPC return err } } - - b, err := json.Marshal(failure) - if err != nil { - return err - } - - request.Body = io.NopCloser(bytes.NewReader(b)) return nil } diff --git a/common/nexus/nexusrpc/completion_test.go b/common/nexus/nexusrpc/completion_test.go index dc075b4f47..6a55c3f331 100644 --- a/common/nexus/nexusrpc/completion_test.go +++ b/common/nexus/nexusrpc/completion_test.go @@ -75,7 +75,7 @@ func TestSuccessfulCompletion(t *testing.T) { }, nil, nil) defer teardown() - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: 666, OperationToken: "test-operation-token", StartTime: startTime, @@ -101,7 +101,7 @@ func TestSuccessfulCompletion_CustomSerializer(t *testing.T) { ctx, callbackURL, teardown := setupForCompletion(t, &successfulCompletionHandler{}, serializer, nil) defer teardown() - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: 666, Links: []nexus.Link{{ URL: &url.URL{ @@ -175,7 +175,7 @@ func TestFailureCompletion(t *testing.T) { }, nil, nil) defer teardown() - completion := &nexusrpc.OperationCompletionUnsuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Error: nexus.NewOperationCanceledErrorf("expected message"), OperationToken: "test-operation-token", StartTime: startTime, @@ -212,7 +212,7 @@ func TestFailureCompletion_CustomFailureConverter(t *testing.T) { }, nil, fc) defer teardown() - completion := &nexusrpc.OperationCompletionUnsuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Error: nexus.NewOperationCanceledErrorf("expected message"), OperationToken: "test-operation-token", StartTime: startTime, @@ -245,7 +245,7 @@ func TestBadRequestCompletion(t *testing.T) { ctx, callbackURL, teardown := setupForCompletion(t, &failingCompletionHandler{}, nil, nil) defer teardown() - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: []byte("success"), } err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}).CompleteOperation(ctx, callbackURL, completion) diff --git a/components/callbacks/chasm_invocation.go b/components/callbacks/chasm_invocation.go index a88c258ab6..bb0419287c 100644 --- a/components/callbacks/chasm_invocation.go +++ b/components/callbacks/chasm_invocation.go @@ -24,7 +24,7 @@ import ( type chasmInvocation struct { nexus *persistencespb.Callback_Nexus attempt int32 - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions requestID string } @@ -86,14 +86,13 @@ func (c chasmInvocation) getHistoryRequest( RequestId: c.requestID, } - switch op := c.completion.(type) { - case *nexusrpc.OperationCompletionSuccessful: + if c.completion.Error == nil { var payload *commonpb.Payload - if op.Result != nil { + if c.completion.Result != nil { var ok bool - payload, ok = op.Result.(*commonpb.Payload) + payload, ok = c.completion.Result.(*commonpb.Payload) if !ok { - return nil, fmt.Errorf("invalid result, expected a payload, got: %T", op.Result) + return nil, fmt.Errorf("invalid result, expected a payload, got: %T", c.completion.Result) } } @@ -101,11 +100,11 @@ func (c chasmInvocation) getHistoryRequest( Outcome: &historyservice.CompleteNexusOperationChasmRequest_Success{ Success: payload, }, - CloseTime: timestamppb.New(op.CloseTime), + CloseTime: timestamppb.New(c.completion.CloseTime), Completion: completion, } - case *nexusrpc.OperationCompletionUnsuccessful: - failure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(op.Error) + } else { + failure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(c.completion.Error) if err != nil { return nil, fmt.Errorf("failed to convert error to failure: %w", err) } @@ -123,10 +122,8 @@ func (c chasmInvocation) getHistoryRequest( Outcome: &historyservice.CompleteNexusOperationChasmRequest_Failure{ Failure: apiFailure, }, - CloseTime: timestamppb.New(op.CloseTime), + CloseTime: timestamppb.New(c.completion.CloseTime), } - default: - return nil, fmt.Errorf("unexpected nexus.OperationCompletion: %v", completion) } return req, nil diff --git a/components/callbacks/executors_test.go b/components/callbacks/executors_test.go index 18d562dd54..1a0015b58c 100644 --- a/components/callbacks/executors_test.go +++ b/components/callbacks/executors_test.go @@ -54,11 +54,11 @@ func (fakeEnv) Now() time.Time { var _ hsm.Environment = fakeEnv{} type mutableState struct { - completionNexus nexusrpc.OperationCompletion + completionNexus nexusrpc.CompleteOperationOptions completionHsm *persistencespb.HSMCompletionCallbackArg } -func (ms mutableState) GetNexusCompletion(ctx context.Context, requestID string) (nexusrpc.OperationCompletion, error) { +func (ms mutableState) GetNexusCompletion(ctx context.Context, requestID string) (nexusrpc.CompleteOperationOptions, error) { return ms.completionNexus, nil } @@ -267,7 +267,7 @@ func TestProcessBackoffTask(t *testing.T) { } func newMutableState(t *testing.T) mutableState { - completionNexus := &nexusrpc.OperationCompletionSuccessful{} + completionNexus := nexusrpc.CompleteOperationOptions{} hsmCallbackArg := &persistencespb.HSMCompletionCallbackArg{ NamespaceId: "mynsid", WorkflowId: "mywid", @@ -314,7 +314,7 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { cases := []struct { name string setupHistoryClient func(*testing.T, *gomock.Controller) *historyservicemock.MockHistoryServiceClient - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions headerValue string expectsInternalError bool assertOutcome func(*testing.T, callbacks.Callback) @@ -347,8 +347,8 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { }) return client }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionSuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Result: createPayload([]byte("result-data")), CloseTime: dummyTime, } @@ -374,8 +374,8 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { }) return client }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionUnsuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Error: &nexus.OperationError{ State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: nexus.Failure{Message: "operation failed"}}, @@ -398,8 +398,8 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { ).Return(nil, status.Error(codes.Unavailable, "service unavailable")) return client }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionSuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Result: createPayload([]byte("result-data")), } }(), @@ -419,8 +419,8 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { ).Return(nil, status.Error(codes.InvalidArgument, "invalid request")) return client }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionSuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Result: createPayload([]byte("result-data")), } }(), @@ -436,8 +436,8 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { // No RPC call expected return historyservicemock.NewMockHistoryServiceClient(ctrl) }, - completion: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionSuccessful{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ Result: createPayload([]byte("result-data")), } }(), diff --git a/components/callbacks/nexus_invocation.go b/components/callbacks/nexus_invocation.go index 68866dfc35..43a2ed231f 100644 --- a/components/callbacks/nexus_invocation.go +++ b/components/callbacks/nexus_invocation.go @@ -25,12 +25,12 @@ var retryable4xxErrorTypes = []int{ } type CanGetNexusCompletion interface { - GetNexusCompletion(ctx context.Context, requestID string) (nexusrpc.OperationCompletion, error) + GetNexusCompletion(ctx context.Context, requestID string) (nexusrpc.CompleteOperationOptions, error) } type nexusInvocation struct { nexus *persistencespb.Callback_Nexus - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions workflowID, runID string attempt int32 } diff --git a/components/nexusoperations/frontend/handler.go b/components/nexusoperations/frontend/handler.go index b4f6f495fa..1c95fe439e 100644 --- a/components/nexusoperations/frontend/handler.go +++ b/components/nexusoperations/frontend/handler.go @@ -261,11 +261,11 @@ func (h *completionHandler) forwardCompleteOperation(ctx context.Context, r *nex } } - var completion nexusrpc.OperationCompletion + var completion nexusrpc.CompleteOperationOptions switch r.State { case nexus.OperationStateSucceeded: - completion = &nexusrpc.OperationCompletionSuccessful{ + completion = nexusrpc.CompleteOperationOptions{ Result: r.Result.Reader, OperationToken: r.OperationToken, StartTime: r.StartTime, @@ -275,7 +275,7 @@ func (h *completionHandler) forwardCompleteOperation(ctx context.Context, r *nex case nexus.OperationStateFailed, nexus.OperationStateCanceled: // For unsuccessful operations, the Nexus framework reads and closes the original request body to deserialize // the failure, so we must construct a new completion to forward. - completion = &nexusrpc.OperationCompletionUnsuccessful{ + completion = nexusrpc.CompleteOperationOptions{ Error: r.Error, OperationToken: r.OperationToken, StartTime: r.StartTime, diff --git a/service/history/workflow/mutable_state_impl.go b/service/history/workflow/mutable_state_impl.go index b43bdb16cb..81df2b89eb 100644 --- a/service/history/workflow/mutable_state_impl.go +++ b/service/history/workflow/mutable_state_impl.go @@ -722,10 +722,10 @@ func (ms *MutableStateImpl) ChasmWorkflowComponentReadOnly(ctx context.Context) func (ms *MutableStateImpl) GetNexusCompletion( ctx context.Context, requestID string, -) (nexusrpc.OperationCompletion, error) { +) (nexusrpc.CompleteOperationOptions, error) { ce, err := ms.GetCompletionEvent(ctx) if err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } // Create the link information about the workflow to be attached to fabricated started event if completion is @@ -767,7 +767,7 @@ func (ms *MutableStateImpl) GetNexusCompletion( // Nexus does not support it. p = payloads[0] } - return &nexusrpc.OperationCompletionSuccessful{ + return nexusrpc.CompleteOperationOptions{ Result: p, StartTime: ms.executionState.GetStartTime().AsTime(), CloseTime: ce.GetEventTime().AsTime(), @@ -776,7 +776,7 @@ func (ms *MutableStateImpl) GetNexusCompletion( case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_FAILED: f, err := commonnexus.TemporalFailureToNexusFailure(ce.GetWorkflowExecutionFailedEventAttributes().GetFailure()) if err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } opErr := &nexus.OperationError{ Message: "operation failed", @@ -784,9 +784,9 @@ func (ms *MutableStateImpl) GetNexusCompletion( Cause: &nexus.FailureError{Failure: f}, } if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } - return &nexusrpc.OperationCompletionUnsuccessful{ + return nexusrpc.CompleteOperationOptions{ Error: opErr, StartTime: ms.executionState.GetStartTime().AsTime(), CloseTime: ce.GetEventTime().AsTime(), @@ -802,7 +802,7 @@ func (ms *MutableStateImpl) GetNexusCompletion( }, }) if err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } opErr := &nexus.OperationError{ State: nexus.OperationStateCanceled, @@ -810,9 +810,9 @@ func (ms *MutableStateImpl) GetNexusCompletion( Cause: &nexus.FailureError{Failure: f}, } if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } - return &nexusrpc.OperationCompletionUnsuccessful{ + return nexusrpc.CompleteOperationOptions{ Error: opErr, StartTime: ms.executionState.GetStartTime().AsTime(), CloseTime: ce.GetEventTime().AsTime(), @@ -826,7 +826,7 @@ func (ms *MutableStateImpl) GetNexusCompletion( }, }) if err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } opErr := &nexus.OperationError{ State: nexus.OperationStateFailed, @@ -834,9 +834,9 @@ func (ms *MutableStateImpl) GetNexusCompletion( Cause: &nexus.FailureError{Failure: f}, } if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } - return &nexusrpc.OperationCompletionUnsuccessful{ + return nexusrpc.CompleteOperationOptions{ Error: opErr, StartTime: ms.executionState.GetStartTime().AsTime(), CloseTime: ce.GetEventTime().AsTime(), @@ -853,7 +853,7 @@ func (ms *MutableStateImpl) GetNexusCompletion( }, }) if err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } opErr := &nexus.OperationError{ State: nexus.OperationStateFailed, @@ -861,16 +861,16 @@ func (ms *MutableStateImpl) GetNexusCompletion( Cause: &nexus.FailureError{Failure: f}, } if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } - return &nexusrpc.OperationCompletionUnsuccessful{ + return nexusrpc.CompleteOperationOptions{ Error: opErr, StartTime: ms.executionState.GetStartTime().AsTime(), CloseTime: ce.GetEventTime().AsTime(), Links: []nexus.Link{startLink}, }, nil } - return nil, serviceerror.NewInternalf("invalid workflow execution status: %v", ce.GetEventType()) + return nexusrpc.CompleteOperationOptions{}, serviceerror.NewInternalf("invalid workflow execution status: %v", ce.GetEventType()) } // GetHSMCallbackArg converts a workflow completion event into a [persistencespb.HSMCallbackArg]. diff --git a/service/history/workflow/workflow_test/mutable_state_impl_test.go b/service/history/workflow/workflow_test/mutable_state_impl_test.go index 46384b0721..16d058d295 100644 --- a/service/history/workflow/workflow_test/mutable_state_impl_test.go +++ b/service/history/workflow/workflow_test/mutable_state_impl_test.go @@ -362,7 +362,7 @@ func TestGetNexusCompletion(t *testing.T) { cases := []struct { name string mutateState func(historyi.MutableState) (*historypb.HistoryEvent, error) - verifyCompletion func(*testing.T, *historypb.HistoryEvent, nexusrpc.OperationCompletion) + verifyCompletion func(*testing.T, *historypb.HistoryEvent, nexusrpc.CompleteOperationOptions) }{ { name: "success", @@ -378,14 +378,13 @@ func TestGetNexusCompletion(t *testing.T) { }, }, "") }, - verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { - success, ok := completion.(*nexusrpc.OperationCompletionSuccessful) - require.True(t, ok) + verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { + require.Nil(t, completion.Error) require.Equal(t, commonpb.Payload{ Metadata: map[string][]byte{"encoding": []byte("json/plain")}, Data: []byte("3"), - }, success.Result) - require.Equal(t, event.GetEventTime().AsTime(), success.CloseTime) + }, completion.Result) + require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, { @@ -397,12 +396,11 @@ func TestGetNexusCompletion(t *testing.T) { }, }, "") }, - verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { - failure, ok := completion.(*nexusrpc.OperationCompletionUnsuccessful) - require.True(t, ok) - require.Equal(t, nexus.OperationStateFailed, failure.Error.State) - require.Equal(t, "workflow failed", failure.Error.Message) - require.Equal(t, event.GetEventTime().AsTime(), failure.CloseTime) + verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { + require.NotNil(t, completion.Error) + require.Equal(t, nexus.OperationStateFailed, completion.Error.State) + require.Equal(t, "workflow failed", completion.Error.Message) + require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, { @@ -410,12 +408,11 @@ func TestGetNexusCompletion(t *testing.T) { mutateState: func(mutableState historyi.MutableState) (*historypb.HistoryEvent, error) { return mutableState.AddWorkflowExecutionTerminatedEvent(mutableState.GetNextEventID(), "dont care", nil, "identity", false, nil) }, - verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { - failure, ok := completion.(*nexusrpc.OperationCompletionUnsuccessful) - require.True(t, ok) - require.Equal(t, nexus.OperationStateFailed, failure.Error.State) - require.Equal(t, "operation terminated", failure.Error.Message) - require.Equal(t, event.GetEventTime().AsTime(), failure.CloseTime) + verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { + require.NotNil(t, completion.Error) + require.Equal(t, nexus.OperationStateFailed, completion.Error.State) + require.Equal(t, "operation terminated", completion.Error.Message) + require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, { @@ -423,12 +420,11 @@ func TestGetNexusCompletion(t *testing.T) { mutateState: func(mutableState historyi.MutableState) (*historypb.HistoryEvent, error) { return mutableState.AddWorkflowExecutionCanceledEvent(mutableState.GetNextEventID(), &commandpb.CancelWorkflowExecutionCommandAttributes{}) }, - verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { - failure, ok := completion.(*nexusrpc.OperationCompletionUnsuccessful) - require.True(t, ok) - require.Equal(t, nexus.OperationStateCanceled, failure.Error.State) - require.Equal(t, "operation canceled", failure.Error.Message) - require.Equal(t, event.GetEventTime().AsTime(), failure.CloseTime) + verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { + require.NotNil(t, completion.Error) + require.Equal(t, nexus.OperationStateCanceled, completion.Error.State) + require.Equal(t, "operation canceled", completion.Error.Message) + require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, } diff --git a/tests/nexus_workflow_test.go b/tests/nexus_workflow_test.go index 5fc5f5b97e..bf4dcdf61f 100644 --- a/tests/nexus_workflow_test.go +++ b/tests/nexus_workflow_test.go @@ -653,7 +653,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { protorequire.ProtoEqual(s.T(), handlerLink, l) // Completion request fails if the result payload is too large. - largeCompletion := &nexusrpc.OperationCompletionSuccessful{ + largeCompletion := nexusrpc.CompleteOperationOptions{ // Use -10 to avoid hitting MaxNexusAPIRequestBodyBytes. Actual payload will still exceed limit because of // additional Content headers. See common/rpc/grpc.go:66 Result: s.mustToPayload(strings.Repeat("a", (2*1024*1024)-10)), @@ -677,7 +677,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { // Send an invalid completion request and verify that we get an error that the namespace in the URL doesn't match the namespace in the token. invalidCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(invalidNamespace) - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } @@ -1167,7 +1167,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { s.Greater(startedEventIdx, 0) // Send a valid - failed completion request. - completion := &nexusrpc.OperationCompletionUnsuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Error: nexus.NewOperationFailedErrorf("test operation failed"), Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } @@ -1219,7 +1219,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { ctx := testcore.NewContext() - commonCompletion := &nexusrpc.OperationCompletionSuccessful{ + commonCompletion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), } @@ -1249,7 +1249,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.NoError(err) publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path("namespace-doesnt-exist") - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: tokenWithBadNamespace}, } @@ -1266,7 +1266,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.NoError(err) publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: tokenWithBadNamespace}, } @@ -1284,7 +1284,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { namespaceID := s.GetNamespaceID(s.Namespace().String()) validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), OperationToken: strings.Repeat("long", 2000), Header: nexus.Header{commonnexus.CallbackTokenHeader: validToken}, @@ -1306,7 +1306,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), OperationToken: strings.Repeat("long", 2000), Header: nexus.Header{commonnexus.CallbackTokenHeader: validToken}, @@ -1355,7 +1355,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), Header: nexus.Header{ commonnexus.CallbackTokenHeader: validToken, @@ -1384,7 +1384,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), Header: nexus.Header{ commonnexus.CallbackTokenHeader: validToken, @@ -1422,7 +1422,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrors() { callbackToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } @@ -1453,7 +1453,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrorsNoId callbackToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } @@ -1901,7 +1901,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { } } s.True(seenStartedEvent) - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload("result"), Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } @@ -3088,7 +3088,7 @@ func (s *NexusWorkflowTestSuite) generateValidCallbackToken(namespaceID, workflo func (s *NexusWorkflowTestSuite) sendNexusCompletionRequest( ctx context.Context, url string, - completion nexusrpc.OperationCompletion, + completion nexusrpc.CompleteOperationOptions, ) (map[string][]*metricstest.CapturedRecording, error) { capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) diff --git a/tests/xdc/nexus_request_forwarding_test.go b/tests/xdc/nexus_request_forwarding_test.go index 0454daf3dd..06eb0fcc0a 100644 --- a/tests/xdc/nexus_request_forwarding_test.go +++ b/tests/xdc/nexus_request_forwarding_test.go @@ -381,14 +381,14 @@ func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToA func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandbyToActive() { testCases := []struct { name string - getCompletionFn func() nexusrpc.OperationCompletion + getCompletionFn func() nexusrpc.CompleteOperationOptions assertHistoryAndGetCompleteWF func(*testing.T, []*historypb.HistoryEvent) *workflowservice.RespondWorkflowTaskCompletedRequest assertResult func(*testing.T, string) }{ { name: "success", - getCompletionFn: func() nexusrpc.OperationCompletion { - return &nexusrpc.OperationCompletionSuccessful{Result: s.mustToPayload("result")} + getCompletionFn: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{Result: s.mustToPayload("result")} }, assertHistoryAndGetCompleteWF: func(t *testing.T, events []*historypb.HistoryEvent) *workflowservice.RespondWorkflowTaskCompletedRequest { completedEventIdx := slices.IndexFunc(events, func(e *historypb.HistoryEvent) bool { @@ -412,12 +412,12 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb }, { name: "operation error", - getCompletionFn: func() nexusrpc.OperationCompletion { + getCompletionFn: func() nexusrpc.CompleteOperationOptions { opErr := &nexus.OperationError{ State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: nexus.Failure{Message: "intentional operation failure"}}, } - return &nexusrpc.OperationCompletionUnsuccessful{ + return nexusrpc.CompleteOperationOptions{ Error: opErr, } }, @@ -443,9 +443,9 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb }, { name: "canceled", - getCompletionFn: func() nexusrpc.OperationCompletion { + getCompletionFn: func() nexusrpc.CompleteOperationOptions { f := nexus.Failure{Message: "operation canceled"} - return &nexusrpc.OperationCompletionUnsuccessful{ + return nexusrpc.CompleteOperationOptions{ Error: &nexus.OperationError{State: nexus.OperationStateCanceled, Cause: &nexus.FailureError{Failure: f}}, } }, @@ -684,7 +684,7 @@ func (s *NexusRequestForwardingSuite) sendNexusCompletionRequest( t *testing.T, testCluster *testcore.TestCluster, url string, - completion nexusrpc.OperationCompletion, + completion nexusrpc.CompleteOperationOptions, ) (map[string][]*metricstest.CapturedRecording, error) { metricsHandler, ok := testCluster.Host().GetMetricsHandler().(*metricstest.CaptureHandler) s.True(ok) diff --git a/tests/xdc/nexus_state_replication_test.go b/tests/xdc/nexus_state_replication_test.go index 13805af62b..38034b95df 100644 --- a/tests/xdc/nexus_state_replication_test.go +++ b/tests/xdc/nexus_state_replication_test.go @@ -714,7 +714,7 @@ func (s *NexusStateReplicationSuite) waitCallback( } func (s *NexusStateReplicationSuite) completeNexusOperation(ctx context.Context, result any, callbackUrl, callbackToken string) { - completion := &nexusrpc.OperationCompletionSuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Result: s.mustToPayload(result), Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, } @@ -726,11 +726,11 @@ func (s *NexusStateReplicationSuite) completeNexusOperation(ctx context.Context, } func (s *NexusStateReplicationSuite) cancelNexusOperation(ctx context.Context, callbackUrl, callbackToken string) { - completion := &nexusrpc.OperationCompletionUnsuccessful{ + completion := nexusrpc.CompleteOperationOptions{ Error: nexus.NewOperationCanceledErrorf("operation canceled"), } if callbackToken != "" { - completion.SetHeader(commonnexus.CallbackTokenHeader, callbackToken) + completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} } err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ Serializer: commonnexus.PayloadSerializer, From 1a0ce0aeac5d2e1ac81fac35a619f4fb62a5d3cf Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Tue, 10 Feb 2026 22:21:44 -0800 Subject: [PATCH 24/26] Separate client instantiation and function invocation --- common/nexus/nexusrpc/completion_test.go | 19 ++++++++++++------- tests/xdc/nexus_request_forwarding_test.go | 5 +++-- tests/xdc/nexus_state_replication_test.go | 5 +++-- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/common/nexus/nexusrpc/completion_test.go b/common/nexus/nexusrpc/completion_test.go index 6a55c3f331..92fa7e92e4 100644 --- a/common/nexus/nexusrpc/completion_test.go +++ b/common/nexus/nexusrpc/completion_test.go @@ -92,7 +92,8 @@ func TestSuccessfulCompletion(t *testing.T) { Header: nexus.Header{"foo": "bar"}, } - err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}).CompleteOperation(ctx, callbackURL, completion) + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}) + err := c.CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) } @@ -118,9 +119,10 @@ func TestSuccessfulCompletion_CustomSerializer(t *testing.T) { completion.Header.Set("foo", "bar") completion.Header.Set(nexus.HeaderOperationToken, "test-operation-token") - err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ Serializer: serializer, - }).CompleteOperation(ctx, callbackURL, completion) + }) + err := c.CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) require.Equal(t, 1, serializer.decoded) @@ -191,7 +193,8 @@ func TestFailureCompletion(t *testing.T) { }}, Header: nexus.Header{"foo": "bar"}, } - err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}).CompleteOperation(ctx, callbackURL, completion) + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}) + err := c.CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) } @@ -228,9 +231,10 @@ func TestFailureCompletion_CustomFailureConverter(t *testing.T) { }}, Header: nexus.Header{"foo": "bar"}, } - err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ FailureConverter: fc, - }).CompleteOperation(ctx, callbackURL, completion) + }) + err := c.CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) } @@ -248,7 +252,8 @@ func TestBadRequestCompletion(t *testing.T) { completion := nexusrpc.CompleteOperationOptions{ Result: []byte("success"), } - err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}).CompleteOperation(ctx, callbackURL, completion) + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}) + err := c.CompleteOperation(ctx, callbackURL, completion) var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeBadRequest, handlerErr.Type) diff --git a/tests/xdc/nexus_request_forwarding_test.go b/tests/xdc/nexus_request_forwarding_test.go index 06eb0fcc0a..2d9989f065 100644 --- a/tests/xdc/nexus_request_forwarding_test.go +++ b/tests/xdc/nexus_request_forwarding_test.go @@ -691,9 +691,10 @@ func (s *NexusRequestForwardingSuite) sendNexusCompletionRequest( capture := metricsHandler.StartCapture() defer metricsHandler.StopCapture(capture) - err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ Serializer: cnexus.PayloadSerializer, - }).CompleteOperation(ctx, url, completion) + }) + err := c.CompleteOperation(ctx, url, completion) return capture.Snapshot(), err } diff --git a/tests/xdc/nexus_state_replication_test.go b/tests/xdc/nexus_state_replication_test.go index 38034b95df..aa52046bed 100644 --- a/tests/xdc/nexus_state_replication_test.go +++ b/tests/xdc/nexus_state_replication_test.go @@ -732,8 +732,9 @@ func (s *NexusStateReplicationSuite) cancelNexusOperation(ctx context.Context, c if callbackToken != "" { completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} } - err := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ Serializer: commonnexus.PayloadSerializer, - }).CompleteOperation(ctx, callbackUrl, completion) + }) + err := c.CompleteOperation(ctx, callbackUrl, completion) s.NoError(err) } From 1709d698ce69c0b8dbe14bee826e3c624cf4e0a2 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Tue, 10 Feb 2026 22:53:28 -0800 Subject: [PATCH 25/26] Self review --- chasm/lib/callback/nexus_invocation.go | 20 +++++++--------- common/nexus/nexusrpc/completion.go | 9 ++----- components/callbacks/nexus_invocation.go | 20 +++++++--------- .../workflow_test/mutable_state_impl_test.go | 8 +++---- tests/nexus_workflow_test.go | 24 ++++++++++++------- tests/xdc/nexus_request_forwarding_test.go | 3 ++- 6 files changed, 42 insertions(+), 42 deletions(-) diff --git a/chasm/lib/callback/nexus_invocation.go b/chasm/lib/callback/nexus_invocation.go index 63dc9b4735..ec29fabbac 100644 --- a/chasm/lib/callback/nexus_invocation.go +++ b/chasm/lib/callback/nexus_invocation.go @@ -71,9 +71,7 @@ func (n nexusInvocation) Invoke( // Make the call and record metrics. startTime := time.Now() - for k, v := range n.nexus.Header { - n.completion.SetHeader(k, v) - } + n.completion.Header = n.nexus.Header err := client.CompleteOperation(ctx, n.nexus.Url, n.completion) namespaceTag := metrics.NamespaceTag(ns.Name().String()) @@ -82,15 +80,15 @@ func (n nexusInvocation) Invoke( e.metricsHandler.Counter(RequestCounter.Name()).Record(1, namespaceTag, destTag, outcomeTag) e.metricsHandler.Timer(RequestLatencyHistogram.Name()).Record(time.Since(startTime), namespaceTag, destTag, outcomeTag) - if err == nil { - return invocationResultOK{} - } - retryable := isRetryableCallError(err) - e.logger.Error("Callback request failed", tag.Error(err), tag.Bool("retryable", retryable)) - if retryable { - return invocationResultRetry{err} + if err != nil { + retryable := isRetryableCallError(err) + e.logger.Error("Callback request failed", tag.Error(err), tag.Bool("retryable", retryable)) + if retryable { + return invocationResultRetry{err} + } + return invocationResultFail{err} } - return invocationResultFail{err} + return invocationResultOK{} } func isRetryableCallError(err error) bool { diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index efb9094c47..581be0032e 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -108,13 +108,8 @@ type CompleteOperationOptions struct { Result any } -func (c CompleteOperationOptions) SetHeader(key, value string) { - if c.Header == nil { - c.Header = make(nexus.Header, 1) - } - c.Header[key] = value -} - +// nolint:revive // This method is long but it's more readable to keep the logic in one place since it's all related to +// constructing the completion request. func (c CompleteOperationOptions) applyToHTTPRequest(cc *CompletionHTTPClient, request *http.Request) error { if request.Header == nil { request.Header = make(http.Header) diff --git a/components/callbacks/nexus_invocation.go b/components/callbacks/nexus_invocation.go index 43a2ed231f..e24f070f45 100644 --- a/components/callbacks/nexus_invocation.go +++ b/components/callbacks/nexus_invocation.go @@ -68,9 +68,7 @@ func (n nexusInvocation) Invoke(ctx context.Context, ns *namespace.Namespace, e // Make the call and record metrics. startTime := time.Now() - for k, v := range n.nexus.Header { - n.completion.SetHeader(k, v) - } + n.completion.Header = n.nexus.Header err := client.CompleteOperation(ctx, n.nexus.Url, n.completion) namespaceTag := metrics.NamespaceTag(ns.Name().String()) @@ -79,15 +77,15 @@ func (n nexusInvocation) Invoke(ctx context.Context, ns *namespace.Namespace, e e.MetricsHandler.Counter(RequestCounter.Name()).Record(1, namespaceTag, destTag, statusCodeTag) e.MetricsHandler.Timer(RequestLatencyHistogram.Name()).Record(time.Since(startTime), namespaceTag, destTag, statusCodeTag) - if err == nil { - return invocationResultOK{} - } - retryable := isRetryableCallError(err) - e.Logger.Error("Callback request failed", tag.Error(err), tag.Bool("retryable", retryable)) - if retryable { - return invocationResultRetry{err} + if err != nil { + retryable := isRetryableCallError(err) + e.Logger.Error("Callback request failed", tag.Error(err), tag.Bool("retryable", retryable)) + if retryable { + return invocationResultRetry{err} + } + return invocationResultFail{err} } - return invocationResultFail{err} + return invocationResultOK{} } func outcomeTag(callCtx context.Context, callErr error) string { diff --git a/service/history/workflow/workflow_test/mutable_state_impl_test.go b/service/history/workflow/workflow_test/mutable_state_impl_test.go index 16d058d295..fe97c59f4b 100644 --- a/service/history/workflow/workflow_test/mutable_state_impl_test.go +++ b/service/history/workflow/workflow_test/mutable_state_impl_test.go @@ -380,7 +380,7 @@ func TestGetNexusCompletion(t *testing.T) { }, verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { require.Nil(t, completion.Error) - require.Equal(t, commonpb.Payload{ + require.Equal(t, &commonpb.Payload{ Metadata: map[string][]byte{"encoding": []byte("json/plain")}, Data: []byte("3"), }, completion.Result) @@ -399,7 +399,7 @@ func TestGetNexusCompletion(t *testing.T) { verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { require.NotNil(t, completion.Error) require.Equal(t, nexus.OperationStateFailed, completion.Error.State) - require.Equal(t, "workflow failed", completion.Error.Message) + require.Equal(t, "workflow failed", completion.Error.Cause.Error()) require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, @@ -411,7 +411,7 @@ func TestGetNexusCompletion(t *testing.T) { verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { require.NotNil(t, completion.Error) require.Equal(t, nexus.OperationStateFailed, completion.Error.State) - require.Equal(t, "operation terminated", completion.Error.Message) + require.Equal(t, "operation terminated", completion.Error.Cause.Error()) require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, @@ -423,7 +423,7 @@ func TestGetNexusCompletion(t *testing.T) { verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { require.NotNil(t, completion.Error) require.Equal(t, nexus.OperationStateCanceled, completion.Error.State) - require.Equal(t, "operation canceled", completion.Error.Message) + require.Equal(t, "operation canceled", completion.Error.Cause.Error()) require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, diff --git a/tests/nexus_workflow_test.go b/tests/nexus_workflow_test.go index bf4dcdf61f..b1958bd90a 100644 --- a/tests/nexus_workflow_test.go +++ b/tests/nexus_workflow_test.go @@ -1219,14 +1219,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { ctx := testcore.NewContext() - commonCompletion := nexusrpc.CompleteOperationOptions{ - Result: s.mustToPayload("result"), - } - s.Run("ConfigDisabled", func() { s.OverrideDynamicConfig(dynamicconfig.EnableNexus, false) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + } publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) - snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, commonCompletion) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) @@ -1235,8 +1234,11 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.Run("ConfigDisabledNoIdentifier", func() { s.OverrideDynamicConfig(dynamicconfig.EnableNexus, false) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + } publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, commonCompletion) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) @@ -1322,10 +1324,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { }) s.Run("InvalidCallbackToken", func() { + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + } publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) // metrics collection is not initialized before callback validation // Send request without callback token, helper does not add token if blank - _, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, commonCompletion) + _, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) // Verify we get the correct error response var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) @@ -1334,10 +1339,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { }) s.Run("InvalidCallbackTokenNoIdentifier", func() { + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + } publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier // metrics collection is not initialized before callback validation // Send request without callback token, helper does not add token if blank - _, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, commonCompletion) + _, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) // Verify we get the correct error response var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) diff --git a/tests/xdc/nexus_request_forwarding_test.go b/tests/xdc/nexus_request_forwarding_test.go index 2d9989f065..610c4c62c6 100644 --- a/tests/xdc/nexus_request_forwarding_test.go +++ b/tests/xdc/nexus_request_forwarding_test.go @@ -597,7 +597,8 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb }, 5*time.Second, 500*time.Millisecond) completion := tc.getCompletionFn() - completion.SetHeader(cnexus.CallbackTokenHeader, callbackToken) + completion.Header = make(nexus.Header, 1) + completion.Header.Set(cnexus.CallbackTokenHeader, callbackToken) snap, err := s.sendNexusCompletionRequest(ctx, s.T(), s.clusters[1], publicCallbackUrl, completion) s.NoError(err) s.Equal(1, len(snap["nexus_completion_requests"])) From f7a6fdaea471c7d81d68618ae4b844d765a82180 Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Tue, 10 Feb 2026 23:09:02 -0800 Subject: [PATCH 26/26] Bump Nexus SDK --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 9eb3d6b2fe..3c3afcb14e 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/lib/pq v1.10.9 github.com/maruel/panicparse/v2 v2.4.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/nexus-rpc/sdk-go v0.5.2-0.20260210000428-3d8ad6dc9742 + github.com/nexus-rpc/sdk-go v0.5.2-0.20260211051645-26b0b4c584e5 github.com/olekukonko/tablewriter v0.0.5 github.com/olivere/elastic/v7 v7.0.32 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index f5ced64579..5845e34a55 100644 --- a/go.sum +++ b/go.sum @@ -236,8 +236,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/nexus-rpc/sdk-go v0.5.2-0.20260210000428-3d8ad6dc9742 h1:nE/NqmspHqHkLh8rsSsQ/dLXOrbx4N7dU7GoGsO0DdI= -github.com/nexus-rpc/sdk-go v0.5.2-0.20260210000428-3d8ad6dc9742/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= +github.com/nexus-rpc/sdk-go v0.5.2-0.20260211051645-26b0b4c584e5 h1:Van9KGGs8lcDgxzSNFbDhEMNeJ80TbBxwZ45f9iBk9U= +github.com/nexus-rpc/sdk-go v0.5.2-0.20260211051645-26b0b4c584e5/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=