diff --git a/wool/src/wool/runtime/worker/README.md b/wool/src/wool/runtime/worker/README.md index c4a4d80..13b3af2 100644 --- a/wool/src/wool/runtime/worker/README.md +++ b/wool/src/wool/runtime/worker/README.md @@ -220,6 +220,20 @@ Proxies on worker subprocesses are lazy by default — the `WorkerPool` propagat When `lazy=True`, concurrent `dispatch()` calls use a double-checked lock to ensure the proxy starts exactly once. The `lazy` flag is preserved through `cloudpickle` serialization, so proxies sent to worker subprocesses as part of a task retain their laziness setting. +### Context lifecycle + +Both `WorkerPool` and `WorkerProxy` are **single-use** async context managers. Once entered and exited, the same instance cannot be entered again — create a new instance instead. Attempting to call `enter()` or `__aenter__()` a second time raises `RuntimeError`. This prevents silent state corruption from reentrant or repeated context usage (e.g., accidentally nesting `async with proxy:` blocks or calling `enter()` in a retry loop). + +```python +# Correct — one instance per context +async with wool.WorkerPool(spawn=4): + await my_routine() + +# Need another pool? Create a new instance. +async with wool.WorkerPool(spawn=4): + await my_routine() +``` + ### Self-describing connections Workers are self-describing: each worker advertises its gRPC transport configuration via `ChannelOptions` in its `WorkerMetadata`. When a client discovers a worker, it reads the advertised options and configures its channel to match — message sizes, keepalive intervals, concurrency limits, and compression are all set automatically. There is no separate client-side configuration step; the worker's metadata is the single source of truth for how to connect to it. diff --git a/wool/src/wool/runtime/worker/pool.py b/wool/src/wool/runtime/worker/pool.py index d42d893..a5ced29 100644 --- a/wool/src/wool/runtime/worker/pool.py +++ b/wool/src/wool/runtime/worker/pool.py @@ -26,6 +26,7 @@ from wool.runtime.worker.proxy import LoadBalancerLike from wool.runtime.worker.proxy import RoundRobinLoadBalancer from wool.runtime.worker.proxy import WorkerProxy +from wool.utilities.noreentry import noreentry # public @@ -397,6 +398,7 @@ async def create_proxy(): self._proxy_factory = create_proxy + @noreentry async def __aenter__(self) -> WorkerPool: """Starts the worker pool and its services, returning a session. @@ -405,6 +407,10 @@ async def __aenter__(self) -> WorkerPool: :returns: The :class:`WorkerPool` instance itself for method chaining. + :raises RuntimeError: + If the pool has already been entered. ``WorkerPool`` + contexts are single-use — create a new instance instead + of re-entering. """ self._proxy_context = self._proxy_factory() await self._proxy_context.__aenter__() diff --git a/wool/src/wool/runtime/worker/proxy.py b/wool/src/wool/runtime/worker/proxy.py index 7be103e..a92636c 100644 --- a/wool/src/wool/runtime/worker/proxy.py +++ b/wool/src/wool/runtime/worker/proxy.py @@ -33,6 +33,7 @@ from wool.runtime.worker.auth import WorkerCredentials from wool.runtime.worker.connection import WorkerConnection from wool.runtime.worker.metadata import WorkerMetadata +from wool.utilities.noreentry import noreentry if TYPE_CHECKING: from contextvars import Token @@ -413,6 +414,7 @@ def workers(self) -> list[WorkerMetadata]: else: return [] + @noreentry async def enter(self) -> None: """Enter the proxy context. @@ -421,6 +423,10 @@ async def enter(self) -> None: :meth:`dispatch` is first called. When ``lazy=False``, calls :meth:`start` eagerly. + :raises RuntimeError: + If the proxy has already been entered. ``WorkerProxy`` + contexts are single-use — create a new instance instead + of re-entering. :raises RuntimeError: If the proxy has already been started and ``lazy`` is ``False``. diff --git a/wool/src/wool/utilities/noreentry.py b/wool/src/wool/utilities/noreentry.py new file mode 100644 index 0000000..f497e02 --- /dev/null +++ b/wool/src/wool/utilities/noreentry.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import asyncio.coroutines +import functools +import inspect +import sys +import weakref +from typing import Never + + +class _Token: + """Hashable, weakly-referenceable token for instance tracking.""" + + pass + + +class NoReentryBoundMethod: + """Descriptor implementing single-use guard for bound methods. + + On the first call the decorated method executes normally. Any + subsequent call raises :class:`RuntimeError`. + + Guard state uses a per-instance token stored on the instance under + ``__noreentry_token__``. The token is unique, hashable, and tied to the + instance's lifetime. The descriptor tracks tokens in a WeakSet to auto-clean + when instances are garbage collected. + + Works with both synchronous and asynchronous methods. Only supports + bound methods; using @noreentry on bare functions raises TypeError. + """ + + def __init__(self, fn, /): + functools.update_wrapper(self, fn) + if inspect.iscoroutinefunction(fn): + if sys.version_info >= (3, 12): + inspect.markcoroutinefunction(self) + else: + self._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore[attr-defined] + self._fn = fn + self._invocations = weakref.WeakSet() + + def __call__(self, *args, **kwargs) -> Never: + raise TypeError("@noreentry only decorates methods, not bare functions") + + def __get__(self, obj, objtype=None): + if obj is None: + return self + + # Cache the wrapper on the instance to avoid recreating it on every access. + cache_key = f"__noreentry_wrapper_{id(self)}__" + return obj.__dict__.setdefault(cache_key, self._make_wrapper(obj)) + + def _make_wrapper(self, obj): + """Create the bound wrapper for an instance.""" + guard = self._guard + fn = self._fn + if inspect.iscoroutinefunction(fn): + + @functools.wraps(fn) + async def async_wrapper(*args, **kwargs): + guard(obj) + return await fn(obj, *args, **kwargs) + + return async_wrapper + else: + + @functools.wraps(fn) + def sync_wrapper(*args, **kwargs): + guard(obj) + return fn(obj, *args, **kwargs) + + return sync_wrapper + + def _guard(self, obj): + """Check and record invocation on the specified object.""" + # Get or create a unique token for this object. + token = obj.__dict__.setdefault("__noreentry_token__", _Token()) + + # Check if this descriptor was invoked on this object already. + if token in self._invocations: + raise RuntimeError( + f"'{self._fn.__qualname__}' cannot be invoked more than once" + ) + + # Track invocation. + self._invocations.add(token) + + +def noreentry(fn): + """Mark a method as single-use. + + On the first call the decorated method executes normally. Any + subsequent call raises :class:`RuntimeError`. + + Guard state uses a per-instance token stored on the instance under + ``__noreentry_token__``. The token is unique, hashable, and tied to the + instance's lifetime. The descriptor tracks tokens in a :class:`WeakSet` + to auto-clean when instances are garbage collected. + + Works with both synchronous and asynchronous methods. Only supports + bound methods; using @noreentry on bare functions raises :class:`TypeError` + on the first invocation. + + :param fn: + The bound instance or class method to decorate. + """ + return NoReentryBoundMethod(fn) diff --git a/wool/tests/runtime/worker/test_pool.py b/wool/tests/runtime/worker/test_pool.py index 102be55..4b39fc3 100644 --- a/wool/tests/runtime/worker/test_pool.py +++ b/wool/tests/runtime/worker/test_pool.py @@ -554,6 +554,47 @@ async def test___aexit___cleanup_on_error(self, mock_worker_factory): assert pool_created, "Pool should have been created before exception" assert exception_caught, "Exception should have been propagated" + @pytest.mark.asyncio + async def test___aenter___already_entered_raises_error(self, mock_worker_factory): + """Test pool raises on reentrant entry. + + Given: + A WorkerPool that is already entered via async with. + When: + The pool is entered a second time via async with. + Then: + It should raise RuntimeError. + """ + # Arrange + pool = WorkerPool(worker=mock_worker_factory, spawn=2) + + # Act & assert + async with pool: + with pytest.raises(RuntimeError, match="cannot be invoked more than once"): + async with pool: + pass + + @pytest.mark.asyncio + async def test___aenter___after_exit_raises_error(self, mock_worker_factory): + """Test pool raises when re-entered after exit. + + Given: + A WorkerPool that has been entered and exited. + When: + The pool is entered again via async with. + Then: + It should raise RuntimeError because the context is single-use. + """ + # Arrange + pool = WorkerPool(worker=mock_worker_factory, spawn=2) + async with pool: + pass + + # Act & assert + with pytest.raises(RuntimeError, match="cannot be invoked more than once"): + async with pool: + pass + @pytest.mark.asyncio async def test___aenter___lifecycle_returns_pool_instance( self, diff --git a/wool/tests/runtime/worker/test_proxy.py b/wool/tests/runtime/worker/test_proxy.py index 3f51a4a..63f9d62 100644 --- a/wool/tests/runtime/worker/test_proxy.py +++ b/wool/tests/runtime/worker/test_proxy.py @@ -797,6 +797,45 @@ async def test_enter_with_non_lazy_proxy( # Assert assert proxy.started + @pytest.mark.asyncio + async def test_enter_already_entered_raises_error(self, mock_discovery_service): + """Test enter raises on reentrant call. + + Given: + A lazy WorkerProxy that has already been entered. + When: + enter() is called a second time. + Then: + It should raise RuntimeError. + """ + # Arrange + proxy = WorkerProxy(discovery=mock_discovery_service) + await proxy.enter() + + # Act & assert + with pytest.raises(RuntimeError, match="cannot be invoked more than once"): + await proxy.enter() + + @pytest.mark.asyncio + async def test_enter_after_exit_raises_error(self, mock_discovery_service): + """Test enter raises after a full enter/exit cycle. + + Given: + A lazy WorkerProxy that has been entered and exited. + When: + enter() is called again. + Then: + It should raise RuntimeError because the context is single-use. + """ + # Arrange + proxy = WorkerProxy(discovery=mock_discovery_service) + await proxy.enter() + await proxy.exit() + + # Act & assert + with pytest.raises(RuntimeError, match="cannot be invoked more than once"): + await proxy.enter() + @pytest.mark.asyncio async def test_stop_clears_state(self, mock_discovery_service, mock_proxy_session): """Test clear workers and reset the started flag to False. diff --git a/wool/tests/utilities/test_noreentry.py b/wool/tests/utilities/test_noreentry.py new file mode 100644 index 0000000..1e0b08c --- /dev/null +++ b/wool/tests/utilities/test_noreentry.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from wool.utilities.noreentry import noreentry + + +# Test helpers (not fixtures) +class _SyncDummy: + """Class with a sync @noreentry method.""" + + @noreentry + def run(self): + return "ok" + + +class _AsyncDummy: + """Class with an async @noreentry method.""" + + @noreentry + async def run(self): + return "ok" + + +class _MultiMethodDummy: + """Class with multiple @noreentry methods.""" + + @noreentry + def first(self): + return "first" + + @noreentry + def second(self): + return "second" + + +class TestNoReentry: + """Tests for the noreentry decorator.""" + + def test_noreentry_sync_method_first_invocation(self): + """Test sync method executes normally on first invocation. + + Given: + A class with a sync @noreentry method. + When: + The method is called for the first time. + Then: + It should return normally. + """ + # Arrange + obj = _SyncDummy() + + # Act + result = obj.run() + + # Assert + assert result == "ok" + + def test_noreentry_sync_method_second_invocation_raises(self): + """Test sync method raises RuntimeError on second invocation. + + Given: + A class with a sync @noreentry method called once. + When: + The method is called a second time on the same instance. + Then: + It should raise RuntimeError. + """ + # Arrange + obj = _SyncDummy() + obj.run() + + # Act & assert + with pytest.raises(RuntimeError, match="cannot be invoked more than once"): + obj.run() + + @pytest.mark.asyncio + async def test_noreentry_async_method_first_invocation(self): + """Test async method executes normally on first invocation. + + Given: + A class with an async @noreentry method. + When: + The method is awaited for the first time. + Then: + It should return normally. + """ + # Arrange + obj = _AsyncDummy() + + # Act + result = await obj.run() + + # Assert + assert result == "ok" + + @pytest.mark.asyncio + async def test_noreentry_async_method_second_invocation_raises(self): + """Test async method raises RuntimeError on second invocation. + + Given: + A class with an async @noreentry method awaited once. + When: + The method is awaited a second time on the same instance. + Then: + It should raise RuntimeError. + """ + # Arrange + obj = _AsyncDummy() + await obj.run() + + # Act & assert + with pytest.raises(RuntimeError, match="cannot be invoked more than once"): + await obj.run() + + def test_noreentry_separate_instances_independent(self): + """Test instances track guard state independently. + + Given: + Two instances of a class with a @noreentry method. + When: + The method is called on the first instance. + Then: + The method remains callable on the second instance. + """ + # Arrange + a = _SyncDummy() + b = _SyncDummy() + a.run() + + # Act + result = b.run() + + # Assert + assert result == "ok" + + def test_noreentry_error_message_qualname(self): + """Test RuntimeError includes the method's qualified name. + + Given: + A class with a @noreentry method called once. + When: + The method is called a second time. + Then: + The RuntimeError message should include the method's __qualname__. + """ + # Arrange + obj = _SyncDummy() + obj.run() + + # Act & assert + with pytest.raises(RuntimeError, match="_SyncDummy.run"): + obj.run() + + def test_noreentry_preserves_coroutinefunction_check(self): + """Test decorator preserves async function detection. + + Given: + A class with an async @noreentry method. + When: + asyncio.iscoroutinefunction is called on the method. + Then: + It should return True. + """ + # Act & assert + assert asyncio.iscoroutinefunction(_AsyncDummy.run) + + def test_noreentry_preserves_wrapped_function_name(self): + """Test decorator preserves the original function name. + + Given: + A class with a @noreentry method. + When: + The decorated method's __name__ is inspected. + Then: + It should equal the original function name. + """ + # Act & assert + assert _SyncDummy.run.__name__ == "run" + + def test_noreentry_multiple_methods_independent(self): + """Test guard on one method does not affect other methods. + + Given: + A class with two @noreentry methods where the first + has been guarded (called twice). + When: + The second method is called. + Then: + The second method executes normally. + """ + # Arrange + obj = _MultiMethodDummy() + obj.first() + with pytest.raises(RuntimeError): + obj.first() + + # Act + result = obj.second() + + # Assert + assert result == "second" + + def test_noreentry_unbound_access_returns_descriptor(self): + """Test accessing decorated method via class returns descriptor. + + Given: + A class with a @noreentry method. + When: + The method is accessed through the class (unbound). + Then: + It should return the descriptor itself. + """ + # Act + unbound = _SyncDummy.run + + # Assert + assert unbound is not None + + def test_noreentry_bare_function_raises_error(self): + """Test decorator rejects application to bare functions. + + Given: + A @noreentry-decorated function called without a bound instance. + When: + The descriptor is called directly. + Then: + It should raise TypeError. + """ + # Arrange + unbound = _SyncDummy.run + + # Act & assert + with pytest.raises( + TypeError, match="only decorates methods, not bare functions" + ): + unbound()