diff --git a/wool/src/wool/__init__.py b/wool/src/wool/__init__.py index 859b6e9..3540580 100644 --- a/wool/src/wool/__init__.py +++ b/wool/src/wool/__init__.py @@ -38,6 +38,8 @@ from wool.runtime.worker.metadata import WorkerMetadata from wool.runtime.worker.pool import WorkerPool from wool.runtime.worker.proxy import WorkerProxy +from wool.runtime.worker.service import BackpressureContext +from wool.runtime.worker.service import BackpressureLike from wool.runtime.worker.service import WorkerService pickling_support.install() @@ -80,6 +82,9 @@ "TaskException", "current_task", "routine", + # Backpressure + "BackpressureContext", + "BackpressureLike", # Workers "LocalWorker", "Worker", diff --git a/wool/src/wool/runtime/worker/local.py b/wool/src/wool/runtime/worker/local.py index 0f6a61f..e84e703 100644 --- a/wool/src/wool/runtime/worker/local.py +++ b/wool/src/wool/runtime/worker/local.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from typing import TYPE_CHECKING from typing import Any import grpc.aio @@ -11,6 +12,9 @@ from wool.runtime.worker.base import WorkerOptions from wool.runtime.worker.process import WorkerProcess +if TYPE_CHECKING: + from wool.runtime.worker.service import BackpressureLike + # public class LocalWorker(Worker): @@ -60,6 +64,15 @@ class LocalWorker(Worker): :param options: gRPC message size options. Defaults to :class:`WorkerOptions` with 100 MB limits. + :param backpressure: + Optional admission control hook. A callable receiving a + :class:`~wool.runtime.worker.service.BackpressureContext` + and returning ``True`` to **reject** the task or ``False`` + to **accept** it. Both sync and async callables are + supported. When a task is rejected the worker responds with + gRPC ``RESOURCE_EXHAUSTED``, causing the load balancer to + skip to the next worker. ``None`` (default) accepts all + tasks unconditionally. :param extra: Additional metadata as key-value pairs. """ @@ -76,6 +89,7 @@ def __init__( proxy_pool_ttl: float = 60.0, credentials: WorkerCredentials | None = None, options: WorkerOptions | None = None, + backpressure: BackpressureLike | None = None, **extra: Any, ): super().__init__(*tags, **extra) @@ -90,6 +104,7 @@ def __init__( options=options, tags=frozenset(self._tags), extra=self._extra, + backpressure=backpressure, ) @property diff --git a/wool/src/wool/runtime/worker/process.py b/wool/src/wool/runtime/worker/process.py index 4f12cb8..3ff5a8e 100644 --- a/wool/src/wool/runtime/worker/process.py +++ b/wool/src/wool/runtime/worker/process.py @@ -18,6 +18,7 @@ from typing import Any from typing import Final +import cloudpickle import grpc.aio import wool @@ -32,6 +33,7 @@ if TYPE_CHECKING: from wool.runtime.worker.proxy import WorkerProxy + from wool.runtime.worker.service import BackpressureLike _ctx = _mp.get_context("spawn") Pipe = _ctx.Pipe @@ -64,6 +66,18 @@ class WorkerProcess(Process): :param options: gRPC message size options. Defaults to :class:`WorkerOptions` with 100 MB limits. + :param uid: + Unique identifier for this worker. Auto-generated if not + provided. + :param tags: + Capability tags for filtering and selection. + :param extra: + Additional metadata as key-value pairs. + :param backpressure: + Optional admission control hook. See + :class:`~wool.runtime.worker.service.BackpressureLike`. + Serialized with ``cloudpickle`` for transfer to the + subprocess. :param args: Additional args for :class:`multiprocessing.Process`. :param kwargs: @@ -91,6 +105,7 @@ def __init__( options: WorkerOptions | None = None, tags: frozenset[str] = frozenset(), extra: dict[str, Any] | None = None, + backpressure: BackpressureLike | None = None, **kwargs, ): super().__init__(*args, **kwargs) @@ -112,6 +127,9 @@ def __init__( self._tags = tags self._extra = extra if extra is not None else {} self._metadata = None + self._backpressure = ( + cloudpickle.dumps(backpressure) if backpressure is not None else None + ) self._get_metadata, self._set_metadata = Pipe(duplex=False) @property @@ -294,7 +312,12 @@ async def _serve(self): server.add_insecure_port(uds_target) uds_address = uds_target - service = WorkerService() + backpressure = ( + cloudpickle.loads(self._backpressure) + if self._backpressure is not None + else None + ) + service = WorkerService(backpressure=backpressure) protocol.add_to_server[protocol.WorkerServicer](service, server) with _signal_handlers(service): @@ -336,8 +359,10 @@ async def _serve(self): os.unlink(uds_path) def _address(self, host, port) -> str: - """Format network address for the given port. + """Format network address for the given host and port. + :param host: + Host address to include in the address. :param port: Port number to include in the address. :returns: diff --git a/wool/src/wool/runtime/worker/service.py b/wool/src/wool/runtime/worker/service.py index a866d89..e76743a 100644 --- a/wool/src/wool/runtime/worker/service.py +++ b/wool/src/wool/runtime/worker/service.py @@ -5,11 +5,16 @@ import contextvars import threading from contextlib import contextmanager +from dataclasses import dataclass from inspect import isasyncgen from inspect import isasyncgenfunction +from inspect import isawaitable from inspect import iscoroutinefunction from typing import AsyncGenerator from typing import AsyncIterator +from typing import Awaitable +from typing import Protocol +from typing import runtime_checkable import cloudpickle from grpc import StatusCode @@ -25,6 +30,59 @@ _SENTINEL = object() +# public +@dataclass(frozen=True) +class BackpressureContext: + """Snapshot of worker state provided to backpressure hooks. + + :param active_task_count: + Number of tasks currently executing on this worker. + :param task: + The incoming :class:`~wool.runtime.routine.task.Task` being + evaluated for admission. + """ + + active_task_count: int + task: Task + + +# public +@runtime_checkable +class BackpressureLike(Protocol): + """Protocol for backpressure hooks. + + A backpressure hook determines whether an incoming task should be + rejected. Return ``True`` to **reject** the task (apply + backpressure) or ``False`` to **accept** it. + + When a task is rejected the worker responds with gRPC + ``RESOURCE_EXHAUSTED``, which the load balancer treats as + transient and skips to the next worker. + + Pass ``None`` (the default) to accept all tasks unconditionally. + + Both sync and async implementations are supported:: + + def sync_hook(ctx: BackpressureContext) -> bool: + return ctx.active_task_count >= 4 + + + async def async_hook(ctx: BackpressureContext) -> bool: + return ctx.active_task_count >= 4 + """ + + def __call__(self, ctx: BackpressureContext) -> bool | Awaitable[bool]: + """Evaluate whether to reject the incoming task. + + :param ctx: + Snapshot of the worker's current state and the incoming + task. + :returns: + ``True`` to reject the task, ``False`` to accept it. + """ + ... + + class _Task: def __init__(self, task: asyncio.Task): self._work = task @@ -78,6 +136,11 @@ class WorkerService(protocol.WorkerServicer): Handles graceful shutdown by rejecting new tasks while allowing in-flight tasks to complete. Exposes :attr:`stopping` and :attr:`stopped` events for lifecycle monitoring. + + :param backpressure: + Optional admission control hook. See + :class:`BackpressureLike`. ``None`` (default) accepts all + tasks unconditionally. """ _docket: set[_Task | _AsyncGen] @@ -86,11 +149,12 @@ class WorkerService(protocol.WorkerServicer): _task_completed: asyncio.Event _loop_pool: ResourcePool[tuple[asyncio.AbstractEventLoop, threading.Thread]] - def __init__(self): + def __init__(self, *, backpressure: BackpressureLike | None = None): self._stopped = asyncio.Event() self._stopping = asyncio.Event() self._task_completed = asyncio.Event() self._docket = set() + self._backpressure = backpressure self._loop_pool = ResourcePool( factory=self._create_worker_loop, finalizer=self._destroy_worker_loop, @@ -143,7 +207,24 @@ async def dispatch( ) response = await anext(aiter(request_iterator)) - with self._tracker(Task.from_protobuf(response.task), request_iterator) as task: + work_task = Task.from_protobuf(response.task) + + if self._backpressure is not None: + decision = self._backpressure( + BackpressureContext( + active_task_count=len(self._docket), + task=work_task, + ) + ) + if isawaitable(decision): + decision = await decision + if decision: + await context.abort( + StatusCode.RESOURCE_EXHAUSTED, + "Task rejected by backpressure hook", + ) + + with self._tracker(work_task, request_iterator) as task: ack = protocol.Ack(version=protocol.__version__) yield protocol.Response(ack=ack) try: @@ -471,21 +552,9 @@ async def _await(self): await asyncio.sleep(0) async def _cancel(self): - """Cancel multiple tasks safely. + """Cancel all tracked tasks in the docket. - Cancels the provided tasks while performing safety checks to - avoid canceling the current task or already completed tasks. - Waits for all cancelled tasks to complete in parallel and handles - cancellation exceptions. - - :param tasks: - The :class:`asyncio.Task` instances to cancel. - - .. note:: - This method performs the following safety checks: - - Avoids canceling the current task (would cause deadlock) - - Only cancels tasks that are not already done - - Properly handles :exc:`asyncio.CancelledError` - exceptions. + Cancels every entry in :attr:`_docket` and waits for them to + finish, handling cancellation exceptions gracefully. """ await asyncio.gather(*(w.cancel() for w in self._docket), return_exceptions=True) diff --git a/wool/tests/integration/conftest.py b/wool/tests/integration/conftest.py index 05e3592..321b5aa 100644 --- a/wool/tests/integration/conftest.py +++ b/wool/tests/integration/conftest.py @@ -105,11 +105,27 @@ class RoutineBinding(Enum): STATICMETHOD = auto() +class BackpressureMode(Enum): + NONE = auto() + SYNC = auto() + ASYNC = auto() + + class LazyMode(Enum): LAZY = auto() EAGER = auto() +def _sync_accept_hook(ctx): + """Sync backpressure hook that accepts all tasks.""" + return False + + +async def _async_accept_hook(ctx): + """Async backpressure hook that accepts all tasks.""" + return False + + @dataclass(frozen=True) class Scenario: """Composable scenario describing one integration test configuration. @@ -127,6 +143,7 @@ class Scenario: timeout: TimeoutKind | None = None binding: RoutineBinding | None = None lazy: LazyMode | None = None + backpressure: BackpressureMode | None = None def __or__(self, other: Scenario) -> Scenario: """Merge two partial scenarios. Right side wins on ``None`` fields. @@ -147,7 +164,7 @@ def __or__(self, other: Scenario) -> Scenario: @property def is_complete(self) -> bool: - """True when all 9 dimensions are set.""" + """True when all 10 dimensions are set.""" return all(getattr(self, f.name) is not None for f in fields(self)) def __str__(self) -> str: @@ -285,29 +302,46 @@ async def _lan_async_cm(): lazy = scenario.lazy is LazyMode.LAZY + match scenario.backpressure: + case BackpressureMode.SYNC: + bp_hook = _sync_accept_hook + case BackpressureMode.ASYNC: + bp_hook = _async_accept_hook + case _: + bp_hook = None + try: if runtime_ctx is not None: runtime_ctx.__enter__() try: if scenario.pool_mode is PoolMode.DURABLE: - async with _durable_pool_context(lb, creds, options, lazy) as pool: + async with _durable_pool_context( + lb, creds, options, lazy, backpressure=bp_hook + ) as pool: yield pool elif scenario.pool_mode is PoolMode.DURABLE_SHARED: async with _durable_shared_pool_context( - lb, creds, options, lazy + lb, creds, options, lazy, backpressure=bp_hook ) as pool: yield pool elif scenario.pool_mode is PoolMode.DURABLE_JOINED: async with _durable_joined_pool_context( - scenario.discovery, lb, creds, options, lazy + scenario.discovery, + lb, + creds, + options, + lazy, + backpressure=bp_hook, ) as pool: yield pool else: pool_kwargs = { "loadbalancer": lb, "credentials": creds, - "worker": partial(LocalWorker, options=options), + "worker": partial( + LocalWorker, options=options, backpressure=bp_hook + ), "lazy": lazy, } match scenario.pool_mode: @@ -356,7 +390,7 @@ async def _lan_async_cm(): @asynccontextmanager -async def _durable_pool_context(lb, creds, options, lazy): +async def _durable_pool_context(lb, creds, options, lazy, *, backpressure=None): """Manually start a worker, register it, then create a DURABLE pool. DURABLE pools don't spawn workers — they only discover external @@ -368,7 +402,9 @@ async def _durable_pool_context(lb, creds, options, lazy): """ namespace = f"durable-{uuid.uuid4().hex[:12]}" with LocalDiscovery(namespace) as discovery: - worker = LocalWorker(credentials=creds, options=options) + worker = LocalWorker( + credentials=creds, options=options, backpressure=backpressure + ) await worker.start() try: publisher = discovery.publisher @@ -390,7 +426,7 @@ async def _durable_pool_context(lb, creds, options, lazy): @asynccontextmanager -async def _durable_shared_pool_context(lb, creds, options, lazy): +async def _durable_shared_pool_context(lb, creds, options, lazy, *, backpressure=None): """Create two pools sharing the same LocalDiscovery subscriber. Exercises ``SubscriberMeta`` singleton caching and @@ -401,7 +437,9 @@ async def _durable_shared_pool_context(lb, creds, options, lazy): """ namespace = f"shared-{uuid.uuid4().hex[:12]}" with LocalDiscovery(namespace) as discovery: - worker = LocalWorker(credentials=creds, options=options) + worker = LocalWorker( + credentials=creds, options=options, backpressure=backpressure + ) await worker.start() try: publisher = discovery.publisher @@ -469,7 +507,9 @@ async def _acm(): @asynccontextmanager -async def _durable_joined_pool_context(discovery_factory, lb, creds, options, lazy): +async def _durable_joined_pool_context( + discovery_factory, lb, creds, options, lazy, *, backpressure=None +): """Create a DURABLE pool that joins an externally owned namespace. Sets up an owner ``LocalDiscovery`` that creates workers and publishes @@ -479,7 +519,7 @@ async def _durable_joined_pool_context(discovery_factory, lb, creds, options, la """ namespace = f"joined-{uuid.uuid4().hex[:12]}" - worker = LocalWorker(credentials=creds, options=options) + worker = LocalWorker(credentials=creds, options=options, backpressure=backpressure) await worker.start() try: owner = LocalDiscovery(namespace) @@ -720,6 +760,7 @@ def _pairwise_filter(row): timeout=row[6], binding=row[7], lazy=row[8], + backpressure=row[9], ) for row in AllPairs( [ @@ -732,6 +773,7 @@ def _pairwise_filter(row): list(TimeoutKind), list(RoutineBinding), list(LazyMode), + list(BackpressureMode), ], filter_func=_pairwise_filter, ) @@ -792,6 +834,7 @@ def scenarios_strategy(draw): binding = draw(st.sampled_from(RoutineBinding)) lazy = draw(st.sampled_from(LazyMode)) + backpressure = draw(st.sampled_from(BackpressureMode)) return Scenario( shape=shape, @@ -803,6 +846,7 @@ def scenarios_strategy(draw): timeout=timeout, binding=binding, lazy=lazy, + backpressure=backpressure, ) diff --git a/wool/tests/integration/test_integration.py b/wool/tests/integration/test_integration.py index 70d0852..ccd2dfb 100644 --- a/wool/tests/integration/test_integration.py +++ b/wool/tests/integration/test_integration.py @@ -9,6 +9,7 @@ from hypothesis import settings from .conftest import PAIRWISE_SCENARIOS +from .conftest import BackpressureMode from .conftest import CredentialType from .conftest import DiscoveryFactory from .conftest import LazyMode @@ -70,6 +71,7 @@ async def body(): TimeoutKind.NONE, RoutineBinding.MODULE_FUNCTION, LazyMode.LAZY, + BackpressureMode.NONE, ) ) @example( @@ -83,6 +85,7 @@ async def body(): TimeoutKind.NONE, RoutineBinding.MODULE_FUNCTION, LazyMode.LAZY, + BackpressureMode.SYNC, ) ) @example( @@ -96,6 +99,7 @@ async def body(): TimeoutKind.NONE, RoutineBinding.MODULE_FUNCTION, LazyMode.LAZY, + BackpressureMode.ASYNC, ) ) @given(scenario=scenarios_strategy()) diff --git a/wool/tests/integration/test_pool_composition.py b/wool/tests/integration/test_pool_composition.py index 55bf2d5..57cafc4 100644 --- a/wool/tests/integration/test_pool_composition.py +++ b/wool/tests/integration/test_pool_composition.py @@ -1,7 +1,18 @@ """Tests for pool composition via build_pool_from_scenario.""" +import uuid +from functools import partial + import pytest +from wool.runtime.discovery.local import LocalDiscovery +from wool.runtime.loadbalancer.base import NoWorkersAvailable +from wool.runtime.loadbalancer.roundrobin import RoundRobinLoadBalancer +from wool.runtime.worker.local import LocalWorker +from wool.runtime.worker.pool import WorkerPool + +from . import routines +from .conftest import BackpressureMode from .conftest import CredentialType from .conftest import DiscoveryFactory from .conftest import LazyMode @@ -12,6 +23,7 @@ from .conftest import Scenario from .conftest import TimeoutKind from .conftest import WorkerOptionsKind +from .conftest import _DirectDiscovery from .conftest import build_pool_from_scenario from .conftest import invoke_routine @@ -41,6 +53,7 @@ async def test_build_pool_from_scenario_with_default_mode(self, credentials_map) timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -73,6 +86,7 @@ async def test_build_pool_from_scenario_with_ephemeral_mode(self, credentials_ma timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.EAGER, + backpressure=BackpressureMode.NONE, ) # Act @@ -105,6 +119,7 @@ async def test_build_pool_from_scenario_with_durable_mode(self, credentials_map) timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -137,6 +152,7 @@ async def test_build_pool_from_scenario_with_hybrid_mode(self, credentials_map): timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -173,6 +189,7 @@ async def test_build_pool_from_scenario_with_durable_joined_local( timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -205,6 +222,7 @@ async def test_build_pool_from_scenario_with_restrictive_opts(self, credentials_ timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -238,6 +256,7 @@ async def test_build_pool_from_scenario_with_keepalive_opts(self, credentials_ma timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -270,6 +289,7 @@ async def test_build_pool_from_scenario_with_dispatch_timeout(self, credentials_ timeout=TimeoutKind.VIA_RUNTIME_CONTEXT, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act @@ -305,6 +325,79 @@ async def test_build_pool_from_scenario_with_shared_discovery(self, credentials_ timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, + ) + + # Act + async with build_pool_from_scenario(scenario, credentials_map): + result = await invoke_routine(scenario) + + # Assert + assert result == 3 + + @pytest.mark.asyncio + async def test_build_pool_from_scenario_with_sync_backpressure( + self, credentials_map + ): + """Test building a pool with a sync backpressure accept hook. + + Given: + A complete scenario using a SYNC backpressure hook that + accepts all tasks (survives cloudpickle serialization to + the subprocess). + When: + A pool is built and a coroutine routine is dispatched. + Then: + It should return the correct result. + """ + # Arrange + scenario = Scenario( + shape=RoutineShape.COROUTINE, + pool_mode=PoolMode.DEFAULT, + discovery=DiscoveryFactory.NONE, + lb=LbFactory.CLASS_REF, + credential=CredentialType.INSECURE, + options=WorkerOptionsKind.DEFAULT, + timeout=TimeoutKind.NONE, + binding=RoutineBinding.MODULE_FUNCTION, + lazy=LazyMode.LAZY, + backpressure=BackpressureMode.SYNC, + ) + + # Act + async with build_pool_from_scenario(scenario, credentials_map): + result = await invoke_routine(scenario) + + # Assert + assert result == 3 + + @pytest.mark.asyncio + async def test_build_pool_from_scenario_with_async_backpressure( + self, credentials_map + ): + """Test building a pool with an async backpressure accept hook. + + Given: + A complete scenario using an ASYNC backpressure hook that + accepts all tasks (async hook survives cloudpickle + serialization to the subprocess). + When: + A pool is built and a coroutine routine is dispatched. + Then: + It should return the correct result. + """ + # Arrange + scenario = Scenario( + shape=RoutineShape.COROUTINE, + pool_mode=PoolMode.DEFAULT, + discovery=DiscoveryFactory.NONE, + lb=LbFactory.CLASS_REF, + credential=CredentialType.INSECURE, + options=WorkerOptionsKind.DEFAULT, + timeout=TimeoutKind.NONE, + binding=RoutineBinding.MODULE_FUNCTION, + lazy=LazyMode.LAZY, + backpressure=BackpressureMode.ASYNC, ) # Act @@ -313,3 +406,111 @@ async def test_build_pool_from_scenario_with_shared_discovery(self, credentials_ # Assert assert result == 3 + + +def _sync_reject_hook(ctx): + """Sync backpressure hook that rejects all tasks.""" + return True + + +async def _async_reject_hook(ctx): + """Async backpressure hook that rejects all tasks.""" + return True + + +@pytest.mark.integration +class TestBackpressureRejection: + @pytest.mark.asyncio + async def test_sync_backpressure_rejection(self): + """Test sync backpressure hook rejects task end-to-end. + + Given: + A single-worker pool with a sync backpressure hook that + rejects all tasks. + When: + A coroutine routine is dispatched. + Then: + It should raise NoWorkersAvailable because the only worker + rejects with RESOURCE_EXHAUSTED. + """ + # Arrange + pool = WorkerPool( + size=1, + loadbalancer=RoundRobinLoadBalancer, + worker=partial(LocalWorker, backpressure=_sync_reject_hook), + ) + + # Act & assert + async with pool: + with pytest.raises(NoWorkersAvailable): + await routines.add(1, 2) + + @pytest.mark.asyncio + async def test_async_backpressure_rejection(self): + """Test async backpressure hook rejects task end-to-end. + + Given: + A single-worker pool with an async backpressure hook that + rejects all tasks. + When: + A coroutine routine is dispatched. + Then: + It should raise NoWorkersAvailable because the only worker + rejects with RESOURCE_EXHAUSTED. + """ + # Arrange + pool = WorkerPool( + size=1, + loadbalancer=RoundRobinLoadBalancer, + worker=partial(LocalWorker, backpressure=_async_reject_hook), + ) + + # Act & assert + async with pool: + with pytest.raises(NoWorkersAvailable): + await routines.add(1, 2) + + @pytest.mark.asyncio + async def test_backpressure_fallback_to_accepting_worker(self): + """Test load balancer falls through to an accepting worker. + + Given: + A durable pool with two workers: one rejecting all tasks + via backpressure and one accepting all tasks. + When: + A coroutine routine is dispatched. + Then: + It should succeed by falling through to the accepting + worker after the rejecting worker returns + RESOURCE_EXHAUSTED. + """ + # Arrange + namespace = f"bp-fallback-{uuid.uuid4().hex[:12]}" + rejecting_worker = LocalWorker(backpressure=_sync_reject_hook) + accepting_worker = LocalWorker() + + await rejecting_worker.start() + await accepting_worker.start() + try: + with LocalDiscovery(namespace) as discovery: + publisher = discovery.publisher + async with publisher: + await publisher.publish("worker-added", rejecting_worker.metadata) + await publisher.publish("worker-added", accepting_worker.metadata) + pool = WorkerPool( + discovery=_DirectDiscovery(discovery), + loadbalancer=RoundRobinLoadBalancer, + ) + + # Act + async with pool: + result = await routines.add(1, 2) + + # Assert + assert result == 3 + + await publisher.publish("worker-dropped", rejecting_worker.metadata) + await publisher.publish("worker-dropped", accepting_worker.metadata) + finally: + await accepting_worker.stop() + await rejecting_worker.stop() diff --git a/wool/tests/integration/test_scenario.py b/wool/tests/integration/test_scenario.py index 81159b2..162a6e0 100644 --- a/wool/tests/integration/test_scenario.py +++ b/wool/tests/integration/test_scenario.py @@ -2,6 +2,7 @@ import pytest +from .conftest import BackpressureMode from .conftest import CredentialType from .conftest import DiscoveryFactory from .conftest import LazyMode @@ -103,7 +104,7 @@ def test_is_complete_with_all_fields(self): """Test that a fully populated scenario reports complete. Given: - A scenario with all 9 dimensions set. + A scenario with all 10 dimensions set. When: ``is_complete`` is checked. Then: @@ -120,6 +121,7 @@ def test_is_complete_with_all_fields(self): timeout=TimeoutKind.NONE, binding=RoutineBinding.MODULE_FUNCTION, lazy=LazyMode.LAZY, + backpressure=BackpressureMode.NONE, ) # Act & assert @@ -165,4 +167,4 @@ def test___str___with_partial_fields(self): result = str(scenario) # Assert - assert result == "COROUTINE-DEFAULT-_-_-_-_-_-_-_" + assert result == "COROUTINE-DEFAULT-_-_-_-_-_-_-_-_" diff --git a/wool/tests/runtime/worker/test_backpressure.py b/wool/tests/runtime/worker/test_backpressure.py new file mode 100644 index 0000000..6d19f94 --- /dev/null +++ b/wool/tests/runtime/worker/test_backpressure.py @@ -0,0 +1,194 @@ +from uuid import uuid4 + +import pytest +from hypothesis import given +from hypothesis import settings +from hypothesis import strategies as st + +from wool.runtime.routine.task import Task +from wool.runtime.routine.task import WorkerProxyLike +from wool.runtime.worker.service import BackpressureContext +from wool.runtime.worker.service import BackpressureLike + +from .conftest import PicklableMock + + +def _make_task(): + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + return Task( + id=uuid4(), + callable=lambda: None, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + + +class TestBackpressureContext: + """Tests for :class:`BackpressureContext` dataclass.""" + + def test___init___with_valid_fields(self): + """Test BackpressureContext instantiation. + + Given: + An active task count and a Task instance + When: + BackpressureContext is instantiated + Then: + It should store both fields correctly + """ + # Arrange + task = _make_task() + + # Act + ctx = BackpressureContext(active_task_count=3, task=task) + + # Assert + assert ctx.active_task_count == 3 + assert ctx.task is task + + def test___init___is_frozen(self): + """Test BackpressureContext is immutable. + + Given: + A BackpressureContext instance + When: + An attribute is reassigned + Then: + It should raise FrozenInstanceError + """ + # Arrange + ctx = BackpressureContext(active_task_count=0, task=_make_task()) + + # Act & Assert + with pytest.raises(AttributeError): + ctx.active_task_count = 5 + + def test___eq___with_equal_fields(self): + """Test BackpressureContext equality with identical fields. + + Given: + Two BackpressureContext instances with the same active_task_count and task + When: + Compared with == + Then: + It should return True + """ + # Arrange + task = _make_task() + ctx_a = BackpressureContext(active_task_count=2, task=task) + ctx_b = BackpressureContext(active_task_count=2, task=task) + + # Act & Assert + assert ctx_a == ctx_b + + def test___eq___with_different_task_count(self): + """Test BackpressureContext inequality with different active_task_count. + + Given: + Two BackpressureContext instances with different active_task_count + When: + Compared with == + Then: + It should return False + """ + # Arrange + task = _make_task() + ctx_a = BackpressureContext(active_task_count=0, task=task) + ctx_b = BackpressureContext(active_task_count=1, task=task) + + # Act & Assert + assert ctx_a != ctx_b + + @given(count=st.integers(min_value=0, max_value=10_000)) + @settings(max_examples=50) + def test___init___with_arbitrary_task_count(self, count): + """Test BackpressureContext stores arbitrary non-negative active_task_count. + + Given: + Any non-negative integer for active_task_count + When: + BackpressureContext is instantiated + Then: + It should store the value exactly + """ + # Arrange + task = _make_task() + + # Act + ctx = BackpressureContext(active_task_count=count, task=task) + + # Assert + assert ctx.active_task_count == count + + +class TestBackpressureLike: + """Tests for :class:`BackpressureLike` protocol.""" + + def test_sync_callable_satisfies_protocol(self): + """Test sync callable satisfies BackpressureLike. + + Given: + A sync function with the correct signature + When: + Checked against BackpressureLike + Then: + It should be recognized as an instance + """ + + # Arrange + def hook(ctx: BackpressureContext) -> bool: + return True + + # Act & Assert + assert isinstance(hook, BackpressureLike) + + def test_async_callable_satisfies_protocol(self): + """Test async callable satisfies BackpressureLike. + + Given: + An async function with the correct signature + When: + Checked against BackpressureLike + Then: + It should be recognized as an instance + """ + + # Arrange + async def hook(ctx: BackpressureContext) -> bool: + return True + + # Act & Assert + assert isinstance(hook, BackpressureLike) + + def test_callable_class_satisfies_protocol(self): + """Test callable class instance satisfies BackpressureLike. + + Given: + A class with a __call__ method + When: + An instance is checked against BackpressureLike + Then: + It should be recognized as an instance + """ + + # Arrange + class Hook: + def __call__(self, ctx: BackpressureContext) -> bool: + return ctx.active_task_count >= 4 + + # Act & Assert + assert isinstance(Hook(), BackpressureLike) + + def test_non_callable_does_not_satisfy_protocol(self): + """Test non-callable does not satisfy BackpressureLike. + + Given: + A plain string (not callable) + When: + Checked against BackpressureLike + Then: + It should not be recognized as an instance + """ + # Act & Assert + assert not isinstance("not-a-hook", BackpressureLike) diff --git a/wool/tests/runtime/worker/test_local.py b/wool/tests/runtime/worker/test_local.py index 222da4d..0db75de 100644 --- a/wool/tests/runtime/worker/test_local.py +++ b/wool/tests/runtime/worker/test_local.py @@ -704,3 +704,26 @@ async def test_worker_lifecycle_with_credentials( # Assert assert worker.metadata is not None + + def test___init___with_backpressure(self, mocker): + """Test LocalWorker initialization with backpressure hook. + + Given: + A callable backpressure hook + When: + LocalWorker is instantiated with backpressure=hook + Then: + It should forward the hook to WorkerProcess + """ + # Arrange + MockWorkerProcess = mocker.patch.object(local_module, "WorkerProcess") + + def hook(ctx): + return ctx.active_task_count >= 4 + + # Act + LocalWorker(backpressure=hook) + + # Assert + MockWorkerProcess.assert_called_once() + assert MockWorkerProcess.call_args.kwargs["backpressure"] is hook diff --git a/wool/tests/runtime/worker/test_service.py b/wool/tests/runtime/worker/test_service.py index 2b65c01..49dd904 100644 --- a/wool/tests/runtime/worker/test_service.py +++ b/wool/tests/runtime/worker/test_service.py @@ -2172,3 +2172,335 @@ async def resilient_generator(): # Assert assert result["do_dispatch"] is False assert result["has_proxy"] is True + + +class TestWorkerServiceBackpressure: + """Tests for :class:`WorkerService` backpressure hook.""" + + def test___init___with_backpressure_hook(self): + """Test WorkerService initialization with a backpressure hook. + + Given: + A callable backpressure hook + When: + WorkerService is instantiated with backpressure=hook + Then: + It should initialize successfully with stopping and stopped events unset + """ + + # Arrange + def hook(ctx): + return False + + # Act + service = WorkerService(backpressure=hook) + + # Assert + assert not service.stopping.is_set() + assert not service.stopped.is_set() + + @pytest.mark.asyncio + async def test_dispatch_with_sync_backpressure_accepting( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test dispatch succeeds when sync backpressure hook returns False. + + Given: + A :class:`WorkerService` with a sync backpressure hook that returns False + When: + Dispatch RPC is called + Then: + It should accept the task and return the result normally + """ + + # Arrange + async def sample_task(): + return "accepted" + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + wool_task = Task( + id=uuid4(), + callable=sample_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + request = protocol.Request(task=wool_task.to_protobuf()) + + def hook(ctx): + return False + + service = WorkerService(backpressure=hook) + + # Act + async with grpc_aio_stub(servicer=service) as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + responses = [r async for r in stream] + + # Assert + ack, response = responses + assert ack.HasField("ack") + assert response.HasField("result") + assert cloudpickle.loads(response.result.dump) == "accepted" + + @pytest.mark.asyncio + async def test_dispatch_with_sync_backpressure_rejecting( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test dispatch aborts when sync backpressure hook returns True. + + Given: + A :class:`WorkerService` with a sync backpressure hook that returns True + When: + Dispatch RPC is called + Then: + It should reject the task with RESOURCE_EXHAUSTED status + """ + + # Arrange + async def sample_task(): + return "should_not_reach" + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + wool_task = Task( + id=uuid4(), + callable=sample_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + request = protocol.Request(task=wool_task.to_protobuf()) + + def hook(ctx): + return True + + service = WorkerService(backpressure=hook) + + # Act & assert + async with grpc_aio_stub(servicer=service) as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + with pytest.raises(grpc.RpcError) as exc_info: + async for _ in stream: + pass + assert exc_info.value.code() == StatusCode.RESOURCE_EXHAUSTED + + @pytest.mark.asyncio + async def test_dispatch_with_async_backpressure_rejecting( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test dispatch aborts when async backpressure hook returns True. + + Given: + A :class:`WorkerService` with an async backpressure hook that returns True + When: + Dispatch RPC is called + Then: + It should reject the task with RESOURCE_EXHAUSTED status + """ + + # Arrange + async def sample_task(): + return "should_not_reach" + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + wool_task = Task( + id=uuid4(), + callable=sample_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + request = protocol.Request(task=wool_task.to_protobuf()) + + async def async_hook(ctx): + return True + + service = WorkerService(backpressure=async_hook) + + # Act & assert + async with grpc_aio_stub(servicer=service) as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + with pytest.raises(grpc.RpcError) as exc_info: + async for _ in stream: + pass + assert exc_info.value.code() == StatusCode.RESOURCE_EXHAUSTED + + @pytest.mark.asyncio + async def test_dispatch_with_async_backpressure_accepting( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test dispatch succeeds when async backpressure hook returns False. + + Given: + A :class:`WorkerService` with an async backpressure hook that returns False + When: + Dispatch RPC is called + Then: + It should accept the task and return the result normally + """ + + # Arrange + async def sample_task(): + return "async_accepted" + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + wool_task = Task( + id=uuid4(), + callable=sample_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + request = protocol.Request(task=wool_task.to_protobuf()) + + async def async_hook(ctx): + return False + + service = WorkerService(backpressure=async_hook) + + # Act + async with grpc_aio_stub(servicer=service) as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + responses = [r async for r in stream] + + # Assert + ack, response = responses + assert ack.HasField("ack") + assert response.HasField("result") + assert cloudpickle.loads(response.result.dump) == "async_accepted" + + @pytest.mark.asyncio + async def test_dispatch_with_backpressure_receiving_context( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test backpressure hook receives correct context. + + Given: + A :class:`WorkerService` with a backpressure hook that captures its argument + When: + Dispatch RPC is called + Then: + It should pass a BackpressureContext with active_task_count and task fields + """ + # Arrange + from wool.runtime.worker.service import BackpressureContext + + async def sample_task(): + return "result" + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + wool_task = Task( + id=uuid4(), + callable=sample_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + request = protocol.Request(task=wool_task.to_protobuf()) + + captured = [] + + def hook(ctx): + captured.append(ctx) + return False + + service = WorkerService(backpressure=hook) + + # Act + async with grpc_aio_stub(servicer=service) as stub: + stream = stub.dispatch() + await stream.write(request) + await stream.done_writing() + [r async for r in stream] + + # Assert + assert len(captured) == 1 + ctx = captured[0] + assert isinstance(ctx, BackpressureContext) + assert ctx.active_task_count == 0 + assert ctx.task.id == wool_task.id + + @pytest.mark.asyncio + async def test_dispatch_with_backpressure_and_active_tasks( + self, grpc_aio_stub, mock_worker_proxy_cache + ): + """Test backpressure hook sees correct active task count. + + Given: + A :class:`WorkerService` with one active task already dispatched + When: + A second dispatch RPC is called with a backpressure hook + Then: + It should see active_task_count == 1 + """ + # Arrange + global _control_event + _control_event = threading.Event() + + mock_proxy = PicklableMock(spec=WorkerProxyLike, id="test-proxy-id") + + first_task = Task( + id=uuid4(), + callable=_controllable_task, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + first_request = protocol.Request(task=first_task.to_protobuf()) + + async def second_fn(): + return "second" + + second_task = Task( + id=uuid4(), + callable=second_fn, + args=(), + kwargs={}, + proxy=mock_proxy, + ) + second_request = protocol.Request(task=second_task.to_protobuf()) + + captured_count = [] + + def hook(ctx): + captured_count.append(ctx.active_task_count) + return False + + service = WorkerService(backpressure=hook) + + # Act + try: + async with grpc_aio_stub(servicer=service) as stub: + # Dispatch first task (blocks on control event) + stream1 = stub.dispatch() + await stream1.write(first_request) + await stream1.done_writing() + # Wait for ack to confirm first task is tracked + async for response in stream1: + assert response.HasField("ack") + break + + # Dispatch second task — hook should see 1 active task + stream2 = stub.dispatch() + await stream2.write(second_request) + await stream2.done_writing() + [r async for r in stream2] + + # Release first task + _control_event.set() + [r async for r in stream1] + finally: + if _control_event and not _control_event.is_set(): + _control_event.set() + _control_event = None + + # Assert — first dispatch sees 0 active, second sees 1 + assert captured_count == [0, 1] diff --git a/wool/tests/test_public.py b/wool/tests/test_public.py index b3c6891..b5bd0b5 100644 --- a/wool/tests/test_public.py +++ b/wool/tests/test_public.py @@ -45,6 +45,8 @@ def test_public_api_completeness(): "TaskException", "current_task", "routine", + "BackpressureContext", + "BackpressureLike", "LocalWorker", "Worker", "WorkerCredentials",