Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions wool/src/wool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from wool.runtime.worker.metadata import WorkerMetadata
from wool.runtime.worker.pool import WorkerPool
from wool.runtime.worker.proxy import WorkerProxy
from wool.runtime.worker.service import BackpressureContext
from wool.runtime.worker.service import BackpressureLike
from wool.runtime.worker.service import WorkerService

pickling_support.install()
Expand Down Expand Up @@ -80,6 +82,9 @@
"TaskException",
"current_task",
"routine",
# Backpressure
"BackpressureContext",
"BackpressureLike",
# Workers
"LocalWorker",
"Worker",
Expand Down
15 changes: 15 additions & 0 deletions wool/src/wool/runtime/worker/local.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING
from typing import Any

import grpc.aio
Expand All @@ -11,6 +12,9 @@
from wool.runtime.worker.base import WorkerOptions
from wool.runtime.worker.process import WorkerProcess

if TYPE_CHECKING:
from wool.runtime.worker.service import BackpressureLike


# public
class LocalWorker(Worker):
Expand Down Expand Up @@ -60,6 +64,15 @@ class LocalWorker(Worker):
:param options:
gRPC message size options. Defaults to
:class:`WorkerOptions` with 100 MB limits.
:param backpressure:
Optional admission control hook. A callable receiving a
:class:`~wool.runtime.worker.service.BackpressureContext`
and returning ``True`` to **reject** the task or ``False``
to **accept** it. Both sync and async callables are
supported. When a task is rejected the worker responds with
gRPC ``RESOURCE_EXHAUSTED``, causing the load balancer to
skip to the next worker. ``None`` (default) accepts all
tasks unconditionally.
:param extra:
Additional metadata as key-value pairs.
"""
Expand All @@ -76,6 +89,7 @@ def __init__(
proxy_pool_ttl: float = 60.0,
credentials: WorkerCredentials | None = None,
options: WorkerOptions | None = None,
backpressure: BackpressureLike | None = None,
**extra: Any,
):
super().__init__(*tags, **extra)
Expand All @@ -90,6 +104,7 @@ def __init__(
options=options,
tags=frozenset(self._tags),
extra=self._extra,
backpressure=backpressure,
)

@property
Expand Down
29 changes: 27 additions & 2 deletions wool/src/wool/runtime/worker/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any
from typing import Final

import cloudpickle
import grpc.aio

import wool
Expand All @@ -32,6 +33,7 @@

if TYPE_CHECKING:
from wool.runtime.worker.proxy import WorkerProxy
from wool.runtime.worker.service import BackpressureLike

_ctx = _mp.get_context("spawn")
Pipe = _ctx.Pipe
Expand Down Expand Up @@ -64,6 +66,18 @@ class WorkerProcess(Process):
:param options:
gRPC message size options. Defaults to
:class:`WorkerOptions` with 100 MB limits.
:param uid:
Unique identifier for this worker. Auto-generated if not
provided.
:param tags:
Capability tags for filtering and selection.
:param extra:
Additional metadata as key-value pairs.
:param backpressure:
Optional admission control hook. See
:class:`~wool.runtime.worker.service.BackpressureLike`.
Serialized with ``cloudpickle`` for transfer to the
subprocess.
:param args:
Additional args for :class:`multiprocessing.Process`.
:param kwargs:
Expand Down Expand Up @@ -91,6 +105,7 @@ def __init__(
options: WorkerOptions | None = None,
tags: frozenset[str] = frozenset(),
extra: dict[str, Any] | None = None,
backpressure: BackpressureLike | None = None,
**kwargs,
):
super().__init__(*args, **kwargs)
Expand All @@ -112,6 +127,9 @@ def __init__(
self._tags = tags
self._extra = extra if extra is not None else {}
self._metadata = None
self._backpressure = (
cloudpickle.dumps(backpressure) if backpressure is not None else None
)
self._get_metadata, self._set_metadata = Pipe(duplex=False)

@property
Expand Down Expand Up @@ -294,7 +312,12 @@ async def _serve(self):
server.add_insecure_port(uds_target)
uds_address = uds_target

service = WorkerService()
backpressure = (
cloudpickle.loads(self._backpressure)
if self._backpressure is not None
else None
)
service = WorkerService(backpressure=backpressure)
protocol.add_to_server[protocol.WorkerServicer](service, server)

with _signal_handlers(service):
Expand Down Expand Up @@ -336,8 +359,10 @@ async def _serve(self):
os.unlink(uds_path)

def _address(self, host, port) -> str:
"""Format network address for the given port.
"""Format network address for the given host and port.

:param host:
Host address to include in the address.
:param port:
Port number to include in the address.
:returns:
Expand Down
103 changes: 86 additions & 17 deletions wool/src/wool/runtime/worker/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
import contextvars
import threading
from contextlib import contextmanager
from dataclasses import dataclass
from inspect import isasyncgen
from inspect import isasyncgenfunction
from inspect import isawaitable
from inspect import iscoroutinefunction
from typing import AsyncGenerator
from typing import AsyncIterator
from typing import Awaitable
from typing import Protocol
from typing import runtime_checkable

import cloudpickle
from grpc import StatusCode
Expand All @@ -25,6 +30,59 @@
_SENTINEL = object()


# public
@dataclass(frozen=True)
class BackpressureContext:
"""Snapshot of worker state provided to backpressure hooks.

:param active_task_count:
Number of tasks currently executing on this worker.
:param task:
The incoming :class:`~wool.runtime.routine.task.Task` being
evaluated for admission.
"""

active_task_count: int
task: Task


# public
@runtime_checkable
class BackpressureLike(Protocol):
"""Protocol for backpressure hooks.

A backpressure hook determines whether an incoming task should be
rejected. Return ``True`` to **reject** the task (apply
backpressure) or ``False`` to **accept** it.

When a task is rejected the worker responds with gRPC
``RESOURCE_EXHAUSTED``, which the load balancer treats as
transient and skips to the next worker.

Pass ``None`` (the default) to accept all tasks unconditionally.

Both sync and async implementations are supported::

def sync_hook(ctx: BackpressureContext) -> bool:
return ctx.active_task_count >= 4


async def async_hook(ctx: BackpressureContext) -> bool:
return ctx.active_task_count >= 4
"""

def __call__(self, ctx: BackpressureContext) -> bool | Awaitable[bool]:
"""Evaluate whether to reject the incoming task.

:param ctx:
Snapshot of the worker's current state and the incoming
task.
:returns:
``True`` to reject the task, ``False`` to accept it.
"""
...


class _Task:
def __init__(self, task: asyncio.Task):
self._work = task
Expand Down Expand Up @@ -78,6 +136,11 @@ class WorkerService(protocol.WorkerServicer):
Handles graceful shutdown by rejecting new tasks while allowing
in-flight tasks to complete. Exposes :attr:`stopping` and
:attr:`stopped` events for lifecycle monitoring.

:param backpressure:
Optional admission control hook. See
:class:`BackpressureLike`. ``None`` (default) accepts all
tasks unconditionally.
"""

_docket: set[_Task | _AsyncGen]
Expand All @@ -86,11 +149,12 @@ class WorkerService(protocol.WorkerServicer):
_task_completed: asyncio.Event
_loop_pool: ResourcePool[tuple[asyncio.AbstractEventLoop, threading.Thread]]

def __init__(self):
def __init__(self, *, backpressure: BackpressureLike | None = None):
self._stopped = asyncio.Event()
self._stopping = asyncio.Event()
self._task_completed = asyncio.Event()
self._docket = set()
self._backpressure = backpressure
self._loop_pool = ResourcePool(
factory=self._create_worker_loop,
finalizer=self._destroy_worker_loop,
Expand Down Expand Up @@ -143,7 +207,24 @@ async def dispatch(
)

response = await anext(aiter(request_iterator))
with self._tracker(Task.from_protobuf(response.task), request_iterator) as task:
work_task = Task.from_protobuf(response.task)

if self._backpressure is not None:
decision = self._backpressure(
BackpressureContext(
active_task_count=len(self._docket),
task=work_task,
)
)
if isawaitable(decision):
decision = await decision
if decision:
await context.abort(
StatusCode.RESOURCE_EXHAUSTED,
"Task rejected by backpressure hook",
)

with self._tracker(work_task, request_iterator) as task:
ack = protocol.Ack(version=protocol.__version__)
yield protocol.Response(ack=ack)
try:
Expand Down Expand Up @@ -471,21 +552,9 @@ async def _await(self):
await asyncio.sleep(0)

async def _cancel(self):
"""Cancel multiple tasks safely.
"""Cancel all tracked tasks in the docket.

Cancels the provided tasks while performing safety checks to
avoid canceling the current task or already completed tasks.
Waits for all cancelled tasks to complete in parallel and handles
cancellation exceptions.

:param tasks:
The :class:`asyncio.Task` instances to cancel.

.. note::
This method performs the following safety checks:
- Avoids canceling the current task (would cause deadlock)
- Only cancels tasks that are not already done
- Properly handles :exc:`asyncio.CancelledError`
exceptions.
Cancels every entry in :attr:`_docket` and waits for them to
finish, handling cancellation exceptions gracefully.
"""
await asyncio.gather(*(w.cancel() for w in self._docket), return_exceptions=True)
Loading
Loading