From 307f478ad32acb3a5661b699c333d2a2733dba0f Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 12:27:20 -0400 Subject: [PATCH 1/4] feat: Add backpressure hook for server-side admission control Add BackpressureLike protocol and BackpressureContext dataclass that let users configure a per-worker hook to reject incoming tasks. The hook receives a snapshot of active task count and the incoming task, returning True to reject (triggering gRPC RESOURCE_EXHAUSTED) or False to accept. The load balancer already treats RESOURCE_EXHAUSTED as transient and skips to the next worker. The hook is threaded through LocalWorker -> WorkerProcess -> WorkerService, serialized via cloudpickle across the spawn boundary to support lambdas and closures. --- wool/src/wool/__init__.py | 5 ++ wool/src/wool/runtime/worker/local.py | 6 ++ wool/src/wool/runtime/worker/process.py | 13 +++- wool/src/wool/runtime/worker/service.py | 80 ++++++++++++++++++++++++- 4 files changed, 101 insertions(+), 3 deletions(-) diff --git a/wool/src/wool/__init__.py b/wool/src/wool/__init__.py index 859b6e9..3540580 100644 --- a/wool/src/wool/__init__.py +++ b/wool/src/wool/__init__.py @@ -38,6 +38,8 @@ from wool.runtime.worker.metadata import WorkerMetadata from wool.runtime.worker.pool import WorkerPool from wool.runtime.worker.proxy import WorkerProxy +from wool.runtime.worker.service import BackpressureContext +from wool.runtime.worker.service import BackpressureLike from wool.runtime.worker.service import WorkerService pickling_support.install() @@ -80,6 +82,9 @@ "TaskException", "current_task", "routine", + # Backpressure + "BackpressureContext", + "BackpressureLike", # Workers "LocalWorker", "Worker", diff --git a/wool/src/wool/runtime/worker/local.py b/wool/src/wool/runtime/worker/local.py index 0f6a61f..e3a000e 100644 --- a/wool/src/wool/runtime/worker/local.py +++ b/wool/src/wool/runtime/worker/local.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from typing import TYPE_CHECKING from typing import Any import grpc.aio @@ -11,6 +12,9 @@ from wool.runtime.worker.base import WorkerOptions from wool.runtime.worker.process import WorkerProcess +if TYPE_CHECKING: + from wool.runtime.worker.service import BackpressureLike + # public class LocalWorker(Worker): @@ -76,6 +80,7 @@ def __init__( proxy_pool_ttl: float = 60.0, credentials: WorkerCredentials | None = None, options: WorkerOptions | None = None, + backpressure: BackpressureLike | None = None, **extra: Any, ): super().__init__(*tags, **extra) @@ -90,6 +95,7 @@ def __init__( options=options, tags=frozenset(self._tags), extra=self._extra, + backpressure=backpressure, ) @property diff --git a/wool/src/wool/runtime/worker/process.py b/wool/src/wool/runtime/worker/process.py index 4f12cb8..6675030 100644 --- a/wool/src/wool/runtime/worker/process.py +++ b/wool/src/wool/runtime/worker/process.py @@ -18,6 +18,7 @@ from typing import Any from typing import Final +import cloudpickle import grpc.aio import wool @@ -32,6 +33,7 @@ if TYPE_CHECKING: from wool.runtime.worker.proxy import WorkerProxy + from wool.runtime.worker.service import BackpressureLike _ctx = _mp.get_context("spawn") Pipe = _ctx.Pipe @@ -91,6 +93,7 @@ def __init__( options: WorkerOptions | None = None, tags: frozenset[str] = frozenset(), extra: dict[str, Any] | None = None, + backpressure: BackpressureLike | None = None, **kwargs, ): super().__init__(*args, **kwargs) @@ -112,6 +115,9 @@ def __init__( self._tags = tags self._extra = extra if extra is not None else {} self._metadata = None + self._backpressure = ( + cloudpickle.dumps(backpressure) if backpressure is not None else None + ) self._get_metadata, self._set_metadata = Pipe(duplex=False) @property @@ -294,7 +300,12 @@ async def _serve(self): server.add_insecure_port(uds_target) uds_address = uds_target - service = WorkerService() + backpressure = ( + cloudpickle.loads(self._backpressure) + if self._backpressure is not None + else None + ) + service = WorkerService(backpressure=backpressure) protocol.add_to_server[protocol.WorkerServicer](service, server) with _signal_handlers(service): diff --git a/wool/src/wool/runtime/worker/service.py b/wool/src/wool/runtime/worker/service.py index a866d89..0b429f2 100644 --- a/wool/src/wool/runtime/worker/service.py +++ b/wool/src/wool/runtime/worker/service.py @@ -5,11 +5,16 @@ import contextvars import threading from contextlib import contextmanager +from dataclasses import dataclass from inspect import isasyncgen from inspect import isasyncgenfunction +from inspect import isawaitable from inspect import iscoroutinefunction from typing import AsyncGenerator from typing import AsyncIterator +from typing import Awaitable +from typing import Protocol +from typing import runtime_checkable import cloudpickle from grpc import StatusCode @@ -25,6 +30,59 @@ _SENTINEL = object() +# public +@dataclass(frozen=True) +class BackpressureContext: + """Snapshot of worker state provided to backpressure hooks. + + :param active_task_count: + Number of tasks currently executing on this worker. + :param task: + The incoming :class:`~wool.runtime.routine.task.Task` being + evaluated for admission. + """ + + active_task_count: int + task: Task + + +# public +@runtime_checkable +class BackpressureLike(Protocol): + """Protocol for backpressure hooks. + + A backpressure hook determines whether an incoming task should be + rejected. Return ``True`` to **reject** the task (apply + backpressure) or ``False`` to **accept** it. + + When a task is rejected the worker responds with gRPC + ``RESOURCE_EXHAUSTED``, which the load balancer treats as + transient and skips to the next worker. + + Pass ``None`` (the default) to accept all tasks unconditionally. + + Both sync and async implementations are supported:: + + def sync_hook(ctx: BackpressureContext) -> bool: + return ctx.active_task_count >= 4 + + + async def async_hook(ctx: BackpressureContext) -> bool: + return ctx.active_task_count >= 4 + """ + + def __call__(self, ctx: BackpressureContext) -> bool | Awaitable[bool]: + """Evaluate whether to reject the incoming task. + + :param ctx: + Snapshot of the worker's current state and the incoming + task. + :returns: + ``True`` to reject the task, ``False`` to accept it. + """ + ... + + class _Task: def __init__(self, task: asyncio.Task): self._work = task @@ -86,11 +144,12 @@ class WorkerService(protocol.WorkerServicer): _task_completed: asyncio.Event _loop_pool: ResourcePool[tuple[asyncio.AbstractEventLoop, threading.Thread]] - def __init__(self): + def __init__(self, *, backpressure: BackpressureLike | None = None): self._stopped = asyncio.Event() self._stopping = asyncio.Event() self._task_completed = asyncio.Event() self._docket = set() + self._backpressure = backpressure self._loop_pool = ResourcePool( factory=self._create_worker_loop, finalizer=self._destroy_worker_loop, @@ -143,7 +202,24 @@ async def dispatch( ) response = await anext(aiter(request_iterator)) - with self._tracker(Task.from_protobuf(response.task), request_iterator) as task: + work_task = Task.from_protobuf(response.task) + + if self._backpressure is not None: + decision = self._backpressure( + BackpressureContext( + active_task_count=len(self._docket), + task=work_task, + ) + ) + if isawaitable(decision): + decision = await decision + if decision: + await context.abort( + StatusCode.RESOURCE_EXHAUSTED, + "Task rejected by backpressure hook", + ) + + with self._tracker(work_task, request_iterator) as task: ack = protocol.Ack(version=protocol.__version__) yield protocol.Response(ack=ack) try: From 3b0a441db69bd4bb6d11c4dd3c87ca9e55716fd2 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 12:27:34 -0400 Subject: [PATCH 2/4] test: Add backpressure hook tests Cover BackpressureContext (instantiation, frozen immutability, equality, Hypothesis property test), BackpressureLike protocol (sync, async, callable class, negative case), WorkerService dispatch with backpressure (sync/async accept and reject, context field validation, active task counting), and LocalWorker construction with the hook. Update public API completeness test for the new exports. --- .../tests/runtime/worker/test_backpressure.py | 194 ++++++++++ wool/tests/runtime/worker/test_local.py | 23 ++ wool/tests/runtime/worker/test_service.py | 332 ++++++++++++++++++ wool/tests/test_public.py | 2 + 4 files changed, 551 insertions(+) create mode 100644 wool/tests/runtime/worker/test_backpressure.py diff --git a/wool/tests/runtime/worker/test_backpressure.py b/wool/tests/runtime/worker/test_backpressure.py new file mode 100644 index 0000000..6d19f94 --- /dev/null +++ b/wool/tests/runtime/worker/test_backpressure.py @@ -0,0 +1,194 @@ +from uuid import uuid4 + +import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +from wool.runtime.routine.task import Task +from wool.runtime.routine.task import WorkerProxyLike +from wool.runtime.worker.service import BackpressureContext +from wool.runtime.worker.service import BackpressureLike + +from .conftest import PicklableMock + + +def _make_task(): + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + return Task( + id=uuid4(), + callable=lambda: None, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + + +class TestBackpressureContext: + """Tests for :class:`BackpressureContext` dataclass.""" + + def test___init___with_valid_fields(self): + """Test BackpressureContext instantiation. + + Given: + An active task count and a Task instance + When: + BackpressureContext is instantiated + Then: + It should store both fields correctly + """ + # Arrange + task = _make_task() + + # Act + ctx = BackpressureContext(active_task_count=3, task=task) + + # Assert + assert ctx.active_task_count == 3 + assert ctx.task is task + + def test___init___is_frozen(self): + """Test BackpressureContext is immutable. + + Given: + A BackpressureContext instance + When: + An attribute is reassigned + Then: + It should raise FrozenInstanceError + """ + # Arrange + ctx = BackpressureContext(active_task_count=0, task=_make_task()) + + # Act & Assert + with pytest.raises(AttributeError): + ctx.active_task_count = 5 + + def test___eq___with_equal_fields(self): + """Test BackpressureContext equality with identical fields. + + Given: + Two BackpressureContext instances with the same active_task_count and task + When: + Compared with == + Then: + It should return True + """ + # Arrange + task = _make_task() + ctx_a = BackpressureContext(active_task_count=2, task=task) + ctx_b = BackpressureContext(active_task_count=2, task=task) + + # Act & Assert + assert ctx_a == ctx_b + + def test___eq___with_different_task_count(self): + """Test BackpressureContext inequality with different active_task_count. + + Given: + Two BackpressureContext instances with different active_task_count + When: + Compared with == + Then: + It should return False + """ + # Arrange + task = _make_task() + ctx_a = BackpressureContext(active_task_count=0, task=task) + ctx_b = BackpressureContext(active_task_count=1, task=task) + + # Act & Assert + assert ctx_a != ctx_b + + @given(count=st.integers(min_value=0, max_value=10_000)) + @settings(max_examples=50) + def test___init___with_arbitrary_task_count(self, count): + """Test BackpressureContext stores arbitrary non-negative active_task_count. + + Given: + Any non-negative integer for active_task_count + When: + BackpressureContext is instantiated + Then: + It should store the value exactly + """ + # Arrange + task = _make_task() + + # Act + ctx = BackpressureContext(active_task_count=count, task=task) + + # Assert + assert ctx.active_task_count == count + + +class TestBackpressureLike: + """Tests for :class:`BackpressureLike` protocol.""" + + def test_sync_callable_satisfies_protocol(self): + """Test sync callable satisfies BackpressureLike. + + Given: + A sync function with the correct signature + When: + Checked against BackpressureLike + Then: + It should be recognized as an instance + """ + + # Arrange + def hook(ctx: BackpressureContext) -> bool: + return True + + # Act & Assert + assert isinstance(hook, BackpressureLike) + + def test_async_callable_satisfies_protocol(self): + """Test async callable satisfies BackpressureLike. + + Given: + An async function with the correct signature + When: + Checked against BackpressureLike + Then: + It should be recognized as an instance + """ + + # Arrange + async def hook(ctx: BackpressureContext) -> bool: + return True + + # Act & Assert + assert isinstance(hook, BackpressureLike) + + def test_callable_class_satisfies_protocol(self): + """Test callable class instance satisfies BackpressureLike. + + Given: + A class with a __call__ method + When: + An instance is checked against BackpressureLike + Then: + It should be recognized as an instance + """ + + # Arrange + class Hook: + def __call__(self, ctx: BackpressureContext) -> bool: + return ctx.active_task_count >= 4 + + # Act & Assert + assert isinstance(Hook(), BackpressureLike) + + def test_non_callable_does_not_satisfy_protocol(self): + """Test non-callable does not satisfy BackpressureLike. + + Given: + A plain string (not callable) + When: + Checked against BackpressureLike + Then: + It should not be recognized as an instance + """ + # Act & Assert + assert not isinstance("not-a-hook", BackpressureLike) diff --git a/wool/tests/runtime/worker/test_local.py b/wool/tests/runtime/worker/test_local.py index 222da4d..0db75de 100644 --- a/wool/tests/runtime/worker/test_local.py +++ b/wool/tests/runtime/worker/test_local.py @@ -704,3 +704,26 @@ async def test_worker_lifecycle_with_credentials( # Assert assert worker.metadata is not None + + def test___init___with_backpressure(self, mocker): + """Test LocalWorker initialization with backpressure hook. + + Given: + A callable backpressure hook + When: + LocalWorker is instantiated with backpressure=hook + Then: + It should forward the hook to WorkerProcess + """ + # Arrange + MockWorkerProcess = mocker.patch.object(local_module, "WorkerProcess") + + def hook(ctx): + return ctx.active_task_count >= 4 + + # Act + LocalWorker(backpressure=hook) + + # Assert + MockWorkerProcess.assert_called_once() + assert MockWorkerProcess.call_args.kwargs["backpressure"] is hook diff --git a/wool/tests/runtime/worker/test_service.py b/wool/tests/runtime/worker/test_service.py index 2b65c01..49dd904 100644 --- a/wool/tests/runtime/worker/test_service.py +++ b/wool/tests/runtime/worker/test_service.py @@ -2172,3 +2172,335 @@ async def resilient_generator(): # Assert assert result["do_dispatch"] is False assert result["has_proxy"] is True + + +class TestWorkerServiceBackpressure: + """Tests for :class:`WorkerService` backpressure hook.""" + + def test___init___with_backpressure_hook(self): + """Test WorkerService initialization with a backpressure hook. + + Given: + A callable backpressure hook + When: + WorkerService is instantiated with backpressure=hook + Then: + It should initialize successfully with stopping and stopped events unset + """ + + # Arrange + def hook(ctx): + return False + + # Act + service = WorkerService(backpressure=hook) + + # Assert + assert not service.stopping.is_set() + assert not service.stopped.is_set() + + @pytest.mark.asyncio + async def test_dispatch_with_sync_backpressure_accepting( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test dispatch succeeds when sync backpressure hook returns False. + + Given: + A :class:`WorkerService` with a sync backpressure hook that returns False + When: + Dispatch RPC is called + Then: + It should accept the task and return the result normally + """ + + # Arrange + async def sample_task(): + return "accepted" + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + wool_task = Task( + id=uuid4(), + callable=sample_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + request = protocol.Request(task=wool_task.to_protobuf()) + + def hook(ctx): + return False + + service = WorkerService(backpressure=hook) + + # Act + async with grpc_aio_stub(servicer=service) as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + responses = [r async for r in stream] + + # Assert + ack, response = responses + assert ack.HasField("ack") + assert response.HasField("result") + assert cloudpickle.loads(response.result.dump) == "accepted" + + @pytest.mark.asyncio + async def test_dispatch_with_sync_backpressure_rejecting( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test dispatch aborts when sync backpressure hook returns True. + + Given: + A :class:`WorkerService` with a sync backpressure hook that returns True + When: + Dispatch RPC is called + Then: + It should reject the task with RESOURCE_EXHAUSTED status + """ + + # Arrange + async def sample_task(): + return "should_not_reach" + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + wool_task = Task( + id=uuid4(), + callable=sample_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + request = protocol.Request(task=wool_task.to_protobuf()) + + def hook(ctx): + return True + + service = WorkerService(backpressure=hook) + + # Act & assert + async with grpc_aio_stub(servicer=service) as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + with pytest.raises(grpc.RpcError) as exc_info: + async for _ in stream: + pass + assert exc_info.value.code() == StatusCode.RESOURCE_EXHAUSTED + + @pytest.mark.asyncio + async def test_dispatch_with_async_backpressure_rejecting( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test dispatch aborts when async backpressure hook returns True. + + Given: + A :class:`WorkerService` with an async backpressure hook that returns True + When: + Dispatch RPC is called + Then: + It should reject the task with RESOURCE_EXHAUSTED status + """ + + # Arrange + async def sample_task(): + return "should_not_reach" + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + wool_task = Task( + id=uuid4(), + callable=sample_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + request = protocol.Request(task=wool_task.to_protobuf()) + + async def async_hook(ctx): + return True + + service = WorkerService(backpressure=async_hook) + + # Act & assert + async with grpc_aio_stub(servicer=service) as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + with pytest.raises(grpc.RpcError) as exc_info: + async for _ in stream: + pass + assert exc_info.value.code() == StatusCode.RESOURCE_EXHAUSTED + + @pytest.mark.asyncio + async def test_dispatch_with_async_backpressure_accepting( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test dispatch succeeds when async backpressure hook returns False. + + Given: + A :class:`WorkerService` with an async backpressure hook that returns False + When: + Dispatch RPC is called + Then: + It should accept the task and return the result normally + """ + + # Arrange + async def sample_task(): + return "async_accepted" + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + wool_task = Task( + id=uuid4(), + callable=sample_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + request = protocol.Request(task=wool_task.to_protobuf()) + + async def async_hook(ctx): + return False + + service = WorkerService(backpressure=async_hook) + + # Act + async with grpc_aio_stub(servicer=service) as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + responses = [r async for r in stream] + + # Assert + ack, response = responses + assert ack.HasField("ack") + assert response.HasField("result") + assert cloudpickle.loads(response.result.dump) == "async_accepted" + + @pytest.mark.asyncio + async def test_dispatch_with_backpressure_receiving_context( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test backpressure hook receives correct context. + + Given: + A :class:`WorkerService` with a backpressure hook that captures its argument + When: + Dispatch RPC is called + Then: + It should pass a BackpressureContext with active_task_count and task fields + """ + # Arrange + from wool.runtime.worker.service import BackpressureContext + + async def sample_task(): + return "result" + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + wool_task = Task( + id=uuid4(), + callable=sample_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + request = protocol.Request(task=wool_task.to_protobuf()) + + captured = [] + + def hook(ctx): + captured.append(ctx) + return False + + service = WorkerService(backpressure=hook) + + # Act + async with grpc_aio_stub(servicer=service) as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + [r async for r in stream] + + # Assert + assert len(captured) == 1 + ctx = captured[0] + assert isinstance(ctx, BackpressureContext) + assert ctx.active_task_count == 0 + assert ctx.task.id == wool_task.id + + @pytest.mark.asyncio + async def test_dispatch_with_backpressure_and_active_tasks( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test backpressure hook sees correct active task count. + + Given: + A :class:`WorkerService` with one active task already dispatched + When: + A second dispatch RPC is called with a backpressure hook + Then: + It should see active_task_count == 1 + """ + # Arrange + global _control_event + _control_event = threading.Event() + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + + first_task = Task( + id=uuid4(), + callable=_controllable_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + first_request = protocol.Request(task=first_task.to_protobuf()) + + async def second_fn(): + return "second" + + second_task = Task( + id=uuid4(), + callable=second_fn, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + second_request = protocol.Request(task=second_task.to_protobuf()) + + captured_count = [] + + def hook(ctx): + captured_count.append(ctx.active_task_count) + return False + + service = WorkerService(backpressure=hook) + + # Act + try: + async with grpc_aio_stub(servicer=service) as stub: + # Dispatch first task (blocks on control event) + stream1 = stub.dispatch() + await stream1.write(first_request) + await stream1.done_writing() + # Wait for ack to confirm first task is tracked + async for response in stream1: + assert response.HasField("ack") + break + + # Dispatch second task — hook should see 1 active task + stream2 = stub.dispatch() + await stream2.write(second_request) + await stream2.done_writing() + [r async for r in stream2] + + # Release first task + _control_event.set() + [r async for r in stream1] + finally: + if _control_event and not _control_event.is_set(): + _control_event.set() + _control_event = None + + # Assert — first dispatch sees 0 active, second sees 1 + assert captured_count == [0, 1] diff --git a/wool/tests/test_public.py b/wool/tests/test_public.py index b3c6891..b5bd0b5 100644 --- a/wool/tests/test_public.py +++ b/wool/tests/test_public.py @@ -45,6 +45,8 @@ def test_public_api_completeness(): "TaskException", "current_task", "routine", + "BackpressureContext", + "BackpressureLike", "LocalWorker", "Worker", "WorkerCredentials", From 97cb8da3a27e0d86df3ba841857e34c3c685ce2d Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 15:59:47 -0400 Subject: [PATCH 3/4] docs: Add missing param documentation to backpressure classes LocalWorker, WorkerProcess, and WorkerService were missing docstring entries for parameters introduced by the backpressure feature. WorkerProcess was also missing uid, tags, and extra params that predated this PR. Removes a stale :param tasks: entry from WorkerService._cancel which takes no arguments. --- wool/src/wool/runtime/worker/local.py | 9 +++++++++ wool/src/wool/runtime/worker/process.py | 16 +++++++++++++++- wool/src/wool/runtime/worker/service.py | 23 ++++++++--------------- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/wool/src/wool/runtime/worker/local.py b/wool/src/wool/runtime/worker/local.py index e3a000e..e84e703 100644 --- a/wool/src/wool/runtime/worker/local.py +++ b/wool/src/wool/runtime/worker/local.py @@ -64,6 +64,15 @@ class LocalWorker(Worker): :param options: gRPC message size options. Defaults to :class:`WorkerOptions` with 100 MB limits. + :param backpressure: + Optional admission control hook. A callable receiving a + :class:`~wool.runtime.worker.service.BackpressureContext` + and returning ``True`` to **reject** the task or ``False`` + to **accept** it. Both sync and async callables are + supported. When a task is rejected the worker responds with + gRPC ``RESOURCE_EXHAUSTED``, causing the load balancer to + skip to the next worker. ``None`` (default) accepts all + tasks unconditionally. :param extra: Additional metadata as key-value pairs. """ diff --git a/wool/src/wool/runtime/worker/process.py b/wool/src/wool/runtime/worker/process.py index 6675030..3ff5a8e 100644 --- a/wool/src/wool/runtime/worker/process.py +++ b/wool/src/wool/runtime/worker/process.py @@ -66,6 +66,18 @@ class WorkerProcess(Process): :param options: gRPC message size options. Defaults to :class:`WorkerOptions` with 100 MB limits. + :param uid: + Unique identifier for this worker. Auto-generated if not + provided. + :param tags: + Capability tags for filtering and selection. + :param extra: + Additional metadata as key-value pairs. + :param backpressure: + Optional admission control hook. See + :class:`~wool.runtime.worker.service.BackpressureLike`. + Serialized with ``cloudpickle`` for transfer to the + subprocess. :param args: Additional args for :class:`multiprocessing.Process`. :param kwargs: @@ -347,8 +359,10 @@ async def _serve(self): os.unlink(uds_path) def _address(self, host, port) -> str: - """Format network address for the given port. + """Format network address for the given host and port. + :param host: + Host address to include in the address. :param port: Port number to include in the address. :returns: diff --git a/wool/src/wool/runtime/worker/service.py b/wool/src/wool/runtime/worker/service.py index 0b429f2..e76743a 100644 --- a/wool/src/wool/runtime/worker/service.py +++ b/wool/src/wool/runtime/worker/service.py @@ -136,6 +136,11 @@ class WorkerService(protocol.WorkerServicer): Handles graceful shutdown by rejecting new tasks while allowing in-flight tasks to complete. Exposes :attr:`stopping` and :attr:`stopped` events for lifecycle monitoring. + + :param backpressure: + Optional admission control hook. See + :class:`BackpressureLike`. ``None`` (default) accepts all + tasks unconditionally. """ _docket: set[_Task | _AsyncGen] @@ -547,21 +552,9 @@ async def _await(self): await asyncio.sleep(0) async def _cancel(self): - """Cancel multiple tasks safely. - - Cancels the provided tasks while performing safety checks to - avoid canceling the current task or already completed tasks. - Waits for all cancelled tasks to complete in parallel and handles - cancellation exceptions. + """Cancel all tracked tasks in the docket. - :param tasks: - The :class:`asyncio.Task` instances to cancel. - - .. note:: - This method performs the following safety checks: - - Avoids canceling the current task (would cause deadlock) - - Only cancels tasks that are not already done - - Properly handles :exc:`asyncio.CancelledError` - exceptions. + Cancels every entry in :attr:`_docket` and waits for them to + finish, handling cancellation exceptions gracefully. """ await asyncio.gather(*(w.cancel() for w in self._docket), return_exceptions=True) From 9198542848dd663bc10d70eeba1f14888d9894a6 Mon Sep 17 00:00:00 2001 From: Conrad Date: Wed, 1 Apr 2026 16:34:15 -0400 Subject: [PATCH 4/4] test: Add BackpressureMode as 10th integration dimension Extends the pairwise covering array and Hypothesis strategy with a BackpressureMode dimension (NONE, SYNC, ASYNC) to verify that backpressure hooks survive cloudpickle serialization through real subprocesses. Adds rejection composition tests validating the end-to-end RESOURCE_EXHAUSTED path and load balancer fallback to an accepting worker. --- wool/tests/integration/conftest.py | 66 +++++- wool/tests/integration/test_integration.py | 4 + .../integration/test_pool_composition.py | 201 ++++++++++++++++++ wool/tests/integration/test_scenario.py | 6 +- 4 files changed, 264 insertions(+), 13 deletions(-) diff --git a/wool/tests/integration/conftest.py b/wool/tests/integration/conftest.py index 05e3592..321b5aa 100644 --- a/wool/tests/integration/conftest.py +++ b/wool/tests/integration/conftest.py @@ -105,11 +105,27 @@ class RoutineBinding(Enum): STATICMETHOD = auto() +class BackpressureMode(Enum): + NONE = auto() + SYNC = auto() + ASYNC = auto() + + class LazyMode(Enum): LAZY = auto() EAGER = auto() +def _sync_accept_hook(ctx): + """Sync backpressure hook that accepts all tasks.""" + return False + + +async def _async_accept_hook(ctx): + """Async backpressure hook that accepts all tasks.""" + return False + + @dataclass(frozen=True) class Scenario: """Composable scenario describing one integration test configuration. @@ -127,6 +143,7 @@ class Scenario: timeout: TimeoutKind | None = None binding: RoutineBinding | None = None lazy: LazyMode | None = None + backpressure: BackpressureMode | None = None def __or__(self, other: Scenario) -> Scenario: """Merge two partial scenarios. Right side wins on ``None`` fields. @@ -147,7 +164,7 @@ def __or__(self, other: Scenario) -> Scenario: @property def is_complete(self) -> bool: - """True when all 9 dimensions are set.""" + """True when all 10 dimensions are set.""" return all(getattr(self, f.name) is not None for f in fields(self)) def __str__(self) -> str: @@ -285,29 +302,46 @@ async def _lan_async_cm(): lazy = scenario.lazy is LazyMode.LAZY + match scenario.backpressure: + case BackpressureMode.SYNC: + bp_hook = _sync_accept_hook + case BackpressureMode.ASYNC: + bp_hook = _async_accept_hook + case _: + bp_hook = None + try: if runtime_ctx is not None: runtime_ctx.__enter__() try: if scenario.pool_mode is PoolMode.DURABLE: - async with _durable_pool_context(lb, creds, options, lazy) as pool: + async with _durable_pool_context( + lb, creds, options, lazy, backpressure=bp_hook + ) as pool: yield pool elif scenario.pool_mode is PoolMode.DURABLE_SHARED: async with _durable_shared_pool_context( - lb, creds, options, lazy + lb, creds, options, lazy, backpressure=bp_hook ) as pool: yield pool elif scenario.pool_mode is PoolMode.DURABLE_JOINED: async with _durable_joined_pool_context( - scenario.discovery, lb, creds, options, lazy + scenario.discovery, + lb, + creds, + options, + lazy, + backpressure=bp_hook, ) as pool: yield pool else: pool_kwargs = { "loadbalancer": lb, "credentials": creds, - "worker": partial(LocalWorker, options=options), + "worker": partial( + LocalWorker, options=options, backpressure=bp_hook + ), "lazy": lazy, } match scenario.pool_mode: @@ -356,7 +390,7 @@ async def _lan_async_cm(): @asynccontextmanager -async def _durable_pool_context(lb, creds, options, lazy): +async def _durable_pool_context(lb, creds, options, lazy, *, backpressure=None): """Manually start a worker, register it, then create a DURABLE pool. DURABLE pools don't spawn workers — they only discover external @@ -368,7 +402,9 @@ async def _durable_pool_context(lb, creds, options, lazy): """ namespace = f"durable-{uuid.uuid4().hex[:12]}" with LocalDiscovery(namespace) as discovery: - worker = LocalWorker(credentials=creds, options=options) + worker = LocalWorker( + credentials=creds, options=options, backpressure=backpressure + ) await worker.start() try: publisher = discovery.publisher @@ -390,7 +426,7 @@ async def _durable_pool_context(lb, creds, options, lazy): @asynccontextmanager -async def _durable_shared_pool_context(lb, creds, options, lazy): +async def _durable_shared_pool_context(lb, creds, options, lazy, *, backpressure=None): """Create two pools sharing the same LocalDiscovery subscriber. Exercises ``SubscriberMeta`` singleton caching and @@ -401,7 +437,9 @@ async def _durable_shared_pool_context(lb, creds, options, lazy): """ namespace = f"shared-{uuid.uuid4().hex[:12]}" with LocalDiscovery(namespace) as discovery: - worker = LocalWorker(credentials=creds, options=options) + worker = LocalWorker( + credentials=creds, options=options, backpressure=backpressure + ) await worker.start() try: publisher = discovery.publisher @@ -469,7 +507,9 @@ async def _acm(): @asynccontextmanager -async def _durable_joined_pool_context(discovery_factory, lb, creds, options, lazy): +async def _durable_joined_pool_context( + discovery_factory, lb, creds, options, lazy, *, backpressure=None +): """Create a DURABLE pool that joins an externally owned namespace. Sets up an owner ``LocalDiscovery`` that creates workers and publishes @@ -479,7 +519,7 @@ async def _durable_joined_pool_context(discovery_factory, lb, creds, options, la """ namespace = f"joined-{uuid.uuid4().hex[:12]}" - worker = LocalWorker(credentials=creds, options=options) + worker = LocalWorker(credentials=creds, options=options, backpressure=backpressure) await worker.start() try: owner = LocalDiscovery(namespace) @@ -720,6 +760,7 @@ def _pairwise_filter(row): timeout=row[6], binding=row[7], lazy=row[8], + backpressure=row[9], ) for row in AllPairs( [ @@ -732,6 +773,7 @@ def _pairwise_filter(row): list(TimeoutKind), list(RoutineBinding), list(LazyMode), + list(BackpressureMode), ], filter_func=_pairwise_filter, ) @@ -792,6 +834,7 @@ def scenarios_strategy(draw): binding = draw(st.sampled_from(RoutineBinding)) lazy = draw(st.sampled_from(LazyMode)) + backpressure = draw(st.sampled_from(BackpressureMode)) return Scenario( shape=shape, @@ -803,6 +846,7 @@ def scenarios_strategy(draw): timeout=timeout, binding=binding, lazy=lazy, + backpressure=backpressure, ) diff --git a/wool/tests/integration/test_integration.py b/wool/tests/integration/test_integration.py index 70d0852..ccd2dfb 100644 --- a/wool/tests/integration/test_integration.py +++ b/wool/tests/integration/test_integration.py @@ -9,6 +9,7 @@ from hypothesis import settings from .conftest import PAIRWISE_SCENARIOS +from .conftest import BackpressureMode from .conftest import CredentialType from .conftest import DiscoveryFactory from .conftest import LazyMode @@ -70,6 +71,7 @@ async def body(): TimeoutKind.NONE, RoutineBinding.MODULE_FUNCTION, LazyMode.LAZY, + BackpressureMode.NONE, ) ) @example( @@ -83,6 +85,7 @@ async def body(): TimeoutKind.NONE, RoutineBinding.MODULE_FUNCTION, LazyMode.LAZY, + BackpressureMode.SYNC, ) ) @example( @@ -96,6 +99,7 @@ async def body(): TimeoutKind.NONE, RoutineBinding.MODULE_FUNCTION, LazyMode.LAZY, + BackpressureMode.ASYNC, ) ) @given(scenario=scenarios_strategy()) diff --git a/wool/tests/integration/test_pool_composition.py b/wool/tests/integration/test_pool_composition.py index 55bf2d5..57cafc4 100644 --- a/wool/tests/integration/test_pool_composition.py +++ b/wool/tests/integration/test_pool_composition.py @@ -1,7 +1,18 @@ """Tests for pool composition via build_pool_from_scenario.""" +import uuid +from functools import partial + import pytest +from wool.runtime.discovery.local import LocalDiscovery +from wool.runtime.loadbalancer.base import NoWorkersAvailable +from wool.runtime.loadbalancer.roundrobin import RoundRobinLoadBalancer +from wool.runtime.worker.local import LocalWorker +from wool.runtime.worker.pool import WorkerPool + +from . import routines +from .conftest import BackpressureMode from .conftest import CredentialType from .conftest import DiscoveryFactory from .conftest import LazyMode @@ -12,6 +23,7 @@ from .conftest import Scenario from .conftest import TimeoutKind from .conftest import WorkerOptionsKind +from .conftest import _DirectDiscovery from .conftest import build_pool_from_scenario from .conftest import invoke_routine @@ -41,6 +53,7 @@ async def test_build_pool_from_scenario_with_default_mode(self, credentials_map) timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -73,6 +86,7 @@ async def test_build_pool_from_scenario_with_ephemeral_mode(self, credentials_ma timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.EAGER, + backpressure=BackpressureMode.NONE, ) # Act @@ -105,6 +119,7 @@ async def test_build_pool_from_scenario_with_durable_mode(self, credentials_map) timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -137,6 +152,7 @@ async def test_build_pool_from_scenario_with_hybrid_mode(self, credentials_map): timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -173,6 +189,7 @@ async def test_build_pool_from_scenario_with_durable_joined_local( timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -205,6 +222,7 @@ async def test_build_pool_from_scenario_with_restrictive_opts(self, credentials_ timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -238,6 +256,7 @@ async def test_build_pool_from_scenario_with_keepalive_opts(self, credentials_ma timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -270,6 +289,7 @@ async def test_build_pool_from_scenario_with_dispatch_timeout(self, credentials_ timeout=TimeoutKind.VIA_RUNTIME_CONTEXT, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -305,6 +325,79 @@ async def test_build_pool_from_scenario_with_shared_discovery(self, credentials_ timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, + ) + + # Act + async with build_pool_from_scenario(scenario, credentials_map): + result = await invoke_routine(scenario) + + # Assert + assert result == 3 + + @pytest.mark.asyncio + async def test_build_pool_from_scenario_with_sync_backpressure( + self, credentials_map + ): + """Test building a pool with a sync backpressure accept hook. + + Given: + A complete scenario using a SYNC backpressure hook that + accepts all tasks (survives cloudpickle serialization to + the subprocess). + When: + A pool is built and a coroutine routine is dispatched. + Then: + It should return the correct result. + """ + # Arrange + scenario = Scenario( + shape=RoutineShape.COROUTINE, + pool_mode=PoolMode.DEFAULT, + discovery=DiscoveryFactory.NONE, + lb=LbFactory.CLASS_REF, + credential=CredentialType.INSECURE, + options=WorkerOptionsKind.DEFAULT, + timeout=TimeoutKind.NONE, + binding=RoutineBinding.MODULE_FUNCTION, + lazy=LazyMode.LAZY, + backpressure=BackpressureMode.SYNC, + ) + + # Act + async with build_pool_from_scenario(scenario, credentials_map): + result = await invoke_routine(scenario) + + # Assert + assert result == 3 + + @pytest.mark.asyncio + async def test_build_pool_from_scenario_with_async_backpressure( + self, credentials_map + ): + """Test building a pool with an async backpressure accept hook. + + Given: + A complete scenario using an ASYNC backpressure hook that + accepts all tasks (async hook survives cloudpickle + serialization to the subprocess). + When: + A pool is built and a coroutine routine is dispatched. + Then: + It should return the correct result. + """ + # Arrange + scenario = Scenario( + shape=RoutineShape.COROUTINE, + pool_mode=PoolMode.DEFAULT, + discovery=DiscoveryFactory.NONE, + lb=LbFactory.CLASS_REF, + credential=CredentialType.INSECURE, + options=WorkerOptionsKind.DEFAULT, + timeout=TimeoutKind.NONE, + binding=RoutineBinding.MODULE_FUNCTION, + lazy=LazyMode.LAZY, + backpressure=BackpressureMode.ASYNC, ) # Act @@ -313,3 +406,111 @@ async def test_build_pool_from_scenario_with_shared_discovery(self, credentials_ # Assert assert result == 3 + + +def _sync_reject_hook(ctx): + """Sync backpressure hook that rejects all tasks.""" + return True + + +async def _async_reject_hook(ctx): + """Async backpressure hook that rejects all tasks.""" + return True + + +@pytest.mark.integration +class TestBackpressureRejection: + @pytest.mark.asyncio + async def test_sync_backpressure_rejection(self): + """Test sync backpressure hook rejects task end-to-end. + + Given: + A single-worker pool with a sync backpressure hook that + rejects all tasks. + When: + A coroutine routine is dispatched. + Then: + It should raise NoWorkersAvailable because the only worker + rejects with RESOURCE_EXHAUSTED. + """ + # Arrange + pool = WorkerPool( + size=1, + loadbalancer=RoundRobinLoadBalancer, + worker=partial(LocalWorker, backpressure=_sync_reject_hook), + ) + + # Act & assert + async with pool: + with pytest.raises(NoWorkersAvailable): + await routines.add(1, 2) + + @pytest.mark.asyncio + async def test_async_backpressure_rejection(self): + """Test async backpressure hook rejects task end-to-end. + + Given: + A single-worker pool with an async backpressure hook that + rejects all tasks. + When: + A coroutine routine is dispatched. + Then: + It should raise NoWorkersAvailable because the only worker + rejects with RESOURCE_EXHAUSTED. + """ + # Arrange + pool = WorkerPool( + size=1, + loadbalancer=RoundRobinLoadBalancer, + worker=partial(LocalWorker, backpressure=_async_reject_hook), + ) + + # Act & assert + async with pool: + with pytest.raises(NoWorkersAvailable): + await routines.add(1, 2) + + @pytest.mark.asyncio + async def test_backpressure_fallback_to_accepting_worker(self): + """Test load balancer falls through to an accepting worker. + + Given: + A durable pool with two workers: one rejecting all tasks + via backpressure and one accepting all tasks. + When: + A coroutine routine is dispatched. + Then: + It should succeed by falling through to the accepting + worker after the rejecting worker returns + RESOURCE_EXHAUSTED. + """ + # Arrange + namespace = f"bp-fallback-{uuid.uuid4().hex[:12]}" + rejecting_worker = LocalWorker(backpressure=_sync_reject_hook) + accepting_worker = LocalWorker() + + await rejecting_worker.start() + await accepting_worker.start() + try: + with LocalDiscovery(namespace) as discovery: + publisher = discovery.publisher + async with publisher: + await publisher.publish("worker-added", rejecting_worker.metadata) + await publisher.publish("worker-added", accepting_worker.metadata) + pool = WorkerPool( + discovery=_DirectDiscovery(discovery), + loadbalancer=RoundRobinLoadBalancer, + ) + + # Act + async with pool: + result = await routines.add(1, 2) + + # Assert + assert result == 3 + + await publisher.publish("worker-dropped", rejecting_worker.metadata) + await publisher.publish("worker-dropped", accepting_worker.metadata) + finally: + await accepting_worker.stop() + await rejecting_worker.stop() diff --git a/wool/tests/integration/test_scenario.py b/wool/tests/integration/test_scenario.py index 81159b2..162a6e0 100644 --- a/wool/tests/integration/test_scenario.py +++ b/wool/tests/integration/test_scenario.py @@ -2,6 +2,7 @@ import pytest +from .conftest import BackpressureMode from .conftest import CredentialType from .conftest import DiscoveryFactory from .conftest import LazyMode @@ -103,7 +104,7 @@ def test_is_complete_with_all_fields(self): """Test that a fully populated scenario reports complete. Given: - A scenario with all 9 dimensions set. + A scenario with all 10 dimensions set. When: ``is_complete`` is checked. Then: @@ -120,6 +121,7 @@ def test_is_complete_with_all_fields(self): timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act & assert @@ -165,4 +167,4 @@ def test___str___with_partial_fields(self): result = str(scenario) # Assert - assert result == "COROUTINE-DEFAULT-_-_-_-_-_-_-_" + assert result == "COROUTINE-DEFAULT-_-_-_-_-_-_-_-_"