From 71c8912c47f4bf54f85a0d74337c7cfa35926d6c Mon Sep 17 00:00:00 2001 From: Quinn Klassen Date: Thu, 15 Jan 2026 18:29:35 -0800 Subject: [PATCH] Add utility to create Nexus endpoint --- temporalio/testing/_workflow.py | 44 +++++++++++++++++++++++++++++ tests/nexus/test_workflow_caller.py | 37 ++++++++++++++++-------- 2 files changed, 69 insertions(+), 12 deletions(-) diff --git a/temporalio/testing/_workflow.py b/temporalio/testing/_workflow.py index 62da91b6e..aff127431 100644 --- a/temporalio/testing/_workflow.py +++ b/temporalio/testing/_workflow.py @@ -15,6 +15,8 @@ import google.protobuf.empty_pb2 from typing_extensions import Self +import temporalio.api.nexus.v1 +import temporalio.api.operatorservice.v1 import temporalio.api.testservice.v1 import temporalio.bridge.testing import temporalio.client @@ -396,6 +398,48 @@ def supports_time_skipping(self) -> bool: """Whether this environment supports time skipping.""" return False + async def create_nexus_endpoint( + self, endpoint_name: str, task_queue: str + ) -> temporalio.api.nexus.v1.Endpoint: + """Create a Nexus endpoint with the given name and task queue. + + Args: + endpoint_name: The name of the Nexus endpoint to create. + task_queue: The task queue to associate with the endpoint. + + Returns: + The created Nexus endpoint. + """ + response = await self._client.operator_service.create_nexus_endpoint( + temporalio.api.operatorservice.v1.CreateNexusEndpointRequest( + spec=temporalio.api.nexus.v1.EndpointSpec( + name=endpoint_name, + target=temporalio.api.nexus.v1.EndpointTarget( + worker=temporalio.api.nexus.v1.EndpointTarget.Worker( + namespace=self._client.namespace, + task_queue=task_queue, + ) + ), + ) + ) + ) + return response.endpoint + + async def delete_nexus_endpoint( + self, endpoint: temporalio.api.nexus.v1.Endpoint + ) -> None: + """Delete a Nexus endpoint. + + Args: + endpoint: The Nexus endpoint to delete. + """ + await self._client.operator_service.delete_nexus_endpoint( + temporalio.api.operatorservice.v1.DeleteNexusEndpointRequest( + id=endpoint.id, + version=endpoint.version, + ) + ) + @contextmanager def auto_time_skipping_disabled(self) -> Iterator[None]: """Disable any automatic time skipping if this is a time-skipping diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 07c22e688..ef3a396ea 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -52,7 +52,7 @@ from temporalio.worker import Worker from tests.helpers import find_free_port, new_worker from tests.helpers.metrics import PromMetricMatcher -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name # TODO(nexus-prerelease): test availability of Temporal client etc in async context set by worker # TODO(nexus-preview): test worker shutdown, wait_all_completed, drain etc @@ -434,7 +434,8 @@ async def test_sync_operation_happy_path(client: Client, env: WorkflowEnvironmen task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) wf_output = await client.execute_workflow( CallerWorkflow.run, args=[ @@ -471,7 +472,8 @@ async def test_workflow_run_operation_happy_path( task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) wf_output = await client.execute_workflow( CallerWorkflow.run, args=[ @@ -527,7 +529,8 @@ async def test_sync_response( task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) caller_wf_handle = await client.start_workflow( CallerWorkflow.run, args=[ @@ -603,6 +606,7 @@ async def test_async_response( workflow_failure_exception_types=[Exception], ): caller_wf_handle, handler_wf_handle = await _start_wf_and_nexus_op( + env, client, task_queue, exception_in_operation_start, @@ -676,6 +680,7 @@ async def test_async_response( async def _start_wf_and_nexus_op( + env: WorkflowEnvironment, client: Client, task_queue: str, exception_in_operation_start: bool, @@ -689,7 +694,8 @@ async def _start_wf_and_nexus_op( """ Start the caller workflow and wait until the Nexus operation has started. """ - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) operation_workflow_id = str(uuid.uuid4()) # Start the caller workflow and wait until it confirms the Nexus operation has started. @@ -774,7 +780,8 @@ async def test_untyped_caller( op_definition_type=op_definition_type, exception_in_operation_start=exception_in_operation_start, ) - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) caller_wf_handle = await client.start_workflow( UntypedCallerWorkflow.run, args=[ @@ -941,7 +948,8 @@ async def test_service_interface_and_implementation_names( task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) assert await client.execute_workflow( ServiceInterfaceAndImplCallerWorkflow.run, args=(CallerReference.INTERFACE, NameOverride.YES, task_queue), @@ -1057,7 +1065,8 @@ async def test_workflow_run_operation_can_execute_workflow_before_starting_backi ], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) result = await client.execute_workflow( WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow.run, args=("result-1", task_queue), @@ -1109,7 +1118,8 @@ async def test_nexus_operation_summary( ], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) wf_id = f"wf-{uuid.uuid4()}" handle = await client.start_workflow( ExecuteNexusOperationWithSummaryWorkflow.run, @@ -1405,7 +1415,8 @@ async def test_workflow_run_operation_overloads( ], nexus_service_handlers=[OverloadTestServiceHandler()], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) res = await client.execute_workflow( OverloadTestCallerWorkflow.run, args=[op, OverloadTestValue(value=2)], @@ -1465,7 +1476,8 @@ async def test_workflow_caller_custom_metrics(client: Client, env: WorkflowEnvir pytest.skip("Nexus tests don't work with time-skipping server") task_queue = str(uuid.uuid4()) - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) # Create new runtime with Prom server prom_addr = f"127.0.0.1:{find_free_port()}" @@ -1558,7 +1570,8 @@ async def test_workflow_caller_buffered_metrics( runtime=runtime, ) task_queue = str(uuid.uuid4()) - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) async with new_worker( client, CustomMetricsWorkflow,