diff --git a/api/matchingservice/v1/request_response.pb.go b/api/matchingservice/v1/request_response.pb.go index 8093bd2e40..4a8dfb4145 100644 --- a/api/matchingservice/v1/request_response.pb.go +++ b/api/matchingservice/v1/request_response.pb.go @@ -14,16 +14,17 @@ import ( v11 "go.temporal.io/api/common/v1" v112 "go.temporal.io/api/deployment/v1" v19 "go.temporal.io/api/enums/v1" + v114 "go.temporal.io/api/failure/v1" v16 "go.temporal.io/api/history/v1" v113 "go.temporal.io/api/nexus/v1" v15 "go.temporal.io/api/protocol/v1" v12 "go.temporal.io/api/query/v1" v14 "go.temporal.io/api/taskqueue/v1" - v114 "go.temporal.io/api/worker/v1" + v115 "go.temporal.io/api/worker/v1" v1 "go.temporal.io/api/workflowservice/v1" v17 "go.temporal.io/server/api/clock/v1" v110 "go.temporal.io/server/api/deployment/v1" - v115 "go.temporal.io/server/api/enums/v1" + v116 "go.temporal.io/server/api/enums/v1" v13 "go.temporal.io/server/api/history/v1" v111 "go.temporal.io/server/api/persistence/v1" v18 "go.temporal.io/server/api/taskqueue/v1" @@ -3876,6 +3877,7 @@ type DispatchNexusTaskResponse struct { // *DispatchNexusTaskResponse_HandlerError // *DispatchNexusTaskResponse_Response // *DispatchNexusTaskResponse_RequestTimeout + // *DispatchNexusTaskResponse_Failure Outcome isDispatchNexusTaskResponse_Outcome `protobuf_oneof:"outcome"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -3918,6 +3920,7 @@ func (x *DispatchNexusTaskResponse) GetOutcome() isDispatchNexusTaskResponse_Out return nil } +// Deprecated: Marked as deprecated in temporal/server/api/matchingservice/v1/request_response.proto. func (x *DispatchNexusTaskResponse) GetHandlerError() *v113.HandlerError { if x != nil { if x, ok := x.Outcome.(*DispatchNexusTaskResponse_HandlerError); ok { @@ -3945,12 +3948,23 @@ func (x *DispatchNexusTaskResponse) GetRequestTimeout() *DispatchNexusTaskRespon return nil } +func (x *DispatchNexusTaskResponse) GetFailure() *v114.Failure { + if x != nil { + if x, ok := x.Outcome.(*DispatchNexusTaskResponse_Failure); ok { + return x.Failure + } + } + return nil +} + type isDispatchNexusTaskResponse_Outcome interface { isDispatchNexusTaskResponse_Outcome() } type DispatchNexusTaskResponse_HandlerError struct { - // Set if the worker's handler failed the nexus task. + // Deprecated. Use failure field instead. + // + // Deprecated: Marked as deprecated in temporal/server/api/matchingservice/v1/request_response.proto. HandlerError *v113.HandlerError `protobuf:"bytes,1,opt,name=handler_error,json=handlerError,proto3,oneof"` } @@ -3963,12 +3977,19 @@ type DispatchNexusTaskResponse_RequestTimeout struct { RequestTimeout *DispatchNexusTaskResponse_Timeout `protobuf:"bytes,3,opt,name=request_timeout,json=requestTimeout,proto3,oneof"` } +type DispatchNexusTaskResponse_Failure struct { + // Set if the worker's handler failed the nexus task. Must contain a NexusHandlerFailureInfo object. + Failure *v114.Failure `protobuf:"bytes,4,opt,name=failure,proto3,oneof"` +} + func (*DispatchNexusTaskResponse_HandlerError) isDispatchNexusTaskResponse_Outcome() {} func (*DispatchNexusTaskResponse_Response) isDispatchNexusTaskResponse_Outcome() {} func (*DispatchNexusTaskResponse_RequestTimeout) isDispatchNexusTaskResponse_Outcome() {} +func (*DispatchNexusTaskResponse_Failure) isDispatchNexusTaskResponse_Outcome() {} + type PollNexusTaskQueueRequest struct { state protoimpl.MessageState `protogen:"open.v1"` NamespaceId string `protobuf:"bytes,1,opt,name=namespace_id,json=namespaceId,proto3" json:"namespace_id,omitempty"` @@ -4889,7 +4910,7 @@ func (x *ListWorkersRequest) GetListRequest() *v1.ListWorkersRequest { type ListWorkersResponse struct { state protoimpl.MessageState `protogen:"open.v1"` - WorkersInfo []*v114.WorkerInfo `protobuf:"bytes,1,rep,name=workers_info,json=workersInfo,proto3" json:"workers_info,omitempty"` + WorkersInfo []*v115.WorkerInfo `protobuf:"bytes,1,rep,name=workers_info,json=workersInfo,proto3" json:"workers_info,omitempty"` NextPageToken []byte `protobuf:"bytes,2,opt,name=next_page_token,json=nextPageToken,proto3" json:"next_page_token,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -4925,7 +4946,7 @@ func (*ListWorkersResponse) Descriptor() ([]byte, []int) { return file_temporal_server_api_matchingservice_v1_request_response_proto_rawDescGZIP(), []int{72} } -func (x *ListWorkersResponse) GetWorkersInfo() []*v114.WorkerInfo { +func (x *ListWorkersResponse) GetWorkersInfo() []*v115.WorkerInfo { if x != nil { return x.WorkersInfo } @@ -5100,7 +5121,7 @@ func (x *DescribeWorkerRequest) GetRequest() *v1.DescribeWorkerRequest { type DescribeWorkerResponse struct { state protoimpl.MessageState `protogen:"open.v1"` - WorkerInfo *v114.WorkerInfo `protobuf:"bytes,1,opt,name=worker_info,json=workerInfo,proto3" json:"worker_info,omitempty"` + WorkerInfo *v115.WorkerInfo `protobuf:"bytes,1,opt,name=worker_info,json=workerInfo,proto3" json:"worker_info,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -5135,7 +5156,7 @@ func (*DescribeWorkerResponse) Descriptor() ([]byte, []int) { return file_temporal_server_api_matchingservice_v1_request_response_proto_rawDescGZIP(), []int{76} } -func (x *DescribeWorkerResponse) GetWorkerInfo() *v114.WorkerInfo { +func (x *DescribeWorkerResponse) GetWorkerInfo() *v115.WorkerInfo { if x != nil { return x.WorkerInfo } @@ -5158,7 +5179,7 @@ type UpdateFairnessStateRequest struct { NamespaceId string `protobuf:"bytes,1,opt,name=namespace_id,json=namespaceId,proto3" json:"namespace_id,omitempty"` TaskQueue string `protobuf:"bytes,2,opt,name=task_queue,json=taskQueue,proto3" json:"task_queue,omitempty"` TaskQueueType v19.TaskQueueType `protobuf:"varint,3,opt,name=task_queue_type,json=taskQueueType,proto3,enum=temporal.api.enums.v1.TaskQueueType" json:"task_queue_type,omitempty"` - FairnessState v115.FairnessState `protobuf:"varint,4,opt,name=fairness_state,json=fairnessState,proto3,enum=temporal.server.api.enums.v1.FairnessState" json:"fairness_state,omitempty"` + FairnessState v116.FairnessState `protobuf:"varint,4,opt,name=fairness_state,json=fairnessState,proto3,enum=temporal.server.api.enums.v1.FairnessState" json:"fairness_state,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -5214,11 +5235,11 @@ func (x *UpdateFairnessStateRequest) GetTaskQueueType() v19.TaskQueueType { return v19.TaskQueueType(0) } -func (x *UpdateFairnessStateRequest) GetFairnessState() v115.FairnessState { +func (x *UpdateFairnessStateRequest) GetFairnessState() v116.FairnessState { if x != nil { return x.FairnessState } - return v115.FairnessState(0) + return v116.FairnessState(0) } type UpdateFairnessStateResponse struct { @@ -5693,7 +5714,7 @@ var File_temporal_server_api_matchingservice_v1_request_response_proto protorefl const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc = "" + "\n" + - "=temporal/server/api/matchingservice/v1/request_response.proto\x12&temporal.server.api.matchingservice.v1\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a$temporal/api/common/v1/message.proto\x1a(temporal/api/deployment/v1/message.proto\x1a&temporal/api/enums/v1/task_queue.proto\x1a%temporal/api/history/v1/message.proto\x1a'temporal/api/taskqueue/v1/message.proto\x1a#temporal/api/query/v1/message.proto\x1a&temporal/api/protocol/v1/message.proto\x1a*temporal/server/api/clock/v1/message.proto\x1a/temporal/server/api/deployment/v1/message.proto\x1a,temporal/server/api/history/v1/message.proto\x1a.temporal/server/api/persistence/v1/nexus.proto\x1a4temporal/server/api/persistence/v1/task_queues.proto\x1a.temporal/server/api/taskqueue/v1/message.proto\x1a1temporal/server/api/enums/v1/fairness_state.proto\x1a6temporal/api/workflowservice/v1/request_response.proto\x1a#temporal/api/nexus/v1/message.proto\x1a$temporal/api/worker/v1/message.proto\"\xc3\x02\n" + + "=temporal/server/api/matchingservice/v1/request_response.proto\x12&temporal.server.api.matchingservice.v1\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a$temporal/api/common/v1/message.proto\x1a(temporal/api/deployment/v1/message.proto\x1a&temporal/api/enums/v1/task_queue.proto\x1a%temporal/api/failure/v1/message.proto\x1a%temporal/api/history/v1/message.proto\x1a'temporal/api/taskqueue/v1/message.proto\x1a#temporal/api/query/v1/message.proto\x1a&temporal/api/protocol/v1/message.proto\x1a*temporal/server/api/clock/v1/message.proto\x1a/temporal/server/api/deployment/v1/message.proto\x1a,temporal/server/api/history/v1/message.proto\x1a.temporal/server/api/persistence/v1/nexus.proto\x1a4temporal/server/api/persistence/v1/task_queues.proto\x1a.temporal/server/api/taskqueue/v1/message.proto\x1a1temporal/server/api/enums/v1/fairness_state.proto\x1a6temporal/api/workflowservice/v1/request_response.proto\x1a#temporal/api/nexus/v1/message.proto\x1a$temporal/api/worker/v1/message.proto\"\xc3\x02\n" + "\x1cPollWorkflowTaskQueueRequest\x12!\n" + "\fnamespace_id\x18\x01 \x01(\tR\vnamespaceId\x12\x1b\n" + "\tpoller_id\x18\x02 \x01(\tR\bpollerId\x12`\n" + @@ -6023,11 +6044,12 @@ const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc "\n" + "task_queue\x18\x02 \x01(\v2$.temporal.api.taskqueue.v1.TaskQueueR\ttaskQueue\x128\n" + "\arequest\x18\x03 \x01(\v2\x1e.temporal.api.nexus.v1.RequestR\arequest\x12T\n" + - "\fforward_info\x18\x04 \x01(\v21.temporal.server.api.taskqueue.v1.TaskForwardInfoR\vforwardInfo\"\xb2\x02\n" + - "\x19DispatchNexusTaskResponse\x12J\n" + - "\rhandler_error\x18\x01 \x01(\v2#.temporal.api.nexus.v1.HandlerErrorH\x00R\fhandlerError\x12=\n" + + "\fforward_info\x18\x04 \x01(\v21.temporal.server.api.taskqueue.v1.TaskForwardInfoR\vforwardInfo\"\xf4\x02\n" + + "\x19DispatchNexusTaskResponse\x12N\n" + + "\rhandler_error\x18\x01 \x01(\v2#.temporal.api.nexus.v1.HandlerErrorB\x02\x18\x01H\x00R\fhandlerError\x12=\n" + "\bresponse\x18\x02 \x01(\v2\x1f.temporal.api.nexus.v1.ResponseH\x00R\bresponse\x12t\n" + - "\x0frequest_timeout\x18\x03 \x01(\v2I.temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.TimeoutH\x00R\x0erequestTimeout\x1a\t\n" + + "\x0frequest_timeout\x18\x03 \x01(\v2I.temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.TimeoutH\x00R\x0erequestTimeout\x12<\n" + + "\afailure\x18\x04 \x01(\v2 .temporal.api.failure.v1.FailureH\x00R\afailure\x1a\t\n" + "\aTimeoutB\t\n" + "\aoutcome\"\xb4\x02\n" + "\x19PollNexusTaskQueueRequest\x12!\n" + @@ -6266,23 +6288,24 @@ var file_temporal_server_api_matchingservice_v1_request_response_proto_goTypes = (*v113.Request)(nil), // 133: temporal.api.nexus.v1.Request (*v113.HandlerError)(nil), // 134: temporal.api.nexus.v1.HandlerError (*v113.Response)(nil), // 135: temporal.api.nexus.v1.Response - (*v1.PollNexusTaskQueueRequest)(nil), // 136: temporal.api.workflowservice.v1.PollNexusTaskQueueRequest - (*v1.PollNexusTaskQueueResponse)(nil), // 137: temporal.api.workflowservice.v1.PollNexusTaskQueueResponse - (*v1.RespondNexusTaskCompletedRequest)(nil), // 138: temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest - (*v1.RespondNexusTaskFailedRequest)(nil), // 139: temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest - (*v111.NexusEndpointSpec)(nil), // 140: temporal.server.api.persistence.v1.NexusEndpointSpec - (*v111.NexusEndpointEntry)(nil), // 141: temporal.server.api.persistence.v1.NexusEndpointEntry - (*v1.RecordWorkerHeartbeatRequest)(nil), // 142: temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest - (*v1.ListWorkersRequest)(nil), // 143: temporal.api.workflowservice.v1.ListWorkersRequest - (*v114.WorkerInfo)(nil), // 144: temporal.api.worker.v1.WorkerInfo - (*v1.UpdateTaskQueueConfigRequest)(nil), // 145: temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest - (*v14.TaskQueueConfig)(nil), // 146: temporal.api.taskqueue.v1.TaskQueueConfig - (*v1.DescribeWorkerRequest)(nil), // 147: temporal.api.workflowservice.v1.DescribeWorkerRequest - (v115.FairnessState)(0), // 148: temporal.server.api.enums.v1.FairnessState - (*v14.TaskQueueStats)(nil), // 149: temporal.api.taskqueue.v1.TaskQueueStats - (*v18.TaskQueueVersionInfoInternal)(nil), // 150: temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal - (*v1.UpdateWorkerBuildIdCompatibilityRequest)(nil), // 151: temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest - (*v110.WorkerDeploymentVersionData)(nil), // 152: temporal.server.api.deployment.v1.WorkerDeploymentVersionData + (*v114.Failure)(nil), // 136: temporal.api.failure.v1.Failure + (*v1.PollNexusTaskQueueRequest)(nil), // 137: temporal.api.workflowservice.v1.PollNexusTaskQueueRequest + (*v1.PollNexusTaskQueueResponse)(nil), // 138: temporal.api.workflowservice.v1.PollNexusTaskQueueResponse + (*v1.RespondNexusTaskCompletedRequest)(nil), // 139: temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest + (*v1.RespondNexusTaskFailedRequest)(nil), // 140: temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest + (*v111.NexusEndpointSpec)(nil), // 141: temporal.server.api.persistence.v1.NexusEndpointSpec + (*v111.NexusEndpointEntry)(nil), // 142: temporal.server.api.persistence.v1.NexusEndpointEntry + (*v1.RecordWorkerHeartbeatRequest)(nil), // 143: temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest + (*v1.ListWorkersRequest)(nil), // 144: temporal.api.workflowservice.v1.ListWorkersRequest + (*v115.WorkerInfo)(nil), // 145: temporal.api.worker.v1.WorkerInfo + (*v1.UpdateTaskQueueConfigRequest)(nil), // 146: temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest + (*v14.TaskQueueConfig)(nil), // 147: temporal.api.taskqueue.v1.TaskQueueConfig + (*v1.DescribeWorkerRequest)(nil), // 148: temporal.api.workflowservice.v1.DescribeWorkerRequest + (v116.FairnessState)(0), // 149: temporal.server.api.enums.v1.FairnessState + (*v14.TaskQueueStats)(nil), // 150: temporal.api.taskqueue.v1.TaskQueueStats + (*v18.TaskQueueVersionInfoInternal)(nil), // 151: temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal + (*v1.UpdateWorkerBuildIdCompatibilityRequest)(nil), // 152: temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest + (*v110.WorkerDeploymentVersionData)(nil), // 153: temporal.server.api.deployment.v1.WorkerDeploymentVersionData } var file_temporal_server_api_matchingservice_v1_request_response_proto_depIdxs = []int32{ 92, // 0: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueRequest.poll_request:type_name -> temporal.api.workflowservice.v1.PollWorkflowTaskQueueRequest @@ -6396,44 +6419,45 @@ var file_temporal_server_api_matchingservice_v1_request_response_proto_depIdxs = 134, // 108: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.handler_error:type_name -> temporal.api.nexus.v1.HandlerError 135, // 109: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.response:type_name -> temporal.api.nexus.v1.Response 91, // 110: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.request_timeout:type_name -> temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.Timeout - 136, // 111: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.request:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueRequest - 81, // 112: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.conditions:type_name -> temporal.server.api.matchingservice.v1.PollConditions - 137, // 113: temporal.server.api.matchingservice.v1.PollNexusTaskQueueResponse.response:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueResponse - 97, // 114: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue - 138, // 115: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest - 97, // 116: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue - 139, // 117: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest - 140, // 118: temporal.server.api.matchingservice.v1.CreateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec - 141, // 119: temporal.server.api.matchingservice.v1.CreateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry - 140, // 120: temporal.server.api.matchingservice.v1.UpdateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec - 141, // 121: temporal.server.api.matchingservice.v1.UpdateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry - 141, // 122: temporal.server.api.matchingservice.v1.ListNexusEndpointsResponse.entries:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry - 142, // 123: temporal.server.api.matchingservice.v1.RecordWorkerHeartbeatRequest.heartbeart_request:type_name -> temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest - 143, // 124: temporal.server.api.matchingservice.v1.ListWorkersRequest.list_request:type_name -> temporal.api.workflowservice.v1.ListWorkersRequest - 144, // 125: temporal.server.api.matchingservice.v1.ListWorkersResponse.workers_info:type_name -> temporal.api.worker.v1.WorkerInfo - 145, // 126: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigRequest.update_taskqueue_config:type_name -> temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest - 146, // 127: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigResponse.updated_taskqueue_config:type_name -> temporal.api.taskqueue.v1.TaskQueueConfig - 147, // 128: temporal.server.api.matchingservice.v1.DescribeWorkerRequest.request:type_name -> temporal.api.workflowservice.v1.DescribeWorkerRequest - 144, // 129: temporal.server.api.matchingservice.v1.DescribeWorkerResponse.worker_info:type_name -> temporal.api.worker.v1.WorkerInfo - 115, // 130: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType - 148, // 131: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.fairness_state:type_name -> temporal.server.api.enums.v1.FairnessState - 115, // 132: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType - 117, // 133: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.version:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersion - 95, // 134: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueResponse.QueriesEntry.value:type_name -> temporal.api.query.v1.WorkflowQuery - 95, // 135: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueResponseWithRawHistory.QueriesEntry.value:type_name -> temporal.api.query.v1.WorkflowQuery - 115, // 136: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesRequest.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType - 115, // 137: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType - 149, // 138: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats:type_name -> temporal.api.taskqueue.v1.TaskQueueStats - 86, // 139: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats_by_priority_key:type_name -> temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry - 149, // 140: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry.value:type_name -> temporal.api.taskqueue.v1.TaskQueueStats - 150, // 141: temporal.server.api.matchingservice.v1.DescribeTaskQueuePartitionResponse.VersionsInfoInternalEntry.value:type_name -> temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal - 151, // 142: temporal.server.api.matchingservice.v1.UpdateWorkerBuildIdCompatibilityRequest.ApplyPublicRequest.request:type_name -> temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest - 152, // 143: temporal.server.api.matchingservice.v1.SyncDeploymentUserDataRequest.UpsertVersionsDataEntry.value:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersionData - 144, // [144:144] is the sub-list for method output_type - 144, // [144:144] is the sub-list for method input_type - 144, // [144:144] is the sub-list for extension type_name - 144, // [144:144] is the sub-list for extension extendee - 0, // [0:144] is the sub-list for field type_name + 136, // 111: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.failure:type_name -> temporal.api.failure.v1.Failure + 137, // 112: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.request:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueRequest + 81, // 113: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.conditions:type_name -> temporal.server.api.matchingservice.v1.PollConditions + 138, // 114: temporal.server.api.matchingservice.v1.PollNexusTaskQueueResponse.response:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueResponse + 97, // 115: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue + 139, // 116: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest + 97, // 117: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue + 140, // 118: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest + 141, // 119: temporal.server.api.matchingservice.v1.CreateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec + 142, // 120: temporal.server.api.matchingservice.v1.CreateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry + 141, // 121: temporal.server.api.matchingservice.v1.UpdateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec + 142, // 122: temporal.server.api.matchingservice.v1.UpdateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry + 142, // 123: temporal.server.api.matchingservice.v1.ListNexusEndpointsResponse.entries:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry + 143, // 124: temporal.server.api.matchingservice.v1.RecordWorkerHeartbeatRequest.heartbeart_request:type_name -> temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest + 144, // 125: temporal.server.api.matchingservice.v1.ListWorkersRequest.list_request:type_name -> temporal.api.workflowservice.v1.ListWorkersRequest + 145, // 126: temporal.server.api.matchingservice.v1.ListWorkersResponse.workers_info:type_name -> temporal.api.worker.v1.WorkerInfo + 146, // 127: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigRequest.update_taskqueue_config:type_name -> temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest + 147, // 128: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigResponse.updated_taskqueue_config:type_name -> temporal.api.taskqueue.v1.TaskQueueConfig + 148, // 129: temporal.server.api.matchingservice.v1.DescribeWorkerRequest.request:type_name -> temporal.api.workflowservice.v1.DescribeWorkerRequest + 145, // 130: temporal.server.api.matchingservice.v1.DescribeWorkerResponse.worker_info:type_name -> temporal.api.worker.v1.WorkerInfo + 115, // 131: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType + 149, // 132: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.fairness_state:type_name -> temporal.server.api.enums.v1.FairnessState + 115, // 133: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType + 117, // 134: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.version:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersion + 95, // 135: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueResponse.QueriesEntry.value:type_name -> temporal.api.query.v1.WorkflowQuery + 95, // 136: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueResponseWithRawHistory.QueriesEntry.value:type_name -> temporal.api.query.v1.WorkflowQuery + 115, // 137: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesRequest.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType + 115, // 138: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType + 150, // 139: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats:type_name -> temporal.api.taskqueue.v1.TaskQueueStats + 86, // 140: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats_by_priority_key:type_name -> temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry + 150, // 141: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry.value:type_name -> temporal.api.taskqueue.v1.TaskQueueStats + 151, // 142: temporal.server.api.matchingservice.v1.DescribeTaskQueuePartitionResponse.VersionsInfoInternalEntry.value:type_name -> temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal + 152, // 143: temporal.server.api.matchingservice.v1.UpdateWorkerBuildIdCompatibilityRequest.ApplyPublicRequest.request:type_name -> temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest + 153, // 144: temporal.server.api.matchingservice.v1.SyncDeploymentUserDataRequest.UpsertVersionsDataEntry.value:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersionData + 145, // [145:145] is the sub-list for method output_type + 145, // [145:145] is the sub-list for method input_type + 145, // [145:145] is the sub-list for extension type_name + 145, // [145:145] is the sub-list for extension extendee + 0, // [0:145] is the sub-list for field type_name } func init() { file_temporal_server_api_matchingservice_v1_request_response_proto_init() } @@ -6460,6 +6484,7 @@ func file_temporal_server_api_matchingservice_v1_request_response_proto_init() { (*DispatchNexusTaskResponse_HandlerError)(nil), (*DispatchNexusTaskResponse_Response)(nil), (*DispatchNexusTaskResponse_RequestTimeout)(nil), + (*DispatchNexusTaskResponse_Failure)(nil), } type x struct{} out := protoimpl.TypeBuilder{ diff --git a/chasm/lib/callback/chasm_invocation.go b/chasm/lib/callback/chasm_invocation.go index 191d3e488c..fe5248c325 100644 --- a/chasm/lib/callback/chasm_invocation.go +++ b/chasm/lib/callback/chasm_invocation.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "fmt" - "io" "github.com/google/uuid" "github.com/nexus-rpc/sdk-go/nexus" @@ -28,7 +27,7 @@ import ( type chasmInvocation struct { nexus *callbackspb.Callback_Nexus attempt int32 - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions requestID string } @@ -135,22 +134,13 @@ func (c chasmInvocation) getHistoryRequest( RequestId: c.requestID, } - switch op := c.completion.(type) { - case *nexusrpc.OperationCompletionSuccessful: - payloadBody, err := io.ReadAll(op.Reader) - if err != nil { - return nil, fmt.Errorf("failed to read payload: %v", err) - } - + if c.completion.Error == nil { var payload *commonpb.Payload - if payloadBody != nil { - content := &nexus.Content{ - Header: op.Reader.Header, - Data: payloadBody, - } - err := commonnexus.PayloadSerializer.Deserialize(content, &payload) - if err != nil { - return nil, fmt.Errorf("failed to deserialize payload: %v", err) + if c.completion.Result != nil { + var ok bool + payload, ok = c.completion.Result.(*commonpb.Payload) + if !ok { + return nil, fmt.Errorf("invalid result, expected a payload, got: %T", c.completion.Result) } } @@ -158,13 +148,21 @@ func (c chasmInvocation) getHistoryRequest( Outcome: &historyservice.CompleteNexusOperationChasmRequest_Success{ Success: payload, }, - CloseTime: timestamppb.New(op.CloseTime), + CloseTime: timestamppb.New(c.completion.CloseTime), Completion: completion, } - case *nexusrpc.OperationCompletionUnsuccessful: - apiFailure, err := commonnexus.NexusFailureToAPIFailure(op.Failure, true) + } else { + failure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(c.completion.Error) if err != nil { - return nil, fmt.Errorf("failed to convert failure type: %v", err) + return nil, fmt.Errorf("failed to convert error to failure: %w", err) + } + // Unwrap the operation error since it's not meant to be sent for Temporal->Temporal completions. + if failure.Cause != nil { + failure = *failure.Cause + } + apiFailure, err := commonnexus.NexusFailureToTemporalFailure(failure) + if err != nil { + return nil, fmt.Errorf("failed to convert failure type: %w", err) } req = &historyservice.CompleteNexusOperationChasmRequest{ @@ -172,10 +170,8 @@ func (c chasmInvocation) getHistoryRequest( Outcome: &historyservice.CompleteNexusOperationChasmRequest_Failure{ Failure: apiFailure, }, - CloseTime: timestamppb.New(op.CloseTime), + CloseTime: timestamppb.New(c.completion.CloseTime), } - default: - return nil, fmt.Errorf("unexpected nexus.OperationCompletion: %v", completion) } return req, nil diff --git a/chasm/lib/callback/component.go b/chasm/lib/callback/component.go index a10734e19a..c288ff76c2 100644 --- a/chasm/lib/callback/component.go +++ b/chasm/lib/callback/component.go @@ -15,7 +15,7 @@ import ( ) type CompletionSource interface { - GetNexusCompletion(ctx chasm.Context, requestID string) (nexusrpc.OperationCompletion, error) + GetNexusCompletion(ctx chasm.Context, requestID string) (nexusrpc.CompleteOperationOptions, error) } var _ chasm.Component = (*Callback)(nil) diff --git a/chasm/lib/callback/executors_test.go b/chasm/lib/callback/executors_test.go index a43883ab20..42dfcfd48d 100644 --- a/chasm/lib/callback/executors_test.go +++ b/chasm/lib/callback/executors_test.go @@ -40,13 +40,13 @@ type mockNexusCompletionGetterComponent struct { Empty *emptypb.Empty - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions err error Callback chasm.Field[*Callback] } -func (m *mockNexusCompletionGetterComponent) GetNexusCompletion(_ chasm.Context, requestID string) (nexusrpc.OperationCompletion, error) { +func (m *mockNexusCompletionGetterComponent) GetNexusCompletion(_ chasm.Context, requestID string) (nexusrpc.CompleteOperationOptions, error) { return m.completion, m.err } @@ -81,7 +81,7 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { caller: func(r *http.Request) (*http.Response, error) { return &http.Response{StatusCode: 200, Body: http.NoBody}, nil }, - expectedMetricOutcome: "status:200", + expectedMetricOutcome: "success", assertOutcome: func(t *testing.T, cb *Callback, err error) { require.NoError(t, err) require.Equal(t, callbackspb.CALLBACK_STATUS_SUCCEEDED, cb.Status) @@ -104,7 +104,7 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { caller: func(r *http.Request) (*http.Response, error) { return &http.Response{StatusCode: 500, Body: http.NoBody}, nil }, - expectedMetricOutcome: "status:500", + expectedMetricOutcome: "handler-error:INTERNAL", assertOutcome: func(t *testing.T, cb *Callback, err error) { var destDownErr *queueserrors.DestinationDownError require.ErrorAs(t, err, &destDownErr) @@ -116,7 +116,7 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { caller: func(r *http.Request) (*http.Response, error) { return &http.Response{StatusCode: 400, Body: http.NoBody}, nil }, - expectedMetricOutcome: "status:400", + expectedMetricOutcome: "handler-error:BAD_REQUEST", assertOutcome: func(t *testing.T, cb *Callback, err error) { require.NoError(t, err) require.Equal(t, callbackspb.CALLBACK_STATUS_FAILED, cb.Status) @@ -210,8 +210,7 @@ func TestExecuteInvocationTaskNexus_Outcomes(t *testing.T) { } // Create completion - completion, err := nexusrpc.NewOperationCompletionSuccessful(nil, nexusrpc.OperationCompletionSuccessfulOptions{}) - require.NoError(t, err) + completion := nexusrpc.CompleteOperationOptions{} // Set up the CompletionSource field to return our mock completion root.SetRootComponent(&mockNexusCompletionGetterComponent{ @@ -371,7 +370,7 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { cases := []struct { name string setupHistoryClient func(*testing.T, *gomock.Controller) resource.HistoryClient - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions headerValue string assertOutcome func(*testing.T, *Callback, error) }{ @@ -397,16 +396,11 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { }) return client }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayloadBytes([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - CloseTime: dummyTime, - }, - ) - require.NoError(t, err) - return comp + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Result: createPayloadBytes([]byte("result-data")), + CloseTime: dummyTime, + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb *Callback, err error) { @@ -430,18 +424,14 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { }) return client }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Error: &nexus.OperationError{ State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: nexus.Failure{Message: "operation failed"}}, }, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - CloseTime: dummyTime, - }, - ) - require.NoError(t, err) - return comp + CloseTime: dummyTime, + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb *Callback, err error) { @@ -459,15 +449,10 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { ).Return(nil, status.Error(codes.Unavailable, "service unavailable")) return client }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayloadBytes([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Result: createPayloadBytes([]byte("result-data")), + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb *Callback, err error) { @@ -485,15 +470,10 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { ).Return(nil, status.Error(codes.InvalidArgument, "invalid request")) return client }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayloadBytes([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Result: createPayloadBytes([]byte("result-data")), + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb *Callback, err error) { @@ -507,15 +487,10 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { // No RPC call expected return historyservicemock.NewMockHistoryServiceClient(ctrl) }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayloadBytes([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Result: createPayloadBytes([]byte("result-data")), + } }(), headerValue: "invalid-base64!!!", assertOutcome: func(t *testing.T, cb *Callback, err error) { @@ -529,15 +504,10 @@ func TestExecuteInvocationTaskChasm_Outcomes(t *testing.T) { // No RPC call expected return historyservicemock.NewMockHistoryServiceClient(ctrl) }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayloadBytes([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Result: createPayloadBytes([]byte("result-data")), + } }(), headerValue: base64.RawURLEncoding.EncodeToString([]byte("not-valid-protobuf")), assertOutcome: func(t *testing.T, cb *Callback, err error) { diff --git a/chasm/lib/callback/nexus_invocation.go b/chasm/lib/callback/nexus_invocation.go index 73eb1285b9..ec29fabbac 100644 --- a/chasm/lib/callback/nexus_invocation.go +++ b/chasm/lib/callback/nexus_invocation.go @@ -1,15 +1,10 @@ package callback import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" - "mime" + "errors" "net/http" "net/http/httptrace" - "slices" "time" "github.com/nexus-rpc/sdk-go/nexus" @@ -32,25 +27,11 @@ var retryable4xxErrorTypes = []int{ type nexusInvocation struct { nexus *callbackspb.Callback_Nexus - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions workflowID, runID string attempt int32 } -func isRetryableHTTPResponse(response *http.Response) bool { - return response.StatusCode >= 500 || slices.Contains(retryable4xxErrorTypes, response.StatusCode) -} - -func outcomeTag(callCtx context.Context, response *http.Response, callErr error) string { - if callErr != nil { - if callCtx.Err() != nil { - return "request-timeout" - } - return "unknown-error" - } - return fmt.Sprintf("status:%d", response.StatusCode) -} - func (n nexusInvocation) WrapError(result invocationResult, err error) error { if retry, ok := result.(invocationResultRetry); ok { return queueserrors.NewDestinationDownError(retry.err.Error(), err) @@ -80,107 +61,54 @@ func (n nexusInvocation) Invoke( } } - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, n.nexus.Url, n.completion) - if err != nil { - return invocationResultFail{queueserrors.NewUnprocessableTaskError( - fmt.Sprintf("failed to construct Nexus request: %v", err), - )} - } - if request.Header == nil { - request.Header = make(http.Header) - } - for k, v := range n.nexus.Header { - request.Header.Set(k, v) - } - - caller := e.httpCallerProvider(queuescommon.NamespaceIDAndDestination{ - NamespaceID: ns.ID().String(), - Destination: taskAttr.Destination, + client := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + HTTPCaller: e.httpCallerProvider(queuescommon.NamespaceIDAndDestination{ + NamespaceID: ns.ID().String(), + Destination: taskAttr.Destination, + }), + Serializer: commonnexus.PayloadSerializer, }) // Make the call and record metrics. startTime := time.Now() - response, err := caller(request) + + n.completion.Header = n.nexus.Header + err := client.CompleteOperation(ctx, n.nexus.Url, n.completion) namespaceTag := metrics.NamespaceTag(ns.Name().String()) destTag := metrics.DestinationTag(taskAttr.Destination) - statusCodeTag := metrics.OutcomeTag(outcomeTag(ctx, response, err)) - e.metricsHandler.Counter(RequestCounter.Name()).Record(1, namespaceTag, destTag, statusCodeTag) - e.metricsHandler.Timer(RequestLatencyHistogram.Name()).Record(time.Since(startTime), namespaceTag, destTag, statusCodeTag) + outcomeTag := metrics.OutcomeTag(outcomeTag(ctx, err)) + e.metricsHandler.Counter(RequestCounter.Name()).Record(1, namespaceTag, destTag, outcomeTag) + e.metricsHandler.Timer(RequestLatencyHistogram.Name()).Record(time.Since(startTime), namespaceTag, destTag, outcomeTag) if err != nil { - e.logger.Error("Callback request failed with error", tag.Error(err)) - return invocationResultRetry{err: err} - } - - if response.StatusCode >= 200 && response.StatusCode < 300 { - // Body is not read but should be discarded to keep the underlying TCP connection alive. - // Just in case something unexpected happens while discarding or closing the body, - // propagate errors to the machine. - if _, err = io.Copy(io.Discard, response.Body); err == nil { - if err = response.Body.Close(); err != nil { - e.logger.Error("Callback request failed with error", tag.Error(err)) - return invocationResultRetry{err: err} - } + retryable := isRetryableCallError(err) + e.logger.Error("Callback request failed", tag.Error(err), tag.Bool("retryable", retryable)) + if retryable { + return invocationResultRetry{err} } - return invocationResultOK{} + return invocationResultFail{err} } - - retryable := isRetryableHTTPResponse(response) - err = readHandlerErrFromResponse(response, e.logger) - e.logger.Error("Callback request failed", tag.Error(err), tag.String("status", response.Status), tag.Bool("retryable", retryable)) - if retryable { - return invocationResultRetry{err: err} - } - return invocationResultFail{err} + return invocationResultOK{} } -// Reads and replaces the http response body and attempts to deserialize it into a Nexus failure. If successful, -// returns a nexus.HandlerError with the deserialized failure as the Cause. If there is an error reading the body or -// during deserialization, returns a nexus.HandlerError with a generic Cause based on response status. -// TODO: This logic is duplicated in the frontend handler for forwarded requests. Eventually it should live in the Nexus SDK. -func readHandlerErrFromResponse(response *http.Response, logger log.Logger) error { - handlerErr := &nexus.HandlerError{ - Type: commonnexus.HandlerErrorTypeFromHTTPStatus(response.StatusCode), - Cause: fmt.Errorf("request failed with: %v", response.Status), - } - - body, err := readAndReplaceBody(response) - if err != nil { - logger.Error("Error reading response body for non-ok callback request", tag.Error(err), tag.String("status", response.Status)) - return err - } - - if !isMediaTypeJSON(response.Header.Get("Content-Type")) { - logger.Error("received invalid content-type header for non-OK HTTP response to CompleteOperation request", tag.Value(response.Header.Get("Content-Type"))) - return handlerErr - } - - var failure nexus.Failure - err = json.Unmarshal(body, &failure) - if err != nil { - logger.Error("failed to deserialize Nexus Failure from HTTP response to CompleteOperation request", tag.Error(err)) - return handlerErr +func isRetryableCallError(err error) bool { + var handlerError *nexus.HandlerError + if errors.As(err, &handlerError) { + return handlerError.Retryable() } - - handlerErr.Cause = &nexus.FailureError{Failure: failure} - return handlerErr -} - -// readAndReplaceBody reads the response body in its entirety and closes it, and then replaces the original response -// body with an in-memory buffer. -// The body is replaced even when there was an error reading the entire body. -func readAndReplaceBody(response *http.Response) ([]byte, error) { - responseBody := response.Body - body, err := io.ReadAll(responseBody) - _ = responseBody.Close() - response.Body = io.NopCloser(bytes.NewReader(body)) - return body, err + return true } -func isMediaTypeJSON(contentType string) bool { - if contentType == "" { - return false +func outcomeTag(callCtx context.Context, callErr error) string { + if callErr != nil { + if callCtx.Err() != nil { + return "request-timeout" + } + var handlerErr *nexus.HandlerError + if errors.As(callErr, &handlerErr) { + return "handler-error:" + string(handlerErr.Type) + } + return "unknown-error" } - mediaType, _, err := mime.ParseMediaType(contentType) - return err == nil && mediaType == "application/json" + return "success" } diff --git a/chasm/lib/workflow/workflow.go b/chasm/lib/workflow/workflow.go index 6e0acdd063..bf609a40b2 100644 --- a/chasm/lib/workflow/workflow.go +++ b/chasm/lib/workflow/workflow.go @@ -117,7 +117,7 @@ func (w *Workflow) AddCompletionCallbacks( func (w *Workflow) GetNexusCompletion( ctx chasm.Context, requestID string, -) (nexusrpc.OperationCompletion, error) { +) (nexusrpc.CompleteOperationOptions, error) { // Retrieve the completion data from the underlying mutable state via MSPointer return w.MSPointer.GetNexusCompletion(ctx, requestID) } diff --git a/chasm/ms_pointer.go b/chasm/ms_pointer.go index 75013b4a62..ff35e4afe5 100644 --- a/chasm/ms_pointer.go +++ b/chasm/ms_pointer.go @@ -20,6 +20,6 @@ func NewMSPointer(backend NodeBackend) MSPointer { } // GetNexusCompletion retrieves the Nexus operation completion data for the given request ID from the underlying mutable state. -func (m MSPointer) GetNexusCompletion(ctx Context, requestID string) (nexusrpc.OperationCompletion, error) { +func (m MSPointer) GetNexusCompletion(ctx Context, requestID string) (nexusrpc.CompleteOperationOptions, error) { return m.backend.GetNexusCompletion(ctx.getContext(), requestID) } diff --git a/chasm/node_backend_mock.go b/chasm/node_backend_mock.go index eee5aca836..7fa090b15d 100644 --- a/chasm/node_backend_mock.go +++ b/chasm/node_backend_mock.go @@ -26,7 +26,7 @@ type MockNodeBackend struct { HandleGetWorkflowKey func() definition.WorkflowKey HandleUpdateWorkflowStateStatus func(state enumsspb.WorkflowExecutionState, status enumspb.WorkflowExecutionStatus) (bool, error) HandleIsWorkflow func() bool - HandleGetNexusCompletion func(ctx context.Context, requestID string) (nexusrpc.OperationCompletion, error) + HandleGetNexusCompletion func(ctx context.Context, requestID string) (nexusrpc.CompleteOperationOptions, error) // Recorded calls (protected by mu). mu sync.Mutex @@ -164,11 +164,11 @@ func (m *MockNodeBackend) IsWorkflow() bool { func (m *MockNodeBackend) GetNexusCompletion( ctx context.Context, requestID string, -) (nexusrpc.OperationCompletion, error) { +) (nexusrpc.CompleteOperationOptions, error) { if m.HandleGetNexusCompletion != nil { return m.HandleGetNexusCompletion(ctx, requestID) } - return nil, nil + return nexusrpc.CompleteOperationOptions{}, nil } func (m *MockNodeBackend) NumTasksAdded() int { diff --git a/chasm/tree.go b/chasm/tree.go index 34cad3cfb6..94d8440dce 100644 --- a/chasm/tree.go +++ b/chasm/tree.go @@ -199,7 +199,7 @@ type ( GetNexusCompletion( ctx context.Context, requestID string, - ) (nexusrpc.OperationCompletion, error) + ) (nexusrpc.CompleteOperationOptions, error) } // NodePathEncoder is an interface for encoding and decoding node paths. diff --git a/common/nexus/failure.go b/common/nexus/failure.go index 35522dcd52..c5543f1b83 100644 --- a/common/nexus/failure.go +++ b/common/nexus/failure.go @@ -2,13 +2,15 @@ package nexus import ( "context" + "encoding/base64" "encoding/json" - "errors" + "fmt" "net/http" "sync/atomic" "github.com/nexus-rpc/sdk-go/nexus" commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" failurepb "go.temporal.io/api/failure/v1" nexuspb "go.temporal.io/api/nexus/v1" "go.temporal.io/api/serviceerror" @@ -52,135 +54,254 @@ var failureTypeString = string((&failurepb.Failure{}).ProtoReflect().Descriptor( // ProtoFailureToNexusFailure converts a proto Nexus Failure to a Nexus SDK Failure. func ProtoFailureToNexusFailure(failure *nexuspb.Failure) nexus.Failure { - return nexus.Failure{ - Message: failure.GetMessage(), - Metadata: failure.GetMetadata(), - Details: failure.GetDetails(), + nf := nexus.Failure{ + Message: failure.GetMessage(), + StackTrace: failure.GetStackTrace(), + Metadata: failure.GetMetadata(), + Details: failure.GetDetails(), + } + if failure.GetCause() != nil { + cause := ProtoFailureToNexusFailure(failure.GetCause()) + nf.Cause = &cause } + return nf } // NexusFailureToProtoFailure converts a Nexus SDK Failure to a proto Nexus Failure. // Always returns a non-nil value. func NexusFailureToProtoFailure(failure nexus.Failure) *nexuspb.Failure { - return &nexuspb.Failure{ - Message: failure.Message, - Metadata: failure.Metadata, - Details: failure.Details, + pf := &nexuspb.Failure{ + Message: failure.Message, + Metadata: failure.Metadata, + Details: failure.Details, + StackTrace: failure.StackTrace, + } + if failure.Cause != nil { + pf.Cause = NexusFailureToProtoFailure(*failure.Cause) } + return pf +} + +type serializedHandlerError struct { + Type string `json:"type,omitempty"` + RetryableOverride *bool `json:"retryableOverride,omitempty"` + // Bytes as base64 encoded string. + EncodedAttributes string `json:"encodedAttributes,omitempty"` } -// APIFailureToNexusFailure converts an API proto Failure to a Nexus SDK Failure setting the metadata "type" field to -// the proto fullname of the temporal API Failure message. -// Mutates the failure temporarily, unsetting the Message field to avoid duplicating the information in the serialized -// failure. Mutating was chosen over cloning for performance reasons since this function may be called frequently. -func APIFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) { - // Unset message so it's not serialized in the details. +// TemporalFailureToNexusFailure converts an API proto Failure to a Nexus SDK Failure setting the metadata "type" field to +// the proto fullname of the temporal API Failure message or the standard Nexus SDK failure types. +// Returns an error if the failure cannot be converted. +// Mutates the failure temporarily, unsetting the Message and StackTrace fields to avoid duplicating the information in +// the serialized failure. Mutating was chosen over cloning for performance reasons since this function may be called +// frequently. +func TemporalFailureToNexusFailure(failure *failurepb.Failure) (nexus.Failure, error) { + var causep *nexus.Failure + if failure.GetCause() != nil { + var cause nexus.Failure + var err error + cause, err = TemporalFailureToNexusFailure(failure.GetCause()) + if err != nil { + return nexus.Failure{}, err + } + causep = &cause + } + + switch info := failure.GetFailureInfo().(type) { + case *failurepb.Failure_NexusHandlerFailureInfo: + var encodedAttributes string + if failure.EncodedAttributes != nil { + b, err := protojson.Marshal(failure.EncodedAttributes) + if err != nil { + return nexus.Failure{}, fmt.Errorf("failed to deserialize HandlerError attributes: %w", err) + } + encodedAttributes = base64.RawURLEncoding.EncodeToString(b) + } + var retryableOverride *bool + // nolint:exhaustive,revive // There are only two valid values other than unspecified. + switch info.NexusHandlerFailureInfo.GetRetryBehavior() { + case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE: + val := true + retryableOverride = &val + case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: + val := false + retryableOverride = &val + } + + handlerError := serializedHandlerError{ + Type: info.NexusHandlerFailureInfo.GetType(), + RetryableOverride: retryableOverride, + EncodedAttributes: encodedAttributes, + } + + details, err := json.Marshal(handlerError) + if err != nil { + return nexus.Failure{}, err + } + return nexus.Failure{ + Message: failure.GetMessage(), + StackTrace: failure.GetStackTrace(), + Metadata: map[string]string{ + "type": "nexus.HandlerError", + }, + Details: details, + Cause: causep, + }, nil + } + // Unset message and stack trace so it's not serialized in the details. var message string message, failure.Message = failure.Message, "" + var stackTrace string + stackTrace, failure.StackTrace = failure.StackTrace, "" + data, err := protojson.Marshal(failure) failure.Message = message - + failure.StackTrace = stackTrace if err != nil { return nexus.Failure{}, err } + return nexus.Failure{ - Message: failure.GetMessage(), + Message: failure.GetMessage(), + StackTrace: failure.GetStackTrace(), Metadata: map[string]string{ "type": failureTypeString, }, Details: data, + Cause: causep, }, nil } -// NexusFailureToAPIFailure converts a Nexus Failure to an API proto Failure. +// NexusFailureToTemporalFailure converts a Nexus Failure to an API proto Failure. // If the failure metadata "type" field is set to the fullname of the temporal API Failure message, the failure is -// reconstructed using protojson.Unmarshal on the failure details field. -func NexusFailureToAPIFailure(failure nexus.Failure, retryable bool) (*failurepb.Failure, error) { - apiFailure := &failurepb.Failure{} +// reconstructed using protojson.Unmarshal on the failure details field. Otherwise, the failure is reconstructed +// based on the known Nexus SDK failure types. +// Returns an error if the failure cannot be converted. +// nolint:revive // cognitive-complexity is high but justified to keep each case together +func NexusFailureToTemporalFailure(f nexus.Failure) (*failurepb.Failure, error) { + apiFailure := &failurepb.Failure{ + Message: f.Message, + StackTrace: f.StackTrace, + } - if failure.Metadata != nil && failure.Metadata["type"] == failureTypeString { - if err := protojson.Unmarshal(failure.Details, apiFailure); err != nil { - return nil, err + if f.Metadata != nil { + switch f.Metadata["type"] { + case failureTypeString: + opts := protojson.UnmarshalOptions{DiscardUnknown: true} + if err := opts.Unmarshal(f.Details, apiFailure); err != nil { + return nil, err + } + // Restore these fields as they are not included in the marshalled failure. + apiFailure.Message = f.Message + apiFailure.StackTrace = f.StackTrace + case "nexus.OperationError": + // Special case for OperationError that adapts from Nexus semantics to Temporal semantics. + // Note that Temporal -> Temporal doesn't go through this code path, operation errors are always used as empty + // wrappers for an underlying causes. + var operationError *nexus.OperationError + err := json.Unmarshal(f.Details, &operationError) + if err != nil { + return nil, fmt.Errorf("failed to deserialize OperationError: %w", err) + } + if operationError.State == nexus.OperationStateCanceled { + // Canceled operation errors are represented as CanceledFailure in Temporal. + apiFailure.FailureInfo = &failurepb.Failure_CanceledFailureInfo{ + CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, + } + } else { + // Failed operation errors are represented as non-retryable ApplicationFailure in Temporal. + apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + NonRetryable: true, + Type: "OperationError", + }, + } + } + case "nexus.HandlerError": + var se serializedHandlerError + err := json.Unmarshal(f.Details, &se) + if err != nil { + return nil, fmt.Errorf("failed to deserialize HandlerError: %w", err) + } + var retryBehavior enumspb.NexusHandlerErrorRetryBehavior + if se.RetryableOverride == nil { + retryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_UNSPECIFIED + } else if *se.RetryableOverride { + retryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE + } else { + retryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE + } + apiFailure.FailureInfo = &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: se.Type, + RetryBehavior: retryBehavior, + }, + } + if len(se.EncodedAttributes) > 0 { + decoded, err := base64.RawURLEncoding.DecodeString(se.EncodedAttributes) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 HandlerError attributes: %w", err) + } + apiFailure.EncodedAttributes = &commonpb.Payload{} + if err := protojson.Unmarshal(decoded, apiFailure.EncodedAttributes); err != nil { + return nil, fmt.Errorf("failed to deserialize HandlerError attributes: %w", err) + } + } + default: + // We don't recognize this type, convert to a generic ApplicationFailure and preserve the original Nexus failure + // as serialized details. + applicationFailureInfo, err := nexusFailureMetadataToApplicationFailureInfo(f) + if err != nil { + return nil, fmt.Errorf("failed to serialize Nexus failure: %w", err) + } + apiFailure.FailureInfo = applicationFailureInfo } - } else { - payloads, err := nexusFailureMetadataToPayloads(failure) + } else if len(f.Details) > 0 { + // We don't recognize this type, convert to a generic ApplicationFailure and preserve the original Nexus failure as + // serialized details. + applicationFailureInfo, err := nexusFailureMetadataToApplicationFailureInfo(f) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to serialize Nexus failure: %w", err) } - apiFailure.FailureInfo = &failurepb.Failure_ApplicationFailureInfo{ - ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ - // Make up a type here, it's not part of the Nexus Failure spec. - Type: "NexusFailure", - Details: payloads, - NonRetryable: !retryable, - }, + apiFailure.FailureInfo = applicationFailureInfo + } + + if f.Cause != nil { + var err error + apiFailure.Cause, err = NexusFailureToTemporalFailure(*f.Cause) + if err != nil { + return nil, err } } - // Ensure this always gets written. - apiFailure.Message = failure.Message return apiFailure, nil } -func OperationErrorToTemporalFailure(opErr *nexus.OperationError) (*failurepb.Failure, error) { - var nexusFailure nexus.Failure - failureErr, ok := opErr.Cause.(*nexus.FailureError) - if ok { - nexusFailure = failureErr.Failure - } else if opErr.Cause != nil { - nexusFailure = nexus.Failure{Message: opErr.Cause.Error()} - } - - // Canceled must be translated into a CanceledFailure to match the SDK expectation. - if opErr.State == nexus.OperationStateCanceled { - if nexusFailure.Metadata != nil && nexusFailure.Metadata["type"] == failureTypeString { - temporalFailure, err := NexusFailureToAPIFailure(nexusFailure, false) - if err != nil { - return nil, err - } - if temporalFailure.GetCanceledFailureInfo() != nil { - // We already have a CanceledFailure, use it. - return temporalFailure, nil - } - // Fallback to encoding the Nexus failure into a Temporal canceled failure, we expect operations that end up - // as canceled to have a CanceledFailureInfo object. - } - payloads, err := nexusFailureMetadataToPayloads(nexusFailure) +func nexusFailureMetadataToApplicationFailureInfo(failure nexus.Failure) (*failurepb.Failure_ApplicationFailureInfo, error) { + var payloads *commonpb.Payloads + if len(failure.Metadata) > 0 || len(failure.Details) > 0 { + // Delete before serializing (note the failure here is passed by value). + failure.Message = "" + failure.StackTrace = "" + data, err := json.Marshal(failure) if err != nil { return nil, err } - return &failurepb.Failure{ - Message: nexusFailure.Message, - FailureInfo: &failurepb.Failure_CanceledFailureInfo{ - CanceledFailureInfo: &failurepb.CanceledFailureInfo{ - Details: payloads, + payloads = &commonpb.Payloads{ + Payloads: []*commonpb.Payload{ + { + Metadata: map[string][]byte{ + "encoding": []byte("json/plain"), + }, + Data: data, }, }, - }, nil - } - - return NexusFailureToAPIFailure(nexusFailure, false) -} - -func nexusFailureMetadataToPayloads(failure nexus.Failure) (*commonpb.Payloads, error) { - if len(failure.Metadata) == 0 && len(failure.Details) == 0 { - return nil, nil - } - // Delete before serializing. - failure.Message = "" - data, err := json.Marshal(failure) - if err != nil { - return nil, err + } } - return &commonpb.Payloads{ - Payloads: []*commonpb.Payload{ - { - Metadata: map[string][]byte{ - "encoding": []byte("json/plain"), - }, - Data: data, - }, + return &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Details: payloads, }, - }, err + }, nil } // ConvertGRPCError converts either a serviceerror or a gRPC status error into a Nexus HandlerError if possible. @@ -210,16 +331,16 @@ func ConvertGRPCError(err error, exposeDetails bool) error { errMessage = "bad request" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeBadRequest, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeBadRequest, + Message: errMessage, } case codes.Aborted, codes.Unavailable: if !exposeDetails { errMessage = "service unavailable" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnavailable, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeUnavailable, + Message: errMessage, } case codes.Canceled: // TODO: This should have a different status code (e.g. 499 which is semi standard but not supported by nexus). @@ -228,72 +349,72 @@ func ConvertGRPCError(err error, exposeDetails bool) error { errMessage = "canceled" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeInternal, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeInternal, + Message: errMessage, } case codes.DataLoss, codes.Internal, codes.Unknown: if !exposeDetails { errMessage = "internal error" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeInternal, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeInternal, + Message: errMessage, } case codes.Unauthenticated: if !exposeDetails { errMessage = "authentication failed" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnauthenticated, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeUnauthenticated, + Message: errMessage, } case codes.PermissionDenied: if !exposeDetails { errMessage = "permission denied" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnauthorized, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeUnauthorized, + Message: errMessage, } case codes.NotFound: if !exposeDetails { errMessage = "not found" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeNotFound, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeNotFound, + Message: errMessage, } case codes.ResourceExhausted: if !exposeDetails { errMessage = "resource exhausted" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeResourceExhausted, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeResourceExhausted, + Message: errMessage, } case codes.Unimplemented: if !exposeDetails { errMessage = "not implemented" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeNotImplemented, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeNotImplemented, + Message: errMessage, } case codes.DeadlineExceeded: if !exposeDetails { errMessage = "request timeout" } return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUpstreamTimeout, - Cause: errors.New(errMessage), + Type: nexus.HandlerErrorTypeUpstreamTimeout, + Message: errMessage, } case codes.OK: return nil } if !exposeDetails { return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeInternal, - Cause: errors.New("internal error"), + Type: nexus.HandlerErrorTypeInternal, + Message: "internal error", } } // Let the nexus SDK handle this for us (log and convert to an internal error). @@ -302,32 +423,7 @@ func ConvertGRPCError(err error, exposeDetails bool) error { func AdaptAuthorizeError(permissionDeniedError *serviceerror.PermissionDenied) error { if permissionDeniedError.Reason != "" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnauthorized, "permission denied: %s", permissionDeniedError.Reason) - } - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnauthorized, "permission denied") -} - -func HandlerErrorTypeFromHTTPStatus(statusCode int) nexus.HandlerErrorType { - switch statusCode { - case http.StatusBadRequest: - return nexus.HandlerErrorTypeBadRequest - case http.StatusUnauthorized: - return nexus.HandlerErrorTypeUnauthenticated - case http.StatusForbidden: - return nexus.HandlerErrorTypeUnauthorized - case http.StatusNotFound: - return nexus.HandlerErrorTypeNotFound - case http.StatusTooManyRequests: - return nexus.HandlerErrorTypeResourceExhausted - case http.StatusInternalServerError: - return nexus.HandlerErrorTypeInternal - case http.StatusNotImplemented: - return nexus.HandlerErrorTypeNotImplemented - case http.StatusServiceUnavailable: - return nexus.HandlerErrorTypeUnavailable - case nexus.StatusUpstreamTimeout: - return nexus.HandlerErrorTypeUpstreamTimeout - default: - return nexus.HandlerErrorTypeInternal + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnauthorized, "permission denied: %s", permissionDeniedError.Reason) } + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnauthorized, "permission denied") } diff --git a/common/nexus/failure_test.go b/common/nexus/failure_test.go new file mode 100644 index 0000000000..5f7dfc7cd9 --- /dev/null +++ b/common/nexus/failure_test.go @@ -0,0 +1,272 @@ +package nexus + +import ( + "testing" + + "github.com/nexus-rpc/sdk-go/nexus" + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" + failurepb "go.temporal.io/api/failure/v1" + "go.temporal.io/sdk/temporal" + "go.temporal.io/server/common/nexus/nexusrpc" + "go.temporal.io/server/common/testing/protorequire" +) + +func TestRoundTrip_ApplicationFailure(t *testing.T) { + original := &failurepb.Failure{ + Message: "application error", + StackTrace: "stack trace here", + EncodedAttributes: mustToPayload(t, "encoded"), + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Type: "CustomError", + NonRetryable: false, + Details: &commonpb.Payloads{ + Payloads: []*commonpb.Payload{mustToPayload(t, "encoded")}, + }, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusHandlerFailure_Retryable(t *testing.T) { + original := &failurepb.Failure{ + Message: "handler error - retryable", + StackTrace: "handler stack trace", + FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: "CustomHandlerError", + RetryBehavior: enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusHandlerFailure_NonRetryable(t *testing.T) { + original := &failurepb.Failure{ + Message: "handler error - non-retryable", + StackTrace: "handler stack trace", + FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: "FatalHandlerError", + RetryBehavior: enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusHandlerFailure_Unspecified(t *testing.T) { + original := &failurepb.Failure{ + Message: "handler error - unspecified retry", + StackTrace: "handler stack trace", + FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: "UnspecifiedHandlerError", + RetryBehavior: enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_UNSPECIFIED, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_NexusHandlerFailure_WithAttributes(t *testing.T) { + original := &failurepb.Failure{ + Message: "handler error with attributes", + StackTrace: "handler stack trace", + FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ + NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ + Type: "ComplexHandlerError", + RetryBehavior: enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE, + }, + }, + EncodedAttributes: mustToPayload(t, "encoded attributes"), + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_WithNestedCauses(t *testing.T) { + original := &failurepb.Failure{ + Message: "top level failure", + StackTrace: "top stack trace", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Type: "TopLevelError", + }, + }, + Cause: &failurepb.Failure{ + Message: "middle failure", + StackTrace: "middle stack trace", + FailureInfo: &failurepb.Failure_TimeoutFailureInfo{ + TimeoutFailureInfo: &failurepb.TimeoutFailureInfo{ + TimeoutType: enumspb.TIMEOUT_TYPE_START_TO_CLOSE, + }, + }, + Cause: &failurepb.Failure{ + Message: "root cause", + StackTrace: "root stack trace", + FailureInfo: &failurepb.Failure_ServerFailureInfo{ + ServerFailureInfo: &failurepb.ServerFailureInfo{ + NonRetryable: true, + }, + }, + }, + }, + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_EmptyFailure(t *testing.T) { + original := &failurepb.Failure{ + Message: "simple message", + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + protorequire.ProtoEqual(t, original, converted) +} + +func TestRoundTrip_OnlyStackTrace(t *testing.T) { + original := &failurepb.Failure{ + StackTrace: "just a stack trace", + } + + nexusFailure, err := TemporalFailureToNexusFailure(original) + require.NoError(t, err) + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + protorequire.ProtoEqual(t, original, converted) +} + +func TestFromOperationFailedError(t *testing.T) { + nexusFailure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(&nexus.OperationError{ + State: nexus.OperationStateFailed, + Message: "operation failed", + StackTrace: "stack trace", + }) + require.NoError(t, err) + cause, err := TemporalFailureToNexusFailure( + temporal.GetDefaultFailureConverter().ErrorToFailure( + temporal.NewApplicationError("app err", "CustomError", "details"), + ), + ) + require.NoError(t, err) + nexusFailure.Cause = &cause + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + expected := &failurepb.Failure{ + Message: "operation failed", + StackTrace: "stack trace", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + NonRetryable: true, + Type: "OperationError", + }, + }, + Cause: &failurepb.Failure{ + Message: "app err", + Source: "GoSDK", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Type: "CustomError", + Details: &commonpb.Payloads{ + Payloads: []*commonpb.Payload{mustToPayload(t, "details")}, + }, + }, + }, + }, + } + protorequire.ProtoEqual(t, expected, converted) +} + +func TestFromOperationCanceledError(t *testing.T) { + nexusFailure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(&nexus.OperationError{ + State: nexus.OperationStateCanceled, + Message: "operation canceled", + StackTrace: "stack trace", + }) + require.NoError(t, err) + cause, err := TemporalFailureToNexusFailure( + temporal.GetDefaultFailureConverter().ErrorToFailure( + temporal.NewApplicationError("app err", "CustomError", "details"), + ), + ) + require.NoError(t, err) + nexusFailure.Cause = &cause + + converted, err := NexusFailureToTemporalFailure(nexusFailure) + require.NoError(t, err) + + expected := &failurepb.Failure{ + Message: "operation canceled", + StackTrace: "stack trace", + FailureInfo: &failurepb.Failure_CanceledFailureInfo{ + CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, + }, + Cause: &failurepb.Failure{ + Message: "app err", + Source: "GoSDK", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Type: "CustomError", + Details: &commonpb.Payloads{ + Payloads: []*commonpb.Payload{mustToPayload(t, "details")}, + }, + }, + }, + }, + } + protorequire.ProtoEqual(t, expected, converted) +} diff --git a/common/nexus/nexusrpc/api.go b/common/nexus/nexusrpc/api.go index ae20c54f61..5743725633 100644 --- a/common/nexus/nexusrpc/api.go +++ b/common/nexus/nexusrpc/api.go @@ -33,8 +33,8 @@ const contentTypeJSON = "application/json" // Query param for passing a callback URL. const queryCallbackURL = "callback" -// HTTP status code for failed operation responses. -const statusOperationFailed = http.StatusFailedDependency +// HTTP status code for unsuccessful (failed or canceled) operation responses. +const statusOperationUnsuccessful = http.StatusFailedDependency func isMediaTypeJSON(contentType string) bool { if contentType == "" { @@ -273,3 +273,17 @@ func ParseDuration(value string) (time.Duration, error) { func FormatDuration(d time.Duration) string { return strconv.FormatInt(d.Milliseconds(), 10) + "ms" } + +// MarkAsWrapperError adds the "unwrap-error" metadata to the original failure of the given OperationError, which +// signals the Temporal codebase to unwrap the underlying failure as the failure cause. +// This is used as a shim for Temporal->Temporal calls. Temporal already has a Failure type that represents an +// OperationError and does not need to record this wrapper error. +func MarkAsWrapperError(failureConverter FailureConverter, opErr *nexus.OperationError) error { + originalFailure, err := failureConverter.ErrorToFailure(opErr) + if err != nil { + return err + } + originalFailure.Metadata["unwrap-error"] = "true" + opErr.OriginalFailure = &originalFailure + return nil +} diff --git a/common/nexus/nexusrpc/cancel_test.go b/common/nexus/nexusrpc/cancel_test.go index c33e13d109..8ae2cd32e3 100644 --- a/common/nexus/nexusrpc/cancel_test.go +++ b/common/nexus/nexusrpc/cancel_test.go @@ -23,19 +23,19 @@ func (h *asyncWithCancelHandler) StartOperation(ctx context.Context, service, op func (h *asyncWithCancelHandler) CancelOperation(ctx context.Context, service, operation, token string, options nexus.CancelOperationOptions) error { if service != testService { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected service: %s", service) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected service: %s", service) } if operation != "f/o/o" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to be 'foo', got: %s", operation) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to be 'foo', got: %s", operation) } if token != "a/sync" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation ID to be 'async', got: %s", token) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation ID to be 'async', got: %s", token) } if h.expectHeader && options.Header.Get("foo") != "bar" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' request header") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' request header") } if options.Header.Get("User-Agent") != "temporalio/server" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", options.Header.Get("User-Agent")) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", options.Header.Get("User-Agent")) } return nil } @@ -78,14 +78,14 @@ func (h *echoTimeoutAsyncWithCancelHandler) StartOperation(ctx context.Context, func (h *echoTimeoutAsyncWithCancelHandler) CancelOperation(ctx context.Context, service, operation, token string, options nexus.CancelOperationOptions) error { deadline, set := ctx.Deadline() if h.expectedTimeout > 0 && !set { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to have timeout set but context has no deadline") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to have timeout set but context has no deadline") } if h.expectedTimeout <= 0 && set { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to have no timeout but context has deadline set") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected operation to have no timeout but context has deadline set") } timeout := time.Until(deadline) if timeout > h.expectedTimeout { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation has timeout (%s) greater than expected (%s)", timeout.String(), h.expectedTimeout.String()) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation has timeout (%s) greater than expected (%s)", timeout.String(), h.expectedTimeout.String()) } return nil } diff --git a/common/nexus/nexusrpc/client.go b/common/nexus/nexusrpc/client.go index 126c1f4815..aad46be6db 100644 --- a/common/nexus/nexusrpc/client.go +++ b/common/nexus/nexusrpc/client.go @@ -30,7 +30,7 @@ type HTTPClientOptions struct { Serializer nexus.Serializer // A [FailureConverter] to convert a [Failure] instance to and from an [error]. Defaults to // [DefaultFailureConverter]. - FailureConverter nexus.FailureConverter + FailureConverter FailureConverter } // User-Agent header set on HTTP requests. @@ -74,6 +74,66 @@ func newUnexpectedResponseError(message string, response *http.Response, body [] } } +type baseHTTPClient struct { + // A function for making HTTP requests. + // Defaults to [http.DefaultClient.Do]. + httpCaller func(*http.Request) (*http.Response, error) + // A [serializer] to customize client serialization behavior. + // By default the client handles JSONables, byte slices, and nil. + serializer nexus.Serializer + // A [failureConverter] to convert a [Failure] instance to and from an [error]. Defaults to + // [DefaultFailureConverter]. + failureConverter FailureConverter +} + +func (c *baseHTTPClient) failureFromResponse(response *http.Response, body []byte) (nexus.Failure, error) { + if !isMediaTypeJSON(response.Header.Get("Content-Type")) { + return nexus.Failure{}, newUnexpectedResponseError(fmt.Sprintf("invalid response content type: %q", response.Header.Get("Content-Type")), response, body) + } + var failure nexus.Failure + err := json.Unmarshal(body, &failure) + return failure, err +} + +func (c *baseHTTPClient) defaultErrorFromResponse(response *http.Response, body []byte, cause error) error { + errorType, err := httpStatusCodeToHandlerErrorType(response) + if err != nil { + // TODO(bergundy): optimization - use the provided cause, it's already a deserialized failure. + return newUnexpectedResponseError(err.Error(), response, body) + } + handlerErr := &nexus.HandlerError{ + Type: errorType, + // For compatibility with older servers. + RetryBehavior: retryBehaviorFromHeader(response.Header), + Cause: cause, + } + + // Ensure we the original failure is available, the calling code expects it. + originalFailure, err := c.failureConverter.ErrorToFailure(handlerErr) + if err != nil { + return newUnexpectedResponseError("failed to construct handler error from response: "+err.Error(), response, body) + } + handlerErr.OriginalFailure = &originalFailure + return handlerErr +} + +// BestEffortHandlerErrorFromResponse attempts to read a handler error from the response, but falls back to a default error. +// This method is exposed as a workaround because the functionality is required for completion. +func (c *baseHTTPClient) bestEffortHandlerErrorFromResponse(response *http.Response, body []byte) error { + failure, err := c.failureFromResponse(response, body) + if err != nil { + return c.defaultErrorFromResponse(response, body, nil) + } + convErr, err := c.failureConverter.FailureToError(failure) + if err != nil { + return newUnexpectedResponseError(fmt.Sprintf("failed to convert Failure to error: %s", err.Error()), response, body) + } + if _, ok := convErr.(*nexus.HandlerError); !ok { + convErr = c.defaultErrorFromResponse(response, body, convErr) + } + return convErr +} + // An HTTPClient makes Nexus service requests as defined in the [Nexus HTTP API]. // // It can start a new operation and get an [OperationHandle] to an existing, asynchronous operation. @@ -85,8 +145,8 @@ func newUnexpectedResponseError(message string, response *http.Response, body [] // // [Nexus HTTP API]: https://github.com/nexus-rpc/api type HTTPClient struct { - // The options this client was created with after applying defaults. - options HTTPClientOptions + baseHTTPClient + service string serviceBaseURL *url.URL } @@ -115,12 +175,17 @@ func NewHTTPClient(options HTTPClientOptions) (*HTTPClient, error) { options.Serializer = nexus.DefaultSerializer() } if options.FailureConverter == nil { - options.FailureConverter = nexus.DefaultFailureConverter() + options.FailureConverter = DefaultFailureConverter() } return &HTTPClient{ - options: options, + baseHTTPClient: baseHTTPClient{ + serializer: options.Serializer, + failureConverter: options.FailureConverter, + httpCaller: options.HTTPCaller, + }, serviceBaseURL: baseURL, + service: options.Service, }, nil } @@ -174,7 +239,7 @@ func (c *HTTPClient) StartOperation( content, ok := input.(*nexus.Content) if !ok { var err error - content, err = c.options.Serializer.Serialize(input) + content, err = c.serializer.Serialize(input) if err != nil { return nil, err } @@ -192,7 +257,7 @@ func (c *HTTPClient) StartOperation( } } - url := c.serviceBaseURL.JoinPath(url.PathEscape(c.options.Service), url.PathEscape(operation)) + url := c.serviceBaseURL.JoinPath(url.PathEscape(c.service), url.PathEscape(operation)) if options.CallbackURL != "" { q := url.Query() @@ -219,8 +284,10 @@ func (c *HTTPClient) StartOperation( } addContextTimeoutToHTTPHeader(ctx, request.Header) addNexusHeaderToHTTPHeader(options.Header, request.Header) + // If this request is handled by a newer server that supports Nexus failure serialization, trigger that behavior. + request.Header.Set("temporal-nexus-failure-support", "true") - response, err := c.options.HTTPCaller(request) + response, err := c.httpCaller(request) if err != nil { return nil, err } @@ -247,7 +314,7 @@ func (c *HTTPClient) StartOperation( if response.StatusCode == http.StatusOK { return &ClientStartOperationResponse[*nexus.LazyValue]{ Successful: nexus.NewLazyValue( - c.options.Serializer, + c.serializer, &nexus.Reader{ ReadCloser: response.Body, Header: prefixStrippedHTTPHeaderToNexusHeader(response.Header, "content-"), @@ -280,22 +347,35 @@ func (c *HTTPClient) StartOperation( Pending: handle, Links: links, }, nil - case statusOperationFailed: - state, err := getUnsuccessfulStateFromHeader(response, body) + case statusOperationUnsuccessful: + failure, err := c.failureFromResponse(response, body) if err != nil { return nil, err } - failure, err := c.failureFromResponse(response, body) + wireErr, err := c.failureConverter.FailureToError(failure) if err != nil { return nil, err } - failureErr := c.options.FailureConverter.FailureToError(failure) - return nil, &nexus.OperationError{ - State: state, - Cause: failureErr, + // For compatibility with older servers. + if _, ok := wireErr.(*nexus.OperationError); !ok { + state, err := getUnsuccessfulStateFromHeader(response, body) + if err != nil { + return nil, err + } + opErr := &nexus.OperationError{ + State: state, + Message: "nexus operation completed unsuccessfully", + Cause: wireErr, + } + if err := MarkAsWrapperError(c.failureConverter, opErr); err != nil { + return nil, fmt.Errorf("failed to mark operation error as wrapper error: %w", err) + } + wireErr = opErr } + + return nil, wireErr default: return nil, c.bestEffortHandlerErrorFromResponse(response, body) } @@ -346,99 +426,32 @@ func operationInfoFromResponse(response *http.Response, body []byte) (*nexus.Ope return &info, nil } -func (c *HTTPClient) failureFromResponse(response *http.Response, body []byte) (nexus.Failure, error) { - if !isMediaTypeJSON(response.Header.Get("Content-Type")) { - return nexus.Failure{}, newUnexpectedResponseError(fmt.Sprintf("invalid response content type: %q", response.Header.Get("Content-Type")), response, body) - } - var failure nexus.Failure - err := json.Unmarshal(body, &failure) - return failure, err -} - -func (c *HTTPClient) failureFromResponseOrDefault(response *http.Response, body []byte, defaultMessage string) nexus.Failure { - failure, err := c.failureFromResponse(response, body) - if err != nil { - failure.Message = defaultMessage - } - return failure -} - -func (c *HTTPClient) failureErrorFromResponseOrDefault(response *http.Response, body []byte, defaultMessage string) error { - failure := c.failureFromResponseOrDefault(response, body, defaultMessage) - failureErr := c.options.FailureConverter.FailureToError(failure) - return failureErr -} - -func (c *HTTPClient) bestEffortHandlerErrorFromResponse(response *http.Response, body []byte) error { +func httpStatusCodeToHandlerErrorType(response *http.Response) (nexus.HandlerErrorType, error) { switch response.StatusCode { case http.StatusBadRequest: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeBadRequest, - Cause: c.failureErrorFromResponseOrDefault(response, body, "bad request"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } - case http.StatusUnauthorized: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnauthenticated, - Cause: c.failureErrorFromResponseOrDefault(response, body, "unauthenticated"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeBadRequest, nil case http.StatusRequestTimeout: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeRequestTimeout, - Cause: c.failureErrorFromResponseOrDefault(response, body, "request timeout"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeRequestTimeout, nil case http.StatusConflict: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeConflict, - Cause: c.failureErrorFromResponseOrDefault(response, body, "conflict"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeConflict, nil + case http.StatusUnauthorized: + return nexus.HandlerErrorTypeUnauthenticated, nil case http.StatusForbidden: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnauthorized, - Cause: c.failureErrorFromResponseOrDefault(response, body, "unauthorized"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeUnauthorized, nil case http.StatusNotFound: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeNotFound, - Cause: c.failureErrorFromResponseOrDefault(response, body, "not found"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeNotFound, nil case http.StatusTooManyRequests: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeResourceExhausted, - Cause: c.failureErrorFromResponseOrDefault(response, body, "resource exhausted"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeResourceExhausted, nil case http.StatusInternalServerError: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeInternal, - Cause: c.failureErrorFromResponseOrDefault(response, body, "internal error"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeInternal, nil case http.StatusNotImplemented: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeNotImplemented, - Cause: c.failureErrorFromResponseOrDefault(response, body, "not implemented"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeNotImplemented, nil case http.StatusServiceUnavailable: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUnavailable, - Cause: c.failureErrorFromResponseOrDefault(response, body, "unavailable"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeUnavailable, nil case nexus.StatusUpstreamTimeout: - return &nexus.HandlerError{ - Type: nexus.HandlerErrorTypeUpstreamTimeout, - Cause: c.failureErrorFromResponseOrDefault(response, body, "upstream timeout"), - RetryBehavior: retryBehaviorFromHeader(response.Header), - } + return nexus.HandlerErrorTypeUpstreamTimeout, nil default: - return newUnexpectedResponseError(fmt.Sprintf("unexpected response status: %q", response.Status), response, body) + return "", fmt.Errorf("unexpected response status: %q", response.Status) } } diff --git a/common/nexus/nexusrpc/completion.go b/common/nexus/nexusrpc/completion.go index 9c5a49df64..581be0032e 100644 --- a/common/nexus/nexusrpc/completion.go +++ b/common/nexus/nexusrpc/completion.go @@ -6,170 +6,92 @@ import ( "encoding/json" "io" "log/slog" - "maps" "net/http" - "strconv" "time" "github.com/nexus-rpc/sdk-go/nexus" ) -// NewCompletionHTTPRequest creates an HTTP request that delivers an operation completion to a given URL. -func NewCompletionHTTPRequest(ctx context.Context, url string, completion OperationCompletion) (*http.Request, error) { - httpReq, err := http.NewRequestWithContext(ctx, "POST", url, nil) - if err != nil { - return nil, err - } - if err := completion.applyToHTTPRequest(httpReq); err != nil { - return nil, err - } - - httpReq.Header.Set(headerUserAgent, userAgent) - return httpReq, nil -} - -// OperationCompletion is input for [NewCompletionHTTPRequest]. -// It has two implementations: [OperationCompletionSuccessful] and [OperationCompletionUnsuccessful]. -type OperationCompletion interface { - applyToHTTPRequest(*http.Request) error -} - -// OperationCompletionSuccessful is input for [NewCompletionHTTPRequest], used to deliver successful operation results. -type OperationCompletionSuccessful struct { - // Header to send in the completion request. - // Note that this is a Nexus header, not an HTTP header. - Header nexus.Header - - // A [Reader] that may be directly set on the completion or constructed when instantiating via - // [NewOperationCompletionSuccessful]. - // Automatically closed when the completion is delivered. - Reader *nexus.Reader - // OperationToken is the unique token for this operation. Used when a completion callback is received before a - // started response. - OperationToken string - // StartTime is the time the operation started. Used when a completion callback is received before a started response. - StartTime time.Time - // CloseTime is the time the operation completed. Used when a completion callback is received before a started response. - CloseTime time.Time - // Links are used to link back to the operation when a completion callback is received before a started response. - Links []nexus.Link +// CompletionHTTPClient is a client for sending Nexus operation completion callbacks via HTTP. +type CompletionHTTPClient struct { + baseHTTPClient } -// OperationCompletionSuccessfulOptions are options for [NewOperationCompletionSuccessful]. -type OperationCompletionSuccessfulOptions struct { - // Optional serializer for the result. Defaults to the SDK's default Serializer, which handles JSONables, byte - // slices and nils. +// CompletionHTTPClientOptions are options for [NewCompletionHTTPClient]. +type CompletionHTTPClientOptions struct { + // A function for making HTTP requests. + // Defaults to [http.DefaultClient.Do]. + HTTPCaller func(*http.Request) (*http.Response, error) + // A [serializer] to customize client serialization behavior. + // By default the client handles JSONables, byte slices, and nil. Serializer nexus.Serializer - // OperationToken is the unique token for this operation. Used when a completion callback is received before a - // started response. - OperationToken string - // StartTime is the time the operation started. Used when a completion callback is received before a started response. - StartTime time.Time - // CloseTime is the time the operation completed. Used when a completion callback is received before a started response. - CloseTime time.Time - // Links are used to link back to the operation when a completion callback is received before a started response. - Links []nexus.Link + // A [failureConverter] to convert a [Failure] instance to and from an [error]. Defaults to + // [DefaultFailureConverter]. + FailureConverter FailureConverter } -// NewOperationCompletionSuccessful constructs an [OperationCompletionSuccessful] from a given result. -func NewOperationCompletionSuccessful(result any, options OperationCompletionSuccessfulOptions) (*OperationCompletionSuccessful, error) { - reader, ok := result.(*nexus.Reader) - if !ok { - content, ok := result.(*nexus.Content) - if !ok { - serializer := options.Serializer - if serializer == nil { - serializer = nexus.DefaultSerializer() - } - var err error - content, err = serializer.Serialize(result) - if err != nil { - return nil, err - } - } - header := maps.Clone(content.Header) - if header == nil { - header = make(nexus.Header, 1) - } - header["length"] = strconv.Itoa(len(content.Data)) - - reader = &nexus.Reader{ - Header: header, - ReadCloser: io.NopCloser(bytes.NewReader(content.Data)), - } +// NewCompletionHTTPClient constructs a [CompletionHTTPClient] from given options for sending Nexus operation completion +// callbacks via HTTP. +func NewCompletionHTTPClient(options CompletionHTTPClientOptions) *CompletionHTTPClient { + if options.HTTPCaller == nil { + options.HTTPCaller = http.DefaultClient.Do } - - return &OperationCompletionSuccessful{ - Header: make(nexus.Header), - Reader: reader, - OperationToken: options.OperationToken, - StartTime: options.StartTime, - CloseTime: options.CloseTime, - Links: options.Links, - }, nil -} - -func (c *OperationCompletionSuccessful) applyToHTTPRequest(request *http.Request) error { - if request.Header == nil { - request.Header = make(http.Header, len(c.Header)+len(c.Reader.Header)+1) // +1 for headerOperationState + if options.Serializer == nil { + options.Serializer = nexus.DefaultSerializer() } - if c.Reader.Header != nil { - addContentHeaderToHTTPHeader(c.Reader.Header, request.Header) + if options.FailureConverter == nil { + options.FailureConverter = DefaultFailureConverter() } - if c.Header != nil { - addNexusHeaderToHTTPHeader(c.Header, request.Header) + return &CompletionHTTPClient{ + baseHTTPClient: baseHTTPClient{ + httpCaller: options.HTTPCaller, + serializer: options.Serializer, + failureConverter: options.FailureConverter, + }, } - request.Header.Set(headerOperationState, string(nexus.OperationStateSucceeded)) +} - if c.Header.Get(nexus.HeaderOperationToken) == "" && c.OperationToken != "" { - request.Header.Set(nexus.HeaderOperationToken, c.OperationToken) +// CompleteOperation sends a completion callback for a Nexus operation to the given URL with the given completion details. +func (c *CompletionHTTPClient) CompleteOperation(ctx context.Context, url string, completion CompleteOperationOptions) error { + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, nil) + if err != nil { + return err } - if c.Header.Get(headerOperationStartTime) == "" && !c.StartTime.IsZero() { - request.Header.Set(headerOperationStartTime, c.StartTime.Format(http.TimeFormat)) + if err := completion.applyToHTTPRequest(c, httpReq); err != nil { + return err } - if c.Header.Get(headerOperationCloseTime) == "" && !c.CloseTime.IsZero() { - request.Header.Set(headerOperationCloseTime, marshalTimestamp(c.CloseTime)) + + response, err := c.httpCaller(httpReq) + if err != nil { + return err } - if c.Header.Get(headerLink) == "" { - if err := addLinksToHTTPHeader(c.Links, request.Header); err != nil { - return err + + if response.StatusCode >= 200 && response.StatusCode < 300 { + // Body is not read but should be discarded to keep the underlying TCP connection alive. + // Just in case something unexpected happens while discarding or closing the body, + // propagate errors to the machine. + if _, err = io.Copy(io.Discard, response.Body); err == nil { + if err = response.Body.Close(); err != nil { + return err + } } + return nil } - request.Body = c.Reader.ReadCloser - return nil + body, err := readAndReplaceBody(response) + if err != nil { + return err + } + + return c.bestEffortHandlerErrorFromResponse(response, body) } -// OperationCompletionUnsuccessful is input for [NewCompletionHTTPRequest], used to deliver unsuccessful operation -// results. -type OperationCompletionUnsuccessful struct { +// CompleteOperationOptions is input for CompleteOperation. +// If Error is set, the completion is unsuccessful. Otherwise, it is successful. +type CompleteOperationOptions struct { // Header to send in the completion request. // Note that this is a Nexus header, not an HTTP header. Header nexus.Header - // State of the operation, should be failed or canceled. - State nexus.OperationState - // OperationToken is the unique token for this operation. Used when a completion callback is received before a - // started response. - OperationToken string - // StartTime is the time the operation started. Used when a completion callback is received before a started response. - StartTime time.Time - // CloseTime is the time the operation completed. This may be different from the time the completion callback is delivered. - CloseTime time.Time - // Links are used to link back to the operation when a completion callback is received before a started response. - Links []nexus.Link - // Failure object to send with the completion. - Failure nexus.Failure -} - -// OperationCompletionUnsuccessfulOptions are options for [NewOperationCompletionUnsuccessful]. -type OperationCompletionUnsuccessfulOptions struct { - // A [FailureConverter] to convert a [Failure] instance to and from an [error]. Defaults to - // [DefaultFailureConverter]. - FailureConverter nexus.FailureConverter - // OperationID is the unique ID for this operation. Used when a completion callback is received before a started response. - // - // Deprecated: Use OperatonToken instead. - OperationID string // OperationToken is the unique token for this operation. Used when a completion callback is received before a // started response. OperationToken string @@ -179,36 +101,68 @@ type OperationCompletionUnsuccessfulOptions struct { CloseTime time.Time // Links are used to link back to the operation when a completion callback is received before a started response. Links []nexus.Link + // Error to send with the completion. If set, the completion is unsuccessful. + Error *nexus.OperationError + // Result to deliver with the completion. Uses the client's serializer to serialize the result into the request body. + // Only used for successful completions. + Result any } -// NewOperationCompletionUnsuccessful constructs an [OperationCompletionUnsuccessful] from a given error. -func NewOperationCompletionUnsuccessful(opErr *nexus.OperationError, options OperationCompletionUnsuccessfulOptions) (*OperationCompletionUnsuccessful, error) { - if options.FailureConverter == nil { - options.FailureConverter = nexus.DefaultFailureConverter() +// nolint:revive // This method is long but it's more readable to keep the logic in one place since it's all related to +// constructing the completion request. +func (c CompleteOperationOptions) applyToHTTPRequest(cc *CompletionHTTPClient, request *http.Request) error { + if request.Header == nil { + request.Header = make(http.Header) } - failure := options.FailureConverter.ErrorToFailure(opErr.Cause) - - return &OperationCompletionUnsuccessful{ - Header: make(nexus.Header), - State: opErr.State, - Failure: failure, - OperationToken: options.OperationToken, - StartTime: options.StartTime, - CloseTime: options.CloseTime, - Links: options.Links, - }, nil -} -func (c *OperationCompletionUnsuccessful) applyToHTTPRequest(request *http.Request) error { - if request.Header == nil { - request.Header = make(http.Header, len(c.Header)+2) // +2 for headerOperationState and content-type + // Set the body and operation state based on whether the completion is successful or not. + if c.Error != nil { + failure, err := cc.failureConverter.ErrorToFailure(c.Error) + if err != nil { + return err + } + // Backwards compatibility: if the failure has a cause, unwrap it to maintain the behavior as older servers. + if failure.Cause != nil { + failure = *failure.Cause + } + b, err := json.Marshal(failure) + if err != nil { + return err + } + request.Body = io.NopCloser(bytes.NewReader(b)) + // Set the operation state header for backwards compatibility. + request.Header.Set(headerOperationState, string(c.Error.State)) + request.Header.Set("Content-Type", contentTypeJSON) + } else { + reader, ok := c.Result.(*nexus.Reader) + if !ok { + content, ok := c.Result.(*nexus.Content) + if !ok { + var err error + content, err = cc.serializer.Serialize(c.Result) + if err != nil { + return err + } + } + request.ContentLength = int64(len(content.Data)) + reader = &nexus.Reader{ + Header: content.Header, + ReadCloser: io.NopCloser(bytes.NewReader(content.Data)), + } + } + if reader.Header != nil { + addContentHeaderToHTTPHeader(reader.Header, request.Header) + } + request.Body = reader.ReadCloser + request.Header.Set(headerOperationState, string(nexus.OperationStateSucceeded)) } + if c.Header != nil { addNexusHeaderToHTTPHeader(c.Header, request.Header) } - request.Header.Set(headerOperationState, string(c.State)) - request.Header.Set("Content-Type", contentTypeJSON) - + if c.Header.Get(headerUserAgent) == "" { + request.Header.Set(headerUserAgent, userAgent) + } if c.Header.Get(nexus.HeaderOperationToken) == "" && c.OperationToken != "" { request.Header.Set(nexus.HeaderOperationToken, c.OperationToken) } @@ -223,13 +177,6 @@ func (c *OperationCompletionUnsuccessful) applyToHTTPRequest(request *http.Reque return err } } - - b, err := json.Marshal(c.Failure) - if err != nil { - return err - } - - request.Body = io.NopCloser(bytes.NewReader(b)) return nil } @@ -249,7 +196,7 @@ type CompletionRequest struct { // Links are used to link back to the operation when a completion callback is received before a started response. Links []nexus.Link // Parsed from request and set if State is failed or canceled. - Error error + Error *nexus.OperationError // Extracted from request and set if State is succeeded. Result *nexus.LazyValue } @@ -272,11 +219,11 @@ type CompletionHandlerOptions struct { Serializer nexus.Serializer // A [FailureConverter] to convert a [Failure] instance to and from an [error]. Defaults to // [DefaultFailureConverter]. - FailureConverter nexus.FailureConverter + FailureConverter FailureConverter } type completionHTTPHandler struct { - baseHTTPHandler + BaseHTTPHandler options CompletionHandlerOptions } @@ -290,39 +237,58 @@ func (h *completionHTTPHandler) ServeHTTP(writer http.ResponseWriter, request *h if startTimeHeader := request.Header.Get(headerOperationStartTime); startTimeHeader != "" { var parseTimeErr error if completion.StartTime, parseTimeErr = http.ParseTime(startTimeHeader); parseTimeErr != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse operation start time header")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse operation start time header")) return } } if closeTimeHeader := request.Header.Get(headerOperationCloseTime); closeTimeHeader != "" { var parseTimeErr error if completion.CloseTime, parseTimeErr = unmarshalTimestamp(closeTimeHeader); parseTimeErr != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse operation close time header")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse operation close time header")) return } } var decodeErr error if completion.Links, decodeErr = getLinksFromHeader(request.Header); decodeErr != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode links from request headers")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode links from request headers")) return } switch completion.State { case nexus.OperationStateFailed, nexus.OperationStateCanceled: if !isMediaTypeJSON(request.Header.Get("Content-Type")) { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request content type: %q", request.Header.Get("Content-Type"))) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request content type: %q", request.Header.Get("Content-Type"))) return } var failure nexus.Failure b, err := io.ReadAll(request.Body) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to read Failure from request body")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to read Failure from request body")) return } if err := json.Unmarshal(b, &failure); err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to read Failure from request body")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode Failure from request body")) + return + } + completionErr, err := h.FailureConverter.FailureToError(failure) + if err != nil { + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode Failure from request body")) return } - completion.Error = h.failureConverter.FailureToError(failure) + opErr, ok := completionErr.(*nexus.OperationError) + if !ok { + // Backwards compatibility: wrap non-OperationError errors in an OperationError with the appropriate state. + completion.Error = &nexus.OperationError{ + Message: "nexus operation completed unsuccessfully", + State: completion.State, + Cause: completionErr, + } + if err := MarkAsWrapperError(h.FailureConverter, completion.Error); err != nil { + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to decode Failure from request body")) + return + } + } else { + completion.Error = opErr + } case nexus.OperationStateSucceeded: completion.Result = nexus.NewLazyValue( h.options.Serializer, @@ -332,11 +298,11 @@ func (h *completionHTTPHandler) ServeHTTP(writer http.ResponseWriter, request *h }, ) default: - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request operation state: %q", completion.State)) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request operation state: %q", completion.State)) return } if err := h.options.Handler.CompleteOperation(ctx, &completion); err != nil { - h.writeFailure(writer, err) + h.WriteFailure(writer, request, err) } } @@ -349,13 +315,13 @@ func NewCompletionHTTPHandler(options CompletionHandlerOptions) http.Handler { options.Serializer = nexus.DefaultSerializer() } if options.FailureConverter == nil { - options.FailureConverter = nexus.DefaultFailureConverter() + options.FailureConverter = DefaultFailureConverter() } return &completionHTTPHandler{ options: options, - baseHTTPHandler: baseHTTPHandler{ - logger: options.Logger, - failureConverter: options.FailureConverter, + BaseHTTPHandler: BaseHTTPHandler{ + Logger: options.Logger, + FailureConverter: options.FailureConverter, }, } } diff --git a/common/nexus/nexusrpc/completion_test.go b/common/nexus/nexusrpc/completion_test.go index 4436634430..92fa7e92e4 100644 --- a/common/nexus/nexusrpc/completion_test.go +++ b/common/nexus/nexusrpc/completion_test.go @@ -3,8 +3,6 @@ package nexusrpc_test import ( "context" "errors" - "io" - "net/http" "net/url" "testing" "time" @@ -33,28 +31,28 @@ func validateExpectedTime(expected, actual time.Time, resolution time.Duration) func (h *successfulCompletionHandler) CompleteOperation(ctx context.Context, completion *nexusrpc.CompletionRequest) error { if completion.HTTPRequest.URL.Path != "/callback" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL path: %q", completion.HTTPRequest.URL.Path) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL path: %q", completion.HTTPRequest.URL.Path) } if completion.HTTPRequest.URL.Query().Get("a") != "b" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'a' query param: %q", completion.HTTPRequest.URL.Query().Get("a")) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'a' query param: %q", completion.HTTPRequest.URL.Query().Get("a")) } if completion.HTTPRequest.Header.Get("foo") != "bar" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' header: %q", completion.HTTPRequest.Header.Get("foo")) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' header: %q", completion.HTTPRequest.Header.Get("foo")) } if completion.HTTPRequest.Header.Get("User-Agent") != "temporalio/server" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", completion.HTTPRequest.Header.Get("User-Agent")) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", completion.HTTPRequest.Header.Get("User-Agent")) } if completion.OperationToken != "test-operation-token" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation token: %q", completion.OperationToken) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation token: %q", completion.OperationToken) } if len(completion.Links) == 0 { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected Links to be set on CompletionRequest") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected Links to be set on CompletionRequest") } if !validateExpectedTime(h.expectedStartTime, completion.StartTime, time.Second) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected StartTime to be equal") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected StartTime to be equal") } if !validateExpectedTime(h.expectedCloseTime, completion.CloseTime, time.Millisecond) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected CloseTime to be equal") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected CloseTime to be equal") } var result int err := completion.Result.Consume(&result) @@ -62,7 +60,7 @@ func (h *successfulCompletionHandler) CompleteOperation(ctx context.Context, com return err } if result != 666 { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid result: %q", result) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid result: %q", result) } return nil } @@ -77,7 +75,8 @@ func TestSuccessfulCompletion(t *testing.T) { }, nil, nil) defer teardown() - completion, err := nexusrpc.NewOperationCompletionSuccessful(666, nexusrpc.OperationCompletionSuccessfulOptions{ + completion := nexusrpc.CompleteOperationOptions{ + Result: 666, OperationToken: "test-operation-token", StartTime: startTime, CloseTime: closeTime, @@ -90,19 +89,12 @@ func TestSuccessfulCompletion(t *testing.T) { }, Type: "url", }}, - }) - completion.Header.Set("foo", "bar") - require.NoError(t, err) + Header: nexus.Header{"foo": "bar"}, + } - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackURL, completion) - require.NoError(t, err) - response, err := http.DefaultClient.Do(request) + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}) + err := c.CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) - // nolint:errcheck - defer response.Body.Close() - _, err = io.ReadAll(response.Body) - require.NoError(t, err) - require.Equal(t, http.StatusOK, response.StatusCode) } func TestSuccessfulCompletion_CustomSerializer(t *testing.T) { @@ -110,8 +102,8 @@ func TestSuccessfulCompletion_CustomSerializer(t *testing.T) { ctx, callbackURL, teardown := setupForCompletion(t, &successfulCompletionHandler{}, serializer, nil) defer teardown() - completion, err := nexusrpc.NewOperationCompletionSuccessful(666, nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: serializer, + completion := nexusrpc.CompleteOperationOptions{ + Result: 666, Links: []nexus.Link{{ URL: &url.URL{ Scheme: "https", @@ -121,20 +113,17 @@ func TestSuccessfulCompletion_CustomSerializer(t *testing.T) { }, Type: "url", }}, - }) + Header: nexus.Header{"foo": "bar"}, + } + completion.Header.Set("foo", "bar") completion.Header.Set(nexus.HeaderOperationToken, "test-operation-token") - require.NoError(t, err) - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackURL, completion) - require.NoError(t, err) - response, err := http.DefaultClient.Do(request) - require.NoError(t, err) - // nolint:errcheck - defer response.Body.Close() - _, err = io.ReadAll(response.Body) + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: serializer, + }) + err := c.CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) - require.Equal(t, http.StatusOK, response.StatusCode) require.Equal(t, 1, serializer.decoded) require.Equal(t, 1, serializer.encoded) @@ -148,25 +137,25 @@ type failureExpectingCompletionHandler struct { func (h *failureExpectingCompletionHandler) CompleteOperation(ctx context.Context, completion *nexusrpc.CompletionRequest) error { if completion.State != nexus.OperationStateCanceled { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected completion state: %q", completion.State) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected completion state: %q", completion.State) } if err := h.errorChecker(completion.Error); err != nil { return err } if completion.HTTPRequest.Header.Get("foo") != "bar" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' header: %q", completion.HTTPRequest.Header.Get("foo")) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'foo' header: %q", completion.HTTPRequest.Header.Get("foo")) } if completion.OperationToken != "test-operation-token" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation token: %q", completion.OperationToken) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation token: %q", completion.OperationToken) } if len(completion.Links) == 0 { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected Links to be set on CompletionRequest") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected Links to be set on CompletionRequest") } if !validateExpectedTime(h.expectedStartTime, completion.StartTime, time.Second) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected StartTime to be equal") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected StartTime to be equal") } if !validateExpectedTime(h.expectedCloseTime, completion.CloseTime, time.Millisecond) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected CloseTime to be equal") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "expected CloseTime to be equal") } return nil @@ -178,17 +167,18 @@ func TestFailureCompletion(t *testing.T) { ctx, callbackURL, teardown := setupForCompletion(t, &failureExpectingCompletionHandler{ errorChecker: func(err error) error { - if err.Error() != "expected message" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid failure: %v", err) + if opErr, ok := err.(*nexus.OperationError); ok && opErr.Message == "expected message" { + return nil } - return nil + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid failure: %v", err) }, expectedStartTime: startTime, expectedCloseTime: closeTime, }, nil, nil) defer teardown() - completion, err := nexusrpc.NewOperationCompletionUnsuccessful(nexus.NewOperationCanceledError("expected message"), nexusrpc.OperationCompletionUnsuccessfulOptions{ + completion := nexusrpc.CompleteOperationOptions{ + Error: nexus.NewOperationCanceledErrorf("expected message"), OperationToken: "test-operation-token", StartTime: startTime, CloseTime: closeTime, @@ -201,18 +191,11 @@ func TestFailureCompletion(t *testing.T) { }, Type: "url", }}, - }) - require.NoError(t, err) - completion.Header.Set("foo", "bar") - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackURL, completion) - require.NoError(t, err) - response, err := http.DefaultClient.Do(request) - require.NoError(t, err) - // nolint:errcheck - defer response.Body.Close() - _, err = io.ReadAll(response.Body) + Header: nexus.Header{"foo": "bar"}, + } + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}) + err := c.CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) - require.Equal(t, http.StatusOK, response.StatusCode) } func TestFailureCompletion_CustomFailureConverter(t *testing.T) { @@ -223,7 +206,7 @@ func TestFailureCompletion_CustomFailureConverter(t *testing.T) { ctx, callbackURL, teardown := setupForCompletion(t, &failureExpectingCompletionHandler{ errorChecker: func(err error) error { if !errors.Is(err, errCustom) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid failure, expected a custom error: %v", err) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid failure, expected a custom error: %v", err) } return nil }, @@ -232,11 +215,11 @@ func TestFailureCompletion_CustomFailureConverter(t *testing.T) { }, nil, fc) defer teardown() - completion, err := nexusrpc.NewOperationCompletionUnsuccessful(nexus.NewOperationCanceledError("expected message"), nexusrpc.OperationCompletionUnsuccessfulOptions{ - FailureConverter: fc, - OperationToken: "test-operation-token", - StartTime: startTime, - CloseTime: closeTime, + completion := nexusrpc.CompleteOperationOptions{ + Error: nexus.NewOperationCanceledErrorf("expected message"), + OperationToken: "test-operation-token", + StartTime: startTime, + CloseTime: closeTime, Links: []nexus.Link{{ URL: &url.URL{ Scheme: "https", @@ -246,40 +229,32 @@ func TestFailureCompletion_CustomFailureConverter(t *testing.T) { }, Type: "url", }}, + Header: nexus.Header{"foo": "bar"}, + } + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + FailureConverter: fc, }) + err := c.CompleteOperation(ctx, callbackURL, completion) require.NoError(t, err) - completion.Header.Set("foo", "bar") - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackURL, completion) - require.NoError(t, err) - response, err := http.DefaultClient.Do(request) - require.NoError(t, err) - // nolint:errcheck - defer response.Body.Close() - _, err = io.ReadAll(response.Body) - require.NoError(t, err) - require.Equal(t, http.StatusOK, response.StatusCode) } type failingCompletionHandler struct { } func (h *failingCompletionHandler) CompleteOperation(ctx context.Context, completion *nexusrpc.CompletionRequest) error { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "I can't get no satisfaction") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "I can't get no satisfaction") } func TestBadRequestCompletion(t *testing.T) { ctx, callbackURL, teardown := setupForCompletion(t, &failingCompletionHandler{}, nil, nil) defer teardown() - completion, err := nexusrpc.NewOperationCompletionSuccessful([]byte("success"), nexusrpc.OperationCompletionSuccessfulOptions{}) - require.NoError(t, err) - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackURL, completion) - require.NoError(t, err) - response, err := http.DefaultClient.Do(request) - require.NoError(t, err) - // nolint:errcheck - defer response.Body.Close() - _, err = io.ReadAll(response.Body) - require.NoError(t, err) - require.Equal(t, http.StatusBadRequest, response.StatusCode) + completion := nexusrpc.CompleteOperationOptions{ + Result: []byte("success"), + } + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{}) + err := c.CompleteOperation(ctx, callbackURL, completion) + var handlerErr *nexus.HandlerError + require.ErrorAs(t, err, &handlerErr) + require.Equal(t, nexus.HandlerErrorTypeBadRequest, handlerErr.Type) } diff --git a/common/nexus/nexusrpc/failure_converter.go b/common/nexus/nexusrpc/failure_converter.go new file mode 100644 index 0000000000..befdaec8d1 --- /dev/null +++ b/common/nexus/nexusrpc/failure_converter.go @@ -0,0 +1,211 @@ +package nexusrpc + +import ( + "encoding/json" + "fmt" + + "github.com/nexus-rpc/sdk-go/nexus" +) + +// FailureConverter is used by the framework to transform [error] instances to and from [Failure] instances. +// To customize conversion logic, implement this interface and provide your implementation to framework methods such as +// [NewHTTPClient] and [NewHTTPHandler]. +// By default the SDK translates [HandlerError], [OperationError] and [FailureError] to and from [Failure] +// objects maintaining their cause chain. +// Arbitrary errors are translated to a [Failure] object with its Message set to the Error() string, losing the cause +// chain. +type FailureConverter interface { + // ErrorToFailure converts an [error] to a [Failure]. + // Note that the provided error may be nil. + ErrorToFailure(error) (nexus.Failure, error) + // FailureToError converts a [Failure] to an [error]. + FailureToError(nexus.Failure) (error, error) +} + +type knownErrorFailureConverter struct{} + +type serializedHandlerError struct { + Type string `json:"type,omitempty"` + RetryableOverride *bool `json:"retryableOverride,omitempty"` +} + +func (e serializedHandlerError) RetryBehavior() nexus.HandlerErrorRetryBehavior { + if e.RetryableOverride == nil { + return nexus.HandlerErrorRetryBehaviorUnspecified + } + if *e.RetryableOverride { + return nexus.HandlerErrorRetryBehaviorRetryable + } + return nexus.HandlerErrorRetryBehaviorNonRetryable +} + +type serializedOperationError struct { + State string `json:"state,omitempty"` +} + +// ErrorToFailure implements FailureConverter. +// nolint:revive // Keeping all of the logic together for readability, even if it means the function is long. +func (e knownErrorFailureConverter) ErrorToFailure(err error) (nexus.Failure, error) { + if err == nil { + return nexus.Failure{}, nil + } + // NOTE: not using errors.Unwrap here we are intentionally only supporting unwrapping known errors. + switch typedErr := err.(type) { + case *nexus.FailureError: + f := typedErr.Failure + // Convert the erorr cause if set and failure cause is not already set. + if typedErr.Cause != nil && f.Cause == nil { + c, err := e.ErrorToFailure(typedErr.Cause) + if err != nil { + return nexus.Failure{}, err + } + f.Cause = &c + } + return f, nil + case *nexus.HandlerError: + if typedErr.OriginalFailure != nil { + return *typedErr.OriginalFailure, nil + } + data := serializedHandlerError{ + Type: string(typedErr.Type), + RetryableOverride: retryBehaviorAsOptionalBool(typedErr), + } + var details []byte + details, err := json.Marshal(data) + if err != nil { + return nexus.Failure{}, err + } + f := nexus.Failure{ + Message: typedErr.Message, + StackTrace: typedErr.StackTrace, + Metadata: map[string]string{ + "type": "nexus.HandlerError", + }, + Details: details, + } + + if typedErr.Cause != nil { + c, err := e.ErrorToFailure(typedErr.Cause) + if err != nil { + return nexus.Failure{}, err + } + f.Cause = &c + } + return f, nil + case *nexus.OperationError: + if typedErr.OriginalFailure != nil { + return *typedErr.OriginalFailure, nil + } + data := serializedOperationError{ + State: string(typedErr.State), + } + details, err := json.Marshal(data) + if err != nil { + return nexus.Failure{}, err + } + f := nexus.Failure{ + Message: typedErr.Message, + StackTrace: typedErr.StackTrace, + Metadata: map[string]string{ + "type": "nexus.OperationError", + }, + Details: details, + } + + if typedErr.Cause != nil { + c, err := e.ErrorToFailure(typedErr.Cause) + if err != nil { + return nexus.Failure{}, err + } + f.Cause = &c + } + return f, nil + default: + return nexus.Failure{ + Message: typedErr.Error(), + }, nil + } +} + +// FailureToError implements FailureConverter. +// nolint:revive // Keeping all of the logic together for readability, even if it means the function is long. +func (e knownErrorFailureConverter) FailureToError(f nexus.Failure) (error, error) { + if f.Metadata != nil { + switch f.Metadata["type"] { + case "nexus.HandlerError": + var se serializedHandlerError + err := json.Unmarshal(f.Details, &se) + if err != nil { + return nil, fmt.Errorf("failed to deserialize HandlerError: %w", err) + } + he := &nexus.HandlerError{ + Message: f.Message, + StackTrace: f.StackTrace, + Type: nexus.HandlerErrorType(se.Type), + RetryBehavior: se.RetryBehavior(), + OriginalFailure: &f, + } + if f.Cause != nil { + he.Cause, err = e.FailureToError(*f.Cause) + if err != nil { + return nil, err + } + } + return he, nil + case "nexus.OperationError": + var se serializedOperationError + err := json.Unmarshal(f.Details, &se) + if err != nil { + return nil, fmt.Errorf("failed to deserialize OperationError: %w", err) + } + oe := &nexus.OperationError{ + Message: f.Message, + StackTrace: f.StackTrace, + State: nexus.OperationState(se.State), + OriginalFailure: &f, + } + if f.Cause != nil { + oe.Cause, err = e.FailureToError(*f.Cause) + if err != nil { + return nil, err + } + } + return oe, nil + } + } + // Note that the original failure cause is retained on the FailureError's failure object. + fe := &nexus.FailureError{Failure: f} + if f.Cause != nil { + c, err := e.FailureToError(*f.Cause) + if err != nil { + return nil, err + } + fe.Cause = c + } + return fe, nil +} + +var defaultFailureConverter FailureConverter = knownErrorFailureConverter{} + +// DefaultFailureConverter returns the SDK's default [FailureConverter] implementation. Translates [HandlerError], +// [OperationError] and [FailureError] to and from [Failure] objects maintaining their cause chain. +// Arbitrary errors are translated to a [Failure] object with its Message set to the Error() string, losing the cause +// chain. +// [Failure] instances are converted to [FailureError] to allow access to the full failure metadata and details if +// available. +func DefaultFailureConverter() FailureConverter { + return defaultFailureConverter +} + +func retryBehaviorAsOptionalBool(e *nexus.HandlerError) *bool { + // nolint:exhaustive // this is a simple optional boolean. + switch e.RetryBehavior { + case nexus.HandlerErrorRetryBehaviorRetryable: + ret := true + return &ret + case nexus.HandlerErrorRetryBehaviorNonRetryable: + ret := false + return &ret + } + return nil +} diff --git a/common/nexus/nexusrpc/failure_conveter_test.go b/common/nexus/nexusrpc/failure_conveter_test.go new file mode 100644 index 0000000000..2cb94b9757 --- /dev/null +++ b/common/nexus/nexusrpc/failure_conveter_test.go @@ -0,0 +1,135 @@ +package nexusrpc + +import ( + "errors" + "testing" + + "github.com/nexus-rpc/sdk-go/nexus" + "github.com/stretchr/testify/require" +) + +func TestFailureConverter_GenericError(t *testing.T) { + failure, err := defaultFailureConverter.ErrorToFailure(errors.New("test")) + require.NoError(t, err) + actual, err := defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + require.Equal(t, &nexus.FailureError{Failure: nexus.Failure{Message: "test"}}, actual) +} + +func TestFailureConverter_FailureError(t *testing.T) { + cause := &nexus.FailureError{ + Failure: nexus.Failure{Message: "cause"}, + } + fe := &nexus.FailureError{ + Failure: nexus.Failure{ + Message: "foo", + }, + Cause: cause, + } + failure, err := defaultFailureConverter.ErrorToFailure(fe) + require.NoError(t, err) + actual, err := defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + // The serialized failure cause is retained. + fe.Failure.Cause = failure.Cause + require.Equal(t, fe, actual) + + // Serialize again and verify the original failure is used. + fe.Cause = errors.New("should be ignored") + failure, err = defaultFailureConverter.ErrorToFailure(fe) + require.NoError(t, err) + actual, err = defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + // Reset back to the orignal cause before comparing. + fe.Cause = cause + require.Equal(t, fe, actual) +} + +func TestFailureConverter_HandlerError(t *testing.T) { + cause := &nexus.FailureError{Failure: nexus.Failure{Message: "cause"}} + he := nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "foo") + he.StackTrace = "stack" + he.Cause = cause + failure, err := defaultFailureConverter.ErrorToFailure(he) + require.NoError(t, err) + // Verify that the original failure is retained. + he.OriginalFailure = &failure + actual, err := defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + require.Equal(t, he, actual) + + // Serialize again and verify the original failure is used. + he.Cause = errors.New("should be ignored") + failure, err = defaultFailureConverter.ErrorToFailure(he) + require.NoError(t, err) + actual, err = defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + // Reset back to the orignal cause before comparing. + he.Cause = cause + require.Equal(t, he, actual) +} + +func TestFailureConverter_HandlerErrorRetryBehavior(t *testing.T) { + he := nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "foo") + he.StackTrace = "stack" + he.RetryBehavior = nexus.HandlerErrorRetryBehaviorRetryable + failure, err := defaultFailureConverter.ErrorToFailure(he) + require.NoError(t, err) + // Verify that the original failure is retained. + he.OriginalFailure = &failure + actual, err := defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + require.Equal(t, he, actual) +} + +func TestFailureConverter_OperationError(t *testing.T) { + cause := &nexus.FailureError{Failure: nexus.Failure{Message: "cause"}} + oe := nexus.NewOperationCanceledErrorf("foo") + oe.StackTrace = "stack" + oe.Cause = cause + failure, err := defaultFailureConverter.ErrorToFailure(oe) + require.NoError(t, err) + // Verify that the original failure is retained. + oe.OriginalFailure = &failure + actual, err := defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + require.Equal(t, oe, actual) + + // Serialize again and verify the original failure is used. + oe.Cause = errors.New("should be ignored") + failure, err = defaultFailureConverter.ErrorToFailure(oe) + require.NoError(t, err) + actual, err = defaultFailureConverter.FailureToError(failure) + require.NoError(t, err) + // Reset back to the orignal cause before comparing. + oe.Cause = cause + require.Equal(t, oe, actual) +} + +func TestDefaultFailureConverterArbitraryError(t *testing.T) { + sourceErr := errors.New("test") + conv := defaultFailureConverter + + f, err := conv.ErrorToFailure(sourceErr) + require.NoError(t, err) + convErr, err := conv.FailureToError(f) + require.NoError(t, err) + require.Equal(t, sourceErr.Error(), convErr.Error()) +} + +func TestDefaultFailureConverterFailureError(t *testing.T) { + sourceErr := &nexus.FailureError{ + Failure: nexus.Failure{ + Message: "test", + Metadata: map[string]string{"key": "value"}, + Details: []byte(`"details"`), + }, + } + conv := defaultFailureConverter + + f, err := conv.ErrorToFailure(sourceErr) + require.NoError(t, err) + convErr, err := conv.FailureToError(f) + require.NoError(t, err) + require.Equal(t, sourceErr, convErr) +} diff --git a/common/nexus/nexusrpc/handle.go b/common/nexus/nexusrpc/handle.go index a5f89b8154..ca8f75a9a7 100644 --- a/common/nexus/nexusrpc/handle.go +++ b/common/nexus/nexusrpc/handle.go @@ -22,7 +22,7 @@ type OperationHandle[T any] struct { // // Cancelation is asynchronous and may be not be respected by the operation's implementation. func (h *OperationHandle[T]) Cancel(ctx context.Context, options nexus.CancelOperationOptions) error { - u := h.client.serviceBaseURL.JoinPath(url.PathEscape(h.client.options.Service), url.PathEscape(h.Operation), "cancel") + u := h.client.serviceBaseURL.JoinPath(url.PathEscape(h.client.service), url.PathEscape(h.Operation), "cancel") request, err := http.NewRequestWithContext(ctx, "POST", u.String(), nil) if err != nil { return err @@ -31,7 +31,7 @@ func (h *OperationHandle[T]) Cancel(ctx context.Context, options nexus.CancelOpe addContextTimeoutToHTTPHeader(ctx, request.Header) request.Header.Set(headerUserAgent, userAgent) addNexusHeaderToHTTPHeader(options.Header, request.Header) - response, err := h.client.options.HTTPCaller(request) + response, err := h.client.httpCaller(request) if err != nil { return err } diff --git a/common/nexus/nexusrpc/server.go b/common/nexus/nexusrpc/server.go index a4b16ae893..5b0a55a9e0 100644 --- a/common/nexus/nexusrpc/server.go +++ b/common/nexus/nexusrpc/server.go @@ -18,10 +18,10 @@ import ( "github.com/nexus-rpc/sdk-go/nexus" ) -func applyResultToHTTPResponse(r nexus.HandlerStartOperationResult[any], writer http.ResponseWriter, handler *httpHandler) { +func applyResultToHTTPResponse(r nexus.HandlerStartOperationResult[any], writer http.ResponseWriter, request *http.Request, handler *httpHandler) { switch r := r.(type) { case interface{ ValueAsAny() any }: - handler.writeResult(writer, r.ValueAsAny()) + handler.writeResult(writer, request, r.ValueAsAny()) case *nexus.HandlerStartOperationResultAsync: info := nexus.OperationInfo{ Token: r.OperationToken, @@ -29,7 +29,7 @@ func applyResultToHTTPResponse(r nexus.HandlerStartOperationResult[any], writer } b, err := json.Marshal(info) if err != nil { - handler.logger.Error("failed to serialize operation info", "error", err) + handler.Logger.Error("failed to serialize operation info", "error", err) writer.WriteHeader(http.StatusInternalServerError) return } @@ -37,22 +37,22 @@ func applyResultToHTTPResponse(r nexus.HandlerStartOperationResult[any], writer writer.Header().Set("Content-Type", contentTypeJSON) writer.WriteHeader(http.StatusCreated) if _, err := writer.Write(b); err != nil { - handler.logger.Error("failed to write response body", "error", err) + handler.Logger.Error("failed to write response body", "error", err) } } } -type baseHTTPHandler struct { - logger *slog.Logger - failureConverter nexus.FailureConverter +type BaseHTTPHandler struct { + Logger *slog.Logger + FailureConverter FailureConverter } type httpHandler struct { - baseHTTPHandler + BaseHTTPHandler options HandlerOptions } -func (h *httpHandler) writeResult(writer http.ResponseWriter, result any) { +func (h *httpHandler) writeResult(writer http.ResponseWriter, request *http.Request, result any) { var reader *nexus.Reader if r, ok := result.(*nexus.Reader); ok { // Close the request body in case we error before sending the HTTP request (which may double close but @@ -66,7 +66,7 @@ func (h *httpHandler) writeResult(writer http.ResponseWriter, result any) { var err error content, err = h.options.Serializer.Serialize(result) if err != nil { - h.writeFailure(writer, fmt.Errorf("failed to serialize handler result: %w", err)) + h.WriteFailure(writer, request, fmt.Errorf("failed to serialize handler result: %w", err)) return } } @@ -85,12 +85,16 @@ func (h *httpHandler) writeResult(writer http.ResponseWriter, result any) { return } if _, err := io.Copy(writer, reader); err != nil { - h.logger.Error("failed to write response body", "error", err) + h.Logger.Error("failed to write response body", "error", err) } } -func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { +// WriteFailure writes the given error to the response, converting it to a Failure if possible. It also sets the HTTP +// status code based on the type of error. +// nolint:revive // Keeping all of the logic together for readability, even if it means the function is long. +func (h *BaseHTTPHandler) WriteFailure(writer http.ResponseWriter, r *http.Request, err error) { var failure nexus.Failure + var failureError *nexus.FailureError var opError *nexus.OperationError var handlerError *nexus.HandlerError var operationState nexus.OperationState @@ -98,17 +102,38 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { if errors.As(err, &opError) { operationState = opError.State - failure = h.failureConverter.ErrorToFailure(opError.Cause) - statusCode = statusOperationFailed + var convErr error + failure, convErr = h.FailureConverter.ErrorToFailure(opError) + if convErr != nil { + h.Logger.Error("failed to convert operation error to failure", "error", convErr) + writer.WriteHeader(http.StatusInternalServerError) + return + } + // Backward compatibility, unwrap the failure cause. + if r.Header.Get("temporal-nexus-failure-support") != "true" && failure.Cause != nil { + failure = *failure.Cause + } + + statusCode = statusOperationUnsuccessful if operationState != nexus.OperationStateFailed && operationState != nexus.OperationStateCanceled { - h.logger.Error("unexpected operation state", "state", operationState) + h.Logger.Error("unexpected operation state", "state", operationState) writer.WriteHeader(http.StatusInternalServerError) return } writer.Header().Set(headerOperationState, string(operationState)) } else if errors.As(err, &handlerError) { - failure = h.failureConverter.ErrorToFailure(handlerError.Cause) + var convErr error + failure, convErr = h.FailureConverter.ErrorToFailure(handlerError) + if convErr != nil { + h.Logger.Error("failed to convert handler error to failure", "error", convErr) + writer.WriteHeader(http.StatusInternalServerError) + return + } + // Backward compatibility, unwrap the failure cause. + if r.Header.Get("temporal-nexus-failure-support") != "true" && failure.Cause != nil { + failure = *failure.Cause + } switch handlerError.Type { case nexus.HandlerErrorTypeBadRequest: statusCode = http.StatusBadRequest @@ -133,18 +158,20 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { case nexus.HandlerErrorTypeUpstreamTimeout: statusCode = nexus.StatusUpstreamTimeout default: - h.logger.Error("unexpected handler error type", "type", handlerError.Type) + h.Logger.Error("unexpected handler error type", "type", handlerError.Type) } + } else if errors.As(err, &failureError) { + failure = failureError.Failure } else { failure = nexus.Failure{ Message: "internal server error", } - h.logger.Error("handler failed", "error", err) + h.Logger.Error("handler failed", "error", err) } b, err := json.Marshal(failure) if err != nil { - h.logger.Error("failed to marshal failure", "error", err) + h.Logger.Error("failed to marshal failure", "error", err) writer.WriteHeader(http.StatusInternalServerError) return } @@ -166,14 +193,14 @@ func (h *baseHTTPHandler) writeFailure(writer http.ResponseWriter, err error) { writer.WriteHeader(statusCode) if _, err := writer.Write(b); err != nil { - h.logger.Error("failed to write response body", "error", err) + h.Logger.Error("failed to write response body", "error", err) } } func (h *httpHandler) startOperation(service, operation string, writer http.ResponseWriter, request *http.Request) { links, err := getLinksFromHeader(request.Header) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid %q header", headerLink)) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid %q header", headerLink)) return } options := nexus.StartOperationOptions{ @@ -204,16 +231,16 @@ func (h *httpHandler) startOperation(service, operation string, writer http.Resp }) response, err := h.options.Handler.StartOperation(ctx, service, operation, value, options) if err != nil { - h.writeFailure(writer, err) + h.WriteFailure(writer, request, err) } else { if err := addLinksToHTTPHeader(nexus.HandlerLinks(ctx), writer.Header()); err != nil { - h.logger.Error("failed to serialize links into header", "error", err) + h.Logger.Error("failed to serialize links into header", "error", err) // clear any previous links already written to the header writer.Header().Del(headerLink) writer.WriteHeader(http.StatusInternalServerError) return } - applyResultToHTTPResponse(response, writer, h) + applyResultToHTTPResponse(response, writer, request, h) } } @@ -232,7 +259,7 @@ func (h *httpHandler) cancelOperation(service, operation, token string, writer h Header: options.Header, }) if err := h.options.Handler.CancelOperation(ctx, service, operation, token, options); err != nil { - h.writeFailure(writer, err) + h.WriteFailure(writer, request, err) return } @@ -247,8 +274,8 @@ func (h *httpHandler) parseRequestTimeoutHeader(writer http.ResponseWriter, requ if timeoutStr != "" { timeoutDuration, err := ParseDuration(timeoutStr) if err != nil { - h.logger.Warn("invalid request timeout header", "timeout", timeoutStr) - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request timeout header")) + h.Logger.Warn("invalid request timeout header", "timeout", timeoutStr) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request timeout header")) return 0, false } return timeoutDuration, true @@ -287,28 +314,28 @@ type HandlerOptions struct { Serializer nexus.Serializer // A [FailureConverter] to convert a [Failure] instance to and from an [error]. // Defaults to [DefaultFailureConverter]. - FailureConverter nexus.FailureConverter + FailureConverter FailureConverter } func (h *httpHandler) handleRequest(writer http.ResponseWriter, request *http.Request) { if request.Method != "POST" { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request method: expected POST, got %q", request.Method)) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid request method: expected POST, got %q", request.Method)) return } parts := strings.Split(request.URL.EscapedPath(), "/") // First part is empty (due to leading /) if len(parts) < 3 { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "not found")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "not found")) return } service, err := url.PathUnescape(parts[1]) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) return } operation, err := url.PathUnescape(parts[2]) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) return } @@ -323,7 +350,7 @@ func (h *httpHandler) handleRequest(writer http.ResponseWriter, request *http.Re if len(parts) == 5 && parts[4] == "cancel" { token, err := url.PathUnescape(parts[3]) if err != nil { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "failed to parse URL path")) return } @@ -332,7 +359,7 @@ func (h *httpHandler) handleRequest(writer http.ResponseWriter, request *http.Re } if len(parts) != 4 || parts[3] != "cancel" { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "not found")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "not found")) return } @@ -340,7 +367,7 @@ func (h *httpHandler) handleRequest(writer http.ResponseWriter, request *http.Re if token == "" { token = request.URL.Query().Get("token") if token == "" { - h.writeFailure(writer, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "missing operation token")) + h.WriteFailure(writer, request, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "missing operation token")) return } } else { @@ -363,12 +390,12 @@ func NewHTTPHandler(options HandlerOptions) http.Handler { options.Serializer = nexus.DefaultSerializer() } if options.FailureConverter == nil { - options.FailureConverter = nexus.DefaultFailureConverter() + options.FailureConverter = DefaultFailureConverter() } handler := &httpHandler{ - baseHTTPHandler: baseHTTPHandler{ - logger: options.Logger, - failureConverter: options.FailureConverter, + BaseHTTPHandler: BaseHTTPHandler{ + Logger: options.Logger, + FailureConverter: options.FailureConverter, }, options: options, } diff --git a/common/nexus/nexusrpc/setup_test.go b/common/nexus/nexusrpc/setup_test.go index 63c65e42b6..5a4369b075 100644 --- a/common/nexus/nexusrpc/setup_test.go +++ b/common/nexus/nexusrpc/setup_test.go @@ -19,7 +19,7 @@ const testTimeout = time.Second * 5 const testService = "Ser/vic e" const getResultMaxTimeout = time.Millisecond * 300 -func setupCustom(t *testing.T, handler nexus.Handler, serializer nexus.Serializer, failureConverter nexus.FailureConverter) (ctx context.Context, client *nexusrpc.HTTPClient, teardown func()) { +func setupCustom(t *testing.T, handler nexus.Handler, serializer nexus.Serializer, failureConverter nexusrpc.FailureConverter) (ctx context.Context, client *nexusrpc.HTTPClient, teardown func()) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) httpHandler := nexusrpc.NewHTTPHandler(nexusrpc.HandlerOptions{ @@ -55,7 +55,7 @@ func setup(t *testing.T, handler nexus.Handler) (ctx context.Context, client *ne return setupCustom(t, handler, nil, nil) } -func setupForCompletion(t *testing.T, handler nexusrpc.CompletionHandler, serializer nexus.Serializer, failureConverter nexus.FailureConverter) (ctx context.Context, callbackURL string, teardown func()) { +func setupForCompletion(t *testing.T, handler nexusrpc.CompletionHandler, serializer nexus.Serializer, failureConverter nexusrpc.FailureConverter) (ctx context.Context, callbackURL string, teardown func()) { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) httpHandler := nexusrpc.NewCompletionHTTPHandler(nexusrpc.CompletionHandlerOptions{ @@ -112,19 +112,20 @@ type customFailureConverter struct{} var errCustom = errors.New("custom") // ErrorToFailure implements FailureConverter. -func (c customFailureConverter) ErrorToFailure(err error) nexus.Failure { +func (c customFailureConverter) ErrorToFailure(err error) (nexus.Failure, error) { return nexus.Failure{ Message: err.Error(), Metadata: map[string]string{ "type": "custom", }, - } + }, nil } // FailureToError implements FailureConverter. -func (c customFailureConverter) FailureToError(f nexus.Failure) error { +// nolint:revive // unnamed results of the same type is fine for test +func (c customFailureConverter) FailureToError(f nexus.Failure) (error, error) { if f.Metadata["type"] != "custom" { - return errors.New(f.Message) + return errors.New(f.Message), nil } - return fmt.Errorf("%w: %s", errCustom, f.Message) + return fmt.Errorf("%w: %s", errCustom, f.Message), nil } diff --git a/common/nexus/nexusrpc/start_test.go b/common/nexus/nexusrpc/start_test.go index 20c90a5798..a06619c0fd 100644 --- a/common/nexus/nexusrpc/start_test.go +++ b/common/nexus/nexusrpc/start_test.go @@ -25,29 +25,29 @@ func (h *successHandler) StartOperation(ctx context.Context, service, operation return nil, err } if service != testService { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected service: %s", service) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected service: %s", service) } if operation != "i need to/be escaped" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected operation: %s", operation) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected operation: %s", operation) } if options.CallbackURL != "http://test/callback" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected callback URL: %s", options.CallbackURL) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unexpected callback URL: %s", options.CallbackURL) } if options.CallbackHeader.Get("callback-test") != "ok" { - return nil, nexus.HandlerErrorf( + return nil, nexus.NewHandlerErrorf( nexus.HandlerErrorTypeBadRequest, "invalid 'callback-test' callback header: %q", options.CallbackHeader.Get("callback-test"), ) } if options.Header.Get("test") != "ok" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'test' header: %q", options.Header.Get("test")) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'test' header: %q", options.Header.Get("test")) } if options.Header.Get("nexus-callback-callback-test") != "" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "callback header not omitted from options Header") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "callback header not omitted from options Header") } if options.Header.Get("User-Agent") != "temporalio/server" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", options.Header.Get("User-Agent")) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", options.Header.Get("User-Agent")) } return &nexus.HandlerStartOperationResultSync[any]{Value: body}, nil @@ -191,7 +191,7 @@ func (h *echoHandler) StartOperation(ctx context.Context, service, operation str Data: data, } default: - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unknown input-type header") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "unknown input-type header") } nexus.AddHandlerLinks(ctx, options.Links...) return &nexus.HandlerStartOperationResultSync[any]{Value: output}, nil @@ -377,7 +377,7 @@ func TestStart_NilContentHeaderDoesNotPanic(t *testing.T) { var numberValidatorOperation = nexus.NewSyncOperation("number-validator", func(ctx context.Context, input int, _ nexus.StartOperationOptions) (int, error) { if input != 3 { - return input, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid number: %d", input) + return input, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid number: %d", input) } return input, nil }) diff --git a/common/nexus/payload_serializer.go b/common/nexus/payload_serializer.go index 06e9f63e13..ff4a86438b 100644 --- a/common/nexus/payload_serializer.go +++ b/common/nexus/payload_serializer.go @@ -98,6 +98,10 @@ func setUnknownNexusContent(nexusHeader nexus.Header, payloadMetadata map[string // Serialize implements nexus.Serializer. func (payloadSerializer) Serialize(v any) (*nexus.Content, error) { + if v == nil { + // Use same structure as the nil serializer from the Nexus Go SDK. + return &nexus.Content{Header: nexus.Header{}}, nil + } payload, ok := v.(*commonpb.Payload) if !ok { return nil, fmt.Errorf("%w: cannot serialize %v", errSerializer, v) diff --git a/components/callbacks/chasm_invocation.go b/components/callbacks/chasm_invocation.go index 843c251cdc..bb0419287c 100644 --- a/components/callbacks/chasm_invocation.go +++ b/components/callbacks/chasm_invocation.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "fmt" - "io" "github.com/google/uuid" "github.com/nexus-rpc/sdk-go/nexus" @@ -25,7 +24,7 @@ import ( type chasmInvocation struct { nexus *persistencespb.Callback_Nexus attempt int32 - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions requestID string } @@ -87,22 +86,13 @@ func (c chasmInvocation) getHistoryRequest( RequestId: c.requestID, } - switch op := c.completion.(type) { - case *nexusrpc.OperationCompletionSuccessful: - payloadBody, err := io.ReadAll(op.Reader) - if err != nil { - return nil, fmt.Errorf("failed to read payload: %v", err) - } - + if c.completion.Error == nil { var payload *commonpb.Payload - if payloadBody != nil { - content := &nexus.Content{ - Header: op.Reader.Header, - Data: payloadBody, - } - err := commonnexus.PayloadSerializer.Deserialize(content, &payload) - if err != nil { - return nil, fmt.Errorf("failed to deserialize payload: %v", err) + if c.completion.Result != nil { + var ok bool + payload, ok = c.completion.Result.(*commonpb.Payload) + if !ok { + return nil, fmt.Errorf("invalid result, expected a payload, got: %T", c.completion.Result) } } @@ -110,13 +100,21 @@ func (c chasmInvocation) getHistoryRequest( Outcome: &historyservice.CompleteNexusOperationChasmRequest_Success{ Success: payload, }, - CloseTime: timestamppb.New(op.CloseTime), + CloseTime: timestamppb.New(c.completion.CloseTime), Completion: completion, } - case *nexusrpc.OperationCompletionUnsuccessful: - apiFailure, err := commonnexus.NexusFailureToAPIFailure(op.Failure, true) + } else { + failure, err := nexusrpc.DefaultFailureConverter().ErrorToFailure(c.completion.Error) if err != nil { - return nil, fmt.Errorf("failed to convert failure type: %v", err) + return nil, fmt.Errorf("failed to convert error to failure: %w", err) + } + // Unwrap the operation error since it's not meant to be sent for Temporal->Temporal completions. + if failure.Cause != nil { + failure = *failure.Cause + } + apiFailure, err := commonnexus.NexusFailureToTemporalFailure(failure) + if err != nil { + return nil, fmt.Errorf("failed to convert failure type: %w", err) } req = &historyservice.CompleteNexusOperationChasmRequest{ @@ -124,10 +122,8 @@ func (c chasmInvocation) getHistoryRequest( Outcome: &historyservice.CompleteNexusOperationChasmRequest_Failure{ Failure: apiFailure, }, - CloseTime: timestamppb.New(op.CloseTime), + CloseTime: timestamppb.New(c.completion.CloseTime), } - default: - return nil, fmt.Errorf("unexpected nexus.OperationCompletion: %v", completion) } return req, nil diff --git a/components/callbacks/executors_test.go b/components/callbacks/executors_test.go index 75c8c3010d..1a0015b58c 100644 --- a/components/callbacks/executors_test.go +++ b/components/callbacks/executors_test.go @@ -54,11 +54,11 @@ func (fakeEnv) Now() time.Time { var _ hsm.Environment = fakeEnv{} type mutableState struct { - completionNexus nexusrpc.OperationCompletion + completionNexus nexusrpc.CompleteOperationOptions completionHsm *persistencespb.HSMCompletionCallbackArg } -func (ms mutableState) GetNexusCompletion(ctx context.Context, requestID string) (nexusrpc.OperationCompletion, error) { +func (ms mutableState) GetNexusCompletion(ctx context.Context, requestID string) (nexusrpc.CompleteOperationOptions, error) { return ms.completionNexus, nil } @@ -80,7 +80,7 @@ func TestProcessInvocationTaskNexus_Outcomes(t *testing.T) { return &http.Response{StatusCode: 200, Body: http.NoBody}, nil }, retryable: false, - expectedMetricOutcome: "status:200", + expectedMetricOutcome: "success", assertOutcome: func(t *testing.T, cb callbacks.Callback) { require.Equal(t, enumsspb.CALLBACK_STATE_SUCCEEDED, cb.State()) }, @@ -102,7 +102,7 @@ func TestProcessInvocationTaskNexus_Outcomes(t *testing.T) { return &http.Response{StatusCode: 500, Body: http.NoBody}, nil }, retryable: true, - expectedMetricOutcome: "status:500", + expectedMetricOutcome: "handler-error:INTERNAL", assertOutcome: func(t *testing.T, cb callbacks.Callback) { require.Equal(t, enumsspb.CALLBACK_STATE_BACKING_OFF, cb.State()) }, @@ -113,7 +113,7 @@ func TestProcessInvocationTaskNexus_Outcomes(t *testing.T) { return &http.Response{StatusCode: 400, Body: http.NoBody}, nil }, retryable: false, - expectedMetricOutcome: "status:400", + expectedMetricOutcome: "handler-error:BAD_REQUEST", assertOutcome: func(t *testing.T, cb callbacks.Callback) { require.Equal(t, enumsspb.CALLBACK_STATE_FAILED, cb.State()) }, @@ -267,8 +267,7 @@ func TestProcessBackoffTask(t *testing.T) { } func newMutableState(t *testing.T) mutableState { - completionNexus, err := nexusrpc.NewOperationCompletionSuccessful(nil, nexusrpc.OperationCompletionSuccessfulOptions{}) - require.NoError(t, err) + completionNexus := nexusrpc.CompleteOperationOptions{} hsmCallbackArg := &persistencespb.HSMCompletionCallbackArg{ NamespaceId: "mynsid", WorkflowId: "mywid", @@ -315,7 +314,7 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { cases := []struct { name string setupHistoryClient func(*testing.T, *gomock.Controller) *historyservicemock.MockHistoryServiceClient - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions headerValue string expectsInternalError bool assertOutcome func(*testing.T, callbacks.Callback) @@ -348,16 +347,11 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { }) return client }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayload([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - CloseTime: dummyTime, - }, - ) - require.NoError(t, err) - return comp + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Result: createPayload([]byte("result-data")), + CloseTime: dummyTime, + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb callbacks.Callback) { @@ -380,18 +374,14 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { }) return client }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{ + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Error: &nexus.OperationError{ State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: nexus.Failure{Message: "operation failed"}}, }, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - CloseTime: dummyTime, - }, - ) - require.NoError(t, err) - return comp + CloseTime: dummyTime, + } }(), headerValue: encodedRef, assertOutcome: func(t *testing.T, cb callbacks.Callback) { @@ -408,15 +398,10 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { ).Return(nil, status.Error(codes.Unavailable, "service unavailable")) return client }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayload([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Result: createPayload([]byte("result-data")), + } }(), headerValue: encodedRef, expectsInternalError: true, @@ -434,15 +419,10 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { ).Return(nil, status.Error(codes.InvalidArgument, "invalid request")) return client }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayload([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Result: createPayload([]byte("result-data")), + } }(), headerValue: encodedRef, expectsInternalError: true, @@ -456,15 +436,10 @@ func TestProcessInvocationTaskChasm_Outcomes(t *testing.T) { // No RPC call expected return historyservicemock.NewMockHistoryServiceClient(ctrl) }, - completion: func() nexusrpc.OperationCompletion { - comp, err := nexusrpc.NewOperationCompletionSuccessful( - createPayload([]byte("result-data")), - nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }, - ) - require.NoError(t, err) - return comp + completion: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{ + Result: createPayload([]byte("result-data")), + } }(), headerValue: "invalid-base64!!!", expectsInternalError: true, diff --git a/components/callbacks/nexus_invocation.go b/components/callbacks/nexus_invocation.go index 0d5c99da35..e24f070f45 100644 --- a/components/callbacks/nexus_invocation.go +++ b/components/callbacks/nexus_invocation.go @@ -1,15 +1,10 @@ package callbacks import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" - "mime" + "errors" "net/http" "net/http/httptrace" - "slices" "time" "github.com/nexus-rpc/sdk-go/nexus" @@ -30,30 +25,16 @@ var retryable4xxErrorTypes = []int{ } type CanGetNexusCompletion interface { - GetNexusCompletion(ctx context.Context, requestID string) (nexusrpc.OperationCompletion, error) + GetNexusCompletion(ctx context.Context, requestID string) (nexusrpc.CompleteOperationOptions, error) } type nexusInvocation struct { nexus *persistencespb.Callback_Nexus - completion nexusrpc.OperationCompletion + completion nexusrpc.CompleteOperationOptions workflowID, runID string attempt int32 } -func isRetryableHTTPResponse(response *http.Response) bool { - return response.StatusCode >= 500 || slices.Contains(retryable4xxErrorTypes, response.StatusCode) -} - -func outcomeTag(callCtx context.Context, response *http.Response, callErr error) string { - if callErr != nil { - if callCtx.Err() != nil { - return "request-timeout" - } - return "unknown-error" - } - return fmt.Sprintf("status:%d", response.StatusCode) -} - func (n nexusInvocation) WrapError(result invocationResult, err error) error { if failure, ok := result.(invocationResultRetry); ok { return queueserrors.NewDestinationDownError(failure.err.Error(), err) @@ -77,107 +58,54 @@ func (n nexusInvocation) Invoke(ctx context.Context, ns *namespace.Namespace, e } } - request, err := nexusrpc.NewCompletionHTTPRequest(ctx, n.nexus.Url, n.completion) - if err != nil { - return invocationResultFail{queueserrors.NewUnprocessableTaskError( - fmt.Sprintf("failed to construct Nexus request: %v", err), - )} - } - if request.Header == nil { - request.Header = make(http.Header) - } - for k, v := range n.nexus.Header { - request.Header.Set(k, v) - } - - caller := e.HTTPCallerProvider(queuescommon.NamespaceIDAndDestination{ - NamespaceID: ns.ID().String(), - Destination: task.Destination(), + client := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + HTTPCaller: e.HTTPCallerProvider(queuescommon.NamespaceIDAndDestination{ + NamespaceID: ns.ID().String(), + Destination: task.Destination(), + }), + Serializer: commonnexus.PayloadSerializer, }) // Make the call and record metrics. startTime := time.Now() - response, err := caller(request) + + n.completion.Header = n.nexus.Header + err := client.CompleteOperation(ctx, n.nexus.Url, n.completion) namespaceTag := metrics.NamespaceTag(ns.Name().String()) destTag := metrics.DestinationTag(task.Destination()) - statusCodeTag := metrics.OutcomeTag(outcomeTag(ctx, response, err)) + statusCodeTag := metrics.OutcomeTag(outcomeTag(ctx, err)) e.MetricsHandler.Counter(RequestCounter.Name()).Record(1, namespaceTag, destTag, statusCodeTag) e.MetricsHandler.Timer(RequestLatencyHistogram.Name()).Record(time.Since(startTime), namespaceTag, destTag, statusCodeTag) if err != nil { - e.Logger.Error("Callback request failed with error", tag.Error(err)) - return invocationResultRetry{err} - } - - if response.StatusCode >= 200 && response.StatusCode < 300 { - // Body is not read but should be discarded to keep the underlying TCP connection alive. - // Just in case something unexpected happens while discarding or closing the body, - // propagate errors to the machine. - if _, err = io.Copy(io.Discard, response.Body); err == nil { - if err = response.Body.Close(); err != nil { - e.Logger.Error("Callback request failed with error", tag.Error(err)) - return invocationResultRetry{err} - } + retryable := isRetryableCallError(err) + e.Logger.Error("Callback request failed", tag.Error(err), tag.Bool("retryable", retryable)) + if retryable { + return invocationResultRetry{err} } - return invocationResultOK{} + return invocationResultFail{err} } - - retryable := isRetryableHTTPResponse(response) - err = readHandlerErrFromResponse(response, e.Logger) - e.Logger.Error("Callback request failed", tag.Error(err), tag.String("status", response.Status), tag.Bool("retryable", retryable)) - if retryable { - return invocationResultRetry{err} - } - return invocationResultFail{err} + return invocationResultOK{} } -// Reads and replaces the http response body and attempts to deserialize it into a Nexus failure. If successful, -// returns a nexus.HandlerError with the deserialized failure as the Cause. If there is an error reading the body or -// during deserialization, returns a nexus.HandlerError with a generic Cause based on response status. -// TODO: This logic is duplicated in the frontend handler for forwarded requests. Eventually it should live in the Nexus SDK. -func readHandlerErrFromResponse(response *http.Response, logger log.Logger) error { - handlerErr := &nexus.HandlerError{ - Type: commonnexus.HandlerErrorTypeFromHTTPStatus(response.StatusCode), - Cause: fmt.Errorf("request failed with: %v", response.Status), - } - - body, err := readAndReplaceBody(response) - if err != nil { - logger.Error("Error reading response body for non-ok callback request", tag.Error(err), tag.String("status", response.Status)) - return err - } - - if !isMediaTypeJSON(response.Header.Get("Content-Type")) { - logger.Error("received invalid content-type header for non-OK HTTP response to CompleteOperation request", tag.Value(response.Header.Get("Content-Type"))) - return handlerErr - } - - var failure nexus.Failure - err = json.Unmarshal(body, &failure) - if err != nil { - logger.Error("failed to deserialize Nexus Failure from HTTP response to CompleteOperation request", tag.Error(err)) - return handlerErr +func outcomeTag(callCtx context.Context, callErr error) string { + if callErr != nil { + if callCtx.Err() != nil { + return "request-timeout" + } + var handlerErr *nexus.HandlerError + if errors.As(callErr, &handlerErr) { + return "handler-error:" + string(handlerErr.Type) + } + return "unknown-error" } - - handlerErr.Cause = &nexus.FailureError{Failure: failure} - return handlerErr -} - -// readAndReplaceBody reads the response body in its entirety and closes it, and then replaces the original response -// body with an in-memory buffer. -// The body is replaced even when there was an error reading the entire body. -func readAndReplaceBody(response *http.Response) ([]byte, error) { - responseBody := response.Body - body, err := io.ReadAll(responseBody) - _ = responseBody.Close() - response.Body = io.NopCloser(bytes.NewReader(body)) - return body, err + return "success" } -func isMediaTypeJSON(contentType string) bool { - if contentType == "" { - return false +func isRetryableCallError(err error) bool { + var handlerError *nexus.HandlerError + if errors.As(err, &handlerError) { + return handlerError.Retryable() } - mediaType, _, err := mime.ParseMediaType(contentType) - return err == nil && mediaType == "application/json" + return true } diff --git a/components/nexusoperations/completion.go b/components/nexusoperations/completion.go index 90d59bbecf..cf37247030 100644 --- a/components/nexusoperations/completion.go +++ b/components/nexusoperations/completion.go @@ -8,6 +8,7 @@ import ( "github.com/nexus-rpc/sdk-go/nexus" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" + failurepb "go.temporal.io/api/failure/v1" historypb "go.temporal.io/api/history/v1" "go.temporal.io/api/serviceerror" commonnexus "go.temporal.io/server/common/nexus" @@ -43,25 +44,40 @@ func handleSuccessfulOperationResult( func handleOperationError( node *hsm.Node, operation Operation, - opFailedError *nexus.OperationError, + opErr *nexus.OperationError, ) error { eventID, err := hsm.EventIDFromToken(operation.ScheduledEventToken) if err != nil { return err } - failure, err := commonnexus.OperationErrorToTemporalFailure(opFailedError) - if err != nil { - return err + var originalCause *failurepb.Failure + // Special marker for Temporal->Temporal calls to indicate that the original failure should be unwrapped. + // Temporal uses a wrapper operation error with no additional information to transmit the OperationError over the network. + // The meaningful information is in the operation error's cause. + unwrapError := opErr.OriginalFailure.Metadata["unwrap-error"] == "true" + + if unwrapError && opErr.OriginalFailure.Cause != nil { + var err error + originalCause, err = commonnexus.NexusFailureToTemporalFailure(*opErr.OriginalFailure.Cause) + if err != nil { + return serviceerror.NewInvalidArgumentf("Malformed failure: %v", err) + } + } else { + // Transform the OperationError to either ApplicationFailure or CanceledFailure based on the operation error state. + originalCause, err = commonnexus.NexusFailureToTemporalFailure(*opErr.OriginalFailure) + if err != nil { + return serviceerror.NewInvalidArgumentf("Malformed failure: %v", err) + } } - switch opFailedError.State { // nolint:exhaustive + switch opErr.State { // nolint:exhaustive case nexus.OperationStateFailed: event := node.AddHistoryEvent(enumspb.EVENT_TYPE_NEXUS_OPERATION_FAILED, func(e *historypb.HistoryEvent) { // We must assign to this property, linter doesn't like this. // nolint:revive e.Attributes = &historypb.HistoryEvent_NexusOperationFailedEventAttributes{ NexusOperationFailedEventAttributes: &historypb.NexusOperationFailedEventAttributes{ - Failure: nexusOperationFailure(operation, eventID, failure), + Failure: createNexusOperationFailure(operation, eventID, originalCause), ScheduledEventId: eventID, RequestId: operation.RequestId, }, @@ -70,12 +86,23 @@ func handleOperationError( return FailedEventDefinition{}.Apply(node.Parent, event) case nexus.OperationStateCanceled: + if originalCause.GetCanceledFailureInfo() == nil { + // Old SDKs may send an ApplicationFailure for canceled operation causes. + originalCause = &failurepb.Failure{ + Message: originalCause.GetMessage(), + StackTrace: originalCause.GetStackTrace(), + FailureInfo: &failurepb.Failure_CanceledFailureInfo{ + CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, + }, + Cause: originalCause.GetCause(), + } + } event := node.AddHistoryEvent(enumspb.EVENT_TYPE_NEXUS_OPERATION_CANCELED, func(e *historypb.HistoryEvent) { // We must assign to this property, linter doesn't like this. // nolint:revive e.Attributes = &historypb.HistoryEvent_NexusOperationCanceledEventAttributes{ NexusOperationCanceledEventAttributes: &historypb.NexusOperationCanceledEventAttributes{ - Failure: nexusOperationFailure(operation, eventID, failure), + Failure: createNexusOperationFailure(operation, eventID, originalCause), ScheduledEventId: eventID, RequestId: operation.RequestId, }, @@ -86,7 +113,7 @@ func handleOperationError( default: // Both the Nexus Client and CompletionHandler reject invalid states, but just in case, we return this as a // transition error. - return fmt.Errorf("unexpected operation state: %v", opFailedError.State) + return fmt.Errorf("unexpected operation state: %v", opErr.State) } } diff --git a/components/nexusoperations/executors.go b/components/nexusoperations/executors.go index 0b980e948c..2a97803817 100644 --- a/components/nexusoperations/executors.go +++ b/components/nexusoperations/executors.go @@ -177,7 +177,7 @@ func (e taskExecutor) executeInvocationTask(ctx context.Context, env hsm.Environ // This happens when we accept the ScheduleNexusOperation command when the endpoint is not found in the registry as // indicated by the EndpointNotFoundAlwaysNonRetryable dynamic config. if args.endpointID == "" { - handlerError := nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") + handlerError := nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") return e.saveResult(ctx, env, ref, nil, handlerError) } @@ -185,7 +185,7 @@ func (e taskExecutor) executeInvocationTask(ctx context.Context, env hsm.Environ if err != nil { if errors.As(err, new(*serviceerror.NotFound)) { // The endpoint is not registered, immediately fail the invocation. - handlerError := nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") + handlerError := nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") return e.saveResult(ctx, env, ref, nil, handlerError) } return err @@ -479,12 +479,12 @@ func (e taskExecutor) saveResult(ctx context.Context, env hsm.Environment, ref h func (e taskExecutor) handleStartOperationError(env hsm.Environment, node *hsm.Node, operation Operation, callErr error) error { var handlerErr *nexus.HandlerError - var opFailedErr *nexus.OperationError + var opErr *nexus.OperationError var opTimeoutBelowMinErr *operationTimeoutBelowMinError switch { - case errors.As(callErr, &opFailedErr): - return handleOperationError(node, operation, opFailedErr) + case errors.As(callErr, &opErr): + return handleOperationError(node, operation, opErr) case errors.As(callErr, &handlerErr) && !handlerErr.Retryable(): // The StartOperation request got an unexpected response that is not retryable, fail the operation. // Although Failure is nullable, Nexus SDK is expected to always populate this field @@ -526,15 +526,15 @@ func handleNonRetryableStartOperationError(node *hsm.Node, operation Operation, if err != nil { return err } - failure, err := callErrToFailure(callErr, true) + cause, err := callErrToFailure(callErr, false) if err != nil { return err } attrs := &historypb.NexusOperationFailedEventAttributes{ - Failure: nexusOperationFailure( + Failure: createNexusOperationFailure( operation, eventID, - failure, + cause, ), ScheduledEventId: eventID, RequestId: operation.RequestId, @@ -579,7 +579,7 @@ func (e taskExecutor) recordOperationTimeout(node *hsm.Node, timeoutType enumspb // nolint:revive // We must mutate here even if the linter doesn't like it. e.Attributes = &historypb.HistoryEvent_NexusOperationTimedOutEventAttributes{ NexusOperationTimedOutEventAttributes: &historypb.NexusOperationTimedOutEventAttributes{ - Failure: nexusOperationFailure( + Failure: createNexusOperationFailure( op, eventID, &failurepb.Failure{ @@ -629,7 +629,7 @@ func (e taskExecutor) executeCancelationTask(ctx context.Context, env hsm.Enviro endpoint, err := e.lookupEndpoint(ctx, namespace.ID(ref.WorkflowKey.NamespaceID), args.endpointID, args.endpointName) if err != nil { if errors.As(err, new(*serviceerror.NotFound)) { - handlerError := nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") + handlerError := nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "endpoint not registered") // The endpoint is not registered, immediately fail the invocation. return e.saveCancelationResult(ctx, env, ref, handlerError, args.scheduledEventID) @@ -841,7 +841,7 @@ func (e taskExecutor) lookupEndpoint(ctx context.Context, namespaceID namespace. return entry, nil } -func nexusOperationFailure(operation Operation, scheduledEventID int64, cause *failurepb.Failure) *failurepb.Failure { +func createNexusOperationFailure(operation Operation, scheduledEventID int64, cause *failurepb.Failure) *failurepb.Failure { return &failurepb.Failure{ Message: "nexus operation completed unsuccessfully", FailureInfo: &failurepb.Failure_NexusOperationExecutionFailureInfo{ @@ -924,44 +924,21 @@ func isDestinationDown(err error) bool { func callErrToFailure(callErr error, retryable bool) (*failurepb.Failure, error) { var handlerErr *nexus.HandlerError if errors.As(callErr, &handlerErr) { - var retryBehavior enumspb.NexusHandlerErrorRetryBehavior - // nolint:exhaustive // unspecified is the default - switch handlerErr.RetryBehavior { - case nexus.HandlerErrorRetryBehaviorRetryable: - retryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_RETRYABLE - case nexus.HandlerErrorRetryBehaviorNonRetryable: - retryBehavior = enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE - } - failure := &failurepb.Failure{ - Message: handlerErr.Error(), - FailureInfo: &failurepb.Failure_NexusHandlerFailureInfo{ - NexusHandlerFailureInfo: &failurepb.NexusHandlerFailureInfo{ - Type: string(handlerErr.Type), - RetryBehavior: retryBehavior, - }, - }, - } - var failureError *nexus.FailureError - if errors.As(handlerErr.Cause, &failureError) { + var nf nexus.Failure + if handlerErr.OriginalFailure != nil { + nf = *handlerErr.OriginalFailure + } else { var err error - failure.Cause, err = commonnexus.NexusFailureToAPIFailure(failureError.Failure, retryable) + nf, err = nexusrpc.DefaultFailureConverter().ErrorToFailure(handlerErr) if err != nil { return nil, err } - } else { - cause := handlerErr.Cause - if cause == nil { - cause = errors.New("unknown cause") - } - failure.Cause = &failurepb.Failure{ - Message: cause.Error(), - FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ - ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{}, - }, - } } - - return failure, nil + f, err := commonnexus.NexusFailureToTemporalFailure(nf) + if err != nil { + return nil, err + } + return f, nil } return &failurepb.Failure{ diff --git a/components/nexusoperations/executors_test.go b/components/nexusoperations/executors_test.go index afa3347a5a..b6fef19b21 100644 --- a/components/nexusoperations/executors_test.go +++ b/components/nexusoperations/executors_test.go @@ -3,7 +3,6 @@ package nexusoperations_test import ( "context" "encoding/json" - "errors" "testing" "time" @@ -155,23 +154,23 @@ func TestProcessInvocationTask(t *testing.T) { onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { // Also use this test case to check the input and options provided. if service != "service" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation name") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation name") } if operation != "operation" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation name") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation name") } if options.CallbackHeader.Get("temporal-callback-token") == "" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "empty callback token") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "empty callback token") } if options.CallbackURL != "http://localhost/callback" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback URL") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback URL") } if options.Header.Get(nexus.HeaderOperationTimeout) != "1ms" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) } var v string if err := input.Consume(&v); err != nil || v != "input" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid input") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid input") } return &nexus.HandlerStartOperationResultSync[any]{Value: "result"}, nil }, @@ -194,9 +193,10 @@ func TestProcessInvocationTask(t *testing.T) { destinationDown: false, onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { return nil, &nexus.OperationError{ - State: nexus.OperationStateFailed, + State: nexus.OperationStateFailed, + Message: "operation failed from handler", Cause: &nexus.FailureError{ - Failure: nexus.Failure{Message: "operation failed from handler", Metadata: map[string]string{"encoding": "json/plain"}, Details: json.RawMessage("\"details\"")}, + Failure: nexus.Failure{Message: "cause", Metadata: map[string]string{"encoding": "json/plain"}, Details: json.RawMessage("\"details\"")}, }, } }, @@ -222,13 +222,23 @@ func TestProcessInvocationTask(t *testing.T) { Message: "operation failed from handler", FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ - Type: "NexusFailure", - Details: &commonpb.Payloads{ - Payloads: []*commonpb.Payload{ - mustToPayload(t, nexus.Failure{Metadata: map[string]string{"encoding": "json/plain"}, Details: []byte(`"details"`)}), + Type: "OperationError", + NonRetryable: true, + }, + }, + Cause: &failurepb.Failure{ + Message: "cause", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Details: &commonpb.Payloads{ + Payloads: []*commonpb.Payload{ + mustToPayload(t, nexus.Failure{ + Metadata: map[string]string{"encoding": "json/plain"}, + Details: json.RawMessage("\"details\""), + }), + }, }, }, - NonRetryable: true, }, }, }, @@ -243,9 +253,10 @@ func TestProcessInvocationTask(t *testing.T) { destinationDown: false, onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { return nil, &nexus.OperationError{ - State: nexus.OperationStateCanceled, + State: nexus.OperationStateCanceled, + Message: "operation canceled from handler", Cause: &nexus.FailureError{ - Failure: nexus.Failure{Message: "operation canceled from handler", Metadata: map[string]string{"encoding": "json/plain"}, Details: json.RawMessage("\"details\"")}, + Failure: nexus.Failure{Message: "cause", Metadata: map[string]string{"encoding": "json/plain"}, Details: json.RawMessage("\"details\"")}, }, } }, @@ -270,10 +281,19 @@ func TestProcessInvocationTask(t *testing.T) { Cause: &failurepb.Failure{ Message: "operation canceled from handler", FailureInfo: &failurepb.Failure_CanceledFailureInfo{ - CanceledFailureInfo: &failurepb.CanceledFailureInfo{ - Details: &commonpb.Payloads{ - Payloads: []*commonpb.Payload{ - mustToPayload(t, nexus.Failure{Metadata: map[string]string{"encoding": "json/plain"}, Details: []byte(`"details"`)}), + CanceledFailureInfo: &failurepb.CanceledFailureInfo{}, + }, + Cause: &failurepb.Failure{ + Message: "cause", + FailureInfo: &failurepb.Failure_ApplicationFailureInfo{ + ApplicationFailureInfo: &failurepb.ApplicationFailureInfo{ + Details: &commonpb.Payloads{ + Payloads: []*commonpb.Payload{ + mustToPayload(t, nexus.Failure{ + Metadata: map[string]string{"encoding": "json/plain"}, + Details: json.RawMessage("\"details\""), + }), + }, }, }, }, @@ -289,13 +309,13 @@ func TestProcessInvocationTask(t *testing.T) { requestTimeout: time.Hour, destinationDown: true, onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") }, expectedMetricOutcome: "handler-error:INTERNAL", checkOutcome: func(t *testing.T, op nexusoperations.Operation, events []*historypb.HistoryEvent) { require.Equal(t, enumsspb.NEXUS_OPERATION_STATE_BACKING_OFF, op.State()) - require.NotNil(t, op.LastAttemptFailure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (INTERNAL): internal server error", op.LastAttemptFailure.Message) + require.Equal(t, string(nexus.HandlerErrorTypeInternal), op.LastAttemptFailure.GetNexusHandlerFailureInfo().GetType()) + require.Equal(t, "internal server error", op.LastAttemptFailure.Message) require.Equal(t, 0, len(events)) }, }, @@ -325,7 +345,7 @@ func TestProcessInvocationTask(t *testing.T) { onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { opTimeout, err := time.ParseDuration(options.Header.Get(nexus.HeaderOperationTimeout)) if err != nil || opTimeout > 10*time.Millisecond { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) } time.Sleep(time.Millisecond * 100) //nolint:forbidigo // Allow time.Sleep for timeout tests return &nexus.HandlerStartOperationResultAsync{OperationToken: "op-token"}, nil @@ -345,7 +365,7 @@ func TestProcessInvocationTask(t *testing.T) { expectedMetricOutcome: "request-timeout", onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { if options.Header.Get(nexus.HeaderOperationTimeout) != "" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation timeout header should not be set, got: %s", options.Header.Get(nexus.HeaderOperationTimeout)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation timeout header should not be set, got: %s", options.Header.Get(nexus.HeaderOperationTimeout)) } time.Sleep(time.Millisecond * 100) //nolint:forbidigo // Allow time.Sleep for timeout tests return &nexus.HandlerStartOperationResultAsync{OperationToken: "op-token"}, nil @@ -365,7 +385,7 @@ func TestProcessInvocationTask(t *testing.T) { expectedMetricOutcome: "pending", onStartOperation: func(ctx context.Context, service, operation string, input *nexus.LazyValue, options nexus.StartOperationOptions) (nexus.HandlerStartOperationResult[any], error) { if options.Header.Get(nexus.HeaderOperationTimeout) != "60000ms" { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation timeout header: %s", options.Header.Get(nexus.HeaderOperationTimeout)) } return &nexus.HandlerStartOperationResultAsync{OperationToken: "op-token"}, nil }, @@ -399,8 +419,8 @@ func TestProcessInvocationTask(t *testing.T) { require.Equal(t, enumsspb.NEXUS_OPERATION_STATE_FAILED, op.State()) require.Equal(t, 1, len(events)) failure := events[0].GetNexusOperationFailedEventAttributes().Failure.Cause - require.NotNil(t, failure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (NOT_FOUND): endpoint not registered", failure.Message) + require.Equal(t, string(nexus.HandlerErrorTypeNotFound), failure.GetNexusHandlerFailureInfo().GetType()) + require.Equal(t, "endpoint not registered", failure.Message) }, }, { @@ -413,10 +433,8 @@ func TestProcessInvocationTask(t *testing.T) { require.Equal(t, enumsspb.NEXUS_OPERATION_STATE_FAILED, op.State()) require.Equal(t, 1, len(events)) failure := events[0].GetNexusOperationFailedEventAttributes().Failure.Cause - require.NotNil(t, failure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (NOT_FOUND): endpoint not registered", failure.Message) - require.NotNil(t, failure.Cause.GetApplicationFailureInfo()) - require.Equal(t, "endpoint not registered", failure.Cause.Message) + require.Equal(t, string(nexus.HandlerErrorTypeNotFound), failure.GetNexusHandlerFailureInfo().GetType()) + require.Equal(t, "endpoint not registered", failure.Message) }, }, { @@ -799,15 +817,15 @@ func TestProcessCancelationTask(t *testing.T) { // Check non retryable internal error. return &nexus.HandlerError{ Type: nexus.HandlerErrorTypeInternal, - Cause: errors.New("operation not found"), + Message: "operation not found", RetryBehavior: nexus.HandlerErrorRetryBehaviorNonRetryable, } }, expectedMetricOutcome: "handler-error:INTERNAL", checkOutcome: func(t *testing.T, c nexusoperations.Cancelation) { require.Equal(t, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_FAILED, c.State()) - require.NotNil(t, c.LastAttemptFailure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (INTERNAL): operation not found", c.LastAttemptFailure.Message) + require.Equal(t, string(nexus.HandlerErrorTypeInternal), c.LastAttemptFailure.GetNexusHandlerFailureInfo().GetType()) + require.Equal(t, "operation not found", c.LastAttemptFailure.Message) }, }, { @@ -830,7 +848,7 @@ func TestProcessCancelationTask(t *testing.T) { header: map[string]string{"key": "value"}, onCancelOperation: func(ctx context.Context, service, operation, token string, options nexus.CancelOperationOptions) error { if options.Header["key"] != "value" { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, `"key" header is not equal to "value"`) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, `"key" header is not equal to "value"`) } return nil }, @@ -845,13 +863,13 @@ func TestProcessCancelationTask(t *testing.T) { requestTimeout: time.Hour, destinationDown: true, onCancelOperation: func(ctx context.Context, service, operation, token string, options nexus.CancelOperationOptions) error { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") }, expectedMetricOutcome: "handler-error:INTERNAL", checkOutcome: func(t *testing.T, c nexusoperations.Cancelation) { require.Equal(t, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_BACKING_OFF, c.State()) require.NotNil(t, c.LastAttemptFailure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (INTERNAL): internal server error", c.LastAttemptFailure.Message) + require.Equal(t, "internal server error", c.LastAttemptFailure.Message) }, }, { @@ -903,8 +921,8 @@ func TestProcessCancelationTask(t *testing.T) { onCancelOperation: nil, // This should not be called if the endpoint is not found. checkOutcome: func(t *testing.T, c nexusoperations.Cancelation) { require.Equal(t, enumspb.NEXUS_OPERATION_CANCELLATION_STATE_FAILED, c.State()) - require.NotNil(t, c.LastAttemptFailure.GetNexusHandlerFailureInfo()) - require.Equal(t, "handler error (NOT_FOUND): endpoint not registered", c.LastAttemptFailure.Message) + require.Equal(t, string(nexus.HandlerErrorTypeNotFound), c.LastAttemptFailure.GetNexusHandlerFailureInfo().GetType()) + require.Equal(t, "endpoint not registered", c.LastAttemptFailure.Message) }, }, } diff --git a/components/nexusoperations/frontend/handler.go b/components/nexusoperations/frontend/handler.go index 4e457785a3..1c95fe439e 100644 --- a/components/nexusoperations/frontend/handler.go +++ b/components/nexusoperations/frontend/handler.go @@ -1,13 +1,9 @@ package frontend import ( - "bytes" "context" - "encoding/json" "errors" "fmt" - "io" - "mime" "net/http" "net/http/httptrace" "net/url" @@ -21,6 +17,7 @@ import ( commonpb "go.temporal.io/api/common/v1" "go.temporal.io/api/serviceerror" "go.temporal.io/server/api/historyservice/v1" + "go.temporal.io/server/common" "go.temporal.io/server/common/authorization" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/dynamicconfig" @@ -91,18 +88,18 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C startTime := time.Now() if !h.Config.Enabled() { h.preProcessErrorsCounter.Record(1) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "Nexus APIs are disabled") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "Nexus APIs are disabled") } token, err := commonnexus.DecodeCallbackToken(r.HTTPRequest.Header.Get(commonnexus.CallbackTokenHeader)) if err != nil { h.Logger.Error("failed to decode callback token", tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") } completion, err := h.CallbackTokenGenerator.DecodeCompletion(token) if err != nil { h.Logger.Error("failed to decode completion from token", tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") } ns, err := h.NamespaceRegistry.GetNamespaceByID(namespace.ID(completion.NamespaceId)) if err != nil { @@ -110,7 +107,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C h.preProcessErrorsCounter.Record(1) var nfe *serviceerror.NamespaceNotFound if errors.As(err, &nfe) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace %q not found", completion.NamespaceId) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace %q not found", completion.NamespaceId) } return commonnexus.ConvertGRPCError(err, false) } @@ -143,7 +140,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C if err != nil { h.Logger.Error("failed to extract namespace from request", tag.Error(err)) h.preProcessErrorsCounter.Record(1) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL") } if nsName != ns.Name().String() { logger.Error( @@ -152,7 +149,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C tag.Error(err), tag.String("completion-namespace-id", completion.GetNamespaceId()), ) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid callback token") } } @@ -165,7 +162,7 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C } tokenLimit := h.Config.MaxOperationTokenLength(ns.Name().String()) if len(r.OperationToken) > tokenLimit { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation token length exceeds allowed limit (%d/%d)", len(r.OperationToken), tokenLimit) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "operation token length exceeds allowed limit (%d/%d)", len(r.OperationToken), tokenLimit) } var links []*commonpb.Link @@ -201,25 +198,18 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C } switch r.State { // nolint:exhaustive case nexus.OperationStateFailed, nexus.OperationStateCanceled: - failureErr, ok := r.Error.(*nexus.FailureError) - if !ok { - // This shouldn't happen as the Nexus SDK is always expected to convert Failures from the wire to - // FailureErrors. - logger.Error("result error is not a FailureError", tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal server error") - } hr.Outcome = &historyservice.CompleteNexusOperationRequest_Failure{ - Failure: commonnexus.NexusFailureToProtoFailure(failureErr.Failure), + Failure: commonnexus.NexusFailureToProtoFailure(*r.Error.OriginalFailure), } case nexus.OperationStateSucceeded: var result *commonpb.Payload if err := r.Result.Consume(&result); err != nil { logger.Error("cannot deserialize payload from completion result", tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid result content") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid result content") } if result.Size() > h.Config.PayloadSizeLimit(ns.Name().String()) { logger.Error("payload size exceeds error limit for Nexus CompleteOperation request", tag.WorkflowNamespace(ns.Name().String())) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "result exceeds size limit") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "result exceeds size limit") } hr.Outcome = &historyservice.CompleteNexusOperationRequest_Success{ Success: result, @@ -227,14 +217,14 @@ func (h *completionHandler) CompleteOperation(ctx context.Context, r *nexusrpc.C default: // The Nexus SDK ensures this never happens but just in case... logger.Error("invalid operation state in completion request", tag.String("state", string(r.State)), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid completion state") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid completion state") } _, err = h.HistoryClient.CompleteNexusOperation(ctx, hr) if err != nil { logger.Error("failed to process nexus completion request", tag.Error(err)) var namespaceInactiveErr *serviceerror.NamespaceNotActive if errors.As(err, &namespaceInactiveErr) { - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") } var notFoundErr *serviceerror.NotFound if errors.As(err, ¬FoundErr) { @@ -249,13 +239,13 @@ func (h *completionHandler) forwardCompleteOperation(ctx context.Context, r *nex client, err := h.ForwardingClients.Get(rCtx.namespace.ActiveClusterName(rCtx.workflowID)) if err != nil { h.Logger.Error("unable to get HTTP client for forward request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err), tag.SourceCluster(h.ClusterMetadata.GetCurrentClusterName()), tag.TargetCluster(rCtx.namespace.ActiveClusterName(rCtx.workflowID))) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } forwardURL, err := url.JoinPath(client.BaseURL(), commonnexus.RouteCompletionCallback.Path(rCtx.namespace.Name().String())) if err != nil { h.Logger.Error("failed to construct forwarding request URL", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err), tag.TargetCluster(rCtx.namespace.ActiveClusterName(rCtx.workflowID))) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } if h.HTTPTraceProvider != nil { @@ -271,121 +261,55 @@ func (h *completionHandler) forwardCompleteOperation(ctx context.Context, r *nex } } - var forwardReq *http.Request + var completion nexusrpc.CompleteOperationOptions + switch r.State { case nexus.OperationStateSucceeded: - // For successful operations, the Nexus framework streams the result as a LazyValue, so we can reuse the - // incoming request body. - forwardReq, err = http.NewRequestWithContext(ctx, r.HTTPRequest.Method, forwardURL, r.HTTPRequest.Body) - if err != nil { - h.Logger.Error("failed to construct forwarding HTTP request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + completion = nexusrpc.CompleteOperationOptions{ + Result: r.Result.Reader, + OperationToken: r.OperationToken, + StartTime: r.StartTime, + CloseTime: r.CloseTime, + Links: r.Links, } case nexus.OperationStateFailed, nexus.OperationStateCanceled: // For unsuccessful operations, the Nexus framework reads and closes the original request body to deserialize // the failure, so we must construct a new completion to forward. - var failureErr *nexus.FailureError - if !errors.As(r.Error, &failureErr) { - // This shouldn't happen as the Nexus SDK is always expected to convert Failures from the wire to - // FailureErrors. - h.Logger.Error("received unexpected error type when trying to forward Nexus operation completion", tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") - } - c := &nexusrpc.OperationCompletionUnsuccessful{ - State: r.State, + completion = nexusrpc.CompleteOperationOptions{ + Error: r.Error, OperationToken: r.OperationToken, StartTime: r.StartTime, + CloseTime: r.CloseTime, Links: r.Links, - Failure: failureErr.Failure, - } - forwardReq, err = nexusrpc.NewCompletionHTTPRequest(ctx, forwardURL, c) - if err != nil { - h.Logger.Error("failed to construct forwarding HTTP request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } default: - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation state: %q", r.State) - } - - forwardReq.Header = rCtx.originalHeaders - if forwardReq.Header == nil { - forwardReq.Header = make(http.Header, 1) - } - forwardReq.Header.Set(interceptor.DCRedirectionApiHeaderName, "true") - - resp, err := client.Do(forwardReq) - if err != nil { - h.Logger.Error("received error from HTTP client when forwarding request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") - } - - // TODO: The following response handling logic is duplicated in the nexus_invocation executor. Eventually it should live in the Nexus SDK. - body, err := readAndReplaceBody(resp) - if err != nil { - h.Logger.Error("unable to read HTTP response for forwarded request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") - } - - if resp.StatusCode == http.StatusOK { - return nil - } - - if !isMediaTypeJSON(resp.Header.Get("Content-Type")) { - h.Logger.Error("received invalid content-type header for non-OK HTTP response to forwarded request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Value(resp.Header.Get("Content-Type"))) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation state: %q", r.State) } - var failure nexus.Failure - err = json.Unmarshal(body, &failure) - if err != nil { - h.Logger.Error("failed to deserialize Nexus Failure from HTTP response to forwarded request", tag.Operation(apiName), tag.WorkflowNamespace(rCtx.namespace.Name().String()), tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") - } - - // TODO: Upgrade Nexus SDK in order to reduce HTTP exposure - handlerErr := &nexus.HandlerError{ - Type: commonnexus.HandlerErrorTypeFromHTTPStatus(resp.StatusCode), - Cause: &nexus.FailureError{Failure: failure}, - } - - if handlerErr.Type == nexus.HandlerErrorTypeInternal && resp.StatusCode != http.StatusInternalServerError { - h.Logger.Warn("received unknown status code on Nexus client unexpected response error", tag.Value(resp.StatusCode)) - handlerErr.Cause = errors.New("internal error") - } - - return handlerErr -} - -// readAndReplaceBody reads the response body in its entirety and closes it, and then replaces the original response -// body with an in-memory buffer. -// The body is replaced even when there was an error reading the entire body. -func readAndReplaceBody(response *http.Response) ([]byte, error) { - responseBody := response.Body - body, err := io.ReadAll(responseBody) - _ = responseBody.Close() - response.Body = io.NopCloser(bytes.NewReader(body)) - return body, err + rCtx.originalHeaders.Set(interceptor.DCRedirectionApiHeaderName, "true") + cc := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + HTTPCaller: (&forwardingHTTPHeaderWrapper{ + client: client, + originalRequestHeaders: rCtx.originalHeaders, + }).Do, + }) + return cc.CompleteOperation(ctx, forwardURL, completion) } -func isMediaTypeJSON(contentType string) bool { - if contentType == "" { - return false - } - mediaType, _, err := mime.ParseMediaType(contentType) - return err == nil && mediaType == "application/json" +type forwardingHTTPHeaderWrapper struct { + client *common.FrontendHTTPClient + originalRequestHeaders http.Header } -// Copies HTTP request headers to Nexus headers except those starting with content- since those will be added by the client. -func httpHeaderToNexusHeader(httpHeader http.Header) nexus.Header { - header := nexus.Header{} - for k, v := range httpHeader { - lowerK := strings.ToLower(k) - if !strings.HasPrefix(lowerK, "content-") { - // Nexus headers can only have single values, ignore multiple values. - header[lowerK] = v[0] +func (f *forwardingHTTPHeaderWrapper) Do(req *http.Request) (*http.Response, error) { + // For forwarded requests, copy the original HTTP headers without sanitization. + for k, v := range f.originalRequestHeaders { + if req.Header.Get(k) == "" { + req.Header.Set(k, v[0]) } } - return header + + return f.client.Do(req) } type requestContext struct { @@ -532,7 +456,7 @@ func (c *requestContext) interceptRequest(ctx context.Context, request *nexusrpc return serviceerror.NewNamespaceNotActive(c.namespace.Name().String(), c.ClusterMetadata.GetCurrentClusterName(), c.namespace.ActiveClusterName(c.workflowID)) } c.metricsHandler = c.metricsHandler.WithTags(metrics.OutcomeTag("namespace_inactive_forwarding_disabled")) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") } c.cleanupFunctions = append(c.cleanupFunctions, func(retErr error) { diff --git a/go.mod b/go.mod index 5c6dc7173f..3c3afcb14e 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/lib/pq v1.10.9 github.com/maruel/panicparse/v2 v2.4.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/nexus-rpc/sdk-go v0.5.1 + github.com/nexus-rpc/sdk-go v0.5.2-0.20260211051645-26b0b4c584e5 github.com/olekukonko/tablewriter v0.0.5 github.com/olivere/elastic/v7 v7.0.32 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index d81b84166e..5845e34a55 100644 --- a/go.sum +++ b/go.sum @@ -236,8 +236,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/nexus-rpc/sdk-go v0.5.1 h1:UFYYfoHlQc+Pn9gQpmn9QE7xluewAn2AO1OSkAh7YFU= -github.com/nexus-rpc/sdk-go v0.5.1/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= +github.com/nexus-rpc/sdk-go v0.5.2-0.20260211051645-26b0b4c584e5 h1:Van9KGGs8lcDgxzSNFbDhEMNeJ80TbBxwZ45f9iBk9U= +github.com/nexus-rpc/sdk-go v0.5.2-0.20260211051645-26b0b4c584e5/go.mod h1:FHdPfVQwRuJFZFTF0Y2GOAxCrbIBNrcPna9slkGKPYk= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= diff --git a/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto b/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto index 49bb428fb5..b4dae9af1c 100644 --- a/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto +++ b/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto @@ -9,6 +9,7 @@ import "google/protobuf/timestamp.proto"; import "temporal/api/common/v1/message.proto"; import "temporal/api/deployment/v1/message.proto"; import "temporal/api/enums/v1/task_queue.proto"; +import "temporal/api/failure/v1/message.proto"; import "temporal/api/history/v1/message.proto"; import "temporal/api/taskqueue/v1/message.proto"; import "temporal/api/query/v1/message.proto"; @@ -554,11 +555,13 @@ message DispatchNexusTaskResponse { message Timeout {} oneof outcome { - // Set if the worker's handler failed the nexus task. - temporal.api.nexus.v1.HandlerError handler_error = 1; + // Deprecated. Use failure field instead. + temporal.api.nexus.v1.HandlerError handler_error = 1 [deprecated = true]; // Set if the worker's handler responded successfully to the nexus task. temporal.api.nexus.v1.Response response = 2; Timeout request_timeout = 3; + // Set if the worker's handler failed the nexus task. Must contain a NexusHandlerFailureInfo object. + temporal.api.failure.v1.Failure failure = 4; } } diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index ef81703993..d245447c28 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -46,6 +46,9 @@ const ( // Generic Nexus context that is not bound to a specific operation. // Includes fields extracted from an incoming Nexus request before being handled by the Nexus HTTP handler. type nexusContext struct { + // Whether to use the new Temporal failure responses path. + // Set from the incoming nexus request's "temporal-nexus-failure-support" header. + callerFailureSupport bool requestStartTime time.Time apiName string namespaceName string @@ -196,7 +199,7 @@ func (c *operationContext) interceptRequest( ) } c.metricsHandler = c.metricsHandler.WithTags(metrics.OutcomeTag("namespace_inactive_forwarding_disabled")) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "cluster inactive") } c.cleanupFunctions = append(c.cleanupFunctions, func(respHeaders map[string]string, retErr error) { @@ -367,7 +370,7 @@ func (h *nexusHandler) getOperationContext(ctx context.Context, method string) ( var nfe *serviceerror.NamespaceNotFound if errors.As(err, &nfe) { - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace not found: %q", nc.namespaceName) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "namespace not found: %q", nc.namespaceName) } return nil, commonnexus.ConvertGRPCError(err, false) } @@ -413,6 +416,9 @@ func (h *nexusHandler) StartOperation( Variant: &nexuspb.Request_StartOperation{ StartOperation: &startOperationRequest, }, + Capabilities: &nexuspb.Request_Capabilities{ + TemporalFailureResponses: oc.callerFailureSupport, + }, }) if err := oc.interceptRequest(ctx, request, options.Header); err != nil { @@ -426,11 +432,11 @@ func (h *nexusHandler) StartOperation( // Transform nexus Content to temporal Payload with common/nexus PayloadSerializer. if err = input.Consume(&startOperationRequest.Payload); err != nil { oc.logger.Warn("invalid input", tag.Error(err)) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid input") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid input") } if startOperationRequest.Payload.Size() > h.payloadSizeLimit(oc.namespaceName) { oc.logger.Error("payload size exceeds error limit for Nexus StartOperation request", tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "input exceeds size limit") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "input exceeds size limit") } // Dispatch the request to be sync matched with a worker polling on the nexusContext taskQueue. @@ -443,20 +449,35 @@ func (h *nexusHandler) StartOperation( } // Convert to standard Nexus SDK response. switch t := response.GetOutcome().(type) { + case *matchingservice.DispatchNexusTaskResponse_Failure: + // Set the failure source to "worker" if we've reached this case. + // Failure conversions errors below are the user's fault, as it implies that malformed completions were sent from + // the worker. + oc.setFailureSource(commonnexus.FailureSourceWorker) + oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.Failure.GetNexusHandlerFailureInfo().GetType())) + nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) + if err != nil { + oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + he, err := nexusrpc.DefaultFailureConverter().FailureToError(nf) + if err != nil { + oc.logger.Error("error converting Nexus failure to Nexus HandlerError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + return nil, he + case *matchingservice.DispatchNexusTaskResponse_HandlerError: + // Deprecated case. Replaced with DispatchNexusTaskResponse_Failure oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.HandlerError.GetErrorType())) - - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - - err := h.convertOutcomeToNexusHandlerError(t) + oc.setFailureSource(commonnexus.FailureSourceWorker) + err := convertOutcomeToNexusHandlerError(t) return nil, err case *matchingservice.DispatchNexusTaskResponse_RequestTimeout: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_timeout")) - oc.setFailureSource(commonnexus.FailureSourceWorker) - - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeUpstreamTimeout, "upstream timeout") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUpstreamTimeout, "upstream timeout") case *matchingservice.DispatchNexusTaskResponse_Response: switch t := t.Response.GetStartOperation().GetVariant().(type) { @@ -470,7 +491,6 @@ func (h *nexusHandler) StartOperation( case *nexuspb.StartOperationResponse_AsyncSuccess: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("async_success")) - token := t.AsyncSuccess.GetOperationToken() if token == "" { token = t.AsyncSuccess.GetOperationId() @@ -483,24 +503,57 @@ func (h *nexusHandler) StartOperation( case *nexuspb.StartOperationResponse_OperationError: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("operation_error")) - - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - - err := &nexus.OperationError{ - State: nexus.OperationState(t.OperationError.GetOperationState()), + oc.setFailureSource(commonnexus.FailureSourceWorker) + opErr := &nexus.OperationError{ + Message: "operation error", + State: nexus.OperationState(t.OperationError.GetOperationState()), Cause: &nexus.FailureError{ Failure: commonnexus.ProtoFailureToNexusFailure(t.OperationError.GetFailure()), }, } - return nil, err + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + oc.logger.Error("error converting OperationError to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + return nil, opErr + + case *nexuspb.StartOperationResponse_Failure: + // Set the failure source to "worker" if we've reached this case. + // Failure conversions errors below are the user's fault, as it implies that malformed completions were sent from + // the worker. + oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("failure")) + oc.setFailureSource(commonnexus.FailureSourceWorker) + nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) + if err != nil { + oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + cause, err := nexusrpc.DefaultFailureConverter().FailureToError(nf) + if err != nil { + oc.logger.Error("error converting Nexus failure to Nexus OperationError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + state := nexus.OperationStateFailed + if t.Failure.GetCanceledFailureInfo() != nil { + state = nexus.OperationStateCanceled + } + opErr := &nexus.OperationError{ + State: state, + Message: "operation error", + Cause: cause, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + oc.logger.Error("error converting OperationError to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + return nil, opErr } } // This is the worker's fault. oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:EMPTY_OUTCOME")) + oc.setFailureSource(commonnexus.FailureSourceWorker) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") } func parseLinks(links []*nexuspb.Link, logger log.Logger) []nexus.Link { @@ -589,6 +642,9 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, OperationId: token, }, }, + Capabilities: &nexuspb.Request_Capabilities{ + TemporalFailureResponses: oc.callerFailureSupport, + }, }) if err := oc.interceptRequest(ctx, request, options.Header); err != nil { var notActiveErr *serviceerror.NamespaceNotActive @@ -608,20 +664,35 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, } // Convert to standard Nexus SDK response. switch t := response.GetOutcome().(type) { + case *matchingservice.DispatchNexusTaskResponse_Failure: + oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.Failure.GetNexusHandlerFailureInfo().GetType())) + // Set the failure source to "worker" if we've reached this case. + // Failure conversions errors below are the user's fault, as it implies that malformed completions were sent from + // the worker. + oc.setFailureSource(commonnexus.FailureSourceWorker) + nf, err := commonnexus.TemporalFailureToNexusFailure(t.Failure) + if err != nil { + oc.logger.Error("error converting Temporal failure to Nexus failure", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + he, err := nexusrpc.DefaultFailureConverter().FailureToError(nf) + if err != nil { + oc.logger.Error("error converting Nexus failure to Nexus HandlerError", tag.Error(err), tag.Operation(operation), tag.WorkflowNamespace(oc.namespaceName)) + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + } + return he + case *matchingservice.DispatchNexusTaskResponse_HandlerError: + // Deprecated case. Replaced with DispatchNexusTaskResponse_Failure oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:" + t.HandlerError.GetErrorType())) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - - err := h.convertOutcomeToNexusHandlerError(t) + err := convertOutcomeToNexusHandlerError(t) return err case *matchingservice.DispatchNexusTaskResponse_RequestTimeout: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_timeout")) - oc.setFailureSource(commonnexus.FailureSourceWorker) - - return nexus.HandlerErrorf(nexus.HandlerErrorTypeUpstreamTimeout, "upstream timeout") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUpstreamTimeout, "upstream timeout") case *matchingservice.DispatchNexusTaskResponse_Response: oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("success")) @@ -629,10 +700,9 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, } // This is the worker's fault. oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("handler_error:EMPTY_OUTCOME")) - oc.nexusContext.setFailureSource(commonnexus.FailureSourceWorker) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "empty outcome") } func (h *nexusHandler) forwardCancelOperation( @@ -653,7 +723,7 @@ func (h *nexusHandler) forwardCancelOperation( handle, err := client.NewOperationHandle(operation, id) if err != nil { oc.logger.Warn("invalid Nexus cancel operation.", tag.Error(err)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation") + return nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid operation") } if h.httpTraceProvider != nil { @@ -686,7 +756,7 @@ func (h *nexusHandler) nexusClientForActiveCluster(oc *operationContext, service if err != nil { oc.logger.Error("failed to forward Nexus request. error creating HTTP client", tag.Error(err), tag.SourceCluster(oc.namespace.ActiveClusterName(namespace.EmptyBusinessID)), tag.TargetCluster(oc.namespace.ActiveClusterName(""))) oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("request_forwarding_failed")) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "request forwarding failed") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "request forwarding failed") } httpCaller := &forwardingHttpHeaderWrapper{ @@ -718,7 +788,7 @@ func (h *nexusHandler) nexusClientForActiveCluster(oc *operationContext, service tag.WorkflowTaskQueueName(oc.taskQueue), tag.Error(err)) oc.metricsHandler = oc.metricsHandler.WithTags(metrics.OutcomeTag("request_forwarding_failed")) - return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "request forwarding failed") + return nil, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "request forwarding failed") } return nexusrpc.NewHTTPClient(nexusrpc.HTTPClientOptions{ @@ -728,7 +798,7 @@ func (h *nexusHandler) nexusClientForActiveCluster(oc *operationContext, service }) } -func (h *nexusHandler) convertOutcomeToNexusHandlerError(resp *matchingservice.DispatchNexusTaskResponse_HandlerError) *nexus.HandlerError { +func convertOutcomeToNexusHandlerError(resp *matchingservice.DispatchNexusTaskResponse_HandlerError) *nexus.HandlerError { var retryBehavior nexus.HandlerErrorRetryBehavior // nolint:exhaustive // unspecified is the default switch resp.HandlerError.RetryBehavior { @@ -737,28 +807,13 @@ func (h *nexusHandler) convertOutcomeToNexusHandlerError(resp *matchingservice.D case enumspb.NEXUS_HANDLER_ERROR_RETRY_BEHAVIOR_NON_RETRYABLE: retryBehavior = nexus.HandlerErrorRetryBehaviorNonRetryable } - handlerError := &nexus.HandlerError{ - Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), - Cause: &nexus.FailureError{ - Failure: commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()), - }, + // nolint:staticcheck // Deprecated function still in use for backward compatibility. + cause := commonnexus.ProtoFailureToNexusFailure(resp.HandlerError.GetFailure()) + return &nexus.HandlerError{ + // nolint:staticcheck // Deprecated function still in use for backward compatibility. + Type: nexus.HandlerErrorType(resp.HandlerError.GetErrorType()), RetryBehavior: retryBehavior, - } - - switch handlerError.Type { - case nexus.HandlerErrorTypeUpstreamTimeout, - nexus.HandlerErrorTypeUnauthenticated, - nexus.HandlerErrorTypeUnauthorized, - nexus.HandlerErrorTypeBadRequest, - nexus.HandlerErrorTypeResourceExhausted, - nexus.HandlerErrorTypeNotFound, - nexus.HandlerErrorTypeNotImplemented, - nexus.HandlerErrorTypeUnavailable, - nexus.HandlerErrorTypeInternal: - return handlerError - default: - h.logger.Warn("received unknown or unset Nexus handler error type", tag.Value(handlerError.Type)) - return nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") + Cause: &nexus.FailureError{Failure: cause}, } } diff --git a/service/frontend/nexus_http_handler.go b/service/frontend/nexus_http_handler.go index 7f40edbdac..b5943d5f01 100644 --- a/service/frontend/nexus_http_handler.go +++ b/service/frontend/nexus_http_handler.go @@ -2,7 +2,6 @@ package frontend import ( "context" - "encoding/json" "errors" "net/http" "net/url" @@ -34,6 +33,7 @@ import ( // Small wrapper that does some pre-processing before handing requests over to the Nexus SDK's HTTP handler. type NexusHTTPHandler struct { + base nexusrpc.BaseHTTPHandler logger log.Logger nexusHandler http.Handler enpointRegistry commonnexus.EndpointRegistry @@ -67,6 +67,10 @@ func NewNexusHTTPHandler( httpTraceProvider commonnexus.HTTPClientTraceProvider, ) *NexusHTTPHandler { return &NexusHTTPHandler{ + base: nexusrpc.BaseHTTPHandler{ + Logger: log.NewSlogLogger(logger), + FailureConverter: nexusrpc.DefaultFailureConverter(), + }, logger: logger, enpointRegistry: endpointRegistry, namespaceRegistry: namespaceRegistry, @@ -110,66 +114,50 @@ func (h *NexusHTTPHandler) RegisterRoutes(r *mux.Router) { HandlerFunc(h.dispatchNexusTaskByEndpoint) } -func (h *NexusHTTPHandler) writeNexusFailure(writer http.ResponseWriter, statusCode int, failure *nexus.Failure) { +func (h *NexusHTTPHandler) writeFailure(writer http.ResponseWriter, r *http.Request, err error) { h.preprocessErrorCounter.Record(1) - - if failure == nil { - writer.WriteHeader(statusCode) - return - } - - bytes, err := json.Marshal(failure) - if err != nil { - h.logger.Error("failed to marshal failure", tag.Error(err)) - writer.WriteHeader(http.StatusInternalServerError) - return - } - writer.Header().Set("Content-Type", "application/json") - writer.WriteHeader(statusCode) - - if _, err := writer.Write(bytes); err != nil { - h.logger.Error("failed to write response body", tag.Error(err)) - } + h.base.WriteFailure(writer, r, err) } // Handler for [nexushttp.RouteSet.DispatchNexusTaskByNamespaceAndTaskQueue]. func (h *NexusHTTPHandler) dispatchNexusTaskByNamespaceAndTaskQueue(w http.ResponseWriter, r *http.Request) { if !h.enabled() { - h.writeNexusFailure(w, http.StatusNotFound, &nexus.Failure{Message: "nexus endpoints disabled"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "nexus endpoints disabled")) return } var err error - nc := h.baseNexusContext(configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName) + nc := h.baseNexusContext(configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName, r.Header) params := prepareRequest(commonnexus.RouteDispatchNexusTaskByNamespaceAndTaskQueue, w, r) if nc.taskQueue, err = url.PathUnescape(params.TaskQueue); err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusBadRequest, &nexus.Failure{Message: "invalid URL"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL")) return } if nc.namespaceName, err = url.PathUnescape(params.Namespace); err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusBadRequest, &nexus.Failure{Message: "invalid URL"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL")) return } if err = h.namespaceValidationInterceptor.ValidateName(nc.namespaceName); err != nil { h.logger.Error("invalid namespace name", tag.Error(err)) - h.writeNexusFailure(w, http.StatusBadRequest, &nexus.Failure{Message: err.Error()}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "%v", err.Error())) return } - r, err = h.parseTlsAndAuthInfo(r, nc) + rWithAuthCtx, err := h.parseTlsAndAuthInfo(r, nc) if err != nil { h.logger.Error("failed to get claims", tag.Error(err)) - h.writeNexusFailure(w, http.StatusUnauthorized, &nexus.Failure{Message: "unauthorized"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnauthenticated, "unauthorized")) return } + r = rWithAuthCtx u, err := mux.CurrentRoute(r).URL("namespace", params.Namespace, "task_queue", params.TaskQueue) if err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusInternalServerError, &nexus.Failure{Message: "internal error"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error")) return } @@ -179,7 +167,7 @@ func (h *NexusHTTPHandler) dispatchNexusTaskByNamespaceAndTaskQueue(w http.Respo // Handler for [nexushttp.RouteSet.DispatchNexusTaskByEndpoint]. func (h *NexusHTTPHandler) dispatchNexusTaskByEndpoint(w http.ResponseWriter, r *http.Request) { if !h.enabled() { - h.writeNexusFailure(w, http.StatusNotFound, &nexus.Failure{Message: "nexus endpoints disabled"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "nexus endpoints disabled")) return } @@ -188,7 +176,7 @@ func (h *NexusHTTPHandler) dispatchNexusTaskByEndpoint(w http.ResponseWriter, r endpointID, err := url.PathUnescape(endpointIDEscaped) if err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusBadRequest, &nexus.Failure{Message: "invalid URL"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid URL")) return } endpointEntry, err := h.enpointRegistry.GetByID(r.Context(), endpointID) @@ -207,39 +195,40 @@ func (h *NexusHTTPHandler) dispatchNexusTaskByEndpoint(w http.ResponseWriter, r w.Header().Set("nexus-request-retryable", "false") } } - h.writeNexusFailure(w, http.StatusNotFound, &nexus.Failure{Message: "nexus endpoint not found"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "nexus endpoint not found")) case codes.DeadlineExceeded: - h.writeNexusFailure(w, http.StatusRequestTimeout, &nexus.Failure{Message: "request timed out"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeRequestTimeout, "request timed out")) default: - h.writeNexusFailure(w, http.StatusInternalServerError, &nexus.Failure{Message: "internal error"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error")) } return } - nc, ok := h.nexusContextFromEndpoint(endpointEntry, w) + nc, ok := h.nexusContextFromEndpoint(endpointEntry, w, r) if !ok { // nexusContextFromEndpoint already writes the failure response. return } - r, err = h.parseTlsAndAuthInfo(r, nc) + rWithAuthCtx, err := h.parseTlsAndAuthInfo(r, nc) if err != nil { h.logger.Error("failed to get claims", tag.Error(err)) - h.writeNexusFailure(w, http.StatusUnauthorized, &nexus.Failure{Message: "unauthorized"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnauthenticated, "unauthorized")) return } + r = rWithAuthCtx u, err := mux.CurrentRoute(r).URL("endpoint", endpointIDEscaped) if err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusInternalServerError, &nexus.Failure{Message: "internal error"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error")) return } h.serveResolvedURL(w, r, u, nc) } -func (h *NexusHTTPHandler) baseNexusContext(apiName string) *nexusContext { +func (h *NexusHTTPHandler) baseNexusContext(apiName string, header http.Header) *nexusContext { return &nexusContext{ namespaceValidationInterceptor: h.namespaceValidationInterceptor, namespaceRateLimitInterceptor: h.namespaceRateLimitInterceptor, @@ -248,6 +237,7 @@ func (h *NexusHTTPHandler) baseNexusContext(apiName string) *nexusContext { apiName: apiName, requestStartTime: time.Now(), responseHeaders: make(map[string]string), + callerFailureSupport: header.Get("temporal-nexus-failure-support") == "true", } } @@ -255,7 +245,7 @@ func (h *NexusHTTPHandler) baseNexusContext(apiName string) *nexusContext { // endpoint is valid for dispatching. // For security reasons, at the moment only worker target endpoints are considered valid, in the future external // endpoints may also be supported. -func (h *NexusHTTPHandler) nexusContextFromEndpoint(entry *persistencespb.NexusEndpointEntry, w http.ResponseWriter) (*nexusContext, bool) { +func (h *NexusHTTPHandler) nexusContextFromEndpoint(entry *persistencespb.NexusEndpointEntry, w http.ResponseWriter, r *http.Request) (*nexusContext, bool) { switch v := entry.Endpoint.Spec.GetTarget().GetVariant().(type) { case *persistencespb.NexusEndpointTarget_Worker_: nsName, err := h.namespaceRegistry.GetNamespaceName(namespace.ID(v.Worker.GetNamespaceId())) @@ -264,20 +254,20 @@ func (h *NexusHTTPHandler) nexusContextFromEndpoint(entry *persistencespb.NexusE var notFoundErr *serviceerror.NamespaceNotFound if errors.As(err, ¬FoundErr) { w.Header().Set("nexus-request-retryable", "true") - h.writeNexusFailure(w, http.StatusNotFound, &nexus.Failure{Message: "invalid endpoint target"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeNotFound, "invalid endpoint target")) } else { - h.writeNexusFailure(w, http.StatusInternalServerError, &nexus.Failure{Message: "internal error"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error")) } return nil, false } - nc := h.baseNexusContext(configs.DispatchNexusTaskByEndpointAPIName) + nc := h.baseNexusContext(configs.DispatchNexusTaskByEndpointAPIName, r.Header) nc.namespaceName = nsName.String() nc.taskQueue = v.Worker.GetTaskQueue() nc.endpointName = entry.Endpoint.Spec.Name nc.endpointID = entry.Id return nc, true default: - h.writeNexusFailure(w, http.StatusBadRequest, &nexus.Failure{Message: "invalid endpoint target"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeBadRequest, "invalid endpoint target")) return nil, false } } @@ -328,7 +318,7 @@ func (h *NexusHTTPHandler) serveResolvedURL(w http.ResponseWriter, r *http.Reque prefix, err := url.PathUnescape(u.Path) if err != nil { h.logger.Error("invalid URL", tag.Error(err)) - h.writeNexusFailure(w, http.StatusInternalServerError, &nexus.Failure{Message: "internal error"}) + h.writeFailure(w, r, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error")) return } prefix = path.Dir(prefix) diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index a9d5e42212..9991a0239e 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -5582,8 +5582,10 @@ func (wh *WorkflowHandler) RespondNexusTaskCompleted(ctx context.Context, reques // doesn't go into workflow history, and the Nexus request caller is unknown, there doesn't seem like there's a // good reason to fail at this point. - if details := request.GetResponse().GetStartOperation().GetOperationError().GetFailure().GetDetails(); details != nil && !json.Valid(details) { - return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") + if opErr := request.GetResponse().GetStartOperation().GetOperationError(); opErr != nil { + if details := opErr.GetFailure().GetDetails(); details != nil && !json.Valid(details) { + return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") + } } matchingRequest := &matchingservice.RespondNexusTaskCompletedRequest{ @@ -5623,8 +5625,16 @@ func (wh *WorkflowHandler) RespondNexusTaskFailed(ctx context.Context, request * } namespaceId := namespace.ID(tt.GetNamespaceId()) - if details := request.GetError().GetFailure().GetDetails(); details != nil && !json.Valid(details) { - return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") + if request.Error == nil && request.Failure == nil { // nolint:staticcheck // checking deprecated field for backwards compatibility + return nil, serviceerror.NewInvalidArgument("request must contain error or failure") + } + if request.GetError() != nil { // nolint:staticcheck // checking deprecated field for backwards compatibility + if details := request.GetError().GetFailure().GetDetails(); details != nil && !json.Valid(details) { // nolint:staticcheck // checking deprecated field for backwards compatibility + return nil, serviceerror.NewInvalidArgument("failure details must be JSON serializable") + } + } + if request.GetFailure() != nil && request.GetFailure().GetNexusHandlerFailureInfo() == nil { + return nil, serviceerror.NewInvalidArgument("request Failure must contain error or failure with NexusHandlerFailureInfo") } // NOTE: Not checking blob size limit here as we already enforce the 4 MB gRPC request limit and since this diff --git a/service/history/handler.go b/service/history/handler.go index 734d8a0b26..3409602ac4 100644 --- a/service/history/handler.go +++ b/service/history/handler.go @@ -36,6 +36,7 @@ import ( "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" commonnexus "go.temporal.io/server/common/nexus" + "go.temporal.io/server/common/nexus/nexusrpc" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/persistence/serialization" "go.temporal.io/server/common/persistence/visibility/manager" @@ -2157,11 +2158,22 @@ func (h *Handler) CompleteNexusOperation(ctx context.Context, request *historyse } var opErr *nexus.OperationError if request.State != string(nexus.OperationStateSucceeded) { - opErr = &nexus.OperationError{ - State: nexus.OperationState(request.GetState()), - Cause: &nexus.FailureError{ - Failure: commonnexus.ProtoFailureToNexusFailure(request.GetFailure()), - }, + failure := commonnexus.ProtoFailureToNexusFailure(request.GetFailure()) + recvdErr, err := nexusrpc.DefaultFailureConverter().FailureToError(failure) + if err != nil { + return nil, serviceerror.NewInvalidArgument("unable to convert failure to error") + } + // Backward compatibility: if the received error is not of type OperationError, wrap the error in OperationError. + var ok bool + if opErr, ok = recvdErr.(*nexus.OperationError); !ok { + opErr = &nexus.OperationError{ + State: nexus.OperationState(request.GetState()), + Message: "nexus operation completed unsuccessfully", + Cause: recvdErr, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + return nil, serviceerror.NewInvalidArgument("unable to convert operation error to failure") + } } } err = nexusoperations.CompletionHandler( diff --git a/service/history/workflow/mutable_state_impl.go b/service/history/workflow/mutable_state_impl.go index dd59c02102..6492ee6c47 100644 --- a/service/history/workflow/mutable_state_impl.go +++ b/service/history/workflow/mutable_state_impl.go @@ -722,10 +722,10 @@ func (ms *MutableStateImpl) ChasmWorkflowComponentReadOnly(ctx context.Context) func (ms *MutableStateImpl) GetNexusCompletion( ctx context.Context, requestID string, -) (nexusrpc.OperationCompletion, error) { +) (nexusrpc.CompleteOperationOptions, error) { ce, err := ms.GetCompletionEvent(ctx) if err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } // Create the link information about the workflow to be attached to fabricated started event if completion is @@ -767,30 +767,33 @@ func (ms *MutableStateImpl) GetNexusCompletion( // Nexus does not support it. p = payloads[0] } - completion, err := nexusrpc.NewOperationCompletionSuccessful(p, nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - StartTime: ms.executionState.GetStartTime().AsTime(), - CloseTime: ce.GetEventTime().AsTime(), - Links: []nexus.Link{startLink}, - }) - if err != nil { - return nil, serviceerror.NewInternalf("failed to construct Nexus completion: %v", err) - } - return completion, nil + return nexusrpc.CompleteOperationOptions{ + Result: p, + StartTime: ms.executionState.GetStartTime().AsTime(), + CloseTime: ce.GetEventTime().AsTime(), + Links: []nexus.Link{startLink}, + }, nil case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_FAILED: - f, err := commonnexus.APIFailureToNexusFailure(ce.GetWorkflowExecutionFailedEventAttributes().GetFailure()) + f, err := commonnexus.TemporalFailureToNexusFailure(ce.GetWorkflowExecutionFailedEventAttributes().GetFailure()) if err != nil { - return nil, err - } - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: f}}, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - StartTime: ms.executionState.GetStartTime().AsTime(), - CloseTime: ce.GetEventTime().AsTime(), - Links: []nexus.Link{startLink}, - }) + return nexusrpc.CompleteOperationOptions{}, err + } + opErr := &nexus.OperationError{ + Message: "operation failed", + State: nexus.OperationStateFailed, + Cause: &nexus.FailureError{Failure: f}, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + return nexusrpc.CompleteOperationOptions{}, err + } + return nexusrpc.CompleteOperationOptions{ + Error: opErr, + StartTime: ms.executionState.GetStartTime().AsTime(), + CloseTime: ce.GetEventTime().AsTime(), + Links: []nexus.Link{startLink}, + }, nil case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED: - f, err := commonnexus.APIFailureToNexusFailure(&failurepb.Failure{ + f, err := commonnexus.TemporalFailureToNexusFailure(&failurepb.Failure{ Message: "operation canceled", FailureInfo: &failurepb.Failure_CanceledFailureInfo{ CanceledFailureInfo: &failurepb.CanceledFailureInfo{ @@ -799,37 +802,48 @@ func (ms *MutableStateImpl) GetNexusCompletion( }, }) if err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{ - State: nexus.OperationStateCanceled, - Cause: &nexus.FailureError{Failure: f}, - }, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - StartTime: ms.executionState.GetStartTime().AsTime(), - CloseTime: ce.GetEventTime().AsTime(), - Links: []nexus.Link{startLink}, - }) + opErr := &nexus.OperationError{ + State: nexus.OperationStateCanceled, + Message: "operation canceled", + Cause: &nexus.FailureError{Failure: f}, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + return nexusrpc.CompleteOperationOptions{}, err + } + return nexusrpc.CompleteOperationOptions{ + Error: opErr, + StartTime: ms.executionState.GetStartTime().AsTime(), + CloseTime: ce.GetEventTime().AsTime(), + Links: []nexus.Link{startLink}, + }, nil case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED: - f, err := commonnexus.APIFailureToNexusFailure(&failurepb.Failure{ + f, err := commonnexus.TemporalFailureToNexusFailure(&failurepb.Failure{ Message: "operation terminated", FailureInfo: &failurepb.Failure_TerminatedFailureInfo{ TerminatedFailureInfo: &failurepb.TerminatedFailureInfo{}, }, }) if err != nil { - return nil, err - } - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: f}}, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - StartTime: ms.executionState.GetStartTime().AsTime(), - CloseTime: ce.GetEventTime().AsTime(), - Links: []nexus.Link{startLink}, - }) + return nexusrpc.CompleteOperationOptions{}, err + } + opErr := &nexus.OperationError{ + State: nexus.OperationStateFailed, + Message: "operation failed", + Cause: &nexus.FailureError{Failure: f}, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + return nexusrpc.CompleteOperationOptions{}, err + } + return nexusrpc.CompleteOperationOptions{ + Error: opErr, + StartTime: ms.executionState.GetStartTime().AsTime(), + CloseTime: ce.GetEventTime().AsTime(), + Links: []nexus.Link{startLink}, + }, nil case enumspb.EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT: - f, err := commonnexus.APIFailureToNexusFailure(&failurepb.Failure{ + f, err := commonnexus.TemporalFailureToNexusFailure(&failurepb.Failure{ Message: "operation exceeded internal timeout", FailureInfo: &failurepb.Failure_TimeoutFailureInfo{ TimeoutFailureInfo: &failurepb.TimeoutFailureInfo{ @@ -839,20 +853,24 @@ func (ms *MutableStateImpl) GetNexusCompletion( }, }) if err != nil { - return nil, err + return nexusrpc.CompleteOperationOptions{}, err } - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{ - State: nexus.OperationStateFailed, - Cause: &nexus.FailureError{Failure: f}, - }, - nexusrpc.OperationCompletionUnsuccessfulOptions{ - StartTime: ms.executionState.GetStartTime().AsTime(), - CloseTime: ce.GetEventTime().AsTime(), - Links: []nexus.Link{startLink}, - }) + opErr := &nexus.OperationError{ + State: nexus.OperationStateFailed, + Message: "operation failed", + Cause: &nexus.FailureError{Failure: f}, + } + if err := nexusrpc.MarkAsWrapperError(nexusrpc.DefaultFailureConverter(), opErr); err != nil { + return nexusrpc.CompleteOperationOptions{}, err + } + return nexusrpc.CompleteOperationOptions{ + Error: opErr, + StartTime: ms.executionState.GetStartTime().AsTime(), + CloseTime: ce.GetEventTime().AsTime(), + Links: []nexus.Link{startLink}, + }, nil } - return nil, serviceerror.NewInternalf("invalid workflow execution status: %v", ce.GetEventType()) + return nexusrpc.CompleteOperationOptions{}, serviceerror.NewInternalf("invalid workflow execution status: %v", ce.GetEventType()) } // GetHSMCallbackArg converts a workflow completion event into a [persistencespb.HSMCallbackArg]. diff --git a/service/history/workflow/workflow_test/mutable_state_impl_test.go b/service/history/workflow/workflow_test/mutable_state_impl_test.go index 1e06c10238..fe97c59f4b 100644 --- a/service/history/workflow/workflow_test/mutable_state_impl_test.go +++ b/service/history/workflow/workflow_test/mutable_state_impl_test.go @@ -6,7 +6,6 @@ package workflow_test import ( "context" "fmt" - "io" "math" "testing" "time" @@ -363,7 +362,7 @@ func TestGetNexusCompletion(t *testing.T) { cases := []struct { name string mutateState func(historyi.MutableState) (*historypb.HistoryEvent, error) - verifyCompletion func(*testing.T, *historypb.HistoryEvent, nexusrpc.OperationCompletion) + verifyCompletion func(*testing.T, *historypb.HistoryEvent, nexusrpc.CompleteOperationOptions) }{ { name: "success", @@ -379,15 +378,13 @@ func TestGetNexusCompletion(t *testing.T) { }, }, "") }, - verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { - success, ok := completion.(*nexusrpc.OperationCompletionSuccessful) - require.True(t, ok) - require.Equal(t, "application/json", success.Reader.Header.Get("type")) - require.Equal(t, "1", success.Reader.Header.Get("length")) - buf, err := io.ReadAll(success.Reader) - require.NoError(t, err) - require.Equal(t, []byte("3"), buf) - require.Equal(t, event.GetEventTime().AsTime(), success.CloseTime) + verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { + require.Nil(t, completion.Error) + require.Equal(t, &commonpb.Payload{ + Metadata: map[string][]byte{"encoding": []byte("json/plain")}, + Data: []byte("3"), + }, completion.Result) + require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, { @@ -399,12 +396,11 @@ func TestGetNexusCompletion(t *testing.T) { }, }, "") }, - verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { - failure, ok := completion.(*nexusrpc.OperationCompletionUnsuccessful) - require.True(t, ok) - require.Equal(t, nexus.OperationStateFailed, failure.State) - require.Equal(t, "workflow failed", failure.Failure.Message) - require.Equal(t, event.GetEventTime().AsTime(), failure.CloseTime) + verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { + require.NotNil(t, completion.Error) + require.Equal(t, nexus.OperationStateFailed, completion.Error.State) + require.Equal(t, "workflow failed", completion.Error.Cause.Error()) + require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, { @@ -412,12 +408,11 @@ func TestGetNexusCompletion(t *testing.T) { mutateState: func(mutableState historyi.MutableState) (*historypb.HistoryEvent, error) { return mutableState.AddWorkflowExecutionTerminatedEvent(mutableState.GetNextEventID(), "dont care", nil, "identity", false, nil) }, - verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { - failure, ok := completion.(*nexusrpc.OperationCompletionUnsuccessful) - require.True(t, ok) - require.Equal(t, nexus.OperationStateFailed, failure.State) - require.Equal(t, "operation terminated", failure.Failure.Message) - require.Equal(t, event.GetEventTime().AsTime(), failure.CloseTime) + verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { + require.NotNil(t, completion.Error) + require.Equal(t, nexus.OperationStateFailed, completion.Error.State) + require.Equal(t, "operation terminated", completion.Error.Cause.Error()) + require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, { @@ -425,12 +420,11 @@ func TestGetNexusCompletion(t *testing.T) { mutateState: func(mutableState historyi.MutableState) (*historypb.HistoryEvent, error) { return mutableState.AddWorkflowExecutionCanceledEvent(mutableState.GetNextEventID(), &commandpb.CancelWorkflowExecutionCommandAttributes{}) }, - verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.OperationCompletion) { - failure, ok := completion.(*nexusrpc.OperationCompletionUnsuccessful) - require.True(t, ok) - require.Equal(t, nexus.OperationStateCanceled, failure.State) - require.Equal(t, "operation canceled", failure.Failure.Message) - require.Equal(t, event.GetEventTime().AsTime(), failure.CloseTime) + verifyCompletion: func(t *testing.T, event *historypb.HistoryEvent, completion nexusrpc.CompleteOperationOptions) { + require.NotNil(t, completion.Error) + require.Equal(t, nexus.OperationStateCanceled, completion.Error.State) + require.Equal(t, "operation canceled", completion.Error.Cause.Error()) + require.Equal(t, event.GetEventTime().AsTime(), completion.CloseTime) }, }, } diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index f146c30d0b..b3c535b919 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -2500,8 +2500,14 @@ func (e *matchingEngineImpl) DispatchNexusTask(ctx context.Context, request *mat return nil, result.internalError } if result.failedWorkerResponse != nil { - return &matchingservice.DispatchNexusTaskResponse{Outcome: &matchingservice.DispatchNexusTaskResponse_HandlerError{ - HandlerError: result.failedWorkerResponse.GetRequest().GetError(), + if result.failedWorkerResponse.GetRequest().GetError() != nil { // nolint:staticcheck // checking deprecated field for backwards compatibility + // Deprecated case. Kept for backwards-compatibility with older SDKs that are sending errors instead of failures. + return &matchingservice.DispatchNexusTaskResponse{Outcome: &matchingservice.DispatchNexusTaskResponse_HandlerError{ + HandlerError: result.failedWorkerResponse.GetRequest().GetError(), // nolint:staticcheck // checking deprecated field for backwards compatibility + }}, nil + } + return &matchingservice.DispatchNexusTaskResponse{Outcome: &matchingservice.DispatchNexusTaskResponse_Failure{ + Failure: result.failedWorkerResponse.GetRequest().GetFailure(), }}, nil } diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index ac9d4a2309..5247903a7f 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -364,7 +364,7 @@ func (s *CallbacksSuite) TestWorkflowNexusCallbacks_CarriedOver() { var err error if attempt < numAttempts { // force retry - err = nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "intentional error") + err = nexus.NewHandlerErrorf(nexus.HandlerErrorTypeInternal, "intentional error") } ch.requestCompleteCh <- err } diff --git a/tests/nexus_api_test.go b/tests/nexus_api_test.go index 5bca780c1e..566ac2ea04 100644 --- a/tests/nexus_api_test.go +++ b/tests/nexus_api_test.go @@ -203,6 +203,8 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, nexus.HandlerErrorRetryBehaviorUnspecified, handlerErr.RetryBehavior) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) + require.Empty(t, handlerErr.Message) + require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, }, @@ -223,6 +225,8 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, nexus.HandlerErrorRetryBehaviorNonRetryable, handlerErr.RetryBehavior) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) + require.Empty(t, handlerErr.Message) + require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, }, @@ -248,7 +252,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Outcomes() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUpstreamTimeout, handlerErr.Type) - require.Equal(t, "upstream timeout", handlerErr.Cause.Error()) + require.Equal(t, "upstream timeout", handlerErr.Message) }, }, } @@ -343,7 +347,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_WithNamespaceAndTaskQueue_Na var handlerError *nexus.HandlerError s.ErrorAs(err, &handlerError) s.Equal(nexus.HandlerErrorTypeNotFound, handlerError.Type) - s.Equal(fmt.Sprintf("namespace not found: %q", namespace), handlerError.Cause.Error()) + s.Equal(fmt.Sprintf("namespace not found: %q", namespace), handlerError.Message) snap := capture.Snapshot() @@ -371,7 +375,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_WithNamespaceAndTaskQueue_Na s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) // I wish we'd never put periods in error messages :( - s.Equal("Namespace length exceeds limit.", handlerErr.Cause.Error()) + s.Equal("Namespace length exceeds limit.", handlerErr.Message) snap := capture.Snapshot() @@ -406,7 +410,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Forbidden() { }, checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied: unauthorized in test", handlerErr.Cause.Error()) + require.Equal(t, "permission denied: unauthorized in test", handlerErr.Message) }, expectedOutcomeMetric: "unauthorized", exposeAuthorizerErrors: false, @@ -427,7 +431,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Forbidden() { }, checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied", handlerErr.Cause.Error()) + require.Equal(t, "permission denied", handlerErr.Message) }, expectedOutcomeMetric: "unauthorized", exposeAuthorizerErrors: false, @@ -448,7 +452,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Forbidden() { }, checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied", handlerErr.Cause.Error()) + require.Equal(t, "permission denied", handlerErr.Message) }, expectedOutcomeMetric: "unauthorized", exposeAuthorizerErrors: false, @@ -457,19 +461,19 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Forbidden() { name: "deny with exposed error", onAuthorize: func(ctx context.Context, c *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { if ct.APIName == configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName { - return authorization.Result{}, nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") + return authorization.Result{}, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") } if ct.APIName == configs.DispatchNexusTaskByEndpointAPIName { if ct.NexusEndpointName != testEndpoint.Spec.Name { panic("expected nexus endpoint name") } - return authorization.Result{}, nexus.HandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") + return authorization.Result{}, nexus.NewHandlerErrorf(nexus.HandlerErrorTypeUnavailable, "exposed error") } return authorization.Result{Decision: authorization.DecisionAllow}, nil }, checkFailure: func(t *testing.T, handlerErr *nexus.HandlerError) { require.Equal(t, nexus.HandlerErrorTypeUnavailable, handlerErr.Type) - require.Equal(t, "exposed error", handlerErr.Cause.Error()) + require.Equal(t, "exposed error", handlerErr.Message) }, expectedOutcomeMetric: "internal_auth_error", exposeAuthorizerErrors: true, @@ -536,7 +540,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) - require.Equal(t, "permission denied", handlerErr.Cause.Error()) + require.Equal(t, "permission denied", handlerErr.Message) require.Equal(t, 0, len(snap["nexus_request_preprocess_errors"])) }, }, @@ -549,7 +553,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUnauthenticated, handlerErr.Type) - require.Equal(t, "unauthorized", handlerErr.Cause.Error()) + require.Equal(t, "unauthorized", handlerErr.Message) require.Equal(t, 1, len(snap["nexus_request_preprocess_errors"])) }, }, @@ -598,21 +602,12 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_Claims() { go s.nexusTaskPoller(ctx, taskQueue, tc.handler) } - var result *nexusrpc.ClientStartOperationResponse[string] - var snap map[string][]*metricstest.CapturedRecording - - // Wait until the endpoint is loaded into the registry. - s.Eventually(func() bool { - capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() - defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - - result, err = nexusrpc.StartOperation(ctx, client, op, "input", nexus.StartOperationOptions{ - Header: tc.header, - }) - snap = capture.Snapshot() - var handlerErr *nexus.HandlerError - return err == nil || !(errors.As(err, &handlerErr) && handlerErr.Type == nexus.HandlerErrorTypeNotFound) - }, 10*time.Second, 1*time.Second) + capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() + result, err := nexusrpc.StartOperation(ctx, client, op, "input", nexus.StartOperationOptions{ + Header: tc.header, + }) + snap := capture.Snapshot() + s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) tc.assertion(t, result, err, snap) } @@ -662,7 +657,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_PayloadSizeLimit() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeBadRequest, handlerErr.Type) - require.Equal(t, "input exceeds size limit", handlerErr.Cause.Error()) + require.Equal(t, "input exceeds size limit", handlerErr.Message) } s.T().Run("ByNamespaceAndTaskQueue", func(t *testing.T) { @@ -722,6 +717,8 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) require.Equal(t, "worker", headers.Get("Temporal-Nexus-Failure-Source")) + require.Empty(t, handlerErr.Message) + require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) }, }, @@ -741,7 +738,7 @@ func (s *NexusApiTestSuite) TestNexusCancelOperation_Outcomes() { var handlerErr *nexus.HandlerError require.ErrorAs(t, err, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUpstreamTimeout, handlerErr.Type) - require.Equal(t, "upstream timeout", handlerErr.Cause.Error()) + require.Equal(t, "upstream timeout", handlerErr.Message) }, }, } @@ -978,7 +975,7 @@ func (s *NexusApiTestSuite) TestNexusStartOperation_ByEndpoint_EndpointNotFound( var handlerErr *nexus.HandlerError s.ErrorAs(err, &handlerErr) s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) - s.Equal("nexus endpoint not found", handlerErr.Cause.Error()) + s.Equal("nexus endpoint not found", handlerErr.Message) snap := capture.Snapshot() s.Equal(1, len(snap["nexus_request_preprocess_errors"])) } diff --git a/tests/nexus_workflow_test.go b/tests/nexus_workflow_test.go index 0f76a45d7d..b1958bd90a 100644 --- a/tests/nexus_workflow_test.go +++ b/tests/nexus_workflow_test.go @@ -1,13 +1,10 @@ package tests import ( - "bytes" "context" "encoding/json" "errors" "fmt" - "io" - "net/http" "slices" "strings" "testing" @@ -656,23 +653,20 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { protorequire.ProtoEqual(s.T(), handlerLink, l) // Completion request fails if the result payload is too large. - largeCompletion, err := nexusrpc.NewOperationCompletionSuccessful( + largeCompletion := nexusrpc.CompleteOperationOptions{ // Use -10 to avoid hitting MaxNexusAPIRequestBodyBytes. Actual payload will still exceed limit because of // additional Content headers. See common/rpc/grpc.go:66 - s.mustToPayload(strings.Repeat("a", (2*1024*1024)-10)), - nexusrpc.OperationCompletionSuccessfulOptions{Serializer: commonnexus.PayloadSerializer}, - ) + Result: s.mustToPayload(strings.Repeat("a", (2*1024*1024)-10)), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } s.NoError(err) - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, largeCompletion, callbackToken) - s.Equal(http.StatusBadRequest, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, largeCompletion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_bad_request"}) - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload(nil), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) - s.NoError(err) - invalidNamespace := testcore.RandomizeStr("ns") _, err = s.FrontendClient().RegisterNamespace(ctx, &workflowservice.RegisterNamespaceRequest{ Namespace: invalidNamespace, @@ -683,11 +677,15 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { // Send an invalid completion request and verify that we get an error that the namespace in the URL doesn't match the namespace in the token. invalidCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(invalidNamespace) - res, _, body := s.sendNexusCompletionRequest(ctx, s.T(), invalidCallbackURL, completion, callbackToken) - + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } + _, err = s.sendNexusCompletionRequest(ctx, invalidCallbackURL, completion) // Verify we get the correct error response - s.Equal(http.StatusBadRequest, res.StatusCode) - s.Contains(body, "invalid callback token", "Response should indicate namespace mismatch") + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + s.Contains(handlerErr.Error(), "invalid callback token", "Response should indicate namespace mismatch") // Manipulate the token to verify we get the expected errors in the API. gen := &commonnexus.CallbackTokenGenerator{} @@ -701,9 +699,11 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { workflowNotFoundToken.WorkflowId = "not-found" callbackToken, err = gen.Tokenize(workflowNotFoundToken) s.NoError(err) + completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} - res, snap, _ = s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_not_found"}) @@ -712,23 +712,20 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { staleToken.Ref.MachineInitialVersionedTransition.NamespaceFailoverVersion++ callbackToken, err = gen.Tokenize(staleToken) s.NoError(err) + completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} - res, snap, _ = s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_not_found"}) - // Send a valid - successful completion request. - completion, err = nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) - s.NoError(err) - callbackToken, err = gen.Tokenize(completionToken) s.NoError(err) + completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} - res, snap, _ = s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusOK, res.StatusCode) + snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + s.NoError(err) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "success"}) // Ensure that CompleteOperation request is tracked as part of normal service telemetry metrics @@ -739,8 +736,9 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletion() { s.Greater(idx, -1) // Resend the request and verify we get a not found error since the operation has already completed. - res, snap, _ = s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err = s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_not_found"}) @@ -987,7 +985,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() s.NoError(err) expectedLinks := []*commonpb.Link_WorkflowEvent{ - &commonpb.Link_WorkflowEvent{ + { Namespace: s.Namespace().String(), WorkflowId: completionWFID, RunId: completionWfRunIDs[0], @@ -998,7 +996,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionBeforeStart() }, }, }, - &commonpb.Link_WorkflowEvent{ + { Namespace: s.Namespace().String(), WorkflowId: completionWFID, RunId: completionWfRunIDs[1], @@ -1169,10 +1167,12 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { s.Greater(startedEventIdx, 0) // Send a valid - failed completion request. - completion, err := nexusrpc.NewOperationCompletionUnsuccessful(nexus.NewOperationFailedError("test operation failed"), nexusrpc.OperationCompletionUnsuccessfulOptions{}) + completion := nexusrpc.CompleteOperationOptions{ + Error: nexus.NewOperationFailedErrorf("test operation failed"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) s.NoError(err) - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusOK, res.StatusCode) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "success"}) @@ -1209,32 +1209,39 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncFailure() { var result string err = run.Get(ctx, &result) var wee *temporal.WorkflowExecutionError - s.ErrorAs(err, &wee) - s.True(strings.HasPrefix(wee.Unwrap().Error(), "nexus operation completed unsuccessfully")) + + var noe *temporal.NexusOperationError + s.ErrorAs(wee, &noe) + s.Contains(noe.Error(), "test operation failed") } func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { ctx := testcore.NewContext() - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) - s.NoError(err) - s.Run("ConfigDisabled", func() { s.OverrideDynamicConfig(dynamicconfig.EnableNexus, false) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + } publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, "") - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_request_preprocess_errors"])) }) s.Run("ConfigDisabledNoIdentifier", func() { s.OverrideDynamicConfig(dynamicconfig.EnableNexus, false) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + } publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, "") - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_request_preprocess_errors"])) }) @@ -1244,8 +1251,14 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.NoError(err) publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path("namespace-doesnt-exist") - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, tokenWithBadNamespace) - s.Equal(http.StatusNotFound, res.StatusCode) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: tokenWithBadNamespace}, + } + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_request_preprocess_errors"])) }) @@ -1255,26 +1268,34 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.NoError(err) publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, tokenWithBadNamespace) - s.Equal(http.StatusNotFound, res.StatusCode) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: tokenWithBadNamespace}, + } + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_request_preprocess_errors"])) }) s.Run("OperationTokenTooLong", func() { publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - OperationToken: strings.Repeat("long", 2000), - }) - s.NoError(err) // Generate a valid callback token to get past initial validation namespaceID := s.GetNamespaceID(s.Namespace().String()) validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + OperationToken: strings.Repeat("long", 2000), + Header: nexus.Header{commonnexus.CallbackTokenHeader: validToken}, + } - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, validToken) - s.Equal(http.StatusBadRequest, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Equal(0, len(snap["nexus_completion_request_preprocess_errors"])) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_bad_request"}) @@ -1282,44 +1303,54 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { s.Run("OperationTokenTooLongNoIdentifier", func() { publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - OperationToken: strings.Repeat("long", 2000), - }) - s.NoError(err) - // Generate a valid callback token to get past initial validation namespaceID := s.GetNamespaceID(s.Namespace().String()) validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, validToken) - s.Equal(http.StatusBadRequest, res.StatusCode) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + OperationToken: strings.Repeat("long", 2000), + Header: nexus.Header{commonnexus.CallbackTokenHeader: validToken}, + } + + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Equal(0, len(snap["nexus_completion_request_preprocess_errors"])) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "error_bad_request"}) }) s.Run("InvalidCallbackToken", func() { + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + } publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) // metrics collection is not initialized before callback validation // Send request without callback token, helper does not add token if blank - res, _, body := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, "") - + _, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) // Verify we get the correct error response - s.Equal(http.StatusBadRequest, res.StatusCode) - s.Contains(string(body), "invalid callback token", "Response should indicate invalid callback token") + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + s.Contains(handlerErr.Error(), "invalid callback token", "Response should indicate invalid callback token") }) s.Run("InvalidCallbackTokenNoIdentifier", func() { + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + } publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier // metrics collection is not initialized before callback validation // Send request without callback token, helper does not add token if blank - res, _, body := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, "") - + _, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) // Verify we get the correct error response - s.Equal(http.StatusBadRequest, res.StatusCode) - s.Contains(string(body), "invalid callback token", "Response should indicate invalid callback token") + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + s.Contains(handlerErr.Error(), "invalid callback token", "Response should indicate invalid callback token") }) s.Run("InvalidClientVersion", func() { @@ -1332,22 +1363,21 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, publicCallbackURL, completion) - s.NoError(err) - req.Header.Set("User-Agent", "Nexus-go-sdk/v99.0.0") - req.Header.Add(commonnexus.CallbackTokenHeader, validToken) - - res, err := http.DefaultClient.Do(req) - s.NoError(err) - _, err = io.ReadAll(res.Body) - s.NoError(err) - defer func() { - err := res.Body.Close() - s.NoError(err) - }() - + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + Header: nexus.Header{ + commonnexus.CallbackTokenHeader: validToken, + "user-agent": "Nexus-go-sdk/v99.0.0", + }, + } + client := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: commonnexus.PayloadSerializer, + }) + err = client.CompleteOperation(ctx, publicCallbackURL, completion) snap := capture.Snapshot() - s.Equal(http.StatusBadRequest, res.StatusCode) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unsupported_client"}) }) @@ -1362,19 +1392,22 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionErrors() { validToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, publicCallbackURL, completion) - s.NoError(err) - req.Header.Set("User-Agent", "Nexus-go-sdk/v99.0.0") - req.Header.Add(commonnexus.CallbackTokenHeader, validToken) - - res, err := http.DefaultClient.Do(req) - s.NoError(err) - _, err = io.ReadAll(res.Body) - s.NoError(err) - defer res.Body.Close() + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + Header: nexus.Header{ + commonnexus.CallbackTokenHeader: validToken, + "user-agent": "Nexus-go-sdk/v99.0.0", + }, + } + client := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: commonnexus.PayloadSerializer, + }) + err = client.CompleteOperation(ctx, publicCallbackURL, completion) snap := capture.Snapshot() - s.Equal(http.StatusBadRequest, res.StatusCode) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unsupported_client"}) }) @@ -1392,19 +1425,21 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrors() { s.GetTestCluster().Host().SetOnAuthorize(onAuthorize) defer s.GetTestCluster().Host().SetOnAuthorize(nil) - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) - s.NoError(err) - // Generate a valid callback token for testing namespaceID := s.GetNamespaceID(s.Namespace().String()) callbackToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } + publicCallbackURL := "http://" + s.HttpAPIAddress() + "/" + commonnexus.RouteCompletionCallback.Path(s.Namespace().String()) - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusForbidden, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unauthorized"}) } @@ -1421,19 +1456,20 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAuthErrorsNoId s.GetTestCluster().Host().SetOnAuthorize(onAuthorize) defer s.GetTestCluster().Host().SetOnAuthorize(nil) - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) - s.NoError(err) - // Generate a valid callback token for testing namespaceID := s.GetNamespaceID(s.Namespace().String()) callbackToken, err := s.generateValidCallbackToken(namespaceID, testcore.RandomizeStr("workflow"), uuid.NewString()) s.NoError(err) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } publicCallbackURL := "http://" + s.HttpAPIAddress() + commonnexus.PathCompletionCallbackNoIdentifier - res, snap, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackURL, completion, callbackToken) - s.Equal(http.StatusForbidden, res.StatusCode) + snap, err := s.sendNexusCompletionRequest(ctx, publicCallbackURL, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeUnauthorized, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": s.Namespace().String(), "outcome": "unauthorized"}) } @@ -1873,14 +1909,13 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationAsyncCompletionAfterReset() { } } s.True(seenStartedEvent) - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: commonnexus.PayloadSerializer, - }) + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload("result"), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } + _, err = s.sendNexusCompletionRequest(ctx, publicCallbackUrl, completion) s.NoError(err) - res, _, _ := s.sendNexusCompletionRequest(ctx, s.T(), publicCallbackUrl, completion, callbackToken) - s.Equal(http.StatusOK, res.StatusCode) - // Poll again and verify the completion is recorded and triggers workflow progress. pollResp, err = s.FrontendClient().PollWorkflowTaskQueue(ctx, &workflowservice.PollWorkflowTaskQueueRequest{ Namespace: s.Namespace().String(), @@ -2131,8 +2166,9 @@ func (s *NexusWorkflowTestSuite) TestNexusSyncOperationErrorRehydration() { checkWorkflowError: func(t *testing.T, wfErr error) { var opErr *temporal.NexusOperationError require.ErrorAs(t, wfErr, &opErr) + require.Equal(t, "nexus operation completed unsuccessfully", opErr.Message) var appErr *temporal.ApplicationError - require.ErrorAs(t, opErr, &appErr) + require.ErrorAs(t, opErr.Cause, &appErr) require.Equal(t, "some error", appErr.Message()) }, }, @@ -2501,6 +2537,7 @@ func (s *NexusWorkflowTestSuite) TestNexusOperationSyncNexusFailure() { var handlerErr *nexus.HandlerError s.ErrorAs(wfErr, &handlerErr) s.Equal(nexus.HandlerErrorTypeBadRequest, handlerErr.Type) + // Old SDK path var appErr *temporal.ApplicationError s.ErrorAs(handlerErr.Cause, &appErr) s.Equal(appErr.Message(), "fail me") @@ -3058,25 +3095,15 @@ func (s *NexusWorkflowTestSuite) generateValidCallbackToken(namespaceID, workflo func (s *NexusWorkflowTestSuite) sendNexusCompletionRequest( ctx context.Context, - t *testing.T, url string, - completion nexusrpc.OperationCompletion, - callbackToken string, -) (*http.Response, map[string][]*metricstest.CapturedRecording, string) { + completion nexusrpc.CompleteOperationOptions, +) (map[string][]*metricstest.CapturedRecording, error) { capture := s.GetTestCluster().Host().CaptureMetricsHandler().StartCapture() defer s.GetTestCluster().Host().CaptureMetricsHandler().StopCapture(capture) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, url, completion) - require.NoError(t, err) - if callbackToken != "" { - req.Header.Add(commonnexus.CallbackTokenHeader, callbackToken) - } - res, err := http.DefaultClient.Do(req) - require.NoError(t, err) - responseBody := res.Body - body, err := io.ReadAll(responseBody) - require.NoError(t, err) - require.NoError(t, res.Body.Close()) - res.Body = io.NopCloser(bytes.NewReader(body)) - return res, capture.Snapshot(), string(body) + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: commonnexus.PayloadSerializer, + }) + err := c.CompleteOperation(ctx, url, completion) + return capture.Snapshot(), err } diff --git a/tests/xdc/nexus_request_forwarding_test.go b/tests/xdc/nexus_request_forwarding_test.go index de8670a14c..610c4c62c6 100644 --- a/tests/xdc/nexus_request_forwarding_test.go +++ b/tests/xdc/nexus_request_forwarding_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "slices" "testing" @@ -191,6 +190,8 @@ func (s *NexusRequestForwardingSuite) TestStartOperationForwardedFromStandbyToAc var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) + require.Empty(t, handlerErr.Message) + require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) requireExpectedMetricsCaptured(t, activeSnap, ns, "StartNexusOperation", "handler_error:INTERNAL") requireExpectedMetricsCaptured(t, passiveSnap, ns, "StartNexusOperation", "forwarded_request_error") @@ -211,7 +212,7 @@ func (s *NexusRequestForwardingSuite) TestStartOperationForwardedFromStandbyToAc var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUnavailable, handlerErr.Type) - require.Equal(t, "cluster inactive", handlerErr.Cause.Error()) + require.Equal(t, "cluster inactive", handlerErr.Message) requireExpectedMetricsCaptured(t, passiveSnap, ns, "StartNexusOperation", "namespace_inactive_forwarding_disabled") }, }, @@ -315,6 +316,8 @@ func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToA var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeInternal, handlerErr.Type) + require.Empty(t, handlerErr.Message) + require.Error(t, handlerErr.Cause) require.Equal(t, "deliberate internal failure", handlerErr.Cause.Error()) requireExpectedMetricsCaptured(t, activeSnap, ns, "CancelNexusOperation", "handler_error:INTERNAL") requireExpectedMetricsCaptured(t, passiveSnap, ns, "CancelNexusOperation", "forwarded_request_error") @@ -335,7 +338,7 @@ func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToA var handlerErr *nexus.HandlerError require.ErrorAs(t, retErr, &handlerErr) require.Equal(t, nexus.HandlerErrorTypeUnavailable, handlerErr.Type) - require.Equal(t, "cluster inactive", handlerErr.Cause.Error()) + require.Equal(t, "cluster inactive", handlerErr.Message) requireExpectedMetricsCaptured(t, passiveSnap, ns, "CancelNexusOperation", "namespace_inactive_forwarding_disabled") }, }, @@ -378,16 +381,14 @@ func (s *NexusRequestForwardingSuite) TestCancelOperationForwardedFromStandbyToA func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandbyToActive() { testCases := []struct { name string - getCompletionFn func() (nexusrpc.OperationCompletion, error) + getCompletionFn func() nexusrpc.CompleteOperationOptions assertHistoryAndGetCompleteWF func(*testing.T, []*historypb.HistoryEvent) *workflowservice.RespondWorkflowTaskCompletedRequest assertResult func(*testing.T, string) }{ { name: "success", - getCompletionFn: func() (nexusrpc.OperationCompletion, error) { - return nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload("result"), nexusrpc.OperationCompletionSuccessfulOptions{ - Serializer: cnexus.PayloadSerializer, - }) + getCompletionFn: func() nexusrpc.CompleteOperationOptions { + return nexusrpc.CompleteOperationOptions{Result: s.mustToPayload("result")} }, assertHistoryAndGetCompleteWF: func(t *testing.T, events []*historypb.HistoryEvent) *workflowservice.RespondWorkflowTaskCompletedRequest { completedEventIdx := slices.IndexFunc(events, func(e *historypb.HistoryEvent) bool { @@ -411,11 +412,14 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb }, { name: "operation error", - getCompletionFn: func() (nexusrpc.OperationCompletion, error) { - f := nexus.Failure{Message: "intentional operation failure"} - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{State: nexus.OperationStateFailed, Cause: &nexus.FailureError{Failure: f}}, - nexusrpc.OperationCompletionUnsuccessfulOptions{}) + getCompletionFn: func() nexusrpc.CompleteOperationOptions { + opErr := &nexus.OperationError{ + State: nexus.OperationStateFailed, + Cause: &nexus.FailureError{Failure: nexus.Failure{Message: "intentional operation failure"}}, + } + return nexusrpc.CompleteOperationOptions{ + Error: opErr, + } }, assertHistoryAndGetCompleteWF: func(t *testing.T, events []*historypb.HistoryEvent) *workflowservice.RespondWorkflowTaskCompletedRequest { failedEventIdx := slices.IndexFunc(events, func(e *historypb.HistoryEvent) bool { @@ -439,11 +443,11 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb }, { name: "canceled", - getCompletionFn: func() (nexusrpc.OperationCompletion, error) { + getCompletionFn: func() nexusrpc.CompleteOperationOptions { f := nexus.Failure{Message: "operation canceled"} - return nexusrpc.NewOperationCompletionUnsuccessful( - &nexus.OperationError{State: nexus.OperationStateCanceled, Cause: &nexus.FailureError{Failure: f}}, - nexusrpc.OperationCompletionUnsuccessfulOptions{}) + return nexusrpc.CompleteOperationOptions{ + Error: &nexus.OperationError{State: nexus.OperationStateCanceled, Cause: &nexus.FailureError{Failure: f}}, + } }, assertHistoryAndGetCompleteWF: func(t *testing.T, events []*historypb.HistoryEvent) *workflowservice.RespondWorkflowTaskCompletedRequest { canceledEventIdx := slices.IndexFunc(events, func(e *historypb.HistoryEvent) bool { @@ -592,10 +596,11 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb return err == nil && len(resp.PendingNexusOperations) > 0 }, 5*time.Second, 500*time.Millisecond) - completion, err := tc.getCompletionFn() + completion := tc.getCompletionFn() + completion.Header = make(nexus.Header, 1) + completion.Header.Set(cnexus.CallbackTokenHeader, callbackToken) + snap, err := s.sendNexusCompletionRequest(ctx, s.T(), s.clusters[1], publicCallbackUrl, completion) s.NoError(err) - res, snap := s.sendNexusCompletionRequest(ctx, s.T(), s.clusters[1], publicCallbackUrl, completion, callbackToken) - s.Equal(http.StatusOK, res.StatusCode) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": ns, "outcome": "request_forwarded"}) @@ -610,8 +615,10 @@ func (s *NexusRequestForwardingSuite) TestOperationCompletionForwardedFromStandb }) // Resend the request and verify we get a not found error since the operation has already completed. - res, snap = s.sendNexusCompletionRequest(ctx, s.T(), s.clusters[0], publicCallbackUrl, completion, callbackToken) - s.Equal(http.StatusNotFound, res.StatusCode) + snap, err = s.sendNexusCompletionRequest(ctx, s.T(), s.clusters[0], publicCallbackUrl, completion) + var handlerErr *nexus.HandlerError + s.ErrorAs(err, &handlerErr) + s.Equal(nexus.HandlerErrorTypeNotFound, handlerErr.Type) s.Equal(1, len(snap["nexus_completion_requests"])) s.Subset(snap["nexus_completion_requests"][0].Tags, map[string]string{"namespace": ns, "outcome": "error_not_found"}) @@ -678,26 +685,19 @@ func (s *NexusRequestForwardingSuite) sendNexusCompletionRequest( t *testing.T, testCluster *testcore.TestCluster, url string, - completion nexusrpc.OperationCompletion, - callbackToken string, -) (*http.Response, map[string][]*metricstest.CapturedRecording) { + completion nexusrpc.CompleteOperationOptions, +) (map[string][]*metricstest.CapturedRecording, error) { metricsHandler, ok := testCluster.Host().GetMetricsHandler().(*metricstest.CaptureHandler) s.True(ok) capture := metricsHandler.StartCapture() defer metricsHandler.StopCapture(capture) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, url, completion) - require.NoError(t, err) - if callbackToken != "" { - req.Header.Add(cnexus.CallbackTokenHeader, callbackToken) - } + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: cnexus.PayloadSerializer, + }) + err := c.CompleteOperation(ctx, url, completion) - res, err := http.DefaultClient.Do(req) - require.NoError(t, err) - _, err = io.ReadAll(res.Body) - require.NoError(t, err) - defer res.Body.Close() - return res, capture.Snapshot() + return capture.Snapshot(), err } func requireExpectedMetricsCaptured(t *testing.T, snap map[string][]*metricstest.CapturedRecording, ns string, method string, expectedOutcome string) { diff --git a/tests/xdc/nexus_state_replication_test.go b/tests/xdc/nexus_state_replication_test.go index 4d84b5cc7c..aa52046bed 100644 --- a/tests/xdc/nexus_state_replication_test.go +++ b/tests/xdc/nexus_state_replication_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "net/http" "net/http/httptest" "slices" @@ -715,40 +714,27 @@ func (s *NexusStateReplicationSuite) waitCallback( } func (s *NexusStateReplicationSuite) completeNexusOperation(ctx context.Context, result any, callbackUrl, callbackToken string) { - completion, err := nexusrpc.NewOperationCompletionSuccessful(s.mustToPayload(result), nexusrpc.OperationCompletionSuccessfulOptions{ + completion := nexusrpc.CompleteOperationOptions{ + Result: s.mustToPayload(result), + Header: nexus.Header{commonnexus.CallbackTokenHeader: callbackToken}, + } + client := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ Serializer: commonnexus.PayloadSerializer, }) + err := client.CompleteOperation(ctx, callbackUrl, completion) s.NoError(err) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackUrl, completion) - s.NoError(err) - if callbackToken != "" { - req.Header.Add(commonnexus.CallbackTokenHeader, callbackToken) - } - - res, err := http.DefaultClient.Do(req) - s.NoError(err) - defer res.Body.Close() - _, err = io.ReadAll(res.Body) - s.NoError(err) - s.Equal(http.StatusOK, res.StatusCode) } func (s *NexusStateReplicationSuite) cancelNexusOperation(ctx context.Context, callbackUrl, callbackToken string) { - completion, err := nexusrpc.NewOperationCompletionUnsuccessful( - nexus.NewOperationCanceledError("operation canceled"), - nexusrpc.OperationCompletionUnsuccessfulOptions{}, - ) - s.NoError(err) - req, err := nexusrpc.NewCompletionHTTPRequest(ctx, callbackUrl, completion) - s.NoError(err) + completion := nexusrpc.CompleteOperationOptions{ + Error: nexus.NewOperationCanceledErrorf("operation canceled"), + } if callbackToken != "" { - req.Header.Add(commonnexus.CallbackTokenHeader, callbackToken) + completion.Header = nexus.Header{commonnexus.CallbackTokenHeader: callbackToken} } - - res, err := http.DefaultClient.Do(req) - s.NoError(err) - defer res.Body.Close() - _, err = io.ReadAll(res.Body) + c := nexusrpc.NewCompletionHTTPClient(nexusrpc.CompletionHTTPClientOptions{ + Serializer: commonnexus.PayloadSerializer, + }) + err := c.CompleteOperation(ctx, callbackUrl, completion) s.NoError(err) - s.Equal(http.StatusOK, res.StatusCode) }