Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions temporalio/testing/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import google.protobuf.empty_pb2
from typing_extensions import Self

import temporalio.api.nexus.v1
import temporalio.api.operatorservice.v1
import temporalio.api.testservice.v1
import temporalio.bridge.testing
import temporalio.client
Expand Down Expand Up @@ -396,6 +398,48 @@ def supports_time_skipping(self) -> bool:
"""Whether this environment supports time skipping."""
return False

async def create_nexus_endpoint(
self, endpoint_name: str, task_queue: str
) -> temporalio.api.nexus.v1.Endpoint:
"""Create a Nexus endpoint with the given name and task queue.

Args:
endpoint_name: The name of the Nexus endpoint to create.
task_queue: The task queue to associate with the endpoint.

Returns:
The created Nexus endpoint.
"""
response = await self._client.operator_service.create_nexus_endpoint(
temporalio.api.operatorservice.v1.CreateNexusEndpointRequest(
spec=temporalio.api.nexus.v1.EndpointSpec(
name=endpoint_name,
target=temporalio.api.nexus.v1.EndpointTarget(
worker=temporalio.api.nexus.v1.EndpointTarget.Worker(
namespace=self._client.namespace,
task_queue=task_queue,
)
),
)
)
)
return response.endpoint

async def delete_nexus_endpoint(
self, endpoint: temporalio.api.nexus.v1.Endpoint
) -> None:
"""Delete a Nexus endpoint.

Args:
endpoint: The Nexus endpoint to delete.
"""
await self._client.operator_service.delete_nexus_endpoint(
temporalio.api.operatorservice.v1.DeleteNexusEndpointRequest(
id=endpoint.id,
version=endpoint.version,
)
)

@contextmanager
def auto_time_skipping_disabled(self) -> Iterator[None]:
"""Disable any automatic time skipping if this is a time-skipping
Expand Down
37 changes: 25 additions & 12 deletions tests/nexus/test_workflow_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from temporalio.worker import Worker
from tests.helpers import find_free_port, new_worker
from tests.helpers.metrics import PromMetricMatcher
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
from tests.helpers.nexus import make_nexus_endpoint_name

# TODO(nexus-prerelease): test availability of Temporal client etc in async context set by worker
# TODO(nexus-preview): test worker shutdown, wait_all_completed, drain etc
Expand Down Expand Up @@ -434,7 +434,8 @@ async def test_sync_operation_happy_path(client: Client, env: WorkflowEnvironmen
task_queue=task_queue,
workflow_failure_exception_types=[Exception],
):
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
wf_output = await client.execute_workflow(
CallerWorkflow.run,
args=[
Expand Down Expand Up @@ -471,7 +472,8 @@ async def test_workflow_run_operation_happy_path(
task_queue=task_queue,
workflow_failure_exception_types=[Exception],
):
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
wf_output = await client.execute_workflow(
CallerWorkflow.run,
args=[
Expand Down Expand Up @@ -527,7 +529,8 @@ async def test_sync_response(
task_queue=task_queue,
workflow_failure_exception_types=[Exception],
):
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
caller_wf_handle = await client.start_workflow(
CallerWorkflow.run,
args=[
Expand Down Expand Up @@ -603,6 +606,7 @@ async def test_async_response(
workflow_failure_exception_types=[Exception],
):
caller_wf_handle, handler_wf_handle = await _start_wf_and_nexus_op(
env,
client,
task_queue,
exception_in_operation_start,
Expand Down Expand Up @@ -676,6 +680,7 @@ async def test_async_response(


async def _start_wf_and_nexus_op(
env: WorkflowEnvironment,
client: Client,
task_queue: str,
exception_in_operation_start: bool,
Expand All @@ -689,7 +694,8 @@ async def _start_wf_and_nexus_op(
"""
Start the caller workflow and wait until the Nexus operation has started.
"""
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
operation_workflow_id = str(uuid.uuid4())

# Start the caller workflow and wait until it confirms the Nexus operation has started.
Expand Down Expand Up @@ -774,7 +780,8 @@ async def test_untyped_caller(
op_definition_type=op_definition_type,
exception_in_operation_start=exception_in_operation_start,
)
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
caller_wf_handle = await client.start_workflow(
UntypedCallerWorkflow.run,
args=[
Expand Down Expand Up @@ -941,7 +948,8 @@ async def test_service_interface_and_implementation_names(
task_queue=task_queue,
workflow_failure_exception_types=[Exception],
):
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
assert await client.execute_workflow(
ServiceInterfaceAndImplCallerWorkflow.run,
args=(CallerReference.INTERFACE, NameOverride.YES, task_queue),
Expand Down Expand Up @@ -1057,7 +1065,8 @@ async def test_workflow_run_operation_can_execute_workflow_before_starting_backi
],
task_queue=task_queue,
):
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
result = await client.execute_workflow(
WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow.run,
args=("result-1", task_queue),
Expand Down Expand Up @@ -1109,7 +1118,8 @@ async def test_nexus_operation_summary(
],
task_queue=task_queue,
):
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
wf_id = f"wf-{uuid.uuid4()}"
handle = await client.start_workflow(
ExecuteNexusOperationWithSummaryWorkflow.run,
Expand Down Expand Up @@ -1405,7 +1415,8 @@ async def test_workflow_run_operation_overloads(
],
nexus_service_handlers=[OverloadTestServiceHandler()],
):
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
res = await client.execute_workflow(
OverloadTestCallerWorkflow.run,
args=[op, OverloadTestValue(value=2)],
Expand Down Expand Up @@ -1465,7 +1476,8 @@ async def test_workflow_caller_custom_metrics(client: Client, env: WorkflowEnvir
pytest.skip("Nexus tests don't work with time-skipping server")

task_queue = str(uuid.uuid4())
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)

# Create new runtime with Prom server
prom_addr = f"127.0.0.1:{find_free_port()}"
Expand Down Expand Up @@ -1558,7 +1570,8 @@ async def test_workflow_caller_buffered_metrics(
runtime=runtime,
)
task_queue = str(uuid.uuid4())
await create_nexus_endpoint(task_queue, client)
endpoint_name = make_nexus_endpoint_name(task_queue)
await env.create_nexus_endpoint(endpoint_name, task_queue)
async with new_worker(
client,
CustomMetricsWorkflow,
Expand Down
Loading